#pragma once
#include <util/system/defaults.h>
#include <util/generic/yexception.h>
#include <util/generic/ptr.h>
#include <util/generic/vector.h>
#include <util/generic/algorithm.h>
#include <utility>
#include <queue>
#include "compressor.h"
namespace NCompProto {
template <size_t CacheSize, typename TEntry>
struct TCache {
ui32 CacheKey[CacheSize];
TEntry CacheVal[CacheSize];
size_t Hits;
size_t Misses;
ui32 Hash(ui32 key) {
return key % CacheSize;
}
TCache() {
Hits = 0;
Misses = 0;
Clear();
}
void Clear() {
for (size_t i = 0; i < CacheSize; ++i) {
ui32 j = 0;
for (; Hash(j) == i; ++j)
;
CacheKey[i] = j;
}
}
};
struct TCode {
i64 Probability;
ui32 Start;
ui32 Bits;
ui32 Prefix;
ui32 PrefLength;
TCode(i64 probability = 0, ui32 start = 0, ui32 bits = 0)
: Probability(probability)
, Start(start)
, Bits(bits)
{
}
bool operator<(const TCode& code) const {
return Probability < code.Probability;
}
bool operator>(const TCode& code) const {
return Probability > code.Probability;
}
};
struct TAccum {
struct TTable {
TAutoPtr<TTable> Tables[16];
i64 Counts[16];
TTable(const TTable& other) {
for (size_t i = 0; i < 16; ++i) {
Counts[i] = other.Counts[i];
if (other.Tables[i].Get()) {
Tables[i].Reset(new TTable(*other.Tables[i].Get()));
}
}
}
TTable() {
for (auto& count : Counts)
count = 0;
}
i64 GetCellCount(size_t i) {
i64 count = Counts[i];
if (Tables[i].Get()) {
for (size_t j = 0; j < 16; ++j) {
count += Tables[i]->GetCellCount(j);
}
}
return count;
}
i64 GetCount() {
i64 count = 0;
for (size_t j = 0; j < 16; ++j) {
count += GetCellCount(j);
}
return count;
}
void GenerateFreqs(TVector<std::pair<i64, TCode>>& codes, int depth, int termDepth, ui32 code, i64 cnt) {
if (depth == termDepth) {
for (size_t i = 0; i < 16; ++i) {
i64 iCount = GetCellCount(i);
if (Tables[i].Get()) {
Counts[i] = iCount;
Tables[i].Reset(nullptr);
}
if (iCount > cnt || (termDepth == 0 && iCount > 0)) {
std::pair<i64, TCode> codep;
codep.first = iCount;
codep.second.Probability = iCount;
codep.second.Start = code + (i << (28 - depth));
codep.second.Bits = 28 - depth;
codes.push_back(codep);
Counts[i] = 0;
}
}
}
for (size_t i = 0; i < 16; ++i) {
if (Tables[i].Get()) {
Tables[i]->GenerateFreqs(codes, depth + 4, termDepth, code + (i << (28 - depth)), cnt);
}
}
}
};
TTable Root;
int TableCount;
i64 Total;
ui64 Max;
TAccum() {
TableCount = 0;
Total = 0;
Max = 0;
}
void GenerateFreqs(TVector<std::pair<i64, TCode>>& codes, int mul) const {
TTable root(Root);
for (int i = 28; i > 0; i -= 4) {
root.GenerateFreqs(codes, 0, i, 0, Total / mul);
}
i64 iCount = root.GetCount();
if (iCount == 0)
return;
std::pair<i64, TCode> codep;
codep.first = iCount;
codep.second.Probability = iCount;
codep.second.Start = 0;
ui32 bits = 0;
while (1) {
if ((1ULL << bits) > Max)
break;
++bits;
}
codep.second.Bits = bits;
codes.push_back(codep);
}
TCache<256, i64*> Cache;
void AddMap(ui32 value, i64 weight = 1) {
ui32 index = Cache.Hash(value);
if (Cache.CacheKey[index] == value) {
Cache.CacheVal[index][0] += weight;
return;
}
TTable* root = &Root;
for (size_t i = 0; i < 15; ++i) {
ui32 index2 = (value >> (28 - i * 4)) & 0xf;
if (!root->Tables[index2].Get()) {
if (TableCount < 1024) {
++TableCount;
root->Tables[index2].Reset(new TTable);
} else {
Cache.CacheKey[index2] = value;
Cache.CacheVal[index2] = &root->Counts[index2];
root->Counts[index2] += weight;
return;
}
}
root = root->Tables[index2].Get();
}
Cache.CacheKey[index] = value;
Cache.CacheVal[index] = &root->Counts[value & 0xf];
root->Counts[value & 0xf] += weight;
}
void Add(ui32 value, i64 weight = 1) {
Max = ::Max(Max, (ui64)value);
Total += weight;
AddMap(value, weight);
}
};
struct THuffNode {
i64 Weight;
i64 Priority;
THuffNode* Nodes[2];
TCode* Code;
THuffNode(i64 weight, i64 priority, TCode* code)
: Weight(weight)
, Priority(priority)
, Code(code)
{
Nodes[0] = nullptr;
Nodes[1] = nullptr;
}
void BuildPrefixes(ui32 depth, ui32 prefix) {
if (Code) {
Code->Prefix = prefix;
Code->PrefLength = depth;
return;
}
Nodes[0]->BuildPrefixes(depth + 1, prefix + (0UL << depth));
Nodes[1]->BuildPrefixes(depth + 1, prefix + (1UL << depth));
}
i64 Iterate(size_t depth) const {
if (Code) {
return (depth + Code->Bits) * Code->Probability;
}
return Nodes[0]->Iterate(depth + 1) + Nodes[1]->Iterate(depth + 1);
}
size_t Depth() const {
if (Code) {
return 0;
}
return Max(Nodes[0]->Depth(), Nodes[1]->Depth()) + 1;
}
};
struct THLess {
bool operator()(const THuffNode* a, const THuffNode* b) {
if (a->Weight > b->Weight)
return 1;
if (a->Weight == b->Weight && a->Priority > b->Priority)
return 1;
return 0;
}
};
inline i64 BuildHuff(TVector<TCode>& codes) {
TVector<TSimpleSharedPtr<THuffNode>> hold;
std::priority_queue<THuffNode*, TVector<THuffNode*>, THLess> nodes;
i64 ret = 0;
int priority = 0;
for (size_t i = 0; i < codes.size(); ++i) {
TSimpleSharedPtr<THuffNode> node(new THuffNode(codes[i].Probability, priority++, &codes[i]));
hold.push_back(node);
nodes.push(node.Get());
}
while (nodes.size() > 1) {
THuffNode* nodea = nodes.top();
nodes.pop();
THuffNode* nodeb = nodes.top();
nodes.pop();
TSimpleSharedPtr<THuffNode> node(new THuffNode(nodea->Weight + nodeb->Weight, priority++, nullptr));
node->Nodes[0] = nodea;
node->Nodes[1] = nodeb;
hold.push_back(node);
nodes.push(node.Get());
}
if (nodes.size()) {
THuffNode* node = nodes.top();
node->BuildPrefixes(0, 0);
ret = node->Iterate(0);
}
return ret;
}
struct TCoderEntry {
ui32 MinValue;
ui16 Prefix;
ui8 PrefixBits;
ui8 AllBits;
ui64 MaxValue() const {
return MinValue + (1ULL << (AllBits - PrefixBits));
}
};
inline i64 Analyze(const TAccum& acc, TVector<TCoderEntry>& retCodes) {
i64 ret;
for (int k = 256; k > 0; --k) {
retCodes.clear();
TVector<std::pair<i64, TCode>> pairs;
acc.GenerateFreqs(pairs, k);
TVector<TCode> codes;
for (size_t i = 0; i < pairs.size(); ++i) {
codes.push_back(pairs[i].second);
}
StableSort(codes.begin(), codes.end(), std::greater<TCode>());
ret = BuildHuff(codes);
bool valid = true;
for (size_t i = 0; i < codes.size(); ++i) {
TCoderEntry code;
code.MinValue = codes[i].Start;
code.Prefix = codes[i].Prefix;
code.PrefixBits = codes[i].PrefLength;
if (code.PrefixBits > 6)
valid = false;
code.AllBits = code.PrefixBits + codes[i].Bits;
retCodes.push_back(code);
}
if (valid)
return ret;
}
return ret;
}
struct TComparer {
bool operator()(const TCoderEntry& e0, const TCoderEntry& e1) const {
return e0.AllBits < e1.AllBits;
}
};
struct TCoder {
TVector<TCoderEntry> Entries;
void Normalize() {
TComparer comp;
StableSort(Entries.begin(), Entries.end(), comp);
}
TCoder() {
InitDefault();
}
void InitDefault() {
ui64 cum = 0;
Cache.Clear();
Entries.clear();
ui16 b = 1;
for (ui16 i = 0; i < 40; ++i) {
ui16 bits = Min(b, (ui16)(32));
b = (b * 16) / 10 + 1;
if (b > 32)
b = 32;
TCoderEntry entry;
entry.PrefixBits = i + 1;
entry.AllBits = entry.PrefixBits + bits;
entry.MinValue = (ui32)Min(cum, (ui64)(ui32)(-1));
cum += (1ULL << bits);
entry.Prefix = ((1UL << i) - 1);
Entries.push_back(entry);
if (cum > (ui32)(-1)) {
return;
}
}
}
TCache<1024, TCoderEntry> Cache;
ui64 RealCode(ui32 value, const TCoderEntry& entry, size_t& length) {
length = entry.AllBits;
return (ui64(value - entry.MinValue) << entry.PrefixBits) + entry.Prefix;
}
bool Empty() const {
return Entries.empty();
}
const TCoderEntry& GetEntry(ui32 code, ui8& id) const {
for (size_t i = 0; i < Entries.size(); ++i) {
const TCoderEntry& entry = Entries[i];
ui32 prefMask = (1UL << entry.PrefixBits) - 1UL;
if (entry.Prefix == (code & prefMask)) {
id = ui8(i);
return entry;
}
}
ythrow yexception() << "bad entry";
return Entries[0];
}
ui64 Code(ui32 entry, size_t& length) {
ui32 index = Cache.Hash(entry);
if (Cache.CacheKey[index] == entry) {
++Cache.Hits;
return RealCode(entry, Cache.CacheVal[index], length);
}
++Cache.Misses;
for (size_t i = 0; i < Entries.size(); ++i) {
if (entry >= Entries[i].MinValue && entry < Entries[i].MaxValue()) {
Cache.CacheKey[index] = entry;
Cache.CacheVal[index] = Entries[i];
return RealCode(entry, Cache.CacheVal[index], length);
}
}
ythrow yexception() << "bad huff tree";
return 0;
}
};
}