#include "huffman_codec.h"
#include <library/cpp/bit_io/bitinput.h>
#include <library/cpp/bit_io/bitoutput.h>

#include <util/generic/algorithm.h>
#include <util/generic/bitops.h>
#include <util/stream/buffer.h>
#include <util/stream/length.h>
#include <util/string/printf.h>

namespace NCodecs {
    template <typename T>
    struct TCanonicalCmp {
        bool operator()(const T& a, const T& b) const {
            if (a.CodeLength == b.CodeLength) {
                return a.Char < b.Char;
            } else {
                return a.CodeLength < b.CodeLength;
            }
        }
    };

    template <typename T>
    struct TByCharCmp {
        bool operator()(const T& a, const T& b) const {
            return a.Char < b.Char;
        }
    };

    struct TTreeEntry {
        static const ui32 InvalidBranch = (ui32)-1;

        ui64 Freq = 0;
        ui32 Branches[2]{InvalidBranch, InvalidBranch};

        ui32 CodeLength = 0;
        ui8 Char = 0;
        bool Invalid = false;

        TTreeEntry() = default;

        static bool ByFreq(const TTreeEntry& a, const TTreeEntry& b) {
            return a.Freq < b.Freq;
        }

        static bool ByFreqRev(const TTreeEntry& a, const TTreeEntry& b) {
            return a.Freq > b.Freq;
        }
    };

    using TCodeTree = TVector<TTreeEntry>;

    void InitTreeByFreqs(TCodeTree& tree, const ui64 freqs[256]) {
        tree.reserve(255 * 256 / 2); // worst case - balanced tree

        for (ui32 i = 0; i < 256; ++i) {
            tree.emplace_back();
            tree.back().Char = i;
            tree.back().Freq = freqs[i];
        }

        StableSort(tree.begin(), tree.end(), TTreeEntry::ByFreq);
    }

    void InitTree(TCodeTree& tree, ISequenceReader* in) {
        using namespace NPrivate;
        ui64 freqs[256];
        Zero(freqs);

        TStringBuf r;
        while (in->NextRegion(r)) {
            for (ui64 i = 0; i < r.size(); ++i)
                ++freqs[(ui8)r[i]];
        }

        InitTreeByFreqs(tree, freqs);
    }

    void CalculateCodeLengths(TCodeTree& tree) {
        Y_ENSURE(tree.size() == 256, " ");
        const ui32 firstbranch = tree.size();

        ui32 curleaf = 0;
        ui32 curbranch = firstbranch;

        // building code tree. two priority queues are combined in one.
        while (firstbranch - curleaf + tree.size() - curbranch >= 2) {
            TTreeEntry e;

            for (auto& branche : e.Branches) {
                ui32 br;

                if (curleaf >= firstbranch)
                    br = curbranch++;
                else if (curbranch >= tree.size())
                    br = curleaf++;
                else if (tree[curleaf].Freq < tree[curbranch].Freq)
                    br = curleaf++;
                else
                    br = curbranch++;

                Y_ENSURE(br < tree.size(), " ");
                branche = br;
                e.Freq += tree[br].Freq;
            }

            tree.push_back(e);
            PushHeap(tree.begin() + curbranch, tree.end(), TTreeEntry::ByFreqRev);
        }

        // computing code lengths
        for (ui64 i = tree.size() - 1; i >= firstbranch; --i) {
            TTreeEntry e = tree[i];

            for (auto branche : e.Branches)
                tree[branche].CodeLength = e.CodeLength + 1;
        }

        // chopping off the branches
        tree.resize(firstbranch);

        Sort(tree.begin(), tree.end(), TCanonicalCmp<TTreeEntry>());

        // simplification: we are stripping codes longer than 64 bits
        while (!tree.empty() && tree.back().CodeLength > 64)
            tree.pop_back();

        // will not compress
        if (tree.empty())
            return;

        // special invalid code word
        tree.back().Invalid = true;
    }

    struct TEncoderEntry {
        ui64 Code = 0;

        ui8 CodeLength = 0;
        ui8 Char = 0;
        ui8 Invalid = true;

        explicit TEncoderEntry(TTreeEntry e)
            : CodeLength(e.CodeLength)
            , Char(e.Char)
            , Invalid(e.Invalid)
        {
        }

        TEncoderEntry() = default;
    };

    struct TEncoderTable {
        TEncoderEntry Entries[256];

        void Save(IOutputStream* out) const {
            ui16 nval = 0;

            for (auto entrie : Entries)
                nval += !entrie.Invalid;

            ::Save(out, nval);

            for (auto entrie : Entries) {
                if (!entrie.Invalid) {
                    ::Save(out, entrie.Char);
                    ::Save(out, entrie.CodeLength);
                }
            }
        }

        void Load(IInputStream* in) {
            ui16 nval = 0;
            ::Load(in, nval);

            for (ui32 i = 0; i < 256; ++i)
                Entries[i].Char = i;

            for (ui32 i = 0; i < nval; ++i) {
                ui8 ch = 0;
                ui8 len = 0;
                ::Load(in, ch);
                ::Load(in, len);
                Entries[ch].CodeLength = len;
                Entries[ch].Invalid = false;
            }
        }
    };

    struct TDecoderEntry {
        ui32 NextTable : 10;
        ui32 Char : 8;
        ui32 Invalid : 1;
        ui32 Bad : 1;

        TDecoderEntry()
            : NextTable()
            , Char()
            , Invalid()
            , Bad()
        {
        }
    };

    struct TDecoderTable: public TIntrusiveListItem<TDecoderTable> {
        ui64 Length = 0;
        ui64 BaseCode = 0;

        TDecoderEntry Entries[256];

        TDecoderTable() {
            Zero(Entries);
        }
    };

    const int CACHE_BITS_COUNT = 16;
    class THuffmanCodec::TImpl: public TAtomicRefCount<TImpl> {
        TEncoderTable Encoder;
        TDecoderTable Decoder[256];

        TEncoderEntry Invalid;

        ui32 SubTablesNum;

        class THuffmanCache {
            struct TCacheEntry {
                int EndOffset : 24;
                int BitsLeft : 8;
            };
            TVector<char> DecodeCache;
            TVector<TCacheEntry> CacheEntries;
            const TImpl& Original;

        public:
            THuffmanCache(const THuffmanCodec::TImpl& encoder);

            void Decode(NBitIO::TBitInput& in, TBuffer& out) const;
        };

        THolder<THuffmanCache> Cache;

    public:
        TImpl()
            : SubTablesNum(1)
        {
            Invalid.CodeLength = 255;
        }

        ui8 Encode(TStringBuf in, TBuffer& out) const {
            out.Clear();

            if (in.empty()) {
                return 0;
            }

            out.Reserve(in.size() * 2);

            {
                NBitIO::TBitOutputVector<TBuffer> bout(&out);
                TStringBuf tin = in;

                // data is under compression
                bout.Write(1, 1);

                for (auto t : tin) {
                    const TEncoderEntry& ce = Encoder.Entries[(ui8)t];

                    bout.Write(ce.Code, ce.CodeLength);

                    if (ce.Invalid) {
                        bout.Write(t, 8);
                    }
                }

                // in canonical huffman coding there cannot be a code having no 0 in the suffix
                // and shorter than 8 bits.
                bout.Write((ui64)-1, bout.GetByteReminder());
                return bout.GetByteReminder();
            }
        }

        void Decode(TStringBuf in, TBuffer& out) const {
            out.Clear();

            if (in.empty()) {
                return;
            }

            NBitIO::TBitInput bin(in);
            ui64 f = 0;
            bin.ReadK<1>(f);

            // if data is uncompressed
            if (!f) {
                in.Skip(1);
                out.Append(in.data(), in.size());
            } else {
                out.Reserve(in.size() * 8);

                if (Cache.Get()) {
                    Cache->Decode(bin, out);
                } else {
                    while (ReadNextChar(bin, out)) {
                    }
                }
            }
        }

        Y_FORCE_INLINE int ReadNextChar(NBitIO::TBitInput& bin, TBuffer& out) const {
            const TDecoderTable* table = Decoder;
            TDecoderEntry e;

            int bitsRead = 0;
            while (true) {
                ui64 code = 0;

                if (Y_UNLIKELY(!bin.Read(code, table->Length)))
                    return 0;
                bitsRead += table->Length;

                if (Y_UNLIKELY(code < table->BaseCode))
                    return 0;

                code -= table->BaseCode;

                if (Y_UNLIKELY(code > 255))
                    return 0;

                e = table->Entries[code];

                if (Y_UNLIKELY(e.Bad))
                    return 0;

                if (e.NextTable) {
                    table = Decoder + e.NextTable;
                } else {
                    if (e.Invalid) {
                        code = 0;
                        bin.ReadK<8>(code);
                        bitsRead += 8;
                        out.Append((ui8)code);
                    } else {
                        out.Append((ui8)e.Char);
                    }

                    return bitsRead;
                }
            }

            Y_ENSURE(false, " could not decode input");
            return 0;
        }

        void GenerateEncoder(TCodeTree& tree) {
            const ui64 sz = tree.size();

            TEncoderEntry lastcode = Encoder.Entries[tree[0].Char] = TEncoderEntry(tree[0]);

            for (ui32 i = 1; i < sz; ++i) {
                const TTreeEntry& te = tree[i];
                TEncoderEntry& e = Encoder.Entries[te.Char];
                e = TEncoderEntry(te);

                e.Code = (lastcode.Code + 1) << (e.CodeLength - lastcode.CodeLength);
                lastcode = e;

                e.Code = ReverseBits(e.Code, e.CodeLength);

                if (e.Invalid)
                    Invalid = e;
            }

            for (auto& e : Encoder.Entries) {
                if (e.Invalid)
                    e = Invalid;

                Y_ENSURE(e.CodeLength, " ");
            }
        }

        void RegenerateEncoder() {
            for (auto& entrie : Encoder.Entries) {
                if (entrie.Invalid)
                    entrie.CodeLength = Invalid.CodeLength;
            }

            Sort(Encoder.Entries, Encoder.Entries + 256, TCanonicalCmp<TEncoderEntry>());

            TEncoderEntry lastcode = Encoder.Entries[0];

            for (ui32 i = 1; i < 256; ++i) {
                TEncoderEntry& e = Encoder.Entries[i];
                e.Code = (lastcode.Code + 1) << (e.CodeLength - lastcode.CodeLength);
                lastcode = e;

                e.Code = ReverseBits(e.Code, e.CodeLength);
            }

            for (auto& entrie : Encoder.Entries) {
                if (entrie.Invalid) {
                    Invalid = entrie;
                    break;
                }
            }

            Sort(Encoder.Entries, Encoder.Entries + 256, TByCharCmp<TEncoderEntry>());

            for (auto& entrie : Encoder.Entries) {
                if (entrie.Invalid)
                    entrie = Invalid;
            }
        }

        void BuildDecoder() {
            TEncoderTable enc = Encoder;
            Sort(enc.Entries, enc.Entries + 256, TCanonicalCmp<TEncoderEntry>());

            TEncoderEntry& e1 = enc.Entries[0];
            Decoder[0].BaseCode = e1.Code;
            Decoder[0].Length = e1.CodeLength;

            for (auto e2 : enc.Entries) {
                SetEntry(Decoder, e2.Code, e2.CodeLength, e2);
            }
            Cache.Reset(new THuffmanCache(*this));
        }

        void SetEntry(TDecoderTable* t, ui64 code, ui64 len, TEncoderEntry e) {
            Y_ENSURE(len >= t->Length, len << " < " << t->Length);

            ui64 idx = (code & MaskLowerBits(t->Length)) - t->BaseCode;
            TDecoderEntry& d = t->Entries[idx];

            if (len == t->Length) {
                Y_ENSURE(!d.NextTable, " ");

                d.Char = e.Char;
                d.Invalid = e.Invalid;
                return;
            }

            if (!d.NextTable) {
                Y_ENSURE(SubTablesNum < Y_ARRAY_SIZE(Decoder), " ");
                d.NextTable = SubTablesNum++;
                TDecoderTable* nt = Decoder + d.NextTable;
                nt->Length = Min<ui64>(8, len - t->Length);
                nt->BaseCode = (code >> t->Length) & MaskLowerBits(nt->Length);
            }

            SetEntry(Decoder + d.NextTable, code >> t->Length, len - t->Length, e);
        }

        void Learn(ISequenceReader* in) {
            {
                TCodeTree tree;
                InitTree(tree, in);
                CalculateCodeLengths(tree);
                Y_ENSURE(!tree.empty(), " ");
                GenerateEncoder(tree);
            }
            BuildDecoder();
        }

        void LearnByFreqs(const TArrayRef<std::pair<char, ui64>>& freqs) {
            TCodeTree tree;

            ui64 freqsArray[256];
            Zero(freqsArray);

            for (const auto& freq : freqs)
                freqsArray[static_cast<ui8>(freq.first)] += freq.second;

            InitTreeByFreqs(tree, freqsArray);
            CalculateCodeLengths(tree);

            Y_ENSURE(!tree.empty(), " ");

            GenerateEncoder(tree);
            BuildDecoder();
        }

        void Save(IOutputStream* out) {
            ::Save(out, Invalid.CodeLength);
            Encoder.Save(out);
        }

        void Load(IInputStream* in) {
            ::Load(in, Invalid.CodeLength);
            Encoder.Load(in);
            RegenerateEncoder();
            BuildDecoder();
        }
    };

    THuffmanCodec::TImpl::THuffmanCache::THuffmanCache(const THuffmanCodec::TImpl& codec)
        : Original(codec)
    {
        CacheEntries.resize(1 << CACHE_BITS_COUNT);
        DecodeCache.reserve(CacheEntries.size() * 2);
        char buffer[2];
        TBuffer decoded;
        for (size_t i = 0; i < CacheEntries.size(); i++) {
            buffer[1] = i >> 8;
            buffer[0] = i;
            NBitIO::TBitInput bin(buffer, buffer + sizeof(buffer));
            int totalBits = 0;
            while (true) {
                decoded.Resize(0);
                int bits = codec.ReadNextChar(bin, decoded);
                if (totalBits + bits > 16 || !bits) {
                    TCacheEntry e = {static_cast<int>(DecodeCache.size()), 16 - totalBits};
                    CacheEntries[i] = e;
                    break;
                }

                for (TBuffer::TConstIterator it = decoded.Begin(); it != decoded.End(); ++it) {
                    DecodeCache.push_back(*it);
                }
                totalBits += bits;
            }
        }
        DecodeCache.push_back(0);
        CacheEntries.shrink_to_fit();
        DecodeCache.shrink_to_fit();
    }

    void THuffmanCodec::TImpl::THuffmanCache::Decode(NBitIO::TBitInput& bin, TBuffer& out) const {
        int bits = 0;
        ui64 code = 0;
        while (!bin.Eof()) {
            ui64 f = 0;
            const int toRead = 16 - bits;
            if (toRead > 0 && bin.Read(f, toRead)) {
                code = (code >> (16 - bits)) | (f << bits);
                code &= 0xFFFF;
                TCacheEntry entry = CacheEntries[code];
                int start = code > 0 ? CacheEntries[code - 1].EndOffset : 0;
                out.Append((const char*)&DecodeCache[start], (const char*)&DecodeCache[entry.EndOffset]);
                bits = entry.BitsLeft;
            } else { // should never happen until there are exceptions or unaligned input
                bin.Back(bits);
                if (!Original.ReadNextChar(bin, out))
                    break;

                code = 0;
                bits = 0;
            }
        }
    }

    THuffmanCodec::THuffmanCodec()
        : Impl(new TImpl)
    {
        MyTraits.NeedsTraining = true;
        MyTraits.PreservesPrefixGrouping = true;
        MyTraits.PaddingBit = 1;
        MyTraits.SizeOnEncodeMultiplier = 2;
        MyTraits.SizeOnDecodeMultiplier = 8;
        MyTraits.RecommendedSampleSize = 1 << 21;
    }

    THuffmanCodec::~THuffmanCodec() = default;

    ui8 THuffmanCodec::Encode(TStringBuf in, TBuffer& bbb) const {
        if (Y_UNLIKELY(!Trained))
            ythrow TCodecException() << " not trained";

        return Impl->Encode(in, bbb);
    }

    void THuffmanCodec::Decode(TStringBuf in, TBuffer& bbb) const {
        Impl->Decode(in, bbb);
    }

    void THuffmanCodec::Save(IOutputStream* out) const {
        Impl->Save(out);
    }

    void THuffmanCodec::Load(IInputStream* in) {
        Impl->Load(in);
    }

    void THuffmanCodec::DoLearn(ISequenceReader& in) {
        Impl->Learn(&in);
    }

    void THuffmanCodec::LearnByFreqs(const TArrayRef<std::pair<char, ui64>>& freqs) {
        Impl->LearnByFreqs(freqs);
        Trained = true;
    }

}