summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <[email protected]>2022-12-12 19:13:23 +0300
committervvvv <[email protected]>2022-12-12 19:13:23 +0300
commit749d81a237f2ca42a7697f93ae5fb0fce07e1cb6 (patch)
tree4f7ba1cd689f72ed9fc16dfa983cbce3e38c2ae6
parent79fddfef7319acf25a962af6959614ed1a1539f9 (diff)
use set if only keys are used
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp154
1 files changed, 104 insertions, 50 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 46b5fd1b6da..fdbe826fda1 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -478,10 +478,10 @@ TStringBuf GetKeyView(const TSSOKey& key) {
return key.AsView();
}
-template <typename TKey>
-class TBlockCombineHashedWrapper : public TStatefulWideFlowComputationNode<TBlockCombineHashedWrapper<TKey>> {
+template <typename TKey, bool UseSet>
+class TBlockCombineHashedWrapper : public TStatefulWideFlowComputationNode<TBlockCombineHashedWrapper<TKey, UseSet>> {
public:
- using TSelf = TBlockCombineHashedWrapper<TKey>;
+ using TSelf = TBlockCombineHashedWrapper<TKey, UseSet>;
using TBase = TStatefulWideFlowComputationNode<TSelf>;
TBlockCombineHashedWrapper(TComputationMutables& mutables,
@@ -541,29 +541,40 @@ public:
auto str = out.Finish();
TKey key = MakeKey<TKey>(str);
bool isNew;
- auto iter = s.HashMap_->Insert(key, isNew);
- char* ptr = (char*)s.HashMap_->GetPayload(iter);
- if (isNew) {
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row);
+ if constexpr (UseSet) {
+ auto iter = s.HashSet_->Insert(key, isNew);
+ if (isNew) {
+ if constexpr (std::is_same<TKey, TSSOKey>::value) {
+ MoveKeyToArena(s.HashSet_->GetKey(iter), s.Arena_);
}
- ptr += s.Aggs_[i]->StateSize;
+ s.HashSet_->CheckGrow();
}
+ } else {
+ auto iter = s.HashMap_->Insert(key, isNew);
+ char* ptr = (char*)s.HashMap_->GetPayload(iter);
+ if (isNew) {
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row);
+ }
- if constexpr (std::is_same<TKey, TSSOKey>::value) {
- MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_);
- }
+ ptr += s.Aggs_[i]->StateSize;
+ }
- s.HashMap_->CheckGrow();
- } else {
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row);
+ if constexpr (std::is_same<TKey, TSSOKey>::value) {
+ MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_);
}
- ptr += s.Aggs_[i]->StateSize;
+ s.HashMap_->CheckGrow();
+ } else {
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row);
+ }
+
+ ptr += s.Aggs_[i]->StateSize;
+ }
}
}
}
@@ -574,35 +585,61 @@ public:
}
// export results, TODO: split by batches
- auto size = s.HashMap_->GetSize();
+ ui64 size;
+ if constexpr (UseSet) {
+ size = s.HashSet_->GetSize();
+ } else {
+ size = s.HashMap_->GetSize();
+ }
+
TVector<std::unique_ptr<IKeyColumnBuilder>> keyBuilders;
for (const auto& ks : KeySerializers_) {
keyBuilders.emplace_back(ks->MakeBuilder(size, ctx));
}
- TVector<std::unique_ptr<IAggColumnBuilder>> aggBuilders;
- for (const auto& a : s.Aggs_) {
- aggBuilders.emplace_back(a->MakeBuilder(size));
- }
+ if constexpr (UseSet) {
+ for (auto iter = s.HashSet_->Begin(); iter != s.HashSet_->End(); s.HashSet_->Advance(iter)) {
+ if (s.HashSet_->GetPSL(iter) < 0) {
+ continue;
+ }
- for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) {
- if (s.HashMap_->GetPSL(iter) < 0) {
- continue;
+ const TKey& key = s.HashSet_->GetKey(iter);
+ TInputBuffer in(GetKeyView<TKey>(key));
+ for (auto& kb : keyBuilders) {
+ kb->Add(in);
+ }
}
+ } else {
+ TVector<std::unique_ptr<IAggColumnBuilder>> aggBuilders;
+ for (const auto& a : s.Aggs_) {
+ aggBuilders.emplace_back(a->MakeBuilder(size));
+ }
+
+ for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) {
+ if (s.HashMap_->GetPSL(iter) < 0) {
+ continue;
+ }
+
+ const TKey& key = s.HashMap_->GetKey(iter);
+ auto ptr = (const char*)s.HashMap_->GetPayload(iter);
+ TInputBuffer in(GetKeyView<TKey>(key));
+ for (auto& kb : keyBuilders) {
+ kb->Add(in);
+ }
- const TKey& key = s.HashMap_->GetKey(iter);
- auto ptr = (const char*)s.HashMap_->GetPayload(iter);
- TInputBuffer in(GetKeyView<TKey>(key));
- for (auto& kb : keyBuilders) {
- kb->Add(in);
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ aggBuilders[i]->Add(ptr);
+ }
+
+ ptr += s.Aggs_[i]->StateSize;
+ }
}
for (size_t i = 0; i < s.Aggs_.size(); ++i) {
if (output[Keys_.size() + i]) {
- aggBuilders[i]->Add(ptr);
+ *output[Keys_.size() + i] = aggBuilders[i]->Build();
}
-
- ptr += s.Aggs_[i]->StateSize;
}
}
@@ -612,12 +649,6 @@ public:
}
}
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- *output[Keys_.size() + i] = aggBuilders[i]->Build();
- }
- }
-
MKQL_ENSURE(output[OutputWidth_ - 1], "Block size should not be marked as unused");
*output[OutputWidth_ - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(size)));
return EFetchResult::One;
@@ -636,6 +667,7 @@ private:
bool HasValues_ = false;
ui32 TotalStateSize_ = 0;
std::unique_ptr<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashMap_;
+ std::unique_ptr<TRobinHoodHashSet<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashSet_;
TPagedArena Arena_;
TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams>& params, TComputationContext& ctx)
@@ -654,7 +686,12 @@ private:
TotalStateSize_ += Aggs_.back()->StateSize;
}
- HashMap_ = std::make_unique<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(TotalStateSize_);
+ if constexpr (UseSet) {
+ MKQL_ENSURE(params.empty(), "Only keys are supported");
+ HashSet_ = std::make_unique<TRobinHoodHashSet<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>();
+ } else {
+ HashMap_ = std::make_unique<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(TotalStateSize_);
+ }
}
};
@@ -698,6 +735,27 @@ void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector<TAggPa
}
}
+template <bool UseSet>
+IComputationNode* MakeBlockCombineHashedWrapper(
+ ui32 totalKeysSize,
+ TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ std::optional<ui32> filterColumn,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
+ TVector<TAggParams>&& aggsParams) {
+ if (totalKeysSize <= sizeof(ui32)) {
+ return new TBlockCombineHashedWrapper<ui32, UseSet>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ }
+
+ if (totalKeysSize <= sizeof(ui64)) {
+ return new TBlockCombineHashedWrapper<ui64, UseSet>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ }
+
+ return new TBlockCombineHashedWrapper<TSSOKey, UseSet>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+}
+
}
IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -838,15 +896,11 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation
}
}
- if (totalKeysSize <= sizeof(ui32)) {
- return new TBlockCombineHashedWrapper<ui32>(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
- }
-
- if (totalKeysSize <= sizeof(ui64)) {
- return new TBlockCombineHashedWrapper<ui64>(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ if (aggsParams.size() == 0) {
+ return MakeBlockCombineHashedWrapper<true>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ } else {
+ return MakeBlockCombineHashedWrapper<false>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
}
-
- return new TBlockCombineHashedWrapper<TSSOKey>(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
}
}