diff options
| author | vitalyisaev <[email protected]> | 2023-11-14 09:58:56 +0300 |
|---|---|---|
| committer | vitalyisaev <[email protected]> | 2023-11-14 10:20:20 +0300 |
| commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
| tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.h | |
| parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.h')
| -rw-r--r-- | contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.h | 350 |
1 files changed, 350 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.h b/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.h new file mode 100644 index 00000000000..7b9bb088d8f --- /dev/null +++ b/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.h @@ -0,0 +1,350 @@ +#pragma once + +#include <unordered_map> +#include <base/sort.h> +#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h> +#include <AggregateFunctions/IAggregateFunction.h> +#include <Columns/ColumnFixedString.h> +#include <Columns/ColumnMap.h> +#include <Columns/ColumnString.h> +#include <Columns/ColumnTuple.h> +#include <Columns/ColumnVector.h> +#include <Core/ColumnWithTypeAndName.h> +#include <DataTypes/DataTypeArray.h> +#include <DataTypes/DataTypeMap.h> +#include <DataTypes/DataTypeTuple.h> +#include <DataTypes/DataTypesNumber.h> +#include <Functions/FunctionFactory.h> +#include <Functions/FunctionHelpers.h> +#include <IO/ReadHelpers.h> +#include <IO/WriteHelpers.h> +#include "DataTypes/Serializations/ISerialization.h" +#include <base/IPv4andIPv6.h> +#include "base/types.h" +#include <Common/formatIPv6.h> +#include <Common/Arena.h> +#include "AggregateFunctions/AggregateFunctionFactory.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +template <typename KeyType> +struct AggregateFunctionMapCombinatorData +{ + using SearchType = KeyType; + std::unordered_map<KeyType, AggregateDataPtr> merged_maps; + + static void writeKey(KeyType key, WriteBuffer & buf) { writeBinary(key, buf); } + static void readKey(KeyType & key, ReadBuffer & buf) { readBinary(key, buf); } +}; + +template <> +struct AggregateFunctionMapCombinatorData<String> +{ + struct StringHash + { + using hash_type = std::hash<std::string_view>; + using is_transparent = void; + + size_t operator()(std::string_view str) const { return hash_type{}(str); } + }; + +#ifdef __cpp_lib_generic_unordered_lookup + using SearchType = std::string_view; +#else + using SearchType = std::string; +#endif + std::unordered_map<String, AggregateDataPtr, StringHash, std::equal_to<>> merged_maps; + + static void writeKey(String key, WriteBuffer & buf) + { + writeStringBinary(key, buf); + } + static void readKey(String & key, ReadBuffer & buf) + { + readStringBinary(key, buf); + } +}; + +/// Specialization for IPv6 - for historical reasons it should be stored as FixedString(16) +template <> +struct AggregateFunctionMapCombinatorData<IPv6> +{ + struct IPv6Hash + { + using hash_type = std::hash<IPv6>; + using is_transparent = void; + + size_t operator()(const IPv6 & ip) const { return hash_type{}(ip); } + }; + + using SearchType = IPv6; + std::unordered_map<IPv6, AggregateDataPtr, IPv6Hash, std::equal_to<>> merged_maps; + + static void writeKey(const IPv6 & key, WriteBuffer & buf) + { + writeIPv6Binary(key, buf); + } + static void readKey(IPv6 & key, ReadBuffer & buf) + { + readIPv6Binary(key, buf); + } +}; + +template <typename KeyType> +class AggregateFunctionMap final + : public IAggregateFunctionDataHelper<AggregateFunctionMapCombinatorData<KeyType>, AggregateFunctionMap<KeyType>> +{ +private: + DataTypePtr key_type; + AggregateFunctionPtr nested_func; + + using Data = AggregateFunctionMapCombinatorData<KeyType>; + using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionMap<KeyType>>; + +public: + bool isState() const override + { + return nested_func->isState(); + } + + bool isVersioned() const override + { + return nested_func->isVersioned(); + } + + size_t getVersionFromRevision(size_t revision) const override + { + return nested_func->getVersionFromRevision(revision); + } + + size_t getDefaultVersion() const override + { + return nested_func->getDefaultVersion(); + } + + AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types) + : Base(types, nested->getParameters(), std::make_shared<DataTypeMap>(DataTypes{getKeyType(types, nested), nested->getResultType()})) + , nested_func(nested) + { + key_type = getKeyType(types, nested_func); + } + + String getName() const override { return nested_func->getName() + "Map"; } + + static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested) + { + if (types.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Aggregate function {}Map requires one map argument, but {} found", nested->getName(), types.size()); + + const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get()); + if (!map_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Aggregate function {}Map requires map as argument", nested->getName()); + + return map_type->getKeyType(); + } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + const auto & map_column = assert_cast<const ColumnMap &>(*columns[0]); + const auto & map_nested_tuple = map_column.getNestedData(); + const IColumn::Offsets & map_array_offsets = map_column.getNestedColumn().getOffsets(); + + const size_t offset = map_array_offsets[row_num - 1]; + const size_t size = (map_array_offsets[row_num] - offset); + + const auto & key_column = map_nested_tuple.getColumn(0); + const auto & val_column = map_nested_tuple.getColumn(1); + + auto & merged_maps = this->data(place).merged_maps; + + for (size_t i = 0; i < size; ++i) + { + typename Data::SearchType key; + + if constexpr (std::is_same_v<KeyType, String>) + { + StringRef key_ref; + if (key_type->getTypeId() == TypeIndex::FixedString) + key_ref = assert_cast<const ColumnFixedString &>(key_column).getDataAt(offset + i); + else if (key_type->getTypeId() == TypeIndex::IPv6) + key_ref = assert_cast<const ColumnIPv6 &>(key_column).getDataAt(offset + i); + else + key_ref = assert_cast<const ColumnString &>(key_column).getDataAt(offset + i); + +#ifdef __cpp_lib_generic_unordered_lookup + key = key_ref.toView(); +#else + key = key_ref.toString(); +#endif + } + else + { + key = assert_cast<const ColumnVector<KeyType> &>(key_column).getData()[offset + i]; + } + + AggregateDataPtr nested_place; + auto it = merged_maps.find(key); + + if (it == merged_maps.end()) + { + // create a new place for each key + nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData()); + nested_func->create(nested_place); + merged_maps.emplace(key, nested_place); + } + else + nested_place = it->second; + + const IColumn * nested_columns[1] = {&val_column}; + nested_func->add(nested_place, nested_columns, offset + i, arena); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + auto & merged_maps = this->data(place).merged_maps; + const auto & rhs_maps = this->data(rhs).merged_maps; + + for (const auto & elem : rhs_maps) + { + const auto & it = merged_maps.find(elem.first); + + AggregateDataPtr nested_place; + if (it == merged_maps.end()) + { + // elem.second cannot be copied since this it will be destroyed after merging, + // and lead to use-after-free. + nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData()); + nested_func->create(nested_place); + merged_maps.emplace(elem.first, nested_place); + } + else + { + nested_place = it->second; + } + + nested_func->merge(nested_place, elem.second, arena); + } + } + + template <bool up_to_state> + void destroyImpl(AggregateDataPtr __restrict place) const noexcept + { + AggregateFunctionMapCombinatorData<KeyType> & state = Base::data(place); + + for (const auto & [key, nested_place] : state.merged_maps) + { + if constexpr (up_to_state) + nested_func->destroyUpToState(nested_place); + else + nested_func->destroy(nested_place); + } + + state.~Data(); + } + + void destroy(AggregateDataPtr __restrict place) const noexcept override + { + destroyImpl<false>(place); + } + + bool hasTrivialDestructor() const override + { + return std::is_trivially_destructible_v<Data> && nested_func->hasTrivialDestructor(); + } + + void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override + { + destroyImpl<true>(place); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override + { + auto & merged_maps = this->data(place).merged_maps; + writeVarUInt(merged_maps.size(), buf); + + for (const auto & elem : merged_maps) + { + this->data(place).writeKey(elem.first, buf); + nested_func->serialize(elem.second, buf); + } + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override + { + auto & merged_maps = this->data(place).merged_maps; + UInt64 size; + + readVarUInt(size, buf); + for (UInt64 i = 0; i < size; ++i) + { + KeyType key; + AggregateDataPtr nested_place; + + this->data(place).readKey(key, buf); + nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData()); + nested_func->create(nested_place); + merged_maps.emplace(key, nested_place); + nested_func->deserialize(nested_place, buf, std::nullopt, arena); + } + } + + template <bool merge> + void insertResultIntoImpl(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const + { + auto & map_column = assert_cast<ColumnMap &>(to); + auto & nested_column = map_column.getNestedColumn(); + auto & nested_data_column = map_column.getNestedData(); + + auto & key_column = nested_data_column.getColumn(0); + auto & val_column = nested_data_column.getColumn(1); + + auto & merged_maps = this->data(place).merged_maps; + + // sort the keys + std::vector<KeyType> keys; + keys.reserve(merged_maps.size()); + for (auto & it : merged_maps) + { + keys.push_back(it.first); + } + ::sort(keys.begin(), keys.end()); + + // insert using sorted keys to result column + for (auto & key : keys) + { + key_column.insert(key); + if constexpr (merge) + nested_func->insertMergeResultInto(merged_maps[key], val_column, arena); + else + nested_func->insertResultInto(merged_maps[key], val_column, arena); + } + + IColumn::Offsets & res_offsets = nested_column.getOffsets(); + res_offsets.push_back(val_column.size()); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + insertResultIntoImpl<false>(place, to, arena); + } + + void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + insertResultIntoImpl<true>(place, to, arena); + } + + bool allocatesMemoryInArena() const override { return true; } + + AggregateFunctionPtr getNestedFunction() const override { return nested_func; } +}; + +} |
