From 749d81a237f2ca42a7697f93ae5fb0fce07e1cb6 Mon Sep 17 00:00:00 2001 From: vvvv Date: Mon, 12 Dec 2022 19:13:23 +0300 Subject: use set if only keys are used --- .../yql/minikql/comp_nodes/mkql_block_agg.cpp | 156 ++++++++++++++------- 1 file changed, 105 insertions(+), 51 deletions(-) diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index 46b5fd1b6da..fdbe826fda1 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -478,10 +478,10 @@ TStringBuf GetKeyView(const TSSOKey& key) { return key.AsView(); } -template -class TBlockCombineHashedWrapper : public TStatefulWideFlowComputationNode> { +template +class TBlockCombineHashedWrapper : public TStatefulWideFlowComputationNode> { public: - using TSelf = TBlockCombineHashedWrapper; + using TSelf = TBlockCombineHashedWrapper; using TBase = TStatefulWideFlowComputationNode; TBlockCombineHashedWrapper(TComputationMutables& mutables, @@ -541,29 +541,40 @@ public: auto str = out.Finish(); TKey key = MakeKey(str); bool isNew; - auto iter = s.HashMap_->Insert(key, isNew); - char* ptr = (char*)s.HashMap_->GetPayload(iter); - if (isNew) { - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row); + if constexpr (UseSet) { + auto iter = s.HashSet_->Insert(key, isNew); + if (isNew) { + if constexpr (std::is_same::value) { + MoveKeyToArena(s.HashSet_->GetKey(iter), s.Arena_); } - ptr += s.Aggs_[i]->StateSize; + s.HashSet_->CheckGrow(); } - - if constexpr (std::is_same::value) { - MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_); - } - - s.HashMap_->CheckGrow(); } else { - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row); + auto iter = s.HashMap_->Insert(key, isNew); + char* ptr = (char*)s.HashMap_->GetPayload(iter); + if (isNew) { + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row); + } + + ptr += s.Aggs_[i]->StateSize; } - ptr += s.Aggs_[i]->StateSize; + if constexpr (std::is_same::value) { + MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_); + } + + s.HashMap_->CheckGrow(); + } else { + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row); + } + + ptr += s.Aggs_[i]->StateSize; + } } } } @@ -574,35 +585,61 @@ public: } // export results, TODO: split by batches - auto size = s.HashMap_->GetSize(); + ui64 size; + if constexpr (UseSet) { + size = s.HashSet_->GetSize(); + } else { + size = s.HashMap_->GetSize(); + } + TVector> keyBuilders; for (const auto& ks : KeySerializers_) { keyBuilders.emplace_back(ks->MakeBuilder(size, ctx)); } - TVector> aggBuilders; - for (const auto& a : s.Aggs_) { - aggBuilders.emplace_back(a->MakeBuilder(size)); - } + if constexpr (UseSet) { + for (auto iter = s.HashSet_->Begin(); iter != s.HashSet_->End(); s.HashSet_->Advance(iter)) { + if (s.HashSet_->GetPSL(iter) < 0) { + continue; + } - for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) { - if (s.HashMap_->GetPSL(iter) < 0) { - continue; + const TKey& key = s.HashSet_->GetKey(iter); + TInputBuffer in(GetKeyView(key)); + for (auto& kb : keyBuilders) { + kb->Add(in); + } } + } else { + TVector> aggBuilders; + for (const auto& a : s.Aggs_) { + aggBuilders.emplace_back(a->MakeBuilder(size)); + } + + for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) { + if (s.HashMap_->GetPSL(iter) < 0) { + continue; + } + + const TKey& key = s.HashMap_->GetKey(iter); + auto ptr = (const char*)s.HashMap_->GetPayload(iter); + TInputBuffer in(GetKeyView(key)); + for (auto& kb : keyBuilders) { + kb->Add(in); + } - const TKey& key = s.HashMap_->GetKey(iter); - auto ptr = (const char*)s.HashMap_->GetPayload(iter); - TInputBuffer in(GetKeyView(key)); - for (auto& kb : keyBuilders) { - kb->Add(in); + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + aggBuilders[i]->Add(ptr); + } + + ptr += s.Aggs_[i]->StateSize; + } } for (size_t i = 0; i < s.Aggs_.size(); ++i) { if (output[Keys_.size() + i]) { - aggBuilders[i]->Add(ptr); + *output[Keys_.size() + i] = aggBuilders[i]->Build(); } - - ptr += s.Aggs_[i]->StateSize; } } @@ -612,12 +649,6 @@ public: } } - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - *output[Keys_.size() + i] = aggBuilders[i]->Build(); - } - } - MKQL_ENSURE(output[OutputWidth_ - 1], "Block size should not be marked as unused"); *output[OutputWidth_ - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared(size))); return EFetchResult::One; @@ -636,6 +667,7 @@ private: bool HasValues_ = false; ui32 TotalStateSize_ = 0; std::unique_ptr, std::hash, TMKQLAllocator>> HashMap_; + std::unique_ptr, std::hash, TMKQLAllocator>> HashSet_; TPagedArena Arena_; TState(TMemoryUsageInfo* memInfo, size_t width, std::optional filterColumn, const TVector& params, TComputationContext& ctx) @@ -654,7 +686,12 @@ private: TotalStateSize_ += Aggs_.back()->StateSize; } - HashMap_ = std::make_unique, std::hash, TMKQLAllocator>>(TotalStateSize_); + if constexpr (UseSet) { + MKQL_ENSURE(params.empty(), "Only keys are supported"); + HashSet_ = std::make_unique, std::hash, TMKQLAllocator>>(); + } else { + HashMap_ = std::make_unique, std::hash, TMKQLAllocator>>(TotalStateSize_); + } } }; @@ -698,6 +735,27 @@ void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector +IComputationNode* MakeBlockCombineHashedWrapper( + ui32 totalKeysSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional filterColumn, + size_t width, + const std::vector& keys, + std::vector>&& keySerializers, + TVector&& aggsParams) { + if (totalKeysSize <= sizeof(ui32)) { + return new TBlockCombineHashedWrapper(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)); + } + + if (totalKeysSize <= sizeof(ui64)) { + return new TBlockCombineHashedWrapper(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)); + } + + return new TBlockCombineHashedWrapper(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)); +} + } IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) { @@ -838,15 +896,11 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation } } - if (totalKeysSize <= sizeof(ui32)) { - return new TBlockCombineHashedWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); - } - - if (totalKeysSize <= sizeof(ui64)) { - return new TBlockCombineHashedWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); + if (aggsParams.size() == 0) { + return MakeBlockCombineHashedWrapper(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); } - - return new TBlockCombineHashedWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams)); } } -- cgit v1.3