diff options
| author | vitalyisaev <[email protected]> | 2023-11-14 09:58:56 +0300 |
|---|---|---|
| committer | vitalyisaev <[email protected]> | 2023-11-14 10:20:20 +0300 |
| commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
| tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Functions/isIPAddressContainedIn.cpp | |
| parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Functions/isIPAddressContainedIn.cpp')
| -rw-r--r-- | contrib/clickhouse/src/Functions/isIPAddressContainedIn.cpp | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Functions/isIPAddressContainedIn.cpp b/contrib/clickhouse/src/Functions/isIPAddressContainedIn.cpp new file mode 100644 index 00000000000..abbcb0a5e37 --- /dev/null +++ b/contrib/clickhouse/src/Functions/isIPAddressContainedIn.cpp @@ -0,0 +1,253 @@ +#include <Columns/ColumnConst.h> +#include <Columns/ColumnString.h> +#include <Columns/ColumnsNumber.h> +#include <Common/IPv6ToBinary.h> +#include <Common/formatIPv6.h> +#include <DataTypes/DataTypeNullable.h> +#include <DataTypes/DataTypesNumber.h> +#include <Functions/IFunction.h> +#include <Functions/FunctionFactory.h> +#include <Functions/FunctionHelpers.h> +#include <variant> +#include <charconv> + + +#include <Common/logger_useful.h> +namespace DB::ErrorCodes +{ + extern const int CANNOT_PARSE_TEXT; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +namespace +{ + +class IPAddressVariant +{ +public: + + explicit IPAddressVariant(std::string_view address_str) + { + UInt32 v4; + if (DB::parseIPv4whole(address_str.begin(), address_str.end(), reinterpret_cast<unsigned char *>(&v4))) + { + addr = v4; + } + else + { + addr = IPv6AddrType(); + bool success = DB::parseIPv6whole(address_str.begin(), address_str.end(), std::get<IPv6AddrType>(addr).data()); + if (!success) + throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_TEXT, "Neither IPv4 nor IPv6 address: '{}'", address_str); + } + } + + UInt32 asV4() const + { + if (const auto * val = std::get_if<IPv4AddrType>(&addr)) + return *val; + return 0; + } + + const uint8_t * asV6() const + { + if (const auto * val = std::get_if<IPv6AddrType>(&addr)) + return val->data(); + return nullptr; + } + +private: + using IPv4AddrType = UInt32; + using IPv6AddrType = std::array<uint8_t, IPV6_BINARY_LENGTH>; + + std::variant<IPv4AddrType, IPv6AddrType> addr; +}; + +struct IPAddressCIDR +{ + IPAddressVariant address; + UInt8 prefix; +}; + +IPAddressCIDR parseIPWithCIDR(std::string_view cidr_str) +{ + size_t pos_slash = cidr_str.find('/'); + + if (pos_slash == 0) + throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_TEXT, "Error parsing IP address with prefix: {}", std::string(cidr_str)); + if (pos_slash == std::string_view::npos) + throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_TEXT, "The text does not contain '/': {}", std::string(cidr_str)); + + std::string_view addr_str = cidr_str.substr(0, pos_slash); + IPAddressVariant addr(addr_str); + + uint8_t prefix = 0; + auto prefix_str = cidr_str.substr(pos_slash+1); + + const auto * prefix_str_end = prefix_str.data() + prefix_str.size(); + auto [parse_end, parse_error] = std::from_chars(prefix_str.data(), prefix_str_end, prefix); + uint8_t max_prefix = (addr.asV6() ? IPV6_BINARY_LENGTH : IPV4_BINARY_LENGTH) * 8; + bool has_error = parse_error != std::errc() || parse_end != prefix_str_end || prefix > max_prefix; + if (has_error) + throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_TEXT, "The CIDR has a malformed prefix bits: {}", std::string(cidr_str)); + + return {addr, static_cast<UInt8>(prefix)}; +} + +inline bool isAddressInRange(const IPAddressVariant & address, const IPAddressCIDR & cidr) +{ + if (const auto * cidr_v6 = cidr.address.asV6()) + { + if (const auto * addr_v6 = address.asV6()) + return DB::matchIPv6Subnet(addr_v6, cidr_v6, cidr.prefix); + } + else + { + if (!address.asV6()) + return DB::matchIPv4Subnet(address.asV4(), cidr.address.asV4(), cidr.prefix); + } + return false; +} + +} + +namespace DB +{ + class FunctionIsIPAddressContainedIn : public IFunction + { + public: + static constexpr auto name = "isIPAddressInRange"; + String getName() const override { return name; } + static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionIsIPAddressContainedIn>(); } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /* return_type */, size_t input_rows_count) const override + { + const IColumn * col_addr = arguments[0].column.get(); + const IColumn * col_cidr = arguments[1].column.get(); + + if (const auto * col_addr_const = checkAndGetAnyColumnConst(col_addr)) + { + if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr)) + return executeImpl(*col_addr_const, *col_cidr_const, input_rows_count); + else + return executeImpl(*col_addr_const, *col_cidr, input_rows_count); + } + else + { + if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr)) + return executeImpl(*col_addr, *col_cidr_const, input_rows_count); + else + return executeImpl(*col_addr, *col_cidr, input_rows_count); + } + } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (arguments.size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 2", + getName(), arguments.size()); + + const DataTypePtr & addr_type = arguments[0]; + const DataTypePtr & prefix_type = arguments[1]; + + if (!isString(addr_type) || !isString(prefix_type)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The arguments of function {} must be String", getName()); + + return std::make_shared<DataTypeUInt8>(); + } + + size_t getNumberOfArguments() const override { return 2; } + bool useDefaultImplementationForNulls() const override { return false; } + + private: + /// Like checkAndGetColumnConst() but this function doesn't + /// care about the type of data column. + static const ColumnConst * checkAndGetAnyColumnConst(const IColumn * column) + { + if (!column || !isColumnConst(*column)) + return nullptr; + + return assert_cast<const ColumnConst *>(column); + } + + /// Both columns are constant. + static ColumnPtr executeImpl( + const ColumnConst & col_addr_const, + const ColumnConst & col_cidr_const, + size_t input_rows_count) + { + const auto & col_addr = col_addr_const.getDataColumn(); + const auto & col_cidr = col_cidr_const.getDataColumn(); + + const auto addr = IPAddressVariant(col_addr.getDataAt(0).toView()); + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0).toView()); + + ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1); + ColumnUInt8::Container & vec_res = col_res->getData(); + + vec_res[0] = isAddressInRange(addr, cidr) ? 1 : 0; + + return ColumnConst::create(std::move(col_res), input_rows_count); + } + + /// Address is constant. + static ColumnPtr executeImpl(const ColumnConst & col_addr_const, const IColumn & col_cidr, size_t input_rows_count) + { + const auto & col_addr = col_addr_const.getDataColumn(); + + const auto addr = IPAddressVariant(col_addr.getDataAt(0).toView()); + + ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); + ColumnUInt8::Container & vec_res = col_res->getData(); + + for (size_t i = 0; i < input_rows_count; ++i) + { + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i).toView()); + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; + } + return col_res; + } + + /// CIDR is constant. + static ColumnPtr executeImpl(const IColumn & col_addr, const ColumnConst & col_cidr_const, size_t input_rows_count) + { + const auto & col_cidr = col_cidr_const.getDataColumn(); + + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0).toView()); + + ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); + ColumnUInt8::Container & vec_res = col_res->getData(); + for (size_t i = 0; i < input_rows_count; ++i) + { + const auto addr = IPAddressVariant(col_addr.getDataAt(i).toView()); + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; + } + return col_res; + } + + /// Neither are constant. + static ColumnPtr executeImpl(const IColumn & col_addr, const IColumn & col_cidr, size_t input_rows_count) + { + ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); + ColumnUInt8::Container & vec_res = col_res->getData(); + + for (size_t i = 0; i < input_rows_count; ++i) + { + const auto addr = IPAddressVariant(col_addr.getDataAt(i).toView()); + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i).toView()); + + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; + } + + return col_res; + } + }; + + REGISTER_FUNCTION(IsIPAddressContainedIn) + { + factory.registerFunction<FunctionIsIPAddressContainedIn>(); + } +} |
