diff options
author | ivanmorozov <ivanmorozov@yandex-team.com> | 2023-01-20 09:02:12 +0300 |
---|---|---|
committer | ivanmorozov <ivanmorozov@yandex-team.com> | 2023-01-20 09:02:12 +0300 |
commit | 2db3d52777ce3feefdfc7ee592b95935ed59452a (patch) | |
tree | ae721ea1ae00d14c0f58034fb0a94dcd0f1f9339 | |
parent | e1b2f2bb4261482bac275b4c9fbb36775fec2501 (diff) | |
download | ydb-2db3d52777ce3feefdfc7ee592b95935ed59452a.tar.gz |
rh hash speed up
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h | 96 |
1 files changed, 72 insertions, 24 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h index f18a853468..fea98570e2 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h @@ -10,13 +10,43 @@ namespace NKikimr { namespace NMiniKQL { +template <class TKey> +struct TRobinHoodCacheHashUsageDetector { + static constexpr bool UseCache = !std::is_arithmetic<TKey>::value; +}; + //TODO: only POD key & payloads are now supported -template <typename TKey, typename TEqual, typename THash, typename TAllocator, typename TDeriv> +template <typename TKey, typename TEqual, typename THash, typename TAllocator, typename TDeriv, bool CacheHash = TRobinHoodCacheHashUsageDetector<TKey>::UseCache> class TRobinHoodHashBase { protected: THash HashLocal; TEqual EqualLocal; - using TPSLStorage = i32; + template <bool CacheHashForPSL> + struct TPSLStorageImpl; + + template <> + struct TPSLStorageImpl<true> { + i32 Distance = -1; + ui32 Hash = 0; + TPSLStorageImpl() = default; + TPSLStorageImpl(const ui64 hash) + : Distance(0) + , Hash(hash& Max<ui32>()) { + + } + }; + + template <> + struct TPSLStorageImpl<false> { + i32 Distance = -1; + TPSLStorageImpl() = default; + TPSLStorageImpl(const ui64 /*hash*/) + : Distance(0) { + + } + }; + + using TPSLStorage = TPSLStorageImpl<CacheHash>; explicit TRobinHoodHashBase(const ui64 initialCapacity, THash hash, TEqual equal) : HashLocal(std::move(hash)) @@ -42,7 +72,7 @@ protected: public: // returns iterator Y_FORCE_INLINE char* Insert(TKey key, bool& isNew) { - auto ret = InsertImpl(key, isNew, Capacity, Data, DataEnd); + auto ret = InsertImpl(key, HashLocal(key), isNew, Capacity, Data, DataEnd); Size += isNew ? 1 : 0; return ret; } @@ -54,10 +84,14 @@ public: } } + ui64 GetCapacity() const { + return Capacity; + } + void Clear() { char* ptr = Data; for (ui64 i = 0; i < Capacity; ++i) { - GetPSL(ptr) = -1; + GetPSL(ptr).Distance = -1; ptr += AsDeriv().GetCellSize(); } Size = 0; @@ -96,7 +130,7 @@ public: } bool IsValid(const char* ptr) { - return GetPSL(ptr) >= 0; + return GetPSL(ptr).Distance >= 0; } static const TPSLStorage& GetPSL(const char* ptr) { @@ -124,58 +158,66 @@ public: } private: - Y_FORCE_INLINE char* InsertImpl(TKey key, bool& isNew, ui64 capacity, char* data, char* dataEnd) { + Y_FORCE_INLINE char* InsertImpl(TKey key, const ui64 hash, bool& isNew, ui64 capacity, char* data, char* dataEnd) { isNew = false; - ui64 bucket = (SelfHash ^ HashLocal(key)) & (capacity - 1); + TPSLStorage psl(hash); + ui64 bucket = (SelfHash ^ hash) & (capacity - 1); char* ptr = data + AsDeriv().GetCellSize() * bucket; - TPSLStorage distance = 0; char* returnPtr; typename TDeriv::TPayloadStore tmpPayload; for (;;) { - if (GetPSL(ptr) < 0) { + auto& pslPtr = GetPSL(ptr); + if (pslPtr.Distance < 0) { isNew = true; - GetPSL(ptr) = distance; + pslPtr = psl; GetKey(ptr) = key; return ptr; } - if (EqualLocal(GetKey(ptr), key)) { - return ptr; + if constexpr (CacheHash) { + if (pslPtr.Hash == psl.Hash && EqualLocal(GetKey(ptr), key)) { + return ptr; + } + } else { + if (EqualLocal(GetKey(ptr), key)) { + return ptr; + } } - if (distance > GetPSL(ptr)) { + if (psl.Distance > pslPtr.Distance) { // swap keys & state returnPtr = ptr; - std::swap(distance, GetPSL(ptr)); + std::swap(psl, pslPtr); std::swap(key, GetKey(ptr)); AsDeriv().SavePayload(GetPayload(ptr), tmpPayload); isNew = true; - ++distance; + ++psl.Distance; AdvancePointer(ptr, data, dataEnd); break; } - ++distance; + ++psl.Distance; AdvancePointer(ptr, data, dataEnd); } for (;;) { - if (GetPSL(ptr) < 0) { - GetPSL(ptr) = distance; + auto& pslPtr = GetPSL(ptr); + if (pslPtr.Distance < 0) { + pslPtr = psl; GetKey(ptr) = key; AsDeriv().RestorePayload(GetMutablePayload(ptr), tmpPayload); return returnPtr; // for original key } - if (distance > GetPSL(ptr)) { + if (psl.Distance > pslPtr.Distance) { // swap keys & state - std::swap(distance, GetPSL(ptr)); + std::swap(psl, pslPtr); std::swap(key, GetKey(ptr)); AsDeriv().SwapPayload(GetMutablePayload(ptr), tmpPayload); } - ++distance; + ++psl.Distance; AdvancePointer(ptr, data, dataEnd); } } @@ -189,12 +231,18 @@ private: }; for (auto iter = Begin(); iter != End(); Advance(iter)) { - if (GetPSL(iter) < 0) { + if (GetPSL(iter).Distance < 0) { continue; } bool isNew; - auto newIter = InsertImpl(GetKey(iter), isNew, newCapacity, newData, newDataEnd); + auto& key = GetKey(iter); + char* newIter = nullptr; + if constexpr (CacheHash) { + newIter = InsertImpl(key, GetPSL(iter).Hash, isNew, newCapacity, newData, newDataEnd); + } else { + newIter = InsertImpl(key, HashLocal(key), isNew, newCapacity, newData, newDataEnd); + } Y_ASSERT(isNew); AsDeriv().CopyPayload(GetMutablePayload(newIter), GetPayload(iter)); } @@ -227,7 +275,7 @@ private: dataEnd = data + bytes; char* ptr = data; for (ui64 i = 0; i < capacity; ++i) { - GetPSL(ptr) = -1; + GetPSL(ptr).Distance = -1; ptr += AsDeriv().GetCellSize(); } } |