diff options
author | vvvv <vvvv@yandex-team.com> | 2024-11-07 04:19:26 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.com> | 2024-11-07 04:29:50 +0300 |
commit | 2661be00f3bc47590fda9218bf0386d6355c8c88 (patch) | |
tree | 3d316c07519191283d31c5f537efc6aabb42a2f0 /yql/essentials/minikql/compact_hash_ut.cpp | |
parent | cf2a23963ac10add28c50cc114fbf48953eca5aa (diff) | |
download | ydb-2661be00f3bc47590fda9218bf0386d6355c8c88.tar.gz |
Moved yql/minikql YQL-19206
init
[nodiff:caesar]
commit_hash:d1182ef7d430ccf7e4d37ed933c7126d7bd5d6e4
Diffstat (limited to 'yql/essentials/minikql/compact_hash_ut.cpp')
-rw-r--r-- | yql/essentials/minikql/compact_hash_ut.cpp | 404 |
1 files changed, 404 insertions, 0 deletions
diff --git a/yql/essentials/minikql/compact_hash_ut.cpp b/yql/essentials/minikql/compact_hash_ut.cpp new file mode 100644 index 0000000000..7c29e34e20 --- /dev/null +++ b/yql/essentials/minikql/compact_hash_ut.cpp @@ -0,0 +1,404 @@ +#include "compact_hash.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/bitmap.h> +#include <util/generic/ylimits.h> +#include <util/generic/hash_set.h> +#include <util/generic/maybe.h> +#include <util/generic/xrange.h> +#include <util/random/shuffle.h> + +using namespace NKikimr; +using namespace NKikimr::NCHash; + +Y_UNIT_TEST_SUITE(TCompactHashTest) { + + template <typename TItem> + void TestListPoolPagesImpl(size_t listSize, ui16 countOfLists, size_t expectedListCapacity, ui32 expectedMark) { + using TPool = TListPool<TItem>; + TAlignedPagePool pagePool(__LOCATION__); + TPool pool(pagePool); + UNIT_ASSERT(countOfLists > 1); + + THashSet<TItem*> lists; + + void* pageAddr = nullptr; + for (size_t i = 0; i < countOfLists / 2; ++i) { + TItem* l = pool.template GetList<TItem>(listSize); + UNIT_ASSERT_VALUES_EQUAL(expectedMark, TPool::GetMark(l)); + UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l)); + UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l)); + if (0 == i) { + pageAddr = TAlignedPagePool::GetPageStart(l); + } else { + // All lists are from the same page + UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); + } + UNIT_ASSERT(lists.insert(l).second); + } + // Return all lists except one + while (lists.size() > 1) { + auto it = lists.begin(); + pool.ReturnList(*it); + lists.erase(it); + } + + for (size_t i = 1; i < countOfLists; ++i) { + TItem* l = pool.template GetList<TItem>(listSize); + UNIT_ASSERT_VALUES_EQUAL(expectedMark, TPool::GetMark(l)); + UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l)); + UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l)); + UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); + UNIT_ASSERT(lists.insert(l).second); + } + + for (auto l: lists) { + pool.ReturnList(l); + } + + THashSet<TItem*> lists2; + for (size_t i = 0; i < countOfLists; ++i) { + TItem* l = pool.template GetList<TItem>(listSize); + // All lists are from the same page + UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); + UNIT_ASSERT(1 == lists.erase(l)); + UNIT_ASSERT(lists2.insert(l).second); + } + TItem* l = pool.template GetList<TItem>(listSize); // New page + UNIT_ASSERT_VALUES_UNEQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); + UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l)); + UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l)); + pool.ReturnList(l); + for (auto l: lists2) { + pool.ReturnList(l); + } + } + + template <typename TItem> + void TestListPoolLargeImpl() { + using TPool = TListPool<TItem>; + TAlignedPagePool pagePool(__LOCATION__); + TPool pool(pagePool); + const size_t listSize = TListPoolBase::GetMaxListSize<TItem>(); + TItem* l = pool.template GetList<TItem>(listSize); + pool.template IncrementList<TItem>(l); + UNIT_ASSERT_VALUES_EQUAL((ui32)TListPoolBase::LARGE_MARK, TListPoolBase::GetMark(l)); + UNIT_ASSERT_VALUES_EQUAL(listSize + 1, TListPoolBase::GetPartListSize(l)); + TListPoolBase::SetPartListSize(l, TListPoolBase::GetListCapacity(l)); + pool.template IncrementList<TItem>(l); + UNIT_ASSERT_VALUES_EQUAL(1, TListPoolBase::GetPartListSize(l)); + UNIT_ASSERT_VALUES_EQUAL(1 + TListPoolBase::GetListCapacity(l), TListPoolBase::GetFullListSize(l)); + pool.ReturnList(l); + } + + Y_UNIT_TEST(TestListPoolSmallPagesByte) { + ui16 count = TListPoolBase::GetSmallPageCapacity<ui8>(2); + TestListPoolPagesImpl<ui8>(2, count, 2, TListPoolBase::SMALL_MARK); + } + + Y_UNIT_TEST(TestListPoolMediumPagesByte) { + size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1; + ui16 count = TListPoolBase::GetMediumPageCapacity<ui8>(listSize); + TestListPoolPagesImpl<ui8>(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK); + } + + Y_UNIT_TEST(TestListPoolLargPagesByte) { + TestListPoolLargeImpl<ui8>(); + } + + Y_UNIT_TEST(TestListPoolSmallPagesUi64) { + ui16 count = TListPoolBase::GetSmallPageCapacity<ui64>(2); + TestListPoolPagesImpl<ui64>(2, count, 2, TListPoolBase::SMALL_MARK); + } + + Y_UNIT_TEST(TestListPoolMediumPagesUi64) { + size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1; + ui16 count = TListPoolBase::GetMediumPageCapacity<ui64>(listSize); + TestListPoolPagesImpl<ui64>(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK); + } + + Y_UNIT_TEST(TestListPoolLargPagesUi64) { + TestListPoolLargeImpl<ui64>(); + } + + struct TItem { + ui8 A[256]; + }; + + Y_UNIT_TEST(TestListPoolSmallPagesObj) { + ui16 count = TListPoolBase::GetSmallPageCapacity<TItem>(2); + TestListPoolPagesImpl<TItem>(2, count, 2, TListPoolBase::SMALL_MARK); + } + + Y_UNIT_TEST(TestListPoolMediumPagesObj) { + size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1; + ui16 count = TListPoolBase::GetMediumPageCapacity<TItem>(listSize); + TestListPoolPagesImpl<TItem>(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK); + } + + Y_UNIT_TEST(TestListPoolLargPagesObj) { + TestListPoolLargeImpl<TItem>(); + } + + struct TItemHash { + template <typename T> + size_t operator() (T num) const { + return num % 13; + } + }; + + template <typename TItem> + void TestHashImpl() { + const ui32 elementsCount = 32; + const ui64 sumKeysTarget = elementsCount * (elementsCount - 1) / 2; + + const ui32 addition = 20; + const ui64 sumValuesTarget = sumKeysTarget + addition * elementsCount; + + TAlignedPagePool pagePool(__LOCATION__); + TCompactHash<TItem, TItem, TItemHash> hash(pagePool); + + TVector<TItem> elements(elementsCount); + std::iota(elements.begin(), elements.end(), 0); + Shuffle(elements.begin(), elements.end()); + for (TItem i: elements) { + hash.Insert(i, i + addition); + } + + { + decltype(hash) hash2(std::move(hash)); + decltype(hash) hash3(pagePool); + hash3.Swap(hash2); + hash = hash3; + } + + for (TItem i: elements) { + UNIT_ASSERT(hash.Has(i)); + UNIT_ASSERT(hash.Find(i).Ok()); + UNIT_ASSERT_VALUES_EQUAL(i + addition, hash.Find(i).Get().second); + } + + UNIT_ASSERT(!hash.Has(elementsCount + 1)); + UNIT_ASSERT(!hash.Has(elementsCount + 10)); + UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size()); + UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.UniqSize()); + + ui64 sumKeys = 0; + ui64 sumValues = 0; + for (auto it = hash.Iterate(); it.Ok(); ++it) { + UNIT_ASSERT_VALUES_EQUAL(it.Get().first + addition, it.Get().second); + + sumKeys += it.Get().first; + sumValues += it.Get().second; + } + UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget); + UNIT_ASSERT_VALUES_EQUAL(sumValues, sumValuesTarget); + } + + template <typename TItem> + void TestMultiHashImpl() { + const ui32 keysCount = 10; + const ui64 elementsCount = keysCount * (keysCount + 1) / 2; + + TAlignedPagePool pagePool(__LOCATION__); + TCompactMultiHash<TItem, TItem, TItemHash> hash(pagePool); + + TVector<TItem> keys(keysCount); + std::iota(keys.begin(), keys.end(), 0); + Shuffle(keys.begin(), keys.end()); + + ui64 sumKeysTarget = 0; + ui64 sumValuesTarget = 0; + for (TItem k: keys) { + sumKeysTarget += k; + for (TItem i = 0; i < k + 1; ++i) { + hash.Insert(k, i); + sumValuesTarget += i; + } + } + + { + decltype(hash) hash2(std::move(hash)); + decltype(hash) hash3(pagePool); + hash3.Swap(hash2); + hash = hash3; + } + + for (TItem k: keys) { + UNIT_ASSERT(hash.Has(k)); + UNIT_ASSERT_VALUES_EQUAL(k + 1, hash.Count(k)); + auto it = hash.Find(k); + UNIT_ASSERT(it.Ok()); + TDynBitMap set; + for (; it.Ok(); ++it) { + UNIT_ASSERT_VALUES_EQUAL(k, it.GetKey()); + UNIT_ASSERT(!set.Test(it.GetValue())); + set.Set(it.GetValue()); + } + UNIT_ASSERT_VALUES_EQUAL(set.Count(), k + 1); + } + + UNIT_ASSERT(!hash.Has(keysCount + 1)); + UNIT_ASSERT(!hash.Has(keysCount + 10)); + UNIT_ASSERT(!hash.Find(keysCount + 1).Ok()); + + UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size()); + UNIT_ASSERT_VALUES_EQUAL(keysCount, hash.UniqSize()); + + ui64 sumKeys = 0; + ui64 sumValues = 0; + TMaybe<TItem> prevKey; + for (auto it = hash.Iterate(); it.Ok(); ++it) { + const auto key = it.GetKey(); + if (prevKey != key) { + sumKeys += key; + for (auto valIt = it.MakeCurrentKeyIter(); valIt.Ok(); ++valIt) { + UNIT_ASSERT_VALUES_EQUAL(key, valIt.GetKey()); + sumValues += valIt.GetValue(); + } + prevKey = key; + } + } + UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget); + UNIT_ASSERT_VALUES_EQUAL(sumValues, sumValuesTarget); + + // Test large lists + TItem val = 0; + for (size_t i = 0; i < TListPoolBase::GetLargePageCapacity<TItem>(); ++i) { + hash.Insert(keysCount, val++); + } + + TItem check = 0; + for (auto i = hash.Find(keysCount); i.Ok(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(keysCount, i.GetKey()); + UNIT_ASSERT_VALUES_EQUAL(check++, i.GetValue()); + } + UNIT_ASSERT_VALUES_EQUAL(check, val); + UNIT_ASSERT_VALUES_EQUAL(TListPoolBase::GetLargePageCapacity<TItem>(), hash.Count(keysCount)); + + for (size_t i = 0; i < TListPoolBase::GetLargePageCapacity<TItem>() + 1; ++i) { + hash.Insert(keysCount, val++); + } + + { + decltype(hash) hash2(std::move(hash)); + decltype(hash) hash3(pagePool); + hash3.Swap(hash2); + hash = hash3; + } + + check = 0; + for (auto i = hash.Find(keysCount); i.Ok(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(keysCount, i.GetKey()); + UNIT_ASSERT_VALUES_EQUAL(check++, i.GetValue()); + } + UNIT_ASSERT_VALUES_EQUAL(check, val); + UNIT_ASSERT_VALUES_EQUAL(2 * TListPoolBase::GetLargePageCapacity<TItem>() + 1, hash.Count(keysCount)); + } + + template <typename TItem> + void TestSetImpl() { + const ui32 elementsCount = 32; + const ui64 sumKeysTarget = elementsCount * (elementsCount - 1) / 2; + + TAlignedPagePool pagePool(__LOCATION__); + TCompactHashSet<TItem, TItemHash> hash(pagePool); + + TVector<TItem> elements(elementsCount); + std::iota(elements.begin(), elements.end(), 0); + Shuffle(elements.begin(), elements.end()); + for (TItem i: elements) { + hash.Insert(i); + } + + { + decltype(hash) hash2(std::move(hash)); + decltype(hash) hash3(pagePool); + hash3.Swap(hash2); + hash = hash3; + } + + for (TItem i: elements) { + UNIT_ASSERT(hash.Has(i)); + } + + UNIT_ASSERT(!hash.Has(elementsCount + 1)); + UNIT_ASSERT(!hash.Has(elementsCount + 10)); + UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size()); + UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.UniqSize()); + + ui64 sumKeys = 0; + for (auto i = hash.Iterate(); i.Ok(); ++i) { + sumKeys += *i; + } + UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget); + } + + Y_UNIT_TEST(TestHashByte) { + TestHashImpl<ui8>(); + } + + Y_UNIT_TEST(TestMultiHashByte) { + TestMultiHashImpl<ui8>(); + } + + Y_UNIT_TEST(TestSetByte) { + TestSetImpl<ui8>(); + } + + Y_UNIT_TEST(TestHashUi16) { + TestHashImpl<ui16>(); + } + + Y_UNIT_TEST(TestMultiHashUi16) { + TestMultiHashImpl<ui16>(); + } + + Y_UNIT_TEST(TestSetUi16) { + TestSetImpl<ui16>(); + } + + Y_UNIT_TEST(TestHashUi64) { + TestHashImpl<ui64>(); + } + + Y_UNIT_TEST(TestMultiHashUi64) { + TestMultiHashImpl<ui64>(); + } + + Y_UNIT_TEST(TestSetUi64) { + TestSetImpl<ui64>(); + } + + struct TStressHash { + TStressHash(size_t param) + : Param(param) + { + } + + template <typename T> + size_t operator() (T num) const { + return num % Param; + } + const size_t Param; + }; + + Y_UNIT_TEST(TestStressSmallLists) { + TAlignedPagePool pagePool(__LOCATION__); + for (size_t listSize: xrange<size_t>(2, 17, 1)) { + const size_t backets = TListPoolBase::GetSmallPageCapacity<ui64>(listSize); + const size_t elementsCount = backets * listSize; + + for (size_t count: xrange<size_t>(1, elementsCount + 1, elementsCount / 16)) { + TCompactHashSet<ui64, TStressHash> hash(pagePool, elementsCount, TStressHash(backets)); + for (auto i: xrange(count)) { + hash.Insert(i); + } + UNIT_ASSERT_VALUES_EQUAL(count, hash.Size()); + UNIT_ASSERT_VALUES_EQUAL(count, hash.UniqSize()); + //hash.PrintStat(Cerr); + } + } + } +} |