diff options
| author | vvvv <[email protected]> | 2022-12-12 19:13:23 +0300 |
|---|---|---|
| committer | vvvv <[email protected]> | 2022-12-12 19:13:23 +0300 |
| commit | 749d81a237f2ca42a7697f93ae5fb0fce07e1cb6 (patch) | |
| tree | 4f7ba1cd689f72ed9fc16dfa983cbce3e38c2ae6 | |
| parent | 79fddfef7319acf25a962af6959614ed1a1539f9 (diff) | |
use set if only keys are used
| -rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp | 154 |
1 files changed, 104 insertions, 50 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index 46b5fd1b6da..fdbe826fda1 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -478,10 +478,10 @@ TStringBuf GetKeyView(const TSSOKey& key) { return key.AsView(); } -template <typename TKey> -class TBlockCombineHashedWrapper : public TStatefulWideFlowComputationNode<TBlockCombineHashedWrapper<TKey>> { +template <typename TKey, bool UseSet> +class TBlockCombineHashedWrapper : public TStatefulWideFlowComputationNode<TBlockCombineHashedWrapper<TKey, UseSet>> { public: - using TSelf = TBlockCombineHashedWrapper<TKey>; + using TSelf = TBlockCombineHashedWrapper<TKey, UseSet>; using TBase = TStatefulWideFlowComputationNode<TSelf>; TBlockCombineHashedWrapper(TComputationMutables& mutables, @@ -541,29 +541,40 @@ public: auto str = out.Finish(); TKey key = MakeKey<TKey>(str); bool isNew; - auto iter = s.HashMap_->Insert(key, isNew); - char* ptr = (char*)s.HashMap_->GetPayload(iter); - if (isNew) { - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row); + if constexpr (UseSet) { + auto iter = s.HashSet_->Insert(key, isNew); + if (isNew) { + if constexpr (std::is_same<TKey, TSSOKey>::value) { + MoveKeyToArena(s.HashSet_->GetKey(iter), s.Arena_); } - ptr += s.Aggs_[i]->StateSize; + s.HashSet_->CheckGrow(); } + } else { + auto iter = s.HashMap_->Insert(key, isNew); + char* ptr = (char*)s.HashMap_->GetPayload(iter); + if (isNew) { + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row); + } - if constexpr (std::is_same<TKey, TSSOKey>::value) { - MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_); - } + ptr += s.Aggs_[i]->StateSize; + } - s.HashMap_->CheckGrow(); - } else { - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row); + if constexpr (std::is_same<TKey, TSSOKey>::value) { + MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_); } - ptr += s.Aggs_[i]->StateSize; + s.HashMap_->CheckGrow(); + } else { + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row); + } + + ptr += s.Aggs_[i]->StateSize; + } } } } @@ -574,35 +585,61 @@ public: } // export results, TODO: split by batches - auto size = s.HashMap_->GetSize(); + ui64 size; + if constexpr (UseSet) { + size = s.HashSet_->GetSize(); + } else { + size = s.HashMap_->GetSize(); + } + TVector<std::unique_ptr<IKeyColumnBuilder>> keyBuilders; for (const auto& ks : KeySerializers_) { keyBuilders.emplace_back(ks->MakeBuilder(size, ctx)); } - TVector<std::unique_ptr<IAggColumnBuilder>> aggBuilders; - for (const auto& a : s.Aggs_) { - aggBuilders.emplace_back(a->MakeBuilder(size)); - } + if constexpr (UseSet) { + for (auto iter = s.HashSet_->Begin(); iter != s.HashSet_->End(); s.HashSet_->Advance(iter)) { + if (s.HashSet_->GetPSL(iter) < 0) { + continue; + } - for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) { - if (s.HashMap_->GetPSL(iter) < 0) { - continue; + const TKey& key = s.HashSet_->GetKey(iter); + TInputBuffer in(GetKeyView<TKey>(key)); + for (auto& kb : keyBuilders) { + kb->Add(in); + } } + } else { + TVector<std::unique_ptr<IAggColumnBuilder>> aggBuilders; + for (const auto& a : s.Aggs_) { + aggBuilders.emplace_back(a->MakeBuilder(size)); + } + + for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) { + if (s.HashMap_->GetPSL(iter) < 0) { + continue; + } + + const TKey& key = s.HashMap_->GetKey(iter); + auto ptr = (const char*)s.HashMap_->GetPayload(iter); + TInputBuffer in(GetKeyView<TKey>(key)); + for (auto& kb : keyBuilders) { + kb->Add(in); + } - const TKey& key = s.HashMap_->GetKey(iter); - auto ptr = (const char*)s.HashMap_->GetPayload(iter); - TInputBuffer in(GetKeyView<TKey>(key)); - for (auto& kb : keyBuilders) { - kb->Add(in); + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + aggBuilders[i]->Add(ptr); + } + + ptr += s.Aggs_[i]->StateSize; + } } for (size_t i = 0; i < s.Aggs_.size(); ++i) { if (output[Keys_.size() + i]) { - aggBuilders[i]->Add(ptr); + *output[Keys_.size() + i] = aggBuilders[i]->Build(); } - - ptr += s.Aggs_[i]->StateSize; } } @@ -612,12 +649,6 @@ public: } } - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - *output[Keys_.size() + i] = aggBuilders[i]->Build(); - } - } - MKQL_ENSURE(output[OutputWidth_ - 1], "Block size should not be marked as unused"); *output[OutputWidth_ - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(size))); return EFetchResult::One; @@ -636,6 +667,7 @@ private: bool HasValues_ = false; ui32 TotalStateSize_ = 0; std::unique_ptr<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashMap_; + std::unique_ptr<TRobinHoodHashSet<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashSet_; TPagedArena Arena_; TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams>& params, TComputationContext& ctx) @@ -654,7 +686,12 @@ private: TotalStateSize_ += Aggs_.back()->StateSize; } - HashMap_ = std::make_unique<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(TotalStateSize_); + if constexpr (UseSet) { + MKQL_ENSURE(params.empty(), "Only keys are supported"); + HashSet_ = std::make_unique<TRobinHoodHashSet<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(); + } else { + HashMap_ = std::make_unique<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(TotalStateSize_); + } } }; @@ -698,6 +735,27 @@ void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector<TAggPa } } +template <bool UseSet> +IComputationNode* MakeBlockCombineHashedWrapper( + ui32 totalKeysSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, + TVector<TAggParams>&& aggsParams) { + if (totalKeysSize <= sizeof(ui32)) { + return new TBlockCombineHashedWrapper<ui32, UseSet>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)); + } + + if (totalKeysSize <= sizeof(ui64)) { + return new TBlockCombineHashedWrapper<ui64, UseSet>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)); + } + + return new TBlockCombineHashedWrapper<TSSOKey, UseSet>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)); +} + } IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) { @@ -838,15 +896,11 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation } } - if (totalKeysSize <= sizeof(ui32)) { - return new TBlockCombineHashedWrapper<ui32>(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); - } - - if (totalKeysSize <= sizeof(ui64)) { - return new TBlockCombineHashedWrapper<ui64>(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); + if (aggsParams.size() == 0) { + return MakeBlockCombineHashedWrapper<true>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); } - - return new TBlockCombineHashedWrapper<TSSOKey>(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); } } |
