diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/compproto/huff.h | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/compproto/huff.h')
-rw-r--r-- | library/cpp/compproto/huff.h | 402 |
1 files changed, 402 insertions, 0 deletions
diff --git a/library/cpp/compproto/huff.h b/library/cpp/compproto/huff.h new file mode 100644 index 00000000000..fa5c139189d --- /dev/null +++ b/library/cpp/compproto/huff.h @@ -0,0 +1,402 @@ +#pragma once + +#include <util/system/defaults.h> +#include <util/generic/yexception.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/generic/algorithm.h> +#include <utility> + +#include <queue> + +#include "compressor.h" + +namespace NCompProto { + template <size_t CacheSize, typename TEntry> + struct TCache { + ui32 CacheKey[CacheSize]; + TEntry CacheVal[CacheSize]; + size_t Hits; + size_t Misses; + ui32 Hash(ui32 key) { + return key % CacheSize; + } + TCache() { + Hits = 0; + Misses = 0; + Clear(); + } + void Clear() { + for (size_t i = 0; i < CacheSize; ++i) { + ui32 j = 0; + for (; Hash(j) == i; ++j) + ; + CacheKey[i] = j; + } + } + }; + + struct TCode { + i64 Probability; + ui32 Start; + ui32 Bits; + ui32 Prefix; + ui32 PrefLength; + TCode(i64 probability = 0, ui32 start = 0, ui32 bits = 0) + : Probability(probability) + , Start(start) + , Bits(bits) + { + } + + bool operator<(const TCode& code) const { + return Probability < code.Probability; + } + + bool operator>(const TCode& code) const { + return Probability > code.Probability; + } + }; + + struct TAccum { + struct TTable { + TAutoPtr<TTable> Tables[16]; + i64 Counts[16]; + TTable(const TTable& other) { + for (size_t i = 0; i < 16; ++i) { + Counts[i] = other.Counts[i]; + if (other.Tables[i].Get()) { + Tables[i].Reset(new TTable(*other.Tables[i].Get())); + } + } + } + TTable() { + for (auto& count : Counts) + count = 0; + } + + i64 GetCellCount(size_t i) { + i64 count = Counts[i]; + if (Tables[i].Get()) { + for (size_t j = 0; j < 16; ++j) { + count += Tables[i]->GetCellCount(j); + } + } + return count; + } + + i64 GetCount() { + i64 count = 0; + for (size_t j = 0; j < 16; ++j) { + count += GetCellCount(j); + } + return count; + } + + void GenerateFreqs(TVector<std::pair<i64, TCode>>& codes, int depth, int termDepth, ui32 code, i64 cnt) { + if (depth == termDepth) { + for (size_t i = 0; i < 16; ++i) { + i64 iCount = GetCellCount(i); + if (Tables[i].Get()) { + Counts[i] = iCount; + Tables[i].Reset(nullptr); + } + + if (iCount > cnt || (termDepth == 0 && iCount > 0)) { + std::pair<i64, TCode> codep; + codep.first = iCount; + codep.second.Probability = iCount; + codep.second.Start = code + (i << (28 - depth)); + codep.second.Bits = 28 - depth; + codes.push_back(codep); + Counts[i] = 0; + } + } + } + for (size_t i = 0; i < 16; ++i) { + if (Tables[i].Get()) { + Tables[i]->GenerateFreqs(codes, depth + 4, termDepth, code + (i << (28 - depth)), cnt); + } + } + } + }; + + TTable Root; + int TableCount; + i64 Total; + ui64 Max; + + TAccum() { + TableCount = 0; + Total = 0; + Max = 0; + } + + void GenerateFreqs(TVector<std::pair<i64, TCode>>& codes, int mul) const { + TTable root(Root); + + for (int i = 28; i > 0; i -= 4) { + root.GenerateFreqs(codes, 0, i, 0, Total / mul); + } + + i64 iCount = root.GetCount(); + if (iCount == 0) + return; + std::pair<i64, TCode> codep; + codep.first = iCount; + codep.second.Probability = iCount; + codep.second.Start = 0; + ui32 bits = 0; + while (1) { + if ((1ULL << bits) > Max) + break; + ++bits; + } + codep.second.Bits = bits; + codes.push_back(codep); + } + + TCache<256, i64*> Cache; + + void AddMap(ui32 value, i64 weight = 1) { + ui32 index = Cache.Hash(value); + if (Cache.CacheKey[index] == value) { + Cache.CacheVal[index][0] += weight; + return; + } + TTable* root = &Root; + for (size_t i = 0; i < 15; ++i) { + ui32 index2 = (value >> (28 - i * 4)) & 0xf; + if (!root->Tables[index2].Get()) { + if (TableCount < 1024) { + ++TableCount; + root->Tables[index2].Reset(new TTable); + } else { + Cache.CacheKey[index2] = value; + Cache.CacheVal[index2] = &root->Counts[index2]; + root->Counts[index2] += weight; + return; + } + } + root = root->Tables[index2].Get(); + } + + Cache.CacheKey[index] = value; + Cache.CacheVal[index] = &root->Counts[value & 0xf]; + root->Counts[value & 0xf] += weight; + } + + void Add(ui32 value, i64 weight = 1) { + Max = ::Max(Max, (ui64)value); + Total += weight; + AddMap(value, weight); + }; + }; + + struct THuffNode { + i64 Weight; + i64 Priority; + THuffNode* Nodes[2]; + TCode* Code; + THuffNode(i64 weight, i64 priority, TCode* code) + : Weight(weight) + , Priority(priority) + , Code(code) + { + Nodes[0] = nullptr; + Nodes[1] = nullptr; + } + + void BuildPrefixes(ui32 depth, ui32 prefix) { + if (Code) { + Code->Prefix = prefix; + Code->PrefLength = depth; + return; + } + Nodes[0]->BuildPrefixes(depth + 1, prefix + (0UL << depth)); + Nodes[1]->BuildPrefixes(depth + 1, prefix + (1UL << depth)); + } + + i64 Iterate(size_t depth) const { + if (Code) { + return (depth + Code->Bits) * Code->Probability; + } + return Nodes[0]->Iterate(depth + 1) + Nodes[1]->Iterate(depth + 1); + } + + size_t Depth() const { + if (Code) { + return 0; + } + return Max(Nodes[0]->Depth(), Nodes[1]->Depth()) + 1; + } + }; + + struct THLess { + bool operator()(const THuffNode* a, const THuffNode* b) { + if (a->Weight > b->Weight) + return 1; + if (a->Weight == b->Weight && a->Priority > b->Priority) + return 1; + return 0; + } + }; + + inline i64 BuildHuff(TVector<TCode>& codes) { + TVector<TSimpleSharedPtr<THuffNode>> hold; + std::priority_queue<THuffNode*, TVector<THuffNode*>, THLess> nodes; + i64 ret = 0; + + int priority = 0; + for (size_t i = 0; i < codes.size(); ++i) { + TSimpleSharedPtr<THuffNode> node(new THuffNode(codes[i].Probability, priority++, &codes[i])); + hold.push_back(node); + nodes.push(node.Get()); + } + + while (nodes.size() > 1) { + THuffNode* nodea = nodes.top(); + nodes.pop(); + THuffNode* nodeb = nodes.top(); + nodes.pop(); + TSimpleSharedPtr<THuffNode> node(new THuffNode(nodea->Weight + nodeb->Weight, priority++, nullptr)); + node->Nodes[0] = nodea; + node->Nodes[1] = nodeb; + hold.push_back(node); + nodes.push(node.Get()); + } + + if (nodes.size()) { + THuffNode* node = nodes.top(); + node->BuildPrefixes(0, 0); + ret = node->Iterate(0); + } + + return ret; + }; + + struct TCoderEntry { + ui32 MinValue; + ui16 Prefix; + ui8 PrefixBits; + ui8 AllBits; + + ui64 MaxValue() const { + return MinValue + (1ULL << (AllBits - PrefixBits)); + } + }; + + inline i64 Analyze(const TAccum& acc, TVector<TCoderEntry>& retCodes) { + i64 ret; + for (int k = 256; k > 0; --k) { + retCodes.clear(); + TVector<std::pair<i64, TCode>> pairs; + acc.GenerateFreqs(pairs, k); + TVector<TCode> codes; + for (size_t i = 0; i < pairs.size(); ++i) { + codes.push_back(pairs[i].second); + } + + StableSort(codes.begin(), codes.end(), std::greater<TCode>()); + + ret = BuildHuff(codes); + bool valid = true; + for (size_t i = 0; i < codes.size(); ++i) { + TCoderEntry code; + code.MinValue = codes[i].Start; + code.Prefix = codes[i].Prefix; + code.PrefixBits = codes[i].PrefLength; + if (code.PrefixBits > 6) + valid = false; + code.AllBits = code.PrefixBits + codes[i].Bits; + retCodes.push_back(code); + } + if (valid) + return ret; + } + + return ret; + } + + struct TComparer { + bool operator()(const TCoderEntry& e0, const TCoderEntry& e1) const { + return e0.AllBits < e1.AllBits; + } + }; + + struct TCoder { + TVector<TCoderEntry> Entries; + void Normalize() { + TComparer comp; + StableSort(Entries.begin(), Entries.end(), comp); + } + TCoder() { + InitDefault(); + } + void InitDefault() { + ui64 cum = 0; + Cache.Clear(); + Entries.clear(); + ui16 b = 1; + for (ui16 i = 0; i < 40; ++i) { + ui16 bits = Min(b, (ui16)(32)); + b = (b * 16) / 10 + 1; + if (b > 32) + b = 32; + TCoderEntry entry; + entry.PrefixBits = i + 1; + entry.AllBits = entry.PrefixBits + bits; + entry.MinValue = (ui32)Min(cum, (ui64)(ui32)(-1)); + cum += (1ULL << bits); + entry.Prefix = ((1UL << i) - 1); + Entries.push_back(entry); + if (cum > (ui32)(-1)) { + return; + } + }; + } + + TCache<1024, TCoderEntry> Cache; + + ui64 RealCode(ui32 value, const TCoderEntry& entry, size_t& length) { + length = entry.AllBits; + return (ui64(value - entry.MinValue) << entry.PrefixBits) + entry.Prefix; + } + + bool Empty() const { + return Entries.empty(); + } + const TCoderEntry& GetEntry(ui32 code, ui8& id) const { + for (size_t i = 0; i < Entries.size(); ++i) { + const TCoderEntry& entry = Entries[i]; + ui32 prefMask = (1UL << entry.PrefixBits) - 1UL; + if (entry.Prefix == (code & prefMask)) { + id = ui8(i); + return entry; + } + } + ythrow yexception() << "bad entry"; + return Entries[0]; + } + + ui64 Code(ui32 entry, size_t& length) { + ui32 index = Cache.Hash(entry); + if (Cache.CacheKey[index] == entry) { + ++Cache.Hits; + return RealCode(entry, Cache.CacheVal[index], length); + } + ++Cache.Misses; + for (size_t i = 0; i < Entries.size(); ++i) { + if (entry >= Entries[i].MinValue && entry < Entries[i].MaxValue()) { + Cache.CacheKey[index] = entry; + Cache.CacheVal[index] = Entries[i]; + return RealCode(entry, Cache.CacheVal[index], length); + } + } + + ythrow yexception() << "bad huff tree"; + return 0; + } + }; + +} |