diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/containers | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/containers')
170 files changed, 18685 insertions, 0 deletions
diff --git a/library/cpp/containers/2d_array/2d_array.cpp b/library/cpp/containers/2d_array/2d_array.cpp new file mode 100644 index 00000000000..03115c7e2f9 --- /dev/null +++ b/library/cpp/containers/2d_array/2d_array.cpp @@ -0,0 +1 @@ +#include "2d_array.h" diff --git a/library/cpp/containers/2d_array/2d_array.h b/library/cpp/containers/2d_array/2d_array.h new file mode 100644 index 00000000000..9e246506370 --- /dev/null +++ b/library/cpp/containers/2d_array/2d_array.h @@ -0,0 +1,125 @@ +#pragma once + +#include <util/system/yassert.h> +#include <util/generic/algorithm.h> + +#ifdef _DEBUG +template <class T> +struct TBoundCheck { + T* Data; + size_t Size; + TBoundCheck(T* d, size_t s) { + Data = d; + Size = s; + } + T& operator[](size_t i) const { + Y_ASSERT(i >= 0 && i < Size); + return Data[i]; + } +}; +#endif + +template <class T> +class TArray2D { +private: + typedef T* PT; + T* Data; + T** PData; + size_t XSize; + size_t YSize; + +private: + void Copy(const TArray2D& a) { + XSize = a.XSize; + YSize = a.YSize; + Create(); + for (size_t i = 0; i < XSize * YSize; i++) + Data[i] = a.Data[i]; + } + void Destroy() { + delete[] Data; + delete[] PData; + } + void Create() { + Data = new T[XSize * YSize]; + PData = new PT[YSize]; + for (size_t i = 0; i < YSize; i++) + PData[i] = Data + i * XSize; + } + +public: + TArray2D(size_t xsize = 1, size_t ysize = 1) { + XSize = xsize; + YSize = ysize; + Create(); + } + TArray2D(const TArray2D& a) { + Copy(a); + } + TArray2D& operator=(const TArray2D& a) { + Destroy(); + Copy(a); + return *this; + } + ~TArray2D() { + Destroy(); + } + void SetSizes(size_t xsize, size_t ysize) { + if (XSize == xsize && YSize == ysize) + return; + Destroy(); + XSize = xsize; + YSize = ysize; + Create(); + } + void Clear() { + SetSizes(1, 1); + } +#ifdef _DEBUG + TBoundCheck<T> operator[](size_t i) const { + Y_ASSERT(i < YSize); + return TBoundCheck<T>(PData[i], XSize); + } +#else + T* operator[](size_t i) const { + Y_ASSERT(i < YSize); + return PData[i]; + } +#endif + size_t GetXSize() const { + return XSize; + } + size_t GetYSize() const { + return YSize; + } + void FillZero() { + memset(Data, 0, sizeof(T) * XSize * YSize); + } + void FillEvery(const T& a) { + for (size_t i = 0; i < XSize * YSize; i++) + Data[i] = a; + } + void Swap(TArray2D& a) { + std::swap(Data, a.Data); + std::swap(PData, a.PData); + std::swap(XSize, a.XSize); + std::swap(YSize, a.YSize); + } +}; + +template <class T> +inline bool operator==(const TArray2D<T>& a, const TArray2D<T>& b) { + if (a.GetXSize() != b.GetXSize() || a.GetYSize() != b.GetYSize()) + return false; + for (size_t y = 0; y < a.GetYSize(); ++y) { + for (size_t x = 0; x < a.GetXSize(); ++x) + if (a[y][x] != b[y][x]) + return false; + } + return true; +} + +template <class T> +inline bool operator!=(const TArray2D<T>& a, const TArray2D<T>& b) { + return !(a == b); +} diff --git a/library/cpp/containers/2d_array/ya.make b/library/cpp/containers/2d_array/ya.make new file mode 100644 index 00000000000..71d56b902f3 --- /dev/null +++ b/library/cpp/containers/2d_array/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(kirillovs) + +SRCS( + 2d_array.cpp +) + +END() diff --git a/library/cpp/containers/atomizer/atomizer.cpp b/library/cpp/containers/atomizer/atomizer.cpp new file mode 100644 index 00000000000..7a5f781d992 --- /dev/null +++ b/library/cpp/containers/atomizer/atomizer.cpp @@ -0,0 +1 @@ +#include "atomizer.h" diff --git a/library/cpp/containers/atomizer/atomizer.h b/library/cpp/containers/atomizer/atomizer.h new file mode 100644 index 00000000000..5e40f47ab93 --- /dev/null +++ b/library/cpp/containers/atomizer/atomizer.h @@ -0,0 +1,200 @@ +#pragma once + +#include <library/cpp/containers/str_map/str_map.h> + +#include <util/generic/vector.h> +#include <util/generic/utility.h> + +#include <utility> +#include <cstdio> + +template <class HashFcn = THash<const char*>, class EqualTo = TEqualTo<const char*>> +class atomizer; + +template <class T, class HashFcn = THash<const char*>, class EqualTo = TEqualTo<const char*>> +class super_atomizer; + +template <class HashFcn, class EqualTo> +class atomizer: public string_hash<ui32, HashFcn, EqualTo> { +private: + TVector<const char*> order; + +public: + using iterator = typename string_hash<ui32, HashFcn, EqualTo>::iterator; + using const_iterator = typename string_hash<ui32, HashFcn, EqualTo>::const_iterator; + using value_type = typename string_hash<ui32, HashFcn, EqualTo>::value_type; + using size_type = typename string_hash<ui32, HashFcn, EqualTo>::size_type; + using pool_size_type = typename string_hash<ui32, HashFcn, EqualTo>::pool_size_type; + + using string_hash<ui32, HashFcn, EqualTo>::pool; + using string_hash<ui32, HashFcn, EqualTo>::size; + using string_hash<ui32, HashFcn, EqualTo>::find; + using string_hash<ui32, HashFcn, EqualTo>::end; + using string_hash<ui32, HashFcn, EqualTo>::insert_copy; + using string_hash<ui32, HashFcn, EqualTo>::clear_hash; + + atomizer() { + order.reserve(HASH_SIZE_DEFAULT); + } + atomizer(size_type hash_size, pool_size_type pool_size) + : string_hash<ui32, HashFcn, EqualTo>(hash_size, pool_size) + { + order.reserve(hash_size); + } + ~atomizer() = default; + ui32 string_to_atom(const char* key) { + const char* old_begin = pool.Begin(); + const char* old_end = pool.End(); + std::pair<iterator, bool> ins = insert_copy(key, ui32(size() + 1)); + if (ins.second) { // new? + if (pool.Begin() != old_begin) // repoint? + for (TVector<const char*>::iterator ptr = order.begin(); ptr != order.end(); ++ptr) + if (old_begin <= *ptr && *ptr < old_end) // from old pool? + *ptr += pool.Begin() - old_begin; + order.push_back((*ins.first).first); // copy of 'key' + } + return (ui32)(*ins.first).second; + } + + ui32 perm_string_to_atom(const char* key) { + value_type val(key, ui32(size() + 1)); + std::pair<iterator, bool> ins = this->insert(val); + if (ins.second) + order.push_back((*ins.first).first); // == copy of 'key' + return (ui32)(*ins.first).second; // == size()+1 + } + ui32 find_atom(const char* key) const { + const_iterator it = find(key); + if (it == end()) + return 0; // INVALID_ATOM + else + return (ui32)(*it).second; + } + const char* get_atom_name(ui32 atom) const { + if (atom && atom <= size()) + return order[atom - 1]; + return nullptr; + } + void clear_atomizer() { + clear_hash(); + order.clear(); + } + void SaveC2N(FILE* f) const { // we write sorted file + for (ui32 i = 0; i < order.size(); i++) + if (order[i]) + fprintf(f, "%d\t%s\n", i + 1, order[i]); + } + void LoadC2N(FILE* f) { // but can read unsorted one + long k, km = 0; + char buf[1000]; + char* s; + while (fgets(buf, 1000, f)) { + k = strtol(buf, &s, 10); + char* endl = strchr(s, '\n'); + if (endl) + *endl = 0; + if (k > 0 && k != LONG_MAX) { + km = Max(km, k); + insert_copy(++s, ui32(k)); + } + } + order.resize(km); + memset(&order[0], 0, order.size()); // if some atoms are absent + for (const_iterator I = this->begin(); I != end(); ++I) + order[(*I).second - 1] = (*I).first; + } +}; + +template <class T, class HashFcn, class EqualTo> +class super_atomizer: public string_hash<ui32, HashFcn, EqualTo> { +private: + using TOrder = TVector<std::pair<const char*, T>>; + TOrder order; + +public: + using iterator = typename string_hash<ui32, HashFcn, EqualTo>::iterator; + using const_iterator = typename string_hash<ui32, HashFcn, EqualTo>::const_iterator; + using value_type = typename string_hash<ui32, HashFcn, EqualTo>::value_type; + using size_type = typename string_hash<ui32, HashFcn, EqualTo>::size_type; + using pool_size_type = typename string_hash<ui32, HashFcn, EqualTo>::pool_size_type; + + using o_iterator = typename TOrder::iterator; + using o_const_iterator = typename TOrder::const_iterator; + using o_value_type = typename TOrder::value_type; + + using string_hash<ui32, HashFcn, EqualTo>::pool; + using string_hash<ui32, HashFcn, EqualTo>::size; + using string_hash<ui32, HashFcn, EqualTo>::find; + using string_hash<ui32, HashFcn, EqualTo>::end; + using string_hash<ui32, HashFcn, EqualTo>::insert_copy; + using string_hash<ui32, HashFcn, EqualTo>::clear_hash; + + super_atomizer() { + order.reserve(HASH_SIZE_DEFAULT); + } + super_atomizer(size_type hash_size, pool_size_type pool_size) + : string_hash<ui32, HashFcn, EqualTo>(hash_size, pool_size) + { + order.reserve(hash_size); + } + ~super_atomizer() = default; + ui32 string_to_atom(const char* key, const T* atom_data = NULL) { + const char* old_begin = pool.Begin(); + const char* old_end = pool.End(); + std::pair<iterator, bool> ins = insert_copy(key, ui32(size() + 1)); + if (ins.second) { // new? + if (pool.Begin() != old_begin) // repoint? + for (typename TOrder::iterator ptr = order.begin(); ptr != order.end(); ++ptr) + if (old_begin <= (*ptr).first && (*ptr).first < old_end) // from old pool? + (*ptr).first += pool.Begin() - old_begin; + order.push_back(std::pair<const char*, T>((*ins.first).first, atom_data ? *atom_data : T())); + } + return (*ins.first).second; + } + + ui32 perm_string_to_atom(const char* key, const T* atom_data = NULL) { + value_type val(key, ui32(size() + 1)); + std::pair<iterator, bool> ins = this->insert(val); + if (ins.second) + order.push_back(std::pair<const char*, T>((*ins.first).first, atom_data ? *atom_data : T())); + return (*ins.first).second; // == size()+1 + } + ui32 find_atom(const char* key) const { + const_iterator it = find(key); + if (it == end()) + return 0; // INVALID_ATOM + else + return (*it).second; + } + const char* get_atom_name(ui32 atom) const { + if (atom && atom <= size()) + return order[atom - 1].first; + return nullptr; + } + const T* get_atom_data(ui32 atom) const { + if (atom && atom <= size()) + return &order[atom - 1].second; + return NULL; + } + T* get_atom_data(ui32 atom) { + if (atom && atom <= size()) + return &order[atom - 1].second; + return NULL; + } + o_iterator o_begin() { + return order.begin(); + } + o_iterator o_end() { + return order.end(); + } + o_const_iterator o_begin() const { + return order.begin(); + } + o_const_iterator o_end() const { + return order.end(); + } + void clear_atomizer() { + clear_hash(); + order.clear(); + } +}; diff --git a/library/cpp/containers/atomizer/ya.make b/library/cpp/containers/atomizer/ya.make new file mode 100644 index 00000000000..55165a3b672 --- /dev/null +++ b/library/cpp/containers/atomizer/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +OWNER(g:util) + +PEERDIR( + library/cpp/containers/str_map +) + +SRCS( + atomizer.cpp +) + +END() diff --git a/library/cpp/containers/bitseq/bititerator.h b/library/cpp/containers/bitseq/bititerator.h new file mode 100644 index 00000000000..52dadd37982 --- /dev/null +++ b/library/cpp/containers/bitseq/bititerator.h @@ -0,0 +1,138 @@ +#pragma once + +#include "traits.h" + +#include <library/cpp/pop_count/popcount.h> + +template <typename T> +class TBitIterator { +public: + using TWord = T; + using TTraits = TBitSeqTraits<TWord>; + +public: + TBitIterator(const T* data = nullptr) + : Current(0) + , Mask(0) + , Data(data) + { + } + + /// Get the word next to the one we are currenlty iterating over. + const TWord* NextWord() const { + return Data; + } + + /// Get the next bit without moving the iterator. + bool Peek() const { + return Mask ? (Current & Mask) : (*Data & 1); + } + + /// Get the next bit and move forward. + /// TODO: Implement inversed iteration as well. + bool Next() { + if (!Mask) { + Current = *Data++; + Mask = 1; + } + const bool bit = Current & Mask; + Mask <<= 1; + return bit; + } + + /// Get the next count bits without moving the iterator. + TWord Peek(ui8 count) const { + if (!count) + return 0; + Y_VERIFY_DEBUG(count <= TTraits::NumBits); + + if (!Mask) + return *Data & TTraits::ElemMask(count); + + auto usedBits = (size_t)PopCount(Mask - 1); + TWord result = Current >> usedBits; + auto leftInCurrent = TTraits::NumBits - usedBits; + if (count <= leftInCurrent) + return result & TTraits::ElemMask(count); + + count -= leftInCurrent; + result |= (*Data & TTraits::ElemMask(count)) << leftInCurrent; + return result; + } + + /// Get the next count bits and move forward by count bits. + TWord Read(ui8 count) { + if (!count) + return 0; + Y_VERIFY_DEBUG(count <= TTraits::NumBits); + + if (!Mask) { + Current = *Data++; + Mask = 1 << count; + return Current & TTraits::ElemMask(count); + } + + auto usedBits = (size_t)PopCount(Mask - 1); + TWord result = Current >> usedBits; + auto leftInCurrent = TTraits::NumBits - usedBits; + if (count < leftInCurrent) { + Mask <<= count; + return result & TTraits::ElemMask(count); + } + + count -= leftInCurrent; + if (count) { + Current = *Data++; + Mask = 1 << count; + result |= (Current & TTraits::ElemMask(count)) << leftInCurrent; + } else { + Mask = 0; + } + + return result; + } + + /// Move the iterator forward by count bits. + void Forward(int count) { + if (!count) + return; + + int leftInCurrent = (size_t)PopCount(~(Mask - 1)); + if (count < leftInCurrent) { + Mask <<= count; + return; + } + + count -= leftInCurrent; + Data += count >> TTraits::DivShift; + auto remainder = count & TTraits::ModMask; + + if (remainder) { + Current = *Data++; + Mask = 1 << remainder; + } else { + Current = 0; + Mask = 0; + } + } + + /// Skip trailing bits of the current word and move by count words forward. + void Align(int count = 0) { + Current = 0; + if (Mask) + Mask = 0; + Data += count; + } + + /// Initialize the iterator. + void Reset(const TWord* data) { + Current = 0; + Mask = 0; + Data = data; + } + +private: + TWord Current; + TWord Mask; + const TWord* Data; +}; diff --git a/library/cpp/containers/bitseq/bititerator_ut.cpp b/library/cpp/containers/bitseq/bititerator_ut.cpp new file mode 100644 index 00000000000..ed0925866f6 --- /dev/null +++ b/library/cpp/containers/bitseq/bititerator_ut.cpp @@ -0,0 +1,109 @@ +#include "bititerator.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/vector.h> + +Y_UNIT_TEST_SUITE(TBitIteratorTest) { + TVector<ui16> GenWords() { + TVector<ui16> words(1, 0); + for (ui16 word = 1; word; ++word) + words.push_back(word); + return words; + } + + template <typename TWord> + void AssertPeekRead(TBitIterator<TWord> & iter, ui8 count, TWord expected) { + auto peek = iter.Peek(count); + auto read = iter.Read(count); + UNIT_ASSERT_EQUAL(peek, read); + UNIT_ASSERT_EQUAL(peek, expected); + } + + Y_UNIT_TEST(TestNextAndPeek) { + const auto& words = GenWords(); + + TBitIterator<ui16> iter(words.data()); + ui16 word = 0; + for (int i = 0; i != (1 << 16); ++i, ++word) { + for (int bit = 0; bit != 16; ++bit) { + auto peek = iter.Peek(); + auto next = iter.Next(); + UNIT_ASSERT_EQUAL(peek, next); + UNIT_ASSERT_EQUAL(peek, (word >> bit) & 1); + } + UNIT_ASSERT_EQUAL(iter.NextWord(), words.data() + i + 1); + } + + UNIT_ASSERT_EQUAL(iter.NextWord(), words.data() + words.size()); + } + + Y_UNIT_TEST(TestAlignedReadAndPeek) { + const auto& words = GenWords(); + + TBitIterator<ui16> iter(words.data()); + ui16 word = 0; + for (int i = 0; i != (1 << 16); ++i, ++word) { + AssertPeekRead(iter, 16, word); + UNIT_ASSERT_EQUAL(iter.NextWord(), words.data() + i + 1); + } + + UNIT_ASSERT_EQUAL(iter.NextWord(), words.data() + words.size()); + } + + Y_UNIT_TEST(TestForward) { + TVector<ui32> words; + words.push_back((1 << 10) | (1 << 20) | (1 << 25)); + words.push_back(1 | (1 << 5) | (1 << 6) | (1 << 30)); + for (int i = 0; i < 3; ++i) + words.push_back(0); + words.push_back(1 << 10); + + TBitIterator<ui32> iter(words.data()); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT(!iter.Next()); + iter.Forward(6); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT(iter.Next()); + UNIT_ASSERT(!iter.Next()); + iter.Forward(8); + UNIT_ASSERT(iter.Next()); + iter.Forward(4); + UNIT_ASSERT(iter.Next()); + iter.Forward(5); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT(iter.Next()); + iter.Forward(4); + UNIT_ASSERT(iter.Next()); + + iter.Reset(words.data()); + iter.Forward(38); + UNIT_ASSERT(iter.Next()); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT_EQUAL(iter.NextWord(), words.data() + 2); + + iter.Forward(24 + 32 * 3 + 9); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT(iter.Next()); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT_EQUAL(iter.NextWord(), words.data() + 6); + } + + Y_UNIT_TEST(TestUnalignedReadAndPeek) { + TVector<ui32> words; + words.push_back((1 << 10) | (1 << 20) | (1 << 25)); + words.push_back(1 | (1 << 5) | (1 << 6) | (1 << 30)); + for (int i = 0; i < 5; ++i) + words.push_back(1 | (1 << 10)); + + TBitIterator<ui32> iter(words.data()); + AssertPeekRead(iter, 5, ui32(0)); + AssertPeekRead(iter, 7, ui32(1 << 5)); + AssertPeekRead(iter, 21, ui32((1 << 8) | (1 << 13) | (1 << 20))); + AssertPeekRead(iter, 32, (words[1] >> 1) | (1 << 31)); + iter.Forward(8); + UNIT_ASSERT(!iter.Next()); + UNIT_ASSERT(iter.Next()); + UNIT_ASSERT(!iter.Next()); + } +} diff --git a/library/cpp/containers/bitseq/bitvector.cpp b/library/cpp/containers/bitseq/bitvector.cpp new file mode 100644 index 00000000000..05cb3a881df --- /dev/null +++ b/library/cpp/containers/bitseq/bitvector.cpp @@ -0,0 +1 @@ +#include "bitvector.h" diff --git a/library/cpp/containers/bitseq/bitvector.h b/library/cpp/containers/bitseq/bitvector.h new file mode 100644 index 00000000000..3f8fd81ee57 --- /dev/null +++ b/library/cpp/containers/bitseq/bitvector.h @@ -0,0 +1,158 @@ +#pragma once + +#include "traits.h" + +#include <library/cpp/pop_count/popcount.h> + +#include <util/generic/vector.h> +#include <util/ysaveload.h> + +template <typename T> +class TReadonlyBitVector; + +template <typename T> +class TBitVector { +public: + using TWord = T; + using TTraits = TBitSeqTraits<TWord>; + +private: + friend class TReadonlyBitVector<T>; + ui64 Size_; + TVector<TWord> Data_; + +public: + TBitVector() + : Size_(0) + , Data_(0) + { + } + + TBitVector(ui64 size) + : Size_(size) + , Data_(static_cast<size_t>((Size_ + TTraits::ModMask) >> TTraits::DivShift), 0) + { + } + + virtual ~TBitVector() = default; + + void Clear() { + Size_ = 0; + Data_.clear(); + } + + void Resize(ui64 size) { + Size_ = size; + Data_.resize((Size_ + TTraits::ModMask) >> TTraits::DivShift); + } + + void Swap(TBitVector& other) { + DoSwap(Size_, other.Size_); + DoSwap(Data_, other.Data_); + } + + bool Set(ui64 pos) { + Y_ASSERT(pos < Size_); + TWord& val = Data_[pos >> TTraits::DivShift]; + if (val & TTraits::BitMask(pos & TTraits::ModMask)) + return false; + val |= TTraits::BitMask(pos & TTraits::ModMask); + return true; + } + + bool Test(ui64 pos) const { + return TTraits::Test(Data(), pos, Size_); + } + + void Reset(ui64 pos) { + Y_ASSERT(pos < Size_); + Data_[pos >> TTraits::DivShift] &= ~TTraits::BitMask(pos & TTraits::ModMask); + } + + TWord Get(ui64 pos, ui8 width, TWord mask) const { + return TTraits::Get(Data(), pos, width, mask, Size_); + } + + TWord Get(ui64 pos, ui8 width) const { + return Get(pos, width, TTraits::ElemMask(width)); + } + + void Set(ui64 pos, TWord value, ui8 width, TWord mask) { + if (!width) + return; + Y_ASSERT((pos + width) <= Size_); + size_t word = pos >> TTraits::DivShift; + TWord shift1 = pos & TTraits::ModMask; + TWord shift2 = TTraits::NumBits - shift1; + Data_[word] &= ~(mask << shift1); + Data_[word] |= (value & mask) << shift1; + if (shift2 < width) { + Data_[word + 1] &= ~(mask >> shift2); + Data_[word + 1] |= (value & mask) >> shift2; + } + } + + void Set(ui64 pos, TWord value, ui8 width) { + Set(pos, value, width, TTraits::ElemMask(width)); + } + + void Append(TWord value, ui8 width, TWord mask) { + if (!width) + return; + if (Data_.size() * TTraits::NumBits < Size_ + width) { + Data_.push_back(0); + } + Size_ += width; + Set(Size_ - width, value, width, mask); + } + + void Append(TWord value, ui8 width) { + Append(value, width, TTraits::ElemMask(width)); + } + + size_t Count() const { + size_t count = 0; + for (size_t i = 0; i < Data_.size(); ++i) { + count += (size_t)PopCount(Data_[i]); + } + return count; + } + + ui64 Size() const { + return Size_; + } + + size_t Words() const { + return Data_.size(); + } + + const TWord* Data() const { + return Data_.data(); + } + + void Save(IOutputStream* out) const { + ::Save(out, Size_); + ::Save(out, Data_); + } + + void Load(IInputStream* inp) { + ::Load(inp, Size_); + ::Load(inp, Data_); + } + + ui64 Space() const { + return CHAR_BIT * (sizeof(Size_) + + Data_.size() * sizeof(TWord)); + } + + void Print(IOutputStream& out, size_t truncate = 128) { + for (size_t i = 0; i < Data_.size() && i < truncate; ++i) { + for (int j = TTraits::NumBits - 1; j >= 0; --j) { + size_t pos = TTraits::NumBits * i + j; + out << (pos < Size_ && Test(pos) ? '1' : '0'); + } + out << " "; + } + out << Endl; + } +}; diff --git a/library/cpp/containers/bitseq/bitvector_ut.cpp b/library/cpp/containers/bitseq/bitvector_ut.cpp new file mode 100644 index 00000000000..6137adab1e8 --- /dev/null +++ b/library/cpp/containers/bitseq/bitvector_ut.cpp @@ -0,0 +1,86 @@ +#include "bitvector.h" +#include "readonly_bitvector.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/memory/blob.h> +#include <util/stream/buffer.h> + +Y_UNIT_TEST_SUITE(TBitVectorTest) { + Y_UNIT_TEST(TestEmpty) { + TBitVector<ui64> v64; + UNIT_ASSERT_EQUAL(v64.Size(), 0); + UNIT_ASSERT_EQUAL(v64.Words(), 0); + + TBitVector<ui32> v32(0); + UNIT_ASSERT_EQUAL(v32.Size(), 0); + UNIT_ASSERT_EQUAL(v32.Words(), 0); + } + + Y_UNIT_TEST(TestOneWord) { + TBitVector<ui32> v; + v.Append(1, 1); + v.Append(0, 1); + v.Append(1, 3); + v.Append(10, 4); + v.Append(100500, 20); + + UNIT_ASSERT_EQUAL(v.Get(0, 1), 1); + UNIT_ASSERT(v.Test(0)); + UNIT_ASSERT_EQUAL(v.Get(1, 1), 0); + UNIT_ASSERT_EQUAL(v.Get(2, 3), 1); + UNIT_ASSERT_EQUAL(v.Get(5, 4), 10); + UNIT_ASSERT_EQUAL(v.Get(9, 20), 100500); + + v.Reset(0); + v.Set(9, 1234, 15); + UNIT_ASSERT_EQUAL(v.Get(0, 1), 0); + UNIT_ASSERT(!v.Test(0)); + UNIT_ASSERT_EQUAL(v.Get(9, 15), 1234); + + UNIT_ASSERT_EQUAL(v.Size(), 29); + UNIT_ASSERT_EQUAL(v.Words(), 1); + } + + Y_UNIT_TEST(TestManyWords) { + static const int BITS = 10; + TBitVector<ui64> v; + + for (int i = 0, end = (1 << BITS); i < end; ++i) + v.Append(i, BITS); + + UNIT_ASSERT_EQUAL(v.Size(), BITS * (1 << BITS)); + UNIT_ASSERT_EQUAL(v.Words(), (v.Size() + 63) / 64); + for (int i = 0, end = (1 << BITS); i < end; ++i) + UNIT_ASSERT_EQUAL(v.Get(i * BITS, BITS), (ui64)i); + } + + Y_UNIT_TEST(TestMaxWordSize) { + TBitVector<ui32> v; + for (int i = 0; i < 100; ++i) + v.Append(i, 32); + + for (int i = 0; i < 100; ++i) + UNIT_ASSERT_EQUAL(v.Get(i * 32, 32), (ui32)i); + + v.Set(10 * 32, 100500, 32); + UNIT_ASSERT_EQUAL(v.Get(10 * 32, 32), 100500); + } + + Y_UNIT_TEST(TestReadonlyVector) { + TBitVector<ui64> v(100); + for (ui64 i = 0; i < v.Size(); ++i) { + if (i % 3 == 0) { + v.Set(i); + } + } + TBufferStream bs; + TReadonlyBitVector<ui64>::SaveForReadonlyAccess(&bs, v); + const auto blob = TBlob::FromBuffer(bs.Buffer()); + TReadonlyBitVector<ui64> rv; + rv.LoadFromBlob(blob); + for (ui64 i = 0; i < rv.Size(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(rv.Test(i), i % 3 == 0); + } + } +} diff --git a/library/cpp/containers/bitseq/readonly_bitvector.cpp b/library/cpp/containers/bitseq/readonly_bitvector.cpp new file mode 100644 index 00000000000..891aa7cde28 --- /dev/null +++ b/library/cpp/containers/bitseq/readonly_bitvector.cpp @@ -0,0 +1 @@ +#include "readonly_bitvector.h" diff --git a/library/cpp/containers/bitseq/readonly_bitvector.h b/library/cpp/containers/bitseq/readonly_bitvector.h new file mode 100644 index 00000000000..8612739c3f7 --- /dev/null +++ b/library/cpp/containers/bitseq/readonly_bitvector.h @@ -0,0 +1,76 @@ +#pragma once + +#include "bitvector.h" +#include "traits.h" + +#include <util/memory/blob.h> + +#include <cstring> + +template <typename T> +class TReadonlyBitVector { +public: + using TWord = T; + using TTraits = TBitSeqTraits<TWord>; + + TReadonlyBitVector() + : Size_() + , Data_() + { + } + + explicit TReadonlyBitVector(const TBitVector<T>& vector) + : Size_(vector.Size_) + , Data_(vector.Data_.data()) + { + } + + bool Test(ui64 pos) const { + return TTraits::Test(Data_, pos, Size_); + } + + TWord Get(ui64 pos, ui8 width, TWord mask) const { + return TTraits::Get(Data_, pos, width, mask, Size_); + } + + TWord Get(ui64 pos, ui8 width) const { + return Get(pos, width, TTraits::ElemMask(width)); + } + + ui64 Size() const { + return Size_; + } + + const T* Data() const { + return Data_; + } + + static void SaveForReadonlyAccess(IOutputStream* out, const TBitVector<T>& bv) { + ::Save(out, bv.Size_); + ::Save(out, static_cast<ui64>(bv.Data_.size())); + ::SavePodArray(out, bv.Data_.data(), bv.Data_.size()); + } + + virtual TBlob LoadFromBlob(const TBlob& blob) { + size_t read = 0; + auto cursor = [&]() { return blob.AsUnsignedCharPtr() + read; }; + auto readToPtr = [&](auto* ptr) { + memcpy(ptr, cursor(), sizeof(*ptr)); + read += sizeof(*ptr); + }; + + readToPtr(&Size_); + + ui64 wordCount{}; + readToPtr(&wordCount); + + Data_ = reinterpret_cast<const T*>(cursor()); + read += wordCount * sizeof(T); + + return blob.SubBlob(read, blob.Size()); + } + +private: + ui64 Size_; + const T* Data_; +}; diff --git a/library/cpp/containers/bitseq/traits.h b/library/cpp/containers/bitseq/traits.h new file mode 100644 index 00000000000..2330b1b4f29 --- /dev/null +++ b/library/cpp/containers/bitseq/traits.h @@ -0,0 +1,49 @@ +#pragma once + +#include <util/generic/bitops.h> +#include <util/generic/typetraits.h> +#include <util/system/yassert.h> + +template <typename TWord> +struct TBitSeqTraits { + static constexpr ui8 NumBits = CHAR_BIT * sizeof(TWord); + static constexpr TWord ModMask = static_cast<TWord>(NumBits - 1); + static constexpr TWord DivShift = MostSignificantBitCT(NumBits); + + static inline TWord ElemMask(ui8 count) { + // NOTE: Shifting by the type's length is UB, so we need this workaround. + if (Y_LIKELY(count)) + return TWord(-1) >> (NumBits - count); + return 0; + } + + static inline TWord BitMask(ui8 pos) { + return TWord(1) << pos; + } + + static size_t NumOfWords(size_t bits) { + return (bits + NumBits - 1) >> DivShift; + } + + static bool Test(const TWord* data, ui64 pos, ui64 size) { + Y_ASSERT(pos < size); + return data[pos >> DivShift] & BitMask(pos & ModMask); + } + + static TWord Get(const TWord* data, ui64 pos, ui8 width, TWord mask, ui64 size) { + if (!width) + return 0; + Y_ASSERT((pos + width) <= size); + size_t word = pos >> DivShift; + TWord shift1 = pos & ModMask; + TWord shift2 = NumBits - shift1; + TWord res = data[word] >> shift1 & mask; + if (shift2 < width) { + res |= data[word + 1] << shift2 & mask; + } + return res; + } + + static_assert(std::is_unsigned<TWord>::value, "Expected std::is_unsigned<T>::value."); + static_assert((NumBits & (NumBits - 1)) == 0, "NumBits should be a power of 2."); +}; diff --git a/library/cpp/containers/bitseq/ut/ya.make b/library/cpp/containers/bitseq/ut/ya.make new file mode 100644 index 00000000000..7155e82c06e --- /dev/null +++ b/library/cpp/containers/bitseq/ut/ya.make @@ -0,0 +1,10 @@ +UNITTEST_FOR(library/cpp/containers/bitseq) + +OWNER(g:util) + +SRCS( + bititerator_ut.cpp + bitvector_ut.cpp +) + +END() diff --git a/library/cpp/containers/bitseq/ya.make b/library/cpp/containers/bitseq/ya.make new file mode 100644 index 00000000000..7090956c557 --- /dev/null +++ b/library/cpp/containers/bitseq/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER(g:util) + +PEERDIR( + util/draft + library/cpp/pop_count +) + +SRCS( + bitvector.cpp + readonly_bitvector.cpp +) + +END() diff --git a/library/cpp/containers/compact_vector/compact_vector.cpp b/library/cpp/containers/compact_vector/compact_vector.cpp new file mode 100644 index 00000000000..cca77643e94 --- /dev/null +++ b/library/cpp/containers/compact_vector/compact_vector.cpp @@ -0,0 +1 @@ +#include "compact_vector.h" diff --git a/library/cpp/containers/compact_vector/compact_vector.h b/library/cpp/containers/compact_vector/compact_vector.h new file mode 100644 index 00000000000..dbe7473f0cc --- /dev/null +++ b/library/cpp/containers/compact_vector/compact_vector.h @@ -0,0 +1,209 @@ +#pragma once + +#include <util/generic/yexception.h> +#include <util/generic/utility.h> +#include <util/memory/alloc.h> +#include <util/stream/output.h> +#include <util/system/yassert.h> + +#include <cstdlib> + +// vector that is 8 bytes when empty (TVector is 24 bytes) + +template <typename T> +class TCompactVector { +private: + typedef TCompactVector<T> TThis; + + // XXX: make header independent on T and introduce nullptr + struct THeader { + size_t Size; + size_t Capacity; + }; + + T* Ptr; + + THeader* Header() { + return ((THeader*)Ptr) - 1; + } + + const THeader* Header() const { + return ((THeader*)Ptr) - 1; + } + +public: + typedef T* TIterator; + typedef const T* TConstIterator; + + typedef TIterator iterator; + typedef TConstIterator const_iterator; + + TCompactVector() + : Ptr(nullptr) + { + } + + TCompactVector(const TThis& that) + : Ptr(nullptr) + { + Reserve(that.Size()); + for (TConstIterator i = that.Begin(); i != that.End(); ++i) { + PushBack(*i); + } + } + + ~TCompactVector() { + for (size_t i = 0; i < Size(); ++i) { + try { + (*this)[i].~T(); + } catch (...) { + } + } + if (Ptr) + free(Header()); + } + + TIterator Begin() { + return Ptr; + } + + TIterator End() { + return Ptr + Size(); + } + + TConstIterator Begin() const { + return Ptr; + } + + TConstIterator End() const { + return Ptr + Size(); + } + + iterator begin() { + return Begin(); + } + + const_iterator begin() const { + return Begin(); + } + + iterator end() { + return End(); + } + + const_iterator end() const { + return End(); + } + + void Swap(TThis& that) { + DoSwap(Ptr, that.Ptr); + } + + void Reserve(size_t newCapacity) { + if (newCapacity <= Capacity()) { + } else if (Ptr == nullptr) { + void* mem = ::malloc(sizeof(THeader) + newCapacity * sizeof(T)); + if (mem == nullptr) + ythrow yexception() << "out of memory"; + Ptr = (T*)(((THeader*)mem) + 1); + Header()->Size = 0; + Header()->Capacity = newCapacity; + } else { + TThis copy; + size_t realNewCapacity = Max(Capacity() * 2, newCapacity); + copy.Reserve(realNewCapacity); + for (TConstIterator it = Begin(); it != End(); ++it) { + copy.PushBack(*it); + } + Swap(copy); + } + } + + size_t Size() const { + return Ptr ? Header()->Size : 0; + } + + size_t size() const { + return Size(); + } + + bool Empty() const { + return Size() == 0; + } + + bool empty() const { + return Empty(); + } + + size_t Capacity() const { + return Ptr ? Header()->Capacity : 0; + } + + void PushBack(const T& elem) { + Reserve(Size() + 1); + new (Ptr + Size()) T(elem); + ++(Header()->Size); + } + + T& Back() { + return *(End() - 1); + } + + const T& Back() const { + return *(End() - 1); + } + + T& back() { + return Back(); + } + + const T& back() const { + return Back(); + } + + TIterator Insert(TIterator pos, const T& elem) { + Y_ASSERT(pos >= Begin()); + Y_ASSERT(pos <= End()); + + size_t posn = pos - Begin(); + if (pos == End()) { + PushBack(elem); + } else { + Y_ASSERT(Size() > 0); + + Reserve(Size() + 1); + + PushBack(*(End() - 1)); + + for (size_t i = Size() - 2; i + 1 > posn; --i) { + (*this)[i + 1] = (*this)[i]; + } + + (*this)[posn] = elem; + } + return Begin() + posn; + } + + iterator insert(iterator pos, const T& elem) { + return Insert(pos, elem); + } + + void Clear() { + TThis clean; + Swap(clean); + } + + void clear() { + Clear(); + } + + T& operator[](size_t index) { + Y_ASSERT(index < Size()); + return Ptr[index]; + } + + const T& operator[](size_t index) const { + Y_ASSERT(index < Size()); + return Ptr[index]; + } +}; diff --git a/library/cpp/containers/compact_vector/compact_vector_ut.cpp b/library/cpp/containers/compact_vector/compact_vector_ut.cpp new file mode 100644 index 00000000000..7d413d65759 --- /dev/null +++ b/library/cpp/containers/compact_vector/compact_vector_ut.cpp @@ -0,0 +1,46 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "compact_vector.h" + +Y_UNIT_TEST_SUITE(TCompactVectorTest) { + Y_UNIT_TEST(TestSimple1) { + } + + Y_UNIT_TEST(TestSimple) { + TCompactVector<ui32> vector; + for (ui32 i = 0; i < 10000; ++i) { + vector.PushBack(i + 20); + UNIT_ASSERT_VALUES_EQUAL(i + 1, vector.Size()); + } + for (ui32 i = 0; i < 10000; ++i) { + UNIT_ASSERT_VALUES_EQUAL(i + 20, vector[i]); + } + } + + Y_UNIT_TEST(TestInsert) { + TCompactVector<ui32> vector; + + for (ui32 i = 0; i < 10; ++i) { + vector.PushBack(i + 2); + } + + vector.Insert(vector.Begin(), 99); + + UNIT_ASSERT_VALUES_EQUAL(11u, vector.Size()); + UNIT_ASSERT_VALUES_EQUAL(99u, vector[0]); + for (ui32 i = 0; i < 10; ++i) { + UNIT_ASSERT_VALUES_EQUAL(i + 2, vector[i + 1]); + } + + vector.Insert(vector.Begin() + 3, 77); + + UNIT_ASSERT_VALUES_EQUAL(12u, vector.Size()); + UNIT_ASSERT_VALUES_EQUAL(99u, vector[0]); + UNIT_ASSERT_VALUES_EQUAL(2u, vector[1]); + UNIT_ASSERT_VALUES_EQUAL(3u, vector[2]); + UNIT_ASSERT_VALUES_EQUAL(77u, vector[3]); + UNIT_ASSERT_VALUES_EQUAL(4u, vector[4]); + UNIT_ASSERT_VALUES_EQUAL(5u, vector[5]); + UNIT_ASSERT_VALUES_EQUAL(11u, vector[11]); + } +} diff --git a/library/cpp/containers/compact_vector/ut/ya.make b/library/cpp/containers/compact_vector/ut/ya.make new file mode 100644 index 00000000000..5e655bc619f --- /dev/null +++ b/library/cpp/containers/compact_vector/ut/ya.make @@ -0,0 +1,11 @@ +UNITTEST() + +OWNER(nga) + +SRCDIR(library/cpp/containers/compact_vector) + +SRCS( + compact_vector_ut.cpp +) + +END() diff --git a/library/cpp/containers/compact_vector/ya.make b/library/cpp/containers/compact_vector/ya.make new file mode 100644 index 00000000000..6c23e8d0c11 --- /dev/null +++ b/library/cpp/containers/compact_vector/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(nga) + +SRCS( + compact_vector.cpp +) + +END() diff --git a/library/cpp/containers/comptrie/README.md b/library/cpp/containers/comptrie/README.md new file mode 100644 index 00000000000..43c298e2c83 --- /dev/null +++ b/library/cpp/containers/comptrie/README.md @@ -0,0 +1,232 @@ +Compact trie +============= + +The comptrie library is a fast and very tightly packed +implementation of a prefix tree (Sedgewick's T-trie, that is a ternary tree, +see https://www.cs.princeton.edu/~rs/strings/paper.pdf, +https://www.cs.upc.edu/~ps/downloads/tst/tst.html). It contains tools for creating, optimizing, and serializing trees, accessing by key, and performing +various searches. Because it is template-based and performance-oriented, a significant +part of the library consists of inline functions, so if you don't need all the +features of the library, consider including a more specific header file instead of the top-level +comptrie.h file. + +Description of the data structure +--------------------------------- + +A prefix tree is an implementation of the map data structure +for cases when keys are sequences of characters. The nodes on this tree +contain characters and values. The key that corresponds to a +value is a sequence of all the characters on nodes along the path from the root to the +node with this value. It follows that a tree node can have as many children as +there are different characters, which is quite a lot. For compact storage and +quick access, the children are ordered as a binary search tree, which should be +balanced if possible (characters have a defined order that must be +preserved). Thus, our tree is ternary: each node has a reference to a subtree +of all children and two references to subtrees of siblings. One of these subtrees contains the +younger siblings, and the other contains the elder siblings. + +The library implements tree optimization by merging identical subtrees, which means +the tree becomes a DAG (Directed Acyclic Graph β +an oriented graph without oriented cycles). + +The main class TCompactTrie is defined in comptrie_trie.h and is templatized: +- The first parameter of the template is the character type. It should be an +integer type, which means that arithmetical operations must be defined for it. +- The second parameter of the template is the value type. +- The third parameter is the packer class, which packs values in order to quickly and compactly +serialize the value type to a continuous memory buffer, deserialize it +back, and quickly determine its size using the pointer to the beginning of this +memory buffer. Good packers have already been written for most types, and they are available in +library/cpp/packers. For more information, please refer to the documentation for these packers. + +The set.h file defines a modification for cases when keys must be stored +without values. + +When a tree is built from scratch, the value corresponding to an empty key is +assigned to a single-character key '\0'. So in a tree with the 'char' character type, +the empty key and the '\0' key are bound together. For a subtree received from +a call to FindTails, this restriction no longer exists. + +Creating trees +-------------- + +Building a tree from a list of key-value pairs is performed by the +TCompactTrieBuilder class described in the comptrie_builder.h file. + +This class allows you to add words to a tree one at a time, merge a complete +subtree, and also use an unfinished tree as a map. + +An important optimization is the prefix-grouped mode when you need to add keys +in a certain order (for details, see the comments in the header file). The resulting tree is compactly packed while keys are being added, and the memory consumption is approximately the same as for +the completed tree. For the default mode, compact stacking is turned on at the +very end, and the data consumes quite a lot of memory up until that point. + +Optimizing trees +---------------- + +After a tree is created, there are two optimizing operations that can be applied: + - Minimization to a DAG by merging equal subtrees. + - Fast memory layout. +The functions that implement these operations are declared in the comptrie_builder.h file. The first +optimization is implemented by the CompactTrieMinimize function, and the second is implemented by +CompactTrieMakeFastLayout. You can perform both at once by calling the +CompactTrieMinimizeAndMakeFastLayout function. + +### Minimization ### + +Minimization to a DAG requires quite a lot of time and a large amount of +memory to store an array of subtrees, but it can reduce the size of the tree several +times over (an example is a language model that has many low-frequency +phrases with repeated last words and frequency counts). However, if you know +in advance that there are no duplicate values in the tree, you don't need to waste time on it, since the minimization +won't have any effect on the tree. + +### Fast memory layout ### + +The second optimization function results in fewer cache misses, but it causes the +tree to grow in size. Our experience has shown a 5% gain +in speed for some tries. The algorithm consumes about three times more memory than +the amount required for the source tree. So if the machine has enough memory to +assemble a tree, it does not neccessarily mean that it has enough memory to run +the algorithm. To learn about the theory behind this algorithm, read the comments before the declaration of the CompactTrieMinimize function. + +Serializing trees +----------------- + +The tree resides in memory as a sequence of nodes. Links to other nodes are always +counted relative to the position of the current node. This allows you to save a +tree to disk as it is and then re-load it using mmap(). The TCompactTrie class has the +TBlob constructor for reading a tree from disk. The TCompactTrieBuilder class has +Save/SaveToFile methods for writing a built tree to a stream or a file. + +Accessing trees +--------------- + +As a rule, all methods that accept a key as input have two variants: +- One takes the key in the format: pointer to the beginning of the key, length. +- The other takes a high-level type like TStringBuf. + +You can get a value for a key in the tree β TCompactTrie::Find returns +false if there is no key, and TCompactTrie::Get throws an exception. You can use FindPrefix methods to find the longest key prefix in a tree and get the corresponding value for it. +You can also use a single FindPhrases request to get values for all the beginnings of +a phrase with a given word delimiter. + +An important operation that distinguishes a tree from a simple map is implemented in the FindTails method, +which allows you to obtain a subtree consisting of all possible extensions of the +given prefix. + +Iterators for trees +------------------- + +First of all, there is a typical map iterator over all key-value pairs called +TConstIterator. A tree has three methods that return it: Begin, End, and +UpperBound. The latter takes a key as input and returns an iterator to the +smallest key that is not smaller than the input key. + +The rest of the iterators are not so widely used, and thus are located in +separate files. + +TPrefixIterator is defined in the prefix_iterator.h file. It allows +iterations over all the prefixes of this key available in the tree. + +TSearchIterator is defined in the search_iterator.h file. It allows you to enter +a key in a tree one character at a time and see where it ends up. The following character can +be selected depending on the current result. You can also copy the iterator and +proceed on two different paths. You can actually achieve the same result with +repeated use of the FindTails method, but the authors of this iterator claim +that they obtained a performance gain with it. + +Appendix. Memory implementation details +--------------------------------------- + +*If you are not going to modify the library, then you do not need to read further.* + +First, if the character type has a size larger than 1 byte, then all keys that use these characters are converted to byte strings in the big-endian way. This +means that character bytes are written in a string from the most significant +to the least significant from left to right. Thus it is reduced to the case when +the character in use is 'char'. + +The tree resides in memory as a series of consecutive nodes. The nodes can have different +sizes, so the only way to identify the boundaries of nodes is by passing the entire +tree. + +### Node structure ### + +The structure of a node, as can be understood from thoughtfully reading the +LeapByte function in Comptrie_impl.h, is the following: +- The first byte is for service flags. +- The second byte is a character (unless it is the Ξ΅-link type of node + described below, which has from 1 to 7 bytes of offset distance from the + beginning of this node to the content node, and nothing else). + +Thus, the size of any node is at least 2 bytes. All other elements of a node +are optional. Next there is from 0 to 7 bytes of the packed offset from the beginning +of this node to the beginning of the root node of a subtree with the younger +siblings. It is followed by 0 to 7 bytes of the packed offset from the beginning of this +node to the beginning of the root node of a subtree with the elder siblings. +Next comes the packed value in this node. Its size is not limited, but you may +recall that the packer allows you to quickly determine this size using a pointer +to the beginning of the packed value. Then, if the service flags indicate +that the tree has children, there is a root node of the subtree of children. + +The packed offset is restricted to 7 bytes, and this gives us a limit on the largest +possible size of a tree. You need to study the packer code to understand +the exact limit. + +All packed offsets are nonnegative, meaning that roots of subtrees with +siblings and the node pointed to by the Ξ΅-link must be located +strictly to the right of the current node in memory. This does not allow placement of +finite state machines with oriented cycles in the comptrie. But it does allow you to +effectively stack the comptrie from right to left. + +### Service flags ### + +The byte of service flags contains (as shown by the constants at the beginning of +the comptrie_impl.h file): +- 1 bit of MT_NEXT, indicating whether this node has children. +- 1 bit of MT_FINAL, indicating if there is a value in this node. +- 3 bits of MT_SIZEMASK, indicating the size of the packed offset to a subtree + with elder siblings. +- 3 bits of MT_SIZEMASK << MT_LEFTSHIFT, indicating the size of the packed + offset to a subtree with younger siblings. +If one of these subtrees is not present, then the size of the corresponding +packed offset is 0, and vice versa. + +### Ξ΅-links ### + +These nodes only occur if we optimized a tree into a DAG and got two nodes with +merged subtrees of children. Since the offset to the subtree of children can't be +specified and the root of this subtree should lie just after the value, we have +to add a node of the Ξ΅-link type, which contains the offset to the root subtree of +children and nothing more. This applies to all nodes that have equal subtrees of children, +except the rightmost node. The size of this offset is set in 3 bits of MT_SIZEMASK +flags for a node. + +As the implementation of the IsEpsilonLink function in +comptrie_impl.h demonstrates, the Ξ΅-link differs from other nodes in that it does not have the MT_NEXT flag or the MT_FINAL + flag, so it can always be +identified by the flags. Of course, the best programming practice is to call the +function itself instead of examining the flags. + +Note that the Ξ΅-link flags do not use the MT_SIZEMASK << +MT_LEFTSHIFT` bits, which allows us to start using Ξ΅-links for some other purpose. + +Pattern Searcher +================ + +This is an implementation of Aho-Corasick algorithm on compact trie structure. +In order to create a pattern searcher one must fill a TCompactPatternSearcherBuilder +with patterns and call SaveAsPatternSearcher or SaveToFileAsPatternSearcher. +Then TCompactPatternSearcher must be created from the builder output. + +### Implementation details ### + +Aho-Corasick algorithm stores a suffix link in each node. +A suffix link of a node is the offset (relative to this node) of the largest suffix +of a string this node represents which is present in a trie. +Current implementation also stores a shortcut link to the largest suffix +for which the corresponding node in a trie is a final node. +These two links are stored as NCompactTrie::TSuffixLink structure of two 64-bit +integers. +In a trie layout these links are stored for each node right after the two bytes +containing service flags and a symbol. diff --git a/library/cpp/containers/comptrie/array_with_size.h b/library/cpp/containers/comptrie/array_with_size.h new file mode 100644 index 00000000000..36e61c74103 --- /dev/null +++ b/library/cpp/containers/comptrie/array_with_size.h @@ -0,0 +1,67 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/generic/noncopyable.h> +#include <util/generic/utility.h> +#include <util/system/sys_alloc.h> + +template <typename T> +class TArrayWithSizeHolder : TNonCopyable { + typedef TArrayWithSizeHolder<T> TThis; + + T* Data; + +public: + TArrayWithSizeHolder() + : Data(nullptr) + { + } + + ~TArrayWithSizeHolder() { + if (!Data) + return; + for (size_t i = 0; i < Size(); ++i) { + try { + Data[i].~T(); + } catch (...) { + } + } + y_deallocate(((size_t*)Data) - 1); + } + + void Swap(TThis& copy) { + DoSwap(Data, copy.Data); + } + + void Resize(size_t newSize) { + if (newSize == Size()) + return; + TThis copy; + copy.Data = (T*)(((size_t*)y_allocate(sizeof(size_t) + sizeof(T) * newSize)) + 1); + // does not handle constructor exceptions properly + for (size_t i = 0; i < Min(Size(), newSize); ++i) { + new (copy.Data + i) T(Data[i]); + } + for (size_t i = Min(Size(), newSize); i < newSize; ++i) { + new (copy.Data + i) T; + } + ((size_t*)copy.Data)[-1] = newSize; + Swap(copy); + } + + size_t Size() const { + return Data ? ((size_t*)Data)[-1] : 0; + } + + bool Empty() const { + return Size() == 0; + } + + T* Get() { + return Data; + } + + const T* Get() const { + return Data; + } +}; diff --git a/library/cpp/containers/comptrie/benchmark/main.cpp b/library/cpp/containers/comptrie/benchmark/main.cpp new file mode 100644 index 00000000000..6e42dad18ac --- /dev/null +++ b/library/cpp/containers/comptrie/benchmark/main.cpp @@ -0,0 +1,260 @@ +#include <library/cpp/testing/benchmark/bench.h> + +#include <library/cpp/containers/comptrie/comptrie_trie.h> +#include <library/cpp/containers/comptrie/comptrie_builder.h> +#include <library/cpp/containers/comptrie/search_iterator.h> +#include <library/cpp/containers/comptrie/pattern_searcher.h> + +#include <library/cpp/on_disk/aho_corasick/writer.h> +#include <library/cpp/on_disk/aho_corasick/reader.h> +#include <library/cpp/on_disk/aho_corasick/helpers.h> + +#include <library/cpp/containers/dense_hash/dense_hash.h> + +#include <util/stream/file.h> +#include <util/generic/algorithm.h> +#include <util/random/fast.h> +#include <util/random/shuffle.h> + +///////////////// +// COMMON DATA // +///////////////// + +const size_t MAX_PATTERN_LENGTH = 11; + +TVector<TString> letters = { + "Π°", "Π±", "Π²", "Π³", "Π΄", "Π΅", "Ρ", "ΠΆ", "Π·", "ΠΈ", "ΠΉ", + "ΠΊ", "Π»", "ΠΌ", "Π½", "ΠΎ", "ΠΏ", "Ρ", "Ρ", "Ρ", "Ρ", "Ρ", + "Ρ
", "Ρ", "Ρ", "ΠΆ", "Ρ", "Ρ", "Ρ", "Ρ", "Ρ", "Ρ", "Ρ" +}; + +TString GenerateOneString( + TFastRng<ui64>& rng, + size_t maxLength, + const TVector<TString>& sequences +) { + size_t length = rng.GenRand() % maxLength + 1; + TString result; + while (result.size() < length) { + result += sequences[rng.GenRand() % sequences.size()]; + } + return result; +} + +TVector<TString> GenerateStrings( + TFastRng<ui64>& rng, + size_t num, + size_t maxLength, + const TVector<TString>& sequences +) { + TVector<TString> strings; + while (strings.size() < num) { + strings.push_back(GenerateOneString(rng, maxLength, sequences)); + } + return strings; +} + +struct TDatasetInstance { + TDatasetInstance(const TVector<TString>& sequences) { + TFastRng<ui64> rng(0); + + TVector<TString> prefixes = GenerateStrings(rng, /*num*/10, /*maxLength*/3, sequences); + prefixes.push_back(""); + + TVector<TString> roots = GenerateStrings(rng, /*num*/1000, /*maxLength*/5, sequences); + + TVector<TString> suffixes = GenerateStrings(rng, /*num*/10, /*maxLength*/3, sequences); + suffixes.push_back(""); + + TVector<TString> dictionary; + for (const auto& root : roots) { + for (const auto& prefix : prefixes) { + for (const auto& suffix : suffixes) { + dictionary.push_back(prefix + root + suffix); + Y_ASSERT(dictionary.back().size() < MAX_PATTERN_LENGTH); + } + } + } + Shuffle(dictionary.begin(), dictionary.end()); + + Patterns.assign(dictionary.begin(), dictionary.begin() + 10'000); + + for (size_t sampleIdx = 0; sampleIdx < /*samplesNum*/1'000'000; ++sampleIdx) { + Samples.emplace_back(); + size_t wordsNum = rng.GenRand() % 10; + for (size_t wordIdx = 0; wordIdx < wordsNum; ++wordIdx) { + if (wordIdx > 0) { + Samples.back() += " "; + } + Samples.back() += dictionary[rng.GenRand() % dictionary.size()]; + } + } + }; + + TString GetSample(size_t iteration) const { + TFastRng<ui64> rng(iteration); + return Samples[rng.GenRand() % Samples.size()]; + } + + + TVector<TString> Patterns; + TVector<TString> Samples; +}; + +static const TDatasetInstance dataset(letters); + +////////////////////////// +// NEW PATTERN SEARCHER // +////////////////////////// + +struct TPatternSearcherInstance { + TPatternSearcherInstance() { + TCompactPatternSearcherBuilder<char, ui32> builder; + + for (ui32 patternId = 0; patternId < dataset.Patterns.size(); ++patternId) { + builder.Add(dataset.Patterns[patternId], patternId); + } + + TBufferOutput buffer; + builder.Save(buffer); + + Instance.Reset( + new TCompactPatternSearcher<char, ui32>( + buffer.Buffer().Data(), + buffer.Buffer().Size() + ) + ); + } + + THolder<TCompactPatternSearcher<char, ui32>> Instance; +}; + +static const TPatternSearcherInstance patternSearcherInstance; + +Y_CPU_BENCHMARK(PatternSearcher, iface) { + TVector<TVector<std::pair<ui32, ui32>>> result; + for (size_t iteration = 0; iteration < iface.Iterations(); ++iteration) { + result.emplace_back(); + TString testString = dataset.GetSample(iteration); + auto matches = patternSearcherInstance.Instance->SearchMatches(testString); + for (auto& match : matches) { + result.back().emplace_back(match.End, match.Data); + } + } +} + +////////////////////// +// OLD AHO CORASICK // +////////////////////// + +struct TAhoCorasickInstance { + TAhoCorasickInstance() { + TAhoCorasickBuilder<TString, ui32> builder; + + for (ui32 patternId = 0; patternId < dataset.Patterns.size(); ++patternId) { + builder.AddString(dataset.Patterns[patternId], patternId); + } + + TBufferOutput buffer; + builder.SaveToStream(&buffer); + + Instance.Reset(new TDefaultMappedAhoCorasick(TBlob::FromBuffer(buffer.Buffer()))); + }; + + THolder<TDefaultMappedAhoCorasick> Instance; +}; + +static const TAhoCorasickInstance ahoCorasickInstance; + +Y_CPU_BENCHMARK(AhoCorasick, iface) { + TVector<TDeque<std::pair<ui32, ui32>>> result; + for (size_t iteration = 0; iteration < iface.Iterations(); ++iteration) { + result.emplace_back(); + TString testString = dataset.GetSample(iteration); + auto matches = ahoCorasickInstance.Instance->AhoSearch(testString); + result.push_back(matches); + } +} + +//////////////////////////////// +// COMPTRIE + SIMPLE MATCHING // +//////////////////////////////// + +struct TCompactTrieInstance { + TCompactTrieInstance() { + TCompactTrieBuilder<char, ui32> builder; + + for (ui32 patternId = 0; patternId < dataset.Patterns.size(); ++patternId) { + builder.Add(dataset.Patterns[patternId], patternId); + } + + + TBufferOutput buffer; + CompactTrieMinimizeAndMakeFastLayout(buffer, builder); + + Instance.Reset(new TCompactTrie<char, ui32>( + buffer.Buffer().Data(), + buffer.Buffer().Size() + )); + } + + THolder<TCompactTrie<char, ui32>> Instance; +}; + +static const TCompactTrieInstance compactTrieInstance; + +Y_CPU_BENCHMARK(ComptrieSimple, iface) { + TVector<TVector<std::pair<ui32, ui32>>> result; + for (size_t iteration = 0; iteration < iface.Iterations(); ++iteration) { + result.emplace_back(); + TString testString = dataset.GetSample(iteration); + for (ui32 startPos = 0; startPos < testString.size(); ++startPos) { + TSearchIterator<TCompactTrie<char, ui32>> iter(*(compactTrieInstance.Instance)); + for (ui32 position = startPos; position < testString.size(); ++position) { + if (!iter.Advance(testString[position])) { + break; + } + ui32 answer; + if (iter.GetValue(&answer)) { + result.back().emplace_back(position, answer); + } + } + } + } +} + +//////////////// +// DENSE_HASH // +//////////////// + +struct TDenseHashInstance { + TDenseHashInstance() { + for (ui32 patternId = 0; patternId < dataset.Patterns.size(); ++patternId) { + Instance[dataset.Patterns[patternId]] = patternId; + } + } + + TDenseHash<TString, ui32> Instance; +}; + +static const TDenseHashInstance denseHashInstance; + +Y_CPU_BENCHMARK(DenseHash, iface) { + TVector<TVector<std::pair<ui32, ui32>>> result; + for (size_t iteration = 0; iteration < iface.Iterations(); ++iteration) { + result.emplace_back(); + TString testString = dataset.GetSample(iteration); + for (size_t start = 0; start < testString.size(); ++start) { + for ( + size_t length = 1; + length <= MAX_PATTERN_LENGTH && start + length <= testString.size(); + ++length + ) { + auto value = denseHashInstance.Instance.find(testString.substr(start, length)); + if (value != denseHashInstance.Instance.end()) { + result.back().emplace_back(start + length - 1, value->second); + } + } + } + } +} diff --git a/library/cpp/containers/comptrie/benchmark/ya.make b/library/cpp/containers/comptrie/benchmark/ya.make new file mode 100644 index 00000000000..16fa19530d8 --- /dev/null +++ b/library/cpp/containers/comptrie/benchmark/ya.make @@ -0,0 +1,14 @@ +Y_BENCHMARK() + +OWNER(smirnovpavel) + +SRCS( + main.cpp +) + +PEERDIR( + library/cpp/containers/comptrie + util +) + +END() diff --git a/library/cpp/containers/comptrie/chunked_helpers_trie.h b/library/cpp/containers/comptrie/chunked_helpers_trie.h new file mode 100644 index 00000000000..cfa35f5ba2a --- /dev/null +++ b/library/cpp/containers/comptrie/chunked_helpers_trie.h @@ -0,0 +1,218 @@ +#pragma once + +#include <library/cpp/on_disk/chunks/chunked_helpers.h> + +#include "comptrie.h" + +class TTrieSet { +private: + TCompactTrie<char> Trie; + +public: + TTrieSet(const TBlob& blob) + : Trie(blob) + { + } + + bool Has(const char* key) const { + return Trie.Find(key, strlen(key)); + } + + bool FindLongestPrefix(const char* key, size_t keylen, size_t* prefixLen) { + return Trie.FindLongestPrefix(key, keylen, prefixLen); + } +}; + +template <bool sorted = false> +class TTrieSetWriter { +private: + TCompactTrieBuilder<char> Builder; + +public: + TTrieSetWriter(bool isSorted = sorted) + : Builder(isSorted ? CTBF_PREFIX_GROUPED : CTBF_NONE) + { + } + + void Add(const char* key, size_t keylen) { + Builder.Add(key, keylen, 0); + assert(Has(((TString)key).substr(0, keylen).data())); + } + + void Add(const char* key) { + Add(key, strlen(key)); + } + + bool Has(const char* key) const { + ui64 dummy; + return Builder.Find(key, strlen(key), &dummy); + } + + void Save(IOutputStream& out) const { + Builder.Save(out); + } + + void Clear() { + Builder.Clear(); + } +}; + +template <bool isWriter, bool sorted = false> +struct TTrieSetG; + +template <bool sorted> +struct TTrieSetG<false, sorted> { + typedef TTrieSet T; +}; + +template <bool sorted> +struct TTrieSetG<true, sorted> { + typedef TTrieSetWriter<sorted> T; +}; + +template <typename T> +class TTrieMap { +private: + TCompactTrie<char> Trie; + static_assert(sizeof(T) <= sizeof(ui64), "expect sizeof(T) <= sizeof(ui64)"); + +public: + TTrieMap(const TBlob& blob) + : Trie(blob) + { + } + + bool Get(const char* key, T* value) const { + ui64 trieValue; + if (Trie.Find(key, strlen(key), &trieValue)) { + *value = ReadUnaligned<T>(&trieValue); + return true; + } else { + return false; + } + } + + T Get(const char* key, T def = T()) const { + ui64 trieValue; + if (Trie.Find(key, strlen(key), &trieValue)) { + return ReadUnaligned<T>(&trieValue); + } else { + return def; + } + } + + const TCompactTrie<char>& GetTrie() const { + return Trie; + } +}; + +template <typename T, bool sorted = false> +class TTrieMapWriter { +private: + typedef TCompactTrieBuilder<char> TBuilder; + TBuilder Builder; + static_assert(sizeof(T) <= sizeof(ui64), "expect sizeof(T) <= sizeof(ui64)"); +#ifndef NDEBUG + bool IsSorted; +#endif + +public: + TTrieMapWriter(bool isSorted = sorted) + : Builder(isSorted ? CTBF_PREFIX_GROUPED : CTBF_NONE) +#ifndef NDEBUG + , IsSorted(isSorted) +#endif + { + } + + void Add(const char* key, const T& value) { + ui64 intValue = 0; + memcpy(&intValue, &value, sizeof(T)); + Builder.Add(key, strlen(key), intValue); +#ifndef NDEBUG + /* + if (!IsSorted) { + T test; + assert(Get(key, &test) && value == test); + } + */ +#endif + } + + void Add(const TString& s, const T& value) { + ui64 intValue = 0; + memcpy(&intValue, &value, sizeof(T)); + Builder.Add(s.data(), s.size(), intValue); + } + + bool Get(const char* key, T* value) const { + ui64 trieValue; + if (Builder.Find(key, strlen(key), &trieValue)) { + *value = ReadUnaligned<T>(&trieValue); + return true; + } else { + return false; + } + } + + T Get(const char* key, T def = (T)0) const { + ui64 trieValue; + if (Builder.Find(key, strlen(key), &trieValue)) { + return ReadUnaligned<T>(&trieValue); + } else { + return def; + } + } + + void Save(IOutputStream& out, bool minimize = false) const { + if (minimize) { + CompactTrieMinimize<TBuilder>(out, Builder, false); + } else { + Builder.Save(out); + } + } + + void Clear() { + Builder.Clear(); + } +}; + +template <typename T> +class TTrieSortedMapWriter { +private: + typedef std::pair<TString, T> TValue; + typedef TVector<TValue> TValues; + TValues Values; + +public: + TTrieSortedMapWriter() = default; + + void Add(const char* key, const T& value) { + Values.push_back(TValue(key, value)); + } + + void Save(IOutputStream& out) { + Sort(Values.begin(), Values.end()); + TTrieMapWriter<T, true> writer; + for (typename TValues::const_iterator toValue = Values.begin(); toValue != Values.end(); ++toValue) + writer.Add(toValue->first.data(), toValue->second); + writer.Save(out); + } + + void Clear() { + Values.clear(); + } +}; + +template <typename X, bool isWriter, bool sorted = false> +struct TTrieMapG; + +template <typename X, bool sorted> +struct TTrieMapG<X, false, sorted> { + typedef TTrieMap<X> T; +}; + +template <typename X, bool sorted> +struct TTrieMapG<X, true, sorted> { + typedef TTrieMapWriter<X, sorted> T; +}; diff --git a/library/cpp/containers/comptrie/comptrie.cpp b/library/cpp/containers/comptrie/comptrie.cpp new file mode 100644 index 00000000000..4556e5b5719 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie.cpp @@ -0,0 +1,8 @@ +#include "comptrie_impl.h" +#include "comptrie.h" +#include "array_with_size.h" +#include "comptrie_trie.h" +#include "comptrie_builder.h" +#include "protopacker.h" +#include "set.h" +#include "chunked_helpers_trie.h" diff --git a/library/cpp/containers/comptrie/comptrie.h b/library/cpp/containers/comptrie/comptrie.h new file mode 100644 index 00000000000..f77024327e0 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie.h @@ -0,0 +1,4 @@ +#pragma once + +#include "comptrie_trie.h" +#include "comptrie_builder.h" diff --git a/library/cpp/containers/comptrie/comptrie_builder.cpp b/library/cpp/containers/comptrie/comptrie_builder.cpp new file mode 100644 index 00000000000..28a1e41dd24 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_builder.cpp @@ -0,0 +1 @@ +#include "comptrie_builder.h" diff --git a/library/cpp/containers/comptrie/comptrie_builder.h b/library/cpp/containers/comptrie/comptrie_builder.h new file mode 100644 index 00000000000..cf7d2e39a34 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_builder.h @@ -0,0 +1,159 @@ +#pragma once + +#include "comptrie_packer.h" +#include "minimize.h" +#include "key_selector.h" + +#include <util/stream/file.h> + +// -------------------------------------------------------------------------------------- +// Data Builder +// To build the data buffer, we first create an automaton in memory. The automaton +// is created incrementally. It actually helps a lot to have the input data prefix-grouped +// by key; otherwise, memory consumption becomes a tough issue. +// NOTE: building and serializing the automaton may be lengthy, and takes lots of memory. + +// PREFIX_GROUPED means that if we, while constructing a trie, add to the builder two keys with the same prefix, +// then all the keys that we add between these two also have the same prefix. +// Actually in this mode the builder can accept even more freely ordered input, +// but for input as above it is guaranteed to work. +enum ECompactTrieBuilderFlags { + CTBF_NONE = 0, + CTBF_PREFIX_GROUPED = 1 << 0, + CTBF_VERBOSE = 1 << 1, + CTBF_UNIQUE = 1 << 2, +}; + +using TCompactTrieBuilderFlags = ECompactTrieBuilderFlags; + +inline TCompactTrieBuilderFlags operator|(TCompactTrieBuilderFlags first, TCompactTrieBuilderFlags second) { + return static_cast<TCompactTrieBuilderFlags>(static_cast<int>(first) | second); +} + +inline TCompactTrieBuilderFlags& operator|=(TCompactTrieBuilderFlags& first, TCompactTrieBuilderFlags second) { + return first = first | second; +} + +template <typename T> +class TArrayWithSizeHolder; + +template <class T = char, class D = ui64, class S = TCompactTriePacker<D>> +class TCompactTrieBuilder { +public: + typedef T TSymbol; + typedef D TData; + typedef S TPacker; + typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey; + typedef typename TCompactTrieKeySelector<TSymbol>::TKeyBuf TKeyBuf; + + explicit TCompactTrieBuilder(TCompactTrieBuilderFlags flags = CTBF_NONE, TPacker packer = TPacker(), IAllocator* alloc = TDefaultAllocator::Instance()); + + // All Add.. methods return true if it was a new key, false if the key already existed. + + bool Add(const TSymbol* key, size_t keylen, const TData& value); + bool Add(const TKeyBuf& key, const TData& value) { + return Add(key.data(), key.size(), value); + } + + // add already serialized data + bool AddPtr(const TSymbol* key, size_t keylen, const char* data); + bool AddPtr(const TKeyBuf& key, const char* data) { + return AddPtr(key.data(), key.size(), data); + } + + bool AddSubtreeInFile(const TSymbol* key, size_t keylen, const TString& filename); + bool AddSubtreeInFile(const TKeyBuf& key, const TString& filename) { + return AddSubtreeInFile(key.data(), key.size(), filename); + } + + bool AddSubtreeInBuffer(const TSymbol* key, size_t keylen, TArrayWithSizeHolder<char>&& buffer); + bool AddSubtreeInBuffer(const TKeyBuf& key, TArrayWithSizeHolder<char>&& buffer) { + return AddSubtreeInBuffer(key.data(), key.size(), std::move(buffer)); + } + + bool Find(const TSymbol* key, size_t keylen, TData* value) const; + bool Find(const TKeyBuf& key, TData* value = nullptr) const { + return Find(key.data(), key.size(), value); + } + + bool FindLongestPrefix(const TSymbol* key, size_t keylen, size_t* prefixLen, TData* value = nullptr) const; + bool FindLongestPrefix(const TKeyBuf& key, size_t* prefixLen, TData* value = nullptr) const { + return FindLongestPrefix(key.data(), key.size(), prefixLen, value); + } + + size_t Save(IOutputStream& os) const; + size_t SaveAndDestroy(IOutputStream& os); + size_t SaveToFile(const TString& fileName) const { + TFixedBufferFileOutput out(fileName); + return Save(out); + } + + void Clear(); // Returns all memory to the system and resets the builder state. + + size_t GetEntryCount() const; + size_t GetNodeCount() const; + + // Exact output file size in bytes. + size_t MeasureByteSize() const { + return Impl->MeasureByteSize(); + } + +protected: + class TCompactTrieBuilderImpl; + THolder<TCompactTrieBuilderImpl> Impl; +}; + +//---------------------------------------------------------------------------------------------------------------------- +// Minimize the trie. The result is equivalent to the original +// trie, except that it takes less space (and has marginally lower +// performance, because of eventual epsilon links). +// The algorithm is as follows: starting from the largest pieces, we find +// nodes that have identical continuations (Daciuk's right language), +// and repack the trie. Repacking is done in-place, so memory is less +// of an issue; however, it may take considerable time. + +// IMPORTANT: never try to reminimize an already minimized trie or a trie with fast layout. +// Because of non-local structure and epsilon links, it won't work +// as you expect it to, and can destroy the trie in the making. +// If you want both minimization and fast layout, do the minimization first. + +template <class TPacker> +size_t CompactTrieMinimize(IOutputStream& os, const char* data, size_t datalength, bool verbose = false, const TPacker& packer = TPacker(), NCompactTrie::EMinimizeMode mode = NCompactTrie::MM_DEFAULT); + +template <class TTrieBuilder> +size_t CompactTrieMinimize(IOutputStream& os, const TTrieBuilder& builder, bool verbose = false); + +//---------------------------------------------------------------------------------------------------------------- +// Lay the trie in memory in such a way that there are less cache misses when jumping from root to leaf. +// The trie becomes about 2% larger, but the access became about 25% faster in our experiments. +// Can be called on minimized and non-minimized tries, in the first case in requires half a trie more memory. +// Calling it the second time on the same trie does nothing. +// +// The algorithm is based on van Emde Boas layout as described in the yandex data school lectures on external memory algoritms +// by Maxim Babenko and Ivan Puzyrevsky. The difference is that when we cut the tree into levels +// two nodes connected by a forward link are put into the same level (because they usually lie near each other in the original tree). +// The original paper (describing the layout in Section 2.1) is: +// Michael A. Bender, Erik D. Demaine, Martin Farach-Colton. Cache-Oblivious B-Trees +// SIAM Journal on Computing, volume 35, number 2, 2005, pages 341-358. +// Available on the web: http://erikdemaine.org/papers/CacheObliviousBTrees_SICOMP/ +// Or: Michael A. Bender, Erik D. Demaine, and Martin Farach-Colton. Cache-Oblivious B-Trees +// Proceedings of the 41st Annual Symposium +// on Foundations of Computer Science (FOCS 2000), Redondo Beach, California, November 12-14, 2000, pages 399-409. +// Available on the web: http://erikdemaine.org/papers/FOCS2000b/ +// (there is not much difference between these papers, actually). +// +template <class TPacker> +size_t CompactTrieMakeFastLayout(IOutputStream& os, const char* data, size_t datalength, bool verbose = false, const TPacker& packer = TPacker()); + +template <class TTrieBuilder> +size_t CompactTrieMakeFastLayout(IOutputStream& os, const TTrieBuilder& builder, bool verbose = false); + +// Composition of minimization and fast layout +template <class TPacker> +size_t CompactTrieMinimizeAndMakeFastLayout(IOutputStream& os, const char* data, size_t datalength, bool verbose = false, const TPacker& packer = TPacker()); + +template <class TTrieBuilder> +size_t CompactTrieMinimizeAndMakeFastLayout(IOutputStream& os, const TTrieBuilder& builder, bool verbose = false); + +// Implementation details moved here. +#include "comptrie_builder.inl" diff --git a/library/cpp/containers/comptrie/comptrie_builder.inl b/library/cpp/containers/comptrie/comptrie_builder.inl new file mode 100644 index 00000000000..f273fa65710 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_builder.inl @@ -0,0 +1,1121 @@ +#pragma once + +#include "comptrie_impl.h" +#include "comptrie_trie.h" +#include "make_fast_layout.h" +#include "array_with_size.h" + +#include <library/cpp/containers/compact_vector/compact_vector.h> + +#include <util/memory/alloc.h> +#include <util/memory/blob.h> +#include <util/memory/pool.h> +#include <util/memory/tempbuf.h> +#include <util/memory/smallobj.h> +#include <util/generic/algorithm.h> +#include <util/generic/buffer.h> +#include <util/generic/strbuf.h> + +#include <util/system/align.h> +#include <util/stream/buffer.h> + +#define CONSTEXPR_MAX2(a, b) (a) > (b) ? (a) : (b) +#define CONSTEXPR_MAX3(a, b, c) CONSTEXPR_MAX2(CONSTEXPR_MAX2(a, b), c) + +// TCompactTrieBuilder::TCompactTrieBuilderImpl + +template <class T, class D, class S> +class TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl { +protected: + TMemoryPool Pool; + size_t PayloadSize; + THolder<TFixedSizeAllocator> NodeAllocator; + class TNode; + class TArc; + TNode* Root; + TCompactTrieBuilderFlags Flags; + size_t EntryCount; + size_t NodeCount; + TPacker Packer; + + enum EPayload { + DATA_ABSENT, + DATA_INSIDE, + DATA_MALLOCED, + DATA_IN_MEMPOOL, + }; + +protected: + void ConvertSymbolArrayToChar(const TSymbol* key, size_t keylen, TTempBuf& buf, size_t ckeylen) const; + void NodeLinkTo(TNode* thiz, const TBlob& label, TNode* node); + TNode* NodeForwardAdd(TNode* thiz, const char* label, size_t len, size_t& passed, size_t* nodeCount); + bool FindEntryImpl(const char* key, size_t keylen, TData* value) const; + bool FindLongestPrefixImpl(const char* keyptr, size_t keylen, size_t* prefixLen, TData* value) const; + + size_t NodeMeasureSubtree(TNode* thiz) const; + ui64 NodeSaveSubtree(TNode* thiz, IOutputStream& os) const; + ui64 NodeSaveSubtreeAndDestroy(TNode* thiz, IOutputStream& osy); + void NodeBufferSubtree(TNode* thiz); + + size_t NodeMeasureLeafValue(TNode* thiz) const; + ui64 NodeSaveLeafValue(TNode* thiz, IOutputStream& os) const; + + virtual ui64 ArcMeasure(const TArc* thiz, size_t leftsize, size_t rightsize) const; + + virtual ui64 ArcSaveSelf(const TArc* thiz, IOutputStream& os) const; + ui64 ArcSave(const TArc* thiz, IOutputStream& os) const; + ui64 ArcSaveAndDestroy(const TArc* thiz, IOutputStream& os); + +public: + TCompactTrieBuilderImpl(TCompactTrieBuilderFlags flags, TPacker packer, IAllocator* alloc); + virtual ~TCompactTrieBuilderImpl(); + + void DestroyNode(TNode* node); + void NodeReleasePayload(TNode* thiz); + + char* AddEntryForData(const TSymbol* key, size_t keylen, size_t dataLen, bool& isNewAddition); + TNode* AddEntryForSomething(const TSymbol* key, size_t keylen, bool& isNewAddition); + + bool AddEntry(const TSymbol* key, size_t keylen, const TData& value); + bool AddEntryPtr(const TSymbol* key, size_t keylen, const char* value); + bool AddSubtreeInFile(const TSymbol* key, size_t keylen, const TString& fileName); + bool AddSubtreeInBuffer(const TSymbol* key, size_t keylen, TArrayWithSizeHolder<char>&& buffer); + bool FindEntry(const TSymbol* key, size_t keylen, TData* value) const; + bool FindLongestPrefix(const TSymbol* key, size_t keylen, size_t* prefixlen, TData* value) const; + + size_t Save(IOutputStream& os) const; + size_t SaveAndDestroy(IOutputStream& os); + + void Clear(); + + // lies if some key was added at least twice + size_t GetEntryCount() const; + size_t GetNodeCount() const; + + size_t MeasureByteSize() const { + return NodeMeasureSubtree(Root); + } +}; + +template <class T, class D, class S> +class TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TArc { +public: + TBlob Label; + TNode* Node; + mutable size_t LeftOffset; + mutable size_t RightOffset; + + TArc(const TBlob& lbl, TNode* nd); +}; + +template <class T, class D, class S> +class TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode { +public: + typedef typename TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl TBuilderImpl; + typedef typename TBuilderImpl::TArc TArc; + + struct ISubtree { + virtual ~ISubtree() = default; + virtual bool IsLast() const = 0; + virtual ui64 Measure(const TBuilderImpl* builder) const = 0; + virtual ui64 Save(const TBuilderImpl* builder, IOutputStream& os) const = 0; + virtual ui64 SaveAndDestroy(TBuilderImpl* builder, IOutputStream& os) = 0; + virtual void Destroy(TBuilderImpl*) { } + + // Tries to find key in subtree. + // Returns next node to find the key in (to avoid recursive calls) + // If it has end result, writes it to @value and @result arguments and returns nullptr + virtual const TNode* Find(TStringBuf& key, TData* value, bool& result, const TPacker& packer) const = 0; + virtual const TNode* FindLongestPrefix(TStringBuf& key, TData* value, bool& result, const TPacker& packer) const = 0; + }; + + class TArcSet: public ISubtree, public TCompactVector<TArc> { + public: + typedef typename TCompactVector<TArc>::iterator iterator; + typedef typename TCompactVector<TArc>::const_iterator const_iterator; + + TArcSet() { + Y_ASSERT(reinterpret_cast<ISubtree*>(this) == static_cast<void*>(this)); // This assumption is used in TNode::Subtree() + } + + iterator Find(char ch); + const_iterator Find(char ch) const; + void Add(const TBlob& s, TNode* node); + + bool IsLast() const override { + return this->Empty(); + } + + const TNode* Find(TStringBuf& key, TData* value, bool& result, const TPacker& packer) const override; + const TNode* FindLongestPrefix(TStringBuf& key, TData* value, bool& result, const TPacker& packer) const override { + return Find(key, value, result, packer); + } + + ui64 Measure(const TBuilderImpl* builder) const override { + return MeasureRange(builder, 0, this->size()); + } + + ui64 MeasureRange(const TBuilderImpl* builder, size_t from, size_t to) const { + if (from >= to) + return 0; + + size_t median = (from + to) / 2; + size_t leftsize = (size_t)MeasureRange(builder, from, median); + size_t rightsize = (size_t)MeasureRange(builder, median + 1, to); + + return builder->ArcMeasure(&(*this)[median], leftsize, rightsize); + } + + ui64 Save(const TBuilderImpl* builder, IOutputStream& os) const override { + return SaveRange(builder, 0, this->size(), os); + } + + ui64 SaveAndDestroy(TBuilderImpl* builder, IOutputStream& os) override { + ui64 result = SaveRangeAndDestroy(builder, 0, this->size(), os); + Destroy(builder); + return result; + } + + ui64 SaveRange(const TBuilderImpl* builder, size_t from, size_t to, IOutputStream& os) const { + if (from >= to) + return 0; + + size_t median = (from + to) / 2; + + ui64 written = builder->ArcSave(&(*this)[median], os); + written += SaveRange(builder, from, median, os); + written += SaveRange(builder, median + 1, to, os); + return written; + } + + ui64 SaveRangeAndDestroy(TBuilderImpl* builder, size_t from, size_t to, IOutputStream& os) { + if (from >= to) + return 0; + + size_t median = (from + to) / 2; + + ui64 written = builder->ArcSaveAndDestroy(&(*this)[median], os); + written += SaveRangeAndDestroy(builder, from, median, os); + written += SaveRangeAndDestroy(builder, median + 1, to, os); + return written; + } + + void Destroy(TBuilderImpl* builder) override { + // Delete all nodes down the stream. + for (iterator it = this->begin(); it != this->end(); ++it) { + builder->DestroyNode(it->Node); + } + this->clear(); + } + + ~TArcSet() override { + Y_ASSERT(this->empty()); + } + + }; + + struct TBufferedSubtree: public ISubtree { + TArrayWithSizeHolder<char> Buffer; + + TBufferedSubtree() { + Y_ASSERT(reinterpret_cast<ISubtree*>(this) == static_cast<void*>(this)); // This assumption is used in TNode::Subtree() + } + + bool IsLast() const override { + return Buffer.Empty(); + } + + const TNode* Find(TStringBuf& key, TData* value, bool& result, const TPacker& packer) const override { + if (Buffer.Empty()) { + result = false; + return nullptr; + } + + TCompactTrie<char, D, S> trie(Buffer.Get(), Buffer.Size(), packer); + result = trie.Find(key.data(), key.size(), value); + + return nullptr; + } + + const TNode* FindLongestPrefix(TStringBuf& key, TData* value, bool& result, const TPacker& packer) const override { + if (Buffer.Empty()) { + result = false; + return nullptr; + } + + TCompactTrie<char, D, S> trie(Buffer.Get(), Buffer.Size(), packer); + size_t prefixLen = 0; + result = trie.FindLongestPrefix(key.data(), key.size(), &prefixLen, value); + key = key.SubStr(prefixLen); + + return nullptr; + } + + ui64 Measure(const TBuilderImpl*) const override { + return Buffer.Size(); + } + + ui64 Save(const TBuilderImpl*, IOutputStream& os) const override { + os.Write(Buffer.Get(), Buffer.Size()); + return Buffer.Size(); + } + + ui64 SaveAndDestroy(TBuilderImpl* builder, IOutputStream& os) override { + ui64 result = Save(builder, os); + TArrayWithSizeHolder<char>().Swap(Buffer); + return result; + } + }; + + struct TSubtreeInFile: public ISubtree { + struct TData { + TString FileName; + ui64 Size; + }; + THolder<TData> Data; + + TSubtreeInFile(const TString& fileName) { + // stupid API + TFile file(fileName, RdOnly); + i64 size = file.GetLength(); + if (size < 0) + ythrow yexception() << "unable to get file " << fileName.Quote() << " size for unknown reason"; + Data.Reset(new TData); + Data->FileName = fileName; + Data->Size = size; + + Y_ASSERT(reinterpret_cast<ISubtree*>(this) == static_cast<void*>(this)); // This assumption is used in TNode::Subtree() + } + + bool IsLast() const override { + return Data->Size == 0; + } + + const TNode* Find(TStringBuf& key, typename TCompactTrieBuilder::TData* value, bool& result, const TPacker& packer) const override { + if (!Data) { + result = false; + return nullptr; + } + + TCompactTrie<char, D, S> trie(TBlob::FromFile(Data->FileName), packer); + result = trie.Find(key.data(), key.size(), value); + return nullptr; + } + + const TNode* FindLongestPrefix(TStringBuf& key, typename TCompactTrieBuilder::TData* value, bool& result, const TPacker& packer) const override { + if (!Data) { + result = false; + return nullptr; + } + + TCompactTrie<char, D, S> trie(TBlob::FromFile(Data->FileName), packer); + size_t prefixLen = 0; + result = trie.FindLongestPrefix(key.data(), key.size(), &prefixLen, value); + key = key.SubStr(prefixLen); + + return nullptr; + } + + ui64 Measure(const TBuilderImpl*) const override { + return Data->Size; + } + + ui64 Save(const TBuilderImpl*, IOutputStream& os) const override { + TUnbufferedFileInput is(Data->FileName); + ui64 written = TransferData(&is, &os); + if (written != Data->Size) + ythrow yexception() << "file " << Data->FileName.Quote() << " size changed"; + return written; + } + + ui64 SaveAndDestroy(TBuilderImpl* builder, IOutputStream& os) override { + return Save(builder, os); + } + }; + + union { + char ArcsData[CONSTEXPR_MAX3(sizeof(TArcSet), sizeof(TBufferedSubtree), sizeof(TSubtreeInFile))]; + union { + void* Data1; + long long int Data2; + } Aligner; + }; + + inline ISubtree* Subtree() { + return reinterpret_cast<ISubtree*>(ArcsData); + } + + inline const ISubtree* Subtree() const { + return reinterpret_cast<const ISubtree*>(ArcsData); + } + + EPayload PayloadType; + + inline const char* PayloadPtr() const { + return ((const char*) this) + sizeof(TNode); + } + + inline char* PayloadPtr() { + return ((char*) this) + sizeof(TNode); + } + + // *Payload() + inline const char*& PayloadAsPtr() const { + const char** payload = (const char**) PayloadPtr(); + return *payload; + } + + inline char*& PayloadAsPtr() { + char** payload = (char**) PayloadPtr(); + return *payload; + } + + inline const char* GetPayload() const { + switch (PayloadType) { + case DATA_INSIDE: + return PayloadPtr(); + case DATA_MALLOCED: + case DATA_IN_MEMPOOL: + return PayloadAsPtr(); + case DATA_ABSENT: + default: + abort(); + } + } + + inline char* GetPayload() { + const TNode* thiz = this; + return const_cast<char*>(thiz->GetPayload()); // const_cast is to avoid copy-paste style + } + + bool IsFinal() const { + return PayloadType != DATA_ABSENT; + } + + bool IsLast() const { + return Subtree()->IsLast(); + } + + inline void* operator new(size_t, TFixedSizeAllocator& pool) { + return pool.Allocate(); + } + + inline void operator delete(void* ptr, TFixedSizeAllocator& pool) noexcept { + pool.Release(ptr); + } + + TNode() + : PayloadType(DATA_ABSENT) + { + new (Subtree()) TArcSet; + } + + ~TNode() { + Subtree()->~ISubtree(); + Y_ASSERT(PayloadType == DATA_ABSENT); + } + +}; + +// TCompactTrieBuilder + +template <class T, class D, class S> +TCompactTrieBuilder<T, D, S>::TCompactTrieBuilder(TCompactTrieBuilderFlags flags, TPacker packer, IAllocator* alloc) + : Impl(new TCompactTrieBuilderImpl(flags, packer, alloc)) +{ +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::Add(const TSymbol* key, size_t keylen, const TData& value) { + return Impl->AddEntry(key, keylen, value); +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::AddPtr(const TSymbol* key, size_t keylen, const char* value) { + return Impl->AddEntryPtr(key, keylen, value); +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::AddSubtreeInFile(const TSymbol* key, size_t keylen, const TString& fileName) { + return Impl->AddSubtreeInFile(key, keylen, fileName); +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::AddSubtreeInBuffer(const TSymbol* key, size_t keylen, TArrayWithSizeHolder<char>&& buffer) { + return Impl->AddSubtreeInBuffer(key, keylen, std::move(buffer)); +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::Find(const TSymbol* key, size_t keylen, TData* value) const { + return Impl->FindEntry(key, keylen, value); +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::FindLongestPrefix( + const TSymbol* key, size_t keylen, size_t* prefixlen, TData* value) const { + return Impl->FindLongestPrefix(key, keylen, prefixlen, value); +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::Save(IOutputStream& os) const { + return Impl->Save(os); +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::SaveAndDestroy(IOutputStream& os) { + return Impl->SaveAndDestroy(os); +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::Clear() { + Impl->Clear(); +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::GetEntryCount() const { + return Impl->GetEntryCount(); +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::GetNodeCount() const { + return Impl->GetNodeCount(); +} + +// TCompactTrieBuilder::TCompactTrieBuilderImpl + +template <class T, class D, class S> +TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TCompactTrieBuilderImpl(TCompactTrieBuilderFlags flags, TPacker packer, IAllocator* alloc) + : Pool(1000000, TMemoryPool::TLinearGrow::Instance(), alloc) + , PayloadSize(sizeof(void*)) // XXX: find better value + , NodeAllocator(new TFixedSizeAllocator(sizeof(TNode) + PayloadSize, alloc)) + , Flags(flags) + , EntryCount(0) + , NodeCount(1) + , Packer(packer) +{ + Root = new (*NodeAllocator) TNode; +} + +template <class T, class D, class S> +TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::~TCompactTrieBuilderImpl() { + DestroyNode(Root); +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::ConvertSymbolArrayToChar( + const TSymbol* key, size_t keylen, TTempBuf& buf, size_t buflen) const { + char* ckeyptr = buf.Data(); + + for (size_t i = 0; i < keylen; ++i) { + TSymbol label = key[i]; + for (int j = (int)NCompactTrie::ExtraBits<TSymbol>(); j >= 0; j -= 8) { + Y_ASSERT(ckeyptr < buf.Data() + buflen); + *(ckeyptr++) = (char)(label >> j); + } + } + + buf.Proceed(buflen); + Y_ASSERT(ckeyptr == buf.Data() + buf.Filled()); +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::DestroyNode(TNode* thiz) { + thiz->Subtree()->Destroy(this); + NodeReleasePayload(thiz); + thiz->~TNode(); + NodeAllocator->Release(thiz); +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeReleasePayload(TNode* thiz) { + switch (thiz->PayloadType) { + case DATA_ABSENT: + case DATA_INSIDE: + case DATA_IN_MEMPOOL: + break; + case DATA_MALLOCED: + delete[] thiz->PayloadAsPtr(); + break; + default: + abort(); + } + thiz->PayloadType = DATA_ABSENT; +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::AddEntry( + const TSymbol* key, size_t keylen, const TData& value) { + size_t datalen = Packer.MeasureLeaf(value); + + bool isNewAddition = false; + char* place = AddEntryForData(key, keylen, datalen, isNewAddition); + Packer.PackLeaf(place, value, datalen); + return isNewAddition; +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::AddEntryPtr( + const TSymbol* key, size_t keylen, const char* value) { + size_t datalen = Packer.SkipLeaf(value); + + bool isNewAddition = false; + char* place = AddEntryForData(key, keylen, datalen, isNewAddition); + memcpy(place, value, datalen); + return isNewAddition; +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::AddSubtreeInFile( + const TSymbol* key, size_t keylen, const TString& fileName) { + typedef typename TNode::ISubtree ISubtree; + typedef typename TNode::TSubtreeInFile TSubtreeInFile; + + bool isNewAddition = false; + TNode* node = AddEntryForSomething(key, keylen, isNewAddition); + node->Subtree()->Destroy(this); + node->Subtree()->~ISubtree(); + + new (node->Subtree()) TSubtreeInFile(fileName); + return isNewAddition; +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::AddSubtreeInBuffer( + const TSymbol* key, size_t keylen, TArrayWithSizeHolder<char>&& buffer) { + + typedef typename TNode::TBufferedSubtree TBufferedSubtree; + + bool isNewAddition = false; + TNode* node = AddEntryForSomething(key, keylen, isNewAddition); + node->Subtree()->Destroy(this); + node->Subtree()->~ISubtree(); + + auto subtree = new (node->Subtree()) TBufferedSubtree(); + subtree->Buffer.Swap(buffer); + + return isNewAddition; +} + +template <class T, class D, class S> +typename TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode* + TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::AddEntryForSomething( + const TSymbol* key, size_t keylen, bool& isNewAddition) { + using namespace NCompactTrie; + + EntryCount++; + + if (Flags & CTBF_VERBOSE) + ShowProgress(EntryCount); + + TNode* current = Root; + size_t passed; + + // Special case of empty key: replace it by 1-byte "\0" key. + size_t ckeylen = keylen ? keylen * sizeof(TSymbol) : 1; + TTempBuf ckeybuf(ckeylen); + if (keylen == 0) { + ckeybuf.Append("\0", 1); + } else { + ConvertSymbolArrayToChar(key, keylen, ckeybuf, ckeylen); + } + + char* ckey = ckeybuf.Data(); + + TNode* next; + while ((ckeylen > 0) && (next = NodeForwardAdd(current, ckey, ckeylen, passed, &NodeCount)) != nullptr) { + current = next; + ckeylen -= passed; + ckey += passed; + } + + if (ckeylen != 0) { + //new leaf + NodeCount++; + TNode* leaf = new (*NodeAllocator) TNode(); + NodeLinkTo(current, TBlob::Copy(ckey, ckeylen), leaf); + current = leaf; + } + isNewAddition = (current->PayloadType == DATA_ABSENT); + if ((Flags & CTBF_UNIQUE) && !isNewAddition) + ythrow yexception() << "Duplicate key"; + return current; +} + +template <class T, class D, class S> +char* TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::AddEntryForData(const TSymbol* key, size_t keylen, + size_t datalen, bool& isNewAddition) { + TNode* current = AddEntryForSomething(key, keylen, isNewAddition); + NodeReleasePayload(current); + if (datalen <= PayloadSize) { + current->PayloadType = DATA_INSIDE; + } else if (Flags & CTBF_PREFIX_GROUPED) { + current->PayloadType = DATA_MALLOCED; + current->PayloadAsPtr() = new char[datalen]; + } else { + current->PayloadType = DATA_IN_MEMPOOL; + current->PayloadAsPtr() = (char*) Pool.Allocate(datalen); // XXX: allocate unaligned + } + return current->GetPayload(); +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::FindEntry(const TSymbol* key, size_t keylen, TData* value) const { + using namespace NCompactTrie; + + if (!keylen) { + const char zero = '\0'; + return FindEntryImpl(&zero, 1, value); + } else { + size_t ckeylen = keylen * sizeof(TSymbol); + TTempBuf ckeybuf(ckeylen); + ConvertSymbolArrayToChar(key, keylen, ckeybuf, ckeylen); + return FindEntryImpl(ckeybuf.Data(), ckeylen, value); + } +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::FindEntryImpl(const char* keyptr, size_t keylen, TData* value) const { + const TNode* node = Root; + bool result = false; + TStringBuf key(keyptr, keylen); + while (key && (node = node->Subtree()->Find(key, value, result, Packer))) { + } + return result; +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::FindLongestPrefix( + const TSymbol* key, size_t keylen, size_t* prefixlen, TData* value) const { + using namespace NCompactTrie; + + if (!keylen) { + const char zero = '\0'; + const bool ret = FindLongestPrefixImpl(&zero, 1, prefixlen, value); + if (ret && prefixlen) + *prefixlen = 0; // empty key found + return ret; + } else { + size_t ckeylen = keylen * sizeof(TSymbol); + TTempBuf ckeybuf(ckeylen); + ConvertSymbolArrayToChar(key, keylen, ckeybuf, ckeylen); + bool ret = FindLongestPrefixImpl(ckeybuf.Data(), ckeylen, prefixlen, value); + if (ret && prefixlen && *prefixlen == 1 && ckeybuf.Data()[0] == '\0') + *prefixlen = 0; // if we have found empty key, set prefixlen to zero + else if (!ret) // try to find value with empty key, because empty key is prefix of a every key + ret = FindLongestPrefix(nullptr, 0, prefixlen, value); + + if (ret && prefixlen) + *prefixlen /= sizeof(TSymbol); + + return ret; + } +} + +template <class T, class D, class S> +bool TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::FindLongestPrefixImpl(const char* keyptr, size_t keylen, size_t* prefixLen, TData* value) const { + const TNode* node = Root; + const TNode* lastFinalNode = nullptr; + bool endResult = false; + TStringBuf key(keyptr, keylen); + TStringBuf keyTail = key; + TStringBuf lastFinalKeyTail; + while (keyTail && (node = node->Subtree()->FindLongestPrefix(keyTail, value, endResult, Packer))) { + if (endResult) // no more ways to find prefix and prefix has been found + break; + + if (node->IsFinal()) { + lastFinalNode = node; + lastFinalKeyTail = keyTail; + } + } + if (!endResult && lastFinalNode) { + if (value) + Packer.UnpackLeaf(lastFinalNode->GetPayload(), *value); + keyTail = lastFinalKeyTail; + endResult = true; + } + if (endResult && prefixLen) + *prefixLen = keyTail ? key.size() - keyTail.size() : key.size(); + return endResult; +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::Clear() { + DestroyNode(Root); + Pool.Clear(); + NodeAllocator.Reset(new TFixedSizeAllocator(sizeof(TNode) + PayloadSize, TDefaultAllocator::Instance())); + Root = new (*NodeAllocator) TNode; + EntryCount = 0; + NodeCount = 1; +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::Save(IOutputStream& os) const { + const size_t len = NodeMeasureSubtree(Root); + if (len != NodeSaveSubtree(Root, os)) + ythrow yexception() << "something wrong"; + + return len; +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::SaveAndDestroy(IOutputStream& os) { + const size_t len = NodeMeasureSubtree(Root); + if (len != NodeSaveSubtreeAndDestroy(Root, os)) + ythrow yexception() << "something wrong"; + + return len; +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::GetEntryCount() const { + return EntryCount; +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::GetNodeCount() const { + return NodeCount; +} + +template <class T, class D, class S> +typename TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode* + TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeForwardAdd( + TNode* thiz, const char* label, size_t len, size_t& passed, size_t* nodeCount) { + typename TNode::TArcSet* arcSet = dynamic_cast<typename TNode::TArcSet*>(thiz->Subtree()); + if (!arcSet) + ythrow yexception() << "Bad input order - expected input strings to be prefix-grouped."; + + typename TNode::TArcSet::iterator it = arcSet->Find(*label); + + if (it != arcSet->end()) { + const char* arcLabel = it->Label.AsCharPtr(); + size_t arcLabelLen = it->Label.Length(); + + for (passed = 0; (passed < len) && (passed < arcLabelLen) && (label[passed] == arcLabel[passed]); ++passed) { + //just count + } + + if (passed < arcLabelLen) { + (*nodeCount)++; + TNode* node = new (*NodeAllocator) TNode(); + NodeLinkTo(node, it->Label.SubBlob(passed, arcLabelLen), it->Node); + + it->Node = node; + it->Label = it->Label.SubBlob(passed); + } + + return it->Node; + } + + return nullptr; +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeLinkTo(TNode* thiz, const TBlob& label, TNode* node) { + typename TNode::TArcSet* arcSet = dynamic_cast<typename TNode::TArcSet*>(thiz->Subtree()); + if (!arcSet) + ythrow yexception() << "Bad input order - expected input strings to be prefix-grouped."; + + // Buffer the node at the last arc + if ((Flags & CTBF_PREFIX_GROUPED) && !arcSet->empty()) + NodeBufferSubtree(arcSet->back().Node); + + arcSet->Add(label, node); +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeMeasureSubtree(TNode* thiz) const { + return (size_t)thiz->Subtree()->Measure(this); +} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeSaveSubtree(TNode* thiz, IOutputStream& os) const { + return thiz->Subtree()->Save(this, os); +} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeSaveSubtreeAndDestroy(TNode* thiz, IOutputStream& os) { + return thiz->Subtree()->SaveAndDestroy(this, os); +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeBufferSubtree(TNode* thiz) { + typedef typename TNode::TArcSet TArcSet; + + TArcSet* arcSet = dynamic_cast<TArcSet*>(thiz->Subtree()); + if (!arcSet) + return; + + size_t bufferLength = (size_t)arcSet->Measure(this); + TArrayWithSizeHolder<char> buffer; + buffer.Resize(bufferLength); + + TMemoryOutput bufout(buffer.Get(), buffer.Size()); + + ui64 written = arcSet->Save(this, bufout); + Y_ASSERT(written == bufferLength); + + arcSet->Destroy(this); + arcSet->~TArcSet(); + + typename TNode::TBufferedSubtree* bufferedArcSet = new (thiz->Subtree()) typename TNode::TBufferedSubtree; + + bufferedArcSet->Buffer.Swap(buffer); +} + +template <class T, class D, class S> +size_t TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeMeasureLeafValue(TNode* thiz) const { + if (!thiz->IsFinal()) + return 0; + + return Packer.SkipLeaf(thiz->GetPayload()); +} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::NodeSaveLeafValue(TNode* thiz, IOutputStream& os) const { + if (!thiz->IsFinal()) + return 0; + + size_t len = Packer.SkipLeaf(thiz->GetPayload()); + os.Write(thiz->GetPayload(), len); + return len; +} + +// TCompactTrieBuilder::TCompactTrieBuilderImpl::TNode::TArc + +template <class T, class D, class S> +TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TArc::TArc(const TBlob& lbl, TNode* nd) + : Label(lbl) + , Node(nd) + , LeftOffset(0) + , RightOffset(0) +{} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::ArcMeasure( + const TArc* thiz, size_t leftsize, size_t rightsize) const { + using namespace NCompactTrie; + + size_t coresize = 2 + NodeMeasureLeafValue(thiz->Node); // 2 == (char + flags) + size_t treesize = NodeMeasureSubtree(thiz->Node); + + if (thiz->Label.Length() > 0) + treesize += 2 * (thiz->Label.Length() - 1); + + // Triple measurements are needed because the space needed to store the offset + // shall be added to the offset itself. Hence three iterations. + size_t leftoffsetsize = leftsize ? MeasureOffset(coresize + treesize) : 0; + size_t rightoffsetsize = rightsize ? MeasureOffset(coresize + treesize + leftsize) : 0; + leftoffsetsize = leftsize ? MeasureOffset(coresize + treesize + leftoffsetsize + rightoffsetsize) : 0; + rightoffsetsize = rightsize ? MeasureOffset(coresize + treesize + leftsize + leftoffsetsize + rightoffsetsize) : 0; + leftoffsetsize = leftsize ? MeasureOffset(coresize + treesize + leftoffsetsize + rightoffsetsize) : 0; + rightoffsetsize = rightsize ? MeasureOffset(coresize + treesize + leftsize + leftoffsetsize + rightoffsetsize) : 0; + + coresize += leftoffsetsize + rightoffsetsize; + thiz->LeftOffset = leftsize ? coresize + treesize : 0; + thiz->RightOffset = rightsize ? coresize + treesize + leftsize : 0; + + return coresize + treesize + leftsize + rightsize; +} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::ArcSaveSelf(const TArc* thiz, IOutputStream& os) const { + using namespace NCompactTrie; + + ui64 written = 0; + + size_t leftoffsetsize = MeasureOffset(thiz->LeftOffset); + size_t rightoffsetsize = MeasureOffset(thiz->RightOffset); + + size_t labelLen = thiz->Label.Length(); + + for (size_t i = 0; i < labelLen; ++i) { + char flags = 0; + + if (i == 0) { + flags |= (leftoffsetsize << MT_LEFTSHIFT); + flags |= (rightoffsetsize << MT_RIGHTSHIFT); + } + + if (i == labelLen-1) { + if (thiz->Node->IsFinal()) + flags |= MT_FINAL; + + if (!thiz->Node->IsLast()) + flags |= MT_NEXT; + } else { + flags |= MT_NEXT; + } + + os.Write(&flags, 1); + os.Write(&thiz->Label.AsCharPtr()[i], 1); + written += 2; + + if (i == 0) { + written += ArcSaveOffset(thiz->LeftOffset, os); + written += ArcSaveOffset(thiz->RightOffset, os); + } + } + + written += NodeSaveLeafValue(thiz->Node, os); + return written; +} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::ArcSave(const TArc* thiz, IOutputStream& os) const { + ui64 written = ArcSaveSelf(thiz, os); + written += NodeSaveSubtree(thiz->Node, os); + return written; +} + +template <class T, class D, class S> +ui64 TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::ArcSaveAndDestroy(const TArc* thiz, IOutputStream& os) { + ui64 written = ArcSaveSelf(thiz, os); + written += NodeSaveSubtreeAndDestroy(thiz->Node, os); + return written; +} + +// TCompactTrieBuilder::TCompactTrieBuilderImpl::TNode::TArcSet + +template <class T, class D, class S> +typename TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode::TArcSet::iterator + TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode::TArcSet::Find(char ch) { + using namespace NCompTriePrivate; + iterator it = LowerBound(this->begin(), this->end(), ch, TCmp()); + + if (it != this->end() && it->Label[0] == (unsigned char)ch) { + return it; + } + + return this->end(); +} + +template <class T, class D, class S> +typename TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode::TArcSet::const_iterator + TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode::TArcSet::Find(char ch) const { + using namespace NCompTriePrivate; + const_iterator it = LowerBound(this->begin(), this->end(), ch, TCmp()); + + if (it != this->end() && it->Label[0] == (unsigned char)ch) { + return it; + } + + return this->end(); +} + +template <class T, class D, class S> +void TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode::TArcSet::Add(const TBlob& s, TNode* node) { + using namespace NCompTriePrivate; + this->insert(LowerBound(this->begin(), this->end(), s[0], TCmp()), TArc(s, node)); +} + +template <class T, class D, class S> +const typename TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode* + TCompactTrieBuilder<T, D, S>::TCompactTrieBuilderImpl::TNode::TArcSet::Find( + TStringBuf& key, TData* value, bool& result, const TPacker& packer) const { + result = false; + if (!key) + return nullptr; + + const const_iterator it = Find(key[0]); + if (it != this->end()) { + const char* const arcLabel = it->Label.AsCharPtr(); + const size_t arcLabelLen = it->Label.Length(); + if (key.size() >= arcLabelLen && memcmp(key.data(), arcLabel, arcLabelLen) == 0) { + const TStringBuf srcKey = key; + key = key.SubStr(arcLabelLen); + const TNode* const node = it->Node; + if (srcKey.size() == arcLabelLen) { + // unpack value of it->Node, if it has value + if (!node->IsFinal()) + return nullptr; + + if (value) + packer.UnpackLeaf(node->GetPayload(), *value); + + result = true; + return nullptr; + } + + // find in subtree + return node; + } + } + + return nullptr; +} + +// Different + +//---------------------------------------------------------------------------------------------------------------------- +// Minimize the trie. The result is equivalent to the original +// trie, except that it takes less space (and has marginally lower +// performance, because of eventual epsilon links). +// The algorithm is as follows: starting from the largest pieces, we find +// nodes that have identical continuations (Daciuk's right language), +// and repack the trie. Repacking is done in-place, so memory is less +// of an issue; however, it may take considerable time. + +// IMPORTANT: never try to reminimize an already minimized trie or a trie with fast layout. +// Because of non-local structure and epsilon links, it won't work +// as you expect it to, and can destroy the trie in the making. + +template <class TPacker> +size_t CompactTrieMinimize(IOutputStream& os, const char* data, size_t datalength, bool verbose /*= false*/, const TPacker& packer /*= TPacker()*/, NCompactTrie::EMinimizeMode mode) { + using namespace NCompactTrie; + return CompactTrieMinimizeImpl(os, data, datalength, verbose, &packer, mode); +} + +template <class TTrieBuilder> +size_t CompactTrieMinimize(IOutputStream& os, const TTrieBuilder& builder, bool verbose /*=false*/) { + TBufferStream buftmp; + size_t len = builder.Save(buftmp); + return CompactTrieMinimize<typename TTrieBuilder::TPacker>(os, buftmp.Buffer().Data(), len, verbose); +} + +//---------------------------------------------------------------------------------------------------------------- +// Lay the trie in memory in such a way that there are less cache misses when jumping from root to leaf. +// The trie becomes about 2% larger, but the access became about 25% faster in our experiments. +// Can be called on minimized and non-minimized tries, in the first case in requires half a trie more memory. +// Calling it the second time on the same trie does nothing. +// +// The algorithm is based on van Emde Boas layout as described in the yandex data school lectures on external memory algoritms +// by Maxim Babenko and Ivan Puzyrevsky. The difference is that when we cut the tree into levels +// two nodes connected by a forward link are put into the same level (because they usually lie near each other in the original tree). +// The original paper (describing the layout in Section 2.1) is: +// Michael A. Bender, Erik D. Demaine, Martin Farach-Colton. Cache-Oblivious B-Trees +// SIAM Journal on Computing, volume 35, number 2, 2005, pages 341-358. +// Available on the web: http://erikdemaine.org/papers/CacheObliviousBTrees_SICOMP/ +// Or: Michael A. Bender, Erik D. Demaine, and Martin Farach-Colton. Cache-Oblivious B-Trees +// Proceedings of the 41st Annual Symposium +// on Foundations of Computer Science (FOCS 2000), Redondo Beach, California, November 12-14, 2000, pages 399-409. +// Available on the web: http://erikdemaine.org/papers/FOCS2000b/ +// (there is not much difference between these papers, actually). +// +template <class TPacker> +size_t CompactTrieMakeFastLayout(IOutputStream& os, const char* data, size_t datalength, bool verbose /*= false*/, const TPacker& packer /*= TPacker()*/) { + using namespace NCompactTrie; + return CompactTrieMakeFastLayoutImpl(os, data, datalength, verbose, &packer); +} + +template <class TTrieBuilder> +size_t CompactTrieMakeFastLayout(IOutputStream& os, const TTrieBuilder& builder, bool verbose /*=false*/) { + TBufferStream buftmp; + size_t len = builder.Save(buftmp); + return CompactTrieMakeFastLayout<typename TTrieBuilder::TPacker>(os, buftmp.Buffer().Data(), len, verbose); +} + +template <class TPacker> +size_t CompactTrieMinimizeAndMakeFastLayout(IOutputStream& os, const char* data, size_t datalength, bool verbose/*=false*/, const TPacker& packer/*= TPacker()*/) { + TBufferStream buftmp; + size_t len = CompactTrieMinimize(buftmp, data, datalength, verbose, packer); + return CompactTrieMakeFastLayout(os, buftmp.Buffer().Data(), len, verbose, packer); +} + +template <class TTrieBuilder> +size_t CompactTrieMinimizeAndMakeFastLayout(IOutputStream& os, const TTrieBuilder& builder, bool verbose /*=false*/) { + TBufferStream buftmp; + size_t len = CompactTrieMinimize(buftmp, builder, verbose); + return CompactTrieMakeFastLayout<typename TTrieBuilder::TPacker>(os, buftmp.Buffer().Data(), len, verbose); +} + diff --git a/library/cpp/containers/comptrie/comptrie_impl.cpp b/library/cpp/containers/comptrie/comptrie_impl.cpp new file mode 100644 index 00000000000..a116ab6d1ef --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_impl.cpp @@ -0,0 +1,39 @@ +#include "comptrie_impl.h" + +#include <util/system/rusage.h> +#include <util/stream/output.h> + +// Unpack the leaf value. The algorithm can store up to 8 full bytes in leafs. + +namespace NCompactTrie { + size_t MeasureOffset(size_t offset) { + int n = 0; + + while (offset) { + offset >>= 8; + ++n; + } + + return n; + } + + size_t PackOffset(char* buffer, size_t offset) { + size_t len = MeasureOffset(offset); + size_t i = len; + + while (i--) { + buffer[i] = (char)(offset & 0xFF); + offset >>= 8; + } + + return len; + } + + void ShowProgress(size_t n) { + if (n % 1000000 == 0) + Cerr << n << ", RSS=" << (TRusage::Get().MaxRss >> 20) << "mb" << Endl; + else if (n % 20000 == 0) + Cerr << "."; + } + +} diff --git a/library/cpp/containers/comptrie/comptrie_impl.h b/library/cpp/containers/comptrie/comptrie_impl.h new file mode 100644 index 00000000000..f41c38311a4 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_impl.h @@ -0,0 +1,221 @@ +#pragma once + +#include <util/stream/output.h> + +#ifndef COMPTRIE_DATA_CHECK +#define COMPTRIE_DATA_CHECK 1 +#endif + +// NCompactTrie + +namespace NCompactTrie { + const char MT_FINAL = '\x80'; + const char MT_NEXT = '\x40'; + const char MT_SIZEMASK = '\x07'; + const size_t MT_LEFTSHIFT = 3; + const size_t MT_RIGHTSHIFT = 0; + + Y_FORCE_INLINE size_t UnpackOffset(const char* p, size_t len); + size_t MeasureOffset(size_t offset); + size_t PackOffset(char* buffer, size_t offset); + static inline ui64 ArcSaveOffset(size_t offset, IOutputStream& os); + Y_FORCE_INLINE char LeapByte(const char*& datapos, const char* dataend, char label); + + template <class T> + inline static size_t ExtraBits() { + return (sizeof(T) - 1) * 8; + } + + static inline bool IsEpsilonLink(const char flags) { + return !(flags & (MT_FINAL | MT_NEXT)); + } + + static inline void TraverseEpsilon(const char*& datapos) { + const char flags = *datapos; + if (!IsEpsilonLink(flags)) { + return; + } + const size_t offsetlength = flags & MT_SIZEMASK; + const size_t offset = UnpackOffset(datapos + 1, offsetlength); + Y_ASSERT(offset); + datapos += offset; + } + + static inline size_t LeftOffsetLen(const char flags) { + return (flags >> MT_LEFTSHIFT) & MT_SIZEMASK; + } + + static inline size_t RightOffsetLen(const char flags) { + return flags & MT_SIZEMASK; + } + + void ShowProgress(size_t n); // just print dots +} + +namespace NCompTriePrivate { + template <typename TChar> + struct TStringForChar { + }; + + template <> + struct TStringForChar<char> { + typedef TString TResult; + }; + + template <> + struct TStringForChar<wchar16> { + typedef TUtf16String TResult; + }; + + template <> + struct TStringForChar<wchar32> { + typedef TUtf32String TResult; + }; + +} + +namespace NCompTriePrivate { + struct TCmp { + template <class T> + inline bool operator()(const T& l, const T& r) { + return (unsigned char)(l.Label[0]) < (unsigned char)(r.Label[0]); + } + + template <class T> + inline bool operator()(const T& l, char r) { + return (unsigned char)(l.Label[0]) < (unsigned char)r; + } + }; +} + +namespace NCompactTrie { + static inline ui64 ArcSaveOffset(size_t offset, IOutputStream& os) { + using namespace NCompactTrie; + + if (!offset) + return 0; + + char buf[16]; + size_t len = PackOffset(buf, offset); + os.Write(buf, len); + return len; + } + + // Unpack the offset to the next node. The encoding scheme can store offsets + // up to 7 bytes; whether they fit into size_t is another issue. + Y_FORCE_INLINE size_t UnpackOffset(const char* p, size_t len) { + size_t result = 0; + + while (len--) + result = ((result << 8) | (*(p++) & 0xFF)); + + return result; + } + + // Auxiliary function: consumes one character from the input. Advances the data pointer + // to the position immediately preceding the value for the link just traversed (if any); + // returns flags associated with the link. If no arc with the required label is present, + // zeroes the data pointer. + Y_FORCE_INLINE char LeapByte(const char*& datapos, const char* dataend, char label) { + while (datapos < dataend) { + size_t offsetlength, offset; + const char* startpos = datapos; + char flags = *(datapos++); + + if (IsEpsilonLink(flags)) { + // Epsilon link - jump to the specified offset without further checks. + // These links are created during minimization: original uncompressed + // tree does not need them. (If we find a way to package 3 offset lengths + // into 1 byte, we could get rid of them; but it looks like they do no harm. + Y_ASSERT(datapos < dataend); + offsetlength = flags & MT_SIZEMASK; + offset = UnpackOffset(datapos, offsetlength); + if (!offset) + break; + datapos = startpos + offset; + + continue; + } + + char ch = *(datapos++); + + // Left branch + offsetlength = LeftOffsetLen(flags); + if ((unsigned char)label < (unsigned char)ch) { + offset = UnpackOffset(datapos, offsetlength); + if (!offset) + break; + + datapos = startpos + offset; + + continue; + } + + datapos += offsetlength; + + // Right branch + offsetlength = RightOffsetLen(flags); + if ((unsigned char)label > (unsigned char)ch) { + offset = UnpackOffset(datapos, offsetlength); + + if (!offset) + break; + + datapos = startpos + offset; + + continue; + } + + // Got a match; return position right before the contents for the label + datapos += offsetlength; + return flags; + } + + // if we got here, we're past the dataend - bail out ASAP + datapos = nullptr; + return 0; + } + + // Auxiliary function: consumes one (multibyte) symbol from the input. + // Advances the data pointer to the root of the subtrie beginning after the symbol, + // zeroes it if this subtrie is empty. + // If there is a value associated with the symbol, makes the value pointer point to it, + // otherwise sets it to nullptr. + // Returns true if the symbol was succesfully found in the trie, false otherwise. + template <typename TSymbol, class TPacker> + Y_FORCE_INLINE bool Advance(const char*& datapos, const char* const dataend, const char*& value, + TSymbol label, TPacker packer) { + Y_ASSERT(datapos < dataend); + char flags = MT_NEXT; + for (int i = (int)ExtraBits<TSymbol>(); i >= 0; i -= 8) { + flags = LeapByte(datapos, dataend, (char)(label >> i)); + if (!datapos) { + return false; // no such arc + } + + value = nullptr; + + Y_ASSERT(datapos <= dataend); + if ((flags & MT_FINAL)) { + value = datapos; + datapos += packer.SkipLeaf(datapos); + } + + if (!(flags & MT_NEXT)) { + if (i == 0) { + datapos = nullptr; + return true; + } + return false; // no further way + } + + TraverseEpsilon(datapos); + if (i == 0) { // last byte, and got a match + return true; + } + } + + return false; + } + +} diff --git a/library/cpp/containers/comptrie/comptrie_packer.h b/library/cpp/containers/comptrie/comptrie_packer.h new file mode 100644 index 00000000000..0341eeeae38 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_packer.h @@ -0,0 +1,21 @@ +#pragma once + +#include <library/cpp/packers/packers.h> + +template <class T> +class TCompactTriePacker { +public: + void UnpackLeaf(const char* p, T& t) const { + NPackers::TPacker<T>().UnpackLeaf(p, t); + } + void PackLeaf(char* buffer, const T& data, size_t computedSize) const { + NPackers::TPacker<T>().PackLeaf(buffer, data, computedSize); + } + size_t MeasureLeaf(const T& data) const { + return NPackers::TPacker<T>().MeasureLeaf(data); + } + size_t SkipLeaf(const char* p) const // this function better be fast because it is very frequently used + { + return NPackers::TPacker<T>().SkipLeaf(p); + } +}; diff --git a/library/cpp/containers/comptrie/comptrie_trie.h b/library/cpp/containers/comptrie/comptrie_trie.h new file mode 100644 index 00000000000..40ec1e52b32 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_trie.h @@ -0,0 +1,663 @@ +#pragma once + +#include "comptrie_impl.h" +#include "comptrie_packer.h" +#include "opaque_trie_iterator.h" +#include "leaf_skipper.h" +#include "key_selector.h" + +#include <util/generic/buffer.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/memory/blob.h> +#include <util/stream/input.h> +#include <utility> + +template <class T, class D, class S> +class TCompactTrieBuilder; + +namespace NCompactTrie { + template <class TTrie> + class TFirstSymbolIterator; +} + +template <class TTrie> +class TSearchIterator; + +template <class TTrie> +class TPrefixIterator; + +// in case of <char> specialization cannot distinguish between "" and "\0" keys +template <class T = char, class D = ui64, class S = TCompactTriePacker<D>> +class TCompactTrie { +public: + typedef T TSymbol; + typedef D TData; + typedef S TPacker; + + typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey; + typedef typename TCompactTrieKeySelector<TSymbol>::TKeyBuf TKeyBuf; + + typedef std::pair<TKey, TData> TValueType; + typedef std::pair<size_t, TData> TPhraseMatch; + typedef TVector<TPhraseMatch> TPhraseMatchVector; + + typedef TCompactTrieBuilder<T, D, S> TBuilder; + +protected: + TBlob DataHolder; + const char* EmptyValue = nullptr; + TPacker Packer; + NCompactTrie::TPackerLeafSkipper<TPacker> Skipper = &Packer; // This should be true for every constructor. + +public: + TCompactTrie() = default; + + TCompactTrie(const char* d, size_t len, TPacker packer); + TCompactTrie(const char* d, size_t len) + : TCompactTrie{d, len, TPacker{}} { + } + + TCompactTrie(const TBlob& data, TPacker packer); + explicit TCompactTrie(const TBlob& data) + : TCompactTrie{data, TPacker{}} { + } + + // Skipper should be initialized with &Packer, not with &other.Packer, so you have to redefine these. + TCompactTrie(const TCompactTrie& other); + TCompactTrie(TCompactTrie&& other) noexcept; + TCompactTrie& operator=(const TCompactTrie& other); + TCompactTrie& operator=(TCompactTrie&& other) noexcept; + + explicit operator bool() const { + return !IsEmpty(); + } + + void Init(const char* d, size_t len, TPacker packer = TPacker()); + void Init(const TBlob& data, TPacker packer = TPacker()); + + bool IsInitialized() const; + bool IsEmpty() const; + + bool Find(const TSymbol* key, size_t keylen, TData* value = nullptr) const; + bool Find(const TKeyBuf& key, TData* value = nullptr) const { + return Find(key.data(), key.size(), value); + } + + TData Get(const TSymbol* key, size_t keylen) const { + TData value; + if (!Find(key, keylen, &value)) + ythrow yexception() << "key " << TKey(key, keylen).Quote() << " not found in trie"; + return value; + } + TData Get(const TKeyBuf& key) const { + return Get(key.data(), key.size()); + } + TData GetDefault(const TKeyBuf& key, const TData& def) const { + TData value; + if (!Find(key.data(), key.size(), &value)) + return def; + else + return value; + } + + const TBlob& Data() const { + return DataHolder; + }; + + const NCompactTrie::ILeafSkipper& GetSkipper() const { + return Skipper; + } + + TPacker GetPacker() const { + return Packer; + } + + bool HasCorrectSkipper() const { + return Skipper.GetPacker() == &Packer; + } + + void FindPhrases(const TSymbol* key, size_t keylen, TPhraseMatchVector& matches, TSymbol separator = TSymbol(' ')) const; + void FindPhrases(const TKeyBuf& key, TPhraseMatchVector& matches, TSymbol separator = TSymbol(' ')) const { + return FindPhrases(key.data(), key.size(), matches, separator); + } + bool FindLongestPrefix(const TSymbol* key, size_t keylen, size_t* prefixLen, TData* value = nullptr, bool* hasNext = nullptr) const; + bool FindLongestPrefix(const TKeyBuf& key, size_t* prefixLen, TData* value = nullptr, bool* hasNext = nullptr) const { + return FindLongestPrefix(key.data(), key.size(), prefixLen, value, hasNext); + } + + // Return trie, containing all tails for the given key + inline TCompactTrie<T, D, S> FindTails(const TSymbol* key, size_t keylen) const; + TCompactTrie<T, D, S> FindTails(const TKeyBuf& key) const { + return FindTails(key.data(), key.size()); + } + bool FindTails(const TSymbol* key, size_t keylen, TCompactTrie<T, D, S>& res) const; + bool FindTails(const TKeyBuf& key, TCompactTrie<T, D, S>& res) const { + return FindTails(key.data(), key.size(), res); + } + + // same as FindTails(&key, 1), a bit faster + // return false, if no arc with @label exists + inline bool FindTails(TSymbol label, TCompactTrie<T, D, S>& res) const; + + class TConstIterator { + private: + typedef NCompactTrie::TOpaqueTrieIterator TOpaqueTrieIterator; + typedef NCompactTrie::TOpaqueTrie TOpaqueTrie; + friend class TCompactTrie; + TConstIterator(const TOpaqueTrie& trie, const char* emptyValue, bool atend, TPacker packer); // only usable from Begin() and End() methods + TConstIterator(const TOpaqueTrie& trie, const char* emptyValue, const TKeyBuf& key, TPacker packer); // only usable from UpperBound() method + + public: + TConstIterator() = default; + bool IsEmpty() const { + return !Impl; + } // Almost no other method can be called. + + bool operator==(const TConstIterator& other) const; + bool operator!=(const TConstIterator& other) const; + TConstIterator& operator++(); + TConstIterator operator++(int /*unused*/); + TConstIterator& operator--(); + TConstIterator operator--(int /*unused*/); + TValueType operator*(); + + TKey GetKey() const; + size_t GetKeySize() const; + TData GetValue() const; + void GetValue(TData& data) const; + const char* GetValuePtr() const; + + private: + TPacker Packer; + TCopyPtr<TOpaqueTrieIterator> Impl; + }; + + TConstIterator Begin() const; + TConstIterator begin() const; + TConstIterator End() const; + TConstIterator end() const; + + // Returns an iterator pointing to the smallest key in the trie >= the argument. + // TODO: misleading name. Should be called LowerBound for consistency with stl. + // No. It is the STL that has a misleading name. + // LowerBound of X cannot be greater than X. + TConstIterator UpperBound(const TKeyBuf& key) const; + + void Print(IOutputStream& os); + + size_t Size() const; + + friend class NCompactTrie::TFirstSymbolIterator<TCompactTrie>; + friend class TSearchIterator<TCompactTrie>; + friend class TPrefixIterator<TCompactTrie>; + +protected: + explicit TCompactTrie(const char* emptyValue); + TCompactTrie(const TBlob& data, const char* emptyValue, TPacker packer = TPacker()); + + bool LookupLongestPrefix(const TSymbol* key, size_t keylen, size_t& prefixLen, const char*& valuepos, bool& hasNext) const; + bool LookupLongestPrefix(const TSymbol* key, size_t keylen, size_t& prefixLen, const char*& valuepos) const { + bool hasNext; + return LookupLongestPrefix(key, keylen, prefixLen, valuepos, hasNext); + } + void LookupPhrases(const char* datapos, size_t len, const TSymbol* key, size_t keylen, TVector<TPhraseMatch>& matches, TSymbol separator) const; +}; + +template <class T = char, class D = ui64, class S = TCompactTriePacker<D>> +class TCompactTrieHolder: public TCompactTrie<T, D, S>, NNonCopyable::TNonCopyable { +private: + typedef TCompactTrie<T, D, S> TBase; + TArrayHolder<char> Storage; + +public: + TCompactTrieHolder(IInputStream& is, size_t len); +}; + +//------------------------// +// Implementation section // +//------------------------// + +// TCompactTrie + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TCompactTrie(const TBlob& data, TPacker packer) + : DataHolder(data) + , Packer(packer) +{ + Init(data, packer); +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TCompactTrie(const char* d, size_t len, TPacker packer) + : Packer(packer) +{ + Init(d, len, packer); +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TCompactTrie(const char* emptyValue) + : EmptyValue(emptyValue) +{ +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TCompactTrie(const TBlob& data, const char* emptyValue, TPacker packer) + : DataHolder(data) + , EmptyValue(emptyValue) + , Packer(packer) +{ +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TCompactTrie(const TCompactTrie& other) + : DataHolder(other.DataHolder) + , EmptyValue(other.EmptyValue) + , Packer(other.Packer) +{ +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TCompactTrie(TCompactTrie&& other) noexcept + : DataHolder(std::move(other.DataHolder)) + , EmptyValue(std::move(other.EmptyValue)) + , Packer(std::move(other.Packer)) +{ +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>& TCompactTrie<T, D, S>::operator=(const TCompactTrie& other) { + if (this != &other) { + DataHolder = other.DataHolder; + EmptyValue = other.EmptyValue; + Packer = other.Packer; + } + return *this; +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>& TCompactTrie<T, D, S>::operator=(TCompactTrie&& other) noexcept { + if (this != &other) { + DataHolder = std::move(other.DataHolder); + EmptyValue = std::move(other.EmptyValue); + Packer = std::move(other.Packer); + } + return *this; +} + +template <class T, class D, class S> +void TCompactTrie<T, D, S>::Init(const char* d, size_t len, TPacker packer) { + Init(TBlob::NoCopy(d, len), packer); +} + +template <class T, class D, class S> +void TCompactTrie<T, D, S>::Init(const TBlob& data, TPacker packer) { + using namespace NCompactTrie; + + DataHolder = data; + Packer = packer; + + const char* datapos = DataHolder.AsCharPtr(); + size_t len = DataHolder.Length(); + if (!len) + return; + + const char* const dataend = datapos + len; + + const char* emptypos = datapos; + char flags = LeapByte(emptypos, dataend, 0); + if (emptypos && (flags & MT_FINAL)) { + Y_ASSERT(emptypos <= dataend); + EmptyValue = emptypos; + } +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::IsInitialized() const { + return DataHolder.Data() != nullptr; +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::IsEmpty() const { + return DataHolder.Size() == 0 && EmptyValue == nullptr; +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::Find(const TSymbol* key, size_t keylen, TData* value) const { + size_t prefixLen = 0; + const char* valuepos = nullptr; + bool hasNext; + if (!LookupLongestPrefix(key, keylen, prefixLen, valuepos, hasNext) || prefixLen != keylen) + return false; + if (value) + Packer.UnpackLeaf(valuepos, *value); + return true; +} + +template <class T, class D, class S> +void TCompactTrie<T, D, S>::FindPhrases(const TSymbol* key, size_t keylen, TPhraseMatchVector& matches, TSymbol separator) const { + LookupPhrases(DataHolder.AsCharPtr(), DataHolder.Length(), key, keylen, matches, separator); +} + +template <class T, class D, class S> +inline TCompactTrie<T, D, S> TCompactTrie<T, D, S>::FindTails(const TSymbol* key, size_t keylen) const { + TCompactTrie<T, D, S> ret; + FindTails(key, keylen, ret); + return ret; +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::FindTails(const TSymbol* key, size_t keylen, TCompactTrie<T, D, S>& res) const { + using namespace NCompactTrie; + + size_t len = DataHolder.Length(); + + if (!key || !len) + return false; + + if (!keylen) { + res = *this; + return true; + } + + const char* datastart = DataHolder.AsCharPtr(); + const char* datapos = datastart; + const char* const dataend = datapos + len; + + const TSymbol* keyend = key + keylen; + const char* value = nullptr; + + while (key != keyend) { + T label = *(key++); + if (!NCompactTrie::Advance(datapos, dataend, value, label, Packer)) + return false; + + if (key == keyend) { + if (datapos) { + Y_ASSERT(datapos >= datastart); + res = TCompactTrie<T, D, S>(TBlob::NoCopy(datapos, dataend - datapos), value); + } else { + res = TCompactTrie<T, D, S>(value); + } + return true; + } else if (!datapos) { + return false; // No further way + } + } + + return false; +} + +template <class T, class D, class S> +inline bool TCompactTrie<T, D, S>::FindTails(TSymbol label, TCompactTrie<T, D, S>& res) const { + using namespace NCompactTrie; + + const size_t len = DataHolder.Length(); + if (!len) + return false; + + const char* datastart = DataHolder.AsCharPtr(); + const char* dataend = datastart + len; + const char* datapos = datastart; + const char* value = nullptr; + + if (!NCompactTrie::Advance(datapos, dataend, value, label, Packer)) + return false; + + if (datapos) { + Y_ASSERT(datapos >= datastart); + res = TCompactTrie<T, D, S>(TBlob::NoCopy(datapos, dataend - datapos), value); + } else { + res = TCompactTrie<T, D, S>(value); + } + + return true; +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::Begin() const { + NCompactTrie::TOpaqueTrie self(DataHolder.AsCharPtr(), DataHolder.Length(), Skipper); + return TConstIterator(self, EmptyValue, false, Packer); +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::begin() const { + return Begin(); +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::End() const { + NCompactTrie::TOpaqueTrie self(DataHolder.AsCharPtr(), DataHolder.Length(), Skipper); + return TConstIterator(self, EmptyValue, true, Packer); +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::end() const { + return End(); +} + +template <class T, class D, class S> +size_t TCompactTrie<T, D, S>::Size() const { + size_t res = 0; + for (TConstIterator it = Begin(); it != End(); ++it) + ++res; + return res; +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::UpperBound(const TKeyBuf& key) const { + NCompactTrie::TOpaqueTrie self(DataHolder.AsCharPtr(), DataHolder.Length(), Skipper); + return TConstIterator(self, EmptyValue, key, Packer); +} + +template <class T, class D, class S> +void TCompactTrie<T, D, S>::Print(IOutputStream& os) { + typedef typename ::TCompactTrieKeySelector<T>::TKeyBuf TSBuffer; + for (TConstIterator it = Begin(); it != End(); ++it) { + os << TSBuffer((*it).first.data(), (*it).first.size()) << "\t" << (*it).second << Endl; + } +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::FindLongestPrefix(const TSymbol* key, size_t keylen, size_t* prefixLen, TData* value, bool* hasNext) const { + const char* valuepos = nullptr; + size_t tempPrefixLen = 0; + bool tempHasNext; + bool found = LookupLongestPrefix(key, keylen, tempPrefixLen, valuepos, tempHasNext); + if (prefixLen) + *prefixLen = tempPrefixLen; + if (found && value) + Packer.UnpackLeaf(valuepos, *value); + if (hasNext) + *hasNext = tempHasNext; + return found; +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::LookupLongestPrefix(const TSymbol* key, size_t keylen, size_t& prefixLen, const char*& valuepos, bool& hasNext) const { + using namespace NCompactTrie; + + const char* datapos = DataHolder.AsCharPtr(); + size_t len = DataHolder.Length(); + + prefixLen = 0; + hasNext = false; + bool found = false; + + if (EmptyValue) { + valuepos = EmptyValue; + found = true; + } + + if (!key || !len) + return found; + + const char* const dataend = datapos + len; + + const T* keyend = key + keylen; + while (key != keyend) { + T label = *(key++); + for (i64 i = (i64)ExtraBits<TSymbol>(); i >= 0; i -= 8) { + const char flags = LeapByte(datapos, dataend, (char)(label >> i)); + if (!datapos) { + return found; // no such arc + } + + Y_ASSERT(datapos <= dataend); + if ((flags & MT_FINAL)) { + prefixLen = keylen - (keyend - key) - (i ? 1 : 0); + valuepos = datapos; + hasNext = flags & MT_NEXT; + found = true; + + if (!i && key == keyend) { // last byte, and got a match + return found; + } + datapos += Packer.SkipLeaf(datapos); // skip intermediate leaf nodes + } + + if (!(flags & MT_NEXT)) { + return found; // no further way + } + } + } + + return found; +} + +template <class T, class D, class S> +void TCompactTrie<T, D, S>::LookupPhrases( + const char* datapos, size_t len, const TSymbol* key, size_t keylen, + TVector<TPhraseMatch>& matches, TSymbol separator) const { + using namespace NCompactTrie; + + matches.clear(); + + if (!key || !len) + return; + + const T* const keystart = key; + const T* const keyend = key + keylen; + const char* const dataend = datapos + len; + while (datapos && key != keyend) { + T label = *(key++); + const char* value = nullptr; + if (!Advance(datapos, dataend, value, label, Packer)) { + return; + } + if (value && (key == keyend || *key == separator)) { + size_t matchlength = (size_t)(key - keystart); + D data; + Packer.UnpackLeaf(value, data); + matches.push_back(TPhraseMatch(matchlength, data)); + } + } +} + +// TCompactTrieHolder + +template <class T, class D, class S> +TCompactTrieHolder<T, D, S>::TCompactTrieHolder(IInputStream& is, size_t len) + : Storage(new char[len]) +{ + if (is.Load(Storage.Get(), len) != len) { + ythrow yexception() << "bad data load"; + } + TBase::Init(Storage.Get(), len); +} + +//---------------------------------------------------------------------------------------------------------------- +// TCompactTrie::TConstIterator + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TConstIterator::TConstIterator(const TOpaqueTrie& trie, const char* emptyValue, bool atend, TPacker packer) + : Packer(packer) + , Impl(new TOpaqueTrieIterator(trie, emptyValue, atend)) +{ +} + +template <class T, class D, class S> +TCompactTrie<T, D, S>::TConstIterator::TConstIterator(const TOpaqueTrie& trie, const char* emptyValue, const TKeyBuf& key, TPacker packer) + : Packer(packer) + , Impl(new TOpaqueTrieIterator(trie, emptyValue, true)) +{ + Impl->UpperBound<TSymbol>(key); +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::TConstIterator::operator==(const TConstIterator& other) const { + if (!Impl) + return !other.Impl; + if (!other.Impl) + return false; + return *Impl == *other.Impl; +} + +template <class T, class D, class S> +bool TCompactTrie<T, D, S>::TConstIterator::operator!=(const TConstIterator& other) const { + return !operator==(other); +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator& TCompactTrie<T, D, S>::TConstIterator::operator++() { + Impl->Forward(); + return *this; +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::TConstIterator::operator++(int /*unused*/) { + TConstIterator copy(*this); + Impl->Forward(); + return copy; +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator& TCompactTrie<T, D, S>::TConstIterator::operator--() { + Impl->Backward(); + return *this; +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TConstIterator TCompactTrie<T, D, S>::TConstIterator::operator--(int /*unused*/) { + TConstIterator copy(*this); + Impl->Backward(); + return copy; +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TValueType TCompactTrie<T, D, S>::TConstIterator::operator*() { + return TValueType(GetKey(), GetValue()); +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TKey TCompactTrie<T, D, S>::TConstIterator::GetKey() const { + return Impl->GetKey<TSymbol>(); +} + +template <class T, class D, class S> +size_t TCompactTrie<T, D, S>::TConstIterator::GetKeySize() const { + return Impl->MeasureKey<TSymbol>(); +} + +template <class T, class D, class S> +const char* TCompactTrie<T, D, S>::TConstIterator::GetValuePtr() const { + return Impl->GetValuePtr(); +} + +template <class T, class D, class S> +typename TCompactTrie<T, D, S>::TData TCompactTrie<T, D, S>::TConstIterator::GetValue() const { + D data; + GetValue(data); + return data; +} + +template <class T, class D, class S> +void TCompactTrie<T, D, S>::TConstIterator::GetValue(typename TCompactTrie<T, D, S>::TData& data) const { + const char* ptr = GetValuePtr(); + if (ptr) { + Packer.UnpackLeaf(ptr, data); + } else { + data = typename TCompactTrie<T, D, S>::TData(); + } +} diff --git a/library/cpp/containers/comptrie/comptrie_ut.cpp b/library/cpp/containers/comptrie/comptrie_ut.cpp new file mode 100644 index 00000000000..74bee09b5d6 --- /dev/null +++ b/library/cpp/containers/comptrie/comptrie_ut.cpp @@ -0,0 +1,1791 @@ +#include <util/random/shuffle.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/stream/output.h> +#include <utility> + +#include <util/charset/wide.h> +#include <util/generic/algorithm.h> +#include <util/generic/buffer.h> +#include <util/generic/map.h> +#include <util/generic/vector.h> +#include <util/generic/ptr.h> +#include <util/generic/ylimits.h> + +#include <util/folder/dirut.h> + +#include <util/random/random.h> +#include <util/random/fast.h> + +#include <util/string/hex.h> +#include <util/string/cast.h> + +#include "comptrie.h" +#include "set.h" +#include "first_symbol_iterator.h" +#include "search_iterator.h" +#include "pattern_searcher.h" + +#include <array> +#include <iterator> + + +class TCompactTrieTest: public TTestBase { +private: + UNIT_TEST_SUITE(TCompactTrieTest); + UNIT_TEST(TestTrie8); + UNIT_TEST(TestTrie16); + UNIT_TEST(TestTrie32); + + UNIT_TEST(TestFastTrie8); + UNIT_TEST(TestFastTrie16); + UNIT_TEST(TestFastTrie32); + + UNIT_TEST(TestMinimizedTrie8); + UNIT_TEST(TestMinimizedTrie16); + UNIT_TEST(TestMinimizedTrie32); + + UNIT_TEST(TestFastMinimizedTrie8); + UNIT_TEST(TestFastMinimizedTrie16); + UNIT_TEST(TestFastMinimizedTrie32); + + UNIT_TEST(TestTrieIterator8); + UNIT_TEST(TestTrieIterator16); + UNIT_TEST(TestTrieIterator32); + + UNIT_TEST(TestMinimizedTrieIterator8); + UNIT_TEST(TestMinimizedTrieIterator16); + UNIT_TEST(TestMinimizedTrieIterator32); + + UNIT_TEST(TestPhraseSearch); + UNIT_TEST(TestAddGet); + UNIT_TEST(TestEmpty); + UNIT_TEST(TestUninitializedNonEmpty); + UNIT_TEST(TestRandom); + UNIT_TEST(TestFindTails); + UNIT_TEST(TestPrefixGrouped); + UNIT_TEST(CrashTestPrefixGrouped); + UNIT_TEST(TestMergeFromFile); + UNIT_TEST(TestMergeFromBuffer); + UNIT_TEST(TestUnique); + UNIT_TEST(TestAddRetValue); + UNIT_TEST(TestClear); + + UNIT_TEST(TestIterateEmptyKey); + + UNIT_TEST(TestTrieSet); + + UNIT_TEST(TestTrieForVectorInt64); + UNIT_TEST(TestTrieForListInt64); + UNIT_TEST(TestTrieForSetInt64); + + UNIT_TEST(TestTrieForVectorStroka); + UNIT_TEST(TestTrieForListStroka); + UNIT_TEST(TestTrieForSetStroka); + + UNIT_TEST(TestTrieForVectorWtroka); + UNIT_TEST(TestTrieForVectorFloat); + UNIT_TEST(TestTrieForVectorDouble); + + UNIT_TEST(TestTrieForListVectorInt64); + UNIT_TEST(TestTrieForPairWtrokaVectorInt64); + + UNIT_TEST(TestEmptyValueOutOfOrder); + UNIT_TEST(TestFindLongestPrefixWithEmptyValue); + + UNIT_TEST(TestSearchIterChar); + UNIT_TEST(TestSearchIterWchar); + UNIT_TEST(TestSearchIterWchar32) + + UNIT_TEST(TestCopyAndAssignment); + + UNIT_TEST(TestFirstSymbolIterator8); + UNIT_TEST(TestFirstSymbolIterator16); + UNIT_TEST(TestFirstSymbolIterator32); + UNIT_TEST(TestFirstSymbolIteratorChar32); + + UNIT_TEST(TestArrayPacker); + + UNIT_TEST(TestBuilderFindLongestPrefix); + UNIT_TEST(TestBuilderFindLongestPrefixWithEmptyValue); + + UNIT_TEST(TestPatternSearcherEmpty); + UNIT_TEST(TestPatternSearcherSimple); + UNIT_TEST(TestPatternSearcherRandom); + + UNIT_TEST_SUITE_END(); + + static const char* SampleData[]; + + template <class T> + void CreateTrie(IOutputStream& out, bool minimize, bool useFastLayout); + + template <class T> + void CheckData(const char* src, size_t len); + + template <class T> + void CheckUpperBound(const char* src, size_t len); + + template <class T> + void CheckIterator(const char* src, size_t len); + + template <class T> + void TestTrie(bool minimize, bool useFastLayout); + + template <class T> + void TestTrieIterator(bool minimize); + + template <class T, bool minimize> + void TestRandom(const size_t n, const size_t maxKeySize); + + void TestFindTailsImpl(const TString& prefix); + + void TestUniqueImpl(bool isPrefixGrouped); + + TVector<TUtf16String> GetSampleKeys(size_t nKeys) const; + template <class TContainer> + TVector<TContainer> GetSampleVectorData(size_t nValues); + template <class TContainer> + TVector<TContainer> GetSampleTextVectorData(size_t nValues); + template <class T> + void CheckEquality(const T& value1, const T& value2) const; + template <class TContainer> + void TestTrieWithContainers(const TVector<TUtf16String>& keys, const TVector<TContainer>& sampleData, TString methodName); + + template <typename TChar> + void TestSearchIterImpl(); + + template <class TTrie> + void TestFirstSymbolIteratorForTrie(const TTrie& trie, const TStringBuf& narrowAnswers); + + template <typename TSymbol> + void TestFirstSymbolIterator(); + + template <class T> + class TIntPacker; + template <class T> + class TDummyPacker; + class TStrokaPacker; + +public: + void TestPackers(); + + void TestTrie8(); + void TestTrie16(); + void TestTrie32(); + + void TestFastTrie8(); + void TestFastTrie16(); + void TestFastTrie32(); + + void TestMinimizedTrie8(); + void TestMinimizedTrie16(); + void TestMinimizedTrie32(); + + void TestFastMinimizedTrie8(); + void TestFastMinimizedTrie16(); + void TestFastMinimizedTrie32(); + + void TestTrieIterator8(); + void TestTrieIterator16(); + void TestTrieIterator32(); + + void TestMinimizedTrieIterator8(); + void TestMinimizedTrieIterator16(); + void TestMinimizedTrieIterator32(); + + void TestPhraseSearch(); + void TestAddGet(); + void TestEmpty(); + void TestUninitializedNonEmpty(); + void TestRandom(); + void TestFindTails(); + void TestPrefixGrouped(); + void CrashTestPrefixGrouped(); + void TestMergeFromFile(); + void TestMergeFromBuffer(); + void TestUnique(); + void TestAddRetValue(); + void TestClear(); + + void TestIterateEmptyKey(); + + void TestTrieSet(); + + void TestTrieForVectorInt64(); + void TestTrieForListInt64(); + void TestTrieForSetInt64(); + + void TestTrieForVectorStroka(); + void TestTrieForListStroka(); + void TestTrieForSetStroka(); + + void TestTrieForVectorWtroka(); + void TestTrieForVectorFloat(); + void TestTrieForVectorDouble(); + + void TestTrieForListVectorInt64(); + void TestTrieForPairWtrokaVectorInt64(); + + void TestEmptyValueOutOfOrder(); + void TestFindLongestPrefixWithEmptyValue(); + + void TestSearchIterChar(); + void TestSearchIterWchar(); + void TestSearchIterWchar32(); + + void TestCopyAndAssignment(); + + void TestFirstSymbolIterator8(); + void TestFirstSymbolIterator16(); + void TestFirstSymbolIterator32(); + void TestFirstSymbolIteratorChar32(); + + void TestArrayPacker(); + + void TestBuilderFindLongestPrefix(); + void TestBuilderFindLongestPrefix(size_t keysCount, double branchProbability, bool isPrefixGrouped, bool hasEmptyKey); + void TestBuilderFindLongestPrefixWithEmptyValue(); + + void TestPatternSearcherOnDataset( + const TVector<TString>& patterns, + const TVector<TString>& samples + ); + void TestPatternSearcherEmpty(); + void TestPatternSearcherSimple(); + void TestPatternSearcherRandom( + size_t patternsNum, + size_t patternMaxLength, + size_t strMaxLength, + int maxChar, + TFastRng<ui64>& rng + ); + void TestPatternSearcherRandom(); +}; + +UNIT_TEST_SUITE_REGISTRATION(TCompactTrieTest); + +const char* TCompactTrieTest::SampleData[] = { + "", + "a", "b", "c", "d", + "aa", "ab", "ac", "ad", + "aaa", "aab", "aac", "aad", + "aba", "abb", "abc", "abd", + "fba", "fbb", "fbc", "fbd", + "fbbaa", + "c\x85\xA4\xBF" // Just something outside ASCII. +}; + +template <class T> +typename TCompactTrie<T>::TKey MakeWideKey(const char* str, size_t len) { + typename TCompactTrie<T>::TKey buffer; + for (size_t i = 0; i < len; i++) { + unsigned int ch = (str[i] & 0xFF); + buffer.push_back((T)(ch | ch << 8 | ch << 16 | ch << 24)); + } + return buffer; +} + +template <class T> +typename TCompactTrie<T>::TKey MakeWideKey(const TString& str) { + return MakeWideKey<T>(str.c_str(), str.length()); +} + +template <class T> +typename TCompactTrie<T>::TKey MakeWideKey(const TStringBuf& buf) { + return MakeWideKey<T>(buf.data(), buf.length()); +} + +template <class T> +void TCompactTrieTest::CreateTrie(IOutputStream& out, bool minimize, bool useFastLayout) { + TCompactTrieBuilder<T> builder; + + for (auto& i : SampleData) { + size_t len = strlen(i); + + builder.Add(MakeWideKey<T>(i, len), len * 2); + } + + TBufferOutput tmp2; + IOutputStream& currentOutput = useFastLayout ? tmp2 : out; + if (minimize) { + TBufferOutput buftmp; + builder.Save(buftmp); + CompactTrieMinimize<TCompactTriePacker<ui64>>(currentOutput, buftmp.Buffer().Data(), buftmp.Buffer().Size(), false); + } else { + builder.Save(currentOutput); + } + if (useFastLayout) { + CompactTrieMakeFastLayout<TCompactTriePacker<T>>(out, tmp2.Buffer().Data(), tmp2.Buffer().Size(), false); + } +} + +// Iterates over all strings of length <= 4 made of letters a-g. +static bool LexicographicStep(TString& s) { + if (s.length() < 4) { + s += "a"; + return true; + } + while (!s.empty() && s.back() == 'g') + s.pop_back(); + if (s.empty()) + return false; + char last = s.back(); + last++; + s.pop_back(); + s.push_back(last); + return true; +} + +template <class T> +void TCompactTrieTest::CheckUpperBound(const char* data, size_t datalen) { + TCompactTrie<T> trie(data, datalen); + typedef typename TCompactTrie<T>::TKey TKey; + typedef typename TCompactTrie<T>::TData TData; + + TString key; + do { + const TKey wideKey = MakeWideKey<T>(key); + typename TCompactTrie<T>::TConstIterator it = trie.UpperBound(wideKey); + UNIT_ASSERT_C(it == trie.End() || it.GetKey() >= wideKey, "key=" + key); + TData data; + const bool found = trie.Find(wideKey, &data); + if (found) + UNIT_ASSERT_C(it.GetKey() == wideKey && it.GetValue() == data, "key=" + key); + if (it != trie.Begin()) + UNIT_ASSERT_C((--it).GetKey() < wideKey, "key=" + key); + } while (LexicographicStep(key)); +} + +template <class T> +void TCompactTrieTest::CheckData(const char* data, size_t datalen) { + TCompactTrie<T> trie(data, datalen); + + UNIT_ASSERT_VALUES_EQUAL(Y_ARRAY_SIZE(SampleData), trie.Size()); + + for (auto& i : SampleData) { + size_t len = strlen(i); + ui64 value = 0; + size_t prefixLen = 0; + + typename TCompactTrie<T>::TKey key = MakeWideKey<T>(i, len); + UNIT_ASSERT(trie.Find(key, &value)); + UNIT_ASSERT_EQUAL(len * 2, value); + UNIT_ASSERT(trie.FindLongestPrefix(key, &prefixLen, &value)); + UNIT_ASSERT_EQUAL(len, prefixLen); + UNIT_ASSERT_EQUAL(len * 2, value); + + TString badkey("bb"); + badkey += i; + key = MakeWideKey<T>(badkey); + UNIT_ASSERT(!trie.Find(key)); + value = 123; + UNIT_ASSERT(!trie.Find(key, &value)); + UNIT_ASSERT_EQUAL(123, value); + UNIT_ASSERT(trie.FindLongestPrefix(key, &prefixLen, &value)); + UNIT_ASSERT_EQUAL(1, prefixLen); + UNIT_ASSERT_EQUAL(2, value); + + badkey = i; + badkey += "x"; + key = MakeWideKey<T>(badkey); + UNIT_ASSERT(!trie.Find(key)); + value = 1234; + UNIT_ASSERT(!trie.Find(key, &value)); + UNIT_ASSERT_EQUAL(1234, value); + UNIT_ASSERT(trie.FindLongestPrefix(key, &prefixLen, &value)); + UNIT_ASSERT_EQUAL(len, prefixLen); + UNIT_ASSERT_EQUAL(len * 2, value); + UNIT_ASSERT(trie.FindLongestPrefix(key, &prefixLen, nullptr)); + UNIT_ASSERT_EQUAL(len, prefixLen); + } + + TString testkey("fbbaa"); + typename TCompactTrie<T>::TKey key = MakeWideKey<T>(testkey); + ui64 value = 0; + size_t prefixLen = 0; + UNIT_ASSERT(trie.FindLongestPrefix(key.data(), testkey.length() - 1, &prefixLen, &value)); + UNIT_ASSERT_EQUAL(prefixLen, 3); + UNIT_ASSERT_EQUAL(6, value); + + testkey = "fbbax"; + key = MakeWideKey<T>(testkey); + UNIT_ASSERT(trie.FindLongestPrefix(key, &prefixLen, &value)); + UNIT_ASSERT_EQUAL(prefixLen, 3); + UNIT_ASSERT_EQUAL(6, value); + + value = 12345678; + UNIT_ASSERT(!trie.Find(key, &value)); + UNIT_ASSERT_EQUAL(12345678, value); //Failed Find() should not change value +} + +template <class T> +void TCompactTrieTest::CheckIterator(const char* data, size_t datalen) { + typedef typename TCompactTrie<T>::TKey TKey; + typedef typename TCompactTrie<T>::TValueType TValue; + TMap<TKey, ui64> stored; + + for (auto& i : SampleData) { + size_t len = strlen(i); + + stored[MakeWideKey<T>(i, len)] = len * 2; + } + + TCompactTrie<T> trie(data, datalen); + TVector<TValue> items; + typename TCompactTrie<T>::TConstIterator it = trie.Begin(); + size_t entry_count = 0; + TMap<TKey, ui64> received; + while (it != trie.End()) { + UNIT_ASSERT_VALUES_EQUAL(it.GetKeySize(), it.GetKey().size()); + received.insert(*it); + items.push_back(*it); + entry_count++; + it++; + } + TMap<TKey, ui64> received2; + for (std::pair<TKey, ui64> x : trie) { + received2.insert(x); + } + UNIT_ASSERT(entry_count == stored.size()); + UNIT_ASSERT(received == stored); + UNIT_ASSERT(received2 == stored); + + std::reverse(items.begin(), items.end()); + typename TCompactTrie<T>::TConstIterator revIt = trie.End(); + typename TCompactTrie<T>::TConstIterator const begin = trie.Begin(); + typename TCompactTrie<T>::TConstIterator emptyIt; + size_t pos = 0; + while (revIt != begin) { + revIt--; + UNIT_ASSERT(*revIt == items[pos]); + pos++; + } + // Checking the assignment operator. + revIt = begin; + UNIT_ASSERT(revIt == trie.Begin()); + UNIT_ASSERT(!revIt.IsEmpty()); + UNIT_ASSERT(revIt != emptyIt); + UNIT_ASSERT(revIt != trie.End()); + ++revIt; // Call a method that uses Skipper. + revIt = emptyIt; + UNIT_ASSERT(revIt == emptyIt); + UNIT_ASSERT(revIt.IsEmpty()); + UNIT_ASSERT(revIt != trie.End()); + // Checking the move assignment operator. + revIt = trie.Begin(); + UNIT_ASSERT(revIt == trie.Begin()); + UNIT_ASSERT(!revIt.IsEmpty()); + UNIT_ASSERT(revIt != emptyIt); + UNIT_ASSERT(revIt != trie.End()); + ++revIt; // Call a method that uses Skipper. + revIt = typename TCompactTrie<T>::TConstIterator(); + UNIT_ASSERT(revIt == emptyIt); + UNIT_ASSERT(revIt.IsEmpty()); + UNIT_ASSERT(revIt != trie.End()); +} + +template <class T> +void TCompactTrieTest::TestTrie(bool minimize, bool useFastLayout) { + TBufferOutput bufout; + CreateTrie<T>(bufout, minimize, useFastLayout); + CheckData<T>(bufout.Buffer().Data(), bufout.Buffer().Size()); + CheckUpperBound<T>(bufout.Buffer().Data(), bufout.Buffer().Size()); +} + +template <class T> +void TCompactTrieTest::TestTrieIterator(bool minimize) { + TBufferOutput bufout; + CreateTrie<T>(bufout, minimize, false); + CheckIterator<T>(bufout.Buffer().Data(), bufout.Buffer().Size()); +} + +void TCompactTrieTest::TestTrie8() { + TestTrie<char>(false, false); +} +void TCompactTrieTest::TestTrie16() { + TestTrie<wchar16>(false, false); +} +void TCompactTrieTest::TestTrie32() { + TestTrie<wchar32>(false, false); +} + +void TCompactTrieTest::TestFastTrie8() { + TestTrie<char>(false, true); +} +void TCompactTrieTest::TestFastTrie16() { + TestTrie<wchar16>(false, true); +} +void TCompactTrieTest::TestFastTrie32() { + TestTrie<wchar32>(false, true); +} + +void TCompactTrieTest::TestMinimizedTrie8() { + TestTrie<char>(true, false); +} +void TCompactTrieTest::TestMinimizedTrie16() { + TestTrie<wchar16>(true, false); +} +void TCompactTrieTest::TestMinimizedTrie32() { + TestTrie<wchar32>(true, false); +} + +void TCompactTrieTest::TestFastMinimizedTrie8() { + TestTrie<char>(true, true); +} +void TCompactTrieTest::TestFastMinimizedTrie16() { + TestTrie<wchar16>(true, true); +} +void TCompactTrieTest::TestFastMinimizedTrie32() { + TestTrie<wchar32>(true, true); +} + +void TCompactTrieTest::TestTrieIterator8() { + TestTrieIterator<char>(false); +} +void TCompactTrieTest::TestTrieIterator16() { + TestTrieIterator<wchar16>(false); +} +void TCompactTrieTest::TestTrieIterator32() { + TestTrieIterator<wchar32>(false); +} + +void TCompactTrieTest::TestMinimizedTrieIterator8() { + TestTrieIterator<char>(true); +} +void TCompactTrieTest::TestMinimizedTrieIterator16() { + TestTrieIterator<wchar16>(true); +} +void TCompactTrieTest::TestMinimizedTrieIterator32() { + TestTrieIterator<wchar32>(true); +} + +void TCompactTrieTest::TestPhraseSearch() { + static const char* phrases[] = {"ab", "ab cd", "ab cd ef"}; + static const char* const goodphrase = "ab cd ef gh"; + static const char* const badphrase = "cd ef gh ab"; + TBufferOutput bufout; + + TCompactTrieBuilder<char> builder; + for (size_t i = 0; i < Y_ARRAY_SIZE(phrases); i++) { + builder.Add(phrases[i], strlen(phrases[i]), i); + } + builder.Save(bufout); + + TCompactTrie<char> trie(bufout.Buffer().Data(), bufout.Buffer().Size()); + TVector<TCompactTrie<char>::TPhraseMatch> matches; + trie.FindPhrases(goodphrase, strlen(goodphrase), matches); + + UNIT_ASSERT(matches.size() == Y_ARRAY_SIZE(phrases)); + for (size_t i = 0; i < Y_ARRAY_SIZE(phrases); i++) { + UNIT_ASSERT(matches[i].first == strlen(phrases[i])); + UNIT_ASSERT(matches[i].second == i); + } + + trie.FindPhrases(badphrase, strlen(badphrase), matches); + UNIT_ASSERT(matches.size() == 0); +} + +void TCompactTrieTest::TestAddGet() { + TCompactTrieBuilder<char> builder; + builder.Add("abcd", 4, 1); + builder.Add("acde", 4, 2); + ui64 dummy; + UNIT_ASSERT(builder.Find("abcd", 4, &dummy)); + UNIT_ASSERT(1 == dummy); + UNIT_ASSERT(builder.Find("acde", 4, &dummy)); + UNIT_ASSERT(2 == dummy); + UNIT_ASSERT(!builder.Find("fgdgfacde", 9, &dummy)); + UNIT_ASSERT(!builder.Find("ab", 2, &dummy)); +} + +void TCompactTrieTest::TestEmpty() { + TCompactTrieBuilder<char> builder; + ui64 dummy = 12345; + size_t prefixLen; + UNIT_ASSERT(!builder.Find("abc", 3, &dummy)); + TBufferOutput bufout; + builder.Save(bufout); + + TCompactTrie<char> trie(bufout.Buffer().Data(), bufout.Buffer().Size()); + UNIT_ASSERT(!trie.Find("abc", 3, &dummy)); + UNIT_ASSERT(!trie.Find("", 0, &dummy)); + UNIT_ASSERT(!trie.FindLongestPrefix("abc", 3, &prefixLen, &dummy)); + UNIT_ASSERT(!trie.FindLongestPrefix("", 0, &prefixLen, &dummy)); + UNIT_ASSERT_EQUAL(12345, dummy); + + UNIT_ASSERT(trie.Begin() == trie.End()); + + TCompactTrie<> trieNull; + + UNIT_ASSERT(!trieNull.Find(" ", 1)); + + TCompactTrie<>::TPhraseMatchVector matches; + trieNull.FindPhrases(" ", 1, matches); // just to be sure it doesn't crash + + UNIT_ASSERT(trieNull.Begin() == trieNull.End()); +} + +void TCompactTrieTest::TestUninitializedNonEmpty() { + TBufferOutput bufout; + CreateTrie<char>(bufout, false, false); + TCompactTrie<char> trie(bufout.Buffer().Data(), bufout.Buffer().Size()); + typedef TCompactTrie<char>::TKey TKey; + typedef TCompactTrie<char>::TConstIterator TIter; + + TCompactTrie<char> tails = trie.FindTails("abd", 3); // A trie that has empty value and no data. + UNIT_ASSERT(!tails.IsEmpty()); + UNIT_ASSERT(!tails.IsInitialized()); + const TKey wideKey = MakeWideKey<char>("c", 1); + TIter it = tails.UpperBound(wideKey); + UNIT_ASSERT(it == tails.End()); + UNIT_ASSERT(it != tails.Begin()); + --it; + UNIT_ASSERT(it == tails.Begin()); + ++it; + UNIT_ASSERT(it == tails.End()); +} + +static char RandChar() { + return char(RandomNumber<size_t>() % 256); +} + +static TString RandStr(const size_t max) { + size_t len = RandomNumber<size_t>() % max; + TString key; + for (size_t j = 0; j < len; ++j) + key += RandChar(); + return key; +} + +template <class T, bool minimize> +void TCompactTrieTest::TestRandom(const size_t n, const size_t maxKeySize) { + const TStringBuf EMPTY_KEY = TStringBuf("", 1); + TCompactTrieBuilder<char, typename T::TData, T> builder; + typedef TMap<TString, typename T::TData> TKeys; + TKeys keys; + + typename T::TData dummy; + for (size_t i = 0; i < n; ++i) { + const TString key = RandStr(maxKeySize); + if (key != EMPTY_KEY && keys.find(key) == keys.end()) { + const typename T::TData val = T::Data(key); + keys[key] = val; + UNIT_ASSERT_C(!builder.Find(key.data(), key.size(), &dummy), "key = " << HexEncode(TString(key))); + builder.Add(key.data(), key.size(), val); + UNIT_ASSERT_C(builder.Find(key.data(), key.size(), &dummy), "key = " << HexEncode(TString(key))); + UNIT_ASSERT(dummy == val); + } + } + + TBufferStream stream; + size_t len = builder.Save(stream); + TCompactTrie<char, typename T::TData, T> trie(stream.Buffer().Data(), len); + + TBufferStream buftmp; + if (minimize) { + CompactTrieMinimize<T>(buftmp, stream.Buffer().Data(), len, false); + } + TCompactTrie<char, typename T::TData, T> trieMin(buftmp.Buffer().Data(), buftmp.Buffer().Size()); + + TCompactTrieBuilder<char, typename T::TData, T> prefixGroupedBuilder(CTBF_PREFIX_GROUPED); + + for (typename TKeys::const_iterator i = keys.begin(), mi = keys.end(); i != mi; ++i) { + UNIT_ASSERT(!prefixGroupedBuilder.Find(i->first.c_str(), i->first.size(), &dummy)); + UNIT_ASSERT(trie.Find(i->first.c_str(), i->first.size(), &dummy)); + UNIT_ASSERT(dummy == i->second); + if (minimize) { + UNIT_ASSERT(trieMin.Find(i->first.c_str(), i->first.size(), &dummy)); + UNIT_ASSERT(dummy == i->second); + } + + prefixGroupedBuilder.Add(i->first.c_str(), i->first.size(), dummy); + UNIT_ASSERT(prefixGroupedBuilder.Find(i->first.c_str(), i->first.size(), &dummy)); + + for (typename TKeys::const_iterator j = keys.begin(), end = keys.end(); j != end; ++j) { + typename T::TData valFound; + if (j->first <= i->first) { + UNIT_ASSERT(prefixGroupedBuilder.Find(j->first.c_str(), j->first.size(), &valFound)); + UNIT_ASSERT_VALUES_EQUAL(j->second, valFound); + } else { + UNIT_ASSERT(!prefixGroupedBuilder.Find(j->first.c_str(), j->first.size(), &valFound)); + } + } + } + + TBufferStream prefixGroupedBuffer; + prefixGroupedBuilder.Save(prefixGroupedBuffer); + + UNIT_ASSERT_VALUES_EQUAL(stream.Buffer().Size(), prefixGroupedBuffer.Buffer().Size()); + UNIT_ASSERT(0 == memcmp(stream.Buffer().Data(), prefixGroupedBuffer.Buffer().Data(), stream.Buffer().Size())); +} + +void TCompactTrieTest::TestRandom() { + TestRandom<TIntPacker<ui64>, true>(1000, 1000); + TestRandom<TIntPacker<int>, true>(100, 100); + TestRandom<TDummyPacker<ui64>, true>(0, 0); + TestRandom<TDummyPacker<ui64>, true>(100, 3); + TestRandom<TDummyPacker<ui64>, true>(100, 100); + TestRandom<TStrokaPacker, true>(100, 100); +} + +void TCompactTrieTest::TestFindTailsImpl(const TString& prefix) { + TCompactTrieBuilder<> builder; + + TMap<TString, ui64> input; + + for (auto& i : SampleData) { + TString temp = i; + ui64 val = temp.size() * 2; + builder.Add(temp.data(), temp.size(), val); + if (temp.StartsWith(prefix)) { + input[temp.substr(prefix.size())] = val; + } + } + + typedef TCompactTrie<> TTrie; + + TBufferStream stream; + size_t len = builder.Save(stream); + TTrie trie(stream.Buffer().Data(), len); + + TTrie subtrie = trie.FindTails(prefix.data(), prefix.size()); + + TMap<TString, ui64> output; + + for (TTrie::TConstIterator i = subtrie.Begin(), mi = subtrie.End(); i != mi; ++i) { + TTrie::TValueType val = *i; + output[TString(val.first.data(), val.first.size())] = val.second; + } + UNIT_ASSERT(input.size() == output.size()); + UNIT_ASSERT(input == output); + + TBufferStream buftmp; + CompactTrieMinimize<TTrie::TPacker>(buftmp, stream.Buffer().Data(), len, false); + TTrie trieMin(buftmp.Buffer().Data(), buftmp.Buffer().Size()); + + subtrie = trieMin.FindTails(prefix.data(), prefix.size()); + output.clear(); + + for (TTrie::TConstIterator i = subtrie.Begin(), mi = subtrie.End(); i != mi; ++i) { + TTrie::TValueType val = *i; + output[TString(val.first.data(), val.first.size())] = val.second; + } + UNIT_ASSERT(input.size() == output.size()); + UNIT_ASSERT(input == output); +} + +void TCompactTrieTest::TestPrefixGrouped() { + TBuffer b1b; + TCompactTrieBuilder<char, ui32> b1(CTBF_PREFIX_GROUPED); + const char* data[] = { + "Kazan", + "Moscow", + "Monino", + "Murmansk", + "Fryanovo", + "Fryazino", + "Fryazevo", + "Tumen", + }; + + for (size_t i = 0; i < Y_ARRAY_SIZE(data); ++i) { + ui32 val = strlen(data[i]) + 1; + b1.Add(data[i], strlen(data[i]), val); + for (size_t j = 0; j < Y_ARRAY_SIZE(data); ++j) { + ui32 mustHave = strlen(data[j]) + 1; + ui32 found = 0; + if (j <= i) { + UNIT_ASSERT(b1.Find(data[j], strlen(data[j]), &found)); + UNIT_ASSERT_VALUES_EQUAL(mustHave, found); + } else { + UNIT_ASSERT(!b1.Find(data[j], strlen(data[j]), &found)); + } + } + } + + { + TBufferOutput b1bo(b1b); + b1.Save(b1bo); + } + + TCompactTrie<char, ui32> t1(TBlob::FromBuffer(b1b)); + + //t1.Print(Cerr); + + for (auto& i : data) { + ui32 v; + UNIT_ASSERT(t1.Find(i, strlen(i), &v)); + UNIT_ASSERT_VALUES_EQUAL(strlen(i) + 1, v); + } +} + +void TCompactTrieTest::CrashTestPrefixGrouped() { + TCompactTrieBuilder<char, ui32> builder(CTBF_PREFIX_GROUPED); + const char* data[] = { + "Fryazino", + "Fryanovo", + "Monino", + "", + "Fryazevo", + }; + bool wasException = false; + try { + for (size_t i = 0; i < Y_ARRAY_SIZE(data); ++i) { + builder.Add(data[i], strlen(data[i]), i + 1); + } + } catch (const yexception& e) { + wasException = true; + UNIT_ASSERT(strstr(e.what(), "Bad input order - expected input strings to be prefix-grouped.")); + } + UNIT_ASSERT_C(wasException, "CrashTestPrefixGrouped"); +} + +void TCompactTrieTest::TestMergeFromFile() { + { + TCompactTrieBuilder<> b; + b.Add("yandex", 12); + b.Add("google", 13); + b.Add("mail", 14); + TUnbufferedFileOutput out(GetSystemTempDir() + "/TCompactTrieTest-TestMerge-ru"); + b.Save(out); + } + + { + TCompactTrieBuilder<> b; + b.Add("yandex", 112); + b.Add("google", 113); + b.Add("yahoo", 114); + TUnbufferedFileOutput out(GetSystemTempDir() + "/TCompactTrieTest-TestMerge-com"); + b.Save(out); + } + + { + TCompactTrieBuilder<> b; + UNIT_ASSERT(b.AddSubtreeInFile("com.", GetSystemTempDir() + "/TCompactTrieTest-TestMerge-com")); + UNIT_ASSERT(b.Add("org.kernel", 22)); + UNIT_ASSERT(b.AddSubtreeInFile("ru.", GetSystemTempDir() + "/TCompactTrieTest-TestMerge-ru")); + TUnbufferedFileOutput out(GetSystemTempDir() + "/TCompactTrieTest-TestMerge-res"); + b.Save(out); + } + + TCompactTrie<> trie(TBlob::FromFileSingleThreaded(GetSystemTempDir() + "/TCompactTrieTest-TestMerge-res")); + UNIT_ASSERT_VALUES_EQUAL(12u, trie.Get("ru.yandex")); + UNIT_ASSERT_VALUES_EQUAL(13u, trie.Get("ru.google")); + UNIT_ASSERT_VALUES_EQUAL(14u, trie.Get("ru.mail")); + UNIT_ASSERT_VALUES_EQUAL(22u, trie.Get("org.kernel")); + UNIT_ASSERT_VALUES_EQUAL(112u, trie.Get("com.yandex")); + UNIT_ASSERT_VALUES_EQUAL(113u, trie.Get("com.google")); + UNIT_ASSERT_VALUES_EQUAL(114u, trie.Get("com.yahoo")); + + unlink((GetSystemTempDir() + "/TCompactTrieTest-TestMerge-res").data()); + unlink((GetSystemTempDir() + "/TCompactTrieTest-TestMerge-com").data()); + unlink((GetSystemTempDir() + "/TCompactTrieTest-TestMerge-ru").data()); +} + +void TCompactTrieTest::TestMergeFromBuffer() { + TArrayWithSizeHolder<char> buffer1; + { + TCompactTrieBuilder<> b; + b.Add("aaaaa", 1); + b.Add("bbbbb", 2); + b.Add("ccccc", 3); + buffer1.Resize(b.MeasureByteSize()); + TMemoryOutput out(buffer1.Get(), buffer1.Size()); + b.Save(out); + } + + TArrayWithSizeHolder<char> buffer2; + { + TCompactTrieBuilder<> b; + b.Add("aaaaa", 10); + b.Add("bbbbb", 20); + b.Add("ccccc", 30); + b.Add("xxxxx", 40); + b.Add("yyyyy", 50); + buffer2.Resize(b.MeasureByteSize()); + TMemoryOutput out(buffer2.Get(), buffer2.Size()); + b.Save(out); + } + + { + TCompactTrieBuilder<> b; + UNIT_ASSERT(b.AddSubtreeInBuffer("com.", std::move(buffer1))); + UNIT_ASSERT(b.Add("org.upyachka", 42)); + UNIT_ASSERT(b.AddSubtreeInBuffer("ru.", std::move(buffer2))); + TUnbufferedFileOutput out(GetSystemTempDir() + "/TCompactTrieTest-TestMergeFromBuffer-res"); + b.Save(out); + } + + TCompactTrie<> trie(TBlob::FromFileSingleThreaded(GetSystemTempDir() + "/TCompactTrieTest-TestMergeFromBuffer-res")); + UNIT_ASSERT_VALUES_EQUAL(10u, trie.Get("ru.aaaaa")); + UNIT_ASSERT_VALUES_EQUAL(20u, trie.Get("ru.bbbbb")); + UNIT_ASSERT_VALUES_EQUAL(40u, trie.Get("ru.xxxxx")); + UNIT_ASSERT_VALUES_EQUAL(42u, trie.Get("org.upyachka")); + UNIT_ASSERT_VALUES_EQUAL(1u, trie.Get("com.aaaaa")); + UNIT_ASSERT_VALUES_EQUAL(2u, trie.Get("com.bbbbb")); + UNIT_ASSERT_VALUES_EQUAL(3u, trie.Get("com.ccccc")); + + unlink((GetSystemTempDir() + "/TCompactTrieTest-TestMergeFromBuffer-res").data()); +} + +void TCompactTrieTest::TestUnique() { + TestUniqueImpl(false); + TestUniqueImpl(true); +} + +void TCompactTrieTest::TestUniqueImpl(bool isPrefixGrouped) { + TCompactTrieBuilder<char, ui32> builder(CTBF_UNIQUE | (isPrefixGrouped ? CTBF_PREFIX_GROUPED : CTBF_NONE)); + const char* data[] = { + "Kazan", + "Moscow", + "Monino", + "Murmansk", + "Fryanovo", + "Fryazino", + "Fryazevo", + "Fry", + "Tumen", + }; + for (size_t i = 0; i < Y_ARRAY_SIZE(data); ++i) { + UNIT_ASSERT_C(builder.Add(data[i], strlen(data[i]), i + 1), i); + } + bool wasException = false; + try { + builder.Add(data[4], strlen(data[4]), 20); + } catch (const yexception& e) { + wasException = true; + UNIT_ASSERT(strstr(e.what(), "Duplicate key")); + } + UNIT_ASSERT_C(wasException, "TestUnique"); +} + +void TCompactTrieTest::TestAddRetValue() { + TCompactTrieBuilder<char, ui32> builder; + const char* data[] = { + "Kazan", + "Moscow", + "Monino", + "Murmansk", + "Fryanovo", + "Fryazino", + "Fryazevo", + "Fry", + "Tumen", + }; + for (size_t i = 0; i < Y_ARRAY_SIZE(data); ++i) { + UNIT_ASSERT(builder.Add(data[i], strlen(data[i]), i + 1)); + UNIT_ASSERT(!builder.Add(data[i], strlen(data[i]), i + 2)); + ui32 value; + UNIT_ASSERT(builder.Find(data[i], strlen(data[i]), &value)); + UNIT_ASSERT(value == i + 2); + } +} + +void TCompactTrieTest::TestClear() { + TCompactTrieBuilder<char, ui32> builder; + const char* data[] = { + "Kazan", + "Moscow", + "Monino", + "Murmansk", + "Fryanovo", + "Fryazino", + "Fryazevo", + "Fry", + "Tumen", + }; + for (size_t i = 0; i < Y_ARRAY_SIZE(data); ++i) { + builder.Add(data[i], strlen(data[i]), i + 1); + } + UNIT_ASSERT(builder.GetEntryCount() == Y_ARRAY_SIZE(data)); + builder.Clear(); + UNIT_ASSERT(builder.GetEntryCount() == 0); + UNIT_ASSERT(builder.GetNodeCount() == 1); +} + +void TCompactTrieTest::TestFindTails() { + TestFindTailsImpl("aa"); + TestFindTailsImpl("bb"); + TestFindTailsImpl("fb"); + TestFindTailsImpl("fbc"); + TestFindTailsImpl("fbbaa"); +} + +template <class T> +class TCompactTrieTest::TDummyPacker: public TNullPacker<T> { +public: + static T Data(const TString&) { + T data; + TNullPacker<T>().UnpackLeaf(nullptr, data); + return data; + } + + typedef T TData; +}; + +class TCompactTrieTest::TStrokaPacker: public TCompactTriePacker<TString> { +public: + typedef TString TData; + + static TString Data(const TString& str) { + return str; + } +}; + +template <class T> +class TCompactTrieTest::TIntPacker: public TCompactTriePacker<T> { +public: + typedef T TData; + + static TData Data(const TString&) { + return RandomNumber<std::make_unsigned_t<T>>(); + } +}; + +void TCompactTrieTest::TestIterateEmptyKey() { + TBuffer trieBuffer; + { + TCompactTrieBuilder<char, ui32> builder; + UNIT_ASSERT(builder.Add("", 1)); + TBufferStream trieBufferO(trieBuffer); + builder.Save(trieBufferO); + } + TCompactTrie<char, ui32> trie(TBlob::FromBuffer(trieBuffer)); + ui32 val; + UNIT_ASSERT(trie.Find("", &val)); + UNIT_ASSERT(val == 1); + TCompactTrie<char, ui32>::TConstIterator it = trie.Begin(); + UNIT_ASSERT(it.GetKey().empty()); + UNIT_ASSERT(it.GetValue() == 1); +} + +void TCompactTrieTest::TestTrieSet() { + TBuffer buffer; + { + TCompactTrieSet<char>::TBuilder builder; + UNIT_ASSERT(builder.Add("a", 0)); + UNIT_ASSERT(builder.Add("ab", 1)); + UNIT_ASSERT(builder.Add("abc", 1)); + UNIT_ASSERT(builder.Add("abcd", 0)); + UNIT_ASSERT(!builder.Add("abcd", 1)); + + TBufferStream stream(buffer); + builder.Save(stream); + } + + TCompactTrieSet<char> set(TBlob::FromBuffer(buffer)); + UNIT_ASSERT(set.Has("a")); + UNIT_ASSERT(set.Has("ab")); + UNIT_ASSERT(set.Has("abc")); + UNIT_ASSERT(set.Has("abcd")); + UNIT_ASSERT(!set.Has("abcde")); + UNIT_ASSERT(!set.Has("aa")); + UNIT_ASSERT(!set.Has("b")); + UNIT_ASSERT(!set.Has("")); + + TCompactTrieSet<char> tails; + UNIT_ASSERT(set.FindTails("a", tails)); + UNIT_ASSERT(tails.Has("b")); + UNIT_ASSERT(tails.Has("bcd")); + UNIT_ASSERT(!tails.Has("ab")); + UNIT_ASSERT(!set.Has("")); + + TCompactTrieSet<char> empty; + UNIT_ASSERT(set.FindTails("abcd", empty)); + UNIT_ASSERT(!empty.Has("a")); + UNIT_ASSERT(!empty.Has("b")); + UNIT_ASSERT(!empty.Has("c")); + UNIT_ASSERT(!empty.Has("d")); + UNIT_ASSERT(!empty.Has("d")); + + UNIT_ASSERT(empty.Has("")); // contains only empty string +} + +// Tests for trie with vector (list, set) values + +TVector<TUtf16String> TCompactTrieTest::GetSampleKeys(size_t nKeys) const { + Y_ASSERT(nKeys <= 10); + TString sampleKeys[] = {"a", "b", "ac", "bd", "abe", "bcf", "deg", "ah", "xy", "abc"}; + TVector<TUtf16String> result; + for (size_t i = 0; i < nKeys; i++) + result.push_back(ASCIIToWide(sampleKeys[i])); + return result; +} + +template <class TContainer> +TVector<TContainer> TCompactTrieTest::GetSampleVectorData(size_t nValues) { + TVector<TContainer> data; + for (size_t i = 0; i < nValues; i++) { + data.push_back(TContainer()); + for (size_t j = 0; j < i; j++) + data[i].insert(data[i].end(), (typename TContainer::value_type)((j == 3) ? 0 : (1 << (j * 5)))); + } + return data; +} + +template <class TContainer> +TVector<TContainer> TCompactTrieTest::GetSampleTextVectorData(size_t nValues) { + TVector<TContainer> data; + for (size_t i = 0; i < nValues; i++) { + data.push_back(TContainer()); + for (size_t j = 0; j < i; j++) + data[i].insert(data[i].end(), TString("abc") + ToString<size_t>(j)); + } + return data; +} + +template <class T> +void TCompactTrieTest::CheckEquality(const T& value1, const T& value2) const { + UNIT_ASSERT_VALUES_EQUAL(value1, value2); +} + +template <> +void TCompactTrieTest::CheckEquality<TVector<i64>>(const TVector<i64>& value1, const TVector<i64>& value2) const { + UNIT_ASSERT_VALUES_EQUAL(value1.size(), value2.size()); + for (size_t i = 0; i < value1.size(); i++) + UNIT_ASSERT_VALUES_EQUAL(value1[i], value2[i]); +} + +template <class TContainer> +void TCompactTrieTest::TestTrieWithContainers(const TVector<TUtf16String>& keys, const TVector<TContainer>& sampleData, TString methodName) { + TString fileName = GetSystemTempDir() + "/TCompactTrieTest-TestTrieWithContainers-" + methodName; + + TCompactTrieBuilder<wchar16, TContainer> b; + for (size_t i = 0; i < keys.size(); i++) { + b.Add(keys[i], sampleData[i]); + } + TUnbufferedFileOutput out(fileName); + b.Save(out); + + TCompactTrie<wchar16, TContainer> trie(TBlob::FromFileSingleThreaded(fileName)); + for (size_t i = 0; i < keys.size(); i++) { + TContainer value = trie.Get(keys[i]); + UNIT_ASSERT_VALUES_EQUAL(value.size(), sampleData[i].size()); + typename TContainer::const_iterator p = value.begin(); + typename TContainer::const_iterator p1 = sampleData[i].begin(); + for (; p != value.end(); p++, p1++) + CheckEquality<typename TContainer::value_type>(*p, *p1); + } + + unlink(fileName.data()); +} + +template <> +void TCompactTrieTest::TestTrieWithContainers<std::pair<TUtf16String, TVector<i64>>>(const TVector<TUtf16String>& keys, const TVector<std::pair<TUtf16String, TVector<i64>>>& sampleData, TString methodName) { + typedef std::pair<TUtf16String, TVector<i64>> TContainer; + TString fileName = GetSystemTempDir() + "/TCompactTrieTest-TestTrieWithContainers-" + methodName; + + TCompactTrieBuilder<wchar16, TContainer> b; + for (size_t i = 0; i < keys.size(); i++) { + b.Add(keys[i], sampleData[i]); + } + TUnbufferedFileOutput out(fileName); + b.Save(out); + + TCompactTrie<wchar16, TContainer> trie(TBlob::FromFileSingleThreaded(fileName)); + for (size_t i = 0; i < keys.size(); i++) { + TContainer value = trie.Get(keys[i]); + CheckEquality<TContainer::first_type>(value.first, sampleData[i].first); + CheckEquality<TContainer::second_type>(value.second, sampleData[i].second); + } + + unlink(fileName.data()); +} + +void TCompactTrieTest::TestTrieForVectorInt64() { + TestTrieWithContainers<TVector<i64>>(GetSampleKeys(10), GetSampleVectorData<TVector<i64>>(10), "v-i64"); +} + +void TCompactTrieTest::TestTrieForListInt64() { + TestTrieWithContainers<TList<i64>>(GetSampleKeys(10), GetSampleVectorData<TList<i64>>(10), "l-i64"); +} + +void TCompactTrieTest::TestTrieForSetInt64() { + TestTrieWithContainers<TSet<i64>>(GetSampleKeys(10), GetSampleVectorData<TSet<i64>>(10), "s-i64"); +} + +void TCompactTrieTest::TestTrieForVectorStroka() { + TestTrieWithContainers<TVector<TString>>(GetSampleKeys(10), GetSampleTextVectorData<TVector<TString>>(10), "v-str"); +} + +void TCompactTrieTest::TestTrieForListStroka() { + TestTrieWithContainers<TList<TString>>(GetSampleKeys(10), GetSampleTextVectorData<TList<TString>>(10), "l-str"); +} + +void TCompactTrieTest::TestTrieForSetStroka() { + TestTrieWithContainers<TSet<TString>>(GetSampleKeys(10), GetSampleTextVectorData<TSet<TString>>(10), "s-str"); +} + +void TCompactTrieTest::TestTrieForVectorWtroka() { + TVector<TVector<TString>> data = GetSampleTextVectorData<TVector<TString>>(10); + TVector<TVector<TUtf16String>> wData; + for (size_t i = 0; i < data.size(); i++) { + wData.push_back(TVector<TUtf16String>()); + for (size_t j = 0; j < data[i].size(); j++) + wData[i].push_back(UTF8ToWide(data[i][j])); + } + TestTrieWithContainers<TVector<TUtf16String>>(GetSampleKeys(10), wData, "v-wtr"); +} + +void TCompactTrieTest::TestTrieForVectorFloat() { + TestTrieWithContainers<TVector<float>>(GetSampleKeys(10), GetSampleVectorData<TVector<float>>(10), "v-float"); +} + +void TCompactTrieTest::TestTrieForVectorDouble() { + TestTrieWithContainers<TVector<double>>(GetSampleKeys(10), GetSampleVectorData<TVector<double>>(10), "v-double"); +} + +void TCompactTrieTest::TestTrieForListVectorInt64() { + TVector<i64> tmp; + tmp.push_back(0); + TList<TVector<i64>> dataElement(5, tmp); + TVector<TList<TVector<i64>>> data(10, dataElement); + TestTrieWithContainers<TList<TVector<i64>>>(GetSampleKeys(10), data, "l-v-i64"); +} + +void TCompactTrieTest::TestTrieForPairWtrokaVectorInt64() { + TVector<TUtf16String> keys = GetSampleKeys(10); + TVector<TVector<i64>> values = GetSampleVectorData<TVector<i64>>(10); + TVector<std::pair<TUtf16String, TVector<i64>>> data; + for (size_t i = 0; i < 10; i++) + data.push_back(std::pair<TUtf16String, TVector<i64>>(keys[i] + u"_v", values[i])); + TestTrieWithContainers<std::pair<TUtf16String, TVector<i64>>>(keys, data, "pair-str-v-i64"); +} + +void TCompactTrieTest::TestEmptyValueOutOfOrder() { + TBufferOutput buffer; + using TSymbol = ui32; + { + TCompactTrieBuilder<TSymbol, ui32> builder; + TSymbol key = 1; + builder.Add(&key, 1, 10); + builder.Add(nullptr, 0, 14); + builder.Save(buffer); + } + { + TCompactTrie<TSymbol, ui32> trie(buffer.Buffer().Data(), buffer.Buffer().Size()); + UNIT_ASSERT(trie.Find(nullptr, 0)); + } +} + +void TCompactTrieTest::TestFindLongestPrefixWithEmptyValue() { + TBufferOutput buffer; + { + TCompactTrieBuilder<wchar16, ui32> builder; + builder.Add(u"", 42); + builder.Add(u"yandex", 271828); + builder.Add(u"ya", 31415); + builder.Save(buffer); + } + { + TCompactTrie<wchar16, ui32> trie(buffer.Buffer().Data(), buffer.Buffer().Size()); + size_t prefixLen = 123; + ui32 value = 0; + + UNIT_ASSERT(trie.FindLongestPrefix(u"google", &prefixLen, &value)); + UNIT_ASSERT(prefixLen == 0); + UNIT_ASSERT(value == 42); + + UNIT_ASSERT(trie.FindLongestPrefix(u"yahoo", &prefixLen, &value)); + UNIT_ASSERT(prefixLen == 2); + UNIT_ASSERT(value == 31415); + } +} + +template <typename TChar> +struct TConvertKey { + static inline TString Convert(const TStringBuf& key) { + return ToString(key); + } +}; + +template <> +struct TConvertKey<wchar16> { + static inline TUtf16String Convert(const TStringBuf& key) { + return UTF8ToWide(key); + } +}; + +template <> +struct TConvertKey<wchar32> { + static inline TUtf32String Convert(const TStringBuf& key) { + return TUtf32String::FromUtf8(key); + } +}; + +template <class TSearchIter, class TKeyBuf> +static void MoveIter(TSearchIter& iter, const TKeyBuf& key) { + for (size_t i = 0; i < key.length(); ++i) { + UNIT_ASSERT(iter.Advance(key[i])); + } +} + +template <typename TChar> +void TCompactTrieTest::TestSearchIterImpl() { + TBufferOutput buffer; + { + TCompactTrieBuilder<TChar, ui32> builder; + TStringBuf data[] = { + TStringBuf("abaab"), + TStringBuf("abcdef"), + TStringBuf("abbbc"), + TStringBuf("bdfaa"), + }; + for (size_t i = 0; i < Y_ARRAY_SIZE(data); ++i) { + builder.Add(TConvertKey<TChar>::Convert(data[i]), i + 1); + } + builder.Save(buffer); + } + + TCompactTrie<TChar, ui32> trie(buffer.Buffer().Data(), buffer.Buffer().Size()); + ui32 value = 0; + auto iter(MakeSearchIterator(trie)); + MoveIter(iter, TConvertKey<TChar>::Convert(TStringBuf("abc"))); + UNIT_ASSERT(!iter.GetValue(&value)); + + iter = MakeSearchIterator(trie); + MoveIter(iter, TConvertKey<TChar>::Convert(TStringBuf("abbbc"))); + UNIT_ASSERT(iter.GetValue(&value)); + UNIT_ASSERT_EQUAL(value, 3); + + iter = MakeSearchIterator(trie); + UNIT_ASSERT(iter.Advance(TConvertKey<TChar>::Convert(TStringBuf("bdfa")))); + UNIT_ASSERT(!iter.GetValue(&value)); + + iter = MakeSearchIterator(trie); + UNIT_ASSERT(iter.Advance(TConvertKey<TChar>::Convert(TStringBuf("bdfaa")))); + UNIT_ASSERT(iter.GetValue(&value)); + UNIT_ASSERT_EQUAL(value, 4); + + UNIT_ASSERT(!MakeSearchIterator(trie).Advance(TChar('z'))); + UNIT_ASSERT(!MakeSearchIterator(trie).Advance(TConvertKey<TChar>::Convert(TStringBuf("cdf")))); + UNIT_ASSERT(!MakeSearchIterator(trie).Advance(TConvertKey<TChar>::Convert(TStringBuf("abca")))); +} + +void TCompactTrieTest::TestSearchIterChar() { + TestSearchIterImpl<char>(); +} + +void TCompactTrieTest::TestSearchIterWchar() { + TestSearchIterImpl<wchar16>(); +} + +void TCompactTrieTest::TestSearchIterWchar32() { + TestSearchIterImpl<wchar32>(); +} + +void TCompactTrieTest::TestCopyAndAssignment() { + TBufferOutput bufout; + typedef TCompactTrie<> TTrie; + CreateTrie<char>(bufout, false, false); + TTrie trie(bufout.Buffer().Data(), bufout.Buffer().Size()); + TTrie copy(trie); + UNIT_ASSERT(copy.HasCorrectSkipper()); + TTrie assign; + assign = trie; + UNIT_ASSERT(assign.HasCorrectSkipper()); + TTrie move(std::move(trie)); + UNIT_ASSERT(move.HasCorrectSkipper()); + TTrie moveAssign; + moveAssign = TTrie(bufout.Buffer().Data(), bufout.Buffer().Size()); + UNIT_ASSERT(moveAssign.HasCorrectSkipper()); +} + +template <class TTrie> +void TCompactTrieTest::TestFirstSymbolIteratorForTrie(const TTrie& trie, const TStringBuf& narrowAnswers) { + NCompactTrie::TFirstSymbolIterator<TTrie> it; + it.SetTrie(trie, trie.GetSkipper()); + typename TTrie::TKey answers = MakeWideKey<typename TTrie::TSymbol>(narrowAnswers); + auto answer = answers.begin(); + for (; !it.AtEnd(); it.MakeStep(), ++answer) { + UNIT_ASSERT(answer != answers.end()); + UNIT_ASSERT(it.GetKey() == *answer); + } + UNIT_ASSERT(answer == answers.end()); +} + +template <class TSymbol> +void TCompactTrieTest::TestFirstSymbolIterator() { + TBufferOutput bufout; + typedef TCompactTrie<TSymbol> TTrie; + CreateTrie<TSymbol>(bufout, false, false); + TTrie trie(bufout.Buffer().Data(), bufout.Buffer().Size()); + TStringBuf rootAnswers = "abcdf"; + TestFirstSymbolIteratorForTrie(trie, rootAnswers); + TStringBuf aAnswers = "abcd"; + TestFirstSymbolIteratorForTrie(trie.FindTails(MakeWideKey<TSymbol>("a", 1)), aAnswers); +} + +void TCompactTrieTest::TestFirstSymbolIterator8() { + TestFirstSymbolIterator<char>(); +} + +void TCompactTrieTest::TestFirstSymbolIterator16() { + TestFirstSymbolIterator<wchar16>(); +} + +void TCompactTrieTest::TestFirstSymbolIterator32() { + TestFirstSymbolIterator<ui32>(); +} + +void TCompactTrieTest::TestFirstSymbolIteratorChar32() { + TestFirstSymbolIterator<wchar32>(); +} + + +void TCompactTrieTest::TestArrayPacker() { + using TDataInt = std::array<int, 2>; + const std::pair<TString, TDataInt> dataXxx{"xxx", {{15, 16}}}; + const std::pair<TString, TDataInt> dataYyy{"yyy", {{20, 30}}}; + + TCompactTrieBuilder<char, TDataInt> trieBuilderOne; + trieBuilderOne.Add(dataXxx.first, dataXxx.second); + trieBuilderOne.Add(dataYyy.first, dataYyy.second); + + TBufferOutput bufferOne; + trieBuilderOne.Save(bufferOne); + + const TCompactTrie<char, TDataInt> trieOne(bufferOne.Buffer().Data(), bufferOne.Buffer().Size()); + UNIT_ASSERT_VALUES_EQUAL(dataXxx.second, trieOne.Get(dataXxx.first)); + UNIT_ASSERT_VALUES_EQUAL(dataYyy.second, trieOne.Get(dataYyy.first)); + + using TDataStroka = std::array<TString, 2>; + const std::pair<TString, TDataStroka> dataZzz{"zzz", {{"hello", "there"}}}; + const std::pair<TString, TDataStroka> dataWww{"www", {{"half", "life"}}}; + + TCompactTrieBuilder<char, TDataStroka> trieBuilderTwo; + trieBuilderTwo.Add(dataZzz.first, dataZzz.second); + trieBuilderTwo.Add(dataWww.first, dataWww.second); + + TBufferOutput bufferTwo; + trieBuilderTwo.Save(bufferTwo); + + const TCompactTrie<char, TDataStroka> trieTwo(bufferTwo.Buffer().Data(), bufferTwo.Buffer().Size()); + UNIT_ASSERT_VALUES_EQUAL(dataZzz.second, trieTwo.Get(dataZzz.first)); + UNIT_ASSERT_VALUES_EQUAL(dataWww.second, trieTwo.Get(dataWww.first)); +} + +void TCompactTrieTest::TestBuilderFindLongestPrefix() { + const size_t sizes[] = {10, 100}; + const double branchProbabilities[] = {0.01, 0.1, 0.5, 0.9, 0.99}; + for (size_t size : sizes) { + for (double branchProbability : branchProbabilities) { + TestBuilderFindLongestPrefix(size, branchProbability, false, false); + TestBuilderFindLongestPrefix(size, branchProbability, false, true); + TestBuilderFindLongestPrefix(size, branchProbability, true, false); + TestBuilderFindLongestPrefix(size, branchProbability, true, true); + } + } +} + +void TCompactTrieTest::TestBuilderFindLongestPrefix(size_t keysCount, double branchProbability, bool isPrefixGrouped, bool hasEmptyKey) { + TVector<TString> keys; + TString keyToAdd; + for (size_t i = 0; i < keysCount; ++i) { + const size_t prevKeyLen = keyToAdd.Size(); + // add two random chars to prev key + keyToAdd += RandChar(); + keyToAdd += RandChar(); + const bool changeBranch = prevKeyLen && RandomNumber<double>() < branchProbability; + if (changeBranch) { + const size_t branchPlace = RandomNumber<size_t>(prevKeyLen + 1); // random place in [0, prevKeyLen] + *(keyToAdd.begin() + branchPlace) = RandChar(); + } + keys.push_back(keyToAdd); + } + + if (isPrefixGrouped) + Sort(keys.begin(), keys.end()); + else + Shuffle(keys.begin(), keys.end()); + + TCompactTrieBuilder<char, TString> builder(isPrefixGrouped ? CTBF_PREFIX_GROUPED : CTBF_NONE); + const TString EMPTY_VALUE = "empty"; + if (hasEmptyKey) + builder.Add(nullptr, 0, EMPTY_VALUE); + + for (size_t i = 0; i < keysCount; ++i) { + const TString& key = keys[i]; + + for (size_t j = 0; j < keysCount; ++j) { + const TString& otherKey = keys[j]; + const bool exists = j < i; + size_t expectedSize = 0; + if (exists) { + expectedSize = otherKey.size(); + } else { + size_t max = 0; + for (size_t k = 0; k < i; ++k) + if (keys[k].Size() < otherKey.Size() && keys[k].Size() > max && otherKey.StartsWith(keys[k])) + max = keys[k].Size(); + expectedSize = max; + } + + size_t prefixSize = 0xfcfcfc; + TString value = "abcd"; + const bool expectedResult = hasEmptyKey || expectedSize != 0; + UNIT_ASSERT_VALUES_EQUAL_C(expectedResult, builder.FindLongestPrefix(otherKey.data(), otherKey.size(), &prefixSize, &value), "otherKey = " << HexEncode(otherKey)); + if (expectedResult) { + UNIT_ASSERT_VALUES_EQUAL(expectedSize, prefixSize); + if (expectedSize) { + UNIT_ASSERT_VALUES_EQUAL(TStringBuf(otherKey).SubStr(0, prefixSize), value); + } else { + UNIT_ASSERT_VALUES_EQUAL(EMPTY_VALUE, value); + } + } else { + UNIT_ASSERT_VALUES_EQUAL("abcd", value); + UNIT_ASSERT_VALUES_EQUAL(0xfcfcfc, prefixSize); + } + + for (int c = 0; c < 10; ++c) { + TString extendedKey = otherKey; + extendedKey += RandChar(); + size_t extendedPrefixSize = 0xdddddd; + TString extendedValue = "dcba"; + UNIT_ASSERT_VALUES_EQUAL(expectedResult, builder.FindLongestPrefix(extendedKey.data(), extendedKey.size(), &extendedPrefixSize, &extendedValue)); + if (expectedResult) { + UNIT_ASSERT_VALUES_EQUAL(value, extendedValue); + UNIT_ASSERT_VALUES_EQUAL(prefixSize, extendedPrefixSize); + } else { + UNIT_ASSERT_VALUES_EQUAL("dcba", extendedValue); + UNIT_ASSERT_VALUES_EQUAL(0xdddddd, extendedPrefixSize); + } + } + } + builder.Add(key.data(), key.size(), key); + } + + TBufferOutput buffer; + builder.Save(buffer); +} + +void TCompactTrieTest::TestBuilderFindLongestPrefixWithEmptyValue() { + TCompactTrieBuilder<wchar16, ui32> builder; + builder.Add(u"", 42); + builder.Add(u"yandex", 271828); + builder.Add(u"ya", 31415); + + size_t prefixLen = 123; + ui32 value = 0; + + UNIT_ASSERT(builder.FindLongestPrefix(u"google", &prefixLen, &value)); + UNIT_ASSERT_VALUES_EQUAL(prefixLen, 0); + UNIT_ASSERT_VALUES_EQUAL(value, 42); + + UNIT_ASSERT(builder.FindLongestPrefix(u"yahoo", &prefixLen, &value)); + UNIT_ASSERT_VALUES_EQUAL(prefixLen, 2); + UNIT_ASSERT_VALUES_EQUAL(value, 31415); + + TBufferOutput buffer; + builder.Save(buffer); +} + +void TCompactTrieTest::TestPatternSearcherEmpty() { + TCompactPatternSearcherBuilder<char, ui32> builder; + + TBufferOutput searcherData; + builder.Save(searcherData); + + TCompactPatternSearcher<char, ui32> searcher( + searcherData.Buffer().Data(), + searcherData.Buffer().Size() + ); + + UNIT_ASSERT(searcher.SearchMatches("a").empty()); + UNIT_ASSERT(searcher.SearchMatches("").empty()); + UNIT_ASSERT(searcher.SearchMatches("abc").empty()); +} + +void TCompactTrieTest::TestPatternSearcherOnDataset( + const TVector<TString>& patterns, + const TVector<TString>& samples +) { + TCompactPatternSearcherBuilder<char, ui32> builder; + + for (size_t patternIdx = 0; patternIdx < patterns.size(); ++patternIdx) { + builder.Add(patterns[patternIdx], patternIdx); + } + + TBufferOutput searcherData; + builder.Save(searcherData); + + TCompactPatternSearcher<char, ui32> searcher( + searcherData.Buffer().Data(), + searcherData.Buffer().Size() + ); + + for (const auto& sample : samples) { + const auto matches = searcher.SearchMatches(sample); + + size_t matchesNum = 0; + THashSet<TString> processedPatterns; + for (const auto& pattern : patterns) { + if (pattern.Empty() || processedPatterns.contains(pattern)) { + continue; + } + for (size_t start = 0; start + pattern.Size() <= sample.Size(); ++start) { + matchesNum += (pattern == sample.substr(start, pattern.Size())); + } + processedPatterns.insert(pattern); + } + UNIT_ASSERT_VALUES_EQUAL(matchesNum, matches.size()); + + + TSet<std::pair<size_t, ui32>> foundMatches; + for (const auto& match : matches) { + std::pair<size_t, ui32> matchParams(match.End, match.Data); + UNIT_ASSERT(!foundMatches.contains(matchParams)); + foundMatches.insert(matchParams); + + const auto& pattern = patterns[match.Data]; + UNIT_ASSERT_VALUES_EQUAL( + sample.substr(match.End - pattern.size() + 1, pattern.size()), + pattern + ); + } + } +} + +void TCompactTrieTest::TestPatternSearcherSimple() { + TestPatternSearcherOnDataset( + { // patterns + "abcd", + "abc", + "ab", + "a", + "" + }, + { // samples + "abcde", + "abcd", + "abc", + "ab", + "a", + "" + } + ); + TestPatternSearcherOnDataset( + { // patterns + "a" + "ab", + "abcd", + }, + { // samples + "abcde", + "abcd", + "abc", + "ab", + "a", + "" + } + ); + TestPatternSearcherOnDataset( + { // patterns + "aaaa", + "aaa", + "aa", + "a", + }, + { // samples + "aaaaaaaaaaaa" + } + ); + TestPatternSearcherOnDataset( + { // patterns + "aa", "ab", "ac", "ad", "ae", "af", + "ba", "bb", "bc", "bd", "be", "bf", + "ca", "cb", "cc", "cd", "ce", "cf", + "da", "db", "dc", "dd", "de", "df", + "ea", "eb", "ec", "ed", "ee", "ef", + "fa", "fb", "fc", "fd", "fe", "ff" + }, + { // samples + "dcabafeebfdcbacddacadbaabecdbaeffecdbfabcdcabcfaefecdfebacfedacefbdcacfeb", + "abcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefancdefancdef", + "fedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcba", + "", + "a", "b", "c", "d", "e", "f", + "aa", "ab", "ac", "ad", "ae", "af", + "ba", "bb", "bc", "bd", "be", "bf", + "ca", "cb", "cc", "cd", "ce", "cf", + "da", "db", "dc", "dd", "de", "df", + "ea", "eb", "ec", "ed", "ee", "ef", + "fa", "fb", "fc", "fd", "fe", "ff" + } + ); +} + +static char RandChar( + TFastRng<ui64>& rng, + int maxChar +) { + return static_cast<char>(rng.GenRand() % (maxChar + 1)); +} + +static TString RandStr( + TFastRng<ui64>& rng, + size_t maxLength, + int maxChar, + bool nonEmpty = false +) { + Y_ASSERT(maxLength > 0); + + size_t length; + if (nonEmpty) { + length = rng.GenRand() % maxLength + 1; + } else { + length = rng.GenRand() % (maxLength + 1); + } + + TString result; + while (result.size() < length) { + result += RandChar(rng, maxChar); + } + + return result; +} + +void TCompactTrieTest::TestPatternSearcherRandom( + size_t patternsNum, + size_t patternMaxLength, + size_t strMaxLength, + int maxChar, + TFastRng<ui64>& rng +) { + auto patternToSearch = RandStr(rng, patternMaxLength, maxChar, /*nonEmpty*/true); + + TVector<TString> patterns = {patternToSearch}; + while (patterns.size() < patternsNum) { + patterns.push_back(RandStr(rng, patternMaxLength, maxChar, /*nonEmpty*/true)); + } + + auto filler = RandStr(rng, strMaxLength - patternToSearch.Size() + 1, maxChar); + size_t leftFillerSize = rng.GenRand() % (filler.size() + 1); + auto leftFiller = filler.substr(0, leftFillerSize); + auto rightFiller = filler.substr(leftFillerSize, filler.size() - leftFillerSize); + auto sample = leftFiller + patternToSearch + rightFiller; + + TestPatternSearcherOnDataset(patterns, {sample}); +} + +void TCompactTrieTest::TestPatternSearcherRandom() { + TFastRng<ui64> rng(0); + for (size_t patternMaxLen : {1, 2, 10}) { + for (size_t strMaxLen : TVector<size_t>{patternMaxLen, 2 * patternMaxLen, 10}) { + for (int maxChar : {0, 1, 5, 255}) { + for (size_t patternsNum : {1, 10}) { + for (size_t testIdx = 0; testIdx < 3; ++testIdx) { + TestPatternSearcherRandom( + patternsNum, + patternMaxLen, + strMaxLen, + maxChar, + rng + ); + } + } + } + } + } +} diff --git a/library/cpp/containers/comptrie/first_symbol_iterator.h b/library/cpp/containers/comptrie/first_symbol_iterator.h new file mode 100644 index 00000000000..d06135f06f2 --- /dev/null +++ b/library/cpp/containers/comptrie/first_symbol_iterator.h @@ -0,0 +1,61 @@ +#pragma once + +#include "opaque_trie_iterator.h" +#include <util/generic/ptr.h> + +namespace NCompactTrie { + // Iterates over possible first symbols in a trie. + // Allows one to get the symbol and the subtrie starting from it. + template <class TTrie> + class TFirstSymbolIterator { + public: + using TSymbol = typename TTrie::TSymbol; + using TData = typename TTrie::TData; + + void SetTrie(const TTrie& trie, const ILeafSkipper& skipper) { + Trie = trie; + Impl.Reset(new TOpaqueTrieIterator( + TOpaqueTrie(Trie.Data().AsCharPtr(), Trie.Data().Size(), skipper), + nullptr, + false, + sizeof(TSymbol))); + if (Impl->MeasureKey<TSymbol>() == 0) { + MakeStep(); + } + } + + const TTrie& GetTrie() const { + return Trie; + } + + bool AtEnd() const { + return *Impl == TOpaqueTrieIterator(Impl->GetTrie(), nullptr, true, sizeof(TSymbol)); + } + + TSymbol GetKey() const { + return Impl->GetKey<TSymbol>()[0]; + } + + TTrie GetTails() const { + const TNode& node = Impl->GetNode(); + const size_t forwardOffset = node.GetForwardOffset(); + const char* emptyValue = node.IsFinal() ? Trie.Data().AsCharPtr() + node.GetLeafOffset() : nullptr; + if (forwardOffset) { + const char* start = Trie.Data().AsCharPtr() + forwardOffset; + TBlob body = TBlob::NoCopy(start, Trie.Data().Size() - forwardOffset); + return TTrie(body, emptyValue, Trie.GetPacker()); + } else { + return TTrie(emptyValue); + } + } + + void MakeStep() { + Impl->Forward(); + } + + private: + TTrie Trie; + TCopyPtr<TOpaqueTrieIterator> Impl; + }; + +} diff --git a/library/cpp/containers/comptrie/key_selector.h b/library/cpp/containers/comptrie/key_selector.h new file mode 100644 index 00000000000..60466cef715 --- /dev/null +++ b/library/cpp/containers/comptrie/key_selector.h @@ -0,0 +1,29 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/generic/string.h> +#include <util/generic/strbuf.h> + +template <class T> +struct TCompactTrieKeySelector { + typedef TVector<T> TKey; + typedef TVector<T> TKeyBuf; +}; + +template <class TChar> +struct TCompactTrieCharKeySelector { + typedef TBasicString<TChar> TKey; + typedef TBasicStringBuf<TChar> TKeyBuf; +}; + +template <> +struct TCompactTrieKeySelector<char>: public TCompactTrieCharKeySelector<char> { +}; + +template <> +struct TCompactTrieKeySelector<wchar16>: public TCompactTrieCharKeySelector<wchar16> { +}; + +template <> +struct TCompactTrieKeySelector<wchar32>: public TCompactTrieCharKeySelector<wchar32> { +}; diff --git a/library/cpp/containers/comptrie/leaf_skipper.h b/library/cpp/containers/comptrie/leaf_skipper.h new file mode 100644 index 00000000000..39592589487 --- /dev/null +++ b/library/cpp/containers/comptrie/leaf_skipper.h @@ -0,0 +1,56 @@ +#pragma once + +#include <cstddef> + +namespace NCompactTrie { + class ILeafSkipper { + public: + virtual size_t SkipLeaf(const char* p) const = 0; + virtual ~ILeafSkipper() = default; + }; + + template <class TPacker> + class TPackerLeafSkipper: public ILeafSkipper { + private: + const TPacker* Packer; + + public: + TPackerLeafSkipper(const TPacker* packer) + : Packer(packer) + { + } + + size_t SkipLeaf(const char* p) const override { + return Packer->SkipLeaf(p); + } + + // For test purposes. + const TPacker* GetPacker() const { + return Packer; + } + }; + + // The data you need to traverse the trie without unpacking the values. + struct TOpaqueTrie { + const char* Data; + size_t Length; + const ILeafSkipper& SkipFunction; + + TOpaqueTrie(const char* data, size_t dataLength, const ILeafSkipper& skipFunction) + : Data(data) + , Length(dataLength) + , SkipFunction(skipFunction) + { + } + + bool operator==(const TOpaqueTrie& other) const { + return Data == other.Data && + Length == other.Length && + &SkipFunction == &other.SkipFunction; + } + + bool operator!=(const TOpaqueTrie& other) const { + return !(*this == other); + } + }; +} diff --git a/library/cpp/containers/comptrie/loader/loader.cpp b/library/cpp/containers/comptrie/loader/loader.cpp new file mode 100644 index 00000000000..4811e9d5216 --- /dev/null +++ b/library/cpp/containers/comptrie/loader/loader.cpp @@ -0,0 +1 @@ +#include "loader.h" diff --git a/library/cpp/containers/comptrie/loader/loader.h b/library/cpp/containers/comptrie/loader/loader.h new file mode 100644 index 00000000000..ee10e9b451e --- /dev/null +++ b/library/cpp/containers/comptrie/loader/loader.h @@ -0,0 +1,22 @@ +#pragma once + +#include <library/cpp/archive/yarchive.h> +#include <util/generic/string.h> +#include <util/generic/ptr.h> +#include <util/generic/yexception.h> +#include <util/memory/blob.h> + +template <class TrieType, size_t N> +TrieType LoadTrieFromArchive(const TString& key, + const unsigned char (&data)[N], + bool ignoreErrors = false) { + TArchiveReader archive(TBlob::NoCopy(data, sizeof(data))); + if (archive.Has(key)) { + TAutoPtr<IInputStream> trie = archive.ObjectByKey(key); + return TrieType(TBlob::FromStream(*trie)); + } + if (!ignoreErrors) { + ythrow yexception() << "Resource " << key << " not found"; + } + return TrieType(); +} diff --git a/library/cpp/containers/comptrie/loader/loader_ut.cpp b/library/cpp/containers/comptrie/loader/loader_ut.cpp new file mode 100644 index 00000000000..345063a31e4 --- /dev/null +++ b/library/cpp/containers/comptrie/loader/loader_ut.cpp @@ -0,0 +1,30 @@ +#include <library/cpp/testing/unittest/registar.h> +#include <library/cpp/containers/comptrie/comptrie.h> +#include <library/cpp/containers/comptrie/loader/loader.h> + +using TDummyTrie = TCompactTrie<char, i32>; + +namespace { + const unsigned char DATA[] = { +#include "data.inc" + }; +} + +Y_UNIT_TEST_SUITE(ArchiveLoaderTests) { + Y_UNIT_TEST(BaseTest) { + TDummyTrie trie = LoadTrieFromArchive<TDummyTrie>("/dummy.trie", DATA, true); + UNIT_ASSERT_EQUAL(trie.Size(), 3); + + const TString TrieKyes[3] = { + "zero", "one", "two"}; + i32 val = -1; + for (i32 i = 0; i < 3; ++i) { + UNIT_ASSERT(trie.Find(TrieKyes[i].data(), TrieKyes[i].size(), &val)); + UNIT_ASSERT_EQUAL(i, val); + } + + UNIT_CHECK_GENERATED_EXCEPTION( + LoadTrieFromArchive<TDummyTrie>("/noname.trie", DATA), + yexception); + } +} diff --git a/library/cpp/containers/comptrie/loader/ut/dummy.trie b/library/cpp/containers/comptrie/loader/ut/dummy.trie Binary files differnew file mode 100644 index 00000000000..4af18add2ff --- /dev/null +++ b/library/cpp/containers/comptrie/loader/ut/dummy.trie diff --git a/library/cpp/containers/comptrie/loader/ut/ya.make b/library/cpp/containers/comptrie/loader/ut/ya.make new file mode 100644 index 00000000000..6c0334d3ea7 --- /dev/null +++ b/library/cpp/containers/comptrie/loader/ut/ya.make @@ -0,0 +1,18 @@ +UNITTEST_FOR(library/cpp/containers/comptrie/loader) + +OWNER(my34) + +ARCHIVE( + NAME data.inc + dummy.trie +) + +SRCS( + loader_ut.cpp +) + +PEERDIR( + library/cpp/containers/comptrie/loader +) + +END() diff --git a/library/cpp/containers/comptrie/loader/ya.make b/library/cpp/containers/comptrie/loader/ya.make new file mode 100644 index 00000000000..1e23e442a01 --- /dev/null +++ b/library/cpp/containers/comptrie/loader/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER(my34) + +SRCS( + loader.h + loader.cpp +) + +PEERDIR( + library/cpp/archive + library/cpp/containers/comptrie +) + +END() diff --git a/library/cpp/containers/comptrie/make_fast_layout.cpp b/library/cpp/containers/comptrie/make_fast_layout.cpp new file mode 100644 index 00000000000..ade78d78994 --- /dev/null +++ b/library/cpp/containers/comptrie/make_fast_layout.cpp @@ -0,0 +1,467 @@ +#include "make_fast_layout.h" +#include "node.h" +#include "writeable_node.h" +#include "write_trie_backwards.h" +#include "comptrie_impl.h" + +#include <util/generic/hash.h> +#include <util/generic/utility.h> + +// Lay the trie in memory in such a way that there are less cache misses when jumping from root to leaf. +// The trie becomes about 2% larger, but the access became about 25% faster in our experiments. +// Can be called on minimized and non-minimized tries, in the first case in requires half a trie more memory. +// Calling it the second time on the same trie does nothing. +// +// The algorithm is based on van Emde Boas layout as described in the yandex data school lectures on external memory algoritms +// by Maxim Babenko and Ivan Puzyrevsky. The difference is that when we cut the tree into levels +// two nodes connected by a forward link are put into the same level (because they usually lie near each other in the original tree). +// The original paper (describing the layout in Section 2.1) is: +// Michael A. Bender, Erik D. Demaine, Martin Farach-Colton. Cache-Oblivious B-Trees // SIAM Journal on Computing, volume 35, number 2, 2005, pages 341β358. +// Available on the web: http://erikdemaine.org/papers/CacheObliviousBTrees_SICOMP/ +// Or: Michael A. Bender, Erik D. Demaine, and Martin Farach-Colton. Cache-Oblivious B-Trees // Proceedings of the 41st Annual Symposium +// on Foundations of Computer Science (FOCS 2000), Redondo Beach, California, November 12β14, 2000, pages 399β409. +// Available on the web: http://erikdemaine.org/papers/FOCS2000b/ +// (there is not much difference between these papers, actually). +// +namespace NCompactTrie { + static size_t FindSupportingPowerOf2(size_t n) { + size_t result = 1ull << (8 * sizeof(size_t) - 1); + while (result > n) { + result >>= 1; + } + return result; + } + + namespace { + class TTrieNodeSet { + public: + TTrieNodeSet() = default; + + explicit TTrieNodeSet(const TOpaqueTrie& trie) + : Body(trie.Length / (8 * MinNodeSize) + 1, 0) + { + } + + bool Has(size_t offset) const { + const size_t reducedOffset = ReducedOffset(offset); + return OffsetCell(reducedOffset) & OffsetMask(reducedOffset); + } + + void Add(size_t offset) { + const size_t reducedOffset = ReducedOffset(offset); + OffsetCell(reducedOffset) |= OffsetMask(reducedOffset); + } + + void Remove(size_t offset) { + const size_t reducedOffset = ReducedOffset(offset); + OffsetCell(reducedOffset) &= ~OffsetMask(reducedOffset); + } + + void Swap(TTrieNodeSet& other) { + Body.swap(other.Body); + } + + private: + static const size_t MinNodeSize = 2; + TVector<ui8> Body; + + static size_t ReducedOffset(size_t offset) { + return offset / MinNodeSize; + } + static ui8 OffsetMask(size_t reducedOffset) { + return 1 << (reducedOffset % 8); + } + ui8& OffsetCell(size_t reducedOffset) { + return Body.at(reducedOffset / 8); + } + const ui8& OffsetCell(size_t reducedOffset) const { + return Body.at(reducedOffset / 8); + } + }; + + //--------------------------------------------------------------------- + + class TTrieNodeCounts { + public: + TTrieNodeCounts() = default; + + explicit TTrieNodeCounts(const TOpaqueTrie& trie) + : Body(trie.Length / MinNodeSize, 0) + , IsTree(false) + { + } + + size_t Get(size_t offset) const { + return IsTree ? 1 : Body.at(offset / MinNodeSize); + } + + void Inc(size_t offset) { + if (IsTree) { + return; + } + ui8& count = Body.at(offset / MinNodeSize); + if (count != MaxCount) { + ++count; + } + } + + size_t Dec(size_t offset) { + if (IsTree) { + return 0; + } + ui8& count = Body.at(offset / MinNodeSize); + Y_ASSERT(count > 0); + if (count != MaxCount) { + --count; + } + return count; + } + + void Swap(TTrieNodeCounts& other) { + Body.swap(other.Body); + ::DoSwap(IsTree, other.IsTree); + } + + void SetTreeMode() { + IsTree = true; + Body = TVector<ui8>(); + } + + private: + static const size_t MinNodeSize = 2; + static const ui8 MaxCount = 255; + + TVector<ui8> Body; + bool IsTree = false; + }; + + //---------------------------------------------------------- + + class TOffsetIndex { + public: + // In all methods: + // Key --- offset from the beginning of the old trie. + // Value --- offset from the end of the new trie. + + explicit TOffsetIndex(TTrieNodeCounts& counts) { + ParentCounts.Swap(counts); + } + + void Add(size_t key, size_t value) { + Data[key] = value; + } + + size_t Size() const { + return Data.size(); + } + + size_t Get(size_t key) { + auto pos = Data.find(key); + if (pos == Data.end()) { + ythrow yexception() << "Bad node walking order: trying to get node offset too early or too many times!"; + } + size_t result = pos->second; + if (ParentCounts.Dec(key) == 0) { + // We don't need this offset any more. + Data.erase(pos); + } + return result; + } + + private: + TTrieNodeCounts ParentCounts; + THashMap<size_t, size_t> Data; + }; + + //--------------------------------------------------------------------------------------- + + class TTrieMeasurer { + public: + TTrieMeasurer(const TOpaqueTrie& trie, bool verbose); + void Measure(); + + size_t GetDepth() const { + return Depth; + } + + size_t GetNodeCount() const { + return NodeCount; + } + + size_t GetUnminimizedNodeCount() const { + return UnminimizedNodeCount; + } + + bool IsTree() const { + return NodeCount == UnminimizedNodeCount; + } + + TTrieNodeCounts& GetParentCounts() { + return ParentCounts; + } + + const TOpaqueTrie& GetTrie() const { + return Trie; + } + + private: + const TOpaqueTrie& Trie; + size_t Depth; + TTrieNodeCounts ParentCounts; + size_t NodeCount; + size_t UnminimizedNodeCount; + const bool Verbose; + + // returns depth, increments NodeCount. + size_t MeasureSubtrie(size_t rootOffset, bool isNewPath); + }; + + TTrieMeasurer::TTrieMeasurer(const TOpaqueTrie& trie, bool verbose) + : Trie(trie) + , Depth(0) + , ParentCounts(trie) + , NodeCount(0) + , UnminimizedNodeCount(0) + , Verbose(verbose) + { + Y_ASSERT(Trie.Data); + } + + void TTrieMeasurer::Measure() { + if (Verbose) { + Cerr << "Measuring the trie..." << Endl; + } + NodeCount = 0; + UnminimizedNodeCount = 0; + Depth = MeasureSubtrie(0, true); + if (IsTree()) { + ParentCounts.SetTreeMode(); + } + if (Verbose) { + Cerr << "Unminimized node count: " << UnminimizedNodeCount << Endl; + Cerr << "Trie depth: " << Depth << Endl; + Cerr << "Node count: " << NodeCount << Endl; + } + } + + // A chain of nodes linked by forward links + // is considered one node with many left and right children + // for depth measuring here and in + // TVanEmdeBoasReverseNodeEnumerator::FindDescendants. + size_t TTrieMeasurer::MeasureSubtrie(size_t rootOffset, bool isNewPath) { + Y_ASSERT(rootOffset < Trie.Length); + TNode node(Trie.Data, rootOffset, Trie.SkipFunction); + size_t depth = 0; + for (;;) { + ++UnminimizedNodeCount; + if (Verbose) { + ShowProgress(UnminimizedNodeCount); + } + if (isNewPath) { + if (ParentCounts.Get(node.GetOffset()) > 0) { + isNewPath = false; + } else { + ++NodeCount; + } + ParentCounts.Inc(node.GetOffset()); + } + if (node.GetLeftOffset()) { + depth = Max(depth, 1 + MeasureSubtrie(node.GetLeftOffset(), isNewPath)); + } + if (node.GetRightOffset()) { + depth = Max(depth, 1 + MeasureSubtrie(node.GetRightOffset(), isNewPath)); + } + if (node.GetForwardOffset()) { + node = TNode(Trie.Data, node.GetForwardOffset(), Trie.SkipFunction); + } else { + break; + } + } + return depth; + } + + //-------------------------------------------------------------------------------------- + + using TLevelNodes = TVector<size_t>; + + struct TLevel { + size_t Depth; + TLevelNodes Nodes; + + explicit TLevel(size_t depth) + : Depth(depth) + { + } + }; + + //---------------------------------------------------------------------------------------- + + class TVanEmdeBoasReverseNodeEnumerator: public TReverseNodeEnumerator { + public: + TVanEmdeBoasReverseNodeEnumerator(TTrieMeasurer& measurer, bool verbose) + : Fresh(true) + , Trie(measurer.GetTrie()) + , Depth(measurer.GetDepth()) + , MaxIndexSize(0) + , BackIndex(measurer.GetParentCounts()) + , ProcessedNodes(measurer.GetTrie()) + , Verbose(verbose) + { + } + + bool Move() override { + if (!Fresh) { + CentralWord.pop_back(); + } + if (CentralWord.empty()) { + return MoveCentralWordStart(); + } + return true; + } + + const TNode& Get() const { + return CentralWord.back(); + } + + size_t GetLeafLength() const override { + return Get().GetLeafLength(); + } + + // Returns recalculated offset from the end of the current node. + size_t PrepareOffset(size_t absoffset, size_t resultLength) { + if (!absoffset) + return NPOS; + return resultLength - BackIndex.Get(absoffset); + } + + size_t RecreateNode(char* buffer, size_t resultLength) override { + TWriteableNode newNode(Get(), Trie.Data); + newNode.ForwardOffset = PrepareOffset(Get().GetForwardOffset(), resultLength); + newNode.LeftOffset = PrepareOffset(Get().GetLeftOffset(), resultLength); + newNode.RightOffset = PrepareOffset(Get().GetRightOffset(), resultLength); + + const size_t len = newNode.Pack(buffer); + ProcessedNodes.Add(Get().GetOffset()); + BackIndex.Add(Get().GetOffset(), resultLength + len); + MaxIndexSize = Max(MaxIndexSize, BackIndex.Size()); + return len; + } + + private: + bool Fresh; + TOpaqueTrie Trie; + size_t Depth; + size_t MaxIndexSize; + + TVector<TLevel> Trace; + TOffsetIndex BackIndex; + TVector<TNode> CentralWord; + TTrieNodeSet ProcessedNodes; + + const bool Verbose; + + private: + bool IsVisited(size_t offset) const { + return ProcessedNodes.Has(offset); + } + + void AddCascade(size_t step) { + Y_ASSERT(!(step & (step - 1))); // Should be a power of 2. + while (step > 0) { + size_t root = Trace.back().Nodes.back(); + TLevel level(Trace.back().Depth + step); + Trace.push_back(level); + size_t actualStep = FindSupportingPowerOf2(FindDescendants(root, step, Trace.back().Nodes)); + if (actualStep != step) { + // Retry with a smaller step. + Y_ASSERT(actualStep < step); + step = actualStep; + Trace.pop_back(); + } else { + step /= 2; + } + } + } + + void FillCentralWord() { + CentralWord.clear(); + CentralWord.push_back(TNode(Trie.Data, Trace.back().Nodes.back(), Trie.SkipFunction)); + // Do not check for epsilon links, as the traversal order now is different. + while (CentralWord.back().GetForwardOffset() && !IsVisited(CentralWord.back().GetForwardOffset())) { + CentralWord.push_back(TNode(Trie.Data, CentralWord.back().GetForwardOffset(), Trie.SkipFunction)); + } + } + + bool MoveCentralWordStart() { + do { + if (Fresh) { + TLevel root(0); + Trace.push_back(root); + Trace.back().Nodes.push_back(0); + const size_t sectionDepth = FindSupportingPowerOf2(Depth); + AddCascade(sectionDepth); + Fresh = false; + } else { + Trace.back().Nodes.pop_back(); + if (Trace.back().Nodes.empty() && Trace.size() == 1) { + if (Verbose) { + Cerr << "Max index size: " << MaxIndexSize << Endl; + Cerr << "Current index size: " << BackIndex.Size() << Endl; + } + // We just popped the root. + return false; + } + size_t lastStep = Trace.back().Depth - Trace[Trace.size() - 2].Depth; + if (Trace.back().Nodes.empty()) { + Trace.pop_back(); + } + AddCascade(lastStep / 2); + } + } while (IsVisited(Trace.back().Nodes.back())); + FillCentralWord(); + return true; + } + + // Returns the maximal depth it has reached while searching. + // This is a method because it needs OffsetIndex to skip processed nodes. + size_t FindDescendants(size_t rootOffset, size_t depth, TLevelNodes& result) const { + if (depth == 0) { + result.push_back(rootOffset); + return 0; + } + size_t actualDepth = 0; + TNode node(Trie.Data, rootOffset, Trie.SkipFunction); + for (;;) { + if (node.GetLeftOffset() && !IsVisited(node.GetLeftOffset())) { + actualDepth = Max(actualDepth, + 1 + FindDescendants(node.GetLeftOffset(), depth - 1, result)); + } + if (node.GetRightOffset() && !IsVisited(node.GetRightOffset())) { + actualDepth = Max(actualDepth, + 1 + FindDescendants(node.GetRightOffset(), depth - 1, result)); + } + if (node.GetForwardOffset() && !IsVisited(node.GetForwardOffset())) { + node = TNode(Trie.Data, node.GetForwardOffset(), Trie.SkipFunction); + } else { + break; + } + } + return actualDepth; + } + }; + + } // Anonymous namespace. + + //----------------------------------------------------------------------------------- + + size_t RawCompactTrieFastLayoutImpl(IOutputStream& os, const TOpaqueTrie& trie, bool verbose) { + if (!trie.Data || !trie.Length) { + return 0; + } + TTrieMeasurer mes(trie, verbose); + mes.Measure(); + TVanEmdeBoasReverseNodeEnumerator enumerator(mes, verbose); + return WriteTrieBackwards(os, enumerator, verbose); + } + +} diff --git a/library/cpp/containers/comptrie/make_fast_layout.h b/library/cpp/containers/comptrie/make_fast_layout.h new file mode 100644 index 00000000000..b8fab5d65b8 --- /dev/null +++ b/library/cpp/containers/comptrie/make_fast_layout.h @@ -0,0 +1,20 @@ +#pragma once + +#include "leaf_skipper.h" +#include <cstddef> + +class IOutputStream; + +namespace NCompactTrie { + // Return value: size of the resulting trie. + size_t RawCompactTrieFastLayoutImpl(IOutputStream& os, const NCompactTrie::TOpaqueTrie& trie, bool verbose); + + // Return value: size of the resulting trie. + template <class TPacker> + size_t CompactTrieMakeFastLayoutImpl(IOutputStream& os, const char* data, size_t datalength, bool verbose, const TPacker* packer) { + TPackerLeafSkipper<TPacker> skipper(packer); + TOpaqueTrie trie(data, datalength, skipper); + return RawCompactTrieFastLayoutImpl(os, trie, verbose); + } + +} diff --git a/library/cpp/containers/comptrie/minimize.cpp b/library/cpp/containers/comptrie/minimize.cpp new file mode 100644 index 00000000000..80d0b25217d --- /dev/null +++ b/library/cpp/containers/comptrie/minimize.cpp @@ -0,0 +1,359 @@ +#include "minimize.h" +#include "node.h" +#include "writeable_node.h" +#include "write_trie_backwards.h" +#include "comptrie_impl.h" + +#include <util/generic/hash.h> +#include <util/generic/algorithm.h> + +namespace NCompactTrie { + // Minimize the trie. The result is equivalent to the original + // trie, except that it takes less space (and has marginally lower + // performance, because of eventual epsilon links). + // The algorithm is as follows: starting from the largest pieces, we find + // nodes that have identical continuations (Daciuk's right language), + // and repack the trie. Repacking is done in-place, so memory is less + // of an issue; however, it may take considerable time. + + // IMPORTANT: never try to reminimize an already minimized trie or a trie with fast layout. + // Because of non-local structure and epsilon links, it won't work + // as you expect it to, and can destroy the trie in the making. + + namespace { + using TOffsetList = TVector<size_t>; + using TPieceIndex = THashMap<size_t, TOffsetList>; + + using TSizePair = std::pair<size_t, size_t>; + using TSizePairVector = TVector<TSizePair>; + using TSizePairVectorVector = TVector<TSizePairVector>; + + class TOffsetMap { + protected: + TSizePairVectorVector Data; + + public: + TOffsetMap() { + Data.reserve(0x10000); + } + + void Add(size_t key, size_t value) { + size_t hikey = key & 0xFFFF; + + if (Data.size() <= hikey) + Data.resize(hikey + 1); + + TSizePairVector& sublist = Data[hikey]; + + for (auto& it : sublist) { + if (it.first == key) { + it.second = value; + + return; + } + } + + sublist.push_back(TSizePair(key, value)); + } + + bool Contains(size_t key) const { + return (Get(key) != 0); + } + + size_t Get(size_t key) const { + size_t hikey = key & 0xFFFF; + + if (Data.size() <= hikey) + return 0; + + const TSizePairVector& sublist = Data[hikey]; + + for (const auto& it : sublist) { + if (it.first == key) + return it.second; + } + + return 0; + } + }; + + class TOffsetDeltas { + protected: + TSizePairVector Data; + + public: + void Add(size_t key, size_t value) { + if (Data.empty()) { + if (key == value) + return; // no offset + } else { + TSizePair last = Data.back(); + + if (key <= last.first) { + Cerr << "Trouble: elements to offset delta list added in wrong order" << Endl; + + return; + } + + if (last.first + value == last.second + key) + return; // same offset + } + + Data.push_back(TSizePair(key, value)); + } + + size_t Get(size_t key) const { + if (Data.empty()) + return key; // difference is zero; + + if (key < Data.front().first) + return key; + + // Binary search for the highest entry in the list that does not exceed the key + size_t from = 0; + size_t to = Data.size() - 1; + + while (from < to) { + size_t midpoint = (from + to + 1) / 2; + + if (key < Data[midpoint].first) + to = midpoint - 1; + else + from = midpoint; + } + + TSizePair entry = Data[from]; + + return key - entry.first + entry.second; + } + }; + + class TPieceComparer { + private: + const char* Data; + const size_t Length; + + public: + TPieceComparer(const char* buf, size_t len) + : Data(buf) + , Length(len) + { + } + + bool operator()(size_t p1, const size_t p2) { + int res = memcmp(Data + p1, Data + p2, Length); + + if (res) + return (res > 0); + + return (p1 > p2); // the pieces are sorted in the reverse order of appearance + } + }; + + struct TBranchPoint { + TNode Node; + int Selector; + + public: + TBranchPoint() + : Selector(0) + { + } + + TBranchPoint(const char* data, size_t offset, const ILeafSkipper& skipFunction) + : Node(data, offset, skipFunction) + , Selector(0) + { + } + + bool IsFinal() const { + return Node.IsFinal(); + } + + // NextNode returns child nodes, starting from the last node: Right, then Left, then Forward + size_t NextNode(const TOffsetMap& mergedNodes) { + while (Selector < 3) { + size_t nextOffset = 0; + + switch (++Selector) { + case 1: + if (Node.GetRightOffset()) + nextOffset = Node.GetRightOffset(); + break; + + case 2: + if (Node.GetLeftOffset()) + nextOffset = Node.GetLeftOffset(); + break; + + case 3: + if (Node.GetForwardOffset()) + nextOffset = Node.GetForwardOffset(); + break; + + default: + break; + } + + if (nextOffset && !mergedNodes.Contains(nextOffset)) + return nextOffset; + } + return 0; + } + }; + + class TMergingReverseNodeEnumerator: public TReverseNodeEnumerator { + private: + bool Fresh; + TOpaqueTrie Trie; + const TOffsetMap& MergeMap; + TVector<TBranchPoint> Trace; + TOffsetDeltas OffsetIndex; + + public: + TMergingReverseNodeEnumerator(const TOpaqueTrie& trie, const TOffsetMap& mergers) + : Fresh(true) + , Trie(trie) + , MergeMap(mergers) + { + } + + bool Move() override { + if (Fresh) { + Trace.push_back(TBranchPoint(Trie.Data, 0, Trie.SkipFunction)); + Fresh = false; + } else { + if (Trace.empty()) + return false; + + Trace.pop_back(); + + if (Trace.empty()) + return false; + } + + size_t nextnode = Trace.back().NextNode(MergeMap); + + while (nextnode) { + Trace.push_back(TBranchPoint(Trie.Data, nextnode, Trie.SkipFunction)); + nextnode = Trace.back().NextNode(MergeMap); + } + + return (!Trace.empty()); + } + + const TNode& Get() const { + return Trace.back().Node; + } + size_t GetLeafLength() const override { + return Get().GetLeafLength(); + } + + // Returns recalculated offset from the end of the current node + size_t PrepareOffset(size_t absoffset, size_t minilength) { + if (!absoffset) + return NPOS; + + if (MergeMap.Contains(absoffset)) + absoffset = MergeMap.Get(absoffset); + return minilength - OffsetIndex.Get(Trie.Length - absoffset); + } + + size_t RecreateNode(char* buffer, size_t resultLength) override { + TWriteableNode newNode(Get(), Trie.Data); + newNode.ForwardOffset = PrepareOffset(Get().GetForwardOffset(), resultLength); + newNode.LeftOffset = PrepareOffset(Get().GetLeftOffset(), resultLength); + newNode.RightOffset = PrepareOffset(Get().GetRightOffset(), resultLength); + + if (!buffer) + return newNode.Measure(); + + const size_t len = newNode.Pack(buffer); + OffsetIndex.Add(Trie.Length - Get().GetOffset(), resultLength + len); + + return len; + } + }; + + } + + static void AddPiece(TPieceIndex& index, size_t offset, size_t len) { + index[len].push_back(offset); + } + + static TOffsetMap FindEquivalentSubtries(const TOpaqueTrie& trie, bool verbose, size_t minMergeSize) { + // Tree nodes, arranged by span length. + // When all nodes of a given size are considered, they pop off the queue. + TPieceIndex subtries; + TOffsetMap merger; + // Start walking the trie from head. + AddPiece(subtries, 0, trie.Length); + + size_t counter = 0; + // Now consider all nodes with sizeable continuations + for (size_t curlen = trie.Length; curlen >= minMergeSize && !subtries.empty(); curlen--) { + TPieceIndex::iterator iit = subtries.find(curlen); + + if (iit == subtries.end()) + continue; // fast forward to the next available length value + + TOffsetList& batch = iit->second; + TPieceComparer comparer(trie.Data, curlen); + Sort(batch.begin(), batch.end(), comparer); + + TOffsetList::iterator it = batch.begin(); + while (it != batch.end()) { + if (verbose) + ShowProgress(++counter); + + size_t offset = *it; + + // Fill the array with the subnodes of the element + TNode node(trie.Data, offset, trie.SkipFunction); + size_t end = offset + curlen; + if (size_t rightOffset = node.GetRightOffset()) { + AddPiece(subtries, rightOffset, end - rightOffset); + end = rightOffset; + } + if (size_t leftOffset = node.GetLeftOffset()) { + AddPiece(subtries, leftOffset, end - leftOffset); + end = leftOffset; + } + if (size_t forwardOffset = node.GetForwardOffset()) { + AddPiece(subtries, forwardOffset, end - forwardOffset); + } + + while (++it != batch.end()) { + // Find next different; until then, just add the offsets to the list of merged nodes. + size_t nextoffset = *it; + + if (memcmp(trie.Data + offset, trie.Data + nextoffset, curlen)) + break; + + merger.Add(nextoffset, offset); + } + } + + subtries.erase(curlen); + } + if (verbose) { + Cerr << counter << Endl; + } + return merger; + } + + size_t RawCompactTrieMinimizeImpl(IOutputStream& os, TOpaqueTrie& trie, bool verbose, size_t minMergeSize, EMinimizeMode mode) { + if (!trie.Data || !trie.Length) { + return 0; + } + + TOffsetMap merger = FindEquivalentSubtries(trie, verbose, minMergeSize); + TMergingReverseNodeEnumerator enumerator(trie, merger); + + if (mode == MM_DEFAULT) + return WriteTrieBackwards(os, enumerator, verbose); + else + return WriteTrieBackwardsNoAlloc(os, enumerator, trie, mode); + } + +} diff --git a/library/cpp/containers/comptrie/minimize.h b/library/cpp/containers/comptrie/minimize.h new file mode 100644 index 00000000000..baaa431d044 --- /dev/null +++ b/library/cpp/containers/comptrie/minimize.h @@ -0,0 +1,29 @@ +#pragma once + +#include "leaf_skipper.h" +#include <cstddef> + +class IOutputStream; + +namespace NCompactTrie { + size_t MeasureOffset(size_t offset); + + enum EMinimizeMode { + MM_DEFAULT, // alollocate new memory for minimized tree + MM_NOALLOC, // minimize tree in the same buffer + MM_INPLACE // do not write tree to the stream, but move to the buffer beginning + }; + + // Return value: size of the minimized trie. + size_t RawCompactTrieMinimizeImpl(IOutputStream& os, TOpaqueTrie& trie, bool verbose, size_t minMergeSize, EMinimizeMode mode); + + // Return value: size of the minimized trie. + template <class TPacker> + size_t CompactTrieMinimizeImpl(IOutputStream& os, const char* data, size_t datalength, bool verbose, const TPacker* packer, EMinimizeMode mode) { + TPackerLeafSkipper<TPacker> skipper(packer); + size_t minmerge = MeasureOffset(datalength); + TOpaqueTrie trie(data, datalength, skipper); + return RawCompactTrieMinimizeImpl(os, trie, verbose, minmerge, mode); + } + +} diff --git a/library/cpp/containers/comptrie/node.cpp b/library/cpp/containers/comptrie/node.cpp new file mode 100644 index 00000000000..5fd22f15ec3 --- /dev/null +++ b/library/cpp/containers/comptrie/node.cpp @@ -0,0 +1,79 @@ +#include "node.h" +#include "leaf_skipper.h" +#include "comptrie_impl.h" + +#include <util/system/yassert.h> +#include <util/generic/yexception.h> + +namespace NCompactTrie { + TNode::TNode() + : Offset(0) + , LeafLength(0) + , CoreLength(0) + , Label(0) + { + for (auto& offset : Offsets) { + offset = 0; + } + } + + // We believe that epsilon links are found only on the forward-position and that afer jumping an epsilon link you come to an ordinary node. + + TNode::TNode(const char* data, size_t offset, const ILeafSkipper& skipFunction) + : Offset(offset) + , LeafLength(0) + , CoreLength(0) + , Label(0) + { + for (auto& anOffset : Offsets) { + anOffset = 0; + } + if (!data) + return; // empty constructor + + const char* datapos = data + offset; + char flags = *(datapos++); + Y_ASSERT(!IsEpsilonLink(flags)); + Label = *(datapos++); + + size_t leftsize = LeftOffsetLen(flags); + size_t& leftOffset = Offsets[D_LEFT]; + leftOffset = UnpackOffset(datapos, leftsize); + if (leftOffset) { + leftOffset += Offset; + } + datapos += leftsize; + + size_t rightsize = RightOffsetLen(flags); + size_t& rightOffset = Offsets[D_RIGHT]; + rightOffset = UnpackOffset(datapos, rightsize); + if (rightOffset) { + rightOffset += Offset; + } + datapos += rightsize; + + if (flags & MT_FINAL) { + Offsets[D_FINAL] = datapos - data; + LeafLength = skipFunction.SkipLeaf(datapos); + } + + CoreLength = 2 + leftsize + rightsize + LeafLength; + if (flags & MT_NEXT) { + size_t& forwardOffset = Offsets[D_NEXT]; + forwardOffset = Offset + CoreLength; + // There might be an epsilon link at the forward position. + // If so, set the ForwardOffset to the value that points to the link's end. + const char* forwardPos = data + forwardOffset; + const char forwardFlags = *forwardPos; + if (IsEpsilonLink(forwardFlags)) { + // Jump through the epsilon link. + size_t epsilonOffset = UnpackOffset(forwardPos + 1, forwardFlags & MT_SIZEMASK); + if (!epsilonOffset) { + ythrow yexception() << "Corrupted epsilon link"; + } + forwardOffset += epsilonOffset; + } + } + } + +} diff --git a/library/cpp/containers/comptrie/node.h b/library/cpp/containers/comptrie/node.h new file mode 100644 index 00000000000..d6f4317db09 --- /dev/null +++ b/library/cpp/containers/comptrie/node.h @@ -0,0 +1,80 @@ +#pragma once + +#include <cstddef> + +namespace NCompactTrie { + class ILeafSkipper; + + enum TDirection { + D_LEFT, + D_FINAL, + D_NEXT, + D_RIGHT, + D_MAX + }; + + inline TDirection& operator++(TDirection& direction) { + direction = static_cast<TDirection>(direction + 1); + return direction; + } + + inline TDirection& operator--(TDirection& direction) { + direction = static_cast<TDirection>(direction - 1); + return direction; + } + + class TNode { + public: + TNode(); + // Processes epsilon links and sets ForwardOffset to correct value. Assumes an epsilon link doesn't point to an epsilon link. + TNode(const char* data, size_t offset, const ILeafSkipper& skipFunction); + + size_t GetOffset() const { + return Offset; + } + + size_t GetLeafOffset() const { + return Offsets[D_FINAL]; + } + size_t GetLeafLength() const { + return LeafLength; + } + size_t GetCoreLength() const { + return CoreLength; + } + + size_t GetOffsetByDirection(TDirection direction) const { + return Offsets[direction]; + } + + size_t GetForwardOffset() const { + return Offsets[D_NEXT]; + } + size_t GetLeftOffset() const { + return Offsets[D_LEFT]; + } + size_t GetRightOffset() const { + return Offsets[D_RIGHT]; + } + char GetLabel() const { + return Label; + } + + bool IsFinal() const { + return GetLeafOffset() != 0; + } + + bool HasEpsilonLinkForward() const { + return GetForwardOffset() > Offset + CoreLength; + } + + private: + size_t Offsets[D_MAX]; + size_t Offset; + size_t LeafLength; + size_t CoreLength; + + char Label; + }; + +} diff --git a/library/cpp/containers/comptrie/opaque_trie_iterator.cpp b/library/cpp/containers/comptrie/opaque_trie_iterator.cpp new file mode 100644 index 00000000000..5fd3914be6e --- /dev/null +++ b/library/cpp/containers/comptrie/opaque_trie_iterator.cpp @@ -0,0 +1,231 @@ +#include "opaque_trie_iterator.h" +#include "comptrie_impl.h" +#include "node.h" + +namespace NCompactTrie { + TOpaqueTrieIterator::TOpaqueTrieIterator(const TOpaqueTrie& trie, const char* emptyValue, bool atend, + size_t maxKeyLength) + : Trie(trie) + , EmptyValue(emptyValue) + , AtEmptyValue(emptyValue && !atend) + , MaxKeyLength(maxKeyLength) + { + if (!AtEmptyValue && !atend) + Forward(); + } + + bool TOpaqueTrieIterator::operator==(const TOpaqueTrieIterator& rhs) const { + return (Trie == rhs.Trie && + Forks == rhs.Forks && + EmptyValue == rhs.EmptyValue && + AtEmptyValue == rhs.AtEmptyValue && + MaxKeyLength == rhs.MaxKeyLength); + } + + bool TOpaqueTrieIterator::HasMaxKeyLength() const { + return MaxKeyLength != size_t(-1) && MeasureNarrowKey() == MaxKeyLength; + } + + bool TOpaqueTrieIterator::Forward() { + if (AtEmptyValue) { + AtEmptyValue = false; + bool res = Forward(); // TODO delete this after format change + if (res && MeasureNarrowKey() != 0) { + return res; // there was not "\0" key + } + // otherwise we are skipping "\0" key + } + + if (!Trie.Length) + return false; + + if (Forks.Empty()) { + TFork fork(Trie.Data, 0, Trie.Length, Trie.SkipFunction); + Forks.Push(fork); + } else { + TFork* topFork = &Forks.Top(); + while (!topFork->NextDirection()) { + if (topFork->Node.GetOffset() >= Trie.Length) + return false; + Forks.Pop(); + if (Forks.Empty()) + return false; + topFork = &Forks.Top(); + } + } + + Y_ASSERT(!Forks.Empty()); + while (Forks.Top().CurrentDirection != D_FINAL && !HasMaxKeyLength()) { + TFork nextFork = Forks.Top().NextFork(Trie.SkipFunction); + Forks.Push(nextFork); + } + TFork& top = Forks.Top(); + static_assert(D_FINAL < D_NEXT, "relative order of NEXT and FINAL directions has changed"); + if (HasMaxKeyLength() && top.CurrentDirection == D_FINAL && top.HasDirection(D_NEXT)) { + top.NextDirection(); + } + return true; + } + + bool TOpaqueTrieIterator::Backward() { + if (AtEmptyValue) + return false; + + if (!Trie.Length) { + if (EmptyValue) { + // A trie that has only the empty value; + // we are not at the empty value, so move to it. + AtEmptyValue = true; + return true; + } else { + // Empty trie. + return false; + } + } + + if (Forks.Empty()) { + TFork fork(Trie.Data, 0, Trie.Length, Trie.SkipFunction); + fork.LastDirection(); + Forks.Push(fork); + } else { + TFork* topFork = &Forks.Top(); + while (!topFork->PrevDirection()) { + if (topFork->Node.GetOffset() >= Trie.Length) + return false; + Forks.Pop(); + if (!Forks.Empty()) { + topFork = &Forks.Top(); + } else { + // When there are no more forks, + // we have to iterate over the empty value. + if (!EmptyValue) + return false; + AtEmptyValue = true; + return true; + } + } + } + + Y_ASSERT(!Forks.Empty()); + while (Forks.Top().CurrentDirection != D_FINAL && !HasMaxKeyLength()) { + TFork nextFork = Forks.Top().NextFork(Trie.SkipFunction); + nextFork.LastDirection(); + Forks.Push(nextFork); + } + TFork& top = Forks.Top(); + static_assert(D_FINAL < D_NEXT, "relative order of NEXT and FINAL directions has changed"); + if (HasMaxKeyLength() && top.CurrentDirection == D_NEXT && top.HasDirection(D_FINAL)) { + top.PrevDirection(); + } + if (MeasureNarrowKey() == 0) { + // This is the '\0' key, skip it and get to the EmptyValue. + AtEmptyValue = true; + Forks.Clear(); + } + return true; + } + + const char* TOpaqueTrieIterator::GetValuePtr() const { + if (!Forks.Empty()) { + const TFork& lastFork = Forks.Top(); + Y_ASSERT(lastFork.Node.IsFinal() && lastFork.CurrentDirection == D_FINAL); + return Trie.Data + lastFork.GetValueOffset(); + } + Y_ASSERT(AtEmptyValue); + return EmptyValue; + } + + //------------------------------------------------------------------------- + + TString TForkStack::GetKey() const { + if (HasEmptyKey()) { + return TString(); + } + + TString result(Key); + if (TopHasLabelInKey()) { + result.append(Top().GetLabel()); + } + return result; + } + + bool TForkStack::HasEmptyKey() const { + // Special case: if we get a single zero label, treat it as an empty key + // TODO delete this after format change + if (TopHasLabelInKey()) { + return Key.size() == 0 && Top().GetLabel() == '\0'; + } else { + return Key.size() == 1 && Key[0] == '\0'; + } + } + + size_t TForkStack::MeasureKey() const { + size_t result = Key.size() + (TopHasLabelInKey() ? 1 : 0); + if (result == 1 && HasEmptyKey()) { + return 0; + } + return result; + } + + //------------------------------------------------------------------------- + + TFork::TFork(const char* data, size_t offset, size_t limit, const ILeafSkipper& skipper) + : Node(data, offset, skipper) + , Data(data) + , Limit(limit) + , CurrentDirection(TDirection(0)) + { +#if COMPTRIE_DATA_CHECK + if (Node.GetOffset() >= Limit - 1) + ythrow yexception() << "gone beyond the limit, data is corrupted"; +#endif + while (CurrentDirection < D_MAX && !HasDirection(CurrentDirection)) { + ++CurrentDirection; + } + } + + bool TFork::operator==(const TFork& rhs) const { + return (Data == rhs.Data && + Node.GetOffset() == rhs.Node.GetOffset() && + CurrentDirection == rhs.CurrentDirection); + } + + inline bool TFork::NextDirection() { + do { + ++CurrentDirection; + } while (CurrentDirection < D_MAX && !HasDirection(CurrentDirection)); + return CurrentDirection < D_MAX; + } + + inline bool TFork::PrevDirection() { + if (CurrentDirection == TDirection(0)) { + return false; + } + do { + --CurrentDirection; + } while (CurrentDirection > 0 && !HasDirection(CurrentDirection)); + return HasDirection(CurrentDirection); + } + + void TFork::LastDirection() { + CurrentDirection = D_MAX; + PrevDirection(); + } + + bool TFork::SetDirection(TDirection direction) { + if (!HasDirection(direction)) { + return false; + } + CurrentDirection = direction; + return true; + } + + char TFork::GetLabel() const { + return Node.GetLabel(); + } + + size_t TFork::GetValueOffset() const { + return Node.GetLeafOffset(); + } + +} diff --git a/library/cpp/containers/comptrie/opaque_trie_iterator.h b/library/cpp/containers/comptrie/opaque_trie_iterator.h new file mode 100644 index 00000000000..195da3c1918 --- /dev/null +++ b/library/cpp/containers/comptrie/opaque_trie_iterator.h @@ -0,0 +1,266 @@ +#pragma once + +#include "comptrie_impl.h" +#include "node.h" +#include "key_selector.h" +#include "leaf_skipper.h" + +#include <util/generic/vector.h> +#include <util/generic/yexception.h> + +namespace NCompactTrie { + class ILeafSkipper; + + class TFork { // Auxiliary class for a branching point in the iterator + public: + TNode Node; + const char* Data; + size_t Limit; // valid data is in range [Data + Node.GetOffset(), Data + Limit) + TDirection CurrentDirection; + + public: + TFork(const char* data, size_t offset, size_t limit, const ILeafSkipper& skipper); + + bool operator==(const TFork& rhs) const; + + bool HasLabelInKey() const { + return CurrentDirection == D_NEXT || CurrentDirection == D_FINAL; + } + + bool NextDirection(); + bool PrevDirection(); + void LastDirection(); + + bool HasDirection(TDirection direction) const { + return Node.GetOffsetByDirection(direction); + } + // If the fork doesn't have the specified direction, + // leaves the direction intact and returns false. + // Otherwise returns true. + bool SetDirection(TDirection direction); + TFork NextFork(const ILeafSkipper& skipper) const; + + char GetLabel() const; + size_t GetValueOffset() const; + }; + + inline TFork TFork::NextFork(const ILeafSkipper& skipper) const { + Y_ASSERT(CurrentDirection != D_FINAL); + size_t offset = Node.GetOffsetByDirection(CurrentDirection); + return TFork(Data, offset, Limit, skipper); + } + + //------------------------------------------------------------------------------------------------ + class TForkStack { + public: + void Push(const TFork& fork) { + if (TopHasLabelInKey()) { + Key.push_back(Top().GetLabel()); + } + Forks.push_back(fork); + } + + void Pop() { + Forks.pop_back(); + if (TopHasLabelInKey()) { + Key.pop_back(); + } + } + + TFork& Top() { + return Forks.back(); + } + const TFork& Top() const { + return Forks.back(); + } + + bool Empty() const { + return Forks.empty(); + } + + void Clear() { + Forks.clear(); + Key.clear(); + } + + bool operator==(const TForkStack& other) const { + return Forks == other.Forks; + } + bool operator!=(const TForkStack& other) const { + return !(*this == other); + } + + TString GetKey() const; + size_t MeasureKey() const; + + private: + TVector<TFork> Forks; + TString Key; + + private: + bool TopHasLabelInKey() const { + return !Empty() && Top().HasLabelInKey(); + } + bool HasEmptyKey() const; + }; + + //------------------------------------------------------------------------------------------------ + + template <class TSymbol> + struct TConvertRawKey { + typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey; + static TKey Get(const TString& rawkey) { + TKey result; + const size_t sz = rawkey.size(); + result.reserve(sz / sizeof(TSymbol)); + for (size_t i = 0; i < sz; i += sizeof(TSymbol)) { + TSymbol sym = 0; + for (size_t j = 0; j < sizeof(TSymbol); j++) { + if (sizeof(TSymbol) <= 1) + sym = 0; + else + sym <<= 8; + if (i + j < sz) + sym |= TSymbol(0x00FF & rawkey[i + j]); + } + result.push_back(sym); + } + return result; + } + + static size_t Size(size_t rawsize) { + return rawsize / sizeof(TSymbol); + } + }; + + template <> + struct TConvertRawKey<char> { + static TString Get(const TString& rawkey) { + return rawkey; + } + + static size_t Size(size_t rawsize) { + return rawsize; + } + }; + + //------------------------------------------------------------------------------------------------ + class TOpaqueTrieIterator { // Iterator stuff. Stores a stack of visited forks. + public: + TOpaqueTrieIterator(const TOpaqueTrie& trie, const char* emptyValue, bool atend, + size_t maxKeyLength = size_t(-1)); + + bool operator==(const TOpaqueTrieIterator& rhs) const; + bool operator!=(const TOpaqueTrieIterator& rhs) const { + return !(*this == rhs); + } + + bool Forward(); + bool Backward(); + + template <class TSymbol> + bool UpperBound(const typename TCompactTrieKeySelector<TSymbol>::TKeyBuf& key); // True if matched exactly. + + template <class TSymbol> + typename TCompactTrieKeySelector<TSymbol>::TKey GetKey() const { + return TConvertRawKey<TSymbol>::Get(GetNarrowKey()); + } + + template <class TSymbol> + size_t MeasureKey() const { + return TConvertRawKey<TSymbol>::Size(MeasureNarrowKey()); + } + + TString GetNarrowKey() const { + return Forks.GetKey(); + } + size_t MeasureNarrowKey() const { + return Forks.MeasureKey(); + } + + const char* GetValuePtr() const; // 0 if none + const TNode& GetNode() const { // Could be called for non-empty key and not AtEnd. + return Forks.Top().Node; + } + const TOpaqueTrie& GetTrie() const { + return Trie; + } + + private: + TOpaqueTrie Trie; + TForkStack Forks; + const char* const EmptyValue; + bool AtEmptyValue; + const size_t MaxKeyLength; + + private: + bool HasMaxKeyLength() const; + + template <class TSymbol> + int LongestPrefix(const typename TCompactTrieKeySelector<TSymbol>::TKeyBuf& key); // Used in UpperBound. + }; + + template <class TSymbol> + int TOpaqueTrieIterator::LongestPrefix(const typename TCompactTrieKeySelector<TSymbol>::TKeyBuf& key) { + Forks.Clear(); + TFork next(Trie.Data, 0, Trie.Length, Trie.SkipFunction); + for (size_t i = 0; i < key.size(); i++) { + TSymbol symbol = key[i]; + const bool isLastSymbol = (i + 1 == key.size()); + for (i64 shift = (i64)NCompactTrie::ExtraBits<TSymbol>(); shift >= 0; shift -= 8) { + const unsigned char label = (unsigned char)(symbol >> shift); + const bool isLastByte = (isLastSymbol && shift == 0); + do { + Forks.Push(next); + TFork& top = Forks.Top(); + if (label < (unsigned char)top.GetLabel()) { + if (!top.SetDirection(D_LEFT)) + return 1; + } else if (label > (unsigned char)top.GetLabel()) { + if (!top.SetDirection(D_RIGHT)) { + Forks.Pop(); // We don't pass this fork on the way to the upper bound. + return -1; + } + } else if (isLastByte) { // Here and below label == top.GetLabel(). + if (top.SetDirection(D_FINAL)) { + return 0; // Skip the NextFork() call at the end of the cycle. + } else { + top.SetDirection(D_NEXT); + return 1; + } + } else if (!top.SetDirection(D_NEXT)) { + top.SetDirection(D_FINAL); + return -1; + } + next = top.NextFork(Trie.SkipFunction); + } while (Forks.Top().CurrentDirection != D_NEXT); // Proceed to the next byte. + } + } + // We get here only if the key was empty. + Forks.Push(next); + return 1; + } + + template <class TSymbol> + bool TOpaqueTrieIterator::UpperBound(const typename TCompactTrieKeySelector<TSymbol>::TKeyBuf& key) { + Forks.Clear(); + if (key.empty() && EmptyValue) { + AtEmptyValue = true; + return true; + } else { + AtEmptyValue = false; + } + const int defect = LongestPrefix<TSymbol>(key); + if (defect > 0) { + // Continue the constructed forks with the smallest key possible. + while (Forks.Top().CurrentDirection != D_FINAL) { + TFork next = Forks.Top().NextFork(Trie.SkipFunction); + Forks.Push(next); + } + } else if (defect < 0) { + Forward(); + } + return defect == 0; + } + +} diff --git a/library/cpp/containers/comptrie/pattern_searcher.h b/library/cpp/containers/comptrie/pattern_searcher.h new file mode 100644 index 00000000000..caab51dc1c5 --- /dev/null +++ b/library/cpp/containers/comptrie/pattern_searcher.h @@ -0,0 +1,606 @@ +#pragma once + +#include "comptrie_builder.h" +#include "comptrie_trie.h" +#include "comptrie_impl.h" +#include <library/cpp/packers/packers.h> + +#include <util/system/yassert.h> +#include <util/generic/vector.h> +#include <util/generic/deque.h> +#include <util/stream/str.h> + +// Aho-Corasick algorithm implementation using CompactTrie implementation of Sedgewick's T-trie + +namespace NCompactTrie { + struct TSuffixLink { + ui64 NextSuffixOffset; + ui64 NextSuffixWithDataOffset; + + TSuffixLink(ui64 nextSuffixOffset = 0, ui64 nextSuffixWithDataOffset = 0) + : NextSuffixOffset(nextSuffixOffset) + , NextSuffixWithDataOffset(nextSuffixWithDataOffset) + { + } + }; + + const size_t FLAGS_SIZE = sizeof(char); + const size_t SYMBOL_SIZE = sizeof(char); +}; + +template <class T = char, class D = ui64, class S = TCompactTriePacker<D>> +class TCompactPatternSearcherBuilder : protected TCompactTrieBuilder<T, D, S> { +public: + typedef T TSymbol; + typedef D TData; + typedef S TPacker; + + typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey; + typedef typename TCompactTrieKeySelector<TSymbol>::TKeyBuf TKeyBuf; + + typedef TCompactTrieBuilder<T, D, S> TBase; + +public: + TCompactPatternSearcherBuilder() { + TBase::Impl = MakeHolder<TCompactPatternSearcherBuilderImpl>(); + } + + bool Add(const TSymbol* key, size_t keyLength, const TData& value) { + return TBase::Impl->AddEntry(key, keyLength, value); + } + bool Add(const TKeyBuf& key, const TData& value) { + return Add(key.data(), key.size(), value); + } + + bool Find(const TSymbol* key, size_t keyLength, TData* value) const { + return TBase::Impl->FindEntry(key, keyLength, value); + } + bool Find(const TKeyBuf& key, TData* value = nullptr) const { + return Find(key.data(), key.size(), value); + } + + size_t Save(IOutputStream& os) const { + size_t trieSize = TBase::Impl->MeasureByteSize(); + TBufferOutput serializedTrie(trieSize); + TBase::Impl->Save(serializedTrie); + + auto serializedTrieBuffer = serializedTrie.Buffer(); + CalculateSuffixLinks( + serializedTrieBuffer.Data(), + serializedTrieBuffer.Data() + serializedTrieBuffer.Size() + ); + + os.Write(serializedTrieBuffer.Data(), serializedTrieBuffer.Size()); + return trieSize; + } + + TBlob Save() const { + TBufferStream buffer; + Save(buffer); + return TBlob::FromStream(buffer); + } + + size_t SaveToFile(const TString& fileName) const { + TFileOutput out(fileName); + return Save(out); + } + + size_t MeasureByteSize() const { + return TBase::Impl->MeasureByteSize(); + } + +private: + void CalculateSuffixLinks(char* trieStart, const char* trieEnd) const; + +protected: + class TCompactPatternSearcherBuilderImpl : public TBase::TCompactTrieBuilderImpl { + public: + typedef typename TBase::TCompactTrieBuilderImpl TImplBase; + + TCompactPatternSearcherBuilderImpl( + TCompactTrieBuilderFlags flags = CTBF_NONE, + TPacker packer = TPacker(), + IAllocator* alloc = TDefaultAllocator::Instance() + ) : TImplBase(flags, packer, alloc) { + } + + ui64 ArcMeasure( + const typename TImplBase::TArc* arc, + size_t leftSize, + size_t rightSize + ) const override { + using namespace NCompactTrie; + + size_t coreSize = SYMBOL_SIZE + FLAGS_SIZE + + sizeof(TSuffixLink) + + this->NodeMeasureLeafValue(arc->Node); + size_t treeSize = this->NodeMeasureSubtree(arc->Node); + + if (arc->Label.Length() > 0) + treeSize += (SYMBOL_SIZE + FLAGS_SIZE + sizeof(TSuffixLink)) * + (arc->Label.Length() - 1); + + // Triple measurements are needed because the space needed to store the offset + // shall be added to the offset itself. Hence three iterations. + size_t leftOffsetSize = 0; + size_t rightOffsetSize = 0; + for (size_t iteration = 0; iteration < 3; ++iteration) { + leftOffsetSize = leftSize ? MeasureOffset( + coreSize + treeSize + leftOffsetSize + rightOffsetSize) : 0; + rightOffsetSize = rightSize ? MeasureOffset( + coreSize + treeSize + leftSize + leftOffsetSize + rightOffsetSize) : 0; + } + + coreSize += leftOffsetSize + rightOffsetSize; + arc->LeftOffset = leftSize ? coreSize + treeSize : 0; + arc->RightOffset = rightSize ? coreSize + treeSize + leftSize : 0; + + return coreSize + treeSize + leftSize + rightSize; + } + + ui64 ArcSaveSelf(const typename TImplBase::TArc* arc, IOutputStream& os) const override { + using namespace NCompactTrie; + + ui64 written = 0; + + size_t leftOffsetSize = MeasureOffset(arc->LeftOffset); + size_t rightOffsetSize = MeasureOffset(arc->RightOffset); + + size_t labelLen = arc->Label.Length(); + + for (size_t labelPos = 0; labelPos < labelLen; ++labelPos) { + char flags = 0; + + if (labelPos == 0) { + flags |= (leftOffsetSize << MT_LEFTSHIFT); + flags |= (rightOffsetSize << MT_RIGHTSHIFT); + } + + if (labelPos == labelLen - 1) { + if (arc->Node->IsFinal()) + flags |= MT_FINAL; + if (!arc->Node->IsLast()) + flags |= MT_NEXT; + } else { + flags |= MT_NEXT; + } + + os.Write(&flags, 1); + os.Write(&arc->Label.AsCharPtr()[labelPos], 1); + written += 2; + + TSuffixLink suffixlink; + os.Write(&suffixlink, sizeof(TSuffixLink)); + written += sizeof(TSuffixLink); + + if (labelPos == 0) { + written += ArcSaveOffset(arc->LeftOffset, os); + written += ArcSaveOffset(arc->RightOffset, os); + } + } + + written += this->NodeSaveLeafValue(arc->Node, os); + return written; + } + }; +}; + + +template <class T> +struct TPatternMatch { + ui64 End; + T Data; + + TPatternMatch(ui64 end, const T& data) + : End(end) + , Data(data) + { + } +}; + + +template <class T = char, class D = ui64, class S = TCompactTriePacker<D>> +class TCompactPatternSearcher { +public: + typedef T TSymbol; + typedef D TData; + typedef S TPacker; + + typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey; + typedef typename TCompactTrieKeySelector<TSymbol>::TKeyBuf TKeyBuf; + + typedef TCompactTrie<TSymbol, TData, TPacker> TTrie; +public: + TCompactPatternSearcher() + { + } + + explicit TCompactPatternSearcher(const TBlob& data) + : Trie(data) + { + } + + TCompactPatternSearcher(const char* data, size_t size) + : Trie(data, size) + { + } + + TVector<TPatternMatch<TData>> SearchMatches(const TSymbol* text, size_t textSize) const; + TVector<TPatternMatch<TData>> SearchMatches(const TKeyBuf& text) const { + return SearchMatches(text.data(), text.size()); + } +private: + TTrie Trie; +}; + +//////////////////// +// Implementation // +//////////////////// + +namespace { + +template <class TData, class TPacker> +char ReadNode( + char* nodeStart, + char*& leftSibling, + char*& rightSibling, + char*& directChild, + NCompactTrie::TSuffixLink*& suffixLink, + TPacker packer = TPacker() +) { + char* dataPos = nodeStart; + char flags = *(dataPos++); + + Y_ASSERT(!NCompactTrie::IsEpsilonLink(flags)); // Epsilon links are not allowed + + char label = *(dataPos++); + + suffixLink = (NCompactTrie::TSuffixLink*)dataPos; + dataPos += sizeof(NCompactTrie::TSuffixLink); + + { // Left branch + size_t offsetLength = NCompactTrie::LeftOffsetLen(flags); + size_t leftOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength); + leftSibling = leftOffset ? (nodeStart + leftOffset) : nullptr; + + dataPos += offsetLength; + } + + + { // Right branch + size_t offsetLength = NCompactTrie::RightOffsetLen(flags); + size_t rightOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength); + rightSibling = rightOffset ? (nodeStart + rightOffset) : nullptr; + + dataPos += offsetLength; + } + + directChild = nullptr; + if (flags & NCompactTrie::MT_NEXT) { + directChild = dataPos; + if (flags & NCompactTrie::MT_FINAL) { + directChild += packer.SkipLeaf(directChild); + } + } + + return label; +} + +template <class TData, class TPacker> +char ReadNodeConst( + const char* nodeStart, + const char*& leftSibling, + const char*& rightSibling, + const char*& directChild, + const char*& data, + NCompactTrie::TSuffixLink& suffixLink, + TPacker packer = TPacker() +) { + const char* dataPos = nodeStart; + char flags = *(dataPos++); + + Y_ASSERT(!NCompactTrie::IsEpsilonLink(flags)); // Epsilon links are not allowed + + char label = *(dataPos++); + + suffixLink = *((NCompactTrie::TSuffixLink*)dataPos); + dataPos += sizeof(NCompactTrie::TSuffixLink); + + { // Left branch + size_t offsetLength = NCompactTrie::LeftOffsetLen(flags); + size_t leftOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength); + leftSibling = leftOffset ? (nodeStart + leftOffset) : nullptr; + + dataPos += offsetLength; + } + + + { // Right branch + size_t offsetLength = NCompactTrie::RightOffsetLen(flags); + size_t rightOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength); + rightSibling = rightOffset ? (nodeStart + rightOffset) : nullptr; + + dataPos += offsetLength; + } + + data = nullptr; + if (flags & NCompactTrie::MT_FINAL) { + data = dataPos; + } + directChild = nullptr; + if (flags & NCompactTrie::MT_NEXT) { + directChild = dataPos; + if (flags & NCompactTrie::MT_FINAL) { + directChild += packer.SkipLeaf(directChild); + } + } + + return label; +} + +Y_FORCE_INLINE bool Advance( + const char*& dataPos, + const char* const dataEnd, + char label +) { + if (dataPos == nullptr) { + return false; + } + + while (dataPos < dataEnd) { + size_t offsetLength, offset; + const char* startPos = dataPos; + char flags = *(dataPos++); + char symbol = *(dataPos++); + dataPos += sizeof(NCompactTrie::TSuffixLink); + + // Left branch + offsetLength = NCompactTrie::LeftOffsetLen(flags); + if ((unsigned char)label < (unsigned char)symbol) { + offset = NCompactTrie::UnpackOffset(dataPos, offsetLength); + if (!offset) + break; + + dataPos = startPos + offset; + continue; + } + + dataPos += offsetLength; + + // Right branch + offsetLength = NCompactTrie::RightOffsetLen(flags); + if ((unsigned char)label > (unsigned char)symbol) { + offset = NCompactTrie::UnpackOffset(dataPos, offsetLength); + if (!offset) + break; + + dataPos = startPos + offset; + continue; + } + + dataPos = startPos; + return true; + } + + // if we got here, we're past the dataend - bail out ASAP + dataPos = nullptr; + return false; +} + +} // anonymous + +template <class T, class D, class S> +void TCompactPatternSearcherBuilder<T, D, S>::CalculateSuffixLinks( + char* trieStart, + const char* trieEnd +) const { + struct TBfsElement { + char* Node; + const char* Parent; + + TBfsElement(char* node, const char* parent) + : Node(node) + , Parent(parent) + { + } + }; + + TDeque<TBfsElement> bfsQueue; + if (trieStart && trieStart != trieEnd) { + bfsQueue.emplace_back(trieStart, nullptr); + } + + while (!bfsQueue.empty()) { + auto front = bfsQueue.front(); + char* node = front.Node; + const char* parent = front.Parent; + bfsQueue.pop_front(); + + char* leftSibling; + char* rightSibling; + char* directChild; + NCompactTrie::TSuffixLink* suffixLink; + + char label = ReadNode<TData, TPacker>( + node, + leftSibling, + rightSibling, + directChild, + suffixLink + ); + + const char* suffix; + + if (parent == nullptr) { + suffix = node; + } else { + const char* parentOfSuffix = parent; + const char* temp; + do { + NCompactTrie::TSuffixLink parentOfSuffixSuffixLink; + + ReadNodeConst<TData, TPacker>( + parentOfSuffix, + /*left*/temp, + /*right*/temp, + /*direct*/temp, + /*data*/temp, + parentOfSuffixSuffixLink + ); + if (parentOfSuffixSuffixLink.NextSuffixOffset == 0) { + suffix = trieStart; + if (!Advance(suffix, trieEnd, label)) { + suffix = node; + } + break; + } + parentOfSuffix += parentOfSuffixSuffixLink.NextSuffixOffset; + + NCompactTrie::TSuffixLink tempSuffixLink; + ReadNodeConst<TData, TPacker>( + parentOfSuffix, + /*left*/temp, + /*right*/temp, + /*direct*/suffix, + /*data*/temp, + tempSuffixLink + ); + + if (suffix == nullptr) { + continue; + } + } while (!Advance(suffix, trieEnd, label)); + } + + suffixLink->NextSuffixOffset = suffix - node; + + NCompactTrie::TSuffixLink suffixSuffixLink; + const char* suffixData; + const char* temp; + ReadNodeConst<TData, TPacker>( + suffix, + /*left*/temp, + /*right*/temp, + /*direct*/temp, + suffixData, + suffixSuffixLink + ); + suffixLink->NextSuffixWithDataOffset = suffix - node; + if (suffixData == nullptr) { + suffixLink->NextSuffixWithDataOffset += suffixSuffixLink.NextSuffixWithDataOffset; + } + + if (directChild) { + bfsQueue.emplace_back(directChild, node); + } + + if (leftSibling) { + bfsQueue.emplace_front(leftSibling, parent); + } + + if (rightSibling) { + bfsQueue.emplace_front(rightSibling, parent); + } + } +} + + +template<class T, class D, class S> +TVector<TPatternMatch<D>> TCompactPatternSearcher<T, D, S>::SearchMatches( + const TSymbol* text, + size_t textSize +) const { + const char* temp; + NCompactTrie::TSuffixLink tempSuffixLink; + + const auto& trieData = Trie.Data(); + const char* trieStart = trieData.AsCharPtr(); + size_t dataSize = trieData.Length(); + const char* trieEnd = trieStart + dataSize; + + const char* lastNode = nullptr; + const char* currentSubtree = trieStart; + + TVector<TPatternMatch<TData>> matches; + + for (const TSymbol* position = text; position < text + textSize; ++position) { + TSymbol symbol = *position; + for (i64 i = (i64)NCompactTrie::ExtraBits<TSymbol>(); i >= 0; i -= 8) { + char label = (char)(symbol >> i); + + // Find first suffix extendable by label + while (true) { + const char* nextLastNode = currentSubtree; + if (Advance(nextLastNode, trieEnd, label)) { + lastNode = nextLastNode; + ReadNodeConst<TData, TPacker>( + lastNode, + /*left*/temp, + /*right*/temp, + currentSubtree, + /*data*/temp, + tempSuffixLink + ); + break; + } else { + if (lastNode == nullptr) { + break; + } + } + + NCompactTrie::TSuffixLink suffixLink; + ReadNodeConst<TData, TPacker>( + lastNode, + /*left*/temp, + /*right*/temp, + /*direct*/temp, + /*data*/temp, + suffixLink + ); + if (suffixLink.NextSuffixOffset == 0) { + lastNode = nullptr; + currentSubtree = trieStart; + continue; + } + lastNode += suffixLink.NextSuffixOffset; + ReadNodeConst<TData, TPacker>( + lastNode, + /*left*/temp, + /*right*/temp, + currentSubtree, + /*data*/temp, + tempSuffixLink + ); + } + + // Iterate through all suffixes + const char* suffix = lastNode; + while (suffix != nullptr) { + const char* nodeData; + NCompactTrie::TSuffixLink suffixLink; + ReadNodeConst<TData, TPacker>( + suffix, + /*left*/temp, + /*right*/temp, + /*direct*/temp, + nodeData, + suffixLink + ); + if (nodeData != nullptr) { + TData data; + Trie.GetPacker().UnpackLeaf(nodeData, data); + matches.emplace_back( + position - text, + data + ); + } + if (suffixLink.NextSuffixOffset == 0) { + break; + } + suffix += suffixLink.NextSuffixWithDataOffset; + } + } + } + + return matches; +} diff --git a/library/cpp/containers/comptrie/prefix_iterator.cpp b/library/cpp/containers/comptrie/prefix_iterator.cpp new file mode 100644 index 00000000000..5d4dfa3500a --- /dev/null +++ b/library/cpp/containers/comptrie/prefix_iterator.cpp @@ -0,0 +1 @@ +#include "prefix_iterator.h" diff --git a/library/cpp/containers/comptrie/prefix_iterator.h b/library/cpp/containers/comptrie/prefix_iterator.h new file mode 100644 index 00000000000..b369bb4f425 --- /dev/null +++ b/library/cpp/containers/comptrie/prefix_iterator.h @@ -0,0 +1,88 @@ +#pragma once + +#include "comptrie_trie.h" + +// Iterates over all prefixes of the given key in the trie. +template <class TTrie> +class TPrefixIterator { +public: + using TSymbol = typename TTrie::TSymbol; + using TPacker = typename TTrie::TPacker; + using TData = typename TTrie::TData; + +private: + const TTrie& Trie; + const TSymbol* key; + size_t keylen; + const TSymbol* keyend; + size_t prefixLen; + const char* valuepos; + const char* datapos; + const char* dataend; + TPacker Packer; + const char* EmptyValue; + bool result; + + bool Next(); + +public: + TPrefixIterator(const TTrie& trie, const TSymbol* aKey, size_t aKeylen) + : Trie(trie) + , key(aKey) + , keylen(aKeylen) + , keyend(aKey + aKeylen) + , prefixLen(0) + , valuepos(nullptr) + , datapos(trie.DataHolder.AsCharPtr()) + , dataend(datapos + trie.DataHolder.Length()) + { + result = Next(); + } + + operator bool() const { + return result; + } + + TPrefixIterator& operator++() { + result = Next(); + return *this; + } + + size_t GetPrefixLen() const { + return prefixLen; + } + + void GetValue(TData& to) const { + Trie.Packer.UnpackLeaf(valuepos, to); + } +}; + +template <class TTrie> +bool TPrefixIterator<TTrie>::Next() { + using namespace NCompactTrie; + if (!key || datapos == dataend) + return false; + + if ((key == keyend - keylen) && !valuepos && Trie.EmptyValue) { + valuepos = Trie.EmptyValue; + return true; + } + + while (datapos && key != keyend) { + TSymbol label = *(key++); + if (!Advance(datapos, dataend, valuepos, label, Packer)) { + return false; + } + if (valuepos) { // There is a value at the end of this symbol. + prefixLen = keylen - (keyend - key); + return true; + } + } + + return false; +} + +template <class TTrie> +TPrefixIterator<TTrie> MakePrefixIterator(const TTrie& trie, const typename TTrie::TSymbol* key, size_t keylen) { + return TPrefixIterator<TTrie>(trie, key, keylen); +} diff --git a/library/cpp/containers/comptrie/protopacker.h b/library/cpp/containers/comptrie/protopacker.h new file mode 100644 index 00000000000..3e15866dc54 --- /dev/null +++ b/library/cpp/containers/comptrie/protopacker.h @@ -0,0 +1,29 @@ +#pragma once + +#include <util/stream/mem.h> +#include <util/ysaveload.h> + +template <class Proto> +class TProtoPacker { +public: + TProtoPacker() = default; + + void UnpackLeaf(const char* p, Proto& entry) const { + TMemoryInput in(p + sizeof(ui32), SkipLeaf(p) - sizeof(ui32)); + entry.ParseFromArcadiaStream(&in); + } + void PackLeaf(char* p, const Proto& entry, size_t size) const { + TMemoryOutput out(p, size + sizeof(ui32)); + Save<ui32>(&out, size); + entry.SerializeToArcadiaStream(&out); + } + size_t MeasureLeaf(const Proto& entry) const { + return entry.ByteSize() + sizeof(ui32); + } + size_t SkipLeaf(const char* p) const { + TMemoryInput in(p, sizeof(ui32)); + ui32 size; + Load<ui32>(&in, size); + return size; + } +}; diff --git a/library/cpp/containers/comptrie/search_iterator.cpp b/library/cpp/containers/comptrie/search_iterator.cpp new file mode 100644 index 00000000000..eb915235744 --- /dev/null +++ b/library/cpp/containers/comptrie/search_iterator.cpp @@ -0,0 +1 @@ +#include "search_iterator.h" diff --git a/library/cpp/containers/comptrie/search_iterator.h b/library/cpp/containers/comptrie/search_iterator.h new file mode 100644 index 00000000000..247f7e59363 --- /dev/null +++ b/library/cpp/containers/comptrie/search_iterator.h @@ -0,0 +1,140 @@ +#pragma once + +#include "comptrie_trie.h" +#include "first_symbol_iterator.h" + +#include <util/str_stl.h> +#include <util/digest/numeric.h> +#include <util/digest/multi.h> + +// Iterator for incremental searching. +// All Advance() methods shift the iterator using specifed key/char. +// The subsequent Advance() call starts searching from the previous state. +// The Advance() returns 'true' if specified key part exists in the trie and +// returns 'false' for unsuccessful search. In case of 'false' result +// all subsequent calls also will return 'false'. +// If current iterator state is final then GetValue() method returns 'true' and +// associated value. + +template <class TTrie> +class TSearchIterator { +public: + using TData = typename TTrie::TData; + using TSymbol = typename TTrie::TSymbol; + using TKeyBuf = typename TTrie::TKeyBuf; + + TSearchIterator() = default; + + explicit TSearchIterator(const TTrie& trie) + : Trie(&trie) + , DataPos(trie.DataHolder.AsCharPtr()) + , DataEnd(DataPos + trie.DataHolder.Length()) + , ValuePos(trie.EmptyValue) + { + } + + explicit TSearchIterator(const TTrie& trie, const TTrie& subTrie) + : Trie(&trie) + , DataPos(subTrie.Data().AsCharPtr()) + , DataEnd(trie.DataHolder.AsCharPtr() + trie.DataHolder.Length()) + , ValuePos(subTrie.EmptyValue) + { + } + + bool operator==(const TSearchIterator& other) const { + Y_ASSERT(Trie && other.Trie); + return Trie == other.Trie && + DataPos == other.DataPos && + DataEnd == other.DataEnd && + ValuePos == other.ValuePos; + } + bool operator!=(const TSearchIterator& other) const { + return !(*this == other); + } + + inline bool Advance(TSymbol label) { + Y_ASSERT(Trie); + if (DataPos == nullptr || DataPos >= DataEnd) { + return false; + } + return NCompactTrie::Advance(DataPos, DataEnd, ValuePos, label, Trie->Packer); + } + inline bool Advance(const TKeyBuf& key) { + return Advance(key.data(), key.size()); + } + bool Advance(const TSymbol* key, size_t keylen); + bool GetValue(TData* value = nullptr) const; + bool HasValue() const; + inline size_t GetHash() const; + +private: + const TTrie* Trie = nullptr; + const char* DataPos = nullptr; + const char* DataEnd = nullptr; + const char* ValuePos = nullptr; +}; + +template <class TTrie> +inline TSearchIterator<TTrie> MakeSearchIterator(const TTrie& trie) { + return TSearchIterator<TTrie>(trie); +} + +template <class TTrie> +struct THash<TSearchIterator<TTrie>> { + inline size_t operator()(const TSearchIterator<TTrie>& item) { + return item.GetHash(); + } +}; + +//---------------------------------------------------------------------------- + +template <class TTrie> +bool TSearchIterator<TTrie>::Advance(const TSymbol* key, size_t keylen) { + Y_ASSERT(Trie); + if (!key || DataPos == nullptr || DataPos >= DataEnd) { + return false; + } + if (!keylen) { + return true; + } + + const TSymbol* keyend = key + keylen; + while (key != keyend && DataPos != nullptr) { + if (!NCompactTrie::Advance(DataPos, DataEnd, ValuePos, *(key++), Trie->Packer)) { + return false; + } + if (key == keyend) { + return true; + } + } + return false; +} + +template <class TTrie> +bool TSearchIterator<TTrie>::GetValue(TData* value) const { + Y_ASSERT(Trie); + bool result = false; + if (value) { + if (ValuePos) { + result = true; + Trie->Packer.UnpackLeaf(ValuePos, *value); + } + } + return result; +} + +template <class TTrie> +bool TSearchIterator<TTrie>::HasValue() const { + Y_ASSERT(Trie); + return ValuePos; +} + +template <class TTrie> +inline size_t TSearchIterator<TTrie>::GetHash() const { + Y_ASSERT(Trie); + return MultiHash( + static_cast<const void*>(Trie), + static_cast<const void*>(DataPos), + static_cast<const void*>(DataEnd), + static_cast<const void*>(ValuePos)); +} diff --git a/library/cpp/containers/comptrie/set.h b/library/cpp/containers/comptrie/set.h new file mode 100644 index 00000000000..acd43338f0a --- /dev/null +++ b/library/cpp/containers/comptrie/set.h @@ -0,0 +1,40 @@ +#pragma once + +#include "comptrie_trie.h" + +template <typename T = char> +class TCompactTrieSet: public TCompactTrie<T, ui8, TNullPacker<ui8>> { +public: + typedef TCompactTrie<T, ui8, TNullPacker<ui8>> TBase; + + using typename TBase::TBuilder; + using typename TBase::TKey; + using typename TBase::TKeyBuf; + using typename TBase::TSymbol; + + TCompactTrieSet() = default; + + explicit TCompactTrieSet(const TBlob& data) + : TBase(data) + { + } + + template <typename D> + explicit TCompactTrieSet(const TCompactTrie<T, D, TNullPacker<D>>& trie) + : TBase(trie.Data()) // should be binary compatible for any D + { + } + + TCompactTrieSet(const char* data, size_t len) + : TBase(data, len) + { + } + + bool Has(const typename TBase::TKeyBuf& key) const { + return TBase::Find(key.data(), key.size()); + } + + bool FindTails(const typename TBase::TKeyBuf& key, TCompactTrieSet<T>& res) const { + return TBase::FindTails(key, res); + } +}; diff --git a/library/cpp/containers/comptrie/ut/ya.make b/library/cpp/containers/comptrie/ut/ya.make new file mode 100644 index 00000000000..c4f4666009a --- /dev/null +++ b/library/cpp/containers/comptrie/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/containers/comptrie) + +OWNER(alzobnin) + +SRCS( + comptrie_ut.cpp +) + +END() diff --git a/library/cpp/containers/comptrie/write_trie_backwards.cpp b/library/cpp/containers/comptrie/write_trie_backwards.cpp new file mode 100644 index 00000000000..fd8c28b0ed9 --- /dev/null +++ b/library/cpp/containers/comptrie/write_trie_backwards.cpp @@ -0,0 +1,110 @@ +#include "write_trie_backwards.h" + +#include "comptrie_impl.h" +#include "leaf_skipper.h" + +#include <util/generic/buffer.h> +#include <util/generic/vector.h> + +namespace NCompactTrie { + size_t WriteTrieBackwards(IOutputStream& os, TReverseNodeEnumerator& enumerator, bool verbose) { + if (verbose) { + Cerr << "Writing down the trie..." << Endl; + } + + // Rewrite everything from the back, removing unused pieces. + const size_t chunksize = 0x10000; + TVector<char*> resultData; + + resultData.push_back(new char[chunksize]); + char* chunkend = resultData.back() + chunksize; + + size_t resultLength = 0; + size_t chunkLength = 0; + + size_t counter = 0; + TBuffer bufferHolder; + while (enumerator.Move()) { + if (verbose) + ShowProgress(++counter); + + size_t bufferLength = 64 + enumerator.GetLeafLength(); // never know how big leaf data can be + bufferHolder.Clear(); + bufferHolder.Resize(bufferLength); + char* buffer = bufferHolder.Data(); + + size_t nodelength = enumerator.RecreateNode(buffer, resultLength); + Y_ASSERT(nodelength <= bufferLength); + + resultLength += nodelength; + + if (chunkLength + nodelength <= chunksize) { + chunkLength += nodelength; + memcpy(chunkend - chunkLength, buffer, nodelength); + } else { // allocate a new chunk + memcpy(chunkend - chunksize, buffer + nodelength - (chunksize - chunkLength), chunksize - chunkLength); + chunkLength = chunkLength + nodelength - chunksize; + + resultData.push_back(new char[chunksize]); + chunkend = resultData.back() + chunksize; + + while (chunkLength > chunksize) { // allocate a new chunks + chunkLength -= chunksize; + memcpy(chunkend - chunksize, buffer + chunkLength, chunksize); + + resultData.push_back(new char[chunksize]); + chunkend = resultData.back() + chunksize; + } + + memcpy(chunkend - chunkLength, buffer, chunkLength); + } + } + + if (verbose) + Cerr << counter << Endl; + + // Write the whole thing down + while (!resultData.empty()) { + char* chunk = resultData.back(); + os.Write(chunk + chunksize - chunkLength, chunkLength); + chunkLength = chunksize; + delete[] chunk; + resultData.pop_back(); + } + + return resultLength; + } + + size_t WriteTrieBackwardsNoAlloc(IOutputStream& os, TReverseNodeEnumerator& enumerator, TOpaqueTrie& trie, EMinimizeMode mode) { + char* data = const_cast<char*>(trie.Data); + char* end = data + trie.Length; + char* pos = end; + + TVector<char> buf(64); + while (enumerator.Move()) { + size_t nodeLength = enumerator.RecreateNode(nullptr, end - pos); + if (nodeLength > buf.size()) + buf.resize(nodeLength); + + size_t realLength = enumerator.RecreateNode(buf.data(), end - pos); + Y_ASSERT(realLength == nodeLength); + + pos -= nodeLength; + memcpy(pos, buf.data(), nodeLength); + } + + switch (mode) { + case MM_NOALLOC: + os.Write(pos, end - pos); + break; + case MM_INPLACE: + memmove(data, pos, end - pos); + break; + default: + Y_VERIFY(false); + } + + return end - pos; + } + +} diff --git a/library/cpp/containers/comptrie/write_trie_backwards.h b/library/cpp/containers/comptrie/write_trie_backwards.h new file mode 100644 index 00000000000..634e6b811a4 --- /dev/null +++ b/library/cpp/containers/comptrie/write_trie_backwards.h @@ -0,0 +1,23 @@ +#pragma once + +#include "minimize.h" + +#include <util/generic/vector.h> +#include <util/stream/output.h> +#include <cstddef> + +namespace NCompactTrie { + class TReverseNodeEnumerator { + public: + virtual ~TReverseNodeEnumerator() = default; + virtual bool Move() = 0; + virtual size_t GetLeafLength() const = 0; + virtual size_t RecreateNode(char* buffer, size_t resultLength) = 0; + }; + + struct TOpaqueTrie; + + size_t WriteTrieBackwards(IOutputStream& os, TReverseNodeEnumerator& enumerator, bool verbose); + size_t WriteTrieBackwardsNoAlloc(IOutputStream& os, TReverseNodeEnumerator& enumerator, TOpaqueTrie& trie, EMinimizeMode mode); + +} diff --git a/library/cpp/containers/comptrie/writeable_node.cpp b/library/cpp/containers/comptrie/writeable_node.cpp new file mode 100644 index 00000000000..404003dbbd2 --- /dev/null +++ b/library/cpp/containers/comptrie/writeable_node.cpp @@ -0,0 +1,96 @@ +#include "writeable_node.h" +#include "node.h" +#include "comptrie_impl.h" + +namespace NCompactTrie { + TWriteableNode::TWriteableNode() + : LeafPos(nullptr) + , LeafLength(0) + , ForwardOffset(NPOS) + , LeftOffset(NPOS) + , RightOffset(NPOS) + , Label(0) + { + } + + static size_t GetOffsetFromEnd(const TNode& node, size_t absOffset) { + return absOffset ? absOffset - node.GetOffset() - node.GetCoreLength() : NPOS; + } + + TWriteableNode::TWriteableNode(const TNode& node, const char* data) + : LeafPos(node.IsFinal() ? data + node.GetLeafOffset() : nullptr) + , LeafLength(node.GetLeafLength()) + , ForwardOffset(GetOffsetFromEnd(node, node.GetForwardOffset())) + , LeftOffset(GetOffsetFromEnd(node, node.GetLeftOffset())) + , RightOffset(GetOffsetFromEnd(node, node.GetRightOffset())) + , Label(node.GetLabel()) + { + } + + size_t TWriteableNode::Measure() const { + size_t len = 2 + LeafLength; + size_t fwdLen = 0; + size_t lastLen = 0; + size_t lastFwdLen = 0; + // Now, increase all the offsets by the length and recalculate everything, until it converges + do { + lastLen = len; + lastFwdLen = fwdLen; + + len = 2 + LeafLength; + len += MeasureOffset(LeftOffset != NPOS ? LeftOffset + lastLen : 0); + len += MeasureOffset(RightOffset != NPOS ? RightOffset + lastLen : 0); + + // Relative forward offset of 0 means we don't need extra length for an epsilon link. + // But an epsilon link means we need an extra 1 for the flags and the forward offset is measured + // from the start of the epsilon link, not from the start of our node. + if (ForwardOffset != NPOS && ForwardOffset != 0) { + fwdLen = MeasureOffset(ForwardOffset + lastFwdLen) + 1; + len += fwdLen; + } + + } while (lastLen != len || lastFwdLen != fwdLen); + + return len; + } + + size_t TWriteableNode::Pack(char* buffer) const { + const size_t length = Measure(); + + char flags = 0; + if (LeafPos) { + flags |= MT_FINAL; + } + if (ForwardOffset != NPOS) { + flags |= MT_NEXT; + } + + const size_t leftOffset = LeftOffset != NPOS ? LeftOffset + length : 0; + const size_t rightOffset = RightOffset != NPOS ? RightOffset + length : 0; + const size_t leftOffsetSize = MeasureOffset(leftOffset); + const size_t rightOffsetSize = MeasureOffset(rightOffset); + flags |= (leftOffsetSize << MT_LEFTSHIFT); + flags |= (rightOffsetSize << MT_RIGHTSHIFT); + + buffer[0] = flags; + buffer[1] = Label; + size_t usedLen = 2; + usedLen += PackOffset(buffer + usedLen, leftOffset); + usedLen += PackOffset(buffer + usedLen, rightOffset); + + if (LeafPos && LeafLength) { + memcpy(buffer + usedLen, LeafPos, LeafLength); + usedLen += LeafLength; + } + + if (ForwardOffset != NPOS && ForwardOffset != 0) { + const size_t fwdOffset = ForwardOffset + length - usedLen; + size_t fwdOffsetSize = MeasureOffset(fwdOffset); + buffer[usedLen++] = (char)(fwdOffsetSize & MT_SIZEMASK); + usedLen += PackOffset(buffer + usedLen, fwdOffset); + } + Y_ASSERT(usedLen == length); + return usedLen; + } + +} diff --git a/library/cpp/containers/comptrie/writeable_node.h b/library/cpp/containers/comptrie/writeable_node.h new file mode 100644 index 00000000000..5454e579ef0 --- /dev/null +++ b/library/cpp/containers/comptrie/writeable_node.h @@ -0,0 +1,26 @@ +#pragma once + +#include <cstddef> + +namespace NCompactTrie { + class TNode; + + class TWriteableNode { + public: + const char* LeafPos; + size_t LeafLength; + + size_t ForwardOffset; + size_t LeftOffset; + size_t RightOffset; + char Label; + + TWriteableNode(); + TWriteableNode(const TNode& node, const char* data); + + // When you call this, the offsets should be relative to the end of the node. Use NPOS to indicate an absent offset. + size_t Pack(char* buffer) const; + size_t Measure() const; + }; + +} diff --git a/library/cpp/containers/comptrie/ya.make b/library/cpp/containers/comptrie/ya.make new file mode 100644 index 00000000000..81352da4b25 --- /dev/null +++ b/library/cpp/containers/comptrie/ya.make @@ -0,0 +1,35 @@ +LIBRARY() + +OWNER(velavokr) + +SRCS( + array_with_size.h + chunked_helpers_trie.h + comptrie.h + comptrie_packer.h + comptrie_trie.h + first_symbol_iterator.h + key_selector.h + leaf_skipper.h + set.h + comptrie.cpp + comptrie_builder.cpp + comptrie_impl.cpp + make_fast_layout.cpp + minimize.cpp + node.cpp + opaque_trie_iterator.cpp + prefix_iterator.cpp + search_iterator.cpp + write_trie_backwards.cpp + writeable_node.cpp +) + +PEERDIR( + library/cpp/packers + library/cpp/containers/compact_vector + library/cpp/on_disk/chunks + util/draft +) + +END() diff --git a/library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.cpp b/library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.cpp new file mode 100644 index 00000000000..7334a43c362 --- /dev/null +++ b/library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.cpp @@ -0,0 +1 @@ +#include "disjoint_interval_tree.h" diff --git a/library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.h b/library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.h new file mode 100644 index 00000000000..1f899c99913 --- /dev/null +++ b/library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.h @@ -0,0 +1,272 @@ +#pragma once + +#include <util/generic/map.h> +#include <util/system/yassert.h> + +#include <type_traits> + +template <class T> +class TDisjointIntervalTree { +private: + static_assert(std::is_integral<T>::value, "expect std::is_integral<T>::value"); + + using TTree = TMap<T, T>; // [key, value) + using TIterator = typename TTree::iterator; + using TConstIterator = typename TTree::const_iterator; + using TReverseIterator = typename TTree::reverse_iterator; + using TThis = TDisjointIntervalTree<T>; + + TTree Tree; + size_t NumElements; + +public: + TDisjointIntervalTree() + : NumElements() + { + } + + void Insert(const T t) { + InsertInterval(t, t + 1); + } + + // we assume that none of elements from [begin, end) belong to tree. + void InsertInterval(const T begin, const T end) { + InsertIntervalImpl(begin, end); + NumElements += (size_t)(end - begin); + } + + bool Has(const T t) const { + return const_cast<TThis*>(this)->FindContaining(t) != Tree.end(); + } + + bool Intersects(const T begin, const T end) { + if (Empty()) { + return false; + } + + TIterator l = Tree.lower_bound(begin); + if (l != Tree.end()) { + if (l->first < end) { + return true; + } else if (l != Tree.begin()) { + --l; + return l->second > begin; + } else { + return false; + } + } else { + auto last = Tree.rbegin(); + return begin < last->second; + } + } + + TConstIterator FindContaining(const T t) const { + return const_cast<TThis*>(this)->FindContaining(t); + } + + // Erase element. Returns true when element has been deleted, otherwise false. + bool Erase(const T t) { + TIterator n = FindContaining(t); + if (n == Tree.end()) { + return false; + } + + --NumElements; + + T& begin = const_cast<T&>(n->first); + T& end = const_cast<T&>(n->second); + + // Optimization hack. + if (t == begin) { + if (++begin == end) { // OK to change key since intervals do not intersect. + Tree.erase(n); + return true; + } + + } else if (t == end - 1) { + --end; + + } else { + const T e = end; + end = t; + InsertIntervalImpl(t + 1, e); + } + + Y_ASSERT(begin < end); + return true; + } + + // Erase interval. Returns number of elements removed from set. + size_t EraseInterval(const T begin, const T end) { + Y_ASSERT(begin < end); + + if (Empty()) { + return 0; + } + + size_t elementsRemoved = 0; + + TIterator completelyRemoveBegin = Tree.lower_bound(begin); + if ((completelyRemoveBegin != Tree.end() && completelyRemoveBegin->first > begin && completelyRemoveBegin != Tree.begin()) + || completelyRemoveBegin == Tree.end()) { + // Look at the interval. It could contain [begin, end). + TIterator containingBegin = completelyRemoveBegin; + --containingBegin; + if (containingBegin->first < begin && begin < containingBegin->second) { // Contains begin. + if (containingBegin->second > end) { // Contains end. + const T prevEnd = containingBegin->second; + Y_ASSERT(containingBegin->second - begin <= NumElements); + + Y_ASSERT(containingBegin->second - containingBegin->first > end - begin); + containingBegin->second = begin; + InsertIntervalImpl(end, prevEnd); + + elementsRemoved = end - begin; + NumElements -= elementsRemoved; + return elementsRemoved; + } else { + elementsRemoved += containingBegin->second - begin; + containingBegin->second = begin; + } + } + } + + TIterator completelyRemoveEnd = completelyRemoveBegin != Tree.end() ? Tree.lower_bound(end) : Tree.end(); + if (completelyRemoveEnd != Tree.end() && completelyRemoveEnd != Tree.begin() && completelyRemoveEnd->first != end) { + TIterator containingEnd = completelyRemoveEnd; + --containingEnd; + if (containingEnd->second > end) { + T& leftBorder = const_cast<T&>(containingEnd->first); + + Y_ASSERT(leftBorder < end); + + --completelyRemoveEnd; // Don't remove the whole interval. + + // Optimization hack. + elementsRemoved += end - leftBorder; + leftBorder = end; // OK to change key since intervals do not intersect. + } + } + + for (TIterator i = completelyRemoveBegin; i != completelyRemoveEnd; ++i) { + elementsRemoved += i->second - i->first; + } + + Tree.erase(completelyRemoveBegin, completelyRemoveEnd); + + Y_ASSERT(elementsRemoved <= NumElements); + NumElements -= elementsRemoved; + + return elementsRemoved; + } + + void Swap(TDisjointIntervalTree& rhv) { + Tree.swap(rhv.Tree); + std::swap(NumElements, rhv.NumElements); + } + + void Clear() { + Tree.clear(); + NumElements = 0; + } + + bool Empty() const { + return Tree.empty(); + } + + size_t GetNumElements() const { + return NumElements; + } + + size_t GetNumIntervals() const { + return Tree.size(); + } + + T Min() const { + Y_ASSERT(!Empty()); + return Tree.begin()->first; + } + + T Max() const { + Y_ASSERT(!Empty()); + return Tree.rbegin()->second; + } + + TConstIterator begin() const { + return Tree.begin(); + } + + TConstIterator end() const { + return Tree.end(); + } + +private: + void InsertIntervalImpl(const T begin, const T end) { + Y_ASSERT(begin < end); + Y_ASSERT(!Intersects(begin, end)); + + TIterator l = Tree.lower_bound(begin); + TIterator p = Tree.end(); + if (l != Tree.begin()) { + p = l; + --p; + } + +#ifndef NDEBUG + TIterator u = Tree.upper_bound(begin); + Y_VERIFY_DEBUG(u == Tree.end() || u->first >= end, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, u->first, u->second); + Y_VERIFY_DEBUG(l == Tree.end() || l == u, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, l->first, l->second); + Y_VERIFY_DEBUG(p == Tree.end() || p->second <= begin, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, p->first, p->second); +#endif + + // try to extend interval + if (p != Tree.end() && p->second == begin) { + p->second = end; + //Try to merge 2 intervals - p and next one if possible + auto next = p; + // Next is not Tree.end() here. + ++next; + if (next != Tree.end() && next->first == end) { + p->second = next->second; + Tree.erase(next); + } + // Maybe new interval extends right interval + } else if (l != Tree.end() && end == l->first) { + T& leftBorder = const_cast<T&>(l->first); + // Optimization hack. + leftBorder = begin; // OK to change key since intervals do not intersect. + } else { + Tree.insert(std::make_pair(begin, end)); + } + } + + TIterator FindContaining(const T t) { + TIterator l = Tree.lower_bound(t); + if (l != Tree.end()) { + if (l->first == t) { + return l; + } + Y_ASSERT(l->first > t); + + if (l == Tree.begin()) { + return Tree.end(); + } + + --l; + Y_ASSERT(l->first != t); + + if (l->first < t && t < l->second) { + return l; + } + + } else if (!Tree.empty()) { // l is larger than Begin of any interval, but maybe it belongs to last interval? + TReverseIterator last = Tree.rbegin(); + Y_ASSERT(last->first != t); + + if (last->first < t && t < last->second) { + return (++last).base(); + } + } + return Tree.end(); + } +}; diff --git a/library/cpp/containers/disjoint_interval_tree/ut/disjoint_interval_tree_ut.cpp b/library/cpp/containers/disjoint_interval_tree/ut/disjoint_interval_tree_ut.cpp new file mode 100644 index 00000000000..8474ae89b04 --- /dev/null +++ b/library/cpp/containers/disjoint_interval_tree/ut/disjoint_interval_tree_ut.cpp @@ -0,0 +1,279 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/containers/disjoint_interval_tree/disjoint_interval_tree.h> + +Y_UNIT_TEST_SUITE(DisjointIntervalTreeTest) { + Y_UNIT_TEST(GenericTest) { + TDisjointIntervalTree<ui64> tree; + tree.Insert(1); + tree.Insert(50); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 2); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 2); + + tree.InsertInterval(10, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 22); + + UNIT_ASSERT_VALUES_EQUAL(tree.Min(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.Max(), 51); + + tree.Erase(20); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 4); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 21); + + tree.Clear(); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 0); + } + + Y_UNIT_TEST(MergeIntervalsTest) { + TDisjointIntervalTree<ui64> tree; + tree.Insert(5); + + // Insert interval from right side. + tree.Insert(6); + + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 2); + + { + auto begin = tree.begin(); + UNIT_ASSERT_VALUES_EQUAL(begin->first, 5); + UNIT_ASSERT_VALUES_EQUAL(begin->second, 7); + + ++begin; + UNIT_ASSERT_EQUAL(begin, tree.end()); + } + + // Insert interval from left side. + tree.InsertInterval(2, 5); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + + { + auto begin = tree.begin(); + UNIT_ASSERT_VALUES_EQUAL(begin->first, 2); + UNIT_ASSERT_VALUES_EQUAL(begin->second, 7); + } + + // Merge all intervals. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(0, 3); + tree.InsertInterval(6, 10); + tree.InsertInterval(3, 6); + + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 10); + + auto begin = tree.begin(); + UNIT_ASSERT_VALUES_EQUAL(begin->first, 0); + UNIT_ASSERT_VALUES_EQUAL(begin->second, 10); + } + + } + + Y_UNIT_TEST(EraseIntervalTest) { + // 1. Remove from empty tree. + { + TDisjointIntervalTree<ui64> tree; + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(1, 3), 0); + } + + // 2. No such interval in set. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(1, 3), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(20, 30), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + } + + // 3. Remove the whole tree. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(0, 100), 5); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 0); + UNIT_ASSERT(tree.Empty()); + } + + // 4. Remove the whole tree with borders specified exactly as in tree. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(5, 10), 5); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 0); + UNIT_ASSERT(tree.Empty()); + } + + // 5. Specify left border exactly as in existing interval. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(15, 20); + tree.InsertInterval(25, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(15, 100500), 10); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + } + + // 6. Specify left border somewhere in existing interval. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(15, 20); + tree.InsertInterval(25, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(16, 100500), 9); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 2); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 6); + } + + // 7. Remove from the center of existing interval. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(15, 20); + tree.InsertInterval(25, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(17, 19), 2); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 4); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 13); + + UNIT_ASSERT(tree.Has(16)); + UNIT_ASSERT(tree.Has(19)); + } + + // 8. Remove from the center of the only existing interval. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(15, 20); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(17, 19), 2); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 2); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 3); + + UNIT_ASSERT(tree.Has(16)); + UNIT_ASSERT(tree.Has(19)); + } + + // 9. Specify borders between existing intervals. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(15, 20); + tree.InsertInterval(25, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(10, 15), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(13, 15), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(10, 13), 0); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + } + + // 10. Specify right border exactly as in existing interval. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(15, 20); + tree.InsertInterval(25, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(0, 20), 10); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 1); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 5); + } + + // 11. Specify right border somewhere in existing interval. + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(15, 20); + tree.InsertInterval(25, 30); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 3); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 15); + + UNIT_ASSERT_VALUES_EQUAL(tree.EraseInterval(2, 17), 7); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumIntervals(), 2); + UNIT_ASSERT_VALUES_EQUAL(tree.GetNumElements(), 8); + } + } + + Y_UNIT_TEST(IntersectsTest) { + { + TDisjointIntervalTree<ui64> tree; + UNIT_ASSERT(!tree.Intersects(1, 2)); + } + + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + + UNIT_ASSERT(tree.Intersects(5, 10)); + UNIT_ASSERT(tree.Intersects(5, 6)); + UNIT_ASSERT(tree.Intersects(9, 10)); + UNIT_ASSERT(tree.Intersects(6, 8)); + UNIT_ASSERT(tree.Intersects(1, 8)); + UNIT_ASSERT(tree.Intersects(8, 15)); + UNIT_ASSERT(tree.Intersects(3, 14)); + + UNIT_ASSERT(!tree.Intersects(3, 5)); + UNIT_ASSERT(!tree.Intersects(10, 13)); + } + + { + TDisjointIntervalTree<ui64> tree; + tree.InsertInterval(5, 10); + tree.InsertInterval(20, 30); + + UNIT_ASSERT(tree.Intersects(5, 10)); + UNIT_ASSERT(tree.Intersects(5, 6)); + UNIT_ASSERT(tree.Intersects(9, 10)); + UNIT_ASSERT(tree.Intersects(6, 8)); + UNIT_ASSERT(tree.Intersects(1, 8)); + UNIT_ASSERT(tree.Intersects(8, 15)); + UNIT_ASSERT(tree.Intersects(3, 14)); + UNIT_ASSERT(tree.Intersects(18, 21)); + UNIT_ASSERT(tree.Intersects(3, 50)); + + UNIT_ASSERT(!tree.Intersects(3, 5)); + UNIT_ASSERT(!tree.Intersects(10, 13)); + UNIT_ASSERT(!tree.Intersects(15, 18)); + } + } +} diff --git a/library/cpp/containers/disjoint_interval_tree/ut/ya.make b/library/cpp/containers/disjoint_interval_tree/ut/ya.make new file mode 100644 index 00000000000..6736ce0c2bd --- /dev/null +++ b/library/cpp/containers/disjoint_interval_tree/ut/ya.make @@ -0,0 +1,12 @@ +UNITTEST_FOR(library/cpp/containers/disjoint_interval_tree) + +OWNER( + dcherednik + galaxycrab +) + +SRCS( + disjoint_interval_tree_ut.cpp +) + +END() diff --git a/library/cpp/containers/disjoint_interval_tree/ya.make b/library/cpp/containers/disjoint_interval_tree/ya.make new file mode 100644 index 00000000000..b4f5a52a67f --- /dev/null +++ b/library/cpp/containers/disjoint_interval_tree/ya.make @@ -0,0 +1,10 @@ +OWNER( + dcherednik + galaxycrab +) + +LIBRARY() + +SRCS(disjoint_interval_tree.cpp) + +END() diff --git a/library/cpp/containers/flat_hash/benchmark/flat_hash_benchmark.cpp b/library/cpp/containers/flat_hash/benchmark/flat_hash_benchmark.cpp new file mode 100644 index 00000000000..040cff3fffa --- /dev/null +++ b/library/cpp/containers/flat_hash/benchmark/flat_hash_benchmark.cpp @@ -0,0 +1,180 @@ +#include <library/cpp/containers/flat_hash/flat_hash.h> + +#include <library/cpp/containers/dense_hash/dense_hash.h> +#include <library/cpp/testing/benchmark/bench.h> + +#include <util/random/random.h> +#include <util/generic/xrange.h> +#include <util/generic/hash.h> + +namespace { + +template <class Map, size_t elemCount, class... Args> +void RunLookupPositiveScalarKeysBench(::NBench::NCpu::TParams& iface, Args&&... args) { + using key_type = i32; + static_assert(std::is_same_v<typename Map::key_type, key_type>); + Map hm(std::forward<Args>(args)...); + + TVector<i32> keys(elemCount); + for (auto& k : keys) { + k = RandomNumber<ui32>(std::numeric_limits<i32>::max()); + hm.emplace(k, 0); + } + + for (const auto i : xrange(iface.Iterations())) { + Y_UNUSED(i); + for (const auto& k : keys) { + Y_DO_NOT_OPTIMIZE_AWAY(hm[k]); + } + } +} + +constexpr size_t TEST1_ELEM_COUNT = 10; +constexpr size_t TEST2_ELEM_COUNT = 1000; +constexpr size_t TEST3_ELEM_COUNT = 1000000; + +} + +/* *********************************** TEST1 *********************************** + * Insert TEST1_ELEM_COUNT positive integers and than make lookup. + * No init size provided for tables. + * key_type - i32 + */ + +Y_CPU_BENCHMARK(Test1_fh_TFlatHashMap_LinearProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int>, TEST1_ELEM_COUNT>(iface); +} + +/* +Y_CPU_BENCHMARK(Test1_fh_TFlatHashMap_QuadraticProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int, THash<i32>, + std::equal_to<i32>, NFlatHash::TQuadraticProbing>, TEST1_ELEM_COUNT>(iface); +} +*/ + +Y_CPU_BENCHMARK(Test1_fh_TFlatHashMap_DenseProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int, THash<i32>, + std::equal_to<i32>, NFlatHash::TDenseProbing>, TEST1_ELEM_COUNT>(iface); +} + + +Y_CPU_BENCHMARK(Test1_fh_TDenseHashMapStaticMarker_LinearProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1, THash<i32>, + std::equal_to<i32>, NFlatHash::TLinearProbing>, TEST1_ELEM_COUNT>(iface); +} + +/* +Y_CPU_BENCHMARK(Test1_fh_TDenseHashMapStaticMarker_QuadraticProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1, THash<i32>, + std::equal_to<i32>, NFlatHash::TQuadraticProbing>, TEST1_ELEM_COUNT>(iface); +} +*/ + +Y_CPU_BENCHMARK(Test1_fh_TDenseHashMapStaticMarker_DenseProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1>, TEST1_ELEM_COUNT>(iface); +} + + +Y_CPU_BENCHMARK(Test1_foreign_TDenseHash, iface) { + RunLookupPositiveScalarKeysBench<TDenseHash<i32, int>, TEST1_ELEM_COUNT>(iface, (i32)-1); +} + +Y_CPU_BENCHMARK(Test1_foreign_THashMap, iface) { + RunLookupPositiveScalarKeysBench<THashMap<i32, int>, TEST1_ELEM_COUNT>(iface); +} + +/* *********************************** TEST2 *********************************** + * Insert TEST2_ELEM_COUNT positive integers and than make lookup. + * No init size provided for tables. + * key_type - i32 + */ + +Y_CPU_BENCHMARK(Test2_fh_TFlatHashMap_LinearProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int>, TEST2_ELEM_COUNT>(iface); +} + +/* +Y_CPU_BENCHMARK(Test2_fh_TFlatHashMap_QuadraticProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int, THash<i32>, + std::equal_to<i32>, NFlatHash::TQuadraticProbing>, TEST2_ELEM_COUNT>(iface); +} +*/ + +Y_CPU_BENCHMARK(Test2_fh_TFlatHashMap_DenseProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int, THash<i32>, + std::equal_to<i32>, NFlatHash::TDenseProbing>, TEST2_ELEM_COUNT>(iface); +} + + +Y_CPU_BENCHMARK(Test2_fh_TDenseHashMapStaticMarker_LinearProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1, THash<i32>, + std::equal_to<i32>, NFlatHash::TLinearProbing>, TEST2_ELEM_COUNT>(iface); +} + +/* +Y_CPU_BENCHMARK(Test2_fh_TDenseHashMapStaticMarker_QuadraticProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1, THash<i32>, + std::equal_to<i32>, NFlatHash::TQuadraticProbing>, TEST2_ELEM_COUNT>(iface); +} +*/ + +Y_CPU_BENCHMARK(Test2_fh_TDenseHashMapStaticMarker_DenseProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1>, TEST2_ELEM_COUNT>(iface); +} + + +Y_CPU_BENCHMARK(Test2_foreign_TDenseHash, iface) { + RunLookupPositiveScalarKeysBench<TDenseHash<i32, int>, TEST2_ELEM_COUNT>(iface, (i32)-1); +} + +Y_CPU_BENCHMARK(Test2_foreign_THashMap, iface) { + RunLookupPositiveScalarKeysBench<THashMap<i32, int>, TEST2_ELEM_COUNT>(iface); +} + +/* *********************************** TEST3 *********************************** + * Insert TEST2_ELEM_COUNT positive integers and than make lookup. + * No init size provided for tables. + * key_type - i32 + */ + +Y_CPU_BENCHMARK(Test3_fh_TFlatHashMap_LinearProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int>, TEST3_ELEM_COUNT>(iface); +} + +/* +Y_CPU_BENCHMARK(Test3_fh_TFlatHashMap_QuadraticProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int, THash<i32>, + std::equal_to<i32>, NFlatHash::TQuadraticProbing>, TEST3_ELEM_COUNT>(iface); +} +*/ + +Y_CPU_BENCHMARK(Test3_fh_TFlatHashMap_DenseProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TFlatHashMap<i32, int, THash<i32>, + std::equal_to<i32>, NFlatHash::TDenseProbing>, TEST3_ELEM_COUNT>(iface); +} + + +Y_CPU_BENCHMARK(Test3_fh_TDenseHashMapStaticMarker_LinearProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1, THash<i32>, + std::equal_to<i32>, NFlatHash::TLinearProbing>, TEST3_ELEM_COUNT>(iface); +} + +/* +Y_CPU_BENCHMARK(Test3_fh_TDenseHashMapStaticMarker_QuadraticProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1, THash<i32>, + std::equal_to<i32>, NFlatHash::TQuadraticProbing>, TEST3_ELEM_COUNT>(iface); +} +*/ + +Y_CPU_BENCHMARK(Test3_fh_TDenseHashMapStaticMarker_DenseProbing, iface) { + RunLookupPositiveScalarKeysBench<NFH::TDenseHashMapStaticMarker<i32, int, -1>, TEST3_ELEM_COUNT>(iface); +} + + +Y_CPU_BENCHMARK(Test3_foreign_TDenseHash, iface) { + RunLookupPositiveScalarKeysBench<TDenseHash<i32, int>, TEST3_ELEM_COUNT>(iface, (i32)-1); +} + +Y_CPU_BENCHMARK(Test3_foreign_THashMap, iface) { + RunLookupPositiveScalarKeysBench<THashMap<i32, int>, TEST3_ELEM_COUNT>(iface); +} diff --git a/library/cpp/containers/flat_hash/benchmark/ya.make b/library/cpp/containers/flat_hash/benchmark/ya.make new file mode 100644 index 00000000000..6f9aedf50d5 --- /dev/null +++ b/library/cpp/containers/flat_hash/benchmark/ya.make @@ -0,0 +1,13 @@ +Y_BENCHMARK() + +OWNER(tender-bum) + +SRCS( + flat_hash_benchmark.cpp +) + +PEERDIR( + library/cpp/containers/flat_hash +) + +END() diff --git a/library/cpp/containers/flat_hash/flat_hash.cpp b/library/cpp/containers/flat_hash/flat_hash.cpp new file mode 100644 index 00000000000..2a16398bd48 --- /dev/null +++ b/library/cpp/containers/flat_hash/flat_hash.cpp @@ -0,0 +1 @@ +#include "flat_hash.h" diff --git a/library/cpp/containers/flat_hash/flat_hash.h b/library/cpp/containers/flat_hash/flat_hash.h new file mode 100644 index 00000000000..582b8ae8f5d --- /dev/null +++ b/library/cpp/containers/flat_hash/flat_hash.h @@ -0,0 +1,120 @@ +#pragma once + +#include <library/cpp/containers/flat_hash/lib/map.h> +#include <library/cpp/containers/flat_hash/lib/containers.h> +#include <library/cpp/containers/flat_hash/lib/probings.h> +#include <library/cpp/containers/flat_hash/lib/set.h> +#include <library/cpp/containers/flat_hash/lib/size_fitters.h> +#include <library/cpp/containers/flat_hash/lib/expanders.h> + +#include <util/str_stl.h> + +namespace NPrivate { + +template <class Key, class T, class Hash, class KeyEqual, class Probing, class Alloc> +using TFlatHashMapImpl = NFlatHash::TMap<Key, T, Hash, KeyEqual, + NFlatHash::TFlatContainer<std::pair<const Key, T>, Alloc>, + Probing, NFlatHash::TAndSizeFitter, + NFlatHash::TSimpleExpander>; + +template <class Key, class T, auto emptyMarker, class Hash, class KeyEqual, class Probing, class Alloc> +using TDenseHashMapImpl = + NFlatHash::TMap<Key, T, Hash, KeyEqual, + NFlatHash::TDenseContainer<std::pair<const Key, T>, + NFlatHash::NMap::TStaticValueMarker<emptyMarker, T>, + Alloc>, + Probing, NFlatHash::TAndSizeFitter, NFlatHash::TSimpleExpander>; + + +template <class T, class Hash, class KeyEqual, class Probing, class Alloc> +using TFlatHashSetImpl = NFlatHash::TSet<T, Hash, KeyEqual, + NFlatHash::TFlatContainer<T, Alloc>, + Probing, NFlatHash::TAndSizeFitter, + NFlatHash::TSimpleExpander>; + +template <class T, auto emptyMarker, class Hash, class KeyEqual, class Probing, class Alloc> +using TDenseHashSetImpl = + NFlatHash::TSet<T, Hash, KeyEqual, + NFlatHash::TDenseContainer<T, NFlatHash::NSet::TStaticValueMarker<emptyMarker>, Alloc>, + Probing, NFlatHash::TAndSizeFitter, NFlatHash::TSimpleExpander>; + +} // namespace NPrivate + +namespace NFH { + +/* flat_map: Fast and highly customizable hash map. + * + * Most features would be available soon. + * Until that time we strongly insist on using only class aliases listed below. + */ + +/* Simpliest open addressing hash map. + * Uses additional array to denote status of every bucket. + * Default probing is linear. + * Currently available probings: + * * TLinearProbing + * * TQuadraticProbing + * * TDenseProbing + */ +template <class Key, + class T, + class Hash = THash<Key>, + class KeyEqual = std::equal_to<>, + class Probing = NFlatHash::TLinearProbing, + class Alloc = std::allocator<std::pair<const Key, T>>> +using TFlatHashMap = NPrivate::TFlatHashMapImpl<Key, T, Hash, KeyEqual, Probing, Alloc>; + +/* Open addressing table with user specified marker for empty buckets. + * Currently available probings: + * * TLinearProbing + * * TQuadraticProbing + * * TDenseProbing + */ +template <class Key, + class T, + auto emptyMarker, + class Hash = THash<Key>, + class KeyEqual = std::equal_to<>, + class Probing = NFlatHash::TDenseProbing, + class Alloc = std::allocator<std::pair<const Key, T>>> +using TDenseHashMapStaticMarker = NPrivate::TDenseHashMapImpl<Key, T, emptyMarker, + Hash, KeyEqual, Probing, Alloc>; + + +/* flat_set: Fast and highly customizable hash set. + * + * Most features would be available soon. + * Until that time we strongly insist on using only class aliases listed below. + */ + +/* Simpliest open addressing hash map. + * Uses additional array to denote status of every bucket. + * Default probing is linear. + * Currently available probings: + * * TLinearProbing + * * TQuadraticProbing + * * TDenseProbing + */ +template <class T, + class Hash = THash<T>, + class KeyEqual = std::equal_to<>, + class Probing = NFlatHash::TLinearProbing, + class Alloc = std::allocator<T>> +using TFlatHashSet = NPrivate::TFlatHashSetImpl<T, Hash, KeyEqual, Probing, Alloc>; + +/* Open addressing table with user specified marker for empty buckets. + * Currently available probings: + * * TLinearProbing + * * TQuadraticProbing + * * TDenseProbing + */ +template <class T, + auto emptyMarker, + class Hash = THash<T>, + class KeyEqual = std::equal_to<>, + class Probing = NFlatHash::TDenseProbing, + class Alloc = std::allocator<T>> +using TDenseHashSetStaticMarker = NPrivate::TDenseHashSetImpl<T, emptyMarker, + Hash, KeyEqual, Probing, Alloc>; + +} // namespace NFH diff --git a/library/cpp/containers/flat_hash/lib/concepts/concepts.cpp b/library/cpp/containers/flat_hash/lib/concepts/concepts.cpp new file mode 100644 index 00000000000..63eed9acdd8 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/concepts/concepts.cpp @@ -0,0 +1,4 @@ +#include "container.h" +#include "iterator.h" +#include "size_fitter.h" +#include "value_marker.h" diff --git a/library/cpp/containers/flat_hash/lib/concepts/container.h b/library/cpp/containers/flat_hash/lib/concepts/container.h new file mode 100644 index 00000000000..eac1803b59e --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/concepts/container.h @@ -0,0 +1,66 @@ +#pragma once + +#include <type_traits> + +/* Concepts: + * Container + * RemovalContainer + */ +namespace NFlatHash::NConcepts { + +#define DCV(type) std::declval<type>() +#define DCT(object) decltype(object) + +template <class T, class = void> +struct Container : std::false_type {}; + +template <class T> +struct Container<T, std::void_t< + typename T::value_type, + typename T::size_type, + typename T::difference_type, + DCT(DCV(T).Node(DCV(typename T::size_type))), + DCT(DCV(const T).Node(DCV(typename T::size_type))), + DCT(DCV(const T).Size()), + DCT(DCV(const T).Taken()), + DCT(DCV(const T).Empty()), + DCT(DCV(const T).IsEmpty(DCV(typename T::size_type))), + DCT(DCV(const T).IsTaken(DCV(typename T::size_type))), + DCT(DCV(T).Swap(DCV(T&))), + DCT(DCV(const T).Clone(DCV(typename T::size_type)))>> + : std::conjunction<std::is_same<DCT(DCV(T).Node(DCV(typename T::size_type))), + typename T::value_type&>, + std::is_same<DCT(DCV(const T).Node(DCV(typename T::size_type))), + const typename T::value_type&>, + std::is_same<DCT(DCV(const T).Size()), typename T::size_type>, + std::is_same<DCT(DCV(const T).Taken()), typename T::size_type>, + std::is_same<DCT(DCV(const T).Empty()), typename T::size_type>, + std::is_same<DCT(DCV(const T).IsEmpty(DCV(typename T::size_type))), bool>, + std::is_same<DCT(DCV(const T).IsTaken(DCV(typename T::size_type))), bool>, + std::is_same<DCT(DCV(const T).Clone(DCV(typename T::size_type))), T>, + std::is_copy_constructible<T>, + std::is_move_constructible<T>, + std::is_copy_assignable<T>, + std::is_move_assignable<T>> {}; + +template <class T> +constexpr bool ContainerV = Container<T>::value; + +template <class T, class = void> +struct RemovalContainer : std::false_type {}; + +template <class T> +struct RemovalContainer<T, std::void_t< + DCT(DCV(T).DeleteNode(DCV(typename T::size_type))), + DCT(DCV(const T).IsDeleted(DCV(typename T::size_type)))>> + : std::conjunction<Container<T>, + std::is_same<DCT(DCV(const T).IsDeleted(DCV(typename T::size_type))), + bool>> {}; + +template <class T> +constexpr bool RemovalContainerV = RemovalContainer<T>::value; + +#undef DCV +#undef DCT + +} // namespace NFlatHash::NConcepts diff --git a/library/cpp/containers/flat_hash/lib/concepts/iterator.h b/library/cpp/containers/flat_hash/lib/concepts/iterator.h new file mode 100644 index 00000000000..b9c1c24c827 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/concepts/iterator.h @@ -0,0 +1,20 @@ +#pragma once + +#include <iterator> + +/* Concepts: + * Iterator + */ +namespace NFlatHash::NConcepts { + +template <class T, class = void> +struct Iterator : std::false_type {}; + +template <class T> +struct Iterator<T, std::void_t<typename std::iterator_traits<T>::iterator_category>> + : std::true_type {}; + +template <class T> +constexpr bool IteratorV = Iterator<T>::value; + +} // namespace NFlatHash::NConcepts diff --git a/library/cpp/containers/flat_hash/lib/concepts/size_fitter.h b/library/cpp/containers/flat_hash/lib/concepts/size_fitter.h new file mode 100644 index 00000000000..83d1d313049 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/concepts/size_fitter.h @@ -0,0 +1,34 @@ +#pragma once + +#include <type_traits> + +/* Concepts: + * SizeFitter + */ +namespace NFlatHash::NConcepts { + +#define DCV(type) std::declval<type>() +#define DCT(object) decltype(object) + +template <class T, class = void> +struct SizeFitter : std::false_type {}; + +template <class T> +struct SizeFitter<T, std::void_t< + DCT(DCV(const T).EvalIndex(DCV(size_t), DCV(size_t))), + DCT(DCV(const T).EvalSize(DCV(size_t))), + DCT(DCV(T).Update(DCV(size_t)))>> + : std::conjunction<std::is_same<DCT(DCV(const T).EvalIndex(DCV(size_t), DCV(size_t))), size_t>, + std::is_same<DCT(DCV(const T).EvalSize(DCV(size_t))), size_t>, + std::is_copy_constructible<T>, + std::is_move_constructible<T>, + std::is_copy_assignable<T>, + std::is_move_assignable<T>> {}; + +template <class T> +constexpr bool SizeFitterV = SizeFitter<T>::value; + +#undef DCV +#undef DCT + +} // namespace NFlatHash::NConcepts diff --git a/library/cpp/containers/flat_hash/lib/concepts/value_marker.h b/library/cpp/containers/flat_hash/lib/concepts/value_marker.h new file mode 100644 index 00000000000..9d1e9b210a7 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/concepts/value_marker.h @@ -0,0 +1,34 @@ +#pragma once + +#include <type_traits> + +/* Concepts: + * ValueMarker + */ +namespace NFlatHash::NConcepts { + +#define DCV(type) std::declval<type>() +#define DCT(object) decltype(object) + +template <class T, class = void> +struct ValueMarker : std::false_type {}; + +template <class T> +struct ValueMarker<T, std::void_t< + typename T::value_type, + DCT(DCV(const T).Create()), + DCT(DCV(const T).Equals(DCV(const typename T::value_type&)))>> + : std::conjunction<std::is_constructible<typename T::value_type, DCT(DCV(const T).Create())>, + std::is_same<DCT(DCV(const T).Equals(DCV(const typename T::value_type&))), bool>, + std::is_copy_constructible<T>, + std::is_move_constructible<T>, + std::is_copy_assignable<T>, + std::is_move_assignable<T>> {}; + +template <class T> +constexpr bool ValueMarkerV = ValueMarker<T>::value; + +#undef DCV +#undef DCT + +} // namespace NFlatHash::NConcepts diff --git a/library/cpp/containers/flat_hash/lib/concepts/ya.make b/library/cpp/containers/flat_hash/lib/concepts/ya.make new file mode 100644 index 00000000000..f82fc1d51c1 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/concepts/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(tender-bum) + +SRCS( + concepts.cpp +) + +END() diff --git a/library/cpp/containers/flat_hash/lib/containers.cpp b/library/cpp/containers/flat_hash/lib/containers.cpp new file mode 100644 index 00000000000..0853c23fc1b --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/containers.cpp @@ -0,0 +1 @@ +#include "containers.h" diff --git a/library/cpp/containers/flat_hash/lib/containers.h b/library/cpp/containers/flat_hash/lib/containers.h new file mode 100644 index 00000000000..82008f2f9cf --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/containers.h @@ -0,0 +1,314 @@ +#pragma once + +#include "concepts/container.h" +#include "value_markers.h" + +#include <util/system/yassert.h> + +#include <util/generic/vector.h> +#include <util/generic/typetraits.h> +#include <util/generic/utility.h> + +#include <optional> + +namespace NFlatHash { + +/* FLAT CONTAINER */ + +template <class T, class Alloc = std::allocator<T>> +class TFlatContainer { +public: + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using allocator_type = Alloc; + using pointer = typename std::allocator_traits<allocator_type>::pointer; + using const_pointer = typename std::allocator_traits<allocator_type>::const_pointer; + +private: + class TCage { + enum ENodeStatus { + NS_EMPTY, + NS_TAKEN, + NS_DELETED + }; + + public: + TCage() noexcept = default; + + TCage(const TCage&) = default; + TCage(TCage&&) = default; + + TCage& operator=(const TCage& rhs) { + switch (rhs.Status_) { + case NS_TAKEN: + if constexpr (std::is_copy_assignable_v<value_type>) { + Value_ = rhs.Value_; + } else { + Value_.emplace(rhs.Value()); + } + break; + case NS_EMPTY: + case NS_DELETED: + if (Value_.has_value()) { + Value_.reset(); + } + break; + default: + Y_VERIFY(false, "Not implemented"); + } + Status_ = rhs.Status_; + return *this; + } + // We never call it since all the TCage's are stored in vector + TCage& operator=(TCage&& rhs) = delete; + + template <class... Args> + void Emplace(Args&&... args) { + Y_ASSERT(Status_ == NS_EMPTY); + Value_.emplace(std::forward<Args>(args)...); + Status_ = NS_TAKEN; + } + + void Reset() noexcept { + Y_ASSERT(Status_ == NS_TAKEN); + Value_.reset(); + Status_ = NS_DELETED; + } + + value_type& Value() { + Y_ASSERT(Status_ == NS_TAKEN); + return *Value_; + } + + const value_type& Value() const { + Y_ASSERT(Status_ == NS_TAKEN); + return *Value_; + } + + bool IsEmpty() const noexcept { return Status_ == NS_EMPTY; } + bool IsTaken() const noexcept { return Status_ == NS_TAKEN; } + bool IsDeleted() const noexcept { return Status_ == NS_DELETED; } + + ENodeStatus Status() const noexcept { return Status_; } + + private: + std::optional<value_type> Value_; + ENodeStatus Status_ = NS_EMPTY; + }; + +public: + explicit TFlatContainer(size_type initSize, const allocator_type& alloc = {}) + : Buckets_(initSize, alloc) + , Taken_(0) + , Empty_(initSize) {} + + TFlatContainer(const TFlatContainer&) = default; + TFlatContainer(TFlatContainer&& rhs) + : Buckets_(std::move(rhs.Buckets_)) + , Taken_(rhs.Taken_) + , Empty_(rhs.Empty_) + { + rhs.Taken_ = 0; + rhs.Empty_ = 0; + } + + TFlatContainer& operator=(const TFlatContainer&) = default; + TFlatContainer& operator=(TFlatContainer&&) = default; + + value_type& Node(size_type idx) { return Buckets_[idx].Value(); } + const value_type& Node(size_type idx) const { return Buckets_[idx].Value(); } + + size_type Size() const noexcept { return Buckets_.size(); } + size_type Taken() const noexcept { return Taken_; } + size_type Empty() const noexcept { return Empty_; } + + template <class... Args> + void InitNode(size_type idx, Args&&... args) { + Buckets_[idx].Emplace(std::forward<Args>(args)...); + ++Taken_; + --Empty_; + } + + void DeleteNode(size_type idx) noexcept { + Buckets_[idx].Reset(); + --Taken_; + } + + bool IsEmpty(size_type idx) const { return Buckets_[idx].IsEmpty(); } + bool IsTaken(size_type idx) const { return Buckets_[idx].IsTaken(); } + bool IsDeleted(size_type idx) const { return Buckets_[idx].IsDeleted(); } + + void Swap(TFlatContainer& rhs) noexcept { + DoSwap(Buckets_, rhs.Buckets_); + DoSwap(Taken_, rhs.Taken_); + DoSwap(Empty_, rhs.Empty_); + } + + TFlatContainer Clone(size_type newSize) const { return TFlatContainer(newSize, Buckets_.get_allocator()); } + +private: + TVector<TCage, allocator_type> Buckets_; + size_type Taken_; + size_type Empty_; +}; + +static_assert(NConcepts::ContainerV<TFlatContainer<int>>); +static_assert(NConcepts::RemovalContainerV<TFlatContainer<int>>); + +/* DENSE CONTAINER */ + +template <class T, class EmptyMarker = NSet::TEqValueMarker<T>, class Alloc = std::allocator<T>> +class TDenseContainer { + static_assert(NConcepts::ValueMarkerV<EmptyMarker>); + +public: + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using allocator_type = Alloc; + using pointer = typename std::allocator_traits<allocator_type>::pointer; + using const_pointer = typename std::allocator_traits<allocator_type>::const_pointer; + +public: + TDenseContainer(size_type initSize, EmptyMarker emptyMarker = {}, const allocator_type& alloc = {}) + : Buckets_(initSize, emptyMarker.Create(), alloc) + , Taken_(0) + , EmptyMarker_(std::move(emptyMarker)) {} + + TDenseContainer(const TDenseContainer&) = default; + TDenseContainer(TDenseContainer&&) = default; + + TDenseContainer& operator=(const TDenseContainer& rhs) { + Taken_ = rhs.Taken_; + EmptyMarker_ = rhs.EmptyMarker_; + if constexpr (std::is_copy_assignable_v<value_type>) { + Buckets_ = rhs.Buckets_; + } else { + auto tmp = rhs.Buckets_; + Buckets_.swap(tmp); + } + return *this; + } + TDenseContainer& operator=(TDenseContainer&&) = default; + + value_type& Node(size_type idx) { return Buckets_[idx]; } + const value_type& Node(size_type idx) const { return Buckets_[idx]; } + + size_type Size() const noexcept { return Buckets_.size(); } + size_type Taken() const noexcept { return Taken_; } + size_type Empty() const noexcept { return Size() - Taken(); } + + template <class... Args> + void InitNode(size_type idx, Args&&... args) { + Node(idx).~value_type(); + new (&Buckets_[idx]) value_type(std::forward<Args>(args)...); + ++Taken_; + } + + bool IsEmpty(size_type idx) const { return EmptyMarker_.Equals(Buckets_[idx]); } + bool IsTaken(size_type idx) const { return !IsEmpty(idx); } + + void Swap(TDenseContainer& rhs) + noexcept(noexcept(DoSwap(std::declval<EmptyMarker&>(), std::declval<EmptyMarker&>()))) + { + DoSwap(Buckets_, rhs.Buckets_); + DoSwap(EmptyMarker_, rhs.EmptyMarker_); + DoSwap(Taken_, rhs.Taken_); + } + + TDenseContainer Clone(size_type newSize) const { return { newSize, EmptyMarker_, GetAllocator() }; } + +protected: + allocator_type GetAllocator() const { + return Buckets_.get_allocator(); + } + +protected: + TVector<value_type, allocator_type> Buckets_; + size_type Taken_; + EmptyMarker EmptyMarker_; +}; + +static_assert(NConcepts::ContainerV<TDenseContainer<int>>); +static_assert(!NConcepts::RemovalContainerV<TDenseContainer<int>>); + +template <class T, class DeletedMarker = NSet::TEqValueMarker<T>, + class EmptyMarker = NSet::TEqValueMarker<T>, class Alloc = std::allocator<T>> +class TRemovalDenseContainer : private TDenseContainer<T, EmptyMarker, Alloc> { +private: + static_assert(NConcepts::ValueMarkerV<DeletedMarker>); + + using TBase = TDenseContainer<T, EmptyMarker>; + +public: + using typename TBase::value_type; + using typename TBase::size_type; + using typename TBase::difference_type; + using typename TBase::allocator_type; + using typename TBase::pointer; + using typename TBase::const_pointer; + +public: + TRemovalDenseContainer( + size_type initSize, + DeletedMarker deletedMarker = {}, + EmptyMarker emptyMarker = {}, + const allocator_type& alloc = {}) + : TBase(initSize, std::move(emptyMarker), alloc) + , DeletedMarker_(std::move(deletedMarker)) + , Empty_(initSize) {} + + TRemovalDenseContainer(const TRemovalDenseContainer&) = default; + TRemovalDenseContainer(TRemovalDenseContainer&&) = default; + + TRemovalDenseContainer& operator=(const TRemovalDenseContainer&) = default; + TRemovalDenseContainer& operator=(TRemovalDenseContainer&&) = default; + + using TBase::Node; + using TBase::Size; + using TBase::Taken; + using TBase::InitNode; + using TBase::IsEmpty; + + size_type Empty() const noexcept { return Empty_; } + + template <class... Args> + void InitNode(size_type idx, Args&&... args) { + TBase::InitNode(idx, std::forward<Args>(args)...); + --Empty_; + } + + void DeleteNode(size_type idx) { + if constexpr (!std::is_trivially_destructible_v<value_type>) { + TBase::Node(idx).~value_type(); + } + new (&TBase::Node(idx)) value_type(DeletedMarker_.Create()); + --TBase::Taken_; + } + + bool IsTaken(size_type idx) const { return !IsDeleted(idx) && TBase::IsTaken(idx); } + bool IsDeleted(size_type idx) const { return DeletedMarker_.Equals(Node(idx)); } + + void Swap(TRemovalDenseContainer& rhs) + noexcept(noexcept(std::declval<TBase>().Swap(std::declval<TBase&>())) && + noexcept(DoSwap(std::declval<DeletedMarker&>(), std::declval<DeletedMarker&>()))) + { + TBase::Swap(rhs); + DoSwap(DeletedMarker_, rhs.DeletedMarker_); + DoSwap(Empty_, rhs.Empty_); + } + + TRemovalDenseContainer Clone(size_type newSize) const { + return { newSize, DeletedMarker_, TBase::EmptyMarker_, TBase::GetAllocator() }; + } + +private: + DeletedMarker DeletedMarker_; + size_type Empty_; +}; + +static_assert(NConcepts::ContainerV<TRemovalDenseContainer<int>>); +static_assert(NConcepts::RemovalContainerV<TRemovalDenseContainer<int>>); + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/expanders.cpp b/library/cpp/containers/flat_hash/lib/expanders.cpp new file mode 100644 index 00000000000..6bed3c72f3a --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/expanders.cpp @@ -0,0 +1 @@ +#include "expanders.h" diff --git a/library/cpp/containers/flat_hash/lib/expanders.h b/library/cpp/containers/flat_hash/lib/expanders.h new file mode 100644 index 00000000000..25b10e6bf1e --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/expanders.h @@ -0,0 +1,25 @@ +#pragma once + +#include <utility> + +namespace NFlatHash { + +struct TSimpleExpander { + static constexpr bool NeedGrow(size_t size, size_t buckets) noexcept { + return size >= buckets / 2; + } + + static constexpr bool WillNeedGrow(size_t size, size_t buckets) noexcept { + return NeedGrow(size + 1, buckets); + } + + static constexpr size_t EvalNewSize(size_t buckets) noexcept { + return buckets * 2; + } + + static constexpr size_t SuitableSize(size_t size) noexcept { + return size * 2 + 1; + } +}; + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/fuzz/dense_map_fuzz/fuzz.cpp b/library/cpp/containers/flat_hash/lib/fuzz/dense_map_fuzz/fuzz.cpp new file mode 100644 index 00000000000..9b4cb4c9836 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/dense_map_fuzz/fuzz.cpp @@ -0,0 +1,60 @@ +#include <library/cpp/containers/flat_hash/lib/map.h> +#include <library/cpp/containers/flat_hash/lib/containers.h> +#include <library/cpp/containers/flat_hash/lib/probings.h> +#include <library/cpp/containers/flat_hash/lib/size_fitters.h> +#include <library/cpp/containers/flat_hash/lib/expanders.h> + +#include <library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.h> + +#include <util/generic/hash.h> +#include <util/generic/xrange.h> +#include <util/generic/bt_exception.h> + +using namespace NFlatHash; + +namespace { + +template <class Key, class T> +using TDenseModMap = NFlatHash::TMap<Key, + T, + THash<Key>, + std::equal_to<Key>, + TRemovalDenseContainer<std::pair<const Key, T>, + NMap::TEqValueMarker<Key, T>, + NMap::TEqValueMarker<Key, T>>, + TDenseProbing, + TAndSizeFitter, + TSimpleExpander>; + +NFuzz::EActionType EvalType(ui8 data) { + return static_cast<NFuzz::EActionType>((data >> 5) & 0b111); +} + +ui8 EvalKey(ui8 data) { + return data & 0b11111; +} + +ui8 EvalValue() { + return RandomNumber<ui8>(); +} + +} // namespace + +#include <util/datetime/base.h> + +extern "C" int LLVMFuzzerTestOneInput(const ui8* const wireData, const size_t wireSize) { + THashMap<ui8, ui8> etalon; + // We assume, that markers can't be produced by EvalKey function. + TDenseModMap<ui8, ui8> testee(8, + (1 << 5), // Deleted marker + (1 << 6)); // Empty marker + + for (auto i : xrange(wireSize)) { + auto data = wireData[i]; + + NFuzz::MakeAction(etalon, testee, EvalKey(data), EvalValue(), EvalType(data)); + NFuzz::CheckInvariants(etalon, testee); + } + + return 0; +} diff --git a/library/cpp/containers/flat_hash/lib/fuzz/dense_map_fuzz/ya.make b/library/cpp/containers/flat_hash/lib/fuzz/dense_map_fuzz/ya.make new file mode 100644 index 00000000000..3a5d3d6d8cd --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/dense_map_fuzz/ya.make @@ -0,0 +1,19 @@ +FUZZ() + +OWNER( + tender-bum +) + +SRCS( + fuzz.cpp +) + +PEERDIR( + library/cpp/containers/flat_hash/lib/fuzz/fuzz_common +) + +SIZE(LARGE) + +TAG(ya:fat) + +END() diff --git a/library/cpp/containers/flat_hash/lib/fuzz/flat_map_fuzz/fuzz.cpp b/library/cpp/containers/flat_hash/lib/fuzz/flat_map_fuzz/fuzz.cpp new file mode 100644 index 00000000000..7fb73af0e9f --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/flat_map_fuzz/fuzz.cpp @@ -0,0 +1,53 @@ +#include <library/cpp/containers/flat_hash/lib/map.h> +#include <library/cpp/containers/flat_hash/lib/containers.h> +#include <library/cpp/containers/flat_hash/lib/probings.h> +#include <library/cpp/containers/flat_hash/lib/size_fitters.h> +#include <library/cpp/containers/flat_hash/lib/expanders.h> + +#include <library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.h> + +#include <util/generic/hash.h> +#include <util/generic/xrange.h> +#include <util/generic/bt_exception.h> + +using namespace NFlatHash; + +namespace { + +template <class Key, class T> +using TFlatLinearModMap = NFlatHash::TMap<Key, + T, + THash<Key>, + std::equal_to<Key>, + TFlatContainer<std::pair<const Key, T>>, + TLinearProbing, + TAndSizeFitter, + TSimpleExpander>; + +NFuzz::EActionType EvalType(ui8 data) { + return static_cast<NFuzz::EActionType>((data >> 5) & 0b111); +} + +ui8 EvalKey(ui8 data) { + return data & 0b11111; +} + +ui8 EvalValue() { + return RandomNumber<ui8>(); +} + +} // namespace + +extern "C" int LLVMFuzzerTestOneInput(const ui8* const wireData, const size_t wireSize) { + THashMap<ui8, ui8> etalon; + TFlatLinearModMap<ui8, ui8> testee; + + for (auto i : xrange(wireSize)) { + auto data = wireData[i]; + + NFuzz::MakeAction(etalon, testee, EvalKey(data), EvalValue(), EvalType(data)); + NFuzz::CheckInvariants(etalon, testee); + } + + return 0; +} diff --git a/library/cpp/containers/flat_hash/lib/fuzz/flat_map_fuzz/ya.make b/library/cpp/containers/flat_hash/lib/fuzz/flat_map_fuzz/ya.make new file mode 100644 index 00000000000..3a5d3d6d8cd --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/flat_map_fuzz/ya.make @@ -0,0 +1,19 @@ +FUZZ() + +OWNER( + tender-bum +) + +SRCS( + fuzz.cpp +) + +PEERDIR( + library/cpp/containers/flat_hash/lib/fuzz/fuzz_common +) + +SIZE(LARGE) + +TAG(ya:fat) + +END() diff --git a/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.cpp b/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.cpp new file mode 100644 index 00000000000..efc2973d18c --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.cpp @@ -0,0 +1 @@ +#include "fuzz_common.h" diff --git a/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.h b/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.h new file mode 100644 index 00000000000..71a123d9cf6 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/fuzz_common.h @@ -0,0 +1,223 @@ +#pragma once + +#include <util/generic/bt_exception.h> +#include <util/generic/vector.h> +#include <util/generic/xrange.h> + +#include <util/random/random.h> + +namespace NFlatHash::NFuzz { + +#define FUZZ_ASSERT(cond) \ + Y_ENSURE_EX(cond, TWithBackTrace<yexception>() << Y_STRINGIZE(cond) << " assertion failed ") + +#define FUZZ_ASSERT_THROW(cond, exc) \ + try { \ + cond; \ + FUZZ_ASSERT(false); \ + } catch (const exc&) { \ + } catch (...) { \ + FUZZ_ASSERT(false); \ + } + +enum EActionType { + AT_INSERT, + AT_CLEAR, + AT_REHASH, + AT_ATOP, + AT_AT, + AT_ITERATORS, + AT_ERASE, + AT_FIND +}; + +template <class EtalonMap, class TesteeMap, class Key, class Value> +void MakeAction(EtalonMap& etalon, TesteeMap& testee, Key&& key, Value&& value, EActionType type) { + switch (type) { + case AT_INSERT: { + auto itEt = etalon.insert({ key, value }); + if (itEt.second) { + FUZZ_ASSERT(!testee.contains(key)); + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + + auto itTs = testee.insert(std::make_pair(key, value)); + FUZZ_ASSERT(itTs.second); + FUZZ_ASSERT(itTs.first->first == key); + FUZZ_ASSERT(itTs.first->second == value); + FUZZ_ASSERT(size + 1 == testee.size()); + FUZZ_ASSERT(bucket_count <= testee.bucket_count()); + } else { + FUZZ_ASSERT(testee.contains(key)); + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + + auto itTs = testee.insert(std::make_pair(key, value)); + FUZZ_ASSERT(!itTs.second); + FUZZ_ASSERT(itTs.first->first == key); + FUZZ_ASSERT(itTs.first->second == itEt.first->second); + FUZZ_ASSERT(size == testee.size()); + FUZZ_ASSERT(bucket_count == testee.bucket_count()); + } + break; + } + case AT_CLEAR: { + auto bucket_count = testee.bucket_count(); + testee.clear(); + for (const auto& v : etalon) { + FUZZ_ASSERT(!testee.contains(v.first)); + } + FUZZ_ASSERT(testee.empty()); + FUZZ_ASSERT(testee.size() == 0); + FUZZ_ASSERT(testee.bucket_count() == bucket_count); + FUZZ_ASSERT(testee.load_factor() < std::numeric_limits<float>::epsilon()); + + etalon.clear(); + break; + } + case AT_REHASH: { + testee.rehash(key); + FUZZ_ASSERT(testee.bucket_count() >= key); + break; + } + case AT_ATOP: { + if (etalon.contains(key)) { + FUZZ_ASSERT(testee.contains(key)); + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + + FUZZ_ASSERT(testee[key] == etalon[key]); + + FUZZ_ASSERT(size == testee.size()); + FUZZ_ASSERT(bucket_count == testee.bucket_count()); + } else { + FUZZ_ASSERT(!testee.contains(key)); + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + + FUZZ_ASSERT(testee[key] == etalon[key]); + + FUZZ_ASSERT(size + 1 == testee.size()); + FUZZ_ASSERT(bucket_count <= testee.bucket_count()); + } + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + + etalon[key] = value; + testee[key] = value; + FUZZ_ASSERT(testee[key] == etalon[key]); + FUZZ_ASSERT(testee[key] == value); + + FUZZ_ASSERT(size == testee.size()); + FUZZ_ASSERT(bucket_count == testee.bucket_count()); + break; + } + case AT_AT: { + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + if (etalon.contains(key)) { + FUZZ_ASSERT(testee.contains(key)); + + FUZZ_ASSERT(testee.at(key) == etalon.at(key)); + testee.at(key) = value; + etalon.at(key) = value; + FUZZ_ASSERT(testee.at(key) == etalon.at(key)); + } else { + FUZZ_ASSERT(!testee.contains(key)); + FUZZ_ASSERT_THROW(testee.at(key) = value, std::out_of_range); + FUZZ_ASSERT(!testee.contains(key)); + } + FUZZ_ASSERT(size == testee.size()); + FUZZ_ASSERT(bucket_count == testee.bucket_count()); + break; + } + case AT_ITERATORS: { + auto itBeginTs = testee.begin(); + auto itEndTs = testee.end(); + FUZZ_ASSERT((size_t)std::distance(itBeginTs, itEndTs) == testee.size()); + FUZZ_ASSERT(std::distance(itBeginTs, itEndTs) == + std::distance(etalon.begin(), etalon.end())); + FUZZ_ASSERT(std::distance(testee.cbegin(), testee.cend()) == + std::distance(etalon.cbegin(), etalon.cend())); + break; + } + case AT_ERASE: { + if (etalon.contains(key)) { + FUZZ_ASSERT(testee.contains(key)); + auto size = testee.size(); + auto bucket_count = testee.bucket_count(); + + auto itTs = testee.find(key); + FUZZ_ASSERT(itTs->first == key); + FUZZ_ASSERT(itTs->second == etalon.at(key)); + + testee.erase(itTs); + FUZZ_ASSERT(size - 1 == testee.size()); + FUZZ_ASSERT(bucket_count == testee.bucket_count()); + etalon.erase(key); + } else { + FUZZ_ASSERT(!testee.contains(key)); + } + break; + } + case AT_FIND: { + auto itEt = etalon.find(key); + if (itEt != etalon.end()) { + FUZZ_ASSERT(testee.contains(key)); + + auto itTs = testee.find(key); + FUZZ_ASSERT(itTs != testee.end()); + FUZZ_ASSERT(itTs->first == key); + FUZZ_ASSERT(itTs->second == itEt->second); + + itTs->second = value; + itEt->second = value; + } else { + FUZZ_ASSERT(!testee.contains(key)); + + auto itTs = testee.find(key); + FUZZ_ASSERT(itTs == testee.end()); + } + break; + } + }; +} + +template <class EtalonMap, class TesteeMap> +void CheckInvariants(const EtalonMap& etalon, const TesteeMap& testee) { + using value_type = std::pair<typename TesteeMap::key_type, + typename TesteeMap::mapped_type>; + using size_type = typename TesteeMap::size_type; + + TVector<value_type> etalonVals{ etalon.begin(), etalon.end() }; + std::sort(etalonVals.begin(), etalonVals.end()); + TVector<value_type> testeeVals{ testee.begin(), testee.end() }; + std::sort(testeeVals.begin(), testeeVals.end()); + + FUZZ_ASSERT(testeeVals == etalonVals); + + FUZZ_ASSERT(testee.size() == etalon.size()); + FUZZ_ASSERT(testee.empty() == etalon.empty()); + FUZZ_ASSERT(testee.load_factor() < 0.5f + std::numeric_limits<float>::epsilon()); + FUZZ_ASSERT(testee.bucket_count() > testee.size()); + + size_type buckets = 0; + for (auto b : xrange(testee.bucket_count())) { + buckets += testee.bucket_size(b); + } + FUZZ_ASSERT(buckets == testee.size()); + + for (const auto& v : etalon) { + auto key = v.first; + auto value = v.second; + + FUZZ_ASSERT(testee.contains(key)); + FUZZ_ASSERT(testee.count(key) == 1); + + auto it = testee.find(key); + FUZZ_ASSERT(it->first == key); + FUZZ_ASSERT(it->second == value); + } +} + +} // namespace NFlatHash::NFuzz diff --git a/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/ya.make b/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/ya.make new file mode 100644 index 00000000000..ecb590e1163 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/fuzz_common/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +OWNER(tender-bum) + +SRCS(fuzz_common.cpp) + +PEERDIR( + library/cpp/containers/flat_hash/lib +) + +END() diff --git a/library/cpp/containers/flat_hash/lib/fuzz/ya.make b/library/cpp/containers/flat_hash/lib/fuzz/ya.make new file mode 100644 index 00000000000..dbf2183be5a --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/fuzz/ya.make @@ -0,0 +1,7 @@ +OWNER(tender-bum) + +RECURSE( + flat_map_fuzz + dense_map_fuzz + fuzz_common +) diff --git a/library/cpp/containers/flat_hash/lib/iterator.cpp b/library/cpp/containers/flat_hash/lib/iterator.cpp new file mode 100644 index 00000000000..7c5c206cc36 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/iterator.cpp @@ -0,0 +1 @@ +#include "iterator.h" diff --git a/library/cpp/containers/flat_hash/lib/iterator.h b/library/cpp/containers/flat_hash/lib/iterator.h new file mode 100644 index 00000000000..f6b1e74355d --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/iterator.h @@ -0,0 +1,99 @@ +#pragma once + +#include "concepts/container.h" + +#include <util/system/yassert.h> + +#include <iterator> + +namespace NFlatHash { + +template <class Container, class T> +class TIterator { +private: + static_assert(NConcepts::ContainerV<std::decay_t<Container>>); + +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = typename std::add_pointer<T>::type; + using reference = typename std::add_lvalue_reference<T>::type; + +private: + using size_type = typename Container::size_type; + +public: + TIterator(Container* cont) + : Cont_(cont) + , Idx_(0) + { + if (!cont->IsTaken(Idx_)) { + Next(); + } + } + + TIterator(Container* cont, size_type idx) + : Cont_(cont) + , Idx_(idx) {} + + template <class C, class U, class = std::enable_if_t<std::is_convertible<C*, Container*>::value && + std::is_convertible<U, T>::value>> + TIterator(const TIterator<C, U>& rhs) + : Cont_(rhs.Cont_) + , Idx_(rhs.Idx_) {} + + TIterator(const TIterator&) = default; + + TIterator& operator=(const TIterator&) = default; + + TIterator& operator++() { + Next(); + return *this; + } + TIterator operator++(int) { + auto idx = Idx_; + Next(); + return { Cont_, idx }; + } + + reference operator*() { + return Cont_->Node(Idx_); + } + + pointer operator->() { + return &Cont_->Node(Idx_); + } + + const pointer operator->() const { + return &Cont_->Node(Idx_); + } + + bool operator==(const TIterator& rhs) const noexcept { + Y_ASSERT(Cont_ == rhs.Cont_); + return Idx_ == rhs.Idx_; + } + + bool operator!=(const TIterator& rhs) const noexcept { + return !operator==(rhs); + } + +private: + void Next() { + // Container provider ensures that it's not empty. + do { + ++Idx_; + } while (Idx_ != Cont_->Size() && !Cont_->IsTaken(Idx_)); + } + +private: + template <class C, class U> + friend class TIterator; + + Container* Cont_ = nullptr; + +protected: + size_type Idx_ = 0; +}; + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/map.cpp b/library/cpp/containers/flat_hash/lib/map.cpp new file mode 100644 index 00000000000..b323fbb46d6 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/map.cpp @@ -0,0 +1 @@ +#include "map.h" diff --git a/library/cpp/containers/flat_hash/lib/map.h b/library/cpp/containers/flat_hash/lib/map.h new file mode 100644 index 00000000000..f77c318a615 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/map.h @@ -0,0 +1,233 @@ +#pragma once + +#include "table.h" +#include "concepts/iterator.h" + +#include <util/generic/algorithm.h> +#include <util/generic/mapfindptr.h> + +namespace NFlatHash { + +namespace NPrivate { + +struct TMapKeyGetter { + template <class T> + static constexpr auto& Apply(T& t) noexcept { return t.first; }; + + template <class T> + static constexpr const auto& Apply(const T& t) noexcept { return t.first; }; +}; + +} // namespace NPrivate + +template <class Key, + class T, + class Hash, + class KeyEqual, + class Container, + class Probing, + class SizeFitter, + class Expander> +class TMap : private TTable<Hash, + KeyEqual, + Container, + NPrivate::TMapKeyGetter, + Probing, + SizeFitter, + Expander>, + public TMapOps<TMap<Key, + T, + Hash, + KeyEqual, + Container, + Probing, + SizeFitter, + Expander>> +{ +private: + using TBase = TTable<Hash, + KeyEqual, + Container, + NPrivate::TMapKeyGetter, + Probing, + SizeFitter, + Expander>; + + static_assert(std::is_same<std::pair<const Key, T>, typename Container::value_type>::value); + +public: + using key_type = Key; + using mapped_type = T; + using typename TBase::value_type; + using typename TBase::size_type; + using typename TBase::difference_type; + using typename TBase::hasher; + using typename TBase::key_equal; + using typename TBase::reference; + using typename TBase::const_reference; + using typename TBase::iterator; + using typename TBase::const_iterator; + using typename TBase::allocator_type; + using typename TBase::pointer; + using typename TBase::const_pointer; + +private: + static constexpr size_type INIT_SIZE = 8; + +public: + TMap() : TBase(INIT_SIZE) {} + + template <class... Rest> + TMap(size_type initSize, Rest&&... rest) : TBase(initSize, std::forward<Rest>(rest)...) {} + + template <class I, class... Rest> + TMap(I first, I last, + std::enable_if_t<NConcepts::IteratorV<I>, size_type> initSize = INIT_SIZE, + Rest&&... rest) + : TBase(initSize, std::forward<Rest>(rest)...) + { + insert(first, last); + } + + template <class... Rest> + TMap(std::initializer_list<value_type> il, size_type initSize = INIT_SIZE, Rest&&... rest) + : TBase(initSize, std::forward<Rest>(rest)...) + { + insert(il.begin(), il.end()); + } + + TMap(std::initializer_list<value_type> il, size_type initSize = INIT_SIZE) + : TBase(initSize) + { + insert(il.begin(), il.end()); + } + + TMap(const TMap&) = default; + TMap(TMap&&) = default; + + TMap& operator=(const TMap&) = default; + TMap& operator=(TMap&&) = default; + + // Iterators + using TBase::begin; + using TBase::cbegin; + using TBase::end; + using TBase::cend; + + // Capacity + using TBase::empty; + using TBase::size; + + // Modifiers + using TBase::clear; + using TBase::insert; + using TBase::emplace; + using TBase::emplace_hint; + using TBase::erase; + using TBase::swap; + + template <class V> + std::pair<iterator, bool> insert_or_assign(const key_type& k, V&& v) { + return InsertOrAssignImpl(k, std::forward<V>(v)); + } + template <class V> + std::pair<iterator, bool> insert_or_assign(key_type&& k, V&& v) { + return InsertOrAssignImpl(std::move(k), std::forward<V>(v)); + } + + template <class V> + iterator insert_or_assign(const_iterator, const key_type& k, V&& v) { // TODO(tender-bum) + return insert_or_assign(k, std::forward<V>(v)).first; + } + template <class V> + iterator insert_or_assign(const_iterator, key_type&& k, V&& v) { // TODO(tender-bum) + return insert_or_assign(std::move(k), std::forward<V>(v)).first; + } + + template <class... Args> + std::pair<iterator, bool> try_emplace(const key_type& key, Args&&... args) { + return TryEmplaceImpl(key, std::forward<Args>(args)...); + } + template <class... Args> + std::pair<iterator, bool> try_emplace(key_type&& key, Args&&... args) { + return TryEmplaceImpl(std::move(key), std::forward<Args>(args)...); + } + + template <class... Args> + iterator try_emplace(const_iterator, const key_type& key, Args&&... args) { // TODO(tender-bum) + return try_emplace(key, std::forward<Args>(args)...).first; + } + template <class... Args> + iterator try_emplace(const_iterator, key_type&& key, Args&&... args) { // TODO(tender-bum) + return try_emplace(std::move(key), std::forward<Args>(args)...).first; + } + + // Lookup + using TBase::count; + using TBase::find; + using TBase::contains; + + template <class K> + mapped_type& at(const K& key) { + auto it = find(key); + if (it == end()) { + throw std::out_of_range{ "no such key in map" }; + } + return it->second; + } + + template <class K> + const mapped_type& at(const K& key) const { return const_cast<TMap*>(this)->at(key); } + + template <class K> + Y_FORCE_INLINE mapped_type& operator[](K&& key) { + return TBase::TryCreate(key, [&](size_type idx) { + TBase::Buckets_.InitNode(idx, std::forward<K>(key), mapped_type{}); + }).first->second; + } + + // Bucket interface + using TBase::bucket_count; + using TBase::bucket_size; + + // Hash policy + using TBase::load_factor; + using TBase::rehash; + using TBase::reserve; + + // Observers + using TBase::hash_function; + using TBase::key_eq; + + friend bool operator==(const TMap& lhs, const TMap& rhs) { + return lhs.size() == rhs.size() && AllOf(lhs, [&rhs](const auto& v) { + auto it = rhs.find(v.first); + return it != rhs.end() && *it == v; + }); + } + + friend bool operator!=(const TMap& lhs, const TMap& rhs) { return !(lhs == rhs); } + +private: + template <class K, class... Args> + std::pair<iterator, bool> TryEmplaceImpl(K&& key, Args&&... args) { + return TBase::TryCreate(key, [&](size_type idx) { + TBase::Buckets_.InitNode( + idx, + std::piecewise_construct, + std::forward_as_tuple(std::forward<K>(key)), + std::forward_as_tuple(std::forward<Args>(args)...)); + }); + } + + template <class K, class V> + std::pair<iterator, bool> InsertOrAssignImpl(K&& key, V&& v) { + auto p = try_emplace(std::forward<K>(key), std::forward<V>(v)); + if (!p.second) { + p.first->second = std::forward<V>(v); + } + return p; + } +}; + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/probings.cpp b/library/cpp/containers/flat_hash/lib/probings.cpp new file mode 100644 index 00000000000..f10c6af1132 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/probings.cpp @@ -0,0 +1 @@ +#include "probings.h" diff --git a/library/cpp/containers/flat_hash/lib/probings.h b/library/cpp/containers/flat_hash/lib/probings.h new file mode 100644 index 00000000000..886be59cffc --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/probings.h @@ -0,0 +1,45 @@ +#pragma once + +#include <type_traits> + +namespace NFlatHash { + +class TLinearProbing { +public: + template <class SizeFitter, class F> + static auto FindBucket(SizeFitter sf, size_t idx, size_t sz, F f) { + idx = sf.EvalIndex(idx, sz); + while (!f(idx)) { + idx = sf.EvalIndex(++idx, sz); + } + return idx; + } +}; + +class TQuadraticProbing { +public: + template <class SizeFitter, class F> + static auto FindBucket(SizeFitter sf, size_t idx, size_t sz, F f) { + idx = sf.EvalIndex(idx, sz); + size_t k = 0; + while (!f(idx)) { + idx = sf.EvalIndex(idx + 2 * ++k - 1, sz); + } + return idx; + } +}; + +class TDenseProbing { +public: + template <class SizeFitter, class F> + static auto FindBucket(SizeFitter sf, size_t idx, size_t sz, F f) { + idx = sf.EvalIndex(idx, sz); + size_t k = 0; + while (!f(idx)) { + idx = sf.EvalIndex(idx + ++k, sz); + } + return idx; + } +}; + +} // NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/set.cpp b/library/cpp/containers/flat_hash/lib/set.cpp new file mode 100644 index 00000000000..aa2f9c58e1c --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/set.cpp @@ -0,0 +1 @@ +#include "set.h" diff --git a/library/cpp/containers/flat_hash/lib/set.h b/library/cpp/containers/flat_hash/lib/set.h new file mode 100644 index 00000000000..5266293c6c3 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/set.h @@ -0,0 +1,147 @@ +#pragma once + +#include "table.h" +#include "concepts/iterator.h" + +#include <util/generic/algorithm.h> + +namespace NFlatHash { + +namespace NPrivate { + +struct TSimpleKeyGetter { + template <class T> + static constexpr auto& Apply(T& t) noexcept { return t; }; + + template <class T> + static constexpr const auto& Apply(const T& t) noexcept { return t; }; +}; + +} // namespace NPrivate + +template <class Key, + class Hash, + class KeyEqual, + class Container, + class Probing, + class SizeFitter, + class Expander> +class TSet : private TTable<Hash, + KeyEqual, + Container, + NPrivate::TSimpleKeyGetter, + Probing, + SizeFitter, + Expander, + std::add_const> +{ +private: + using TBase = TTable<Hash, + KeyEqual, + Container, + NPrivate::TSimpleKeyGetter, + Probing, + SizeFitter, + Expander, + std::add_const>; + + static_assert(std::is_same_v<Key, typename Container::value_type>); + +public: + using key_type = Key; + using typename TBase::value_type; + using typename TBase::size_type; + using typename TBase::difference_type; + using typename TBase::hasher; + using typename TBase::key_equal; + using typename TBase::reference; + using typename TBase::const_reference; + using typename TBase::iterator; + using typename TBase::const_iterator; + using typename TBase::allocator_type; + using typename TBase::pointer; + using typename TBase::const_pointer; + +private: + static constexpr size_type INIT_SIZE = 8; + +public: + TSet() : TBase(INIT_SIZE) {} + + template <class... Rest> + TSet(size_type initSize, Rest&&... rest) : TBase(initSize, std::forward<Rest>(rest)...) {} + + template <class I, class... Rest> + TSet(I first, I last, + std::enable_if_t<NConcepts::IteratorV<I>, size_type> initSize = INIT_SIZE, + Rest&&... rest) + : TBase(initSize, std::forward<Rest>(rest)...) + { + insert(first, last); + } + + template <class... Rest> + TSet(std::initializer_list<value_type> il, size_type initSize = INIT_SIZE, Rest&&... rest) + : TBase(initSize, std::forward<Rest>(rest)...) + { + insert(il.begin(), il.end()); + } + + TSet(std::initializer_list<value_type> il, size_type initSize = INIT_SIZE) + : TBase(initSize) + { + insert(il.begin(), il.end()); + } + + TSet(const TSet&) = default; + TSet(TSet&&) = default; + + TSet& operator=(const TSet&) = default; + TSet& operator=(TSet&&) = default; + + // Iterators + using TBase::begin; + using TBase::cbegin; + using TBase::end; + using TBase::cend; + + // Capacity + using TBase::empty; + using TBase::size; + + // Modifiers + using TBase::clear; + using TBase::insert; + using TBase::emplace; + using TBase::emplace_hint; + using TBase::erase; + using TBase::swap; + + // Lookup + using TBase::count; + using TBase::find; + using TBase::contains; + + // Bucket interface + using TBase::bucket_count; + using TBase::bucket_size; + + // Hash policy + using TBase::load_factor; + using TBase::rehash; + using TBase::reserve; + + // Observers + using TBase::hash_function; + using TBase::key_eq; + + friend bool operator==(const TSet& lhs, const TSet& rhs) { + return lhs.size() == rhs.size() && AllOf(lhs, [&rhs](const auto& v) { + return rhs.contains(v); + }); + } + + friend bool operator!=(const TSet& lhs, const TSet& rhs) { return !(lhs == rhs); } +}; + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/size_fitters.cpp b/library/cpp/containers/flat_hash/lib/size_fitters.cpp new file mode 100644 index 00000000000..f1431c27e3c --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/size_fitters.cpp @@ -0,0 +1 @@ +#include "size_fitters.h" diff --git a/library/cpp/containers/flat_hash/lib/size_fitters.h b/library/cpp/containers/flat_hash/lib/size_fitters.h new file mode 100644 index 00000000000..86bd617342d --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/size_fitters.h @@ -0,0 +1,47 @@ +#pragma once + +#include "concepts/size_fitter.h" + +#include <util/system/yassert.h> +#include <util/generic/bitops.h> + +namespace NFlatHash { + +class TAndSizeFitter { +public: + size_t EvalIndex(size_t hs, size_t sz) const noexcept { + Y_ASSERT(Mask_ == sz - 1); + return (hs & Mask_); + } + + size_t EvalSize(size_t sz) const noexcept { + return FastClp2(sz); + } + + void Update(size_t sz) noexcept { + Y_ASSERT((sz & (sz - 1)) == 0); + Mask_ = sz - 1; + } + +private: + size_t Mask_ = 0; +}; + +static_assert(NConcepts::SizeFitterV<TAndSizeFitter>); + +class TModSizeFitter { +public: + constexpr size_t EvalIndex(size_t hs, size_t sz) const noexcept { + return hs % sz; + } + + constexpr size_t EvalSize(size_t sz) const noexcept { + return sz; + } + + constexpr void Update(size_t) noexcept {} +}; + +static_assert(NConcepts::SizeFitterV<TModSizeFitter>); + +} // NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/table.cpp b/library/cpp/containers/flat_hash/lib/table.cpp new file mode 100644 index 00000000000..e89d72ad94f --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/table.cpp @@ -0,0 +1 @@ +#include "table.h" diff --git a/library/cpp/containers/flat_hash/lib/table.h b/library/cpp/containers/flat_hash/lib/table.h new file mode 100644 index 00000000000..b84a052be75 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/table.h @@ -0,0 +1,314 @@ +#pragma once + +#include "iterator.h" +#include "concepts/container.h" +#include "concepts/size_fitter.h" + +#include <util/generic/utility.h> + +#include <functional> + +namespace NFlatHash { + +namespace NPrivate { + +template <class T> +struct TTypeIdentity { using type = T; }; + +} // namespace NPrivate + +template < + class Hash, + class KeyEqual, + class Container, + class KeyGetter, + class Probing, + class SizeFitter, + class Expander, + // Used in the TSet to make iterator behave as const_iterator + template <class> class IteratorModifier = NPrivate::TTypeIdentity> +class TTable { +private: + static_assert(NConcepts::ContainerV<Container>); + static_assert(NConcepts::SizeFitterV<SizeFitter>); + + template <class C, class V> + class TIteratorImpl : public TIterator<C, V> { + private: + using TBase = TIterator<C, V>; + friend class TTable; + + using TBase::TBase; + + public: + TIteratorImpl() : TBase(nullptr, 0) {} + }; + +public: + using value_type = typename Container::value_type; + using size_type = typename Container::size_type; + using difference_type = typename Container::difference_type; + using hasher = Hash; + using key_equal = KeyEqual; + + using reference = value_type&; + using const_reference = const value_type&; + + using iterator = TIteratorImpl<typename IteratorModifier<Container>::type, + typename IteratorModifier<value_type>::type>; + using const_iterator = TIteratorImpl<const Container, const value_type>; + using allocator_type = typename Container::allocator_type; + using pointer = typename Container::pointer; + using const_pointer = typename Container::const_pointer; + +private: + TTable(Container buckets) + : Buckets_(std::move(buckets)) + { + SizeFitter_.Update(bucket_count()); + } + + static constexpr size_type INIT_SIZE = 8; + +public: + template <class... Rest> + TTable(size_type initSize, Rest&&... rest) + : Buckets_(initSize == 0 ? INIT_SIZE : SizeFitter_.EvalSize(initSize), + std::forward<Rest>(rest)...) + { + SizeFitter_.Update(bucket_count()); + } + + TTable(const TTable&) = default; + TTable(TTable&& rhs) + : SizeFitter_(std::move(rhs.SizeFitter_)) + , Buckets_(std::move(rhs.Buckets_)) + , Hasher_(std::move(rhs.Hasher_)) + , KeyEqual_(std::move(rhs.KeyEqual_)) + { + TTable tmp{ Buckets_.Clone(INIT_SIZE) }; + tmp.swap(rhs); + } + + TTable& operator=(const TTable&) = default; + TTable& operator=(TTable&& rhs) { + TTable tmp(std::move(rhs)); + swap(tmp); + return *this; + } + + // Iterators + iterator begin() { return &Buckets_; } + const_iterator begin() const { return const_cast<TTable*>(this)->begin(); } + const_iterator cbegin() const { return begin(); } + + iterator end() { return { &Buckets_, bucket_count() }; } + const_iterator end() const { return const_cast<TTable*>(this)->end(); } + const_iterator cend() const { return end(); } + + // Capacity + bool empty() const noexcept { return size() == 0; } + size_type size() const noexcept { return Buckets_.Taken(); } + + // Modifiers + void clear() { + Container tmp(Buckets_.Clone(bucket_count())); + Buckets_.Swap(tmp); + } + + std::pair<iterator, bool> insert(const value_type& value) { return InsertImpl(value); } + std::pair<iterator, bool> insert(value_type&& value) { return InsertImpl(std::move(value)); } + + template <class T> + std::enable_if_t<!std::is_same_v<std::decay_t<T>, value_type>, + std::pair<iterator, bool>> insert(T&& value) { + return insert(value_type(std::forward<T>(value))); + } + + iterator insert(const_iterator, const value_type& value) { // TODO(tender-bum) + return insert(value).first; + } + iterator insert(const_iterator, value_type&& value) { // TODO(tender-bum) + return insert(std::move(value)).first; + } + + template <class T> + iterator insert(const_iterator, T&& value) { // TODO(tender-bum) + return insert(value_type(std::forward<T>(value))).first; + } + + template <class InputIt> + void insert(InputIt first, InputIt last) { + while (first != last) { + insert(*first++); + } + } + + void insert(std::initializer_list<value_type> il) { + insert(il.begin(), il.end()); + } + + template <class... Args> + std::pair<iterator, bool> emplace(Args&&... args) { + return insert(value_type(std::forward<Args>(args)...)); + } + + template <class... Args> + iterator emplace_hint(const_iterator, Args&&... args) { // TODO(tender-bum) + return emplace(std::forward<Args>(args)...).first; + } + + void erase(const_iterator pos) { + static_assert(NConcepts::RemovalContainerV<Container>, + "That kind of table doesn't allow erasing. Use another table instead."); + if constexpr (NConcepts::RemovalContainerV<Container>) { + Buckets_.DeleteNode(pos.Idx_); + } + } + + void erase(const_iterator f, const_iterator l) { + while (f != l) { + auto nxt = f; + ++nxt; + erase(f); + f = nxt; + } + } + + template <class K> + std::enable_if_t<!std::is_convertible_v<K, iterator> && !std::is_convertible_v<K, const_iterator>, + size_type> erase(const K& key) { + auto it = find(key); + if (it != end()) { + erase(it); + return 1; + } + return 0; + } + + void swap(TTable& rhs) + noexcept(noexcept(std::declval<Container>().Swap(std::declval<Container&>()))) + { + DoSwap(SizeFitter_, rhs.SizeFitter_); + Buckets_.Swap(rhs.Buckets_); + DoSwap(Hasher_, rhs.Hasher_); + DoSwap(KeyEqual_, rhs.KeyEqual_); + } + + // Lookup + template <class K> + size_type count(const K& key) const { return contains(key); } + + template <class K> + iterator find(const K& key) { + size_type hs = hash_function()(key); + auto idx = FindProperBucket(hs, key); + if (Buckets_.IsTaken(idx)) { + return { &Buckets_, idx }; + } + return end(); + } + + template <class K> + const_iterator find(const K& key) const { return const_cast<TTable*>(this)->find(key); } + + template <class K> + bool contains(const K& key) const { + size_type hs = hash_function()(key); + return Buckets_.IsTaken(FindProperBucket(hs, key)); + } + + // Bucket interface + size_type bucket_count() const noexcept { return Buckets_.Size(); } + size_type bucket_size(size_type idx) const { return Buckets_.IsTaken(idx); } + + // Hash policy + float load_factor() const noexcept { + return (float)(bucket_count() - Buckets_.Empty()) / bucket_count(); + } + + void rehash(size_type sz) { + if (sz != 0) { + auto newBuckets = SizeFitter_.EvalSize(sz); + size_type occupied = bucket_count() - Buckets_.Empty(); + if (Expander::NeedGrow(occupied, newBuckets)) { + newBuckets = Max(newBuckets, SizeFitter_.EvalSize(Expander::SuitableSize(size()))); + } + RehashImpl(newBuckets); + } else { + RehashImpl(SizeFitter_.EvalSize(Expander::SuitableSize(size()))); + } + } + + void reserve(size_type sz) { rehash(sz); } // TODO(tender-bum) + + // Observers + constexpr auto hash_function() const noexcept { return Hasher_; } + constexpr auto key_eq() const noexcept { return KeyEqual_; } + +public: + template <class T> + std::pair<iterator, bool> InsertImpl(T&& value) { + return TryCreate(KeyGetter::Apply(value), [&](size_type idx) { + Buckets_.InitNode(idx, std::forward<T>(value)); + }); + } + + template <class T, class F> + Y_FORCE_INLINE std::pair<iterator, bool> TryCreate(const T& key, F nodeInit) { + size_type hs = hash_function()(key); + size_type idx = FindProperBucket(hs, key); + if (!Buckets_.IsTaken(idx)) { + if (Expander::WillNeedGrow(bucket_count() - Buckets_.Empty(), bucket_count())) { + RehashImpl(); + idx = FindProperBucket(hs, key); + } + nodeInit(idx); + return { iterator{ &Buckets_, idx }, true }; + } + return { iterator{ &Buckets_, idx }, false }; + } + + template <class K> + size_type FindProperBucket(size_type hs, const K& key) const { + return Probing::FindBucket(SizeFitter_, hs, bucket_count(), [&](size_type idx) { + if constexpr (NConcepts::RemovalContainerV<Container>) { + return Buckets_.IsEmpty(idx) || + Buckets_.IsTaken(idx) && key_eq()(KeyGetter::Apply(Buckets_.Node(idx)), key); + } else { + return Buckets_.IsEmpty(idx) || key_eq()(KeyGetter::Apply(Buckets_.Node(idx)), key); + } + }); + } + + void RehashImpl() { + if constexpr (NConcepts::RemovalContainerV<Container>) { + size_type occupied = bucket_count() - Buckets_.Empty(); + if (size() < occupied / 2) { + rehash(bucket_count()); // Just clearing all deleted elements + } else { + RehashImpl(SizeFitter_.EvalSize(Expander::EvalNewSize(bucket_count()))); + } + } else { + RehashImpl(SizeFitter_.EvalSize(Expander::EvalNewSize(bucket_count()))); + } + } + + void RehashImpl(size_type newSize) { + TTable tmp = Buckets_.Clone(newSize); + for (auto& value : *this) { + size_type hs = hash_function()(KeyGetter::Apply(value)); + tmp.Buckets_.InitNode( + tmp.FindProperBucket(hs, KeyGetter::Apply(value)), std::move_if_noexcept(value)); + } + swap(tmp); + } + +public: + SizeFitter SizeFitter_; + Container Buckets_; + hasher Hasher_; + key_equal KeyEqual_; +}; + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/ut/containers_ut.cpp b/library/cpp/containers/flat_hash/lib/ut/containers_ut.cpp new file mode 100644 index 00000000000..b17b30fa80b --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ut/containers_ut.cpp @@ -0,0 +1,410 @@ +#include <library/cpp/containers/flat_hash/lib/containers.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/algorithm.h> +#include <util/random/random.h> +#include <util/random/shuffle.h> + +using namespace NFlatHash; + +namespace { + constexpr size_t INIT_SIZE = 128; + + struct TDummy { + static size_t Count; + + TDummy() { ++Count; } + TDummy(const TDummy&) { ++Count; } + ~TDummy() { --Count; } + }; + size_t TDummy::Count = 0; + + struct TAlmostDummy { + static size_t Count; + + TAlmostDummy(int j = 0) : Junk(j) { ++Count; } + TAlmostDummy(const TAlmostDummy& d) : Junk(d.Junk) { ++Count; } + ~TAlmostDummy() { --Count; } + + bool operator==(const TAlmostDummy& r) const { return Junk == r.Junk; }; + bool operator!=(const TAlmostDummy& r) const { return !operator==(r); }; + + int Junk; + }; + size_t TAlmostDummy::Count = 0; + + struct TNotSimple { + enum class EType { + Value, + Empty, + Deleted + } Type_ = EType::Value; + + TString Junk = "something"; // to prevent triviality propagation + int Value = 0; + + static int CtorCalls; + static int DtorCalls; + static int CopyCtorCalls; + static int MoveCtorCalls; + + TNotSimple() { + ++CtorCalls; + } + explicit TNotSimple(int value) + : Value(value) + { + ++CtorCalls; + } + + TNotSimple(const TNotSimple& rhs) { + ++CopyCtorCalls; + Value = rhs.Value; + Type_ = rhs.Type_; + } + TNotSimple(TNotSimple&& rhs) { + ++MoveCtorCalls; + Value = rhs.Value; + Type_ = rhs.Type_; + } + + ~TNotSimple() { + ++DtorCalls; + } + + TNotSimple& operator=(const TNotSimple& rhs) { + ++CopyCtorCalls; + Value = rhs.Value; + Type_ = rhs.Type_; + return *this; + } + TNotSimple& operator=(TNotSimple&& rhs) { + ++MoveCtorCalls; + Value = rhs.Value; + Type_ = rhs.Type_; + return *this; + } + + static TNotSimple Empty() { + TNotSimple ret; + ret.Type_ = EType::Empty; + return ret; + } + + static TNotSimple Deleted() { + TNotSimple ret; + ret.Type_ = EType::Deleted; + return ret; + } + + bool operator==(const TNotSimple& rhs) const noexcept { + return Value == rhs.Value; + } + + static void ResetStats() { + CtorCalls = 0; + DtorCalls = 0; + CopyCtorCalls = 0; + MoveCtorCalls = 0; + } + }; + + int TNotSimple::CtorCalls = 0; + int TNotSimple::DtorCalls = 0; + int TNotSimple::CopyCtorCalls = 0; + int TNotSimple::MoveCtorCalls = 0; + + struct TNotSimpleEmptyMarker { + using value_type = TNotSimple; + + value_type Create() const { + return TNotSimple::Empty(); + } + + bool Equals(const value_type& rhs) const { + return rhs.Type_ == TNotSimple::EType::Empty; + } + }; + + struct TNotSimpleDeletedMarker { + using value_type = TNotSimple; + + value_type Create() const { + return TNotSimple::Deleted(); + } + + bool Equals(const value_type& rhs) const { + return rhs.Type_ == TNotSimple::EType::Deleted; + } + }; + + template <class Container> + void CheckContainersEqual(const Container& a, const Container& b) { + UNIT_ASSERT_EQUAL(a.Size(), b.Size()); + UNIT_ASSERT_EQUAL(a.Taken(), b.Empty()); + for (typename Container::size_type i = 0; i < a.Size(); ++i) { + if (a.IsTaken(i)) { + UNIT_ASSERT(b.IsTaken(i)); + UNIT_ASSERT_EQUAL(a.Node(i), b.Node(i)); + } + } + } + + template <class Container, class... Args> + void SmokingTest(typename Container::size_type size, Args&&... args) { + using size_type = typename Container::size_type; + using value_type = typename Container::value_type; + + Container cont(size, std::forward<Args>(args)...); + UNIT_ASSERT_EQUAL(cont.Size(), size); + UNIT_ASSERT_EQUAL(cont.Taken(), 0); + + for (size_type i = 0; i < cont.Size(); ++i) { + UNIT_ASSERT(cont.IsEmpty(i)); + UNIT_ASSERT(!cont.IsTaken(i)); + } + + // Filling the container till half + TVector<size_type> toInsert(cont.Size()); + Iota(toInsert.begin(), toInsert.end(), 0); + Shuffle(toInsert.begin(), toInsert.end()); + toInsert.resize(toInsert.size() / 2); + for (auto i : toInsert) { + UNIT_ASSERT(cont.IsEmpty(i)); + UNIT_ASSERT(!cont.IsTaken(i)); + value_type value(RandomNumber<size_type>(cont.Size())); + cont.InitNode(i, value); + UNIT_ASSERT(!cont.IsEmpty(i)); + UNIT_ASSERT(cont.IsTaken(i)); + UNIT_ASSERT_EQUAL(cont.Node(i), value); + } + UNIT_ASSERT_EQUAL(cont.Taken(), toInsert.size()); + + // Copy construction test + auto cont2 = cont; + CheckContainersEqual(cont, cont2); + + // Copy assignment test + cont2 = cont2.Clone(0); + UNIT_ASSERT_EQUAL(cont2.Size(), 0); + UNIT_ASSERT_EQUAL(cont2.Taken(), 0); + + // Copy assignment test + cont2 = cont; + CheckContainersEqual(cont, cont2); + + // Move construction test + auto cont3 = std::move(cont2); + UNIT_ASSERT_EQUAL(cont2.Size(), 0); + CheckContainersEqual(cont, cont3); + + // Move assignment test + cont2 = std::move(cont3); + UNIT_ASSERT_EQUAL(cont3.Size(), 0); + CheckContainersEqual(cont, cont2); + } +} + +Y_UNIT_TEST_SUITE(TFlatContainerTest) { + Y_UNIT_TEST(SmokingTest) { + SmokingTest<TFlatContainer<int>>(INIT_SIZE); + } + + Y_UNIT_TEST(SmokingTestNotSimpleType) { + TNotSimple::ResetStats(); + SmokingTest<TFlatContainer<TNotSimple>>(INIT_SIZE); + + UNIT_ASSERT_EQUAL(TNotSimple::CtorCalls + TNotSimple::CopyCtorCalls + TNotSimple::MoveCtorCalls, + TNotSimple::DtorCalls); + UNIT_ASSERT_EQUAL(TNotSimple::CtorCalls, INIT_SIZE / 2 /* created while filling */); + UNIT_ASSERT_EQUAL(TNotSimple::DtorCalls, INIT_SIZE / 2 /* removed filling temporary */ + + INIT_SIZE / 2 /* removed while cloning */ + + INIT_SIZE /* 3 containers dtors */); + UNIT_ASSERT_EQUAL(TNotSimple::CopyCtorCalls, INIT_SIZE / 2 /* 3 created while filling */ + + INIT_SIZE / 2 /* created while copy constructing */ + + INIT_SIZE / 2/* created while copy assigning */); + UNIT_ASSERT_EQUAL(TNotSimple::MoveCtorCalls, 0); + } + + Y_UNIT_TEST(DummyHalfSizeTest) { + using TContainer = TFlatContainer<TDummy>; + using size_type = typename TContainer::size_type; + + { + TContainer cont(INIT_SIZE); + UNIT_ASSERT_EQUAL(TDummy::Count, 0); + + TVector<size_type> toInsert(cont.Size()); + Iota(toInsert.begin(), toInsert.end(), 0); + Shuffle(toInsert.begin(), toInsert.end()); + toInsert.resize(toInsert.size() / 2); + for (auto i : toInsert) { + UNIT_ASSERT(cont.IsEmpty(i)); + UNIT_ASSERT(!cont.IsTaken(i)); + cont.InitNode(i); + UNIT_ASSERT_EQUAL(TDummy::Count, cont.Taken()); + UNIT_ASSERT(!cont.IsEmpty(i)); + UNIT_ASSERT(cont.IsTaken(i)); + } + UNIT_ASSERT_EQUAL(cont.Taken(), cont.Size() / 2); + UNIT_ASSERT_EQUAL(TDummy::Count, cont.Taken()); + } + UNIT_ASSERT_EQUAL(TDummy::Count, 0); + } + + Y_UNIT_TEST(DeleteTest) { + using TContainer = TFlatContainer<TDummy>; + using size_type = typename TContainer::size_type; + + TContainer cont(INIT_SIZE); + auto idx = RandomNumber<size_type>(INIT_SIZE); + UNIT_ASSERT(!cont.IsTaken(idx)); + UNIT_ASSERT(!cont.IsDeleted(idx)); + UNIT_ASSERT_EQUAL(TDummy::Count, 0); + + cont.InitNode(idx); + UNIT_ASSERT_EQUAL(cont.Taken(), 1); + UNIT_ASSERT(cont.IsTaken(idx)); + UNIT_ASSERT(!cont.IsDeleted(idx)); + UNIT_ASSERT_EQUAL(TDummy::Count, 1); + + cont.DeleteNode(idx); + UNIT_ASSERT(!cont.IsTaken(idx)); + UNIT_ASSERT(cont.IsDeleted(idx)); + UNIT_ASSERT_EQUAL(TDummy::Count, 0); + } +} + +Y_UNIT_TEST_SUITE(TDenseContainerTest) { + Y_UNIT_TEST(SmokingTest) { + SmokingTest<TDenseContainer<int, NSet::TStaticValueMarker<-1>>>(INIT_SIZE); + } + + Y_UNIT_TEST(NotSimpleTypeSmokingTest) { + TNotSimple::ResetStats(); + SmokingTest<TDenseContainer<TNotSimple, TNotSimpleEmptyMarker>>(INIT_SIZE); + + UNIT_ASSERT_EQUAL(TNotSimple::CtorCalls + TNotSimple::CopyCtorCalls + TNotSimple::MoveCtorCalls, + TNotSimple::DtorCalls); + UNIT_ASSERT_EQUAL(TNotSimple::CtorCalls, INIT_SIZE / 2 /* created while filling */ + + 2 /* created by empty marker */); + UNIT_ASSERT_EQUAL(TNotSimple::DtorCalls, 1 /* removed empty marker temporary */ + + INIT_SIZE /* half removed while resetting in container, + half removed inserted temporary */ + + INIT_SIZE /* removed while cloning */ + + 1 /* removed empty marker temporary */ + + INIT_SIZE * 2 /* 3 containers dtors */); + UNIT_ASSERT_EQUAL(TNotSimple::CopyCtorCalls, INIT_SIZE /* created while constructing */ + + INIT_SIZE / 2 /* 3 created while filling */ + + INIT_SIZE /* created while copy constructing */ + + INIT_SIZE /* created while copy assigning */); + UNIT_ASSERT_EQUAL(TNotSimple::MoveCtorCalls, 0); + } + + Y_UNIT_TEST(RemovalContainerSmokingTest) { + SmokingTest<TRemovalDenseContainer<int, NSet::TStaticValueMarker<-1>, + NSet::TStaticValueMarker<-2>>>(INIT_SIZE); + } + + Y_UNIT_TEST(NotSimpleTypeRemovalContainerSmokingTest) { + TNotSimple::ResetStats(); + SmokingTest<TRemovalDenseContainer<TNotSimple, TNotSimpleEmptyMarker, + TNotSimpleDeletedMarker>>(INIT_SIZE); + + UNIT_ASSERT_EQUAL(TNotSimple::CtorCalls + TNotSimple::CopyCtorCalls + TNotSimple::MoveCtorCalls, + TNotSimple::DtorCalls); + UNIT_ASSERT_EQUAL(TNotSimple::CtorCalls, INIT_SIZE / 2 /* created while filling */ + + 2 /* created by empty marker */); + UNIT_ASSERT_EQUAL(TNotSimple::DtorCalls, 1 /* removed empty marker temporary */ + + INIT_SIZE /* half removed while resetting in container, + half removed inserted temporary */ + + INIT_SIZE /* removed while cloning */ + + 1 /* removed empty marker temporary */ + + INIT_SIZE * 2 /* 3 containers dtors */); + UNIT_ASSERT_EQUAL(TNotSimple::CopyCtorCalls, INIT_SIZE /* created while constructing */ + + INIT_SIZE / 2 /* 3 created while filling */ + + INIT_SIZE /* created while copy constructing */ + + INIT_SIZE /* created while copy assigning */); + UNIT_ASSERT_EQUAL(TNotSimple::MoveCtorCalls, 0); + } + + Y_UNIT_TEST(DummyHalfSizeTest) { + using TContainer = TDenseContainer<TAlmostDummy, NSet::TEqValueMarker<TAlmostDummy>>; + using size_type = typename TContainer::size_type; + + { + TContainer cont(INIT_SIZE, TAlmostDummy{-1}); + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, cont.Size() + 1); // 1 for empty marker + + TVector<size_type> toInsert(cont.Size()); + Iota(toInsert.begin(), toInsert.end(), 0); + Shuffle(toInsert.begin(), toInsert.end()); + toInsert.resize(toInsert.size() / 2); + for (auto i : toInsert) { + UNIT_ASSERT(cont.IsEmpty(i)); + UNIT_ASSERT(!cont.IsTaken(i)); + cont.InitNode(i, (int)RandomNumber<size_type>(cont.Size())); + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, cont.Size() + 1); + UNIT_ASSERT(!cont.IsEmpty(i)); + UNIT_ASSERT(cont.IsTaken(i)); + } + UNIT_ASSERT_EQUAL(cont.Taken(), toInsert.size()); + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, cont.Size() + 1); + } + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, 0); + } + + Y_UNIT_TEST(DeleteTest) { + using TContainer = TRemovalDenseContainer<TAlmostDummy, NSet::TEqValueMarker<TAlmostDummy>, + NSet::TEqValueMarker<TAlmostDummy>>; + using size_type = typename TContainer::size_type; + + TContainer cont(INIT_SIZE, TAlmostDummy{ -2 }, TAlmostDummy{ -1 }); + auto idx = RandomNumber<size_type>(cont.Size()); + UNIT_ASSERT(!cont.IsTaken(idx)); + UNIT_ASSERT(!cont.IsDeleted(idx)); + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, cont.Size() + 2); // 2 for markers + + cont.InitNode(idx, (int)RandomNumber<size_type>(cont.Size())); + UNIT_ASSERT_EQUAL(cont.Taken(), 1); + UNIT_ASSERT(cont.IsTaken(idx)); + UNIT_ASSERT(!cont.IsDeleted(idx)); + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, cont.Size() + 2); + + cont.DeleteNode(idx); + UNIT_ASSERT(!cont.IsTaken(idx)); + UNIT_ASSERT(cont.IsDeleted(idx)); + UNIT_ASSERT_EQUAL(TAlmostDummy::Count, cont.Size() + 2); + } + + Y_UNIT_TEST(FancyInitsTest) { + { + using TContainer = TDenseContainer<int>; + TContainer cont{ INIT_SIZE, -1 }; + } + { + using TContainer = TDenseContainer<int, NSet::TStaticValueMarker<-1>>; + TContainer cont{ INIT_SIZE }; + static_assert(!std::is_constructible_v<TContainer, size_t, int>); + } + { + using TContainer = TDenseContainer<int, NSet::TEqValueMarker<int>>; + TContainer cont{ INIT_SIZE, -1 }; + TContainer cont2{ INIT_SIZE, NSet::TEqValueMarker<int>{ -1 } }; + } + { + using TContainer = TRemovalDenseContainer<int>; + TContainer cont{ INIT_SIZE, -1, -2 }; + TContainer cont2{ INIT_SIZE, NSet::TEqValueMarker<int>{ -1 }, + NSet::TEqValueMarker<int>{ -2 } }; + } + { + using TContainer = TRemovalDenseContainer<int, NSet::TStaticValueMarker<-1>, + NSet::TStaticValueMarker<-1>>; + TContainer cont{ INIT_SIZE }; + static_assert(!std::is_constructible_v<TContainer, size_t, int>); + static_assert(!std::is_constructible_v<TContainer, size_t, int, int>); + } + } +} diff --git a/library/cpp/containers/flat_hash/lib/ut/iterator_ut.cpp b/library/cpp/containers/flat_hash/lib/ut/iterator_ut.cpp new file mode 100644 index 00000000000..0b77bf043f3 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ut/iterator_ut.cpp @@ -0,0 +1,85 @@ +#include <library/cpp/containers/flat_hash/lib/iterator.h> +#include <library/cpp/containers/flat_hash/lib/containers.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/random/random.h> +#include <util/generic/algorithm.h> + +using namespace NFlatHash; + +namespace { + constexpr size_t INIT_SIZE = 128; + + template <class Container> + void SmokingTest(Container& cont) { + using value_type = typename Container::value_type; + using iterator = TIterator<Container, value_type>; + using size_type = typename Container::size_type; + + iterator f(&cont), l(&cont, cont.Size()); + UNIT_ASSERT_EQUAL(f, l); + UNIT_ASSERT_EQUAL((size_type)std::distance(f, l), cont.Taken()); + + TVector<std::pair<size_type, value_type>> toAdd{ + { 0, (int)RandomNumber<size_type>(INIT_SIZE) }, + { 1 + RandomNumber<size_type>(INIT_SIZE - 2), (int)RandomNumber<size_type>(INIT_SIZE) }, + { INIT_SIZE - 1, (int)RandomNumber<size_type>(INIT_SIZE) } + }; + + for (const auto& p : toAdd) { + UNIT_ASSERT(cont.IsEmpty(p.first)); + cont.InitNode(p.first, p.second); + } + UNIT_ASSERT_EQUAL(cont.Size(), INIT_SIZE); + f = iterator{ &cont }; + l = iterator{ &cont, INIT_SIZE }; + UNIT_ASSERT_UNEQUAL(f, l); + UNIT_ASSERT_EQUAL((size_type)std::distance(f, l), cont.Taken()); + + TVector<value_type> added(f, l); + UNIT_ASSERT(::Equal(toAdd.begin(), toAdd.end(), added.begin(), [](const auto& p, auto v) { + return p.second == v; + })); + } + + template <class Container> + void ConstTest(Container& cont) { + using value_type = typename Container::value_type; + using iterator = TIterator<Container, value_type>; + using const_iterator = TIterator<const Container, const value_type>; + + iterator it{ &cont, INIT_SIZE / 2 }; + const_iterator cit1{ it }; + const_iterator cit2{ &cont, INIT_SIZE / 2 }; + + UNIT_ASSERT_EQUAL(cit1, cit2); + + static_assert(std::is_same<decltype(*it), value_type&>::value); + static_assert(std::is_same<decltype(*cit1), const value_type&>::value); + } +} + +Y_UNIT_TEST_SUITE(TIteratorTest) { + Y_UNIT_TEST(SmokingTest) { + { + TFlatContainer<int> cont(INIT_SIZE); + SmokingTest(cont); + } + { + TDenseContainer<int, NSet::TStaticValueMarker<-1>> cont(INIT_SIZE); + SmokingTest(cont); + } + } + + Y_UNIT_TEST(ConstTest) { + { + TFlatContainer<int> cont(INIT_SIZE); + ConstTest(cont); + } + { + TDenseContainer<int, NSet::TStaticValueMarker<-1>> cont(INIT_SIZE); + ConstTest(cont); + } + } +} diff --git a/library/cpp/containers/flat_hash/lib/ut/probings_ut.cpp b/library/cpp/containers/flat_hash/lib/ut/probings_ut.cpp new file mode 100644 index 00000000000..593f8cbb1bb --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ut/probings_ut.cpp @@ -0,0 +1,34 @@ +#include <library/cpp/containers/flat_hash/lib/probings.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NFlatHash; + +namespace { + struct TDummySizeFitter { + constexpr auto EvalIndex(size_t idx, size_t) const { + return idx; + } + }; + + constexpr TDummySizeFitter SIZE_FITTER; + + auto atLeast13 = [](size_t idx) { return idx >= 13; }; +} + +Y_UNIT_TEST_SUITE(TProbingsTest) { + Y_UNIT_TEST(LinearProbingTest) { + using TProbing = TLinearProbing; + UNIT_ASSERT_EQUAL(TProbing::FindBucket(SIZE_FITTER, 1, 0, atLeast13), 13); + } + + Y_UNIT_TEST(QuadraticProbingTest) { + using TProbing = TQuadraticProbing; + UNIT_ASSERT_EQUAL(TProbing::FindBucket(SIZE_FITTER, 1, 0, atLeast13), 17); + } + + Y_UNIT_TEST(DenseProbingTest) { + using TProbing = TDenseProbing; + UNIT_ASSERT_EQUAL(TProbing::FindBucket(SIZE_FITTER, 1, 0, atLeast13), 16); + } +} diff --git a/library/cpp/containers/flat_hash/lib/ut/size_fitters_ut.cpp b/library/cpp/containers/flat_hash/lib/ut/size_fitters_ut.cpp new file mode 100644 index 00000000000..4167947ece2 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ut/size_fitters_ut.cpp @@ -0,0 +1,51 @@ +#include <library/cpp/containers/flat_hash/lib/size_fitters.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NFlatHash; + +Y_UNIT_TEST_SUITE(TAndSizeFitterTest) { + Y_UNIT_TEST(EvalSizeTest) { + TAndSizeFitter sf; + UNIT_ASSERT_EQUAL(sf.EvalSize(5), 8); + UNIT_ASSERT_EQUAL(sf.EvalSize(8), 8); + UNIT_ASSERT_EQUAL(sf.EvalSize(13), 16); + UNIT_ASSERT_EQUAL(sf.EvalSize(25), 32); + for (size_t i = 1; i < 100; ++i) { + UNIT_ASSERT_EQUAL(sf.EvalSize(i), FastClp2(i)); + } + } + + Y_UNIT_TEST(EvalIndexTest) { + TAndSizeFitter sf; + for (size_t j = 1; j < 10; ++j) { + sf.Update(1 << j); + for (size_t i = 0; i < 100; ++i) { + UNIT_ASSERT_EQUAL(sf.EvalIndex(i, 1 << j), i & ((1 << j) - 1)); + } + } + } +} + +Y_UNIT_TEST_SUITE(TModSizeFitterTest) { + Y_UNIT_TEST(EvalSizeTest) { + TModSizeFitter sf; + UNIT_ASSERT_EQUAL(sf.EvalSize(5), 5); + UNIT_ASSERT_EQUAL(sf.EvalSize(8), 8); + UNIT_ASSERT_EQUAL(sf.EvalSize(13), 13); + UNIT_ASSERT_EQUAL(sf.EvalSize(25), 25); + for (size_t i = 1; i < 100; ++i) { + UNIT_ASSERT_EQUAL(sf.EvalSize(i), i); + } + } + + Y_UNIT_TEST(EvalIndexTest) { + TModSizeFitter sf; + for (size_t j = 1; j < 10; ++j) { + sf.Update(1 << j); // just for integrity + for (size_t i = 0; i < 100; ++i) { + UNIT_ASSERT_EQUAL(sf.EvalIndex(i, 1 << j), i % (1 << j)); + } + } + } +} diff --git a/library/cpp/containers/flat_hash/lib/ut/table_ut.cpp b/library/cpp/containers/flat_hash/lib/ut/table_ut.cpp new file mode 100644 index 00000000000..ea511e2c6af --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ut/table_ut.cpp @@ -0,0 +1,411 @@ +#include <library/cpp/containers/flat_hash/lib/containers.h> +#include <library/cpp/containers/flat_hash/lib/expanders.h> +#include <library/cpp/containers/flat_hash/lib/probings.h> +#include <library/cpp/containers/flat_hash/lib/size_fitters.h> +#include <library/cpp/containers/flat_hash/lib/table.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/algorithm.h> +#include <util/random/random.h> +#include <util/random/shuffle.h> + +using namespace NFlatHash; + +namespace { + template <class T> + struct TJustType { + using type = T; + }; + + template <class... Ts> + struct TTypePack {}; + + template <class F, class... Ts> + constexpr void ForEachType(F&& f, TTypePack<Ts...>) { + ApplyToMany(std::forward<F>(f), TJustType<Ts>{}...); + } + +/* Usage example: + * + * TForEachType<int, float, TString>::Apply([](auto t) { + * using T = GET_TYPE(t); + * }); + * So T would be: + * int on #0 iteration + * float on #1 iteration + * TString on #2 iteration + */ +#define GET_TYPE(ti) typename decltype(ti)::type + + constexpr size_t INIT_SIZE = 32; + constexpr size_t BIG_INIT_SIZE = 128; + + template <class T> + struct TSimpleKeyGetter { + static constexpr T& Apply(T& t) { return t; } + static constexpr const T& Apply(const T& t) { return t; } + }; + + template <class T, + class KeyEqual = std::equal_to<T>, + class ValueEqual = std::equal_to<T>, + class KeyGetter = TSimpleKeyGetter<T>, + class F, + class... Containers> + void ForEachTable(F f, TTypePack<Containers...> cs) { + ForEachType([&](auto p) { + using TProbing = GET_TYPE(p); + + ForEachType([&](auto sf) { + using TSizeFitter = GET_TYPE(sf); + + ForEachType([&](auto t) { + using TContainer = GET_TYPE(t); + static_assert(std::is_same_v<typename TContainer::value_type, T>); + + using TTable = TTable<THash<T>, + KeyEqual, + TContainer, + KeyGetter, + TProbing, + TSizeFitter, + TSimpleExpander>; + + f(TJustType<TTable>{}); + }, cs); + }, TTypePack<TAndSizeFitter, TModSizeFitter>{}); + }, TTypePack<TLinearProbing, TQuadraticProbing, TDenseProbing>{}); + } + + using TAtomContainers = TTypePack<TFlatContainer<int>, + TDenseContainer<int, NSet::TStaticValueMarker<-1>>>; + using TContainers = TTypePack<TFlatContainer<int>, + TDenseContainer<int, NSet::TStaticValueMarker<-1>>>; + using TRemovalContainers = TTypePack<TFlatContainer<int>, + TRemovalDenseContainer<int, NSet::TStaticValueMarker<-2>, + NSet::TStaticValueMarker<-1>>>; +} + +Y_UNIT_TEST_SUITE(TCommonTableAtomsTest) { + Y_UNIT_TEST(InitTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + UNIT_ASSERT(table.empty()); + UNIT_ASSERT_EQUAL(table.size(), 0); + UNIT_ASSERT_EQUAL(table.bucket_count(), INIT_SIZE); + UNIT_ASSERT_EQUAL(table.bucket_size(RandomNumber<size_t>(INIT_SIZE)), 0); + }, TAtomContainers{}); + } + + Y_UNIT_TEST(IteratorTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + auto first = table.begin(); + auto last = table.end(); + UNIT_ASSERT_EQUAL(first, last); + UNIT_ASSERT_EQUAL(std::distance(first, last), 0); + + auto cFirst = table.cbegin(); + auto cLast = table.cend(); + UNIT_ASSERT_EQUAL(cFirst, cLast); + UNIT_ASSERT_EQUAL(std::distance(cFirst, cLast), 0); + }, TAtomContainers{}); + } + + Y_UNIT_TEST(ContainsAndCountTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + for (int i = 0; i < 100; ++i) { + UNIT_ASSERT_EQUAL(table.count(i), 0); + UNIT_ASSERT(!table.contains(i)); + } + }, TAtomContainers{}); + } + + Y_UNIT_TEST(FindTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + for (int i = 0; i < 100; ++i) { + auto it = table.find(i); + UNIT_ASSERT_EQUAL(it, table.end()); + } + }, TAtomContainers{}); + } +} + +Y_UNIT_TEST_SUITE(TCommonTableTest) { + Y_UNIT_TEST(InsertTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + UNIT_ASSERT(table.empty()); + UNIT_ASSERT_EQUAL(table.size(), 0); + + int toInsert = RandomNumber<size_t>(100); + UNIT_ASSERT_EQUAL(table.count(toInsert), 0); + UNIT_ASSERT(!table.contains(toInsert)); + + auto p = table.insert(toInsert); + UNIT_ASSERT_EQUAL(p.first, table.begin()); + UNIT_ASSERT(p.second); + + UNIT_ASSERT(!table.empty()); + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(toInsert), 1); + UNIT_ASSERT(table.contains(toInsert)); + + auto it = table.find(toInsert); + UNIT_ASSERT_UNEQUAL(it, table.end()); + UNIT_ASSERT_EQUAL(it, table.begin()); + UNIT_ASSERT_EQUAL(*it, toInsert); + + auto p2 = table.insert(toInsert); + UNIT_ASSERT_EQUAL(p.first, p2.first); + UNIT_ASSERT(!p2.second); + + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(toInsert), 1); + UNIT_ASSERT(table.contains(toInsert)); + }, TContainers{}); + } + + Y_UNIT_TEST(ClearTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + TVector<int> toInsert(INIT_SIZE); + Iota(toInsert.begin(), toInsert.end(), 0); + ShuffleRange(toInsert); + toInsert.resize(INIT_SIZE / 3); + + for (auto i : toInsert) { + auto p = table.insert(i); + UNIT_ASSERT_EQUAL(*p.first, i); + UNIT_ASSERT(p.second); + } + UNIT_ASSERT_EQUAL(table.size(), toInsert.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table.begin(), table.end()), toInsert.size()); + + for (auto i : toInsert) { + UNIT_ASSERT(table.contains(i)); + UNIT_ASSERT_EQUAL(table.count(i), 1); + } + + auto bc = table.bucket_count(); + table.clear(); + UNIT_ASSERT(table.empty()); + UNIT_ASSERT_EQUAL(table.bucket_count(), bc); + + for (auto i : toInsert) { + UNIT_ASSERT(!table.contains(i)); + UNIT_ASSERT_EQUAL(table.count(i), 0); + } + + table.insert(toInsert.front()); + UNIT_ASSERT(!table.empty()); + }, TContainers{}); + } + + Y_UNIT_TEST(CopyMoveTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + TVector<int> toInsert(INIT_SIZE); + Iota(toInsert.begin(), toInsert.end(), 0); + ShuffleRange(toInsert); + toInsert.resize(INIT_SIZE / 3); + + for (auto i : toInsert) { + auto p = table.insert(i); + UNIT_ASSERT_EQUAL(*p.first, i); + UNIT_ASSERT(p.second); + } + UNIT_ASSERT_EQUAL(table.size(), toInsert.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table.begin(), table.end()), toInsert.size()); + + for (auto i : toInsert) { + UNIT_ASSERT(table.contains(i)); + UNIT_ASSERT_EQUAL(table.count(i), 1); + } + + // Copy construction test + auto table2 = table; + UNIT_ASSERT_EQUAL(table2.size(), table.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table2.begin(), table2.end()), table.size()); + for (auto i : table) { + UNIT_ASSERT(table2.contains(i)); + UNIT_ASSERT_EQUAL(table2.count(i), 1); + } + + table2.clear(); + UNIT_ASSERT(table2.empty()); + + // Copy assignment test + table2 = table; + UNIT_ASSERT_EQUAL(table2.size(), table.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table2.begin(), table2.end()), table.size()); + for (auto i : table) { + UNIT_ASSERT(table2.contains(i)); + UNIT_ASSERT_EQUAL(table2.count(i), 1); + } + + // Move construction test + auto table3 = std::move(table2); + UNIT_ASSERT(table2.empty()); + UNIT_ASSERT(table2.bucket_count() > 0); + + UNIT_ASSERT_EQUAL(table3.size(), table.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table3.begin(), table3.end()), table.size()); + for (auto i : table) { + UNIT_ASSERT(table3.contains(i)); + UNIT_ASSERT_EQUAL(table3.count(i), 1); + } + + table2.insert(toInsert.front()); + UNIT_ASSERT(!table2.empty()); + UNIT_ASSERT_EQUAL(table2.size(), 1); + UNIT_ASSERT_UNEQUAL(table2.bucket_count(), 0); + + // Move assignment test + table2 = std::move(table3); + UNIT_ASSERT(table3.empty()); + UNIT_ASSERT(table3.bucket_count() > 0); + + UNIT_ASSERT_EQUAL(table2.size(), table.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table2.begin(), table2.end()), table.size()); + for (auto i : table) { + UNIT_ASSERT(table2.contains(i)); + UNIT_ASSERT_EQUAL(table2.count(i), 1); + } + + table3.insert(toInsert.front()); + UNIT_ASSERT(!table3.empty()); + UNIT_ASSERT_EQUAL(table3.size(), 1); + UNIT_ASSERT_UNEQUAL(table3.bucket_count(), 0); + }, TContainers{}); + } + + Y_UNIT_TEST(RehashTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + TVector<int> toInsert(INIT_SIZE); + Iota(toInsert.begin(), toInsert.end(), 0); + ShuffleRange(toInsert); + toInsert.resize(INIT_SIZE / 3); + + for (auto i : toInsert) { + table.insert(i); + } + + auto bc = table.bucket_count(); + table.rehash(bc * 2); + UNIT_ASSERT(bc * 2 <= table.bucket_count()); + + UNIT_ASSERT_EQUAL(table.size(), toInsert.size()); + UNIT_ASSERT_EQUAL((size_t)std::distance(table.begin(), table.end()), toInsert.size()); + for (auto i : toInsert) { + UNIT_ASSERT(table.contains(i)); + UNIT_ASSERT_EQUAL(table.count(i), 1); + } + + TVector<int> tmp(table.begin(), table.end()); + Sort(toInsert.begin(), toInsert.end()); + Sort(tmp.begin(), tmp.end()); + + UNIT_ASSERT_VALUES_EQUAL(tmp, toInsert); + + table.rehash(0); + UNIT_ASSERT_EQUAL(table.size(), toInsert.size()); + UNIT_ASSERT(table.bucket_count() > table.size()); + + table.clear(); + UNIT_ASSERT(table.empty()); + table.rehash(INIT_SIZE); + UNIT_ASSERT(table.bucket_count() >= INIT_SIZE); + + table.rehash(0); + UNIT_ASSERT(table.bucket_count() > 0); + }, TContainers{}); + } + + Y_UNIT_TEST(EraseTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ INIT_SIZE }; + + int value = RandomNumber<ui32>(); + table.insert(value); + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(value), 1); + + auto it = table.find(value); + table.erase(it); + + UNIT_ASSERT_EQUAL(table.size(), 0); + UNIT_ASSERT_EQUAL(table.count(value), 0); + + table.insert(value); + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(value), 1); + + table.erase(value); + + UNIT_ASSERT_EQUAL(table.size(), 0); + UNIT_ASSERT_EQUAL(table.count(value), 0); + + table.insert(value); + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(value), 1); + + table.erase(table.find(value), table.end()); + + UNIT_ASSERT_EQUAL(table.size(), 0); + UNIT_ASSERT_EQUAL(table.count(value), 0); + + table.erase(value); + + UNIT_ASSERT_EQUAL(table.size(), 0); + UNIT_ASSERT_EQUAL(table.count(value), 0); + }, TRemovalContainers{}); + } + + Y_UNIT_TEST(EraseBigTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ BIG_INIT_SIZE }; + + for (int i = 0; i < 1000; ++i) { + for (int j = 0; j < static_cast<int>(BIG_INIT_SIZE); ++j) { + table.emplace(j); + } + for (int j = 0; j < static_cast<int>(BIG_INIT_SIZE); ++j) { + table.erase(j); + } + } + UNIT_ASSERT(table.bucket_count() <= BIG_INIT_SIZE * 8); + }, TRemovalContainers{}); + } + + Y_UNIT_TEST(ConstructWithSizeTest) { + ForEachTable<int>([](auto t) { + GET_TYPE(t) table{ 1000 }; + UNIT_ASSERT(table.bucket_count() >= 1000); + + int value = RandomNumber<ui32>(); + table.insert(value); + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(value), 1); + UNIT_ASSERT(table.bucket_count() >= 1000); + + table.rehash(10); + UNIT_ASSERT_EQUAL(table.size(), 1); + UNIT_ASSERT_EQUAL(table.count(value), 1); + UNIT_ASSERT(table.bucket_count() < 1000); + UNIT_ASSERT(table.bucket_count() >= 10); + }, TContainers{}); + } +} diff --git a/library/cpp/containers/flat_hash/lib/ut/ya.make b/library/cpp/containers/flat_hash/lib/ut/ya.make new file mode 100644 index 00000000000..04d65a8c6e6 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ut/ya.make @@ -0,0 +1,17 @@ +UNITTEST() + +OWNER(tender-bum) + +SRCS( + size_fitters_ut.cpp + probings_ut.cpp + containers_ut.cpp + iterator_ut.cpp + table_ut.cpp +) + +PEERDIR( + library/cpp/containers/flat_hash/lib +) + +END() diff --git a/library/cpp/containers/flat_hash/lib/value_markers.cpp b/library/cpp/containers/flat_hash/lib/value_markers.cpp new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/value_markers.cpp diff --git a/library/cpp/containers/flat_hash/lib/value_markers.h b/library/cpp/containers/flat_hash/lib/value_markers.h new file mode 100644 index 00000000000..99351586b50 --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/value_markers.h @@ -0,0 +1,130 @@ +#pragma once + +#include "concepts/value_marker.h" + +#include <type_traits> +#include <tuple> + +namespace NFlatHash { + +namespace NSet { + +template <auto Value> +struct TStaticValueMarker { + using value_type = decltype(Value); + + constexpr auto Create() const noexcept { + return Value; + } + + template <class U> + bool Equals(const U& rhs) const { + return Value == rhs; + } +}; + +static_assert(NConcepts::ValueMarkerV<TStaticValueMarker<5>>); + +template <class T> +class TEqValueMarker { +public: + using value_type = T; + + template <class V, class = std::enable_if_t<std::is_constructible_v<T, std::decay_t<V>>>> + TEqValueMarker(V&& v) : Value_(std::forward<V>(v)) {} + + TEqValueMarker(const TEqValueMarker&) = default; + TEqValueMarker(TEqValueMarker&&) = default; + + TEqValueMarker& operator=(const TEqValueMarker&) = default; + TEqValueMarker& operator=(TEqValueMarker&&) = default; + + const T& Create() const noexcept { + return Value_; + } + + template <class U> + bool Equals(const U& rhs) const { + return Value_ == rhs; + } + +private: + T Value_; +}; + +static_assert(NConcepts::ValueMarkerV<TEqValueMarker<int>>); + +} // namespace NSet + +namespace NMap { + +template <auto Key, class T> +class TStaticValueMarker { + static_assert(std::is_default_constructible_v<T>); + +public: + using value_type = std::pair<decltype(Key), T>; + + TStaticValueMarker() = default; + + TStaticValueMarker(const TStaticValueMarker&) {} + TStaticValueMarker(TStaticValueMarker&&) {} + + TStaticValueMarker& operator=(const TStaticValueMarker&) noexcept { return *this; } + TStaticValueMarker& operator=(TStaticValueMarker&&) noexcept { return *this; } + + std::pair<decltype(Key), const T&> Create() const noexcept { return { Key, Value_ }; } + + template <class U> + bool Equals(const U& rhs) const { + return Key == rhs.first; + } + +private: + T Value_; +}; + +static_assert(NConcepts::ValueMarkerV<TStaticValueMarker<5, int>>); + +template <class Key, class T> +class TEqValueMarker { + static_assert(std::is_default_constructible_v<T>); + +public: + using value_type = std::pair<Key, T>; + + template <class V, class = std::enable_if_t<std::is_constructible_v<Key, std::decay_t<V>>>> + TEqValueMarker(V&& v) : Key_(std::forward<V>(v)) {} + + TEqValueMarker(const TEqValueMarker& vm) + : Key_(vm.Key_) {} + TEqValueMarker(TEqValueMarker&& vm) noexcept(std::is_nothrow_move_constructible_v<Key> + && std::is_nothrow_constructible_v<T>) + : Key_(std::move(vm.Key_)) {} + + TEqValueMarker& operator=(const TEqValueMarker& vm) { + Key_ = vm.Key_; + return *this; + } + TEqValueMarker& operator=(TEqValueMarker&& vm) noexcept(std::is_nothrow_move_assignable_v<Key>) { + Key_ = std::move(vm.Key_); + return *this; + } + + auto Create() const noexcept { return std::tie(Key_, Value_); } + + template <class U> + bool Equals(const U& rhs) const { + return Key_ == rhs.first; + } + +private: + Key Key_; + T Value_; +}; + +static_assert(NConcepts::ValueMarkerV<TEqValueMarker<int, int>>); + +} // namespace NMap + +} // namespace NFlatHash diff --git a/library/cpp/containers/flat_hash/lib/ya.make b/library/cpp/containers/flat_hash/lib/ya.make new file mode 100644 index 00000000000..afaa69110bf --- /dev/null +++ b/library/cpp/containers/flat_hash/lib/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +OWNER(tender-bum) + +SRCS( + containers.cpp + expanders.cpp + iterator.cpp + map.cpp + probings.cpp + set.cpp + size_fitters.cpp + table.cpp + value_markers.cpp +) + +END() diff --git a/library/cpp/containers/flat_hash/ut/flat_hash_ut.cpp b/library/cpp/containers/flat_hash/ut/flat_hash_ut.cpp new file mode 100644 index 00000000000..2b9d6a1dc2a --- /dev/null +++ b/library/cpp/containers/flat_hash/ut/flat_hash_ut.cpp @@ -0,0 +1,272 @@ +#include <library/cpp/containers/flat_hash/flat_hash.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NFH; + +namespace { + +constexpr size_t TEST_INIT_SIZE = 25; +constexpr std::initializer_list<int> SET_INPUT_SAMPLE{1, 2, 3, 4, 5}; +const std::initializer_list<std::pair<const int, TString>> MAP_INPUT_SAMPLE{ + {1, "a"}, + {2, "b"}, + {3, "c"}, + {4, "d"}, + {5, "e"} +}; + +} // namespace + +template <class Map> +class TMapTest : public TTestBase { + void AllocatorTest(); + + void SmokingTest() { + Map mp; + mp.emplace(5, "abc"); + + UNIT_ASSERT_EQUAL(mp.size(), 1); + UNIT_ASSERT(mp.contains(5)); + + auto it = mp.find(5); + UNIT_ASSERT_EQUAL(mp.begin(), it); + UNIT_ASSERT(!mp.empty()); + } + + void CopyConstructionTest() { + Map st(MAP_INPUT_SAMPLE); + auto st2 = st; + + UNIT_ASSERT(!st.empty()); + UNIT_ASSERT(!st2.empty()); + UNIT_ASSERT_EQUAL(st, st2); + } + + void MoveConstructionTest() { + Map st(MAP_INPUT_SAMPLE); + auto st2 = std::move(st); + + UNIT_ASSERT(st.empty()); + UNIT_ASSERT(!st2.empty()); + UNIT_ASSERT_UNEQUAL(st, st2); + } + + void CopyAssignmentTest() { + Map st(MAP_INPUT_SAMPLE); + Map st2; + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(st2.empty()); + + st2 = st; + UNIT_ASSERT_EQUAL(st, st2); + UNIT_ASSERT(!st2.empty()); + } + + void DoubleCopyAssignmentTest() { + Map st(MAP_INPUT_SAMPLE); + Map st2; + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(st2.empty()); + + st2 = st; + UNIT_ASSERT_EQUAL(st, st2); + UNIT_ASSERT(!st2.empty()); + + st2 = st; + UNIT_ASSERT_EQUAL(st, st2); + UNIT_ASSERT(!st2.empty()); + } + + void MoveAssignmentTest() { + Map st(MAP_INPUT_SAMPLE); + Map st2; + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(st2.empty()); + + st2 = std::move(st); + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(!st2.empty()); + UNIT_ASSERT(st.empty()); + } + + void InsertOrAssignTest() { + Map mp; + + auto p = mp.insert_or_assign(5, "abc"); + UNIT_ASSERT_EQUAL(p.first, mp.begin()); + UNIT_ASSERT(p.second); + UNIT_ASSERT_EQUAL(p.first->first, 5); + UNIT_ASSERT_EQUAL(p.first->second, "abc"); + + auto p2 = mp.insert_or_assign(5, "def"); + UNIT_ASSERT_EQUAL(p.first, p2.first); + UNIT_ASSERT(!p2.second); + UNIT_ASSERT_EQUAL(p2.first->first, 5); + UNIT_ASSERT_EQUAL(p2.first->second, "def"); + } + + void TryEmplaceTest() { + Map mp; + + auto p = mp.try_emplace(5, "abc"); + UNIT_ASSERT_EQUAL(p.first, mp.begin()); + UNIT_ASSERT(p.second); + UNIT_ASSERT_EQUAL(p.first->first, 5); + UNIT_ASSERT_EQUAL(p.first->second, "abc"); + + auto p2 = mp.try_emplace(5, "def"); + UNIT_ASSERT_EQUAL(p.first, p2.first); + UNIT_ASSERT(!p2.second); + UNIT_ASSERT_EQUAL(p2.first->first, 5); + UNIT_ASSERT_EQUAL(p.first->second, "abc"); + } + + UNIT_TEST_SUITE_DEMANGLE(TMapTest); + UNIT_TEST(AllocatorTest); + UNIT_TEST(SmokingTest); + UNIT_TEST(CopyConstructionTest); + UNIT_TEST(MoveConstructionTest); + UNIT_TEST(CopyAssignmentTest); + UNIT_TEST(DoubleCopyAssignmentTest); + UNIT_TEST(MoveAssignmentTest); + UNIT_TEST(InsertOrAssignTest); + UNIT_TEST(TryEmplaceTest); + UNIT_TEST_SUITE_END(); +}; + +template <> +void TMapTest<TFlatHashMap<int, TString>>::AllocatorTest() { + using Map = TFlatHashMap<int, TString>; + Map mp(3, typename Map::allocator_type()); +} + +template <> +void TMapTest<TDenseHashMapStaticMarker<int, TString, -1>>::AllocatorTest() { + using Map = TDenseHashMapStaticMarker<int, TString, -1>; + Map mp(3, NFlatHash::NMap::TStaticValueMarker<-1, TString>(), typename Map::allocator_type()); +} + +using TFlatHashMapTest = TMapTest<TFlatHashMap<int, TString>>; +using TDenseHashMapTest = TMapTest<TDenseHashMapStaticMarker<int, TString, -1>>; + +UNIT_TEST_SUITE_REGISTRATION(TFlatHashMapTest); +UNIT_TEST_SUITE_REGISTRATION(TDenseHashMapTest); + + +template <class Set> +class TSetTest : public TTestBase { + void AllocatorTest(); + void DefaultConstructTest() { + Set st; + + UNIT_ASSERT(st.empty()); + UNIT_ASSERT_EQUAL(st.size(), 0); + UNIT_ASSERT(st.bucket_count() > 0); + UNIT_ASSERT_EQUAL(st.begin(), st.end()); + UNIT_ASSERT(st.load_factor() < std::numeric_limits<float>::epsilon()); + } + + void InitCapacityConstructTest() { + Set st(TEST_INIT_SIZE); + + UNIT_ASSERT(st.empty()); + UNIT_ASSERT_EQUAL(st.size(), 0); + UNIT_ASSERT(st.bucket_count() >= TEST_INIT_SIZE); + UNIT_ASSERT_EQUAL(st.begin(), st.end()); + UNIT_ASSERT(st.load_factor() < std::numeric_limits<float>::epsilon()); + } + + void IteratorsConstructTest() { + Set st(SET_INPUT_SAMPLE.begin(), SET_INPUT_SAMPLE.end()); + + UNIT_ASSERT(!st.empty()); + UNIT_ASSERT_EQUAL(st.size(), SET_INPUT_SAMPLE.size()); + UNIT_ASSERT(st.bucket_count() >= st.size()); + UNIT_ASSERT_UNEQUAL(st.begin(), st.end()); + UNIT_ASSERT_EQUAL(static_cast<size_t>(std::distance(st.begin(), st.end())), st.size()); + UNIT_ASSERT(st.load_factor() > 0); + } + + void InitializerListConstructTest() { + Set st(SET_INPUT_SAMPLE); + + UNIT_ASSERT(!st.empty()); + UNIT_ASSERT(st.size() > 0); + UNIT_ASSERT(st.bucket_count() > 0); + UNIT_ASSERT_UNEQUAL(st.begin(), st.end()); + UNIT_ASSERT_EQUAL(static_cast<size_t>(std::distance(st.begin(), st.end())), st.size()); + UNIT_ASSERT(st.load_factor() > 0); + } + + void CopyConstructionTest() { + Set st(SET_INPUT_SAMPLE); + auto st2 = st; + + UNIT_ASSERT(!st.empty()); + UNIT_ASSERT(!st2.empty()); + UNIT_ASSERT_EQUAL(st, st2); + } + + void MoveConstructionTest() { + Set st(SET_INPUT_SAMPLE); + auto st2 = std::move(st); + + UNIT_ASSERT(st.empty()); + UNIT_ASSERT(!st2.empty()); + UNIT_ASSERT_UNEQUAL(st, st2); + } + + void CopyAssignmentTest() { + Set st(SET_INPUT_SAMPLE); + Set st2; + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(st2.empty()); + + st2 = st; + UNIT_ASSERT_EQUAL(st, st2); + UNIT_ASSERT(!st2.empty()); + } + + void MoveAssignmentTest() { + Set st(SET_INPUT_SAMPLE); + Set st2; + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(st2.empty()); + + st2 = std::move(st); + UNIT_ASSERT_UNEQUAL(st, st2); + UNIT_ASSERT(!st2.empty()); + UNIT_ASSERT(st.empty()); + } + + UNIT_TEST_SUITE_DEMANGLE(TSetTest); + UNIT_TEST(AllocatorTest); + UNIT_TEST(DefaultConstructTest); + UNIT_TEST(InitCapacityConstructTest); + UNIT_TEST(IteratorsConstructTest); + UNIT_TEST(InitializerListConstructTest); + UNIT_TEST(CopyConstructionTest); + UNIT_TEST(MoveConstructionTest); + UNIT_TEST(CopyAssignmentTest); + UNIT_TEST(MoveAssignmentTest); + UNIT_TEST_SUITE_END(); +}; + +template <> +void TSetTest<TFlatHashSet<int>>::AllocatorTest() { + using Map = TFlatHashSet<int>; + Map mp(3, typename Map::allocator_type()); +} + +template <> +void TSetTest<TDenseHashSetStaticMarker<int, -1>>::AllocatorTest() { + using Map = TDenseHashSetStaticMarker<int, -1>; + Map mp(3, NFlatHash::NSet::TStaticValueMarker<-1>(), typename Map::allocator_type()); +} + +using TFlatHashSetTest = TSetTest<TFlatHashSet<int, THash<int>>>; +using TDenseHashSetTest = TSetTest<TDenseHashSetStaticMarker<int, -1>>; + +UNIT_TEST_SUITE_REGISTRATION(TFlatHashSetTest); +UNIT_TEST_SUITE_REGISTRATION(TDenseHashSetTest); diff --git a/library/cpp/containers/flat_hash/ut/ya.make b/library/cpp/containers/flat_hash/ut/ya.make new file mode 100644 index 00000000000..1d33d361208 --- /dev/null +++ b/library/cpp/containers/flat_hash/ut/ya.make @@ -0,0 +1,13 @@ +UNITTEST() + +OWNER(tender-bum) + +SRCS( + flat_hash_ut.cpp +) + +PEERDIR( + library/cpp/containers/flat_hash +) + +END() diff --git a/library/cpp/containers/flat_hash/ya.make b/library/cpp/containers/flat_hash/ya.make new file mode 100644 index 00000000000..612e2c1cdea --- /dev/null +++ b/library/cpp/containers/flat_hash/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +OWNER(tender-bum) + +PEERDIR( + library/cpp/containers/flat_hash/lib +) + +SRCS( + flat_hash.cpp +) + +END() diff --git a/library/cpp/containers/intrusive_avl_tree/avltree.cpp b/library/cpp/containers/intrusive_avl_tree/avltree.cpp new file mode 100644 index 00000000000..dd27c7df419 --- /dev/null +++ b/library/cpp/containers/intrusive_avl_tree/avltree.cpp @@ -0,0 +1 @@ +#include "avltree.h" diff --git a/library/cpp/containers/intrusive_avl_tree/avltree.h b/library/cpp/containers/intrusive_avl_tree/avltree.h new file mode 100644 index 00000000000..a58c63b07c5 --- /dev/null +++ b/library/cpp/containers/intrusive_avl_tree/avltree.h @@ -0,0 +1,754 @@ +#pragma once + +#include <util/generic/noncopyable.h> + +template <class T, class C> +struct TAvlTreeItem; + +template <class T, class C> +class TAvlTree: public TNonCopyable { + using TTreeItem = TAvlTreeItem<T, C>; + friend struct TAvlTreeItem<T, C>; + + static inline const T* AsT(const TTreeItem* item) noexcept { + return (const T*)item; + } + + static inline T* AsT(TTreeItem* item) noexcept { + return (T*)item; + } + + template <class TTreeItem, class TValue> + class TIteratorBase { + public: + inline TIteratorBase(TTreeItem* p, const TAvlTree* t) noexcept + : Ptr_(p) + , Tree_(t) + { + } + + inline bool IsEnd() const noexcept { + return Ptr_ == nullptr; + } + + inline bool IsBegin() const noexcept { + return Ptr_ == nullptr; + } + + inline bool IsFirst() const noexcept { + return Ptr_ && Ptr_ == Tree_->Head_; + } + + inline bool IsLast() const noexcept { + return Ptr_ && Ptr_ == Tree_->Tail_; + } + + inline TValue& operator*() const noexcept { + return *AsT(Ptr_); + } + + inline TValue* operator->() const noexcept { + return AsT(Ptr_); + } + + inline TTreeItem* Inc() noexcept { + return Ptr_ = FindNext(Ptr_); + } + + inline TTreeItem* Dec() noexcept { + return Ptr_ = FindPrev(Ptr_); + } + + inline TIteratorBase& operator++() noexcept { + Inc(); + return *this; + } + + inline TIteratorBase operator++(int) noexcept { + TIteratorBase ret(*this); + Inc(); + return ret; + } + + inline TIteratorBase& operator--() noexcept { + Dec(); + return *this; + } + + inline TIteratorBase operator--(int) noexcept { + TIteratorBase ret(*this); + Dec(); + return ret; + } + + inline TIteratorBase Next() const noexcept { + return ConstructNext(*this); + } + + inline TIteratorBase Prev() const noexcept { + return ConstructPrev(*this); + } + + inline bool operator==(const TIteratorBase& r) const noexcept { + return Ptr_ == r.Ptr_; + } + + inline bool operator!=(const TIteratorBase& r) const noexcept { + return Ptr_ != r.Ptr_; + } + + private: + inline static TIteratorBase ConstructNext(const TIteratorBase& i) noexcept { + return TIterator(FindNext(i.Ptr_), i.Tree_); + } + + inline static TIteratorBase ConstructPrev(const TIteratorBase& i) noexcept { + return TIterator(FindPrev(i.Ptr_), i.Tree_); + } + + inline static TIteratorBase FindPrev(TTreeItem* el) noexcept { + if (el->Left_ != nullptr) { + el = el->Left_; + + while (el->Right_ != nullptr) { + el = el->Right_; + } + } else { + while (true) { + TTreeItem* last = el; + el = el->Parent_; + + if (el == nullptr || el->Right_ == last) { + break; + } + } + } + + return el; + } + + static TTreeItem* FindNext(TTreeItem* el) { + if (el->Right_ != nullptr) { + el = el->Right_; + + while (el->Left_) { + el = el->Left_; + } + } else { + while (true) { + TTreeItem* last = el; + el = el->Parent_; + + if (el == nullptr || el->Left_ == last) { + break; + } + } + } + + return el; + } + + private: + TTreeItem* Ptr_; + const TAvlTree* Tree_; + }; + + using TConstIterator = TIteratorBase<const TTreeItem, const T>; + using TIterator = TIteratorBase<TTreeItem, T>; + + static inline TConstIterator ConstructFirstConst(const TAvlTree* t) noexcept { + return TConstIterator(t->Head_, t); + } + + static inline TIterator ConstructFirst(const TAvlTree* t) noexcept { + return TIterator(t->Head_, t); + } + + static inline TConstIterator ConstructLastConst(const TAvlTree* t) noexcept { + return TConstIterator(t->Tail_, t); + } + + static inline TIterator ConstructLast(const TAvlTree* t) noexcept { + return TIterator(t->Tail_, t); + } + + static inline bool Compare(const TTreeItem& l, const TTreeItem& r) { + return C::Compare(*AsT(&l), *AsT(&r)); + } + +public: + using const_iterator = TConstIterator; + using iterator = TIterator; + + inline TAvlTree() noexcept + : Root_(nullptr) + , Head_(nullptr) + , Tail_(nullptr) + { + } + + inline ~TAvlTree() noexcept { + Clear(); + } + + inline void Clear() noexcept { + for (iterator it = Begin(); it != End();) { + (it++)->TTreeItem::Unlink(); + } + } + + inline T* Insert(TTreeItem* el, TTreeItem** lastFound = nullptr) noexcept { + el->Unlink(); + el->Tree_ = this; + + TTreeItem* curEl = Root_; + TTreeItem* parentEl = nullptr; + TTreeItem* lastLess = nullptr; + + while (true) { + if (curEl == nullptr) { + AttachRebal(el, parentEl, lastLess); + + if (lastFound != nullptr) { + *lastFound = el; + } + + return AsT(el); + } + + if (Compare(*el, *curEl)) { + parentEl = lastLess = curEl; + curEl = curEl->Left_; + } else if (Compare(*curEl, *el)) { + parentEl = curEl; + curEl = curEl->Right_; + } else { + if (lastFound != nullptr) { + *lastFound = curEl; + } + + return nullptr; + } + } + } + + inline T* Find(const TTreeItem* el) const noexcept { + TTreeItem* curEl = Root_; + + while (curEl) { + if (Compare(*el, *curEl)) { + curEl = curEl->Left_; + } else if (Compare(*curEl, *el)) { + curEl = curEl->Right_; + } else { + return AsT(curEl); + } + } + + return nullptr; + } + + inline T* LowerBound(const TTreeItem* el) const noexcept { + TTreeItem* curEl = Root_; + TTreeItem* lowerBound = nullptr; + + while (curEl) { + if (Compare(*el, *curEl)) { + lowerBound = curEl; + curEl = curEl->Left_; + } else if (Compare(*curEl, *el)) { + curEl = curEl->Right_; + } else { + return AsT(curEl); + } + } + + return AsT(lowerBound); + } + + inline T* Erase(TTreeItem* el) noexcept { + if (el->Tree_ == this) { + return this->EraseImpl(el); + } + + return nullptr; + } + + inline T* EraseImpl(TTreeItem* el) noexcept { + el->Tree_ = nullptr; + + TTreeItem* replacement; + TTreeItem* fixfrom; + long lheight, rheight; + + if (el->Right_) { + replacement = el->Right_; + + while (replacement->Left_) { + replacement = replacement->Left_; + } + + if (replacement->Parent_ == el) { + fixfrom = replacement; + } else { + fixfrom = replacement->Parent_; + } + + if (el == Head_) { + Head_ = replacement; + } + + RemoveEl(replacement, replacement->Right_); + ReplaceEl(el, replacement); + } else if (el->Left_) { + replacement = el->Left_; + + while (replacement->Right_) { + replacement = replacement->Right_; + } + + if (replacement->Parent_ == el) { + fixfrom = replacement; + } else { + fixfrom = replacement->Parent_; + } + + if (el == Tail_) { + Tail_ = replacement; + } + + RemoveEl(replacement, replacement->Left_); + ReplaceEl(el, replacement); + } else { + fixfrom = el->Parent_; + + if (el == Head_) { + Head_ = el->Parent_; + } + + if (el == Tail_) { + Tail_ = el->Parent_; + } + + RemoveEl(el, nullptr); + } + + if (fixfrom == nullptr) { + return AsT(el); + } + + RecalcHeights(fixfrom); + + TTreeItem* ub = FindFirstUnbalEl(fixfrom); + + while (ub) { + lheight = ub->Left_ ? ub->Left_->Height_ : 0; + rheight = ub->Right_ ? ub->Right_->Height_ : 0; + + if (rheight > lheight) { + ub = ub->Right_; + lheight = ub->Left_ ? ub->Left_->Height_ : 0; + rheight = ub->Right_ ? ub->Right_->Height_ : 0; + + if (rheight > lheight) { + ub = ub->Right_; + } else if (rheight < lheight) { + ub = ub->Left_; + } else { + ub = ub->Right_; + } + } else { + ub = ub->Left_; + lheight = ub->Left_ ? ub->Left_->Height_ : 0; + rheight = ub->Right_ ? ub->Right_->Height_ : 0; + if (rheight > lheight) { + ub = ub->Right_; + } else if (rheight < lheight) { + ub = ub->Left_; + } else { + ub = ub->Left_; + } + } + + fixfrom = Rebalance(ub); + ub = FindFirstUnbalEl(fixfrom); + } + + return AsT(el); + } + + inline const_iterator First() const noexcept { + return ConstructFirstConst(this); + } + + inline const_iterator Last() const noexcept { + return ConstructLastConst(this); + } + + inline const_iterator Begin() const noexcept { + return First(); + } + + inline const_iterator End() const noexcept { + return const_iterator(nullptr, this); + } + + inline const_iterator begin() const noexcept { + return Begin(); + } + + inline const_iterator end() const noexcept { + return End(); + } + + inline const_iterator cbegin() const noexcept { + return Begin(); + } + + inline const_iterator cend() const noexcept { + return End(); + } + + inline iterator First() noexcept { + return ConstructFirst(this); + } + + inline iterator Last() noexcept { + return ConstructLast(this); + } + + inline iterator Begin() noexcept { + return First(); + } + + inline iterator End() noexcept { + return iterator(nullptr, this); + } + + inline iterator begin() noexcept { + return Begin(); + } + + inline iterator end() noexcept { + return End(); + } + + inline bool Empty() const noexcept { + return const_cast<TAvlTree*>(this)->Begin() == const_cast<TAvlTree*>(this)->End(); + } + + inline explicit operator bool() const noexcept { + return !this->Empty(); + } + + template <class Functor> + inline void ForEach(Functor& f) { + iterator it = Begin(); + + while (!it.IsEnd()) { + iterator next = it; + ++next; + f(*it); + it = next; + } + } + +private: + inline TTreeItem* Rebalance(TTreeItem* n) noexcept { + long lheight, rheight; + + TTreeItem* a; + TTreeItem* b; + TTreeItem* c; + TTreeItem* t1; + TTreeItem* t2; + TTreeItem* t3; + TTreeItem* t4; + + TTreeItem* p = n->Parent_; + TTreeItem* gp = p->Parent_; + TTreeItem* ggp = gp->Parent_; + + if (gp->Right_ == p) { + if (p->Right_ == n) { + a = gp; + b = p; + c = n; + t1 = gp->Left_; + t2 = p->Left_; + t3 = n->Left_; + t4 = n->Right_; + } else { + a = gp; + b = n; + c = p; + t1 = gp->Left_; + t2 = n->Left_; + t3 = n->Right_; + t4 = p->Right_; + } + } else { + if (p->Right_ == n) { + a = p; + b = n; + c = gp; + t1 = p->Left_; + t2 = n->Left_; + t3 = n->Right_; + t4 = gp->Right_; + } else { + a = n; + b = p; + c = gp; + t1 = n->Left_; + t2 = n->Right_; + t3 = p->Right_; + t4 = gp->Right_; + } + } + + if (ggp == nullptr) { + Root_ = b; + } else if (ggp->Left_ == gp) { + ggp->Left_ = b; + } else { + ggp->Right_ = b; + } + + b->Parent_ = ggp; + b->Left_ = a; + a->Parent_ = b; + b->Right_ = c; + c->Parent_ = b; + a->Left_ = t1; + + if (t1 != nullptr) { + t1->Parent_ = a; + } + + a->Right_ = t2; + + if (t2 != nullptr) { + t2->Parent_ = a; + } + + c->Left_ = t3; + + if (t3 != nullptr) { + t3->Parent_ = c; + } + + c->Right_ = t4; + + if (t4 != nullptr) { + t4->Parent_ = c; + } + + lheight = a->Left_ ? a->Left_->Height_ : 0; + rheight = a->Right_ ? a->Right_->Height_ : 0; + a->Height_ = (lheight > rheight ? lheight : rheight) + 1; + + lheight = c->Left_ ? c->Left_->Height_ : 0; + rheight = c->Right_ ? c->Right_->Height_ : 0; + c->Height_ = (lheight > rheight ? lheight : rheight) + 1; + + lheight = a->Height_; + rheight = c->Height_; + b->Height_ = (lheight > rheight ? lheight : rheight) + 1; + + RecalcHeights(ggp); + + return ggp; + } + + inline void RecalcHeights(TTreeItem* el) noexcept { + long lheight, rheight, new_height; + + while (el) { + lheight = el->Left_ ? el->Left_->Height_ : 0; + rheight = el->Right_ ? el->Right_->Height_ : 0; + + new_height = (lheight > rheight ? lheight : rheight) + 1; + + if (new_height == el->Height_) { + return; + } else { + el->Height_ = new_height; + } + + el = el->Parent_; + } + } + + inline TTreeItem* FindFirstUnbalGP(TTreeItem* el) noexcept { + long lheight, rheight, balanceProp; + TTreeItem* gp; + + if (el == nullptr || el->Parent_ == nullptr || el->Parent_->Parent_ == nullptr) { + return nullptr; + } + + gp = el->Parent_->Parent_; + + while (gp != nullptr) { + lheight = gp->Left_ ? gp->Left_->Height_ : 0; + rheight = gp->Right_ ? gp->Right_->Height_ : 0; + balanceProp = lheight - rheight; + + if (balanceProp < -1 || balanceProp > 1) { + return el; + } + + el = el->Parent_; + gp = gp->Parent_; + } + + return nullptr; + } + + inline TTreeItem* FindFirstUnbalEl(TTreeItem* el) noexcept { + if (el == nullptr) { + return nullptr; + } + + while (el) { + const long lheight = el->Left_ ? el->Left_->Height_ : 0; + const long rheight = el->Right_ ? el->Right_->Height_ : 0; + const long balanceProp = lheight - rheight; + + if (balanceProp < -1 || balanceProp > 1) { + return el; + } + + el = el->Parent_; + } + + return nullptr; + } + + inline void ReplaceEl(TTreeItem* el, TTreeItem* replacement) noexcept { + TTreeItem* parent = el->Parent_; + TTreeItem* left = el->Left_; + TTreeItem* right = el->Right_; + + replacement->Left_ = left; + + if (left) { + left->Parent_ = replacement; + } + + replacement->Right_ = right; + + if (right) { + right->Parent_ = replacement; + } + + replacement->Parent_ = parent; + + if (parent) { + if (parent->Left_ == el) { + parent->Left_ = replacement; + } else { + parent->Right_ = replacement; + } + } else { + Root_ = replacement; + } + + replacement->Height_ = el->Height_; + } + + inline void RemoveEl(TTreeItem* el, TTreeItem* filler) noexcept { + TTreeItem* parent = el->Parent_; + + if (parent) { + if (parent->Left_ == el) { + parent->Left_ = filler; + } else { + parent->Right_ = filler; + } + } else { + Root_ = filler; + } + + if (filler) { + filler->Parent_ = parent; + } + + return; + } + + inline void AttachRebal(TTreeItem* el, TTreeItem* parentEl, TTreeItem* lastLess) { + el->Parent_ = parentEl; + el->Left_ = nullptr; + el->Right_ = nullptr; + el->Height_ = 1; + + if (parentEl != nullptr) { + if (lastLess == parentEl) { + parentEl->Left_ = el; + } else { + parentEl->Right_ = el; + } + + if (Head_->Left_ == el) { + Head_ = el; + } + + if (Tail_->Right_ == el) { + Tail_ = el; + } + } else { + Root_ = el; + Head_ = Tail_ = el; + } + + RecalcHeights(parentEl); + + TTreeItem* ub = FindFirstUnbalGP(el); + + if (ub != nullptr) { + Rebalance(ub); + } + } + +private: + TTreeItem* Root_; + TTreeItem* Head_; + TTreeItem* Tail_; +}; + +template <class T, class C> +struct TAvlTreeItem: public TNonCopyable { +public: + using TTree = TAvlTree<T, C>; + friend class TAvlTree<T, C>; + friend typename TAvlTree<T, C>::TConstIterator; + friend typename TAvlTree<T, C>::TIterator; + + inline TAvlTreeItem() noexcept + : Left_(nullptr) + , Right_(nullptr) + , Parent_(nullptr) + , Height_(0) + , Tree_(nullptr) + { + } + + inline ~TAvlTreeItem() noexcept { + Unlink(); + } + + inline void Unlink() noexcept { + if (Tree_) { + Tree_->EraseImpl(this); + } + } + +private: + TAvlTreeItem* Left_; + TAvlTreeItem* Right_; + TAvlTreeItem* Parent_; + long Height_; + TTree* Tree_; +}; diff --git a/library/cpp/containers/intrusive_avl_tree/ut/avltree_ut.cpp b/library/cpp/containers/intrusive_avl_tree/ut/avltree_ut.cpp new file mode 100644 index 00000000000..cab2365ccec --- /dev/null +++ b/library/cpp/containers/intrusive_avl_tree/ut/avltree_ut.cpp @@ -0,0 +1,103 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/containers/intrusive_avl_tree/avltree.h> + +class TAvlTreeTest: public TTestBase { + UNIT_TEST_SUITE(TAvlTreeTest); + UNIT_TEST(TestLowerBound); + UNIT_TEST(TestIterator); + UNIT_TEST_SUITE_END(); + +private: + void TestLowerBound(); + void TestIterator(); + + class TIt; + struct TItCompare { + static inline bool Compare(const TIt& l, const TIt& r) noexcept; + }; + + class TIt: public TAvlTreeItem<TIt, TItCompare> { + public: + TIt(int val = 0) + : Val(val) + { + } + + int Val; + }; + + using TIts = TAvlTree<TIt, TItCompare>; +}; + +inline bool TAvlTreeTest::TItCompare::Compare(const TIt& l, const TIt& r) noexcept { + return l.Val < r.Val; +} + +UNIT_TEST_SUITE_REGISTRATION(TAvlTreeTest); + +void TAvlTreeTest::TestLowerBound() { + TIts its; + TIt it1(5); + TIt it2(2); + TIt it3(10); + TIt it4(879); + TIt it5(1); + TIt it6(52); + TIt it7(4); + TIt it8(5); + its.Insert(&it1); + its.Insert(&it2); + its.Insert(&it3); + its.Insert(&it4); + its.Insert(&it5); + its.Insert(&it6); + its.Insert(&it7); + its.Insert(&it8); + + TIt it_zero(0); + TIt it_large(1000); + UNIT_ASSERT_EQUAL(its.LowerBound(&it3), &it3); + UNIT_ASSERT_EQUAL(its.LowerBound(&it_zero), &it5); + UNIT_ASSERT_EQUAL(its.LowerBound(&it_large), nullptr); +} + +void TAvlTreeTest::TestIterator() { + TIts its; + TIt it1(1); + TIt it2(2); + TIt it3(3); + TIt it4(4); + TIt it5(5); + TIt it6(6); + TIt it7(7); + + its.Insert(&it3); + its.Insert(&it1); + its.Insert(&it7); + its.Insert(&it5); + its.Insert(&it4); + its.Insert(&it6); + its.Insert(&it2); + + TVector<int> res; + for (const TIt& i : its) { + res.push_back(i.Val); + } + + TVector<int> expected{1, 2, 3, 4, 5, 6, 7}; + UNIT_ASSERT_EQUAL(res, expected); + + res.clear(); + for (TIt& i : its) { + res.push_back(i.Val); + } + UNIT_ASSERT_EQUAL(res, expected); + + res.clear(); + const TIts* constIts = &its; + for (TIts::const_iterator i = constIts->begin(); i != constIts->end(); ++i) { + res.push_back(i->Val); + } + UNIT_ASSERT_EQUAL(res, expected); +} diff --git a/library/cpp/containers/intrusive_avl_tree/ut/ya.make b/library/cpp/containers/intrusive_avl_tree/ut/ya.make new file mode 100644 index 00000000000..87920306d78 --- /dev/null +++ b/library/cpp/containers/intrusive_avl_tree/ut/ya.make @@ -0,0 +1,12 @@ +UNITTEST_FOR(library/cpp/containers/intrusive_avl_tree) + +OWNER( + pg + g:util +) + +SRCS( + avltree_ut.cpp +) + +END() diff --git a/library/cpp/containers/intrusive_avl_tree/ya.make b/library/cpp/containers/intrusive_avl_tree/ya.make new file mode 100644 index 00000000000..6b061f27609 --- /dev/null +++ b/library/cpp/containers/intrusive_avl_tree/ya.make @@ -0,0 +1,12 @@ +LIBRARY() + +OWNER( + pg + g:util +) + +SRCS( + avltree.cpp +) + +END() diff --git a/library/cpp/containers/intrusive_rb_tree/fuzz/rb_tree_fuzzing.cpp b/library/cpp/containers/intrusive_rb_tree/fuzz/rb_tree_fuzzing.cpp new file mode 100644 index 00000000000..92370760b59 --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/fuzz/rb_tree_fuzzing.cpp @@ -0,0 +1,65 @@ +#include <library/cpp/containers/intrusive_rb_tree/rb_tree.h> + +#include <util/generic/deque.h> +#include <stdint.h> +#include <stddef.h> + +struct TCmp { + template <class T> + static inline bool Compare(const T& l, const T& r) { + return l.N < r.N; + } + + template <class T> + static inline bool Compare(const T& l, ui8 r) { + return l.N < r; + } + + template <class T> + static inline bool Compare(ui8 l, const T& r) { + return l < r.N; + } +}; + +class TNode: public TRbTreeItem<TNode, TCmp> { +public: + inline TNode(ui8 n) noexcept + : N(n) + { + } + + ui8 N; +}; + +using TTree = TRbTree<TNode, TCmp>; + +extern "C" int LLVMFuzzerTestOneInput(const ui8* data, size_t size) { + TDeque<TNode> records; + const ui8 half = 128u; + TTree tree; + for (size_t i = 0; i < size; ++i) { + if (data[i] / half == 0) { + records.emplace_back(data[i] % half); + tree.Insert(&records.back()); + } else { + auto* ptr = tree.Find(data[i] % half); + if (ptr != nullptr) { + tree.Erase(ptr); + } + } + auto check = [](const TNode& node) { + size_t childrens = 1; + if (node.Left_) { + Y_ENSURE(static_cast<const TNode*>(node.Left_)->N <= node.N); + childrens += node.Left_->Children_; + } + if (node.Right_) { + Y_ENSURE(node.N <= static_cast<const TNode*>(node.Right_)->N); + childrens += node.Right_->Children_; + } + Y_ENSURE(childrens == node.Children_); + }; + tree.ForEach(check); + } + return 0; +} diff --git a/library/cpp/containers/intrusive_rb_tree/fuzz/ya.make b/library/cpp/containers/intrusive_rb_tree/fuzz/ya.make new file mode 100644 index 00000000000..61be9919e6d --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/fuzz/ya.make @@ -0,0 +1,20 @@ +FUZZ() + +OWNER( + g:util + mikari +) + +SIZE(LARGE) + +TAG(ya:fat) + +PEERDIR( + library/cpp/containers/intrusive_rb_tree +) + +SRCS( + rb_tree_fuzzing.cpp +) + +END() diff --git a/library/cpp/containers/intrusive_rb_tree/rb_tree.cpp b/library/cpp/containers/intrusive_rb_tree/rb_tree.cpp new file mode 100644 index 00000000000..535b536c41f --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/rb_tree.cpp @@ -0,0 +1 @@ +#include "rb_tree.h" diff --git a/library/cpp/containers/intrusive_rb_tree/rb_tree.h b/library/cpp/containers/intrusive_rb_tree/rb_tree.h new file mode 100644 index 00000000000..0259452a145 --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/rb_tree.h @@ -0,0 +1,818 @@ +#pragma once + +#include <util/generic/utility.h> +#include <util/generic/yexception.h> + +using TRbTreeColorType = bool; + +#define RBTreeRed false +#define RBTreeBlack true + +struct TRbTreeNodeBase { + using TColorType = TRbTreeColorType; + using TBasePtr = TRbTreeNodeBase*; + + TColorType Color_; + TBasePtr Parent_; + TBasePtr Left_; + TBasePtr Right_; + size_t Children_; + + inline TRbTreeNodeBase() noexcept { + ReInitNode(); + } + + inline void ReInitNode() noexcept { + Color_ = RBTreeBlack; + Parent_ = nullptr; + Left_ = nullptr; + Right_ = nullptr; + Children_ = 1; + } + + static TBasePtr MinimumNode(TBasePtr x) { + while (x->Left_ != nullptr) + x = x->Left_; + + return x; + } + + static TBasePtr MaximumNode(TBasePtr x) { + while (x->Right_ != nullptr) + x = x->Right_; + + return x; + } + + static TBasePtr ByIndex(TBasePtr x, size_t index) { + if (x->Left_ != nullptr) { + if (index < x->Left_->Children_) + return ByIndex(x->Left_, index); + index -= x->Left_->Children_; + } + if (0 == index) + return x; + if (!x->Right_) + ythrow yexception() << "index not found"; + return ByIndex(x->Right_, index - 1); + } +}; + +struct TRbTreeBaseIterator; + +template <class TDummy> +class TRbGlobal { +public: + using TBasePtr = TRbTreeNodeBase*; + + static void Rebalance(TBasePtr x, TBasePtr& root); + static TBasePtr RebalanceForErase(TBasePtr z, TBasePtr& root, TBasePtr& leftmost, TBasePtr& rightmost); + static void DecrementChildrenUntilRoot(TBasePtr x, TBasePtr root); + static void RecalcChildren(TBasePtr x); + + static TBasePtr IncrementNode(TBasePtr); + static TBasePtr DecrementNode(TBasePtr); + static void RotateLeft(TBasePtr x, TBasePtr& root); + static void RotateRight(TBasePtr x, TBasePtr& root); +}; + +using TRbGlobalInst = TRbGlobal<bool>; + +struct TRbTreeBaseIterator { + using TBasePtr = TRbTreeNodeBase*; + TBasePtr Node_; + + inline TRbTreeBaseIterator(TBasePtr x = nullptr) noexcept + : Node_(x) + { + } +}; + +template <class TValue, class TTraits> +struct TRbTreeIterator: public TRbTreeBaseIterator { + using TReference = typename TTraits::TReference; + using TPointer = typename TTraits::TPointer; + using TSelf = TRbTreeIterator<TValue, TTraits>; + using TBasePtr = TRbTreeNodeBase*; + + inline TRbTreeIterator() noexcept = default; + + template <class T1> + inline TRbTreeIterator(const T1& x) noexcept + : TRbTreeBaseIterator(x) + { + } + + inline TReference operator*() const noexcept { + return *static_cast<TValue*>(Node_); + } + + inline TPointer operator->() const noexcept { + return static_cast<TValue*>(Node_); + } + + inline TSelf& operator++() noexcept { + Node_ = TRbGlobalInst::IncrementNode(Node_); + return *this; + } + + inline TSelf operator++(int) noexcept { + TSelf tmp = *this; + ++(*this); + return tmp; + } + + inline TSelf& operator--() noexcept { + Node_ = TRbGlobalInst::DecrementNode(Node_); + return *this; + } + + inline TSelf operator--(int) noexcept { + TSelf tmp = *this; + --(*this); + return tmp; + } + + template <class T1> + inline bool operator==(const T1& rhs) const noexcept { + return Node_ == rhs.Node_; + } + + template <class T1> + inline bool operator!=(const T1& rhs) const noexcept { + return Node_ != rhs.Node_; + } +}; + +template <class TValue, class TCmp> +class TRbTree { + struct TCmpAdaptor: public TCmp { + inline TCmpAdaptor() noexcept = default; + + inline TCmpAdaptor(const TCmp& cmp) noexcept + : TCmp(cmp) + { + } + + template <class T1, class T2> + inline bool operator()(const T1& l, const T2& r) const { + return TCmp::Compare(l, r); + } + }; + + struct TNonConstTraits { + using TReference = TValue&; + using TPointer = TValue*; + }; + + struct TConstTraits { + using TReference = const TValue&; + using TPointer = const TValue*; + }; + + using TNodeBase = TRbTreeNodeBase; + using TBasePtr = TRbTreeNodeBase*; + using TColorType = TRbTreeColorType; + +public: + class TRealNode: public TNodeBase { + public: + inline TRealNode() + : Tree_(nullptr) + { + } + + inline ~TRealNode() { + UnLink(); + } + + inline void UnLink() noexcept { + if (Tree_) { + Tree_->EraseImpl(this); + ReInitNode(); + Tree_ = nullptr; + } + } + + inline void SetRbTreeParent(TRbTree* parent) noexcept { + Tree_ = parent; + } + + inline TRbTree* ParentTree() const noexcept { + return Tree_; + } + + private: + TRbTree* Tree_; + }; + + using TIterator = TRbTreeIterator<TValue, TNonConstTraits>; + using TConstIterator = TRbTreeIterator<TValue, TConstTraits>; + + inline TRbTree() noexcept { + Init(); + } + + inline TRbTree(const TCmp& cmp) noexcept + : KeyCompare_(cmp) + { + Init(); + } + + inline void Init() noexcept { + Data_.Color_ = RBTreeRed; + Data_.Parent_ = nullptr; + Data_.Left_ = &Data_; + Data_.Right_ = &Data_; + Data_.Children_ = 0; + } + + struct TDestroy { + inline void operator()(TValue& v) const noexcept { + v.SetRbTreeParent(nullptr); + v.ReInitNode(); + } + }; + + inline ~TRbTree() { + ForEachNoOrder(TDestroy()); + } + + inline void Clear() noexcept { + ForEachNoOrder(TDestroy()); + Init(); + } + + template <class F> + inline void ForEachNoOrder(const F& f) { + ForEachNoOrder(Root(), f); + } + + template <class F> + inline void ForEachNoOrder(TNodeBase* n, const F& f) { + if (n && n != &Data_) { + ForEachNoOrder(n->Left_, f); + ForEachNoOrder(n->Right_, f); + f(ValueNode(n)); + } + } + + inline TIterator Begin() noexcept { + return LeftMost(); + } + + inline TConstIterator Begin() const noexcept { + return LeftMost(); + } + + inline TIterator End() noexcept { + return &this->Data_; + } + + inline TConstIterator End() const noexcept { + return const_cast<TBasePtr>(&this->Data_); + } + + inline bool Empty() const noexcept { + return this->Begin() == this->End(); + } + + inline explicit operator bool() const noexcept { + return !this->Empty(); + } + + inline TIterator Insert(TValue* val) { + return Insert(*val); + } + + inline TIterator Insert(TValue& val) { + val.UnLink(); + + TBasePtr y = &this->Data_; + TBasePtr x = Root(); + + while (x != nullptr) { + ++(x->Children_); + y = x; + + if (KeyCompare_(ValueNode(&val), ValueNode(x))) { + x = LeftNode(x); + } else { + x = RightNode(x); + } + } + + return InsertImpl(y, &val, x); + } + + template <class F> + inline void ForEach(F& f) { + TIterator it = Begin(); + + while (it != End()) { + f(*it++); + } + } + + inline void Erase(TValue& val) noexcept { + val.UnLink(); + } + + inline void Erase(TValue* val) noexcept { + Erase(*val); + } + + inline void Erase(TIterator pos) noexcept { + Erase(*pos); + } + + inline void EraseImpl(TNodeBase* val) noexcept { + TRbGlobalInst::RebalanceForErase(val, this->Data_.Parent_, this->Data_.Left_, this->Data_.Right_); + } + + template <class T1> + inline TValue* Find(const T1& k) const { + TBasePtr y = nullptr; + TBasePtr x = Root(); // Current node. + + while (x != nullptr) + if (!KeyCompare_(ValueNode(x), k)) + y = x, x = LeftNode(x); + else + x = RightNode(x); + + if (y) { + if (KeyCompare_(k, ValueNode(y))) { + y = nullptr; + } + } + + return static_cast<TValue*>(y); + } + + size_t GetIndex(TBasePtr x) const { + size_t index = 0; + + if (x->Left_ != nullptr) { + index += x->Left_->Children_; + } + + while (x != nullptr && x->Parent_ != nullptr && x->Parent_ != const_cast<TBasePtr>(&this->Data_)) { + if (x->Parent_->Right_ == x && x->Parent_->Left_ != nullptr) { + index += x->Parent_->Left_->Children_; + } + if (x->Parent_->Right_ == x) { + index += 1; + } + x = x->Parent_; + } + + return index; + } + + template <class T1> + inline TBasePtr LowerBound(const T1& k) const { + TBasePtr y = const_cast<TBasePtr>(&this->Data_); /* Last node which is not less than k. */ + TBasePtr x = Root(); /* Current node. */ + + while (x != nullptr) + if (!KeyCompare_(ValueNode(x), k)) + y = x, x = LeftNode(x); + else + x = RightNode(x); + + return y; + } + + template <class T1> + inline TBasePtr UpperBound(const T1& k) const { + TBasePtr y = const_cast<TBasePtr>(&this->Data_); /* Last node which is greater than k. */ + TBasePtr x = Root(); /* Current node. */ + + while (x != nullptr) + if (KeyCompare_(k, ValueNode(x))) + y = x, x = LeftNode(x); + else + x = RightNode(x); + + return y; + } + + template <class T1> + inline size_t LessCount(const T1& k) const { + auto x = LowerBound(k); + if (x == const_cast<TBasePtr>(&this->Data_)) { + if (const auto root = Root()) { + return root->Children_; + } else { + return 0; + } + } else { + return GetIndex(x); + } + } + + template <class T1> + inline size_t NotLessCount(const T1& k) const { + return Root()->Children_ - LessCount<T1>(k); + } + + template <class T1> + inline size_t GreaterCount(const T1& k) const { + auto x = UpperBound(k); + if (x == const_cast<TBasePtr>(&this->Data_)) { + return 0; + } else { + return Root()->Children_ - GetIndex(x); + } + } + + template <class T1> + inline size_t NotGreaterCount(const T1& k) const { + return Root()->Children_ - GreaterCount<T1>(k); + } + + TValue* ByIndex(size_t index) { + return static_cast<TValue*>(TRbTreeNodeBase::ByIndex(Root(), index)); + } + +private: + // CRP 7/10/00 inserted argument on_right, which is another hint (meant to + // act like on_left and ignore a portion of the if conditions -- specify + // on_right != nullptr to bypass comparison as false or on_left != nullptr to bypass + // comparison as true) + TIterator InsertImpl(TRbTreeNodeBase* parent, TRbTreeNodeBase* val, TRbTreeNodeBase* on_left = nullptr, TRbTreeNodeBase* on_right = nullptr) { + ValueNode(val).SetRbTreeParent(this); + TBasePtr new_node = val; + + if (parent == &this->Data_) { + LeftNode(parent) = new_node; + // also makes LeftMost() = new_node + Root() = new_node; + RightMost() = new_node; + } else if (on_right == nullptr && + // If on_right != nullptr, the remainder fails to false + (on_left != nullptr || + // If on_left != nullptr, the remainder succeeds to true + KeyCompare_(ValueNode(val), ValueNode(parent)))) + { + LeftNode(parent) = new_node; + if (parent == LeftMost()) + // maintain LeftMost() pointing to min node + LeftMost() = new_node; + } else { + RightNode(parent) = new_node; + if (parent == RightMost()) + // maintain RightMost() pointing to max node + RightMost() = new_node; + } + ParentNode(new_node) = parent; + TRbGlobalInst::Rebalance(new_node, this->Data_.Parent_); + return new_node; + } + + TBasePtr Root() const { + return this->Data_.Parent_; + } + + TBasePtr LeftMost() const { + return this->Data_.Left_; + } + + TBasePtr RightMost() const { + return this->Data_.Right_; + } + + TBasePtr& Root() { + return this->Data_.Parent_; + } + + TBasePtr& LeftMost() { + return this->Data_.Left_; + } + + TBasePtr& RightMost() { + return this->Data_.Right_; + } + + static TBasePtr& LeftNode(TBasePtr x) { + return x->Left_; + } + + static TBasePtr& RightNode(TBasePtr x) { + return x->Right_; + } + + static TBasePtr& ParentNode(TBasePtr x) { + return x->Parent_; + } + + static TValue& ValueNode(TBasePtr x) { + return *static_cast<TValue*>(x); + } + + static TBasePtr MinimumNode(TBasePtr x) { + return TRbTreeNodeBase::MinimumNode(x); + } + + static TBasePtr MaximumNode(TBasePtr x) { + return TRbTreeNodeBase::MaximumNode(x); + } + +private: + TCmpAdaptor KeyCompare_; + TNodeBase Data_; +}; + +template <class TValue, class TCmp> +class TRbTreeItem: public TRbTree<TValue, TCmp>::TRealNode { +}; + +template <class TDummy> +void TRbGlobal<TDummy>::RotateLeft(TRbTreeNodeBase* x, TRbTreeNodeBase*& root) { + TRbTreeNodeBase* y = x->Right_; + x->Right_ = y->Left_; + if (y->Left_ != nullptr) + y->Left_->Parent_ = x; + y->Parent_ = x->Parent_; + + if (x == root) + root = y; + else if (x == x->Parent_->Left_) + x->Parent_->Left_ = y; + else + x->Parent_->Right_ = y; + y->Left_ = x; + x->Parent_ = y; + y->Children_ = x->Children_; + x->Children_ = ((x->Left_) ? x->Left_->Children_ : 0) + ((x->Right_) ? x->Right_->Children_ : 0) + 1; +} + +template <class TDummy> +void TRbGlobal<TDummy>::RotateRight(TRbTreeNodeBase* x, TRbTreeNodeBase*& root) { + TRbTreeNodeBase* y = x->Left_; + x->Left_ = y->Right_; + if (y->Right_ != nullptr) + y->Right_->Parent_ = x; + y->Parent_ = x->Parent_; + + if (x == root) + root = y; + else if (x == x->Parent_->Right_) + x->Parent_->Right_ = y; + else + x->Parent_->Left_ = y; + y->Right_ = x; + x->Parent_ = y; + y->Children_ = x->Children_; + x->Children_ = ((x->Left_) ? x->Left_->Children_ : 0) + ((x->Right_) ? x->Right_->Children_ : 0) + 1; +} + +template <class TDummy> +void TRbGlobal<TDummy>::Rebalance(TRbTreeNodeBase* x, TRbTreeNodeBase*& root) { + x->Color_ = RBTreeRed; + while (x != root && x->Parent_->Color_ == RBTreeRed) { + if (x->Parent_ == x->Parent_->Parent_->Left_) { + TRbTreeNodeBase* y = x->Parent_->Parent_->Right_; + if (y && y->Color_ == RBTreeRed) { + x->Parent_->Color_ = RBTreeBlack; + y->Color_ = RBTreeBlack; + x->Parent_->Parent_->Color_ = RBTreeRed; + x = x->Parent_->Parent_; + } else { + if (x == x->Parent_->Right_) { + x = x->Parent_; + RotateLeft(x, root); + } + x->Parent_->Color_ = RBTreeBlack; + x->Parent_->Parent_->Color_ = RBTreeRed; + RotateRight(x->Parent_->Parent_, root); + } + } else { + TRbTreeNodeBase* y = x->Parent_->Parent_->Left_; + if (y && y->Color_ == RBTreeRed) { + x->Parent_->Color_ = RBTreeBlack; + y->Color_ = RBTreeBlack; + x->Parent_->Parent_->Color_ = RBTreeRed; + x = x->Parent_->Parent_; + } else { + if (x == x->Parent_->Left_) { + x = x->Parent_; + RotateRight(x, root); + } + x->Parent_->Color_ = RBTreeBlack; + x->Parent_->Parent_->Color_ = RBTreeRed; + RotateLeft(x->Parent_->Parent_, root); + } + } + } + root->Color_ = RBTreeBlack; +} + +template <class TDummy> +void TRbGlobal<TDummy>::RecalcChildren(TRbTreeNodeBase* x) { + x->Children_ = ((x->Left_) ? x->Left_->Children_ : 0) + ((x->Right_) ? x->Right_->Children_ : 0) + 1; +} + +template <class TDummy> +void TRbGlobal<TDummy>::DecrementChildrenUntilRoot(TRbTreeNodeBase* x, TRbTreeNodeBase* root) { + auto* ptr = x; + --ptr->Children_; + while (ptr != root) { + ptr = ptr->Parent_; + --ptr->Children_; + } +} + +template <class TDummy> +TRbTreeNodeBase* TRbGlobal<TDummy>::RebalanceForErase(TRbTreeNodeBase* z, + TRbTreeNodeBase*& root, + TRbTreeNodeBase*& leftmost, + TRbTreeNodeBase*& rightmost) { + TRbTreeNodeBase* y = z; + TRbTreeNodeBase* x; + TRbTreeNodeBase* x_parent; + + if (y->Left_ == nullptr) // z has at most one non-null child. y == z. + x = y->Right_; // x might be null. + else { + if (y->Right_ == nullptr) // z has exactly one non-null child. y == z. + x = y->Left_; // x is not null. + else { // z has two non-null children. Set y to + y = TRbTreeNodeBase::MinimumNode(y->Right_); // z's successor. x might be null. + x = y->Right_; + } + } + + if (y != z) { + // relink y in place of z. y is z's successor + z->Left_->Parent_ = y; + y->Left_ = z->Left_; + if (y != z->Right_) { + x_parent = y->Parent_; + if (x) + x->Parent_ = y->Parent_; + y->Parent_->Left_ = x; // y must be a child of mLeft + y->Right_ = z->Right_; + z->Right_->Parent_ = y; + } else + x_parent = y; + if (root == z) + root = y; + else if (z->Parent_->Left_ == z) + z->Parent_->Left_ = y; + else + z->Parent_->Right_ = y; + y->Parent_ = z->Parent_; + DoSwap(y->Color_, z->Color_); + + RecalcChildren(y); + if (x_parent != y) { + --x_parent->Children_; + } + if (x_parent != root) { + DecrementChildrenUntilRoot(x_parent->Parent_, root); + } + y = z; + // y now points to node to be actually deleted + } else { + // y == z + x_parent = y->Parent_; + if (x) + x->Parent_ = y->Parent_; + if (root == z) + root = x; + else { + if (z->Parent_->Left_ == z) + z->Parent_->Left_ = x; + else + z->Parent_->Right_ = x; + DecrementChildrenUntilRoot(z->Parent_, root); // we lost y + } + + if (leftmost == z) { + if (z->Right_ == nullptr) // z->mLeft must be null also + leftmost = z->Parent_; + // makes leftmost == _M_header if z == root + else + leftmost = TRbTreeNodeBase::MinimumNode(x); + } + if (rightmost == z) { + if (z->Left_ == nullptr) // z->mRight must be null also + rightmost = z->Parent_; + // makes rightmost == _M_header if z == root + else // x == z->mLeft + rightmost = TRbTreeNodeBase::MaximumNode(x); + } + } + + if (y->Color_ != RBTreeRed) { + while (x != root && (x == nullptr || x->Color_ == RBTreeBlack)) + if (x == x_parent->Left_) { + TRbTreeNodeBase* w = x_parent->Right_; + if (w->Color_ == RBTreeRed) { + w->Color_ = RBTreeBlack; + x_parent->Color_ = RBTreeRed; + RotateLeft(x_parent, root); + w = x_parent->Right_; + } + if ((w->Left_ == nullptr || + w->Left_->Color_ == RBTreeBlack) && + (w->Right_ == nullptr || + w->Right_->Color_ == RBTreeBlack)) + { + w->Color_ = RBTreeRed; + x = x_parent; + x_parent = x_parent->Parent_; + } else { + if (w->Right_ == nullptr || w->Right_->Color_ == RBTreeBlack) { + if (w->Left_) + w->Left_->Color_ = RBTreeBlack; + w->Color_ = RBTreeRed; + RotateRight(w, root); + w = x_parent->Right_; + } + w->Color_ = x_parent->Color_; + x_parent->Color_ = RBTreeBlack; + if (w->Right_) + w->Right_->Color_ = RBTreeBlack; + RotateLeft(x_parent, root); + break; + } + } else { + // same as above, with mRight <-> mLeft. + TRbTreeNodeBase* w = x_parent->Left_; + if (w->Color_ == RBTreeRed) { + w->Color_ = RBTreeBlack; + x_parent->Color_ = RBTreeRed; + RotateRight(x_parent, root); + w = x_parent->Left_; + } + if ((w->Right_ == nullptr || + w->Right_->Color_ == RBTreeBlack) && + (w->Left_ == nullptr || + w->Left_->Color_ == RBTreeBlack)) + { + w->Color_ = RBTreeRed; + x = x_parent; + x_parent = x_parent->Parent_; + } else { + if (w->Left_ == nullptr || w->Left_->Color_ == RBTreeBlack) { + if (w->Right_) + w->Right_->Color_ = RBTreeBlack; + w->Color_ = RBTreeRed; + RotateLeft(w, root); + w = x_parent->Left_; + } + w->Color_ = x_parent->Color_; + x_parent->Color_ = RBTreeBlack; + if (w->Left_) + w->Left_->Color_ = RBTreeBlack; + RotateRight(x_parent, root); + break; + } + } + if (x) + x->Color_ = RBTreeBlack; + } + return y; +} + +template <class TDummy> +TRbTreeNodeBase* TRbGlobal<TDummy>::DecrementNode(TRbTreeNodeBase* Node_) { + if (Node_->Color_ == RBTreeRed && Node_->Parent_->Parent_ == Node_) + Node_ = Node_->Right_; + else if (Node_->Left_ != nullptr) { + Node_ = TRbTreeNodeBase::MaximumNode(Node_->Left_); + } else { + TBasePtr y = Node_->Parent_; + while (Node_ == y->Left_) { + Node_ = y; + y = y->Parent_; + } + Node_ = y; + } + return Node_; +} + +template <class TDummy> +TRbTreeNodeBase* TRbGlobal<TDummy>::IncrementNode(TRbTreeNodeBase* Node_) { + if (Node_->Right_ != nullptr) { + Node_ = TRbTreeNodeBase::MinimumNode(Node_->Right_); + } else { + TBasePtr y = Node_->Parent_; + while (Node_ == y->Right_) { + Node_ = y; + y = y->Parent_; + } + // check special case: This is necessary if mNode is the + // _M_head and the tree contains only a single node y. In + // that case parent, left and right all point to y! + if (Node_->Right_ != y) + Node_ = y; + } + return Node_; +} + +#undef RBTreeRed +#undef RBTreeBlack diff --git a/library/cpp/containers/intrusive_rb_tree/rb_tree_ut.cpp b/library/cpp/containers/intrusive_rb_tree/rb_tree_ut.cpp new file mode 100644 index 00000000000..c34ed1fd9b4 --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/rb_tree_ut.cpp @@ -0,0 +1,298 @@ +#include "rb_tree.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/random/fast.h> +#include <util/random/easy.h> +#include <util/random/shuffle.h> + +class TRedBlackTreeTest: public TTestBase { + struct TCmp { + template <class T> + static inline bool Compare(const T& l, const T& r) { + return l.N < r.N; + } + + template <class T> + static inline bool Compare(const T& l, int r) { + return l.N < r; + } + + template <class T> + static inline bool Compare(int l, const T& r) { + return l < r.N; + } + }; + + class TNode: public TRbTreeItem<TNode, TCmp> { + public: + inline TNode(int n) noexcept + : N(n) + { + } + + int N; + }; + + using TTree = TRbTree<TNode, TCmp>; + + UNIT_TEST_SUITE(TRedBlackTreeTest); + UNIT_TEST(TestEmpty) + UNIT_TEST(TestInsert) + UNIT_TEST(TestErase) + UNIT_TEST(TestFind) + UNIT_TEST(TestStress) + UNIT_TEST(TestGettingIndexWithDifferentValues) + UNIT_TEST(TestCheckChildrenAfterErase) + UNIT_TEST(TestGettingIndexWithDifferentValuesAfterErase) + UNIT_TEST(TestGettingIndexWithEqualValues) + UNIT_TEST(TestLessCountOnEmptyTree) + UNIT_TEST_SUITE_END(); + +private: + inline void TestStress() { + TVector<TSimpleSharedPtr<TNode>> nodes; + + for (int i = 0; i < 1000; ++i) { + nodes.push_back(new TNode(i)); + } + + TTree tree; + TReallyFastRng32 rnd(Random()); + + for (size_t i = 0; i < 1000000; ++i) { + tree.Insert(nodes[rnd.Uniform(nodes.size())].Get()); + } + + for (TTree::TConstIterator it = tree.Begin(); it != tree.End();) { + const int v1 = it->N; + + if (++it == tree.End()) { + break; + } + + const int v2 = it->N; + + UNIT_ASSERT(v1 < v2); + } + } + + inline void TestGettingIndexWithDifferentValues() { + TVector<TSimpleSharedPtr<TNode>> nodes; + size_t N = 1000; + + for (size_t i = 0; i < N; ++i) { + nodes.push_back(new TNode(int(i))); + } + + TTree tree; + Shuffle(nodes.begin(), nodes.end()); + + for (size_t i = 0; i < N; ++i) { + tree.Insert(nodes[i].Get()); + } + + for (size_t i = 0; i < N; ++i) { + UNIT_ASSERT_EQUAL(tree.LessCount(i), i); + UNIT_ASSERT_EQUAL(tree.NotGreaterCount(i), i + 1); + UNIT_ASSERT_EQUAL(tree.GreaterCount(i), N - i - 1); + UNIT_ASSERT_EQUAL(tree.NotLessCount(i), N - i); + + auto nodePtr = tree.Find(i); + UNIT_ASSERT_EQUAL(tree.GetIndex(nodePtr), i); + UNIT_ASSERT_EQUAL(tree.GetIndex(nodes[i].Get()), static_cast<size_t>(nodes[i]->N)); + } + } + + inline void TestCheckChildrenAfterErase() { + TVector<TSimpleSharedPtr<TNode>> nodes; + size_t N = 1000; + + for (size_t i = 0; i < N; ++i) { + nodes.push_back(new TNode(int(i))); + } + + TTree tree; + Shuffle(nodes.begin(), nodes.end()); + + for (size_t i = 0; i < N; ++i) { + tree.Insert(nodes[i].Get()); + } + auto checker = [](const TTree& tree) { + for (auto node = tree.Begin(); node != tree.End(); ++node) { + size_t childrens = 1; + if (node->Left_) { + childrens += node->Left_->Children_; + } + if (node->Right_) { + childrens += node->Right_->Children_; + } + UNIT_ASSERT_VALUES_EQUAL(childrens, node->Children_); + } + }; + + for (auto node : nodes) { + tree.Erase(node.Get()); + checker(tree); + } + } + + inline void TestGettingIndexWithDifferentValuesAfterErase() { + TVector<TSimpleSharedPtr<TNode>> nodes; + size_t N = 1000; + + for (size_t i = 0; i < N; ++i) { + nodes.push_back(new TNode(int(i))); + } + + TTree tree; + Shuffle(nodes.begin(), nodes.end()); + + for (size_t i = 0; i < N; ++i) { + tree.Insert(nodes[i].Get()); + } + { + size_t index = 0; + for (auto node = tree.Begin(); node != tree.End(); ++node, ++index) { + UNIT_ASSERT_VALUES_EQUAL(tree.GetIndex(&*node), index); + UNIT_ASSERT_VALUES_EQUAL(tree.ByIndex(index)->N, node->N); + UNIT_ASSERT_VALUES_EQUAL(node->N, index); + } + } + + for (size_t i = 1; i < N; i += 2) { + auto* node = tree.Find(i); + UNIT_ASSERT_VALUES_EQUAL(node->N, i); + tree.Erase(node); + } + { + size_t index = 0; + for (auto node = tree.Begin(); node != tree.End(); ++node, ++index) { + UNIT_ASSERT_VALUES_EQUAL(tree.GetIndex(&*node), index); + UNIT_ASSERT_VALUES_EQUAL(tree.ByIndex(index)->N, node->N); + UNIT_ASSERT_VALUES_EQUAL(node->N, 2 * index); + } + } + } + + inline void TestGettingIndexWithEqualValues() { + TVector<TSimpleSharedPtr<TNode>> nodes; + size_t N = 1000; + + for (size_t i = 0; i < N; ++i) { + nodes.push_back(new TNode(0)); + } + + TTree tree; + + for (size_t i = 0; i < N; ++i) { + tree.Insert(nodes[i].Get()); + } + + for (size_t i = 0; i < N; ++i) { + UNIT_ASSERT_EQUAL(tree.LessCount(nodes[i]->N), 0); + UNIT_ASSERT_EQUAL(tree.NotGreaterCount(nodes[i]->N), N); + UNIT_ASSERT_EQUAL(tree.GreaterCount(nodes[i]->N), 0); + UNIT_ASSERT_EQUAL(tree.NotLessCount(nodes[i]->N), N); + + UNIT_ASSERT_EQUAL(tree.LessCount(*nodes[i].Get()), 0); + UNIT_ASSERT_EQUAL(tree.NotGreaterCount(*nodes[i].Get()), N); + UNIT_ASSERT_EQUAL(tree.GreaterCount(*nodes[i].Get()), 0); + UNIT_ASSERT_EQUAL(tree.NotLessCount(*nodes[i].Get()), N); + } + } + + inline void TestFind() { + TTree tree; + + { + TNode n1(1); + TNode n2(2); + TNode n3(3); + + tree.Insert(n1); + tree.Insert(n2); + tree.Insert(n3); + + UNIT_ASSERT_EQUAL(tree.Find(1)->N, 1); + UNIT_ASSERT_EQUAL(tree.Find(2)->N, 2); + UNIT_ASSERT_EQUAL(tree.Find(3)->N, 3); + + UNIT_ASSERT(!tree.Find(0)); + UNIT_ASSERT(!tree.Find(4)); + UNIT_ASSERT(!tree.Find(1234567)); + } + + UNIT_ASSERT(tree.Empty()); + } + + inline void TestEmpty() { + TTree tree; + + UNIT_ASSERT(tree.Empty()); + UNIT_ASSERT_EQUAL(tree.Begin(), tree.End()); + } + + inline void TestInsert() { + TTree tree; + + { + TNode n1(1); + TNode n2(2); + TNode n3(3); + + tree.Insert(n1); + tree.Insert(n2); + tree.Insert(n3); + + TTree::TConstIterator it = tree.Begin(); + + UNIT_ASSERT_EQUAL((it++)->N, 1); + UNIT_ASSERT_EQUAL((it++)->N, 2); + UNIT_ASSERT_EQUAL((it++)->N, 3); + UNIT_ASSERT_EQUAL(it, tree.End()); + } + + UNIT_ASSERT(tree.Empty()); + } + + inline void TestErase() { + TTree tree; + + { + TNode n1(1); + TNode n2(2); + TNode n3(3); + + tree.Insert(n1); + tree.Insert(n2); + tree.Insert(n3); + + TTree::TIterator it = tree.Begin(); + + tree.Erase(it++); + + UNIT_ASSERT_EQUAL(it, tree.Begin()); + UNIT_ASSERT_EQUAL(it->N, 2); + + tree.Erase(it++); + + UNIT_ASSERT_EQUAL(it, tree.Begin()); + UNIT_ASSERT_EQUAL(it->N, 3); + + tree.Erase(it++); + + UNIT_ASSERT_EQUAL(it, tree.Begin()); + UNIT_ASSERT_EQUAL(it, tree.End()); + } + + UNIT_ASSERT(tree.Empty()); + } + + inline void TestLessCountOnEmptyTree() { + TTree tree; + UNIT_ASSERT_VALUES_EQUAL(0, tree.LessCount(TNode(1))); + } +}; + +UNIT_TEST_SUITE_REGISTRATION(TRedBlackTreeTest); diff --git a/library/cpp/containers/intrusive_rb_tree/ut/ya.make b/library/cpp/containers/intrusive_rb_tree/ut/ya.make new file mode 100644 index 00000000000..6f1e3b38ee0 --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/ut/ya.make @@ -0,0 +1,12 @@ +UNITTEST_FOR(library/cpp/containers/intrusive_rb_tree) + +OWNER( + pg + g:util +) + +SRCS( + rb_tree_ut.cpp +) + +END() diff --git a/library/cpp/containers/intrusive_rb_tree/ya.make b/library/cpp/containers/intrusive_rb_tree/ya.make new file mode 100644 index 00000000000..2e5eddcfbe8 --- /dev/null +++ b/library/cpp/containers/intrusive_rb_tree/ya.make @@ -0,0 +1,12 @@ +LIBRARY() + +OWNER( + pg + g:util +) + +SRCS( + rb_tree.cpp +) + +END() diff --git a/library/cpp/containers/paged_vector/paged_vector.cpp b/library/cpp/containers/paged_vector/paged_vector.cpp new file mode 100644 index 00000000000..e354caf09d4 --- /dev/null +++ b/library/cpp/containers/paged_vector/paged_vector.cpp @@ -0,0 +1 @@ +#include "paged_vector.h" diff --git a/library/cpp/containers/paged_vector/paged_vector.h b/library/cpp/containers/paged_vector/paged_vector.h new file mode 100644 index 00000000000..6a3657d3ea5 --- /dev/null +++ b/library/cpp/containers/paged_vector/paged_vector.h @@ -0,0 +1,432 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> + +#include <iterator> + +namespace NPagedVector { + template <class T, ui32 PageSize = 1u << 20, class A = std::allocator<T>> + class TPagedVector; + + namespace NPrivate { + template <class T, class TT, ui32 PageSize, class A> + struct TPagedVectorIterator { + private: + friend class TPagedVector<TT, PageSize, A>; + typedef TPagedVector<TT, PageSize, A> TVec; + typedef TPagedVectorIterator<T, TT, PageSize, A> TSelf; + size_t Offset; + TVec* Vector; + + template <class T1, class TT1, ui32 PageSize1, class A1> + friend struct TPagedVectorIterator; + + public: + TPagedVectorIterator() + : Offset() + , Vector() + { + } + + TPagedVectorIterator(TVec* vector, size_t offset) + : Offset(offset) + , Vector(vector) + { + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + TPagedVectorIterator(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) + : Offset(it.Offset) + , Vector(it.Vector) + { + } + + T& operator*() const { + return (*Vector)[Offset]; + } + + T* operator->() const { + return &(**this); + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + bool operator==(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return Offset == it.Offset; + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + bool operator!=(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return !(*this == it); + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + bool operator<(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return Offset < it.Offset; + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + bool operator<=(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return Offset <= it.Offset; + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + bool operator>(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return !(*this <= it); + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + bool operator>=(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return !(*this < it); + } + + template <class T1, class TT1, ui32 PageSize1, class A1> + ptrdiff_t operator-(const TPagedVectorIterator<T1, TT1, PageSize1, A1>& it) const { + return Offset - it.Offset; + } + + TSelf& operator+=(ptrdiff_t off) { + Offset += off; + return *this; + } + + TSelf& operator-=(ptrdiff_t off) { + return this->operator+=(-off); + } + + TSelf& operator++() { + return this->operator+=(1); + } + + TSelf& operator--() { + return this->operator+=(-1); + } + + TSelf operator++(int) { + TSelf it = *this; + this->operator+=(1); + return it; + } + + TSelf operator--(int) { + TSelf it = *this; + this->operator+=(-1); + return it; + } + + TSelf operator+(ptrdiff_t off) { + TSelf res = *this; + res += off; + return res; + } + + TSelf operator-(ptrdiff_t off) { + return this->operator+(-off); + } + + size_t GetOffset() { + return Offset; + } + }; + } +} + +namespace std { + template <class T, class TT, ui32 PageSize, class A> + struct iterator_traits<NPagedVector::NPrivate::TPagedVectorIterator<T, TT, PageSize, A>> { + typedef ptrdiff_t difference_type; + typedef T value_type; + typedef T* pointer; + typedef T& reference; + typedef random_access_iterator_tag iterator_category; + }; + +} + +namespace NPagedVector { + //2-level radix tree + template <class T, ui32 PageSize, class A> + class TPagedVector: private TVector<TSimpleSharedPtr<TVector<T, A>>, A> { + static_assert(PageSize, "expect PageSize"); + + typedef TVector<T, A> TPage; + typedef TVector<TSimpleSharedPtr<TPage>, A> TPages; + typedef TPagedVector<T, PageSize, A> TSelf; + + public: + typedef NPrivate::TPagedVectorIterator<T, T, PageSize, A> iterator; + typedef NPrivate::TPagedVectorIterator<const T, T, PageSize, A> const_iterator; + typedef std::reverse_iterator<iterator> reverse_iterator; + typedef std::reverse_iterator<const_iterator> const_reverse_iterator; + typedef T value_type; + typedef value_type& reference; + typedef const value_type& const_reference; + + TPagedVector() = default; + + template <typename TIter> + TPagedVector(TIter b, TIter e) { + append(b, e); + } + + iterator begin() { + return iterator(this, 0); + } + + const_iterator begin() const { + return const_iterator((TSelf*)this, 0); + } + + iterator end() { + return iterator(this, size()); + } + + const_iterator end() const { + return const_iterator((TSelf*)this, size()); + } + + reverse_iterator rbegin() { + return reverse_iterator(end()); + } + + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + + reverse_iterator rend() { + return reverse_iterator(begin()); + } + + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + void swap(TSelf& v) { + TPages::swap((TPages&)v); + } + + private: + static size_t PageNumber(size_t idx) { + return idx / PageSize; + } + + static size_t InPageIndex(size_t idx) { + return idx % PageSize; + } + + static size_t Index(size_t pnum, size_t poff) { + return pnum * PageSize + poff; + } + + TPage& PageAt(size_t pnum) const { + return *TPages::at(pnum); + } + + TPage& CurrentPage() const { + return *TPages::back(); + } + + size_t CurrentPageSize() const { + return TPages::empty() ? 0 : CurrentPage().size(); + } + + size_t NPages() const { + return TPages::size(); + } + + void AllocateNewPage() { + TPages::push_back(new TPage()); + CurrentPage().reserve(PageSize); + } + + void MakeNewPage() { + AllocateNewPage(); + CurrentPage().resize(PageSize); + } + + void PrepareAppend() { + if (TPages::empty() || CurrentPage().size() + 1 > PageSize) + AllocateNewPage(); + } + + public: + size_t size() const { + return empty() ? 0 : (NPages() - 1) * PageSize + CurrentPage().size(); + } + + bool empty() const { + return TPages::empty() || 1 == NPages() && CurrentPage().empty(); + } + + explicit operator bool() const noexcept { + return !empty(); + } + + void emplace_back() { + PrepareAppend(); + CurrentPage().emplace_back(); + } + + void push_back(const_reference t) { + PrepareAppend(); + CurrentPage().push_back(t); + } + + void pop_back() { + if (CurrentPage().empty()) + TPages::pop_back(); + CurrentPage().pop_back(); + } + + template <typename TIter> + void append(TIter b, TIter e) { + size_t sz = e - b; + size_t sz1 = Min<size_t>(sz, PageSize - CurrentPageSize()); + size_t sz2 = (sz - sz1) / PageSize; + size_t sz3 = (sz - sz1) % PageSize; + + if (sz1) { + PrepareAppend(); + TPage& p = CurrentPage(); + p.insert(p.end(), b, b + sz1); + } + + for (size_t i = 0; i < sz2; ++i) { + AllocateNewPage(); + TPage& p = CurrentPage(); + p.insert(p.end(), b + sz1 + i * PageSize, b + sz1 + (i + 1) * PageSize); + } + + if (sz3) { + AllocateNewPage(); + TPage& p = CurrentPage(); + p.insert(p.end(), b + sz1 + sz2 * PageSize, e); + } + } + + iterator erase(iterator it) { + size_t pnum = PageNumber(it.Offset); + size_t pidx = InPageIndex(it.Offset); + + if (CurrentPage().empty()) + TPages::pop_back(); + + for (size_t p = NPages() - 1; p > pnum; --p) { + PageAt(p - 1).push_back(PageAt(p).front()); + PageAt(p).erase(PageAt(p).begin()); + } + + PageAt(pnum).erase(PageAt(pnum).begin() + pidx); + return it; + } + + iterator erase(iterator b, iterator e) { + // todo : suboptimal! + while (b != e) { + b = erase(b); + --e; + } + + return b; + } + + iterator insert(iterator it, const value_type& v) { + size_t pnum = PageNumber(it.Offset); + size_t pidx = InPageIndex(it.Offset); + + PrepareAppend(); + + for (size_t p = NPages() - 1; p > pnum; --p) { + PageAt(p).insert(PageAt(p).begin(), PageAt(p - 1).back()); + PageAt(p - 1).pop_back(); + } + + PageAt(pnum).insert(PageAt(pnum).begin() + pidx, v); + return it; + } + + template <typename TIter> + void insert(iterator it, TIter b, TIter e) { + // todo : suboptimal! + for (; b != e; ++b, ++it) + it = insert(it, *b); + } + + reference front() { + return TPages::front()->front(); + } + + const_reference front() const { + return TPages::front()->front(); + } + + reference back() { + return CurrentPage().back(); + } + + const_reference back() const { + return CurrentPage().back(); + } + + void clear() { + TPages::clear(); + } + + void resize(size_t sz) { + if (sz == size()) + return; + + const size_t npages = NPages(); + const size_t newwholepages = sz / PageSize; + const size_t pagepart = sz % PageSize; + const size_t newpages = newwholepages + bool(pagepart); + + if (npages && newwholepages >= npages) + CurrentPage().resize(PageSize); + + if (newpages < npages) + TPages::resize(newpages); + else + for (size_t i = npages; i < newpages; ++i) + MakeNewPage(); + + if (pagepart) + CurrentPage().resize(pagepart); + + Y_VERIFY(sz == size(), "%" PRIu64 " %" PRIu64, (ui64)sz, (ui64)size()); + } + + reference at(size_t idx) { + return TPages::at(PageNumber(idx))->at(InPageIndex(idx)); + } + + const_reference at(size_t idx) const { + return TPages::at(PageNumber(idx))->at(InPageIndex(idx)); + } + + reference operator[](size_t idx) { + return TPages::operator[](PageNumber(idx))->operator[](InPageIndex(idx)); + } + + const_reference operator[](size_t idx) const { + return TPages::operator[](PageNumber(idx))->operator[](InPageIndex(idx)); + } + + friend bool operator==(const TSelf& a, const TSelf& b) { + return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); + } + + friend bool operator<(const TSelf& a, const TSelf& b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); + } + }; + + namespace NPrivate { + typedef std::is_same<std::random_access_iterator_tag, std::iterator_traits< + TPagedVector<ui32>::iterator>::iterator_category> + TIteratorCheck; + static_assert(TIteratorCheck::value, "expect TIteratorCheck::Result"); + } + +} diff --git a/library/cpp/containers/paged_vector/ut/paged_vector_ut.cpp b/library/cpp/containers/paged_vector/ut/paged_vector_ut.cpp new file mode 100644 index 00000000000..e867808ee41 --- /dev/null +++ b/library/cpp/containers/paged_vector/ut/paged_vector_ut.cpp @@ -0,0 +1,378 @@ +#include <library/cpp/containers/paged_vector/paged_vector.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <stdexcept> + +class TPagedVectorTest: public TTestBase { + UNIT_TEST_SUITE(TPagedVectorTest); + UNIT_TEST(Test0) + UNIT_TEST(Test1) + UNIT_TEST(Test2) + UNIT_TEST(Test3) + UNIT_TEST(Test4) + UNIT_TEST(Test5) + UNIT_TEST(Test6) + UNIT_TEST(Test7) + UNIT_TEST(TestAt) + UNIT_TEST(TestAutoRef) + UNIT_TEST(TestIterators) + //UNIT_TEST(TestEbo) + UNIT_TEST_SUITE_END(); + +private: + // Copy-paste of STLPort tests + void Test0() { + using NPagedVector::TPagedVector; + TPagedVector<int, 16> v1; // Empty vector of integers. + + UNIT_ASSERT(v1.empty() == true); + UNIT_ASSERT(v1.size() == 0); + + for (size_t i = 0; i < 256; ++i) { + v1.resize(i + 1); + UNIT_ASSERT_VALUES_EQUAL(v1.size(), i + 1); + } + + for (size_t i = 256; i-- > 0;) { + v1.resize(i); + UNIT_ASSERT_VALUES_EQUAL(v1.size(), i); + } + } + + void Test1() { + using NPagedVector::TPagedVector; + TPagedVector<int, 3> v1; // Empty vector of integers. + + UNIT_ASSERT(v1.empty() == true); + UNIT_ASSERT(v1.size() == 0); + + // UNIT_ASSERT(v1.max_size() == INT_MAX / sizeof(int)); + // cout << "max_size = " << v1.max_size() << endl; + v1.push_back(42); // Add an integer to the vector. + + UNIT_ASSERT(v1.size() == 1); + + UNIT_ASSERT(v1[0] == 42); + + { + TPagedVector<TPagedVector<int, 3>, 3> vect; + vect.resize(10); + UNIT_ASSERT(vect.size() == 10); + TPagedVector<TPagedVector<int, 3>, 3>::iterator it(vect.begin()), end(vect.end()); + for (; it != end; ++it) { + UNIT_ASSERT((*it).empty()); + UNIT_ASSERT((*it).size() == 0); + UNIT_ASSERT((*it).begin() == (*it).end()); + } + } + } + + void Test2() { + using NPagedVector::TPagedVector; + TPagedVector<double, 3> v1; // Empty vector of doubles. + v1.push_back(32.1); + v1.push_back(40.5); + v1.push_back(45.5); + v1.push_back(33.4); + TPagedVector<double, 3> v2; // Another empty vector of doubles. + v2.push_back(3.56); + + UNIT_ASSERT(v1.size() == 4); + UNIT_ASSERT(v1[0] == 32.1); + UNIT_ASSERT(v1[1] == 40.5); + UNIT_ASSERT(v1[2] == 45.5); + UNIT_ASSERT(v1[3] == 33.4); + + UNIT_ASSERT(v2.size() == 1); + UNIT_ASSERT(v2[0] == 3.56); + v1.swap(v2); // Swap the vector's contents. + + UNIT_ASSERT(v1.size() == 1); + UNIT_ASSERT(v1[0] == 3.56); + + UNIT_ASSERT(v2.size() == 4); + UNIT_ASSERT(v2[0] == 32.1); + UNIT_ASSERT(v2[1] == 40.5); + UNIT_ASSERT(v2[2] == 45.5); + UNIT_ASSERT(v2[3] == 33.4); + + v2 = v1; // Assign one vector to another. + + UNIT_ASSERT(v2.size() == 1); + UNIT_ASSERT(v2[0] == 3.56); + + v2.pop_back(); + UNIT_ASSERT(v2.size() == 0); + UNIT_ASSERT(v2.empty()); + } + + void Test3() { + using NPagedVector::TPagedVector; + TPagedVector<char, 1> v1; + + v1.push_back('h'); + v1.push_back('i'); + + UNIT_ASSERT(v1.size() == 2); + UNIT_ASSERT(v1[0] == 'h'); + UNIT_ASSERT(v1[1] == 'i'); + + TPagedVector<char, 1> v2; + v2.resize(v1.size()); + + for (size_t i = 0; i < v1.size(); ++i) + v2[i] = v1[i]; + + v2[1] = 'o'; // Replace second character. + + UNIT_ASSERT(v2.size() == 2); + UNIT_ASSERT(v2[0] == 'h'); + UNIT_ASSERT(v2[1] == 'o'); + + UNIT_ASSERT((v1 == v2) == false); + + UNIT_ASSERT((v1 < v2) == true); + } + + void Test4() { + using NPagedVector::TPagedVector; + TPagedVector<int, 3> v; + v.resize(4); + + v[0] = 1; + v[1] = 4; + v[2] = 9; + v[3] = 16; + + UNIT_ASSERT(v.front() == 1); + UNIT_ASSERT(v.back() == 16); + + v.push_back(25); + + UNIT_ASSERT(v.back() == 25); + UNIT_ASSERT(v.size() == 5); + + v.pop_back(); + + UNIT_ASSERT(v.back() == 16); + UNIT_ASSERT(v.size() == 4); + } + + void Test5() { + int array[] = {1, 4, 9, 16}; + + typedef NPagedVector::TPagedVector<int, 3> TVectorType; + TVectorType v(array, array + 4); + + UNIT_ASSERT(v.size() == 4); + + UNIT_ASSERT(v[0] == 1); + UNIT_ASSERT(v[1] == 4); + UNIT_ASSERT(v[2] == 9); + UNIT_ASSERT(v[3] == 16); + } + + void Test6() { + int array[] = {1, 4, 9, 16, 25, 36}; + + typedef NPagedVector::TPagedVector<int, 3> TVectorType; + TVectorType v(array, array + 6); + TVectorType::iterator vit; + + UNIT_ASSERT_VALUES_EQUAL(v.size(), 6u); + UNIT_ASSERT(v[0] == 1); + UNIT_ASSERT(v[1] == 4); + UNIT_ASSERT(v[2] == 9); + UNIT_ASSERT(v[3] == 16); + UNIT_ASSERT(v[4] == 25); + UNIT_ASSERT(v[5] == 36); + + vit = v.erase(v.begin()); // Erase first element. + UNIT_ASSERT(*vit == 4); + + UNIT_ASSERT(v.size() == 5); + UNIT_ASSERT(v[0] == 4); + UNIT_ASSERT(v[1] == 9); + UNIT_ASSERT(v[2] == 16); + UNIT_ASSERT(v[3] == 25); + UNIT_ASSERT(v[4] == 36); + + vit = v.erase(v.end() - 1); // Erase last element. + UNIT_ASSERT(vit == v.end()); + + UNIT_ASSERT(v.size() == 4); + UNIT_ASSERT(v[0] == 4); + UNIT_ASSERT(v[1] == 9); + UNIT_ASSERT(v[2] == 16); + UNIT_ASSERT(v[3] == 25); + + v.erase(v.begin() + 1, v.end() - 1); // Erase all but first and last. + + UNIT_ASSERT(v.size() == 2); + UNIT_ASSERT(v[0] == 4); + UNIT_ASSERT(v[1] == 25); + } + + void Test7() { + int array1[] = {1, 4, 25}; + int array2[] = {9, 16}; + + typedef NPagedVector::TPagedVector<int, 3> TVectorType; + + TVectorType v(array1, array1 + 3); + TVectorType::iterator vit; + vit = v.insert(v.begin(), 0); // Insert before first element. + UNIT_ASSERT_VALUES_EQUAL(*vit, 0); + + vit = v.insert(v.end(), 36); // Insert after last element. + UNIT_ASSERT(*vit == 36); + + UNIT_ASSERT(v.size() == 5); + UNIT_ASSERT(v[0] == 0); + UNIT_ASSERT(v[1] == 1); + UNIT_ASSERT(v[2] == 4); + UNIT_ASSERT(v[3] == 25); + UNIT_ASSERT(v[4] == 36); + + // Insert contents of array2 before fourth element. + v.insert(v.begin() + 3, array2, array2 + 2); + + UNIT_ASSERT(v.size() == 7); + + UNIT_ASSERT(v[0] == 0); + UNIT_ASSERT(v[1] == 1); + UNIT_ASSERT(v[2] == 4); + UNIT_ASSERT(v[3] == 9); + UNIT_ASSERT(v[4] == 16); + UNIT_ASSERT(v[5] == 25); + UNIT_ASSERT(v[6] == 36); + + v.clear(); + UNIT_ASSERT(v.empty()); + } + + void TestAt() { + using NPagedVector::TPagedVector; + TPagedVector<int, 3> v; + TPagedVector<int, 3> const& cv = v; + + v.push_back(10); + UNIT_ASSERT(v.at(0) == 10); + v.at(0) = 20; + UNIT_ASSERT(cv.at(0) == 20); + + for (;;) { + try { + v.at(1) = 20; + UNIT_ASSERT(false); + } catch (std::out_of_range const&) { + return; + } catch (...) { + UNIT_ASSERT(false); + } + } + } + + void TestAutoRef() { + using NPagedVector::TPagedVector; + typedef TPagedVector<int, 3> TVec; + TVec ref; + for (int i = 0; i < 5; ++i) { + ref.push_back(i); + } + + TPagedVector<TVec, 3> v_v_int; + v_v_int.push_back(ref); + v_v_int.push_back(v_v_int[0]); + v_v_int.push_back(ref); + v_v_int.push_back(v_v_int[0]); + v_v_int.push_back(v_v_int[0]); + v_v_int.push_back(ref); + + TPagedVector<TVec, 3>::iterator vvit(v_v_int.begin()), vvitEnd(v_v_int.end()); + for (; vvit != vvitEnd; ++vvit) { + UNIT_ASSERT(*vvit == ref); + } + } + + struct Point { + int x, y; + }; + + struct PointEx: public Point { + PointEx() + : builtFromBase(false) + { + } + PointEx(const Point&) + : builtFromBase(true) + { + } + + bool builtFromBase; + }; + + void TestIterators() { + using NPagedVector::TPagedVector; + TPagedVector<int, 3> vint; + vint.resize(10); + TPagedVector<int, 3> const& crvint = vint; + + UNIT_ASSERT(vint.begin() == vint.begin()); + UNIT_ASSERT(crvint.begin() == vint.begin()); + UNIT_ASSERT(vint.begin() == crvint.begin()); + UNIT_ASSERT(crvint.begin() == crvint.begin()); + + UNIT_ASSERT(vint.begin() != vint.end()); + UNIT_ASSERT(crvint.begin() != vint.end()); + UNIT_ASSERT(vint.begin() != crvint.end()); + UNIT_ASSERT(crvint.begin() != crvint.end()); + + UNIT_ASSERT(vint.rbegin() == vint.rbegin()); + // Not Standard: + //UNIT_ASSERT(vint.rbegin() == crvint.rbegin()); + //UNIT_ASSERT(crvint.rbegin() == vint.rbegin()); + UNIT_ASSERT(crvint.rbegin() == crvint.rbegin()); + + UNIT_ASSERT(vint.rbegin() != vint.rend()); + // Not Standard: + //UNIT_ASSERT(vint.rbegin() != crvint.rend()); + //UNIT_ASSERT(crvint.rbegin() != vint.rend()); + UNIT_ASSERT(crvint.rbegin() != crvint.rend()); + } + + /* This test check a potential issue with empty base class + * optimization. Some compilers (VC6) do not implement it + * correctly resulting ina wrong behavior. */ + void TestEbo() { + using NPagedVector::TPagedVector; + // We use heap memory as test failure can corrupt vector internal + // representation making executable crash on vector destructor invocation. + // We prefer a simple memory leak, internal corruption should be reveal + // by size or capacity checks. + typedef TPagedVector<int, 3> V; + V* pv1 = new V; + + pv1->resize(1); + pv1->at(0) = 1; + + V* pv2 = new V; + + pv2->resize(10); + for (int i = 0; i < 10; ++i) + pv2->at(i) = 2; + + pv1->swap(*pv2); + + UNIT_ASSERT(pv1->size() == 10); + UNIT_ASSERT((*pv1)[5] == 2); + + UNIT_ASSERT(pv2->size() == 1); + UNIT_ASSERT((*pv2)[0] == 1); + + delete pv2; + delete pv1; + } +}; + +UNIT_TEST_SUITE_REGISTRATION(TPagedVectorTest); diff --git a/library/cpp/containers/paged_vector/ut/ya.make b/library/cpp/containers/paged_vector/ut/ya.make new file mode 100644 index 00000000000..74cfe5fb4ad --- /dev/null +++ b/library/cpp/containers/paged_vector/ut/ya.make @@ -0,0 +1,13 @@ +UNITTEST() + +OWNER(velavokr) + +PEERDIR( + library/cpp/containers/paged_vector +) + +SRCS( + paged_vector_ut.cpp +) + +END() diff --git a/library/cpp/containers/paged_vector/ya.make b/library/cpp/containers/paged_vector/ya.make new file mode 100644 index 00000000000..e14548bc2c4 --- /dev/null +++ b/library/cpp/containers/paged_vector/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(velavokr) + +SRCS( + paged_vector.cpp +) + +END() diff --git a/library/cpp/containers/ring_buffer/ring_buffer.cpp b/library/cpp/containers/ring_buffer/ring_buffer.cpp new file mode 100644 index 00000000000..799dad631be --- /dev/null +++ b/library/cpp/containers/ring_buffer/ring_buffer.cpp @@ -0,0 +1 @@ +#include "ring_buffer.h" diff --git a/library/cpp/containers/ring_buffer/ring_buffer.h b/library/cpp/containers/ring_buffer/ring_buffer.h new file mode 100644 index 00000000000..41220dcf6bf --- /dev/null +++ b/library/cpp/containers/ring_buffer/ring_buffer.h @@ -0,0 +1,81 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/system/yassert.h> + +template <typename T> +class TSimpleRingBuffer { +public: + TSimpleRingBuffer(size_t maxSize) + : MaxSize(maxSize) + { + Items.reserve(MaxSize); + } + + TSimpleRingBuffer(const TSimpleRingBuffer&) = default; + TSimpleRingBuffer(TSimpleRingBuffer&&) = default; + + TSimpleRingBuffer& operator=(const TSimpleRingBuffer&) = default; + TSimpleRingBuffer& operator=(TSimpleRingBuffer&&) = default; + + // First available item + size_t FirstIndex() const { + return Begin; + } + + size_t AvailSize() const { + return Items.size(); + } + + // Total number of items inserted + size_t TotalSize() const { + return FirstIndex() + AvailSize(); + } + + bool IsAvail(size_t index) const { + return index >= FirstIndex() && index < TotalSize(); + } + + const T& operator[](size_t index) const { + Y_ASSERT(IsAvail(index)); + return Items[RealIndex(index)]; + } + + T& operator[](size_t index) { + Y_ASSERT(IsAvail(index)); + return Items[RealIndex(index)]; + } + + void PushBack(const T& t) { + if (Items.size() < MaxSize) { + Items.push_back(t); + } else { + Items[RealIndex(Begin)] = t; + Begin += 1; + } + } + + void Clear() { + Items.clear(); + Begin = 0; + } + +private: + size_t RealIndex(size_t index) const { + return index % MaxSize; + } + +private: + size_t MaxSize; + size_t Begin = 0; + TVector<T> Items; +}; + +template <typename T, size_t maxSize> +class TStaticRingBuffer: public TSimpleRingBuffer<T> { +public: + TStaticRingBuffer() + : TSimpleRingBuffer<T>(maxSize) + { + } +}; diff --git a/library/cpp/containers/ring_buffer/ya.make b/library/cpp/containers/ring_buffer/ya.make new file mode 100644 index 00000000000..51333978f7f --- /dev/null +++ b/library/cpp/containers/ring_buffer/ya.make @@ -0,0 +1,9 @@ +OWNER(mowgli) + +LIBRARY() + +SRCS( + ring_buffer.cpp +) + +END() diff --git a/library/cpp/containers/sorted_vector/sorted_vector.cpp b/library/cpp/containers/sorted_vector/sorted_vector.cpp new file mode 100644 index 00000000000..56aaf69ddef --- /dev/null +++ b/library/cpp/containers/sorted_vector/sorted_vector.cpp @@ -0,0 +1 @@ +#include "sorted_vector.h" diff --git a/library/cpp/containers/sorted_vector/sorted_vector.h b/library/cpp/containers/sorted_vector/sorted_vector.h new file mode 100644 index 00000000000..123539af9e2 --- /dev/null +++ b/library/cpp/containers/sorted_vector/sorted_vector.h @@ -0,0 +1,492 @@ +#pragma once + +#include <util/system/defaults.h> +#include <util/generic/hash.h> +#include <util/generic/vector.h> +#include <util/generic/algorithm.h> +#include <util/generic/mapfindptr.h> +#include <util/ysaveload.h> +#include <utility> + +#include <initializer_list> + +namespace NSorted { + namespace NPrivate { + template <class TPredicate> + struct TEqual { + template<typename TValueType1, typename TValueType2> + inline bool operator()(const TValueType1& l, const TValueType2& r) const { + TPredicate comp; + return comp(l, r) == comp(r, l); + } + }; + + template <typename TValueType, class TPredicate, class TKeyExtractor> + struct TKeyCompare { + inline bool operator()(const TValueType& l, const TValueType& r) const { + TKeyExtractor extractKey; + return TPredicate()(extractKey(l), extractKey(r)); + } + template<typename TKeyType> + inline bool operator()(const TKeyType& l, const TValueType& r) const { + return TPredicate()(l, TKeyExtractor()(r)); + } + template<typename TKeyType> + inline bool operator()(const TValueType& l, const TKeyType& r) const { + return TPredicate()(TKeyExtractor()(l), r); + } + }; + + template <typename TValueType, class TPredicate> + struct TKeyCompare<TValueType, TPredicate, TIdentity> { + template <typename TValueType1, typename TValueType2> + inline bool operator()(const TValueType1& l, const TValueType2& r) const { + return TPredicate()(l, r); + } + }; + + } + + // Sorted vector, which is order by the key. The key is extracted from the value by the provided key-extractor + template <typename TValueType, typename TKeyType = TValueType, class TKeyExtractor = TIdentity, + class TPredicate = TLess<TKeyType>, class A = std::allocator<TValueType>> + class TSortedVector: public TVector<TValueType, A> { + private: + typedef TVector<TValueType, A> TBase; + typedef NPrivate::TKeyCompare<TValueType, TPredicate, TKeyExtractor> TKeyCompare; + typedef NPrivate::TEqual<TKeyCompare> TValueEqual; + typedef NPrivate::TEqual<TPredicate> TKeyEqual; + + public: + typedef TValueType value_type; + typedef TKeyType key_type; + typedef typename TBase::iterator iterator; + typedef typename TBase::const_iterator const_iterator; + typedef typename TBase::size_type size_type; + + public: + inline TSortedVector() + : TBase() + { + } + + inline explicit TSortedVector(size_type count) + : TBase(count) + { + } + + inline TSortedVector(size_type count, const value_type& val) + : TBase(count, val) + { + } + + inline TSortedVector(std::initializer_list<value_type> il) + : TBase(il) + { + Sort(); + } + + inline TSortedVector(std::initializer_list<value_type> il, const typename TBase::allocator_type& a) + : TBase(il, a) + { + Sort(); + } + + template <class TIter> + inline TSortedVector(TIter first, TIter last) + : TBase(first, last) + { + Sort(); + } + + // Inserts non-unique value in the proper position according to the key-sort order. + // Returns iterator, which points to the inserted value + inline iterator Insert(const value_type& value) { + return TBase::insert(LowerBound(TKeyExtractor()(value)), value); + } + + // STL-compatible synonym + Y_FORCE_INLINE iterator insert(const value_type& value) { + return this->Insert(value); + } + + // Inserts non-unique value range in the proper position according to the key-sort order. + template <class TIter> + inline void Insert(TIter first, TIter last) { + TBase::insert(TBase::end(), first, last); + Sort(); + } + + // STL-compatible synonym + template <class TIter> + Y_FORCE_INLINE void insert(TIter first, TIter last) { + this->Insert(first, last); + } + + // Inserts unique value in the proper position according to the key-sort order, + // if the value with the same key doesn't exist. Returns <iterator, bool> pair, + // where the first member is the pointer to the inserted/existing value, and the + // second member indicates either the value is inserted or not. + inline std::pair<iterator, bool> InsertUnique(const value_type& value) { + iterator i = LowerBound(TKeyExtractor()(value)); + if (i == TBase::end() || !TValueEqual()(*i, value)) + return std::pair<iterator, bool>(TBase::insert(i, value), true); + else + return std::pair<iterator, bool>(i, false); + } + + // STL-compatible synonym + Y_FORCE_INLINE std::pair<iterator, bool> insert_unique(const value_type& value) { + return this->InsertUnique(value); + } + + // Inserts unique value range in the proper position according to the key-sort order. + template <class TIter> + inline void InsertUnique(TIter first, TIter last) { + TBase::insert(TBase::end(), first, last); + Sort(); + MakeUnique(); + } + + // STL-compatible synonym + template <class TIter> + Y_FORCE_INLINE void insert_unique(TIter first, TIter last) { + this->InsertUnique(first, last); + } + + // Inserts unique value in the proper position according to the key-sort order. + // If the value with the same key already exists, then it is replaced with the new one. + // Returns iterator, which points to the inserted value + inline iterator InsertOrReplace(const value_type& value) { + iterator i = ::LowerBound(TBase::begin(), TBase::end(), value, TKeyCompare()); + if (i == TBase::end() || !TValueEqual()(*i, value)) + return TBase::insert(i, value); + else + return TBase::insert(TBase::erase(i), value); + } + + // STL-compatible synonym + Y_FORCE_INLINE iterator insert_or_replace(const value_type& value) { + return this->InsertOrReplace(value); + } + + Y_FORCE_INLINE void Sort() { + ::Sort(TBase::begin(), TBase::end(), TKeyCompare()); + } + + // STL-compatible synonym + Y_FORCE_INLINE void sort() { + this->Sort(); + } + + Y_FORCE_INLINE void Sort(iterator from, iterator to) { + ::Sort(from, to, TKeyCompare()); + } + + // STL-compatible synonym + Y_FORCE_INLINE void sort(iterator from, iterator to) { + this->Sort(from, to); + } + + inline void MakeUnique() { + TBase::erase(::Unique(TBase::begin(), TBase::end(), TValueEqual()), TBase::end()); + } + + // STL-compatible synonym + Y_FORCE_INLINE void make_unique() { + this->MakeUnique(); + } + + template<class K> + inline const_iterator Find(const K& key) const { + const_iterator i = LowerBound(key); + if (i == TBase::end() || !TKeyEqual()(TKeyExtractor()(*i), key)) + return TBase::end(); + else + return i; + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE const_iterator find(const K& key) const { + return this->Find(key); + } + + template<class K> + inline iterator Find(const K& key) { + iterator i = LowerBound(key); + if (i == TBase::end() || !TKeyEqual()(TKeyExtractor()(*i), key)) + return TBase::end(); + else + return i; + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE iterator find(const K& key) { + return this->Find(key); + } + + template<class K> + Y_FORCE_INLINE bool Has(const K& key) const { + return this->find(key) != TBase::end(); + } + + template<class K> + Y_FORCE_INLINE bool has(const K& key) const { + return this->Has(key); + } + + template<class K> + Y_FORCE_INLINE iterator LowerBound(const K& key) { + return ::LowerBound(TBase::begin(), TBase::end(), key, TKeyCompare()); + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE iterator lower_bound(const K& key) { + return this->LowerBound(key); + } + + template<class K> + Y_FORCE_INLINE const_iterator LowerBound(const K& key) const { + return ::LowerBound(TBase::begin(), TBase::end(), key, TKeyCompare()); + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE const_iterator lower_bound(const K& key) const { + return this->LowerBound(key); + } + + template<class K> + Y_FORCE_INLINE iterator UpperBound(const K& key) { + return ::UpperBound(TBase::begin(), TBase::end(), key, TKeyCompare()); + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE iterator upper_bound(const K& key) { + return this->UpperBound(key); + } + + template<class K> + Y_FORCE_INLINE const_iterator UpperBound(const K& key) const { + return ::UpperBound(TBase::begin(), TBase::end(), key, TKeyCompare()); + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE const_iterator upper_bound(const K& key) const { + return this->UpperBound(key); + } + + template<class K> + Y_FORCE_INLINE std::pair<iterator, iterator> EqualRange(const K& key) { + return std::equal_range(TBase::begin(), TBase::end(), key, TKeyCompare()); + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE std::pair<iterator, iterator> equal_range(const K& key) { + return this->EqualRange(key); + } + + template<class K> + Y_FORCE_INLINE std::pair<const_iterator, const_iterator> EqualRange(const K& key) const { + return std::equal_range(TBase::begin(), TBase::end(), key, TKeyCompare()); + } + + // STL-compatible synonym + template<class K> + Y_FORCE_INLINE std::pair<const_iterator, const_iterator> equal_range(const K& key) const { + return this->EqualRange(key); + } + + template<class K> + inline void Erase(const K& key) { + std::pair<iterator, iterator> res = EqualRange(key); + TBase::erase(res.first, res.second); + } + + // STL-compatible synonym + Y_FORCE_INLINE void erase(const key_type& key) { + this->Erase(key); + } + + template<class K> + inline size_t count(const K& key) const { + const std::pair<const_iterator, const_iterator> range = this->EqualRange(key); + return std::distance(range.first, range.second); + } + + using TBase::erase; + }; + + // The simplified map (a.k.a TFlatMap, flat_map), which is implemented by the sorted-vector. + // This structure has the side-effect: if you keep a reference to an existing element + // and then inserts a new one, the existing reference can be broken (due to reallocation). + // Please keep this in mind when using this structure. + template <typename TKeyType, typename TValueType, class TPredicate = TLess<TKeyType>, class A = std::allocator<TValueType>> + class TSimpleMap: + public TSortedVector<std::pair<TKeyType, TValueType>, TKeyType, TSelect1st, TPredicate, A>, + public TMapOps<TSimpleMap<TKeyType, TValueType, TPredicate, A>> + { + private: + typedef TSortedVector<std::pair<TKeyType, TValueType>, TKeyType, TSelect1st, TPredicate, A> TBase; + + public: + typedef typename TBase::value_type value_type; + typedef typename TBase::key_type key_type; + typedef typename TBase::iterator iterator; + typedef typename TBase::const_iterator const_iterator; + typedef typename TBase::size_type size_type; + + public: + inline TSimpleMap() + : TBase() + { + } + + template <class TIter> + inline TSimpleMap(TIter first, TIter last) + : TBase(first, last) + { + TBase::MakeUnique(); + } + + inline TSimpleMap(std::initializer_list<value_type> il) + : TBase(il) + { + TBase::MakeUnique(); + } + + inline TValueType& Get(const TKeyType& key) { + typename TBase::iterator i = TBase::LowerBound(key); + if (i == TBase::end() || key != i->first) + return TVector<std::pair<TKeyType, TValueType>, A>::insert(i, std::make_pair(key, TValueType()))->second; + else + return i->second; + } + + template<class K> + inline const TValueType& Get(const K& key, const TValueType& def) const { + typename TBase::const_iterator i = TBase::Find(key); + return i != TBase::end() ? i->second : def; + } + + template<class K> + Y_FORCE_INLINE TValueType& operator[](const K& key) { + return Get(key); + } + + template<class K> + const TValueType& at(const K& key) const { + const auto i = TBase::Find(key); + if (i == TBase::end()) { + throw std::out_of_range("NSorted::TSimpleMap: missing key"); + } + + return i->second; + } + + template<class K> + TValueType& at(const K& key) { + return const_cast<TValueType&>( + const_cast<const TSimpleMap<TKeyType, TValueType, TPredicate, A>*>(this)->at(key)); + } + }; + + // The simplified set (a.k.a TFlatSet, flat_set), which is implemented by the sorted-vector. + // This structure has the same side-effect as TSimpleMap. + // The value type must have TValueType(TKeyType) constructor in order to use [] operator + template <typename TValueType, typename TKeyType = TValueType, class TKeyExtractor = TIdentity, + class TPredicate = TLess<TKeyType>, class A = std::allocator<TValueType>> + class TSimpleSet: public TSortedVector<TValueType, TKeyType, TKeyExtractor, TPredicate, A> { + private: + typedef TSortedVector<TValueType, TKeyType, TKeyExtractor, TPredicate, A> TBase; + + public: + typedef typename TBase::value_type value_type; + typedef typename TBase::key_type key_type; + typedef typename TBase::iterator iterator; + typedef typename TBase::const_iterator const_iterator; + typedef typename TBase::size_type size_type; + typedef NPrivate::TEqual<TPredicate> TKeyEqual; + + public: + inline TSimpleSet() + : TBase() + { + } + + template <class TIter> + inline TSimpleSet(TIter first, TIter last) + : TBase(first, last) + { + TBase::MakeUnique(); + } + + inline TSimpleSet(std::initializer_list<value_type> il) + : TBase(il) + { + TBase::MakeUnique(); + } + + // The method expects that there is a TValueType(TKeyType) constructor available + inline TValueType& Get(const TKeyType& key) { + typename TBase::iterator i = TBase::LowerBound(key); + if (i == TBase::end() || !TKeyEqual()(TKeyExtractor()(*i), key)) + i = TVector<TValueType, A>::insert(i, TValueType(key)); + return *i; + } + + template<class K> + inline const TValueType& Get(const K& key, const TValueType& def) const { + typename TBase::const_iterator i = TBase::Find(key); + return i != TBase::end() ? *i : def; + } + + template<class K> + Y_FORCE_INLINE TValueType& operator[](const K& key) { + return Get(key); + } + + // Inserts value with unique key. Returns <iterator, bool> pair, + // where the first member is the pointer to the inserted/existing value, and the + // second member indicates either the value is inserted or not. + Y_FORCE_INLINE std::pair<iterator, bool> Insert(const TValueType& value) { + return TBase::InsertUnique(value); + } + + // STL-compatible synonym + Y_FORCE_INLINE std::pair<iterator, bool> insert(const TValueType& value) { + return TBase::InsertUnique(value); + } + + // Inserts value range with unique keys. + template <class TIter> + Y_FORCE_INLINE void Insert(TIter first, TIter last) { + TBase::InsertUnique(first, last); + } + + // STL-compatible synonym + template <class TIter> + Y_FORCE_INLINE void insert(TIter first, TIter last) { + TBase::InsertUnique(first, last); + } + }; + +} + +template <typename V, typename K, class E, class P, class A> +class TSerializer<NSorted::TSortedVector<V, K, E, P, A>>: public TVectorSerializer<NSorted::TSortedVector<V, K, E, P, A>> { +}; + +template <typename K, typename V, class P, class A> +class TSerializer<NSorted::TSimpleMap<K, V, P, A>>: public TVectorSerializer<NSorted::TSimpleMap<K, V, P, A>> { +}; + +template <typename V, typename K, class E, class P, class A> +class TSerializer<NSorted::TSimpleSet<V, K, E, P, A>>: public TVectorSerializer<NSorted::TSimpleSet<V, K, E, P, A>> { +}; diff --git a/library/cpp/containers/sorted_vector/sorted_vector_ut.cpp b/library/cpp/containers/sorted_vector/sorted_vector_ut.cpp new file mode 100644 index 00000000000..893862f098a --- /dev/null +++ b/library/cpp/containers/sorted_vector/sorted_vector_ut.cpp @@ -0,0 +1,24 @@ +#include "sorted_vector.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/string.h> +#include <util/generic/strbuf.h> + +Y_UNIT_TEST_SUITE(TestSimpleMap) { + + Y_UNIT_TEST(TestFindPrt) { + NSorted::TSimpleMap<TString, TString> map( + {std::make_pair(TString("a"), TString("a")), std::make_pair(TString("b"), TString("b"))}); + + UNIT_ASSERT_VALUES_UNEQUAL(map.FindPtr(TString("a")), nullptr); + UNIT_ASSERT_VALUES_EQUAL(map.FindPtr(TString("c")), nullptr); + + UNIT_ASSERT_VALUES_UNEQUAL(map.FindPtr(TStringBuf("a")), nullptr); + UNIT_ASSERT_VALUES_EQUAL(map.FindPtr(TStringBuf("c")), nullptr); + + UNIT_ASSERT_VALUES_UNEQUAL(map.FindPtr("a"), nullptr); + UNIT_ASSERT_VALUES_EQUAL(map.FindPtr("c"), nullptr); + } + +} diff --git a/library/cpp/containers/sorted_vector/ut/ya.make b/library/cpp/containers/sorted_vector/ut/ya.make new file mode 100644 index 00000000000..eb8a5b4beff --- /dev/null +++ b/library/cpp/containers/sorted_vector/ut/ya.make @@ -0,0 +1,10 @@ +UNITTEST_FOR(library/cpp/containers/sorted_vector) + +OWNER(udovichenko-r) + + +SRCS( + sorted_vector_ut.cpp +) + +END() diff --git a/library/cpp/containers/sorted_vector/ya.make b/library/cpp/containers/sorted_vector/ya.make new file mode 100644 index 00000000000..1975c5dc908 --- /dev/null +++ b/library/cpp/containers/sorted_vector/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +OWNER(udovichenko-r) + +SRCS( + sorted_vector.cpp +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/containers/stack_array/range_ops.cpp b/library/cpp/containers/stack_array/range_ops.cpp new file mode 100644 index 00000000000..f1b5e3af0de --- /dev/null +++ b/library/cpp/containers/stack_array/range_ops.cpp @@ -0,0 +1 @@ +#include "range_ops.h" diff --git a/library/cpp/containers/stack_array/range_ops.h b/library/cpp/containers/stack_array/range_ops.h new file mode 100644 index 00000000000..1d40341aa19 --- /dev/null +++ b/library/cpp/containers/stack_array/range_ops.h @@ -0,0 +1,52 @@ +#pragma once + +#include <util/generic/typetraits.h> + +#include <new> + +namespace NRangeOps { + template <class T, bool isTrivial> + struct TRangeOpsBase { + static inline void DestroyRange(T* b, T* e) noexcept { + while (e > b) { + (--e)->~T(); + } + } + + static inline void InitializeRange(T* b, T* e) { + T* c = b; + + try { + for (; c < e; ++c) { + new (c) T(); + } + } catch (...) { + DestroyRange(b, c); + + throw; + } + } + }; + + template <class T> + struct TRangeOpsBase<T, true> { + static inline void DestroyRange(T*, T*) noexcept { + } + + static inline void InitializeRange(T*, T*) noexcept { + } + }; + + template <class T> + using TRangeOps = TRangeOpsBase<T, TTypeTraits<T>::IsPod>; + + template <class T> + static inline void DestroyRange(T* b, T* e) noexcept { + TRangeOps<T>::DestroyRange(b, e); + } + + template <class T> + static inline void InitializeRange(T* b, T* e) { + TRangeOps<T>::InitializeRange(b, e); + } +} diff --git a/library/cpp/containers/stack_array/stack_array.cpp b/library/cpp/containers/stack_array/stack_array.cpp new file mode 100644 index 00000000000..68e8b097baa --- /dev/null +++ b/library/cpp/containers/stack_array/stack_array.cpp @@ -0,0 +1 @@ +#include "stack_array.h" diff --git a/library/cpp/containers/stack_array/stack_array.h b/library/cpp/containers/stack_array/stack_array.h new file mode 100644 index 00000000000..28e49bfc3c2 --- /dev/null +++ b/library/cpp/containers/stack_array/stack_array.h @@ -0,0 +1,40 @@ +#pragma once + +#include "range_ops.h" + +#include <util/generic/array_ref.h> +#include <util/system/defaults.h> /* For alloca. */ + +namespace NStackArray { + /** + * A stack-allocated array. Should be used instead of οΏ½ variable length + * arrays that are not part of C++ standard. + * + * Example usage: + * @code + * void f(int size) { + * // T array[size]; // Wrong! + * TStackArray<T> array(ALLOC_ON_STACK(T, size)); // Right! + * // ... + * } + * @endcode + * + * Note that it is generally a *VERY BAD* idea to use this in inline methods + * as those might be called from a loop, and then stack overflow is in the cards. + */ + template <class T> + class TStackArray: public TArrayRef<T> { + public: + inline TStackArray(void* data, size_t len) + : TArrayRef<T>((T*)data, len) + { + NRangeOps::InitializeRange(this->begin(), this->end()); + } + + inline ~TStackArray() { + NRangeOps::DestroyRange(this->begin(), this->end()); + } + }; +} + +#define ALLOC_ON_STACK(type, n) alloca(sizeof(type) * (n)), (n) diff --git a/library/cpp/containers/stack_array/ut/tests_ut.cpp b/library/cpp/containers/stack_array/ut/tests_ut.cpp new file mode 100644 index 00000000000..3e96384f0e7 --- /dev/null +++ b/library/cpp/containers/stack_array/ut/tests_ut.cpp @@ -0,0 +1,94 @@ +#include <library/cpp/containers/stack_array/stack_array.h> +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(TestStackArray) { + using namespace NStackArray; + + static inline void* FillWithTrash(void* d, size_t l) { + memset(d, 0xCC, l); + + return d; + } + +#define ALLOC(type, len) FillWithTrash(alloca(sizeof(type) * len), sizeof(type) * len), len + + Y_UNIT_TEST(Test1) { + TStackArray<ui32> s(ALLOC(ui32, 10)); + + UNIT_ASSERT_VALUES_EQUAL(s.size(), 10); + + for (size_t i = 0; i < s.size(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(s[i], 0xCCCCCCCC); + } + + for (auto&& x : s) { + UNIT_ASSERT_VALUES_EQUAL(x, 0xCCCCCCCC); + } + + for (size_t i = 0; i < s.size(); ++i) { + s[i] = i; + } + + size_t ss = 0; + + for (auto&& x : s) { + ss += x; + } + + UNIT_ASSERT_VALUES_EQUAL(ss, 45); + } + + static int N1 = 0; + + struct TX1 { + inline TX1() { + ++N1; + } + + inline ~TX1() { + --N1; + } + }; + + Y_UNIT_TEST(Test2) { + { + TStackArray<TX1> s(ALLOC(TX1, 10)); + + UNIT_ASSERT_VALUES_EQUAL(N1, 10); + } + + UNIT_ASSERT_VALUES_EQUAL(N1, 0); + } + + static int N2 = 0; + static int N3 = 0; + + struct TX2 { + inline TX2() { + if (N2 >= 5) { + ythrow yexception() << "ups"; + } + + ++N3; + ++N2; + } + + inline ~TX2() { + --N2; + } + }; + + Y_UNIT_TEST(Test3) { + bool haveException = false; + + try { + TStackArray<TX2> s(ALLOC_ON_STACK(TX2, 10)); + } catch (...) { + haveException = true; + } + + UNIT_ASSERT(haveException); + UNIT_ASSERT_VALUES_EQUAL(N2, 0); + UNIT_ASSERT_VALUES_EQUAL(N3, 5); + } +} diff --git a/library/cpp/containers/stack_array/ut/ya.make b/library/cpp/containers/stack_array/ut/ya.make new file mode 100644 index 00000000000..7db7340073b --- /dev/null +++ b/library/cpp/containers/stack_array/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/containers/stack_array) + +OWNER(pg) + +SRCS( + tests_ut.cpp +) + +END() diff --git a/library/cpp/containers/stack_array/ya.make b/library/cpp/containers/stack_array/ya.make new file mode 100644 index 00000000000..9bc0afc66ca --- /dev/null +++ b/library/cpp/containers/stack_array/ya.make @@ -0,0 +1,10 @@ +LIBRARY() + +OWNER(pg) + +SRCS( + range_ops.cpp + stack_array.cpp +) + +END() diff --git a/library/cpp/containers/stack_vector/stack_vec.cpp b/library/cpp/containers/stack_vector/stack_vec.cpp new file mode 100644 index 00000000000..21c0ab3f11e --- /dev/null +++ b/library/cpp/containers/stack_vector/stack_vec.cpp @@ -0,0 +1 @@ +#include "stack_vec.h" diff --git a/library/cpp/containers/stack_vector/stack_vec.h b/library/cpp/containers/stack_vector/stack_vec.h new file mode 100644 index 00000000000..fcc5d9a2a50 --- /dev/null +++ b/library/cpp/containers/stack_vector/stack_vec.h @@ -0,0 +1,212 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/ysaveload.h> + +#include <type_traits> + +// A vector preallocated on the stack. +// After exceeding the preconfigured stack space falls back to the heap. +// Publicly inherits TVector, but disallows swap (and hence shrink_to_fit, also operator= is reimplemented via copying). +// +// Inspired by: http://qt-project.org/doc/qt-4.8/qvarlengtharray.html#details + +template <typename T, size_t CountOnStack = 256, bool UseFallbackAlloc = true, class Alloc = std::allocator<T>> +class TStackVec; + +template <typename T, class Alloc = std::allocator<T>> +using TSmallVec = TStackVec<T, 16, true, Alloc>; + +template <typename T, size_t CountOnStack = 256> +using TStackOnlyVec = TStackVec<T, CountOnStack, false>; + +namespace NPrivate { + template <class Alloc, class StackAlloc, typename T, typename U> + struct TRebind { + typedef TReboundAllocator<Alloc, U> other; + }; + + template <class Alloc, class StackAlloc, typename T> + struct TRebind<Alloc, StackAlloc, T, T> { + typedef StackAlloc other; + }; + + template <typename T, size_t CountOnStack, bool UseFallbackAlloc, class Alloc = std::allocator<T>> + class TStackBasedAllocator: public Alloc { + public: + typedef TStackBasedAllocator<T, CountOnStack, UseFallbackAlloc, Alloc> TSelf; + + using typename Alloc::difference_type; + using typename Alloc::size_type; + using typename Alloc::value_type; + + template <class U> + struct rebind: public ::NPrivate::TRebind<Alloc, TSelf, T, U> { + }; + + public: + TStackBasedAllocator() = default; + + template < + typename... TArgs, + typename = std::enable_if_t< + std::is_constructible_v<Alloc, TArgs...> + > + > + TStackBasedAllocator(TArgs&&... args) + : Alloc(std::forward<TArgs>(args)...) + {} + + T* allocate(size_type n) { + if (!IsStorageUsed && CountOnStack >= n) { + IsStorageUsed = true; + return reinterpret_cast<T*>(&StackBasedStorage[0]); + } else { + if constexpr (!UseFallbackAlloc) { + Y_FAIL( + "Stack storage overflow. Capacity: %d, requested: %d", (int)CountOnStack, int(n)); + } + return FallbackAllocator().allocate(n); + } + } + + void deallocate(T* p, size_type n) { + if (p >= reinterpret_cast<T*>(&StackBasedStorage[0]) && + p < reinterpret_cast<T*>(&StackBasedStorage[CountOnStack])) { + Y_VERIFY(IsStorageUsed); + IsStorageUsed = false; + } else { + FallbackAllocator().deallocate(p, n); + } + } + + private: + std::aligned_storage_t<sizeof(T), alignof(T)> StackBasedStorage[CountOnStack]; + bool IsStorageUsed = false; + + private: + Alloc& FallbackAllocator() noexcept { + return static_cast<Alloc&>(*this); + } + }; +} + +template <typename T, size_t CountOnStack, bool UseFallbackAlloc, class Alloc> +class TStackVec: public TVector<T, ::NPrivate::TStackBasedAllocator<T, CountOnStack, UseFallbackAlloc, TReboundAllocator<Alloc, T>>> { +private: + using TBase = TVector<T, ::NPrivate::TStackBasedAllocator<T, CountOnStack, UseFallbackAlloc, TReboundAllocator<Alloc, T>>>; + using TAllocator = typename TBase::allocator_type; + +public: + using typename TBase::const_iterator; + using typename TBase::const_reverse_iterator; + using typename TBase::iterator; + using typename TBase::reverse_iterator; + using typename TBase::size_type; + using typename TBase::value_type; + +public: + TStackVec(const TAllocator& alloc = TAllocator()) + : TBase(alloc) + { + TBase::reserve(CountOnStack); + } + + explicit TStackVec(size_type count, const TAllocator& alloc = TAllocator()) + : TBase(alloc) + { + if (count <= CountOnStack) { + TBase::reserve(CountOnStack); + } + TBase::resize(count); + } + + TStackVec(size_type count, const T& val, const TAllocator& alloc = TAllocator()) + : TBase(alloc) + { + if (count <= CountOnStack) { + TBase::reserve(CountOnStack); + } + TBase::assign(count, val); + } + + TStackVec(const TStackVec& src) + : TStackVec(src.begin(), src.end()) + { + } + + template <class A> + TStackVec(const TVector<T, A>& src) + : TStackVec(src.begin(), src.end()) + { + } + + TStackVec(std::initializer_list<T> il, const TAllocator& alloc = TAllocator()) + : TStackVec(il.begin(), il.end(), alloc) + { + } + + template <class TIter> + TStackVec(TIter first, TIter last, const TAllocator& alloc = TAllocator()) + : TBase(alloc) + { + // NB(eeight) Since we want to call 'reserve' here, we cannot just delegate to TVector ctor. + // The best way to insert values afterwards is to call TVector::insert. However there is a caveat. + // In order to call this ctor of TVector, T needs to be just move-constructible. Insert however + // requires T to be move-assignable. + TBase::reserve(CountOnStack); + if constexpr (std::is_move_assignable_v<T>) { + // Fast path + TBase::insert(TBase::end(), first, last); + } else { + // Slow path. + for (; first != last; ++first) { + TBase::push_back(*first); + } + } + } + +public: + void swap(TStackVec&) = delete; + void shrink_to_fit() = delete; + + TStackVec& operator=(const TStackVec& src) { + TBase::assign(src.begin(), src.end()); + return *this; + } + + template <class A> + TStackVec& operator=(const TVector<T, A>& src) { + TBase::assign(src.begin(), src.end()); + return *this; + } + + TStackVec& operator=(std::initializer_list<T> il) { + TBase::assign(il.begin(), il.end()); + return *this; + } +}; + +template <typename T, size_t CountOnStack, class Alloc> +class TSerializer<TStackVec<T, CountOnStack, true, Alloc>>: public TVectorSerializer<TStackVec<T, CountOnStack, true, Alloc>> { +}; + +template <typename T, size_t CountOnStack, class Alloc> +class TSerializer<TStackVec<T, CountOnStack, false, Alloc>> { +public: + static void Save(IOutputStream* rh, const TStackVec<T, CountOnStack, false, Alloc>& v) { + if constexpr (CountOnStack < 256) { + ::Save(rh, (ui8)v.size()); + } else { + ::Save(rh, v.size()); + } + ::SaveArray(rh, v.data(), v.size()); + } + + static void Load(IInputStream* rh, TStackVec<T, CountOnStack, false, Alloc>& v) { + std::conditional_t<CountOnStack < 256, ui8, size_t> size; + ::Load(rh, size); + v.resize(size); + ::LoadPodArray(rh, v.data(), v.size()); + } +}; diff --git a/library/cpp/containers/stack_vector/stack_vec_ut.cpp b/library/cpp/containers/stack_vector/stack_vec_ut.cpp new file mode 100644 index 00000000000..19f9677781c --- /dev/null +++ b/library/cpp/containers/stack_vector/stack_vec_ut.cpp @@ -0,0 +1,144 @@ +#include "stack_vec.h" + +#include <library/cpp/testing/unittest/registar.h> + +namespace { + struct TNotCopyAssignable { + const int Value; + }; + + static_assert(std::is_copy_constructible_v<TNotCopyAssignable>); + static_assert(!std::is_copy_assignable_v<TNotCopyAssignable>); + + template <class T, size_t JunkPayloadSize> + struct TThickAlloc: public std::allocator<T> { + template <class U> + struct rebind { + using other = TThickAlloc<U, JunkPayloadSize>; + }; + + char Junk[JunkPayloadSize]{sizeof(T)}; + }; + + template <class T> + struct TStatefulAlloc: public std::allocator<T> { + using TBase = std::allocator<T>; + + template <class U> + struct rebind { + using other = TStatefulAlloc<U>; + }; + + TStatefulAlloc(size_t* allocCount) + : AllocCount(allocCount) + {} + + size_t* AllocCount; + + T* allocate(size_t n) + { + *AllocCount += 1; + return TBase::allocate(n); + } + }; +} + +Y_UNIT_TEST_SUITE(TStackBasedVectorTest) { + Y_UNIT_TEST(TestCreateEmpty) { + TStackVec<int> ints; + UNIT_ASSERT_EQUAL(ints.size(), 0); + } + + Y_UNIT_TEST(TestCreateNonEmpty) { + TStackVec<int> ints(5); + UNIT_ASSERT_EQUAL(ints.size(), 5); + + for (size_t i = 0; i < ints.size(); ++i) { + UNIT_ASSERT_EQUAL(ints[i], 0); + } + } + + Y_UNIT_TEST(TestReallyOnStack) { + const TStackVec<int> vec(5); + + UNIT_ASSERT( + (const char*)&vec <= (const char*)&vec[0] && + (const char*)&vec[0] <= (const char*)&vec + sizeof(vec) + ); + } + + Y_UNIT_TEST(TestFallback) { + TSmallVec<int> ints; + for (int i = 0; i < 14; ++i) { + ints.push_back(i); + } + + for (size_t i = 0; i < ints.size(); ++i) { + UNIT_ASSERT_EQUAL(ints[i], (int)i); + } + + for (int i = 14; i < 20; ++i) { + ints.push_back(i); + } + + for (size_t i = 0; i < ints.size(); ++i) { + UNIT_ASSERT_EQUAL(ints[i], (int)i); + } + + TSmallVec<int> ints2 = ints; + + for (size_t i = 0; i < ints2.size(); ++i) { + UNIT_ASSERT_EQUAL(ints2[i], (int)i); + } + + TSmallVec<int> ints3; + ints3 = ints2; + + for (size_t i = 0; i < ints3.size(); ++i) { + UNIT_ASSERT_EQUAL(ints3[i], (int)i); + } + } + + Y_UNIT_TEST(TestCappedSize) { + TStackVec<int, 8, false> ints; + ints.push_back(1); + ints.push_back(2); + + auto intsCopy = ints; + UNIT_ASSERT_VALUES_EQUAL(intsCopy.capacity(), 8); + + for (int i = 2; i != 8; ++i) { + intsCopy.push_back(i); + } + // Just verify that the program did not crash. + } + + Y_UNIT_TEST(TestCappedSizeWithNotCopyAssignable) { + TStackVec<TNotCopyAssignable, 8, false> values; + values.push_back({1}); + values.push_back({2}); + + auto valuesCopy = values; + UNIT_ASSERT_VALUES_EQUAL(valuesCopy.capacity(), 8); + + for (int i = 2; i != 8; ++i) { + valuesCopy.push_back({i}); + } + // Just verify that the program did not crash. + } + + Y_UNIT_TEST(TestCustomAllocSize) { + constexpr size_t n = 16384; + using TVec = TStackVec<size_t, 1, true, TThickAlloc<size_t, n>>; + UNIT_ASSERT_LT(sizeof(TVec), 1.5 * n); + } + + Y_UNIT_TEST(TestStatefulAlloc) { + size_t count = 0; + TStackVec<size_t, 1, true, TStatefulAlloc<size_t>> vec{{ &count }}; + for (size_t i = 0; i < 5; ++i) { + vec.push_back(1); + } + UNIT_ASSERT_VALUES_EQUAL(count, 3); + } +} diff --git a/library/cpp/containers/stack_vector/ut/ya.make b/library/cpp/containers/stack_vector/ut/ya.make new file mode 100644 index 00000000000..1d704969545 --- /dev/null +++ b/library/cpp/containers/stack_vector/ut/ya.make @@ -0,0 +1,11 @@ +UNITTEST() + +OWNER(ilnurkh) + +SRCDIR(library/cpp/containers/stack_vector) + +SRCS( + stack_vec_ut.cpp +) + +END() diff --git a/library/cpp/containers/stack_vector/ya.make b/library/cpp/containers/stack_vector/ya.make new file mode 100644 index 00000000000..cfb63295ec5 --- /dev/null +++ b/library/cpp/containers/stack_vector/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +OWNER(ilnurkh) + +SRCS( + stack_vec.cpp +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/containers/str_map/str_map.cpp b/library/cpp/containers/str_map/str_map.cpp new file mode 100644 index 00000000000..58c65babda6 --- /dev/null +++ b/library/cpp/containers/str_map/str_map.cpp @@ -0,0 +1 @@ +#include "str_map.h" diff --git a/library/cpp/containers/str_map/str_map.h b/library/cpp/containers/str_map/str_map.h new file mode 100644 index 00000000000..31b00d1b997 --- /dev/null +++ b/library/cpp/containers/str_map/str_map.h @@ -0,0 +1,205 @@ +#pragma once + +#include <util/memory/segmented_string_pool.h> +#include <util/generic/map.h> +#include <util/generic/hash.h> +#include <util/generic/buffer.h> +#include <util/str_stl.h> // less<> and equal_to<> for const char* +#include <utility> +#include <util/generic/noncopyable.h> + +template <class T, class HashFcn = THash<const char*>, class EqualTo = TEqualTo<const char*>, class Alloc = std::allocator<const char*>> +class string_hash; + +template <class T, class HashFcn = THash<const char*>, class EqualTo = TEqualTo<const char*>> +class segmented_string_hash; + +template <class Map> +inline std::pair<typename Map::iterator, bool> +pool_insert(Map* m, const char* key, const typename Map::mapped_type& data, TBuffer& pool) { + std::pair<typename Map::iterator, bool> ins = m->insert(typename Map::value_type(key, data)); + if (ins.second) { // new? + size_t buflen = strlen(key) + 1; // strlen??? + const char* old_pool = pool.Begin(); + pool.Append(key, buflen); + if (pool.Begin() != old_pool) // repoint? + for (typename Map::iterator it = m->begin(); it != m->end(); ++it) + if ((*it).first != key) + const_cast<const char*&>((*it).first) += (pool.Begin() - old_pool); + const_cast<const char*&>((*ins.first).first) = pool.End() - buflen; + } + return ins; +} + +#define HASH_SIZE_DEFAULT 100 +#define AVERAGEWORD_BUF 10 + +template <class T, class HashFcn, class EqualTo, class Alloc> +class string_hash: public THashMap<const char*, T, HashFcn, EqualTo, Alloc> { +protected: + TBuffer pool; + +public: + using yh = THashMap<const char*, T, HashFcn, EqualTo, Alloc>; + using iterator = typename yh::iterator; + using const_iterator = typename yh::const_iterator; + using mapped_type = typename yh::mapped_type; + using size_type = typename yh::size_type; + using pool_size_type = typename yh::size_type; + string_hash() { + pool.Reserve(HASH_SIZE_DEFAULT * AVERAGEWORD_BUF); // reserve here + } + string_hash(size_type hash_size, pool_size_type pool_size) + : THashMap<const char*, T, HashFcn, EqualTo, Alloc>(hash_size) + { + pool.Reserve(pool_size); // reserve here + } + + std::pair<iterator, bool> insert_copy(const char* key, const mapped_type& data) { + return ::pool_insert(this, key, data, pool); + } + + void clear_hash() { + yh::clear(); + pool.Clear(); + } + pool_size_type pool_size() const { + return pool.Size(); + } + + string_hash(const string_hash& sh) + : THashMap<const char*, T, HashFcn, EqualTo, Alloc>() + { + for (const_iterator i = sh.begin(); i != sh.end(); ++i) + insert_copy((*i).first, (*i).second); + } + /* May be faster? + string_hash(const string_hash& sh) + : THashMap<const char *, T, HashFcn, EqualTo>(sh) + { + pool = sh.pool; + size_t delta = pool.begin() - sh.pool.begin(); + for (iterator i = begin(); i != end(); ++i) + (const char*&)(*i).first += delta; + } + */ + string_hash& operator=(const string_hash& sh) { + if (&sh != this) { + clear_hash(); + for (const_iterator i = sh.begin(); i != sh.end(); ++i) + insert_copy((*i).first, (*i).second); + } + return *this; + } + + mapped_type& operator[](const char* key) { + iterator I = yh::find(key); + if (I == yh::end()) + I = insert_copy(key, mapped_type()).first; + return (*I).second; + } +}; + +template <class C, class T, class HashFcn, class EqualTo> +class THashWithSegmentedPoolForKeys: protected THashMap<const C*, T, HashFcn, EqualTo>, TNonCopyable { +protected: + segmented_pool<C> pool; + +public: + using yh = THashMap<const C*, T, HashFcn, EqualTo>; + using iterator = typename yh::iterator; + using const_iterator = typename yh::const_iterator; + using mapped_type = typename yh::mapped_type; + using size_type = typename yh::size_type; + using key_type = typename yh::key_type; + using value_type = typename yh::value_type; + + THashWithSegmentedPoolForKeys(size_type hash_size = HASH_SIZE_DEFAULT, size_t segsize = HASH_SIZE_DEFAULT * AVERAGEWORD_BUF, bool afs = false) + : yh(hash_size) + , pool(segsize) + { + if (afs) + pool.alloc_first_seg(); + } + + std::pair<iterator, bool> insert_copy(const C* key, size_t keylen, const mapped_type& data) { + std::pair<iterator, bool> ins = this->insert(value_type(key, data)); + if (ins.second) // new? + (const C*&)(*ins.first).first = pool.append(key, keylen); + return ins; + } + + void clear_hash() { + yh::clear(); + pool.restart(); + } + + size_t pool_size() const { + return pool.size(); + } + + size_t size() const { + return yh::size(); + } + + bool empty() const { + return yh::empty(); + } + + iterator begin() { + return yh::begin(); + } + + iterator end() { + return yh::end(); + } + + const_iterator begin() const { + return yh::begin(); + } + + const_iterator end() const { + return yh::end(); + } + + iterator find(const key_type& key) { + return yh::find(key); + } + + const_iterator find(const key_type& key) const { + return yh::find(key); + } + + const yh& get_THashMap() const { + return static_cast<const yh&>(*this); + } +}; + +template <class T, class HashFcn, class EqualTo> +class segmented_string_hash: public THashWithSegmentedPoolForKeys<char, T, HashFcn, EqualTo> { +public: + using Base = THashWithSegmentedPoolForKeys<char, T, HashFcn, EqualTo>; + using iterator = typename Base::iterator; + using const_iterator = typename Base::const_iterator; + using mapped_type = typename Base::mapped_type; + using size_type = typename Base::size_type; + using key_type = typename Base::key_type; + using value_type = typename Base::value_type; + +public: + segmented_string_hash(size_type hash_size = HASH_SIZE_DEFAULT, size_t segsize = HASH_SIZE_DEFAULT * AVERAGEWORD_BUF, bool afs = false) + : Base(hash_size, segsize, afs) + { + } + + std::pair<iterator, bool> insert_copy(const char* key, const mapped_type& data) { + return Base::insert_copy(key, strlen(key) + 1, data); + } + + mapped_type& operator[](const char* key) { + iterator I = Base::find(key); + if (I == Base::end()) + I = insert_copy(key, mapped_type()).first; + return (*I).second; + } +}; diff --git a/library/cpp/containers/str_map/ya.make b/library/cpp/containers/str_map/ya.make new file mode 100644 index 00000000000..b834159cda8 --- /dev/null +++ b/library/cpp/containers/str_map/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(pg) + +SRCS( + str_map.cpp +) + +END() diff --git a/library/cpp/containers/top_keeper/README.md b/library/cpp/containers/top_keeper/README.md new file mode 100644 index 00000000000..f160fb1c015 --- /dev/null +++ b/library/cpp/containers/top_keeper/README.md @@ -0,0 +1,26 @@ +TopKeeper - ΡΡΡΡΠΊΡΡΡΠ° Π΄Π°Π½Π½ΡΡ
Π΄Π»Ρ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠ°Π½ΠΈΡ "top M from stream" +ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ Π΄Π»Ρ Π²ΡΠ±ΠΎΡΠΊΠΈ Π½Π°ΠΈΠ±ΠΎΠ»ΡΡΠΈΡ
/ Π½Π°ΠΈΠΌΠ΅Π½ΡΡΠΈΡ
ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ² Π·Π° ΠΎΠ΄ΠΈΠ½ ΠΏΡΠΎΡ
ΠΎΠ΄ (ΠΏΠΎΠ»Π΅Π·Π½ΠΎ ΠΏΡΠΈ ΡΠΈΠ»ΡΡΡΠ°ΡΠΈΠΈ) + +ΠΡΡΡΡ Π²Ρ
ΠΎΠ΄Π½ΠΎΠΉ ΠΏΠΎΡΠΎΠΊ ΡΠΎΡΡΠΎΠΈΡ ΠΈΠ· N ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ², ΠΈΠ· ΠΊΠΎΡΠΎΡΡΡ
Π½ΡΠΆΠ½ΠΎ ΠΎΡΡΠΈΠ»ΡΡΡΠΎΠ²Π°ΡΡ M Ρ Π½Π°ΠΈΠ±ΠΎΠ»ΡΡΠΈΠΌ Π·Π½Π°ΡΠ΅Π½ΠΈΠ΅ΠΌ. +ΠΠ»Π³ΠΎΡΠΈΡΠΌ (Π΄Π»Ρ ΡΠ»ΡΡΠ°Ρ top max M): +1) ΠΡΠ΄Π΅Π»ΠΈΠΌ Π²Π΅ΠΊΡΠΎΡ ΡΠ°Π·ΠΌΠ΅ΡΠ° 2 * M +2Π°) ΠΡΠ»ΠΈ Π²Π΅ΠΊΡΠΎΡ Π·Π°ΠΏΠΎΠ»Π½Π΅Π½ ΠΌΠ΅Π½ΡΡΠ΅, ΡΠ΅ΠΌ Π½Π°ΠΏΠΎΠ»ΠΎΠ²ΠΈΠ½Ρ - Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΠΌ ΠΎΡΠ΅ΡΠ΅Π΄Π½ΠΎΠΉ ΡΠ»Π΅ΠΌΠ΅Π½Ρ, ΠΎΠ±Π½ΠΎΠ²Π»ΡΠ΅ΠΌ ΠΌΠΈΠ½ΠΈΠΌΡΠΌ +2Π±) ΠΠ½Π°ΡΠ΅ - ΡΡΠ°Π²Π½ΠΈΠ²Π°Π΅ΠΌ Ρ ΡΠ΅ΠΊΡΡΠΈΠΌ ΠΌΠΈΠ½ΠΈΠΌΡΠΌΠΎΠΌ, Π² ΡΠ»ΡΡΠ°Π΅, Π΅ΡΠ»ΠΈ Π½ΠΎΠ²ΡΠΉ Π±ΠΎΠ»ΡΡΠ΅, Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΠΌ Π΅Π³ΠΎ Π² Π²Π΅ΠΊΡΠΎΡ, ΠΌΠΈΠ½ΠΈΠΌΡΠΌ Π½Π΅ ΠΎΠ±Π½ΠΎΠ²Π»ΡΠ΅ΠΌ, ΠΈΠ½Π°ΡΠ΅ - ΠΎΡΠ±ΡΠ°ΡΡΠ²Π°Π΅ΠΌ +3) ΠΡΠ»ΠΈ Π·Π°ΠΏΠΎΠ»Π½Π΅Π½ - Π΄Π΅Π»Π°Π΅ΠΌ Partition Sort Ρ M-ΡΠΌ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠΌ Π² ΠΊΠ°ΡΠ΅ΡΡΠ²Π΅ ΡΠ΅ΠΏΠ°ΡΠ°ΡΠΎΡΠ° +4) Π’Π°ΠΊΠΈΠΌ ΠΎΠ±ΡΠ°Π·ΠΎΠΌ Π² Π»Π΅Π²ΠΎΠΉ ΠΏΠΎΠ»ΠΎΠ²ΠΈΠ½Π΅ Π²ΡΠ΅ Π·Π½Π°ΡΠ΅Π½ΠΈΡ Π±ΠΎΠ»ΡΡΠ΅ ΠΎΠ½ΡΡ
Π² ΠΏΡΠ°Π²ΠΎΠΉ, Π² ΠΏΠΎΠ·ΠΈΡΠΈΠΈ M ΡΡΠΎΠΈΡ ΡΠΎΠ²Π½ΠΎ M-ΡΠΉ ΡΠ»Π΅ΠΌΠ΅Π½Ρ ΡΠΎΡΡΠΈΡΠΎΠ²Π°Π½Π½ΠΎΠΉ ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°ΡΠ΅Π»ΡΠ½ΠΎΡΡΠΈ +5) Π£Π΄Π°Π»ΡΠ΅ΠΌ ΡΠ»Π΅ΠΌΠ΅Π½ΡΡ ΠΈΠ· ΠΏΡΠ°Π²ΠΎΠΉ ΠΏΠΎΠ»ΠΎΠ²ΠΈΠ½Ρ + +Π’ΡΡΠ΄ΠΎΡΠΌΠΊΠΎΡΡΡ: +ΠΠ° M Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΉ ΠΏΡΠΎΠΈΡΡ
ΠΎΠ΄ΠΈΡ 1 PartitionSort (ΡΡΡΠ΅Π΄Π½ΡΠ½Π½Π°Ρ ΠΎΡΠ΅Π½ΠΊΠ° ΡΡΡΠ΄ΠΎΡΠΌΠΊΠΎΡΡΠΈ - Π(M)) ΠΈ ΡΠ΄Π°Π»Π΅Π½ΠΈΠ΅ M ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ². Π’Π°ΠΊΠΈΠΌ ΠΎΠ±ΡΠ°Π·ΠΎΠΌ Π΄ΠΎΡΡΠΈΠ³Π°Π΅ΡΡΡ Π(1) ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΉ Π² ΡΡΠ΅Π΄Π½Π΅ΠΌ Π½Π° 1 Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅. ΠΠ»Ρ Π°Π»Π³ΠΎΡΠΈΡΠΌΠ° TLimitedHeap (library/cpp/containers/limited_heap) - ΡΡΠ° ΠΎΡΠ΅Π½ΠΊΠ° Π(log (M)) + +Π’Π΅ΡΡΡ: +ΠΠ° ΡΠ»ΡΡΠ°ΠΉΠ½ΡΡ
ΠΏΠΎΡΠΎΠΊΠ°Ρ
Π΄Π°Π½Π½ΡΡ
ΠΊΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ ΡΡΠ°Π²Π½Π΅Π½ΠΈΠΉ Ρ TopKeeper ΠΈ LimitedHeap ΠΎΠ΄ΠΈΠ½Π°ΠΊΠΎΠ²ΠΎ (ΠΏΡΠΎΠΈΡΡ
ΠΎΠ΄ΠΈΡ ΠΏΠΎΡΠΎΠΌΡ ΡΡΠΎ ΠΌΠΈΠ½ΠΈΠΌΡΠΌ Ρ ΠΏΠ΅ΡΠ²ΠΎΠ³ΠΎ ΠΎΠ±Π½ΠΎΠ²Π»ΡΠ΅ΡΡΡ ΡΠ°Π· Π² M Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΉ, Π° Ρ Π²ΡΠΎΡΠΎΠ³ΠΎ - ΠΏΡΠΈ ΠΊΠ°ΠΆΠ΄ΠΎΠΌ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΈ). Π₯ΡΠ΄ΡΠΈΠΉ ΡΠ»ΡΡΠ°ΠΉ LimitedHeap - ΡΠΎΡΡΠΈΡΠΎΠ²Π°Π½Π½Π°Ρ ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°ΡΠ΅Π»ΡΠ½ΠΎΡΡΡ (Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΠΌ ΠΊΠ°ΠΆΠ΄ΡΠΉ ΡΠ»Π΅ΠΌΠ΅Π½Ρ), Π½Π° ΡΠ°ΠΊΠΎΠΌ ΠΏΠΎΡΠΎΠΊΠ΅ Π΄Π»Ρ 2 000 000 000 int TopKeeper Π²ΡΠΈΠ³ΡΡΠ²Π°Π΅Ρ Π²ΠΎ ΠΌΠ½ΠΎΠ³ΠΎ ΡΠ°Π·. + +ΠΡΠ°Π½ΠΈΡΡ ΠΏΡΠΈΠΌΠ΅Π½ΠΈΠΌΠΎΡΡΠΈ: +ΠΡΠΈΠΌΠ΅Π½ΡΡΡ ΡΡΠΎΠΈΡ Π²ΡΠ΅Π³Π΄Π° Π²ΠΌΠ΅ΡΡΠΎ LimitedHeap (Ρ.ΠΊ. Π²ΡΠ΅Π³Π΄Π° Π½Π΅ Ρ
ΡΠΆΠ΅, Π° Π² Ρ
ΡΠ΄ΡΠ΅ΠΌ ΡΠ»ΡΡΠ°Π΅ - Π»ΡΡΡΠ΅) +ΠΠ³ΡΠ°Π½ΠΈΡΠ΅Π½ΠΈΠ΅ - Π½Π΅ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠΈΠ²Π°Π΅Ρ ΡΡΠ΅Π½Π°ΡΠΈΠΉ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ "ΡΠ΅ΡΠ΅Π΄ΡΡΡΠΈΠ΅ΡΡ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΡ / ΠΈΠ·Π²Π»Π΅ΡΠ΅Π½ΠΈΡ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ²" (ΡΠ»ΠΈΡΠΊΠΎΠΌ ΡΠ°ΡΡΠΎ Π±ΡΠ΄ΡΡ ΠΏΡΠΎΠΈΡΡ
ΠΎΠ΄ΠΈΡΡ Partiotion SortΡ) +ΠΠ»Ρ ΡΡΠΎΠ³ΠΎ, ΠΊΠΎΠ³Π΄Π° Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ² Π·Π°ΠΊΠΎΠ½ΡΠ΅Π½ΠΎ, Π΄ΠΎΠ»ΠΆΠ΅Π½ Π²ΡΠ·ΡΠ²Π°ΡΡΡΡ ΠΌΠ΅ΡΠΎΠ΄ Finalize(). ΠΠ»Ρ ΡΠΏΡΠΎΡΠ΅Π½ΠΈΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ Π°Π²ΡΠΎΠΌΠ°ΡΠΈΡΠ΅ΡΠΊΠΈΠΉ Finalize() Π½Π° GetNext() / Pop(). Π’Π΅ΠΌ Π½Π΅ ΠΌΠ΅Π½Π΅Π΅ ΡΠ²Π½ΡΠΉ Π²ΡΠ·ΠΎΠ² Finalize() ΠΏΠΎ-ΠΏΡΠ΅ΠΆΠ½Π΅ΠΌΡ Π²ΠΎΠ·ΠΌΠΎΠΆΠ΅Π½ - ΡΠ°ΠΊ ΠΌΠΎΠΆΠ½ΠΎ ΠΊΠΎΠ½ΡΡΠΎΠ»Π»ΠΈΡΠΎΠ²Π°ΡΡ ΠΌΠΎΠΌΠ΅Π½Ρ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ ΡΡΡΠ΄ΠΎΡΠΌΠΊΠΎΠΉ ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΈ NthElement(). ΠΠΎΡΠ»Π΅ ΡΠΎΠ³ΠΎ, ΠΊΠ°ΠΊ Π²ΡΠ΅ ΡΠ»Π΅ΠΌΠ΅Π½ΡΡ ΠΈΠ·Π²Π»Π΅ΡΠ΅Π½Ρ TopKeeper ΠΌΠΎΠΆΠ½ΠΎ ΠΏΠ΅ΡΠ΅ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΡ (Π΄Π»Ρ ΡΡΠΎΠ³ΠΎ ΠΆΠ΅ ΡΠ»ΡΠΆΠΈΡ ΠΌΠ΅ΡΠΎΠ΄ Reset()). +Π ΡΠΈΡΡΠ°ΡΠΈΠΈ ΠΊΠΎΠ³Π΄Π° Π½ΡΠΆΠ½Ρ ΡΠ΅ΡΠ΅Π΄ΡΡΡΠΈΠ΅ΡΡ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΡ / ΠΈΠ·Π²Π»Π΅ΡΠ΅Π½ΠΈΡ - ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠΉΡΠ΅ LimitedHeap + +ΠΡΠΈΠΌΠ΅ΡΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ: +library/cpp/containers/top_keeper/ut diff --git a/library/cpp/containers/top_keeper/top_keeper.cpp b/library/cpp/containers/top_keeper/top_keeper.cpp new file mode 100644 index 00000000000..29544bf227c --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper.cpp @@ -0,0 +1 @@ +#include "top_keeper.h" diff --git a/library/cpp/containers/top_keeper/top_keeper.h b/library/cpp/containers/top_keeper/top_keeper.h new file mode 100644 index 00000000000..2f282b5a9e1 --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper.h @@ -0,0 +1,256 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/generic/algorithm.h> +#include <util/generic/maybe.h> +#include <util/str_stl.h> + +template <class T, class TComparator = TGreater<T>, bool sort = true, class Alloc = std::allocator<T>> +class TTopKeeper { +private: + class TVectorWithMin { + private: + TVector<T, Alloc> Internal; + size_t HalfMaxSize; + TComparator Comparer; + size_t MinElementIndex; + + private: + void Reserve() { + Internal.reserve(2 * HalfMaxSize); + } + + template <class UT> + bool Insert(UT&& value) noexcept { + if (Y_UNLIKELY(0 == HalfMaxSize)) { + return false; + } + + if (Internal.size() < HalfMaxSize) { + if (Internal.empty() || Comparer(Internal[MinElementIndex], value)) { + MinElementIndex = Internal.size(); + Internal.push_back(std::forward<UT>(value)); + return true; + } + } else if (!Comparer(value, Internal[MinElementIndex])) { + return false; + } + + Internal.push_back(std::forward<UT>(value)); + + if (Internal.size() == (HalfMaxSize << 1)) { + Partition(); + } + + return true; + } + + public: + using value_type = T; + + TVectorWithMin(const size_t halfMaxSize, const TComparator& comp) + : HalfMaxSize(halfMaxSize) + , Comparer(comp) + { + Reserve(); + } + + template <class TAllocParam> + TVectorWithMin(const size_t halfMaxSize, const TComparator& comp, TAllocParam&& param) + : Internal(std::forward<TAllocParam>(param)) + , HalfMaxSize(halfMaxSize) + , Comparer(comp) + { + Reserve(); + } + + void SortAccending() { + Sort(Internal.begin(), Internal.end(), Comparer); + } + + void Partition() { + if (Y_UNLIKELY(HalfMaxSize == 0)) { + return; + } + if (Y_LIKELY(Internal.size() >= HalfMaxSize)) { + NthElement(Internal.begin(), Internal.begin() + HalfMaxSize - 1, Internal.end(), Comparer); + Internal.erase(Internal.begin() + HalfMaxSize, Internal.end()); + + //we should update MinElementIndex cause we just altered Internal + MinElementIndex = HalfMaxSize - 1; + } + } + + bool Push(const T& value) { + return Insert(value); + } + + bool Push(T&& value) { + return Insert(std::move(value)); + } + + template <class... TArgs> + bool Emplace(TArgs&&... args) { + return Insert(T(std::forward<TArgs>(args)...)); // TODO: make it "real" emplace, not that fake one + } + + void SetMaxSize(size_t newHalfMaxSize) { + HalfMaxSize = newHalfMaxSize; + Reserve(); + Partition(); + } + + size_t GetSize() const { + return Internal.size(); + } + + const auto& GetInternal() const { + return Internal; + } + + auto Extract() { + using std::swap; + + decltype(Internal) values; + swap(Internal, values); + Reset(); + return values; + } + + const T& Back() const { + return Internal.back(); + } + + void Pop() { + Internal.pop_back(); + } + + void Reset() { + Internal.clear(); + //MinElementIndex will reset itself when we start adding new values + } + }; + + void CheckNotFinalized() { + Y_ENSURE(!Finalized, "Cannot insert after finalizing (Pop() / GetNext() / Finalize())! " + "Use TLimitedHeap for this scenario"); + } + + size_t MaxSize; + const TComparator Comparer; + TVectorWithMin Internal; + bool Finalized; + +public: + TTopKeeper() + : MaxSize(0) + , Comparer() + , Internal(0, Comparer) + , Finalized(false) + { + } + + TTopKeeper(size_t maxSize, const TComparator& comp = TComparator()) + : MaxSize(maxSize) + , Comparer(comp) + , Internal(maxSize, comp) + , Finalized(false) + { + } + + template <class TAllocParam> + TTopKeeper(size_t maxSize, const TComparator& comp, TAllocParam&& param) + : MaxSize(maxSize) + , Comparer(comp) + , Internal(maxSize, comp, std::forward<TAllocParam>(param)) + , Finalized(false) + { + } + + void Finalize() { + if (Y_LIKELY(Finalized)) { + return; + } + Internal.Partition(); + if (sort) { + Internal.SortAccending(); + } + Finalized = true; + } + + const T& GetNext() { + Y_ENSURE(!IsEmpty(), "Trying GetNext from empty heap!"); + Finalize(); + return Internal.Back(); + } + + void Pop() { + Y_ENSURE(!IsEmpty(), "Trying Pop from empty heap!"); + Finalize(); + Internal.Pop(); + if (IsEmpty()) { + Reset(); + } + } + + T ExtractOne() { + Y_ENSURE(!IsEmpty(), "Trying ExtractOne from empty heap!"); + Finalize(); + auto value = std::move(Internal.Back()); + Internal.Pop(); + if (IsEmpty()) { + Reset(); + } + return value; + } + + auto Extract() { + Finalize(); + return Internal.Extract(); + } + + bool Insert(const T& value) { + CheckNotFinalized(); + return Internal.Push(value); + } + + bool Insert(T&& value) { + CheckNotFinalized(); + return Internal.Push(std::move(value)); + } + + template <class... TArgs> + bool Emplace(TArgs&&... args) { + CheckNotFinalized(); + return Internal.Emplace(std::forward<TArgs>(args)...); + } + + const auto& GetInternal() { + Finalize(); + return Internal.GetInternal(); + } + + bool IsEmpty() const { + return Internal.GetSize() == 0; + } + + size_t GetSize() const { + return Min(Internal.GetSize(), MaxSize); + } + + size_t GetMaxSize() const { + return MaxSize; + } + + void SetMaxSize(size_t newMaxSize) { + Y_ENSURE(!Finalized, "Cannot resize after finalizing (Pop() / GetNext() / Finalize())! " + "Use TLimitedHeap for this scenario"); + MaxSize = newMaxSize; + Internal.SetMaxSize(newMaxSize); + } + + void Reset() { + Internal.Reset(); + Finalized = false; + } +}; diff --git a/library/cpp/containers/top_keeper/top_keeper/README.md b/library/cpp/containers/top_keeper/top_keeper/README.md new file mode 100644 index 00000000000..f160fb1c015 --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper/README.md @@ -0,0 +1,26 @@ +TopKeeper - ΡΡΡΡΠΊΡΡΡΠ° Π΄Π°Π½Π½ΡΡ
Π΄Π»Ρ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠ°Π½ΠΈΡ "top M from stream" +ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ Π΄Π»Ρ Π²ΡΠ±ΠΎΡΠΊΠΈ Π½Π°ΠΈΠ±ΠΎΠ»ΡΡΠΈΡ
/ Π½Π°ΠΈΠΌΠ΅Π½ΡΡΠΈΡ
ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ² Π·Π° ΠΎΠ΄ΠΈΠ½ ΠΏΡΠΎΡ
ΠΎΠ΄ (ΠΏΠΎΠ»Π΅Π·Π½ΠΎ ΠΏΡΠΈ ΡΠΈΠ»ΡΡΡΠ°ΡΠΈΠΈ) + +ΠΡΡΡΡ Π²Ρ
ΠΎΠ΄Π½ΠΎΠΉ ΠΏΠΎΡΠΎΠΊ ΡΠΎΡΡΠΎΠΈΡ ΠΈΠ· N ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ², ΠΈΠ· ΠΊΠΎΡΠΎΡΡΡ
Π½ΡΠΆΠ½ΠΎ ΠΎΡΡΠΈΠ»ΡΡΡΠΎΠ²Π°ΡΡ M Ρ Π½Π°ΠΈΠ±ΠΎΠ»ΡΡΠΈΠΌ Π·Π½Π°ΡΠ΅Π½ΠΈΠ΅ΠΌ. +ΠΠ»Π³ΠΎΡΠΈΡΠΌ (Π΄Π»Ρ ΡΠ»ΡΡΠ°Ρ top max M): +1) ΠΡΠ΄Π΅Π»ΠΈΠΌ Π²Π΅ΠΊΡΠΎΡ ΡΠ°Π·ΠΌΠ΅ΡΠ° 2 * M +2Π°) ΠΡΠ»ΠΈ Π²Π΅ΠΊΡΠΎΡ Π·Π°ΠΏΠΎΠ»Π½Π΅Π½ ΠΌΠ΅Π½ΡΡΠ΅, ΡΠ΅ΠΌ Π½Π°ΠΏΠΎΠ»ΠΎΠ²ΠΈΠ½Ρ - Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΠΌ ΠΎΡΠ΅ΡΠ΅Π΄Π½ΠΎΠΉ ΡΠ»Π΅ΠΌΠ΅Π½Ρ, ΠΎΠ±Π½ΠΎΠ²Π»ΡΠ΅ΠΌ ΠΌΠΈΠ½ΠΈΠΌΡΠΌ +2Π±) ΠΠ½Π°ΡΠ΅ - ΡΡΠ°Π²Π½ΠΈΠ²Π°Π΅ΠΌ Ρ ΡΠ΅ΠΊΡΡΠΈΠΌ ΠΌΠΈΠ½ΠΈΠΌΡΠΌΠΎΠΌ, Π² ΡΠ»ΡΡΠ°Π΅, Π΅ΡΠ»ΠΈ Π½ΠΎΠ²ΡΠΉ Π±ΠΎΠ»ΡΡΠ΅, Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΠΌ Π΅Π³ΠΎ Π² Π²Π΅ΠΊΡΠΎΡ, ΠΌΠΈΠ½ΠΈΠΌΡΠΌ Π½Π΅ ΠΎΠ±Π½ΠΎΠ²Π»ΡΠ΅ΠΌ, ΠΈΠ½Π°ΡΠ΅ - ΠΎΡΠ±ΡΠ°ΡΡΠ²Π°Π΅ΠΌ +3) ΠΡΠ»ΠΈ Π·Π°ΠΏΠΎΠ»Π½Π΅Π½ - Π΄Π΅Π»Π°Π΅ΠΌ Partition Sort Ρ M-ΡΠΌ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠΌ Π² ΠΊΠ°ΡΠ΅ΡΡΠ²Π΅ ΡΠ΅ΠΏΠ°ΡΠ°ΡΠΎΡΠ° +4) Π’Π°ΠΊΠΈΠΌ ΠΎΠ±ΡΠ°Π·ΠΎΠΌ Π² Π»Π΅Π²ΠΎΠΉ ΠΏΠΎΠ»ΠΎΠ²ΠΈΠ½Π΅ Π²ΡΠ΅ Π·Π½Π°ΡΠ΅Π½ΠΈΡ Π±ΠΎΠ»ΡΡΠ΅ ΠΎΠ½ΡΡ
Π² ΠΏΡΠ°Π²ΠΎΠΉ, Π² ΠΏΠΎΠ·ΠΈΡΠΈΠΈ M ΡΡΠΎΠΈΡ ΡΠΎΠ²Π½ΠΎ M-ΡΠΉ ΡΠ»Π΅ΠΌΠ΅Π½Ρ ΡΠΎΡΡΠΈΡΠΎΠ²Π°Π½Π½ΠΎΠΉ ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°ΡΠ΅Π»ΡΠ½ΠΎΡΡΠΈ +5) Π£Π΄Π°Π»ΡΠ΅ΠΌ ΡΠ»Π΅ΠΌΠ΅Π½ΡΡ ΠΈΠ· ΠΏΡΠ°Π²ΠΎΠΉ ΠΏΠΎΠ»ΠΎΠ²ΠΈΠ½Ρ + +Π’ΡΡΠ΄ΠΎΡΠΌΠΊΠΎΡΡΡ: +ΠΠ° M Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΉ ΠΏΡΠΎΠΈΡΡ
ΠΎΠ΄ΠΈΡ 1 PartitionSort (ΡΡΡΠ΅Π΄Π½ΡΠ½Π½Π°Ρ ΠΎΡΠ΅Π½ΠΊΠ° ΡΡΡΠ΄ΠΎΡΠΌΠΊΠΎΡΡΠΈ - Π(M)) ΠΈ ΡΠ΄Π°Π»Π΅Π½ΠΈΠ΅ M ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ². Π’Π°ΠΊΠΈΠΌ ΠΎΠ±ΡΠ°Π·ΠΎΠΌ Π΄ΠΎΡΡΠΈΠ³Π°Π΅ΡΡΡ Π(1) ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΉ Π² ΡΡΠ΅Π΄Π½Π΅ΠΌ Π½Π° 1 Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅. ΠΠ»Ρ Π°Π»Π³ΠΎΡΠΈΡΠΌΠ° TLimitedHeap (library/cpp/containers/limited_heap) - ΡΡΠ° ΠΎΡΠ΅Π½ΠΊΠ° Π(log (M)) + +Π’Π΅ΡΡΡ: +ΠΠ° ΡΠ»ΡΡΠ°ΠΉΠ½ΡΡ
ΠΏΠΎΡΠΎΠΊΠ°Ρ
Π΄Π°Π½Π½ΡΡ
ΠΊΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ ΡΡΠ°Π²Π½Π΅Π½ΠΈΠΉ Ρ TopKeeper ΠΈ LimitedHeap ΠΎΠ΄ΠΈΠ½Π°ΠΊΠΎΠ²ΠΎ (ΠΏΡΠΎΠΈΡΡ
ΠΎΠ΄ΠΈΡ ΠΏΠΎΡΠΎΠΌΡ ΡΡΠΎ ΠΌΠΈΠ½ΠΈΠΌΡΠΌ Ρ ΠΏΠ΅ΡΠ²ΠΎΠ³ΠΎ ΠΎΠ±Π½ΠΎΠ²Π»ΡΠ΅ΡΡΡ ΡΠ°Π· Π² M Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΉ, Π° Ρ Π²ΡΠΎΡΠΎΠ³ΠΎ - ΠΏΡΠΈ ΠΊΠ°ΠΆΠ΄ΠΎΠΌ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΈ). Π₯ΡΠ΄ΡΠΈΠΉ ΡΠ»ΡΡΠ°ΠΉ LimitedHeap - ΡΠΎΡΡΠΈΡΠΎΠ²Π°Π½Π½Π°Ρ ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°ΡΠ΅Π»ΡΠ½ΠΎΡΡΡ (Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΠΌ ΠΊΠ°ΠΆΠ΄ΡΠΉ ΡΠ»Π΅ΠΌΠ΅Π½Ρ), Π½Π° ΡΠ°ΠΊΠΎΠΌ ΠΏΠΎΡΠΎΠΊΠ΅ Π΄Π»Ρ 2 000 000 000 int TopKeeper Π²ΡΠΈΠ³ΡΡΠ²Π°Π΅Ρ Π²ΠΎ ΠΌΠ½ΠΎΠ³ΠΎ ΡΠ°Π·. + +ΠΡΠ°Π½ΠΈΡΡ ΠΏΡΠΈΠΌΠ΅Π½ΠΈΠΌΠΎΡΡΠΈ: +ΠΡΠΈΠΌΠ΅Π½ΡΡΡ ΡΡΠΎΠΈΡ Π²ΡΠ΅Π³Π΄Π° Π²ΠΌΠ΅ΡΡΠΎ LimitedHeap (Ρ.ΠΊ. Π²ΡΠ΅Π³Π΄Π° Π½Π΅ Ρ
ΡΠΆΠ΅, Π° Π² Ρ
ΡΠ΄ΡΠ΅ΠΌ ΡΠ»ΡΡΠ°Π΅ - Π»ΡΡΡΠ΅) +ΠΠ³ΡΠ°Π½ΠΈΡΠ΅Π½ΠΈΠ΅ - Π½Π΅ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠΈΠ²Π°Π΅Ρ ΡΡΠ΅Π½Π°ΡΠΈΠΉ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ "ΡΠ΅ΡΠ΅Π΄ΡΡΡΠΈΠ΅ΡΡ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΡ / ΠΈΠ·Π²Π»Π΅ΡΠ΅Π½ΠΈΡ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ²" (ΡΠ»ΠΈΡΠΊΠΎΠΌ ΡΠ°ΡΡΠΎ Π±ΡΠ΄ΡΡ ΠΏΡΠΎΠΈΡΡ
ΠΎΠ΄ΠΈΡΡ Partiotion SortΡ) +ΠΠ»Ρ ΡΡΠΎΠ³ΠΎ, ΠΊΠΎΠ³Π΄Π° Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠ² Π·Π°ΠΊΠΎΠ½ΡΠ΅Π½ΠΎ, Π΄ΠΎΠ»ΠΆΠ΅Π½ Π²ΡΠ·ΡΠ²Π°ΡΡΡΡ ΠΌΠ΅ΡΠΎΠ΄ Finalize(). ΠΠ»Ρ ΡΠΏΡΠΎΡΠ΅Π½ΠΈΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ Π°Π²ΡΠΎΠΌΠ°ΡΠΈΡΠ΅ΡΠΊΠΈΠΉ Finalize() Π½Π° GetNext() / Pop(). Π’Π΅ΠΌ Π½Π΅ ΠΌΠ΅Π½Π΅Π΅ ΡΠ²Π½ΡΠΉ Π²ΡΠ·ΠΎΠ² Finalize() ΠΏΠΎ-ΠΏΡΠ΅ΠΆΠ½Π΅ΠΌΡ Π²ΠΎΠ·ΠΌΠΎΠΆΠ΅Π½ - ΡΠ°ΠΊ ΠΌΠΎΠΆΠ½ΠΎ ΠΊΠΎΠ½ΡΡΠΎΠ»Π»ΠΈΡΠΎΠ²Π°ΡΡ ΠΌΠΎΠΌΠ΅Π½Ρ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ ΡΡΡΠ΄ΠΎΡΠΌΠΊΠΎΠΉ ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΈ NthElement(). ΠΠΎΡΠ»Π΅ ΡΠΎΠ³ΠΎ, ΠΊΠ°ΠΊ Π²ΡΠ΅ ΡΠ»Π΅ΠΌΠ΅Π½ΡΡ ΠΈΠ·Π²Π»Π΅ΡΠ΅Π½Ρ TopKeeper ΠΌΠΎΠΆΠ½ΠΎ ΠΏΠ΅ΡΠ΅ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΡ (Π΄Π»Ρ ΡΡΠΎΠ³ΠΎ ΠΆΠ΅ ΡΠ»ΡΠΆΠΈΡ ΠΌΠ΅ΡΠΎΠ΄ Reset()). +Π ΡΠΈΡΡΠ°ΡΠΈΠΈ ΠΊΠΎΠ³Π΄Π° Π½ΡΠΆΠ½Ρ ΡΠ΅ΡΠ΅Π΄ΡΡΡΠΈΠ΅ΡΡ Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΡ / ΠΈΠ·Π²Π»Π΅ΡΠ΅Π½ΠΈΡ - ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠΉΡΠ΅ LimitedHeap + +ΠΡΠΈΠΌΠ΅ΡΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ: +library/cpp/containers/top_keeper/ut diff --git a/library/cpp/containers/top_keeper/top_keeper/top_keeper.cpp b/library/cpp/containers/top_keeper/top_keeper/top_keeper.cpp new file mode 100644 index 00000000000..29544bf227c --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper/top_keeper.cpp @@ -0,0 +1 @@ +#include "top_keeper.h" diff --git a/library/cpp/containers/top_keeper/top_keeper/top_keeper.h b/library/cpp/containers/top_keeper/top_keeper/top_keeper.h new file mode 100644 index 00000000000..2f282b5a9e1 --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper/top_keeper.h @@ -0,0 +1,256 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/generic/algorithm.h> +#include <util/generic/maybe.h> +#include <util/str_stl.h> + +template <class T, class TComparator = TGreater<T>, bool sort = true, class Alloc = std::allocator<T>> +class TTopKeeper { +private: + class TVectorWithMin { + private: + TVector<T, Alloc> Internal; + size_t HalfMaxSize; + TComparator Comparer; + size_t MinElementIndex; + + private: + void Reserve() { + Internal.reserve(2 * HalfMaxSize); + } + + template <class UT> + bool Insert(UT&& value) noexcept { + if (Y_UNLIKELY(0 == HalfMaxSize)) { + return false; + } + + if (Internal.size() < HalfMaxSize) { + if (Internal.empty() || Comparer(Internal[MinElementIndex], value)) { + MinElementIndex = Internal.size(); + Internal.push_back(std::forward<UT>(value)); + return true; + } + } else if (!Comparer(value, Internal[MinElementIndex])) { + return false; + } + + Internal.push_back(std::forward<UT>(value)); + + if (Internal.size() == (HalfMaxSize << 1)) { + Partition(); + } + + return true; + } + + public: + using value_type = T; + + TVectorWithMin(const size_t halfMaxSize, const TComparator& comp) + : HalfMaxSize(halfMaxSize) + , Comparer(comp) + { + Reserve(); + } + + template <class TAllocParam> + TVectorWithMin(const size_t halfMaxSize, const TComparator& comp, TAllocParam&& param) + : Internal(std::forward<TAllocParam>(param)) + , HalfMaxSize(halfMaxSize) + , Comparer(comp) + { + Reserve(); + } + + void SortAccending() { + Sort(Internal.begin(), Internal.end(), Comparer); + } + + void Partition() { + if (Y_UNLIKELY(HalfMaxSize == 0)) { + return; + } + if (Y_LIKELY(Internal.size() >= HalfMaxSize)) { + NthElement(Internal.begin(), Internal.begin() + HalfMaxSize - 1, Internal.end(), Comparer); + Internal.erase(Internal.begin() + HalfMaxSize, Internal.end()); + + //we should update MinElementIndex cause we just altered Internal + MinElementIndex = HalfMaxSize - 1; + } + } + + bool Push(const T& value) { + return Insert(value); + } + + bool Push(T&& value) { + return Insert(std::move(value)); + } + + template <class... TArgs> + bool Emplace(TArgs&&... args) { + return Insert(T(std::forward<TArgs>(args)...)); // TODO: make it "real" emplace, not that fake one + } + + void SetMaxSize(size_t newHalfMaxSize) { + HalfMaxSize = newHalfMaxSize; + Reserve(); + Partition(); + } + + size_t GetSize() const { + return Internal.size(); + } + + const auto& GetInternal() const { + return Internal; + } + + auto Extract() { + using std::swap; + + decltype(Internal) values; + swap(Internal, values); + Reset(); + return values; + } + + const T& Back() const { + return Internal.back(); + } + + void Pop() { + Internal.pop_back(); + } + + void Reset() { + Internal.clear(); + //MinElementIndex will reset itself when we start adding new values + } + }; + + void CheckNotFinalized() { + Y_ENSURE(!Finalized, "Cannot insert after finalizing (Pop() / GetNext() / Finalize())! " + "Use TLimitedHeap for this scenario"); + } + + size_t MaxSize; + const TComparator Comparer; + TVectorWithMin Internal; + bool Finalized; + +public: + TTopKeeper() + : MaxSize(0) + , Comparer() + , Internal(0, Comparer) + , Finalized(false) + { + } + + TTopKeeper(size_t maxSize, const TComparator& comp = TComparator()) + : MaxSize(maxSize) + , Comparer(comp) + , Internal(maxSize, comp) + , Finalized(false) + { + } + + template <class TAllocParam> + TTopKeeper(size_t maxSize, const TComparator& comp, TAllocParam&& param) + : MaxSize(maxSize) + , Comparer(comp) + , Internal(maxSize, comp, std::forward<TAllocParam>(param)) + , Finalized(false) + { + } + + void Finalize() { + if (Y_LIKELY(Finalized)) { + return; + } + Internal.Partition(); + if (sort) { + Internal.SortAccending(); + } + Finalized = true; + } + + const T& GetNext() { + Y_ENSURE(!IsEmpty(), "Trying GetNext from empty heap!"); + Finalize(); + return Internal.Back(); + } + + void Pop() { + Y_ENSURE(!IsEmpty(), "Trying Pop from empty heap!"); + Finalize(); + Internal.Pop(); + if (IsEmpty()) { + Reset(); + } + } + + T ExtractOne() { + Y_ENSURE(!IsEmpty(), "Trying ExtractOne from empty heap!"); + Finalize(); + auto value = std::move(Internal.Back()); + Internal.Pop(); + if (IsEmpty()) { + Reset(); + } + return value; + } + + auto Extract() { + Finalize(); + return Internal.Extract(); + } + + bool Insert(const T& value) { + CheckNotFinalized(); + return Internal.Push(value); + } + + bool Insert(T&& value) { + CheckNotFinalized(); + return Internal.Push(std::move(value)); + } + + template <class... TArgs> + bool Emplace(TArgs&&... args) { + CheckNotFinalized(); + return Internal.Emplace(std::forward<TArgs>(args)...); + } + + const auto& GetInternal() { + Finalize(); + return Internal.GetInternal(); + } + + bool IsEmpty() const { + return Internal.GetSize() == 0; + } + + size_t GetSize() const { + return Min(Internal.GetSize(), MaxSize); + } + + size_t GetMaxSize() const { + return MaxSize; + } + + void SetMaxSize(size_t newMaxSize) { + Y_ENSURE(!Finalized, "Cannot resize after finalizing (Pop() / GetNext() / Finalize())! " + "Use TLimitedHeap for this scenario"); + MaxSize = newMaxSize; + Internal.SetMaxSize(newMaxSize); + } + + void Reset() { + Internal.Reset(); + Finalized = false; + } +}; diff --git a/library/cpp/containers/top_keeper/top_keeper/ut/top_keeper_ut.cpp b/library/cpp/containers/top_keeper/top_keeper/ut/top_keeper_ut.cpp new file mode 100644 index 00000000000..a938279025d --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper/ut/top_keeper_ut.cpp @@ -0,0 +1,222 @@ +#include <library/cpp/containers/limited_heap/limited_heap.h> +#include <library/cpp/containers/top_keeper/top_keeper.h> +#include <library/cpp/testing/unittest/registar.h> +#include <util/random/random.h> + +static ui32 seed = 3; +ui32 Rnd() { + seed = seed * 5 + 1; + return seed; +} + +/* + * Tests for TTopKeeper + */ +Y_UNIT_TEST_SUITE(TTopKeeperTest) { + // Tests correctness on usual examples + Y_UNIT_TEST(CorrectnessTest) { + int m = 20000; + + TLimitedHeap<std::pair<int, int>> h1(m); + TTopKeeper<std::pair<int, int>> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert({r, -r}); + h2.Emplace(r, -r); + } + + h2.Finalize(); + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + // Tests on zero-size correctness + Y_UNIT_TEST(ZeroSizeCorrectnes) { + TTopKeeper<int> h(0); + for (int i = 0; i < 100; ++i) { + h.Insert(i % 10 + i / 10); + } + h.Finalize(); + UNIT_ASSERT(h.IsEmpty()); + } + + // Tests SetMaxSize behaviour + Y_UNIT_TEST(SetMaxSizeTest) { + int m = 20000; + TLimitedHeap<int> h1(m); + TTopKeeper<int> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + h1.SetMaxSize(m / 3); + h2.SetMaxSize(m / 3); + h2.Finalize(); + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + // Tests reuse behavior + Y_UNIT_TEST(ReuseTest) { + int m = 20000; + TLimitedHeap<int> h1(m); + TTopKeeper<int> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + + n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + // Tests reset behavior + Y_UNIT_TEST(ResetTest) { + int m = 20000; + TLimitedHeap<int> h1(m); + TTopKeeper<int> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + for (int i = 0; i < m / 2; ++i) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + + h2.Reset(); + while (!h1.IsEmpty()) { + h1.PopMin(); + } + + n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + Y_UNIT_TEST(PreRegressionTest) { + typedef std::pair<float, unsigned int> TElementType; + + const size_t randomTriesCount = 128; + for (size_t i1 = 0; i1 < randomTriesCount; ++i1) { + const size_t desiredElementsCount = RandomNumber<size_t>(5) + 1; + TLimitedHeap<TElementType> h1(desiredElementsCount); + TTopKeeper<TElementType> h2(desiredElementsCount); + + const size_t elementsToInsert = RandomNumber<size_t>(10) + desiredElementsCount; + UNIT_ASSERT_C(desiredElementsCount <= elementsToInsert, "Test internal invariant is broken"); + for (size_t i2 = 0; i2 < elementsToInsert; ++i2) { + const auto f = RandomNumber<float>(); + const auto id = RandomNumber<unsigned int>(); + + h1.Insert(TElementType(f, id)); + h2.Insert(TElementType(f, id)); + } + + h2.Finalize(); + + //we inserted enough elements to guarantee this outcome + UNIT_ASSERT_EQUAL(h1.GetSize(), desiredElementsCount); + UNIT_ASSERT_EQUAL(h2.GetSize(), desiredElementsCount); + + const auto n = h2.GetSize(); + for (size_t i3 = 0; i3 < n; ++i3) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + } + + Y_UNIT_TEST(CopyKeeperRegressionCase) { + using TKeeper = TTopKeeper<float>; + TVector<TKeeper> v(2, TKeeper(200)); + auto& k = v[1]; + for (size_t i = 0; i < 100; ++i) { + k.Insert(RandomNumber<float>()); + } + k.Finalize(); + } + + Y_UNIT_TEST(ExtractTest) { + TTopKeeper<size_t> keeper(100); + for (size_t i = 0; i < 100; ++i) { + keeper.Insert(i); + } + + auto values = keeper.Extract(); + UNIT_ASSERT_EQUAL(values.size(), 100); + Sort(values); + + for (size_t i = 0; i < 100; ++i) { + UNIT_ASSERT_EQUAL(values[i], i); + } + + UNIT_ASSERT(keeper.IsEmpty()); + } +} diff --git a/library/cpp/containers/top_keeper/top_keeper/ut/ya.make b/library/cpp/containers/top_keeper/top_keeper/ut/ya.make new file mode 100644 index 00000000000..8553389e170 --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper/ut/ya.make @@ -0,0 +1,12 @@ +UNITTEST_FOR(library/cpp/containers/top_keeper) + +OWNER( + mbusel + rmplstiltskin +) + +SRCS( + top_keeper_ut.cpp +) + +END() diff --git a/library/cpp/containers/top_keeper/top_keeper/ya.make b/library/cpp/containers/top_keeper/top_keeper/ya.make new file mode 100644 index 00000000000..79be94ae2bf --- /dev/null +++ b/library/cpp/containers/top_keeper/top_keeper/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(mbusel) + +SRCS( + top_keeper.cpp +) + +END() diff --git a/library/cpp/containers/top_keeper/ut/top_keeper_ut.cpp b/library/cpp/containers/top_keeper/ut/top_keeper_ut.cpp new file mode 100644 index 00000000000..a938279025d --- /dev/null +++ b/library/cpp/containers/top_keeper/ut/top_keeper_ut.cpp @@ -0,0 +1,222 @@ +#include <library/cpp/containers/limited_heap/limited_heap.h> +#include <library/cpp/containers/top_keeper/top_keeper.h> +#include <library/cpp/testing/unittest/registar.h> +#include <util/random/random.h> + +static ui32 seed = 3; +ui32 Rnd() { + seed = seed * 5 + 1; + return seed; +} + +/* + * Tests for TTopKeeper + */ +Y_UNIT_TEST_SUITE(TTopKeeperTest) { + // Tests correctness on usual examples + Y_UNIT_TEST(CorrectnessTest) { + int m = 20000; + + TLimitedHeap<std::pair<int, int>> h1(m); + TTopKeeper<std::pair<int, int>> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert({r, -r}); + h2.Emplace(r, -r); + } + + h2.Finalize(); + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + // Tests on zero-size correctness + Y_UNIT_TEST(ZeroSizeCorrectnes) { + TTopKeeper<int> h(0); + for (int i = 0; i < 100; ++i) { + h.Insert(i % 10 + i / 10); + } + h.Finalize(); + UNIT_ASSERT(h.IsEmpty()); + } + + // Tests SetMaxSize behaviour + Y_UNIT_TEST(SetMaxSizeTest) { + int m = 20000; + TLimitedHeap<int> h1(m); + TTopKeeper<int> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + h1.SetMaxSize(m / 3); + h2.SetMaxSize(m / 3); + h2.Finalize(); + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + // Tests reuse behavior + Y_UNIT_TEST(ReuseTest) { + int m = 20000; + TLimitedHeap<int> h1(m); + TTopKeeper<int> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + + n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + // Tests reset behavior + Y_UNIT_TEST(ResetTest) { + int m = 20000; + TLimitedHeap<int> h1(m); + TTopKeeper<int> h2(m); + + int n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + for (int i = 0; i < m / 2; ++i) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + + h2.Reset(); + while (!h1.IsEmpty()) { + h1.PopMin(); + } + + n = 20000000; + while (n--) { + int r = int(Rnd()); + + h1.Insert(r); + h2.Insert(r); + } + + UNIT_ASSERT_EQUAL(h1.GetSize(), h2.GetSize()); + + while (!h1.IsEmpty()) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + + Y_UNIT_TEST(PreRegressionTest) { + typedef std::pair<float, unsigned int> TElementType; + + const size_t randomTriesCount = 128; + for (size_t i1 = 0; i1 < randomTriesCount; ++i1) { + const size_t desiredElementsCount = RandomNumber<size_t>(5) + 1; + TLimitedHeap<TElementType> h1(desiredElementsCount); + TTopKeeper<TElementType> h2(desiredElementsCount); + + const size_t elementsToInsert = RandomNumber<size_t>(10) + desiredElementsCount; + UNIT_ASSERT_C(desiredElementsCount <= elementsToInsert, "Test internal invariant is broken"); + for (size_t i2 = 0; i2 < elementsToInsert; ++i2) { + const auto f = RandomNumber<float>(); + const auto id = RandomNumber<unsigned int>(); + + h1.Insert(TElementType(f, id)); + h2.Insert(TElementType(f, id)); + } + + h2.Finalize(); + + //we inserted enough elements to guarantee this outcome + UNIT_ASSERT_EQUAL(h1.GetSize(), desiredElementsCount); + UNIT_ASSERT_EQUAL(h2.GetSize(), desiredElementsCount); + + const auto n = h2.GetSize(); + for (size_t i3 = 0; i3 < n; ++i3) { + UNIT_ASSERT_EQUAL(h1.GetMin(), h2.GetNext()); + h1.PopMin(); + h2.Pop(); + } + } + } + + Y_UNIT_TEST(CopyKeeperRegressionCase) { + using TKeeper = TTopKeeper<float>; + TVector<TKeeper> v(2, TKeeper(200)); + auto& k = v[1]; + for (size_t i = 0; i < 100; ++i) { + k.Insert(RandomNumber<float>()); + } + k.Finalize(); + } + + Y_UNIT_TEST(ExtractTest) { + TTopKeeper<size_t> keeper(100); + for (size_t i = 0; i < 100; ++i) { + keeper.Insert(i); + } + + auto values = keeper.Extract(); + UNIT_ASSERT_EQUAL(values.size(), 100); + Sort(values); + + for (size_t i = 0; i < 100; ++i) { + UNIT_ASSERT_EQUAL(values[i], i); + } + + UNIT_ASSERT(keeper.IsEmpty()); + } +} diff --git a/library/cpp/containers/top_keeper/ut/ya.make b/library/cpp/containers/top_keeper/ut/ya.make new file mode 100644 index 00000000000..42cfdd6f133 --- /dev/null +++ b/library/cpp/containers/top_keeper/ut/ya.make @@ -0,0 +1,12 @@ +UNITTEST_FOR(library/cpp/containers/top_keeper) + +OWNER( + ilnurkh + rmplstiltskin +) + +SRCS( + top_keeper_ut.cpp +) + +END() diff --git a/library/cpp/containers/top_keeper/ya.make b/library/cpp/containers/top_keeper/ya.make new file mode 100644 index 00000000000..ed206a1df98 --- /dev/null +++ b/library/cpp/containers/top_keeper/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +OWNER(ilnurkh) + +SRCS( + top_keeper.cpp +) + +END() + +RECURSE_FOR_TESTS(ut) + + diff --git a/library/cpp/containers/ya.make b/library/cpp/containers/ya.make new file mode 100644 index 00000000000..4b1b315e6a5 --- /dev/null +++ b/library/cpp/containers/ya.make @@ -0,0 +1,71 @@ +RECURSE( + 2d_array + absl_flat_hash + absl_tstring_flat_hash + atomizer + bitseq + bitseq/ut + compact_vector + compact_vector/ut + comptrie + comptrie/loader + comptrie/loader/ut + comptrie/ut + comptrie/benchmark + concurrent_hash + concurrent_hash_set + concurrent_hash_set/ut + dense_hash + dense_hash/dense_hash_benchmark + dense_hash/ut + dictionary + dictionary/ut + disjoint_interval_tree + disjoint_interval_tree/ut + ext_priority_queue + ext_priority_queue/ut + fast_trie + fast_trie/ut + flat_hash + flat_hash/benchmark + flat_hash/lib + flat_hash/lib/concepts + flat_hash/lib/fuzz + flat_hash/lib/ut + flat_hash/ut + hash_trie + heap_dict + heap_dict/benchmark + heap_dict/ut + intrusive_avl_tree + intrusive_avl_tree/ut + intrusive_hash + intrusive_hash/ut + intrusive_rb_tree + intrusive_rb_tree/fuzz + intrusive_rb_tree/ut + limited_heap + mh_heap + mh_heap/ut + paged_vector + paged_vector/ut + rarefied_array + ring_buffer + safe_vector + safe_vector/ut + segmented_pool_container + sorted_vector + sorted_vector/ut + spars_ar + stack_array + stack_array/ut + stack_vector + str_hash + str_map + top_keeper + top_keeper/ut + two_level_hash + two_level_hash/ut + vp_tree + vp_tree/ut +) |