aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorivanmorozov <ivanmorozov@yandex-team.com>2023-01-20 09:02:12 +0300
committerivanmorozov <ivanmorozov@yandex-team.com>2023-01-20 09:02:12 +0300
commit2db3d52777ce3feefdfc7ee592b95935ed59452a (patch)
treeae721ea1ae00d14c0f58034fb0a94dcd0f1f9339
parente1b2f2bb4261482bac275b4c9fbb36775fec2501 (diff)
downloadydb-2db3d52777ce3feefdfc7ee592b95935ed59452a.tar.gz
rh hash speed up
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h96
1 files changed, 72 insertions, 24 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
index f18a853468..fea98570e2 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
@@ -10,13 +10,43 @@
namespace NKikimr {
namespace NMiniKQL {
+template <class TKey>
+struct TRobinHoodCacheHashUsageDetector {
+ static constexpr bool UseCache = !std::is_arithmetic<TKey>::value;
+};
+
//TODO: only POD key & payloads are now supported
-template <typename TKey, typename TEqual, typename THash, typename TAllocator, typename TDeriv>
+template <typename TKey, typename TEqual, typename THash, typename TAllocator, typename TDeriv, bool CacheHash = TRobinHoodCacheHashUsageDetector<TKey>::UseCache>
class TRobinHoodHashBase {
protected:
THash HashLocal;
TEqual EqualLocal;
- using TPSLStorage = i32;
+ template <bool CacheHashForPSL>
+ struct TPSLStorageImpl;
+
+ template <>
+ struct TPSLStorageImpl<true> {
+ i32 Distance = -1;
+ ui32 Hash = 0;
+ TPSLStorageImpl() = default;
+ TPSLStorageImpl(const ui64 hash)
+ : Distance(0)
+ , Hash(hash& Max<ui32>()) {
+
+ }
+ };
+
+ template <>
+ struct TPSLStorageImpl<false> {
+ i32 Distance = -1;
+ TPSLStorageImpl() = default;
+ TPSLStorageImpl(const ui64 /*hash*/)
+ : Distance(0) {
+
+ }
+ };
+
+ using TPSLStorage = TPSLStorageImpl<CacheHash>;
explicit TRobinHoodHashBase(const ui64 initialCapacity, THash hash, TEqual equal)
: HashLocal(std::move(hash))
@@ -42,7 +72,7 @@ protected:
public:
// returns iterator
Y_FORCE_INLINE char* Insert(TKey key, bool& isNew) {
- auto ret = InsertImpl(key, isNew, Capacity, Data, DataEnd);
+ auto ret = InsertImpl(key, HashLocal(key), isNew, Capacity, Data, DataEnd);
Size += isNew ? 1 : 0;
return ret;
}
@@ -54,10 +84,14 @@ public:
}
}
+ ui64 GetCapacity() const {
+ return Capacity;
+ }
+
void Clear() {
char* ptr = Data;
for (ui64 i = 0; i < Capacity; ++i) {
- GetPSL(ptr) = -1;
+ GetPSL(ptr).Distance = -1;
ptr += AsDeriv().GetCellSize();
}
Size = 0;
@@ -96,7 +130,7 @@ public:
}
bool IsValid(const char* ptr) {
- return GetPSL(ptr) >= 0;
+ return GetPSL(ptr).Distance >= 0;
}
static const TPSLStorage& GetPSL(const char* ptr) {
@@ -124,58 +158,66 @@ public:
}
private:
- Y_FORCE_INLINE char* InsertImpl(TKey key, bool& isNew, ui64 capacity, char* data, char* dataEnd) {
+ Y_FORCE_INLINE char* InsertImpl(TKey key, const ui64 hash, bool& isNew, ui64 capacity, char* data, char* dataEnd) {
isNew = false;
- ui64 bucket = (SelfHash ^ HashLocal(key)) & (capacity - 1);
+ TPSLStorage psl(hash);
+ ui64 bucket = (SelfHash ^ hash) & (capacity - 1);
char* ptr = data + AsDeriv().GetCellSize() * bucket;
- TPSLStorage distance = 0;
char* returnPtr;
typename TDeriv::TPayloadStore tmpPayload;
for (;;) {
- if (GetPSL(ptr) < 0) {
+ auto& pslPtr = GetPSL(ptr);
+ if (pslPtr.Distance < 0) {
isNew = true;
- GetPSL(ptr) = distance;
+ pslPtr = psl;
GetKey(ptr) = key;
return ptr;
}
- if (EqualLocal(GetKey(ptr), key)) {
- return ptr;
+ if constexpr (CacheHash) {
+ if (pslPtr.Hash == psl.Hash && EqualLocal(GetKey(ptr), key)) {
+ return ptr;
+ }
+ } else {
+ if (EqualLocal(GetKey(ptr), key)) {
+ return ptr;
+ }
}
- if (distance > GetPSL(ptr)) {
+ if (psl.Distance > pslPtr.Distance) {
// swap keys & state
returnPtr = ptr;
- std::swap(distance, GetPSL(ptr));
+ std::swap(psl, pslPtr);
std::swap(key, GetKey(ptr));
AsDeriv().SavePayload(GetPayload(ptr), tmpPayload);
isNew = true;
- ++distance;
+ ++psl.Distance;
AdvancePointer(ptr, data, dataEnd);
break;
}
- ++distance;
+ ++psl.Distance;
AdvancePointer(ptr, data, dataEnd);
}
for (;;) {
- if (GetPSL(ptr) < 0) {
- GetPSL(ptr) = distance;
+ auto& pslPtr = GetPSL(ptr);
+ if (pslPtr.Distance < 0) {
+ pslPtr = psl;
GetKey(ptr) = key;
AsDeriv().RestorePayload(GetMutablePayload(ptr), tmpPayload);
return returnPtr; // for original key
}
- if (distance > GetPSL(ptr)) {
+ if (psl.Distance > pslPtr.Distance) {
// swap keys & state
- std::swap(distance, GetPSL(ptr));
+ std::swap(psl, pslPtr);
std::swap(key, GetKey(ptr));
AsDeriv().SwapPayload(GetMutablePayload(ptr), tmpPayload);
}
- ++distance;
+ ++psl.Distance;
AdvancePointer(ptr, data, dataEnd);
}
}
@@ -189,12 +231,18 @@ private:
};
for (auto iter = Begin(); iter != End(); Advance(iter)) {
- if (GetPSL(iter) < 0) {
+ if (GetPSL(iter).Distance < 0) {
continue;
}
bool isNew;
- auto newIter = InsertImpl(GetKey(iter), isNew, newCapacity, newData, newDataEnd);
+ auto& key = GetKey(iter);
+ char* newIter = nullptr;
+ if constexpr (CacheHash) {
+ newIter = InsertImpl(key, GetPSL(iter).Hash, isNew, newCapacity, newData, newDataEnd);
+ } else {
+ newIter = InsertImpl(key, HashLocal(key), isNew, newCapacity, newData, newDataEnd);
+ }
Y_ASSERT(isNew);
AsDeriv().CopyPayload(GetMutablePayload(newIter), GetPayload(iter));
}
@@ -227,7 +275,7 @@ private:
dataEnd = data + bytes;
char* ptr = data;
for (ui64 i = 0; i < capacity; ++i) {
- GetPSL(ptr) = -1;
+ GetPSL(ptr).Distance = -1;
ptr += AsDeriv().GetCellSize();
}
}