diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2023-12-02 01:45:21 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2023-12-02 02:42:50 +0300 |
commit | 9c43d58f75cf086b744cf4fe2ae180e8f37e4a0c (patch) | |
tree | 9f88a486917d371d099cd712efd91b4c122d209d /contrib/python/marisa-trie/marisa/grimoire/vector | |
parent | 32fb6dda1feb24f9ab69ece5df0cb9ec238ca5e6 (diff) | |
download | ydb-9c43d58f75cf086b744cf4fe2ae180e8f37e4a0c.tar.gz |
Intermediate changes
Diffstat (limited to 'contrib/python/marisa-trie/marisa/grimoire/vector')
6 files changed, 1662 insertions, 0 deletions
diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.cc b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.cc new file mode 100644 index 0000000000..a5abc69319 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.cc @@ -0,0 +1,825 @@ +#include "pop-count.h" +#include "bit-vector.h" + +namespace marisa { +namespace grimoire { +namespace vector { +namespace { + +const UInt8 SELECT_TABLE[8][256] = { + { + 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0 + }, + { + 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1, + 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1, + 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1, + 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1, + 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1 + }, + { + 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2, + 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2, + 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2, + 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2, + 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2, + 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2, + 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2, + 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2, + 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2, + 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2, + 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, + 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3, + 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3, + 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3, + 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, + 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3, + 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3, + 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3, + 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 + } +}; + +#if MARISA_WORD_SIZE == 64 +const UInt64 MASK_55 = 0x5555555555555555ULL; +const UInt64 MASK_33 = 0x3333333333333333ULL; +const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL; +const UInt64 MASK_01 = 0x0101010101010101ULL; +const UInt64 MASK_80 = 0x8080808080808080ULL; + +std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) { + UInt64 counts; + { + #if defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + __m128i lower_nibbles = _mm_cvtsi64_si128(unit & 0x0F0F0F0F0F0F0F0FULL); + __m128i upper_nibbles = _mm_cvtsi64_si128(unit & 0xF0F0F0F0F0F0F0F0ULL); + upper_nibbles = _mm_srli_epi32(upper_nibbles, 4); + + __m128i lower_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles); + __m128i upper_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles); + + counts = _mm_cvtsi128_si64(_mm_add_epi8(lower_counts, upper_counts)); + #else // defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + counts = unit - ((unit >> 1) & MASK_55); + counts = (counts & MASK_33) + ((counts >> 2) & MASK_33); + counts = (counts + (counts >> 4)) & MASK_0F; + #endif // defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + counts *= MASK_01; + } + + #if defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + UInt8 skip; + { + __m128i x = _mm_cvtsi64_si128((i + 1) * MASK_01); + __m128i y = _mm_cvtsi64_si128(counts); + x = _mm_cmpgt_epi8(x, y); + skip = (UInt8)PopCount::count(_mm_cvtsi128_si64(x)); + } + #else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01); + #ifdef _MSC_VER + unsigned long skip; + ::_BitScanForward64(&skip, (x & MASK_80) >> 7); + #else // _MSC_VER + const int skip = ::__builtin_ctzll((x & MASK_80) >> 7); + #endif // _MSC_VER + #endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + + bit_id += skip; + unit >>= skip; + i -= ((counts << 8) >> skip) & 0xFF; + + return bit_id + SELECT_TABLE[i][unit & 0xFF]; +} +#else // MARISA_WORD_SIZE == 64 + #ifdef MARISA_USE_SSE2 +const UInt8 POPCNT_TABLE[256] = { + 0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 32, 40, 40, 48, 40, 48, 48, 56, 40, 48, 48, 56, 48, 56, 56, 64 +}; + +std::size_t select_bit(std::size_t i, std::size_t bit_id, + UInt32 unit_lo, UInt32 unit_hi) { + __m128i unit; + { + __m128i lower_dword = _mm_cvtsi32_si128(unit_lo); + __m128i upper_dword = _mm_cvtsi32_si128(unit_hi); + upper_dword = _mm_slli_si128(upper_dword, 4); + unit = _mm_or_si128(lower_dword, upper_dword); + } + + __m128i counts; + { + #ifdef MARISA_USE_SSSE3 + __m128i lower_nibbles = _mm_set1_epi8(0x0F); + lower_nibbles = _mm_and_si128(lower_nibbles, unit); + __m128i upper_nibbles = _mm_set1_epi8((UInt8)0xF0); + upper_nibbles = _mm_and_si128(upper_nibbles, unit); + upper_nibbles = _mm_srli_epi32(upper_nibbles, 4); + + __m128i lower_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles); + __m128i upper_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles); + + counts = _mm_add_epi8(lower_counts, upper_counts); + #else // MARISA_USE_SSSE3 + __m128i x = _mm_srli_epi32(unit, 1); + x = _mm_and_si128(x, _mm_set1_epi8(0x55)); + x = _mm_sub_epi8(unit, x); + + __m128i y = _mm_srli_epi32(x, 2); + y = _mm_and_si128(y, _mm_set1_epi8(0x33)); + x = _mm_and_si128(x, _mm_set1_epi8(0x33)); + x = _mm_add_epi8(x, y); + + y = _mm_srli_epi32(x, 4); + x = _mm_add_epi8(x, y); + counts = _mm_and_si128(x, _mm_set1_epi8(0x0F)); + #endif // MARISA_USE_SSSE3 + } + + __m128i accumulated_counts; + { + __m128i x = counts; + x = _mm_slli_si128(x, 1); + __m128i y = counts; + y = _mm_add_epi32(y, x); + + x = y; + y = _mm_slli_si128(y, 2); + x = _mm_add_epi32(x, y); + + y = x; + x = _mm_slli_si128(x, 4); + y = _mm_add_epi32(y, x); + + accumulated_counts = _mm_set_epi32(0x7F7F7F7FU, 0x7F7F7F7FU, 0, 0); + accumulated_counts = _mm_or_si128(accumulated_counts, y); + } + + UInt8 skip; + { + __m128i x = _mm_set1_epi8((UInt8)(i + 1)); + x = _mm_cmpgt_epi8(x, accumulated_counts); + skip = POPCNT_TABLE[_mm_movemask_epi8(x)]; + } + + UInt8 byte; + { + #ifdef _MSC_VER + __declspec(align(16)) UInt8 unit_bytes[16]; + __declspec(align(16)) UInt8 accumulated_counts_bytes[16]; + #else // _MSC_VER + UInt8 unit_bytes[16] __attribute__ ((aligned (16))); + UInt8 accumulated_counts_bytes[16] __attribute__ ((aligned (16))); + #endif // _MSC_VER + accumulated_counts = _mm_slli_si128(accumulated_counts, 1); + _mm_store_si128(reinterpret_cast<__m128i *>(unit_bytes), unit); + _mm_store_si128(reinterpret_cast<__m128i *>(accumulated_counts_bytes), + accumulated_counts); + + bit_id += skip; + byte = unit_bytes[skip / 8]; + i -= accumulated_counts_bytes[skip / 8]; + } + + return bit_id + SELECT_TABLE[i][byte]; +} + #endif // MARISA_USE_SSE2 +#endif // MARISA_WORD_SIZE == 64 + +} // namespace + +#if MARISA_WORD_SIZE == 64 + +std::size_t BitVector::rank1(std::size_t i) const { + MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR); + + const RankIndex &rank = ranks_[i / 512]; + std::size_t offset = rank.abs(); + switch ((i / 64) % 8) { + case 1: { + offset += rank.rel1(); + break; + } + case 2: { + offset += rank.rel2(); + break; + } + case 3: { + offset += rank.rel3(); + break; + } + case 4: { + offset += rank.rel4(); + break; + } + case 5: { + offset += rank.rel5(); + break; + } + case 6: { + offset += rank.rel6(); + break; + } + case 7: { + offset += rank.rel7(); + break; + } + } + offset += PopCount::count(units_[i / 64] & ((1ULL << (i % 64)) - 1)); + return offset; +} + +std::size_t BitVector::select0(std::size_t i) const { + MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select0s_[select_id]; + } + std::size_t begin = select0s_[select_id] / 512; + std::size_t end = (select0s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < (middle * 512) - ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= (rank_id * 512) - ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 8; + if (i < (256U - rank.rel4())) { + if (i < (128U - rank.rel2())) { + if (i >= (64U - rank.rel1())) { + unit_id += 1; + i -= 64 - rank.rel1(); + } + } else if (i < (192U - rank.rel3())) { + unit_id += 2; + i -= 128 - rank.rel2(); + } else { + unit_id += 3; + i -= 192 - rank.rel3(); + } + } else if (i < (384U - rank.rel6())) { + if (i < (320U - rank.rel5())) { + unit_id += 4; + i -= 256 - rank.rel4(); + } else { + unit_id += 5; + i -= 320 - rank.rel5(); + } + } else if (i < (448U - rank.rel7())) { + unit_id += 6; + i -= 384 - rank.rel6(); + } else { + unit_id += 7; + i -= 448 - rank.rel7(); + } + + return select_bit(i, unit_id * 64, ~units_[unit_id]); +} + +std::size_t BitVector::select1(std::size_t i) const { + MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select1s_[select_id]; + } + std::size_t begin = select1s_[select_id] / 512; + std::size_t end = (select1s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 8; + if (i < rank.rel4()) { + if (i < rank.rel2()) { + if (i >= rank.rel1()) { + unit_id += 1; + i -= rank.rel1(); + } + } else if (i < rank.rel3()) { + unit_id += 2; + i -= rank.rel2(); + } else { + unit_id += 3; + i -= rank.rel3(); + } + } else if (i < rank.rel6()) { + if (i < rank.rel5()) { + unit_id += 4; + i -= rank.rel4(); + } else { + unit_id += 5; + i -= rank.rel5(); + } + } else if (i < rank.rel7()) { + unit_id += 6; + i -= rank.rel6(); + } else { + unit_id += 7; + i -= rank.rel7(); + } + + return select_bit(i, unit_id * 64, units_[unit_id]); +} + +#else // MARISA_WORD_SIZE == 64 + +std::size_t BitVector::rank1(std::size_t i) const { + MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR); + + const RankIndex &rank = ranks_[i / 512]; + std::size_t offset = rank.abs(); + switch ((i / 64) % 8) { + case 1: { + offset += rank.rel1(); + break; + } + case 2: { + offset += rank.rel2(); + break; + } + case 3: { + offset += rank.rel3(); + break; + } + case 4: { + offset += rank.rel4(); + break; + } + case 5: { + offset += rank.rel5(); + break; + } + case 6: { + offset += rank.rel6(); + break; + } + case 7: { + offset += rank.rel7(); + break; + } + } + if (((i / 32) & 1) == 1) { + offset += PopCount::count(units_[(i / 32) - 1]); + } + offset += PopCount::count(units_[i / 32] & ((1U << (i % 32)) - 1)); + return offset; +} + +std::size_t BitVector::select0(std::size_t i) const { + MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select0s_[select_id]; + } + std::size_t begin = select0s_[select_id] / 512; + std::size_t end = (select0s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < (middle * 512) - ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= (rank_id * 512) - ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 16; + if (i < (256U - rank.rel4())) { + if (i < (128U - rank.rel2())) { + if (i >= (64U - rank.rel1())) { + unit_id += 2; + i -= 64 - rank.rel1(); + } + } else if (i < (192U - rank.rel3())) { + unit_id += 4; + i -= 128 - rank.rel2(); + } else { + unit_id += 6; + i -= 192 - rank.rel3(); + } + } else if (i < (384U - rank.rel6())) { + if (i < (320U - rank.rel5())) { + unit_id += 8; + i -= 256 - rank.rel4(); + } else { + unit_id += 10; + i -= 320 - rank.rel5(); + } + } else if (i < (448U - rank.rel7())) { + unit_id += 12; + i -= 384 - rank.rel6(); + } else { + unit_id += 14; + i -= 448 - rank.rel7(); + } + +#ifdef MARISA_USE_SSE2 + return select_bit(i, unit_id * 32, ~units_[unit_id], ~units_[unit_id + 1]); +#else // MARISA_USE_SSE2 + UInt32 unit = ~units_[unit_id]; + PopCount count(unit); + if (i >= count.lo32()) { + ++unit_id; + i -= count.lo32(); + unit = ~units_[unit_id]; + count = PopCount(unit); + } + + std::size_t bit_id = unit_id * 32; + if (i < count.lo16()) { + if (i >= count.lo8()) { + bit_id += 8; + unit >>= 8; + i -= count.lo8(); + } + } else if (i < count.lo24()) { + bit_id += 16; + unit >>= 16; + i -= count.lo16(); + } else { + bit_id += 24; + unit >>= 24; + i -= count.lo24(); + } + return bit_id + SELECT_TABLE[i][unit & 0xFF]; +#endif // MARISA_USE_SSE2 +} + +std::size_t BitVector::select1(std::size_t i) const { + MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select1s_[select_id]; + } + std::size_t begin = select1s_[select_id] / 512; + std::size_t end = (select1s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 16; + if (i < rank.rel4()) { + if (i < rank.rel2()) { + if (i >= rank.rel1()) { + unit_id += 2; + i -= rank.rel1(); + } + } else if (i < rank.rel3()) { + unit_id += 4; + i -= rank.rel2(); + } else { + unit_id += 6; + i -= rank.rel3(); + } + } else if (i < rank.rel6()) { + if (i < rank.rel5()) { + unit_id += 8; + i -= rank.rel4(); + } else { + unit_id += 10; + i -= rank.rel5(); + } + } else if (i < rank.rel7()) { + unit_id += 12; + i -= rank.rel6(); + } else { + unit_id += 14; + i -= rank.rel7(); + } + +#ifdef MARISA_USE_SSE2 + return select_bit(i, unit_id * 32, units_[unit_id], units_[unit_id + 1]); +#else // MARISA_USE_SSE2 + UInt32 unit = units_[unit_id]; + PopCount count(unit); + if (i >= count.lo32()) { + ++unit_id; + i -= count.lo32(); + unit = units_[unit_id]; + count = PopCount(unit); + } + + std::size_t bit_id = unit_id * 32; + if (i < count.lo16()) { + if (i >= count.lo8()) { + bit_id += 8; + unit >>= 8; + i -= count.lo8(); + } + } else if (i < count.lo24()) { + bit_id += 16; + unit >>= 16; + i -= count.lo16(); + } else { + bit_id += 24; + unit >>= 24; + i -= count.lo24(); + } + return bit_id + SELECT_TABLE[i][unit & 0xFF]; +#endif // MARISA_USE_SSE2 +} + +#endif // MARISA_WORD_SIZE == 64 + +void BitVector::build_index(const BitVector &bv, + bool enables_select0, bool enables_select1) { + ranks_.resize((bv.size() / 512) + (((bv.size() % 512) != 0) ? 1 : 0) + 1); + + std::size_t num_0s = 0; + std::size_t num_1s = 0; + + for (std::size_t i = 0; i < bv.size(); ++i) { + if ((i % 64) == 0) { + const std::size_t rank_id = i / 512; + switch ((i / 64) % 8) { + case 0: { + ranks_[rank_id].set_abs(num_1s); + break; + } + case 1: { + ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs()); + break; + } + case 2: { + ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs()); + break; + } + case 3: { + ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs()); + break; + } + case 4: { + ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs()); + break; + } + case 5: { + ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs()); + break; + } + case 6: { + ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs()); + break; + } + case 7: { + ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs()); + break; + } + } + } + + if (bv[i]) { + if (enables_select1 && ((num_1s % 512) == 0)) { + select1s_.push_back(static_cast<UInt32>(i)); + } + ++num_1s; + } else { + if (enables_select0 && ((num_0s % 512) == 0)) { + select0s_.push_back(static_cast<UInt32>(i)); + } + ++num_0s; + } + } + + if ((bv.size() % 512) != 0) { + const std::size_t rank_id = (bv.size() - 1) / 512; + switch (((bv.size() - 1) / 64) % 8) { + case 0: { + ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs()); + } + case 1: { + ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs()); + } + case 2: { + ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs()); + } + case 3: { + ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs()); + } + case 4: { + ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs()); + } + case 5: { + ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs()); + } + case 6: { + ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs()); + break; + } + } + } + + size_ = bv.size(); + num_1s_ = bv.num_1s(); + + ranks_.back().set_abs(num_1s); + if (enables_select0) { + select0s_.push_back(static_cast<UInt32>(bv.size())); + select0s_.shrink(); + } + if (enables_select1) { + select1s_.push_back(static_cast<UInt32>(bv.size())); + select1s_.shrink(); + } +} + +} // namespace vector +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.h new file mode 100644 index 0000000000..56f99ed699 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.h @@ -0,0 +1,180 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_BIT_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_BIT_VECTOR_H_ + +#include "rank-index.h" +#include "vector.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +class BitVector { + public: +#if MARISA_WORD_SIZE == 64 + typedef UInt64 Unit; +#else // MARISA_WORD_SIZE == 64 + typedef UInt32 Unit; +#endif // MARISA_WORD_SIZE == 64 + + BitVector() + : units_(), size_(0), num_1s_(0), ranks_(), select0s_(), select1s_() {} + + void build(bool enables_select0, bool enables_select1) { + BitVector temp; + temp.build_index(*this, enables_select0, enables_select1); + units_.shrink(); + temp.units_.swap(units_); + swap(temp); + } + + void map(Mapper &mapper) { + BitVector temp; + temp.map_(mapper); + swap(temp); + } + void read(Reader &reader) { + BitVector temp; + temp.read_(reader); + swap(temp); + } + void write(Writer &writer) const { + write_(writer); + } + + void disable_select0() { + select0s_.clear(); + } + void disable_select1() { + select1s_.clear(); + } + + void push_back(bool bit) { + MARISA_THROW_IF(size_ == MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + if (size_ == (MARISA_WORD_SIZE * units_.size())) { + units_.resize(units_.size() + (64 / MARISA_WORD_SIZE), 0); + } + if (bit) { + units_[size_ / MARISA_WORD_SIZE] |= + (Unit)1 << (size_ % MARISA_WORD_SIZE); + ++num_1s_; + } + ++size_; + } + + bool operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return (units_[i / MARISA_WORD_SIZE] + & ((Unit)1 << (i % MARISA_WORD_SIZE))) != 0; + } + + std::size_t rank0(std::size_t i) const { + MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR); + return i - rank1(i); + } + std::size_t rank1(std::size_t i) const; + + std::size_t select0(std::size_t i) const; + std::size_t select1(std::size_t i) const; + + std::size_t num_0s() const { + return size_ - num_1s_; + } + std::size_t num_1s() const { + return num_1s_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t size() const { + return size_; + } + std::size_t total_size() const { + return units_.total_size() + ranks_.total_size() + + select0s_.total_size() + select1s_.total_size(); + } + std::size_t io_size() const { + return units_.io_size() + (sizeof(UInt32) * 2) + ranks_.io_size() + + select0s_.io_size() + select1s_.io_size(); + } + + void clear() { + BitVector().swap(*this); + } + void swap(BitVector &rhs) { + units_.swap(rhs.units_); + marisa::swap(size_, rhs.size_); + marisa::swap(num_1s_, rhs.num_1s_); + ranks_.swap(rhs.ranks_); + select0s_.swap(rhs.select0s_); + select1s_.swap(rhs.select1s_); + } + + private: + Vector<Unit> units_; + std::size_t size_; + std::size_t num_1s_; + Vector<RankIndex> ranks_; + Vector<UInt32> select0s_; + Vector<UInt32> select1s_; + + void build_index(const BitVector &bv, + bool enables_select0, bool enables_select1); + + void map_(Mapper &mapper) { + units_.map(mapper); + { + UInt32 temp_size; + mapper.map(&temp_size); + size_ = temp_size; + } + { + UInt32 temp_num_1s; + mapper.map(&temp_num_1s); + MARISA_THROW_IF(temp_num_1s > size_, MARISA_FORMAT_ERROR); + num_1s_ = temp_num_1s; + } + ranks_.map(mapper); + select0s_.map(mapper); + select1s_.map(mapper); + } + + void read_(Reader &reader) { + units_.read(reader); + { + UInt32 temp_size; + reader.read(&temp_size); + size_ = temp_size; + } + { + UInt32 temp_num_1s; + reader.read(&temp_num_1s); + MARISA_THROW_IF(temp_num_1s > size_, MARISA_FORMAT_ERROR); + num_1s_ = temp_num_1s; + } + ranks_.read(reader); + select0s_.read(reader); + select1s_.read(reader); + } + + void write_(Writer &writer) const { + units_.write(writer); + writer.write((UInt32)size_); + writer.write((UInt32)num_1s_); + ranks_.write(writer); + select0s_.write(writer); + select1s_.write(writer); + } + + // Disallows copy and assignment. + BitVector(const BitVector &); + BitVector &operator=(const BitVector &); +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_BIT_VECTOR_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/flat-vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector/flat-vector.h new file mode 100644 index 0000000000..14b25d7d82 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/flat-vector.h @@ -0,0 +1,206 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_FLAT_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_FLAT_VECTOR_H_ + +#include "vector.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +class FlatVector { + public: +#if MARISA_WORD_SIZE == 64 + typedef UInt64 Unit; +#else // MARISA_WORD_SIZE == 64 + typedef UInt32 Unit; +#endif // MARISA_WORD_SIZE == 64 + + FlatVector() : units_(), value_size_(0), mask_(0), size_(0) {} + + void build(const Vector<UInt32> &values) { + FlatVector temp; + temp.build_(values); + swap(temp); + } + + void map(Mapper &mapper) { + FlatVector temp; + temp.map_(mapper); + swap(temp); + } + void read(Reader &reader) { + FlatVector temp; + temp.read_(reader); + swap(temp); + } + void write(Writer &writer) const { + write_(writer); + } + + UInt32 operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + + const std::size_t pos = i * value_size_; + const std::size_t unit_id = pos / MARISA_WORD_SIZE; + const std::size_t unit_offset = pos % MARISA_WORD_SIZE; + + if ((unit_offset + value_size_) <= MARISA_WORD_SIZE) { + return (UInt32)(units_[unit_id] >> unit_offset) & mask_; + } else { + return (UInt32)((units_[unit_id] >> unit_offset) + | (units_[unit_id + 1] << (MARISA_WORD_SIZE - unit_offset))) & mask_; + } + } + + std::size_t value_size() const { + return value_size_; + } + UInt32 mask() const { + return mask_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t size() const { + return size_; + } + std::size_t total_size() const { + return units_.total_size(); + } + std::size_t io_size() const { + return units_.io_size() + (sizeof(UInt32) * 2) + sizeof(UInt64); + } + + void clear() { + FlatVector().swap(*this); + } + void swap(FlatVector &rhs) { + units_.swap(rhs.units_); + marisa::swap(value_size_, rhs.value_size_); + marisa::swap(mask_, rhs.mask_); + marisa::swap(size_, rhs.size_); + } + + private: + Vector<Unit> units_; + std::size_t value_size_; + UInt32 mask_; + std::size_t size_; + + void build_(const Vector<UInt32> &values) { + UInt32 max_value = 0; + for (std::size_t i = 0; i < values.size(); ++i) { + if (values[i] > max_value) { + max_value = values[i]; + } + } + + std::size_t value_size = 0; + while (max_value != 0) { + ++value_size; + max_value >>= 1; + } + + std::size_t num_units = values.empty() ? 0 : (64 / MARISA_WORD_SIZE); + if (value_size != 0) { + num_units = (std::size_t)( + (((UInt64)value_size * values.size()) + (MARISA_WORD_SIZE - 1)) + / MARISA_WORD_SIZE); + num_units += num_units % (64 / MARISA_WORD_SIZE); + } + + units_.resize(num_units); + if (num_units > 0) { + units_.back() = 0; + } + + value_size_ = value_size; + if (value_size != 0) { + mask_ = MARISA_UINT32_MAX >> (32 - value_size); + } + size_ = values.size(); + + for (std::size_t i = 0; i < values.size(); ++i) { + set(i, values[i]); + } + } + + void map_(Mapper &mapper) { + units_.map(mapper); + { + UInt32 temp_value_size; + mapper.map(&temp_value_size); + MARISA_THROW_IF(temp_value_size > 32, MARISA_FORMAT_ERROR); + value_size_ = temp_value_size; + } + { + UInt32 temp_mask; + mapper.map(&temp_mask); + mask_ = temp_mask; + } + { + UInt64 temp_size; + mapper.map(&temp_size); + MARISA_THROW_IF(temp_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + size_ = (std::size_t)temp_size; + } + } + + void read_(Reader &reader) { + units_.read(reader); + { + UInt32 temp_value_size; + reader.read(&temp_value_size); + MARISA_THROW_IF(temp_value_size > 32, MARISA_FORMAT_ERROR); + value_size_ = temp_value_size; + } + { + UInt32 temp_mask; + reader.read(&temp_mask); + mask_ = temp_mask; + } + { + UInt64 temp_size; + reader.read(&temp_size); + MARISA_THROW_IF(temp_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + size_ = (std::size_t)temp_size; + } + } + + void write_(Writer &writer) const { + units_.write(writer); + writer.write((UInt32)value_size_); + writer.write((UInt32)mask_); + writer.write((UInt64)size_); + } + + void set(std::size_t i, UInt32 value) { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(value > mask_, MARISA_RANGE_ERROR); + + const std::size_t pos = i * value_size_; + const std::size_t unit_id = pos / MARISA_WORD_SIZE; + const std::size_t unit_offset = pos % MARISA_WORD_SIZE; + + units_[unit_id] &= ~((Unit)mask_ << unit_offset); + units_[unit_id] |= (Unit)(value & mask_) << unit_offset; + if ((unit_offset + value_size_) > MARISA_WORD_SIZE) { + units_[unit_id + 1] &= + ~((Unit)mask_ >> (MARISA_WORD_SIZE - unit_offset)); + units_[unit_id + 1] |= + (Unit)(value & mask_) >> (MARISA_WORD_SIZE - unit_offset); + } + } + + // Disallows copy and assignment. + FlatVector(const FlatVector &); + FlatVector &operator=(const FlatVector &); +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_FLAT_VECTOR_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/pop-count.h b/contrib/python/marisa-trie/marisa/grimoire/vector/pop-count.h new file mode 100644 index 0000000000..6d04cf831d --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/pop-count.h @@ -0,0 +1,111 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_POP_COUNT_H_ +#define MARISA_GRIMOIRE_VECTOR_POP_COUNT_H_ + +#include "../intrin.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +#if MARISA_WORD_SIZE == 64 + +class PopCount { + public: + explicit PopCount(UInt64 x) : value_() { + x = (x & 0x5555555555555555ULL) + ((x & 0xAAAAAAAAAAAAAAAAULL) >> 1); + x = (x & 0x3333333333333333ULL) + ((x & 0xCCCCCCCCCCCCCCCCULL) >> 2); + x = (x & 0x0F0F0F0F0F0F0F0FULL) + ((x & 0xF0F0F0F0F0F0F0F0ULL) >> 4); + x *= 0x0101010101010101ULL; + value_ = x; + } + + std::size_t lo8() const { + return (std::size_t)(value_ & 0xFFU); + } + std::size_t lo16() const { + return (std::size_t)((value_ >> 8) & 0xFFU); + } + std::size_t lo24() const { + return (std::size_t)((value_ >> 16) & 0xFFU); + } + std::size_t lo32() const { + return (std::size_t)((value_ >> 24) & 0xFFU); + } + std::size_t lo40() const { + return (std::size_t)((value_ >> 32) & 0xFFU); + } + std::size_t lo48() const { + return (std::size_t)((value_ >> 40) & 0xFFU); + } + std::size_t lo56() const { + return (std::size_t)((value_ >> 48) & 0xFFU); + } + std::size_t lo64() const { + return (std::size_t)((value_ >> 56) & 0xFFU); + } + + static std::size_t count(UInt64 x) { +#if defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + #ifdef _MSC_VER + return __popcnt64(x); + #else // _MSC_VER + return _mm_popcnt_u64(x); + #endif // _MSC_VER +#else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + return PopCount(x).lo64(); +#endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + } + + private: + UInt64 value_; +}; + +#else // MARISA_WORD_SIZE == 64 + +class PopCount { + public: + explicit PopCount(UInt32 x) : value_() { + x = (x & 0x55555555U) + ((x & 0xAAAAAAAAU) >> 1); + x = (x & 0x33333333U) + ((x & 0xCCCCCCCCU) >> 2); + x = (x & 0x0F0F0F0FU) + ((x & 0xF0F0F0F0U) >> 4); + x *= 0x01010101U; + value_ = x; + } + + std::size_t lo8() const { + return value_ & 0xFFU; + } + std::size_t lo16() const { + return (value_ >> 8) & 0xFFU; + } + std::size_t lo24() const { + return (value_ >> 16) & 0xFFU; + } + std::size_t lo32() const { + return (value_ >> 24) & 0xFFU; + } + + static std::size_t count(UInt32 x) { +#ifdef MARISA_USE_POPCNT + #ifdef _MSC_VER + return __popcnt(x); + #else // _MSC_VER + return _mm_popcnt_u32(x); + #endif // _MSC_VER +#else // MARISA_USE_POPCNT + return PopCount(x).lo32(); +#endif // MARISA_USE_POPCNT + } + + private: + UInt32 value_; +}; + +#endif // MARISA_WORD_SIZE == 64 + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_POP_COUNT_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/rank-index.h b/contrib/python/marisa-trie/marisa/grimoire/vector/rank-index.h new file mode 100644 index 0000000000..2403709954 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/rank-index.h @@ -0,0 +1,83 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_RANK_INDEX_H_ +#define MARISA_GRIMOIRE_VECTOR_RANK_INDEX_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +class RankIndex { + public: + RankIndex() : abs_(0), rel_lo_(0), rel_hi_(0) {} + + void set_abs(std::size_t value) { + MARISA_DEBUG_IF(value > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + abs_ = (UInt32)value; + } + void set_rel1(std::size_t value) { + MARISA_DEBUG_IF(value > 64, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~0x7FU) | (value & 0x7FU)); + } + void set_rel2(std::size_t value) { + MARISA_DEBUG_IF(value > 128, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~(0xFFU << 7)) | ((value & 0xFFU) << 7)); + } + void set_rel3(std::size_t value) { + MARISA_DEBUG_IF(value > 192, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~(0xFFU << 15)) | ((value & 0xFFU) << 15)); + } + void set_rel4(std::size_t value) { + MARISA_DEBUG_IF(value > 256, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~(0x1FFU << 23)) | ((value & 0x1FFU) << 23)); + } + void set_rel5(std::size_t value) { + MARISA_DEBUG_IF(value > 320, MARISA_RANGE_ERROR); + rel_hi_ = (UInt32)((rel_hi_ & ~0x1FFU) | (value & 0x1FFU)); + } + void set_rel6(std::size_t value) { + MARISA_DEBUG_IF(value > 384, MARISA_RANGE_ERROR); + rel_hi_ = (UInt32)((rel_hi_ & ~(0x1FFU << 9)) | ((value & 0x1FFU) << 9)); + } + void set_rel7(std::size_t value) { + MARISA_DEBUG_IF(value > 448, MARISA_RANGE_ERROR); + rel_hi_ = (UInt32)((rel_hi_ & ~(0x1FFU << 18)) | ((value & 0x1FFU) << 18)); + } + + std::size_t abs() const { + return abs_; + } + std::size_t rel1() const { + return rel_lo_ & 0x7FU; + } + std::size_t rel2() const { + return (rel_lo_ >> 7) & 0xFFU; + } + std::size_t rel3() const { + return (rel_lo_ >> 15) & 0xFFU; + } + std::size_t rel4() const { + return (rel_lo_ >> 23) & 0x1FFU; + } + std::size_t rel5() const { + return rel_hi_ & 0x1FFU; + } + std::size_t rel6() const { + return (rel_hi_ >> 9) & 0x1FFU; + } + std::size_t rel7() const { + return (rel_hi_ >> 18) & 0x1FFU; + } + + private: + UInt32 abs_; + UInt32 rel_lo_; + UInt32 rel_hi_; +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_RANK_INDEX_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector/vector.h new file mode 100644 index 0000000000..148cc8b491 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/vector.h @@ -0,0 +1,257 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_VECTOR_H_ + +#include <new> + +#include "../io.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +template <typename T> +class Vector { + public: + Vector() + : buf_(), objs_(NULL), const_objs_(NULL), + size_(0), capacity_(0), fixed_(false) {} + ~Vector() { + if (objs_ != NULL) { + for (std::size_t i = 0; i < size_; ++i) { + objs_[i].~T(); + } + } + } + + void map(Mapper &mapper) { + Vector temp; + temp.map_(mapper); + swap(temp); + } + + void read(Reader &reader) { + Vector temp; + temp.read_(reader); + swap(temp); + } + + void write(Writer &writer) const { + write_(writer); + } + + void push_back(const T &x) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == max_size(), MARISA_SIZE_ERROR); + reserve(size_ + 1); + new (&objs_[size_]) T(x); + ++size_; + } + + void pop_back() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + objs_[--size_].~T(); + } + + // resize() assumes that T's placement new does not throw an exception. + void resize(std::size_t size) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + reserve(size); + for (std::size_t i = size_; i < size; ++i) { + new (&objs_[i]) T; + } + for (std::size_t i = size; i < size_; ++i) { + objs_[i].~T(); + } + size_ = size; + } + + // resize() assumes that T's placement new does not throw an exception. + void resize(std::size_t size, const T &x) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + reserve(size); + for (std::size_t i = size_; i < size; ++i) { + new (&objs_[i]) T(x); + } + for (std::size_t i = size; i < size_; ++i) { + objs_[i].~T(); + } + size_ = size; + } + + void reserve(std::size_t capacity) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + if (capacity <= capacity_) { + return; + } + MARISA_DEBUG_IF(capacity > max_size(), MARISA_SIZE_ERROR); + std::size_t new_capacity = capacity; + if (capacity_ > (capacity / 2)) { + if (capacity_ > (max_size() / 2)) { + new_capacity = max_size(); + } else { + new_capacity = capacity_ * 2; + } + } + realloc(new_capacity); + } + + void shrink() { + MARISA_THROW_IF(fixed_, MARISA_STATE_ERROR); + if (size_ != capacity_) { + realloc(size_); + } + } + + void fix() { + MARISA_THROW_IF(fixed_, MARISA_STATE_ERROR); + fixed_ = true; + } + + const T *begin() const { + return const_objs_; + } + const T *end() const { + return const_objs_ + size_; + } + const T &operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return const_objs_[i]; + } + const T &front() const { + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return const_objs_[0]; + } + const T &back() const { + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return const_objs_[size_ - 1]; + } + + T *begin() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + return objs_; + } + T *end() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + return objs_ + size_; + } + T &operator[](std::size_t i) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return objs_[i]; + } + T &front() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return objs_[0]; + } + T &back() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return objs_[size_ - 1]; + } + + std::size_t size() const { + return size_; + } + std::size_t capacity() const { + return capacity_; + } + bool fixed() const { + return fixed_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t total_size() const { + return sizeof(T) * size_; + } + std::size_t io_size() const { + return sizeof(UInt64) + ((total_size() + 7) & ~(std::size_t)0x07); + } + + void clear() { + Vector().swap(*this); + } + void swap(Vector &rhs) { + buf_.swap(rhs.buf_); + marisa::swap(objs_, rhs.objs_); + marisa::swap(const_objs_, rhs.const_objs_); + marisa::swap(size_, rhs.size_); + marisa::swap(capacity_, rhs.capacity_); + marisa::swap(fixed_, rhs.fixed_); + } + + static std::size_t max_size() { + return MARISA_SIZE_MAX / sizeof(T); + } + + private: + scoped_array<char> buf_; + T *objs_; + const T *const_objs_; + std::size_t size_; + std::size_t capacity_; + bool fixed_; + + void map_(Mapper &mapper) { + UInt64 total_size; + mapper.map(&total_size); + MARISA_THROW_IF(total_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + MARISA_THROW_IF((total_size % sizeof(T)) != 0, MARISA_FORMAT_ERROR); + const std::size_t size = (std::size_t)(total_size / sizeof(T)); + mapper.map(&const_objs_, size); + mapper.seek((std::size_t)((8 - (total_size % 8)) % 8)); + size_ = size; + fix(); + } + void read_(Reader &reader) { + UInt64 total_size; + reader.read(&total_size); + MARISA_THROW_IF(total_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + MARISA_THROW_IF((total_size % sizeof(T)) != 0, MARISA_FORMAT_ERROR); + const std::size_t size = (std::size_t)(total_size / sizeof(T)); + resize(size); + reader.read(objs_, size); + reader.seek((std::size_t)((8 - (total_size % 8)) % 8)); + } + void write_(Writer &writer) const { + writer.write((UInt64)total_size()); + writer.write(const_objs_, size_); + writer.seek((8 - (total_size() % 8)) % 8); + } + + // realloc() assumes that T's placement new does not throw an exception. + void realloc(std::size_t new_capacity) { + MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR); + + scoped_array<char> new_buf( + new (std::nothrow) char[sizeof(T) * new_capacity]); + MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR); + T *new_objs = reinterpret_cast<T *>(new_buf.get()); + + for (std::size_t i = 0; i < size_; ++i) { + new (&new_objs[i]) T(objs_[i]); + } + for (std::size_t i = 0; i < size_; ++i) { + objs_[i].~T(); + } + + buf_.swap(new_buf); + objs_ = new_objs; + const_objs_ = new_objs; + capacity_ = new_capacity; + } + + // Disallows copy and assignment. + Vector(const Vector &); + Vector &operator=(const Vector &); +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_VECTOR_H_ |