aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-12-22 14:44:31 +0300
committervvvv <vvvv@ydb.tech>2022-12-22 14:44:31 +0300
commit0ecb69ad9d422ed676fd41df1741115ce62d40b6 (patch)
tree6966085b5b0ca58c6fed332d77cd06976a521bb5
parent0d26f36a6080267e930f1daf6b54476dce512707 (diff)
downloadydb-0ecb69ad9d422ed676fd41df1741115ce62d40b6.tar.gz
move big agg states to arena, adapters for std::unordered_map/std::unordered_set
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp483
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp10
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h6
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp10
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp20
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h154
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_rh_hash_ut.cpp45
7 files changed, 582 insertions, 146 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 66aab94ef4..60531a0dbc 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -15,16 +15,211 @@
#include <arrow/array/array_primitive.h>
#include <arrow/array/builder_primitive.h>
+//#define USE_STD_UNORDERED
+
namespace NKikimr {
namespace NMiniKQL {
namespace {
-class TSSOKey {
+constexpr bool InlineAggState = false;
+
+#ifdef USE_STD_UNORDERED
+template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
+class TDynamicHashMapImpl {
+ using TMapType = std::unordered_map<TKey, std::vector<char>, THash, TEqual>;
+ using const_iterator = typename TMapType::const_iterator;
+ using iterator = typename TMapType::iterator;
+public:
+ TDynamicHashMapImpl(size_t stateSize)
+ : StateSize_(stateSize)
+ {}
+
+ ui64 GetSize() const {
+ return Map_.size();
+ }
+
+ const_iterator Begin() const {
+ return Map_.begin();
+ }
+
+ const_iterator End() const {
+ return Map_.end();
+ }
+
+ bool IsValid(const_iterator iter) const {
+ return true;
+ }
+
+ void Advance(const_iterator& iter) const {
+ ++iter;
+ }
+
+ iterator Insert(const TKey& key, bool& isNew) {
+ auto res = Map_.emplace(key, std::vector<char>());
+ isNew = res.second;
+ if (isNew) {
+ res.first->second.resize(StateSize_);
+ }
+
+ return res.first;
+ }
+
+ const TKey& GetKey(const_iterator it) const {
+ return it->first;
+ }
+
+ char* GetMutablePayload(iterator it) const {
+ return it->second.data();
+ }
+
+ const char* GetPayload(const_iterator it) const {
+ return it->second.data();
+ }
+
+ void CheckGrow() {
+ }
+
+private:
+ const size_t StateSize_;
+ TMapType Map_;
+};
+
+template <typename TKey, typename TPayload, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
+class TFixedHashMapImpl {
+ using TMapType = std::unordered_map<TKey, TPayload, THash, TEqual>;
+ using const_iterator = typename TMapType::const_iterator;
+ using iterator = typename TMapType::iterator;
public:
+ ui64 GetSize() const {
+ return Map_.size();
+ }
+
+ const_iterator Begin() const {
+ return Map_.begin();
+ }
+
+ const_iterator End() const {
+ return Map_.end();
+ }
+
+ bool IsValid(const_iterator iter) const {
+ return true;
+ }
+
+ void Advance(const_iterator& iter) const {
+ ++iter;
+ }
+
+ iterator Insert(const TKey& key, bool& isNew) {
+ auto res = Map_.emplace(key, TPayload());
+ isNew = res.second;
+ return res.first;
+ }
+
+ const TKey& GetKey(const_iterator it) const {
+ return it->first;
+ }
+
+ char* GetMutablePayload(iterator it) const {
+ return (char*)&it->second;
+ }
+
+ const char* GetPayload(const_iterator it) const {
+ return (const char*)&it->second;
+ }
+
+ void CheckGrow() {
+ }
+
+private:
+ TMapType Map_;
+};
+
+template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
+class THashSetImpl {
+ using TSetType = std::unordered_set<TKey, THash, TEqual>;
+ using const_iterator = typename TSetType::const_iterator;
+ using iterator = typename TSetType::iterator;
+public:
+ ui64 GetSize() const {
+ return Set_.size();
+ }
+
+ const_iterator Begin() const {
+ return Set_.begin();
+ }
+
+ const_iterator End() const {
+ return Set_.end();
+ }
+
+ bool IsValid(const_iterator iter) const {
+ return true;
+ }
+
+ void Advance(const_iterator& iter) const {
+ ++iter;
+ }
+
+ iterator Insert(const TKey& key, bool& isNew) {
+ auto res = Set_.emplace(key);
+ isNew = res.second;
+ return res.first;
+ }
+
+ void CheckGrow() {
+ }
+
+ const TKey& GetKey(const_iterator it) const {
+ return *it;
+ }
+
+ char* GetMutablePayload(iterator it) const {
+ Y_UNUSED(it);
+ return nullptr;
+ }
+
+ const char* GetPayload(const_iterator it) const {
+ Y_UNUSED(it);
+ return nullptr;
+ }
+
+private:
+ TSetType Set_;
+};
+
+#else
+#define TDynamicHashMapImpl TRobinHoodHashMap
+#define TFixedHashMapImpl TRobinHoodHashFixedMap
+#define THashSetImpl TRobinHoodHashSet
+#endif
+
+using TState8 = ui64;
+static_assert(sizeof(TState8) == 8);
+
+using TState16 = std::pair<ui64, ui64>;
+static_assert(sizeof(TState16) == 16);
+
+using TStateArena = void*;
+static_assert(sizeof(TStateArena) == sizeof(void*));
+
+class TSSOKey {
+private:
static constexpr size_t SSO_Length = 16;
static_assert(SSO_Length < 128); // should fit into 7 bits
+ struct TExternal {
+ ui64 Length_;
+ const char* Ptr_;
+ };
+
+ struct TInplace {
+ ui8 SmallLength_;
+ char Buffer_[SSO_Length];
+ };
+
+public:
static bool CanBeInplace(TStringBuf data) {
return data.Size() + 1 <= sizeof(TSSOKey);
}
@@ -54,9 +249,9 @@ public:
}
}
- void UpdateExternalPointer(const char *ptr) {
+ void UpdateExternalPointer(const char *ptr) const {
Y_ASSERT(!IsInplace());
- U.E.Ptr_ = ptr;
+ const_cast<TExternal&>(U.E).Ptr_ = ptr;
}
private:
@@ -67,14 +262,8 @@ private:
private:
union {
- struct TExternal {
- ui64 Length_;
- const char* Ptr_;
- } E;
- struct TInplace {
- ui8 SmallLength_;
- char Buffer_[SSO_Length];
- } I;
+ TExternal E;
+ TInplace I;
} U;
};
@@ -412,6 +601,7 @@ private:
ui32 totalStateSize = 0;
for (const auto& p : params) {
Aggs_.emplace_back(p.Prepared_->Make(ctx));
+ MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch");
totalStateSize += Aggs_.back()->StateSize;
}
@@ -462,7 +652,7 @@ TSSOKey MakeKey(TStringBuf s) {
}
}
-void MoveKeyToArena(TSSOKey& key, TPagedArena& arena) {
+void MoveKeyToArena(const TSSOKey& key, TPagedArena& arena) {
if (key.IsInplace()) {
return;
}
@@ -483,12 +673,14 @@ TStringBuf GetKeyView(const TSSOKey& key) {
return key.AsView();
}
-template <typename TKey, typename TAggregator, bool UseSet, bool UseFilter, bool Finalize, typename TDerived>
+template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, typename TDerived>
class THashedWrapperBase : public TStatefulWideFlowComputationNode<TDerived> {
public:
- using TSelf = THashedWrapperBase<TKey, TAggregator, UseSet, UseFilter, Finalize, TDerived>;
+ using TSelf = THashedWrapperBase<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, TDerived>;
using TBase = TStatefulWideFlowComputationNode<TDerived>;
+ static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value;
+
THashedWrapperBase(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
std::optional<ui32> filterColumn,
@@ -574,8 +766,8 @@ public:
auto str = out.Finish();
TKey key = MakeKey<TKey>(str);
- bool isNew;
if constexpr (UseSet) {
+ bool isNew;
auto iter = s.HashSet_->Insert(key, isNew);
if (isNew) {
if constexpr (std::is_same<TKey, TSSOKey>::value) {
@@ -585,38 +777,10 @@ public:
s.HashSet_->CheckGrow();
}
} else {
- auto iter = s.HashMap_->Insert(key, isNew);
- char* ptr = (char*)s.HashMap_->GetPayload(iter);
- if (isNew) {
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- if constexpr (Finalize) {
- s.Aggs_[i]->LoadState(ptr, s.Values_.data(), row);
- } else {
- s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row);
- }
- }
-
- ptr += s.Aggs_[i]->StateSize;
- }
-
- if constexpr (std::is_same<TKey, TSSOKey>::value) {
- MoveKeyToArena(s.HashMap_->GetKey(iter), s.Arena_);
- }
-
- s.HashMap_->CheckGrow();
+ if (!InlineAggState) {
+ Insert(*s.HashFixedMap_, key, row, output, s);
} else {
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- if constexpr (Finalize) {
- s.Aggs_[i]->UpdateState(ptr, s.Values_.data(), row);
- } else {
- s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row);
- }
- }
-
- ptr += s.Aggs_[i]->StateSize;
- }
+ Insert(*s.HashMap_, key, row, output, s);
}
}
}
@@ -631,7 +795,11 @@ public:
if constexpr (UseSet) {
size = s.HashSet_->GetSize();
} else {
- size = s.HashMap_->GetSize();
+ if (!InlineAggState) {
+ size = s.HashFixedMap_->GetSize();
+ } else {
+ size = s.HashMap_->GetSize();
+ }
}
TVector<std::unique_ptr<IKeyColumnBuilder>> keyBuilders;
@@ -641,7 +809,7 @@ public:
if constexpr (UseSet) {
for (auto iter = s.HashSet_->Begin(); iter != s.HashSet_->End(); s.HashSet_->Advance(iter)) {
- if (s.HashSet_->GetPSL(iter) < 0) {
+ if (!s.HashSet_->IsValid(iter)) {
continue;
}
@@ -661,25 +829,10 @@ public:
}
}
- for (auto iter = s.HashMap_->Begin(); iter != s.HashMap_->End(); s.HashMap_->Advance(iter)) {
- if (s.HashMap_->GetPSL(iter) < 0) {
- continue;
- }
-
- const TKey& key = s.HashMap_->GetKey(iter);
- auto ptr = (const char*)s.HashMap_->GetPayload(iter);
- TInputBuffer in(GetKeyView<TKey>(key));
- for (auto& kb : keyBuilders) {
- kb->Add(in);
- }
-
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- aggBuilders[i]->Add(ptr);
- }
-
- ptr += s.Aggs_[i]->StateSize;
- }
+ if (!InlineAggState) {
+ Iterate(*s.HashFixedMap_, keyBuilders, aggBuilders, output, s);
+ } else {
+ Iterate(*s.HashMap_, keyBuilders, aggBuilders, output, s);
}
for (size_t i = 0; i < s.Aggs_.size(); ++i) {
@@ -712,8 +865,9 @@ private:
bool IsFinished_ = false;
bool HasValues_ = false;
ui32 TotalStateSize_ = 0;
- std::unique_ptr<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashMap_;
- std::unique_ptr<TRobinHoodHashSet<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashSet_;
+ std::unique_ptr<TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashMap_;
+ std::unique_ptr<THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashSet_;
+ std::unique_ptr<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashFixedMap_;
TPagedArena Arena_;
TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams<TAggregator>>& params, TComputationContext& ctx)
@@ -728,14 +882,19 @@ private:
for (const auto& p : params) {
Aggs_.emplace_back(p.Prepared_->Make(ctx));
+ MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch");
TotalStateSize_ += Aggs_.back()->StateSize;
}
if constexpr (UseSet) {
MKQL_ENSURE(params.empty(), "Only keys are supported");
- HashSet_ = std::make_unique<TRobinHoodHashSet<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>();
+ HashSet_ = std::make_unique<THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>();
} else {
- HashMap_ = std::make_unique<TRobinHoodHashMap<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(TotalStateSize_);
+ if (!InlineAggState) {
+ HashFixedMap_ = std::make_unique<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>();
+ } else {
+ HashMap_ = std::make_unique<TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>>(TotalStateSize_);
+ }
}
}
};
@@ -752,6 +911,92 @@ private:
return *static_cast<TState*>(state.AsBoxed().Get());
}
+ template <typename THash>
+ void Insert(THash& hash, const TKey& key, ui64 row, NUdf::TUnboxedValue*const* output, TState& s) const {
+ bool isNew;
+ auto iter = hash.Insert(key, isNew);
+ char* payload = (char*)hash.GetMutablePayload(iter);
+ char* ptr;
+
+ if (isNew) {
+ if constexpr (UseArena) {
+ ptr = (char*)s.Arena_.Alloc(s.TotalStateSize_);
+ *(char**)payload = ptr;
+ } else {
+ ptr = payload;
+ }
+
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ if constexpr (Finalize) {
+ s.Aggs_[i]->LoadState(ptr, s.Values_.data(), row);
+ } else {
+ s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row);
+ }
+ }
+
+ ptr += s.Aggs_[i]->StateSize;
+ }
+
+ if constexpr (std::is_same<TKey, TSSOKey>::value) {
+ MoveKeyToArena(hash.GetKey(iter), s.Arena_);
+ }
+
+ hash.CheckGrow();
+ } else {
+ if constexpr (UseArena) {
+ ptr = *(char**)payload;
+ } else {
+ ptr = payload;
+ }
+
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ if constexpr (Finalize) {
+ s.Aggs_[i]->UpdateState(ptr, s.Values_.data(), row);
+ } else {
+ s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row);
+ }
+ }
+
+ ptr += s.Aggs_[i]->StateSize;
+ }
+ }
+ }
+
+ template <typename THash>
+ void Iterate(THash& hash, const TVector<std::unique_ptr<IKeyColumnBuilder>>& keyBuilders,
+ const TVector<std::unique_ptr<IAggColumnBuilder>>& aggBuilders,
+ NUdf::TUnboxedValue*const* output, TState& s) const {
+ for (auto iter = hash.Begin(); iter != hash.End(); hash.Advance(iter)) {
+ if (!hash.IsValid(iter)) {
+ continue;
+ }
+
+ const TKey& key = hash.GetKey(iter);
+ auto payload = (const char*)hash.GetPayload(iter);
+ const char* ptr;
+ if constexpr (UseArena) {
+ ptr = *(const char**)payload;
+ } else {
+ ptr = payload;
+ }
+
+ TInputBuffer in(GetKeyView<TKey>(key));
+ for (auto& kb : keyBuilders) {
+ kb->Add(in);
+ }
+
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ aggBuilders[i]->Add(ptr);
+ }
+
+ ptr += s.Aggs_[i]->StateSize;
+ }
+ }
+ }
+
ui64 GetBatchLength(const NUdf::TUnboxedValue* columns) const {
return TArrowBlock::From(columns[Width_ - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
}
@@ -766,11 +1011,11 @@ private:
std::vector<std::unique_ptr<IKeySerializer>> KeySerializers_;
};
-template <typename TKey, bool UseSet, bool UseFilter>
-class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, UseSet, UseFilter, false, TBlockCombineHashedWrapper<TKey, UseSet, UseFilter>> {
+template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter>
+class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> {
public:
- using TSelf = TBlockCombineHashedWrapper<TKey, UseSet, UseFilter>;
- using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, UseSet, UseFilter, false, TSelf>;
+ using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>;
+ using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, TSelf>;
TBlockCombineHashedWrapper(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
@@ -783,11 +1028,11 @@ public:
{}
};
-template <typename TKey, bool UseSet>
-class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, UseSet, false, true, TBlockMergeFinalizeHashedWrapper<TKey, UseSet>> {
+template <typename TKey, typename TFixedAggState, bool UseSet>
+class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> {
public:
- using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, UseSet>;
- using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, UseSet, false, true, TSelf>;
+ using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>;
+ using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, TSelf>;
TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
@@ -835,7 +1080,8 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareB
}
template <typename TAggregator>
-void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams<TAggregator>>& aggsParams, const TTypeEnvironment& env) {
+ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams<TAggregator>>& aggsParams, const TTypeEnvironment& env) {
+ ui32 totalStateSize = 0;
for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) {
auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i));
auto name = AS_VALUE(TDataLiteral, aggVal->GetValue(0))->AsValue().AsStringRef();
@@ -847,13 +1093,38 @@ void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<
TAggParams<TAggregator> p;
p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), tupleType, filterColumn, argColumns, env);
+ totalStateSize += p.Prepared_->StateSize;
aggsParams.emplace_back(std::move(p));
}
+
+ return totalStateSize;
+}
+
+template <bool UseSet, bool UseFilter, typename TKey>
+IComputationNode* MakeBlockCombineHashedWrapper(
+ ui32 totalStateSize,
+ TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ std::optional<ui32> filterColumn,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
+ TVector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) {
+ if (totalStateSize <= sizeof(TState8)) {
+ return new TBlockCombineHashedWrapper<TKey, TState8, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ }
+
+ if (totalStateSize <= sizeof(TState16)) {
+ return new TBlockCombineHashedWrapper<TKey, TState16, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ }
+
+ return new TBlockCombineHashedWrapper<TKey, TStateArena, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
}
template <bool UseSet, bool UseFilter>
IComputationNode* MakeBlockCombineHashedWrapper(
ui32 totalKeysSize,
+ ui32 totalStateSize,
TComputationMutables& mutables,
IComputationWideFlowNode* flow,
std::optional<ui32> filterColumn,
@@ -862,19 +1133,41 @@ IComputationNode* MakeBlockCombineHashedWrapper(
std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
TVector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) {
if (totalKeysSize <= sizeof(ui32)) {
- return new TBlockCombineHashedWrapper<ui32, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui32>(totalStateSize, mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
}
if (totalKeysSize <= sizeof(ui64)) {
- return new TBlockCombineHashedWrapper<ui64, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui64>(totalStateSize, mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ }
+
+ return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TSSOKey>(totalStateSize, mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+}
+
+template <typename TKey, bool UseSet>
+IComputationNode* MakeBlockMergeFinalizeHashedWrapper(
+ ui32 totalStateSize,
+ TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
+ TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) {
+
+ if (totalStateSize <= sizeof(TState8)) {
+ return new TBlockMergeFinalizeHashedWrapper<TKey, TState8, UseSet>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
+ }
+
+ if (totalStateSize <= sizeof(TState16)) {
+ return new TBlockMergeFinalizeHashedWrapper<TKey, TState16, UseSet>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
}
- return new TBlockCombineHashedWrapper<TSSOKey, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams));
+ return new TBlockMergeFinalizeHashedWrapper<TKey, TStateArena, UseSet>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
}
template <bool UseSet>
IComputationNode* MakeBlockMergeFinalizeHashedWrapper(
ui32 totalKeysSize,
+ ui32 totalStateSize,
TComputationMutables& mutables,
IComputationWideFlowNode* flow,
size_t width,
@@ -882,14 +1175,14 @@ IComputationNode* MakeBlockMergeFinalizeHashedWrapper(
std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) {
if (totalKeysSize <= sizeof(ui32)) {
- return new TBlockMergeFinalizeHashedWrapper<ui32, UseSet>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockMergeFinalizeHashedWrapper<ui32, UseSet>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
}
if (totalKeysSize <= sizeof(ui64)) {
- return new TBlockMergeFinalizeHashedWrapper<ui64, UseSet>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockMergeFinalizeHashedWrapper<ui64, UseSet>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
}
- return new TBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
}
void PrepareKeys(const std::vector<TKeyParams>& keys, ui32& totalKeysSize, std::vector<std::unique_ptr<IKeySerializer>>& keySerializers) {
@@ -1005,7 +1298,7 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
TVector<TAggParams<IBlockAggregatorCombineAll>> aggsParams;
- FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams));
}
@@ -1032,7 +1325,7 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3));
TVector<TAggParams<IBlockAggregatorCombineKeys>> aggsParams;
- FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
ui32 totalKeysSize = 0;
std::vector<std::unique_ptr<IKeySerializer>> keySerializers;
@@ -1040,15 +1333,15 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation
if (filterColumn) {
if (aggsParams.size() == 0) {
- return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
} else {
- return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
}
} else {
if (aggsParams.size() == 0) {
- return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
} else {
- return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
}
}
}
@@ -1070,16 +1363,16 @@ IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TCompu
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
TVector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams;
- FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env);
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env);
ui32 totalKeysSize = 0;
std::vector<std::unique_ptr<IKeySerializer>> keySerializers;
PrepareKeys(keys, totalKeysSize, keySerializers);
if (aggsParams.size() == 0) {
- return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
} else {
- return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
+ return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, std::move(keySerializers), std::move(aggsParams));
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
index e40ae152f1..55134d8a37 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
@@ -261,8 +261,11 @@ public:
template <typename TTag>
class TPreparedCountAll : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedCountAll(std::optional<ui32> filterColumn, ui32 argColumn)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TState))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
{}
@@ -278,8 +281,11 @@ private:
template <typename TTag>
class TPreparedCount : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedCount(std::optional<ui32> filterColumn, ui32 argColumn)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TState))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
{}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
index be65fc192d..0e734ccc6b 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
@@ -87,6 +87,12 @@ public:
virtual ~IPreparedBlockAggregator() = default;
virtual std::unique_ptr<T> Make(TComputationContext& ctx) const = 0;
+
+ const ui32 StateSize;
+
+ explicit IPreparedBlockAggregator(ui32 stateSize)
+ : StateSize(stateSize)
+ {}
};
class IBlockAggregatorFactory {
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
index 190ec70d0e..9f05a9299d 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
@@ -437,9 +437,12 @@ private:
template <typename TTag, typename TIn, typename TInScalar, typename TBuilder, bool IsMin>
class TPreparedMinMaxBlockAggregatorNullableOrScalar : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn,
const std::shared_ptr<arrow::DataType>& builderDataType)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TState<TIn, IsMin>))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
, BuilderDataType_(builderDataType)
{}
@@ -457,9 +460,12 @@ private:
template <typename TTag, typename TIn, typename TInScalar, typename TBuilder, bool IsMin>
class TPreparedMinMaxBlockAggregator : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
const std::shared_ptr<arrow::DataType>& builderDataType)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TSimpleState<TIn, IsMin>))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
, BuilderDataType_(builderDataType)
{}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
index cf1df6f051..9d93077092 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
@@ -705,9 +705,12 @@ private:
template <typename TTag, typename TIn, typename TSum, typename TBuilder, typename TInScalar>
class TPreparedSumBlockAggregatorNullableOrScalar : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedSumBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn,
const std::shared_ptr<arrow::DataType>& builderDataType)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TSumState<TSum>))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
, BuilderDataType_(builderDataType)
{}
@@ -725,9 +728,12 @@ private:
template <typename TTag, typename TIn, typename TSum, typename TBuilder, typename TInScalar>
class TPreparedSumBlockAggregator : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
const std::shared_ptr<arrow::DataType>& builderDataType)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TSumSimpleState<TSum>))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
, BuilderDataType_(builderDataType)
{}
@@ -825,9 +831,12 @@ public:
template <typename TTag, typename TIn, typename TInScalar>
class TPreparedAvgBlockAggregator : public TTag::TPreparedAggregator {
public:
+ using TBase = typename TTag::TPreparedAggregator;
+
TPreparedAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
const std::shared_ptr<arrow::DataType>& builderDataType)
- : FilterColumn_(filterColumn)
+ : TBase(sizeof(TAvgState))
+ , FilterColumn_(filterColumn)
, ArgColumn_(argColumn)
, BuilderDataType_(builderDataType)
{}
@@ -844,8 +853,11 @@ private:
class TPreparedAvgBlockAggregatorOverState : public TFinalizeKeysTag::TPreparedAggregator {
public:
+ using TBase = TFinalizeKeysTag::TPreparedAggregator;
+
TPreparedAvgBlockAggregatorOverState(ui32 argColumn)
- : ArgColumn_(argColumn)
+ : TBase(sizeof(TAvgState))
+ , ArgColumn_(argColumn)
{}
std::unique_ptr<typename TFinalizeKeysTag::TAggregator> Make(TComputationContext& ctx) const final {
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
index a37b84c33b..1d813b5167 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
@@ -5,6 +5,7 @@
#include <vector>
#include <util/digest/city.h>
+#include <util/generic/scope.h>
namespace NKikimr {
namespace NMiniKQL {
@@ -15,19 +16,29 @@ class TRobinHoodHashBase {
protected:
using TPSLStorage = i32;
- using TVec = std::vector<char, TAllocator>;
-
explicit TRobinHoodHashBase(ui64 initialCapacity = 1u << 8)
: Capacity(initialCapacity)
+ , Allocator()
, SelfHash(GetSelfHash(this))
{
Y_ENSURE((Capacity & (Capacity - 1)) == 0);
}
+ ~TRobinHoodHashBase() {
+ if (Data) {
+ Allocator.deallocate(Data, DataEnd - Data);
+ }
+ }
+
+ TRobinHoodHashBase(const TRobinHoodHashBase&) = delete;
+ TRobinHoodHashBase(TRobinHoodHashBase&&) = delete;
+ void operator=(const TRobinHoodHashBase&) = delete;
+ void operator=(TRobinHoodHashBase&&) = delete;
+
public:
// returns iterator
Y_FORCE_INLINE char* Insert(TKey key, bool& isNew) {
- auto ret = InsertImpl(key, isNew, Capacity, Data);
+ auto ret = InsertImpl(key, isNew, Capacity, Data, DataEnd);
Size += isNew ? 1 : 0;
return ret;
}
@@ -44,19 +55,19 @@ public:
}
const char* Begin() const {
- return Data.data();
+ return Data;
}
const char* End() const {
- return Data.data() + Data.size();
+ return DataEnd;
}
char* Begin() {
- return Data.data();
+ return Data;
}
char* End() {
- return Data.data() + Data.size();
+ return DataEnd;
}
void Advance(char*& ptr) {
@@ -67,6 +78,10 @@ public:
ptr += AsDeriv().GetCellSize();
}
+ bool IsValid(const char* ptr) {
+ return GetPSL(ptr) >= 0;
+ }
+
static const TPSLStorage& GetPSL(const char* ptr) {
return *(const TPSLStorage*)ptr;
}
@@ -75,6 +90,10 @@ public:
return *(const TKey*)(ptr + sizeof(TPSLStorage));
}
+ static TKey& GetKey(char* ptr) {
+ return *(TKey*)(ptr + sizeof(TPSLStorage));
+ }
+
const void* GetPayload(const char* ptr) {
return AsDeriv().GetPayloadImpl(ptr);
}
@@ -83,21 +102,18 @@ public:
return *(TPSLStorage*)ptr;
}
- static TKey& GetKey(char* ptr) {
- return *(TKey*)(ptr + sizeof(TPSLStorage));
- }
-
- void* GetPayload(char* ptr) {
+ void* GetMutablePayload(char* ptr) {
return AsDeriv().GetPayloadImpl(ptr);
}
private:
- Y_FORCE_INLINE char* InsertImpl(TKey key, bool& isNew, ui64 capacity, TVec& data) {
+ Y_FORCE_INLINE char* InsertImpl(TKey key, bool& isNew, ui64 capacity, char* data, char* dataEnd) {
isNew = false;
ui64 bucket = (SelfHash ^ THash()(key)) & (capacity - 1);
- char* ptr = data.data() + AsDeriv().GetCellSize() * bucket;
+ char* ptr = data + AsDeriv().GetCellSize() * bucket;
TPSLStorage distance = 0;
char* returnPtr;
+ typename TDeriv::TPayloadStore tmpPayload;
for (;;) {
if (GetPSL(ptr) < 0) {
isNew = true;
@@ -115,23 +131,23 @@ private:
returnPtr = ptr;
std::swap(distance, GetPSL(ptr));
std::swap(key, GetKey(ptr));
- AsDeriv().SavePayload(GetPayload(ptr));
+ AsDeriv().SavePayload(GetPayload(ptr), tmpPayload);
isNew = true;
++distance;
- AdvancePointer(ptr, data);
+ AdvancePointer(ptr, data, dataEnd);
break;
}
++distance;
- AdvancePointer(ptr, data);
+ AdvancePointer(ptr, data, dataEnd);
}
for (;;) {
if (GetPSL(ptr) < 0) {
GetPSL(ptr) = distance;
GetKey(ptr) = key;
- AsDeriv().RestorePayload(GetPayload(ptr));
+ AsDeriv().RestorePayload(GetMutablePayload(ptr), tmpPayload);
return returnPtr; // for original key
}
@@ -139,36 +155,41 @@ private:
// swap keys & state
std::swap(distance, GetPSL(ptr));
std::swap(key, GetKey(ptr));
- AsDeriv().SwapPayload(GetPayload(ptr));
+ AsDeriv().SwapPayload(GetMutablePayload(ptr), tmpPayload);
}
++distance;
- AdvancePointer(ptr, data);
+ AdvancePointer(ptr, data, dataEnd);
}
}
void Grow() {
- TVec newData;
auto newCapacity = Capacity * 2;
- Allocate(newCapacity, newData);
+ char *newData, *newDataEnd;
+ Allocate(newCapacity, newData, newDataEnd);
+ Y_DEFER {
+ Allocator.deallocate(newData, newDataEnd - newData);
+ };
+
for (auto iter = Begin(); iter != End(); Advance(iter)) {
if (GetPSL(iter) < 0) {
continue;
}
bool isNew;
- auto newIter = InsertImpl(GetKey(iter), isNew, newCapacity, newData);
+ auto newIter = InsertImpl(GetKey(iter), isNew, newCapacity, newData, newDataEnd);
Y_ASSERT(isNew);
- AsDeriv().CopyPayload(GetPayload(newIter), GetPayload(iter));
+ AsDeriv().CopyPayload(GetMutablePayload(newIter), GetPayload(iter));
}
- Data.swap(newData);
Capacity = newCapacity;
+ std::swap(Data, newData);
+ std::swap(DataEnd, newDataEnd);
}
- void AdvancePointer(char*& ptr, TVec& data) const {
+ void AdvancePointer(char*& ptr, char* begin, char* end) const {
ptr += AsDeriv().GetCellSize();
- ptr = (ptr == data.data() + data.size()) ? data.data() : ptr;
+ ptr = (ptr == end) ? begin : ptr;
}
static ui64 GetSelfHash(void* self) {
@@ -179,13 +200,15 @@ private:
protected:
void Init() {
- Allocate(Capacity, Data);
+ Allocate(Capacity, Data, DataEnd);
}
private:
- void Allocate(ui64 capacity, TVec& data) const {
- data.resize(AsDeriv().GetCellSize() * capacity);
- char* ptr = data.data();
+ void Allocate(ui64 capacity, char*& data, char*& dataEnd) {
+ ui64 bytes = capacity * AsDeriv().GetCellSize();
+ data = Allocator.allocate(bytes);
+ dataEnd = data + bytes;
+ char* ptr = data;
for (ui64 i = 0; i < capacity; ++i) {
GetPSL(ptr) = -1;
ptr += AsDeriv().GetCellSize();
@@ -203,8 +226,10 @@ private:
private:
ui64 Size = 0;
ui64 Capacity;
- TVec Data;
+ TAllocator Allocator;
const ui64 SelfHash;
+ char* Data = nullptr;
+ char* DataEnd = nullptr;
};
template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
@@ -212,6 +237,7 @@ class TRobinHoodHashMap : public TRobinHoodHashBase<TKey, TEqual, THash, TAlloca
public:
using TSelf = TRobinHoodHashMap<TKey, TEqual, THash, TAllocator>;
using TBase = TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TSelf>;
+ using TPayloadStore = int;
explicit TRobinHoodHashMap(ui32 payloadSize, ui64 initialCapacity = 1u << 8)
: TBase(initialCapacity)
@@ -239,15 +265,18 @@ public:
memcpy(dst, src, PayloadSize);
}
- void SavePayload(const void* p) {
+ void SavePayload(const void* p, int& store) {
+ Y_UNUSED(store);
memcpy(TmpPayload.data(), p, PayloadSize);
}
- void RestorePayload(void* p) {
+ void RestorePayload(void* p, const int& store) {
+ Y_UNUSED(store);
memcpy(p, TmpPayload.data(), PayloadSize);
}
- void SwapPayload(void* p) {
+ void SwapPayload(void* p, int& store) {
+ Y_UNUSED(store);
memcpy(TmpPayload2.data(), p, PayloadSize);
memcpy(p, TmpPayload.data(), PayloadSize);
TmpPayload2.swap(TmpPayload);
@@ -256,7 +285,50 @@ public:
private:
const ui32 CellSize;
const ui32 PayloadSize;
- typename TBase::TVec TmpPayload, TmpPayload2;
+ using TVec = std::vector<char, TAllocator>;
+ TVec TmpPayload, TmpPayload2;
+};
+
+template <typename TKey, typename TPayload, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
+class TRobinHoodHashFixedMap : public TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TRobinHoodHashFixedMap<TKey, TPayload, TEqual, THash, TAllocator>> {
+public:
+ using TSelf = TRobinHoodHashFixedMap<TKey, TPayload, TEqual, THash, TAllocator>;
+ using TBase = TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TSelf>;
+ using TPayloadStore = TPayload;
+
+ explicit TRobinHoodHashFixedMap(ui64 initialCapacity = 1u << 8)
+ : TBase(initialCapacity)
+ {
+ TBase::Init();
+ }
+
+ ui32 GetCellSize() const {
+ return sizeof(typename TBase::TPSLStorage) + sizeof(TKey) + sizeof(TPayload);
+ }
+
+ void* GetPayloadImpl(char* ptr) {
+ return ptr + sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
+ }
+
+ const void* GetPayloadImpl(const char* ptr) {
+ return ptr + sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
+ }
+
+ void CopyPayload(void* dst, const void* src) {
+ *(TPayload*)dst = *(const TPayload*)src;
+ }
+
+ void SavePayload(const void* p, TPayload& store) {
+ store = *(const TPayload*)p;
+ }
+
+ void RestorePayload(void* p, const TPayload& store) {
+ *(TPayload*)p = store;
+ }
+
+ void SwapPayload(void* p, TPayload& store) {
+ std::swap(*(TPayload*)p, store);
+ }
};
template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
@@ -264,6 +336,7 @@ class TRobinHoodHashSet : public TRobinHoodHashBase<TKey, TEqual, THash, TAlloca
public:
using TSelf = TRobinHoodHashSet<TKey, TEqual, THash, TAllocator>;
using TBase = TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TSelf>;
+ using TPayloadStore = int;
explicit TRobinHoodHashSet(ui64 initialCapacity = 1u << 8)
: TBase(initialCapacity)
@@ -290,16 +363,19 @@ public:
Y_UNUSED(src);
}
- void SavePayload(const void* p) {
+ void SavePayload(const void* p, int& store) {
Y_UNUSED(p);
+ Y_UNUSED(store);
}
- void RestorePayload(void* p) {
+ void RestorePayload(void* p, const int& store) {
Y_UNUSED(p);
+ Y_UNUSED(store);
}
- void SwapPayload(void* p) {
+ void SwapPayload(void* p, int& store) {
Y_UNUSED(p);
+ Y_UNUSED(store);
}
};
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_rh_hash_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_rh_hash_ut.cpp
index d47bcb319a..ba81221b79 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_rh_hash_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_rh_hash_ut.cpp
@@ -21,17 +21,17 @@ Y_UNIT_TEST_SUITE(TMiniKQLRobinHoodHashTest) {
UNIT_ASSERT_VALUES_EQUAL(isNew, inserted);
it->second += i;
if (isNew) {
- *(i64*)rh.GetPayload(iter) = i;
+ *(i64*)rh.GetMutablePayload(iter) = i;
rh.CheckGrow();
} else {
- *(i64*)rh.GetPayload(iter) += i;
+ *(i64*)rh.GetMutablePayload(iter) += i;
}
UNIT_ASSERT_VALUES_EQUAL(h.size(), rh.GetSize());
}
for (auto it = rh.Begin(); it != rh.End(); rh.Advance(it)) {
- if (rh.GetPSL(it) < 0) {
+ if (!rh.IsValid(it)) {
continue;
}
@@ -45,6 +45,43 @@ Y_UNIT_TEST_SUITE(TMiniKQLRobinHoodHashTest) {
UNIT_ASSERT(h.empty());
}
+ Y_UNIT_TEST(FixedMap) {
+ TRobinHoodHashFixedMap<i32, i64> rh;
+ std::unordered_map<i32, i64> h;
+ for (ui64 i = 0; i < 10000; ++i) {
+ auto k = i % 1000;
+ auto [it, inserted] = h.emplace(k, 0);
+ bool isNew;
+ auto iter = rh.Insert(k, isNew);
+ UNIT_ASSERT_VALUES_EQUAL(rh.GetKey(iter), k);
+ UNIT_ASSERT_VALUES_EQUAL(isNew, inserted);
+ it->second += i;
+ if (isNew) {
+ *(i64*)rh.GetMutablePayload(iter) = i;
+ rh.CheckGrow();
+ } else {
+ *(i64*)rh.GetMutablePayload(iter) += i;
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(h.size(), rh.GetSize());
+ }
+
+ for (auto it = rh.Begin(); it != rh.End(); rh.Advance(it)) {
+ if (!rh.IsValid(it)) {
+ continue;
+ }
+
+ auto key = rh.GetKey(it);
+ auto hit = h.find(key);
+ UNIT_ASSERT(hit != h.end());
+ UNIT_ASSERT_VALUES_EQUAL(*(i64*)rh.GetPayload(it), hit->second);
+ h.erase(key);
+ }
+
+ UNIT_ASSERT(h.empty());
+ }
+
+
Y_UNIT_TEST(Set) {
TRobinHoodHashSet<i32> rh;
std::unordered_set<i32> h;
@@ -63,7 +100,7 @@ Y_UNIT_TEST_SUITE(TMiniKQLRobinHoodHashTest) {
}
for (auto it = rh.Begin(); it != rh.End(); rh.Advance(it)) {
- if (rh.GetPSL(it) < 0) {
+ if (!rh.IsValid(it)) {
continue;
}