#include "compact_hash.h"
#include <library/cpp/testing/unittest/registar.h>
#include <util/generic/bitmap.h>
#include <util/generic/ylimits.h>
#include <util/generic/hash_set.h>
#include <util/generic/maybe.h>
#include <util/generic/xrange.h>
#include <util/random/shuffle.h>
using namespace NKikimr;
using namespace NKikimr::NCHash;
Y_UNIT_TEST_SUITE(TCompactHashTest) {
template <typename TItem>
void TestListPoolPagesImpl(size_t listSize, ui16 countOfLists, size_t expectedListCapacity, ui32 expectedMark) {
using TPool = TListPool<TItem>;
TAlignedPagePool pagePool(__LOCATION__);
TPool pool(pagePool);
UNIT_ASSERT(countOfLists > 1);
THashSet<TItem*> lists;
void* pageAddr = nullptr;
for (size_t i = 0; i < countOfLists / 2; ++i) {
TItem* l = pool.template GetList<TItem>(listSize);
UNIT_ASSERT_VALUES_EQUAL(expectedMark, TPool::GetMark(l));
UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l));
UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l));
if (0 == i) {
pageAddr = TAlignedPagePool::GetPageStart(l);
} else {
// All lists are from the same page
UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l));
}
UNIT_ASSERT(lists.insert(l).second);
}
// Return all lists except one
while (lists.size() > 1) {
auto it = lists.begin();
pool.ReturnList(*it);
lists.erase(it);
}
for (size_t i = 1; i < countOfLists; ++i) {
TItem* l = pool.template GetList<TItem>(listSize);
UNIT_ASSERT_VALUES_EQUAL(expectedMark, TPool::GetMark(l));
UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l));
UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l));
UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l));
UNIT_ASSERT(lists.insert(l).second);
}
for (auto l: lists) {
pool.ReturnList(l);
}
THashSet<TItem*> lists2;
for (size_t i = 0; i < countOfLists; ++i) {
TItem* l = pool.template GetList<TItem>(listSize);
// All lists are from the same page
UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l));
UNIT_ASSERT(1 == lists.erase(l));
UNIT_ASSERT(lists2.insert(l).second);
}
TItem* l = pool.template GetList<TItem>(listSize); // New page
UNIT_ASSERT_VALUES_UNEQUAL(pageAddr, TAlignedPagePool::GetPageStart(l));
UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l));
UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l));
pool.ReturnList(l);
for (auto l: lists2) {
pool.ReturnList(l);
}
}
template <typename TItem>
void TestListPoolLargeImpl() {
using TPool = TListPool<TItem>;
TAlignedPagePool pagePool(__LOCATION__);
TPool pool(pagePool);
const size_t listSize = TListPoolBase::GetMaxListSize<TItem>();
TItem* l = pool.template GetList<TItem>(listSize);
pool.template IncrementList<TItem>(l);
UNIT_ASSERT_VALUES_EQUAL((ui32)TListPoolBase::LARGE_MARK, TListPoolBase::GetMark(l));
UNIT_ASSERT_VALUES_EQUAL(listSize + 1, TListPoolBase::GetPartListSize(l));
TListPoolBase::SetPartListSize(l, TListPoolBase::GetListCapacity(l));
pool.template IncrementList<TItem>(l);
UNIT_ASSERT_VALUES_EQUAL(1, TListPoolBase::GetPartListSize(l));
UNIT_ASSERT_VALUES_EQUAL(1 + TListPoolBase::GetListCapacity(l), TListPoolBase::GetFullListSize(l));
pool.ReturnList(l);
}
Y_UNIT_TEST(TestListPoolSmallPagesByte) {
ui16 count = TListPoolBase::GetSmallPageCapacity<ui8>(2);
TestListPoolPagesImpl<ui8>(2, count, 2, TListPoolBase::SMALL_MARK);
}
Y_UNIT_TEST(TestListPoolMediumPagesByte) {
size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1;
ui16 count = TListPoolBase::GetMediumPageCapacity<ui8>(listSize);
TestListPoolPagesImpl<ui8>(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK);
}
Y_UNIT_TEST(TestListPoolLargPagesByte) {
TestListPoolLargeImpl<ui8>();
}
Y_UNIT_TEST(TestListPoolSmallPagesUi64) {
ui16 count = TListPoolBase::GetSmallPageCapacity<ui64>(2);
TestListPoolPagesImpl<ui64>(2, count, 2, TListPoolBase::SMALL_MARK);
}
Y_UNIT_TEST(TestListPoolMediumPagesUi64) {
size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1;
ui16 count = TListPoolBase::GetMediumPageCapacity<ui64>(listSize);
TestListPoolPagesImpl<ui64>(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK);
}
Y_UNIT_TEST(TestListPoolLargPagesUi64) {
TestListPoolLargeImpl<ui64>();
}
struct TItem {
ui8 A[256];
};
Y_UNIT_TEST(TestListPoolSmallPagesObj) {
ui16 count = TListPoolBase::GetSmallPageCapacity<TItem>(2);
TestListPoolPagesImpl<TItem>(2, count, 2, TListPoolBase::SMALL_MARK);
}
Y_UNIT_TEST(TestListPoolMediumPagesObj) {
size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1;
ui16 count = TListPoolBase::GetMediumPageCapacity<TItem>(listSize);
TestListPoolPagesImpl<TItem>(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK);
}
Y_UNIT_TEST(TestListPoolLargPagesObj) {
TestListPoolLargeImpl<TItem>();
}
struct TItemHash {
template <typename T>
size_t operator() (T num) const {
return num % 13;
}
};
template <typename TItem>
void TestHashImpl() {
const ui32 elementsCount = 32;
const ui64 sumKeysTarget = elementsCount * (elementsCount - 1) / 2;
const ui32 addition = 20;
const ui64 sumValuesTarget = sumKeysTarget + addition * elementsCount;
TAlignedPagePool pagePool(__LOCATION__);
TCompactHash<TItem, TItem, TItemHash> hash(pagePool);
TVector<TItem> elements(elementsCount);
std::iota(elements.begin(), elements.end(), 0);
Shuffle(elements.begin(), elements.end());
for (TItem i: elements) {
hash.Insert(i, i + addition);
}
{
decltype(hash) hash2(std::move(hash));
decltype(hash) hash3(pagePool);
hash3.Swap(hash2);
hash = hash3;
}
for (TItem i: elements) {
UNIT_ASSERT(hash.Has(i));
UNIT_ASSERT(hash.Find(i).Ok());
UNIT_ASSERT_VALUES_EQUAL(i + addition, hash.Find(i).Get().second);
}
UNIT_ASSERT(!hash.Has(elementsCount + 1));
UNIT_ASSERT(!hash.Has(elementsCount + 10));
UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size());
UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.UniqSize());
ui64 sumKeys = 0;
ui64 sumValues = 0;
for (auto it = hash.Iterate(); it.Ok(); ++it) {
UNIT_ASSERT_VALUES_EQUAL(it.Get().first + addition, it.Get().second);
sumKeys += it.Get().first;
sumValues += it.Get().second;
}
UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget);
UNIT_ASSERT_VALUES_EQUAL(sumValues, sumValuesTarget);
}
template <typename TItem>
void TestMultiHashImpl() {
const ui32 keysCount = 10;
const ui64 elementsCount = keysCount * (keysCount + 1) / 2;
TAlignedPagePool pagePool(__LOCATION__);
TCompactMultiHash<TItem, TItem, TItemHash> hash(pagePool);
TVector<TItem> keys(keysCount);
std::iota(keys.begin(), keys.end(), 0);
Shuffle(keys.begin(), keys.end());
ui64 sumKeysTarget = 0;
ui64 sumValuesTarget = 0;
for (TItem k: keys) {
sumKeysTarget += k;
for (TItem i = 0; i < k + 1; ++i) {
hash.Insert(k, i);
sumValuesTarget += i;
}
}
{
decltype(hash) hash2(std::move(hash));
decltype(hash) hash3(pagePool);
hash3.Swap(hash2);
hash = hash3;
}
for (TItem k: keys) {
UNIT_ASSERT(hash.Has(k));
UNIT_ASSERT_VALUES_EQUAL(k + 1, hash.Count(k));
auto it = hash.Find(k);
UNIT_ASSERT(it.Ok());
TDynBitMap set;
for (; it.Ok(); ++it) {
UNIT_ASSERT_VALUES_EQUAL(k, it.GetKey());
UNIT_ASSERT(!set.Test(it.GetValue()));
set.Set(it.GetValue());
}
UNIT_ASSERT_VALUES_EQUAL(set.Count(), k + 1);
}
UNIT_ASSERT(!hash.Has(keysCount + 1));
UNIT_ASSERT(!hash.Has(keysCount + 10));
UNIT_ASSERT(!hash.Find(keysCount + 1).Ok());
UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size());
UNIT_ASSERT_VALUES_EQUAL(keysCount, hash.UniqSize());
ui64 sumKeys = 0;
ui64 sumValues = 0;
TMaybe<TItem> prevKey;
for (auto it = hash.Iterate(); it.Ok(); ++it) {
const auto key = it.GetKey();
if (prevKey != key) {
sumKeys += key;
for (auto valIt = it.MakeCurrentKeyIter(); valIt.Ok(); ++valIt) {
UNIT_ASSERT_VALUES_EQUAL(key, valIt.GetKey());
sumValues += valIt.GetValue();
}
prevKey = key;
}
}
UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget);
UNIT_ASSERT_VALUES_EQUAL(sumValues, sumValuesTarget);
// Test large lists
TItem val = 0;
for (size_t i = 0; i < TListPoolBase::GetLargePageCapacity<TItem>(); ++i) {
hash.Insert(keysCount, val++);
}
TItem check = 0;
for (auto i = hash.Find(keysCount); i.Ok(); ++i) {
UNIT_ASSERT_VALUES_EQUAL(keysCount, i.GetKey());
UNIT_ASSERT_VALUES_EQUAL(check++, i.GetValue());
}
UNIT_ASSERT_VALUES_EQUAL(check, val);
UNIT_ASSERT_VALUES_EQUAL(TListPoolBase::GetLargePageCapacity<TItem>(), hash.Count(keysCount));
for (size_t i = 0; i < TListPoolBase::GetLargePageCapacity<TItem>() + 1; ++i) {
hash.Insert(keysCount, val++);
}
{
decltype(hash) hash2(std::move(hash));
decltype(hash) hash3(pagePool);
hash3.Swap(hash2);
hash = hash3;
}
check = 0;
for (auto i = hash.Find(keysCount); i.Ok(); ++i) {
UNIT_ASSERT_VALUES_EQUAL(keysCount, i.GetKey());
UNIT_ASSERT_VALUES_EQUAL(check++, i.GetValue());
}
UNIT_ASSERT_VALUES_EQUAL(check, val);
UNIT_ASSERT_VALUES_EQUAL(2 * TListPoolBase::GetLargePageCapacity<TItem>() + 1, hash.Count(keysCount));
}
template <typename TItem>
void TestSetImpl() {
const ui32 elementsCount = 32;
const ui64 sumKeysTarget = elementsCount * (elementsCount - 1) / 2;
TAlignedPagePool pagePool(__LOCATION__);
TCompactHashSet<TItem, TItemHash> hash(pagePool);
TVector<TItem> elements(elementsCount);
std::iota(elements.begin(), elements.end(), 0);
Shuffle(elements.begin(), elements.end());
for (TItem i: elements) {
hash.Insert(i);
}
{
decltype(hash) hash2(std::move(hash));
decltype(hash) hash3(pagePool);
hash3.Swap(hash2);
hash = hash3;
}
for (TItem i: elements) {
UNIT_ASSERT(hash.Has(i));
}
UNIT_ASSERT(!hash.Has(elementsCount + 1));
UNIT_ASSERT(!hash.Has(elementsCount + 10));
UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size());
UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.UniqSize());
ui64 sumKeys = 0;
for (auto i = hash.Iterate(); i.Ok(); ++i) {
sumKeys += *i;
}
UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget);
}
Y_UNIT_TEST(TestHashByte) {
TestHashImpl<ui8>();
}
Y_UNIT_TEST(TestMultiHashByte) {
TestMultiHashImpl<ui8>();
}
Y_UNIT_TEST(TestSetByte) {
TestSetImpl<ui8>();
}
Y_UNIT_TEST(TestHashUi16) {
TestHashImpl<ui16>();
}
Y_UNIT_TEST(TestMultiHashUi16) {
TestMultiHashImpl<ui16>();
}
Y_UNIT_TEST(TestSetUi16) {
TestSetImpl<ui16>();
}
Y_UNIT_TEST(TestHashUi64) {
TestHashImpl<ui64>();
}
Y_UNIT_TEST(TestMultiHashUi64) {
TestMultiHashImpl<ui64>();
}
Y_UNIT_TEST(TestSetUi64) {
TestSetImpl<ui64>();
}
struct TStressHash {
TStressHash(size_t param)
: Param(param)
{
}
template <typename T>
size_t operator() (T num) const {
return num % Param;
}
const size_t Param;
};
Y_UNIT_TEST(TestStressSmallLists) {
TAlignedPagePool pagePool(__LOCATION__);
for (size_t listSize: xrange<size_t>(2, 17, 1)) {
const size_t backets = TListPoolBase::GetSmallPageCapacity<ui64>(listSize);
const size_t elementsCount = backets * listSize;
for (size_t count: xrange<size_t>(1, elementsCount + 1, elementsCount / 16)) {
TCompactHashSet<ui64, TStressHash> hash(pagePool, elementsCount, TStressHash(backets));
for (auto i: xrange(count)) {
hash.Insert(i);
}
UNIT_ASSERT_VALUES_EQUAL(count, hash.Size());
UNIT_ASSERT_VALUES_EQUAL(count, hash.UniqSize());
//hash.PrintStat(Cerr);
}
}
}
}