#pragma once

#include "comptrie_builder.h"
#include "comptrie_trie.h"
#include "comptrie_impl.h"
#include <library/cpp/packers/packers.h>

#include <util/system/yassert.h>
#include <util/generic/vector.h>
#include <util/generic/deque.h>
#include <util/stream/str.h>

// Aho-Corasick algorithm implementation using CompactTrie implementation of Sedgewick's T-trie

namespace NCompactTrie {
    struct TSuffixLink {
        ui64 NextSuffixOffset;
        ui64 NextSuffixWithDataOffset;

        TSuffixLink(ui64 nextSuffixOffset = 0, ui64 nextSuffixWithDataOffset = 0)
            : NextSuffixOffset(nextSuffixOffset)
            , NextSuffixWithDataOffset(nextSuffixWithDataOffset)
        {
        }
    };

    const size_t FLAGS_SIZE = sizeof(char);
    const size_t SYMBOL_SIZE = sizeof(char);
};

template <class T = char, class D = ui64, class S = TCompactTriePacker<D>>
class TCompactPatternSearcherBuilder : protected TCompactTrieBuilder<T, D, S> {
public:
    typedef T TSymbol;
    typedef D TData;
    typedef S TPacker;

    typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey;
    typedef typename TCompactTrieKeySelector<TSymbol>::TKeyBuf TKeyBuf;

    typedef TCompactTrieBuilder<T, D, S> TBase;

public:
    TCompactPatternSearcherBuilder() {
        TBase::Impl = MakeHolder<TCompactPatternSearcherBuilderImpl>();
    }

    bool Add(const TSymbol* key, size_t keyLength, const TData& value) {
        return TBase::Impl->AddEntry(key, keyLength, value);
    }
    bool Add(const TKeyBuf& key, const TData& value) {
        return Add(key.data(), key.size(), value);
    }

    bool Find(const TSymbol* key, size_t keyLength, TData* value) const {
        return TBase::Impl->FindEntry(key, keyLength, value);
    }
    bool Find(const TKeyBuf& key, TData* value = nullptr) const {
        return Find(key.data(), key.size(), value);
    }

    size_t Save(IOutputStream& os) const {
        size_t trieSize = TBase::Impl->MeasureByteSize();
        TBufferOutput serializedTrie(trieSize);
        TBase::Impl->Save(serializedTrie);

        auto serializedTrieBuffer = serializedTrie.Buffer();
        CalculateSuffixLinks(
            serializedTrieBuffer.Data(),
            serializedTrieBuffer.Data() + serializedTrieBuffer.Size()
        );

        os.Write(serializedTrieBuffer.Data(), serializedTrieBuffer.Size());
        return trieSize;
    }

    TBlob Save() const {
        TBufferStream buffer;
        Save(buffer);
        return TBlob::FromStream(buffer);
    }

    size_t SaveToFile(const TString& fileName) const {
        TFileOutput out(fileName);
        return Save(out);
    }

    size_t MeasureByteSize() const {
        return TBase::Impl->MeasureByteSize();
    }

private:
    void CalculateSuffixLinks(char* trieStart, const char* trieEnd) const;

protected:
    class TCompactPatternSearcherBuilderImpl : public TBase::TCompactTrieBuilderImpl {
    public:
        typedef typename TBase::TCompactTrieBuilderImpl TImplBase;

        TCompactPatternSearcherBuilderImpl(
            TCompactTrieBuilderFlags flags = CTBF_NONE,
            TPacker packer = TPacker(),
            IAllocator* alloc = TDefaultAllocator::Instance()
        ) : TImplBase(flags, packer, alloc) {
        }

        ui64 ArcMeasure(
            const typename TImplBase::TArc* arc,
            size_t leftSize,
            size_t rightSize
        ) const override {
            using namespace NCompactTrie;

            size_t coreSize = SYMBOL_SIZE + FLAGS_SIZE +
                sizeof(TSuffixLink) +
                this->NodeMeasureLeafValue(arc->Node);
            size_t treeSize = this->NodeMeasureSubtree(arc->Node);

            if (arc->Label.Length() > 0)
                treeSize += (SYMBOL_SIZE + FLAGS_SIZE + sizeof(TSuffixLink)) *
                    (arc->Label.Length() - 1);

            // Triple measurements are needed because the space needed to store the offset
            // shall be added to the offset itself. Hence three iterations.
            size_t leftOffsetSize = 0;
            size_t rightOffsetSize = 0;
            for (size_t iteration = 0; iteration < 3; ++iteration) {
                leftOffsetSize = leftSize ? MeasureOffset(
                    coreSize + treeSize + leftOffsetSize + rightOffsetSize) : 0;
                rightOffsetSize = rightSize ? MeasureOffset(
                    coreSize + treeSize + leftSize + leftOffsetSize + rightOffsetSize) : 0;
            }

            coreSize += leftOffsetSize + rightOffsetSize;
            arc->LeftOffset = leftSize ? coreSize + treeSize : 0;
            arc->RightOffset = rightSize ? coreSize + treeSize + leftSize : 0;

            return coreSize + treeSize + leftSize + rightSize;
        }

        ui64 ArcSaveSelf(const typename TImplBase::TArc* arc, IOutputStream& os) const override {
            using namespace NCompactTrie;

            ui64 written = 0;

            size_t leftOffsetSize = MeasureOffset(arc->LeftOffset);
            size_t rightOffsetSize = MeasureOffset(arc->RightOffset);

            size_t labelLen = arc->Label.Length();

            for (size_t labelPos = 0; labelPos < labelLen; ++labelPos) {
                char flags = 0;

                if (labelPos == 0) {
                    flags |= (leftOffsetSize << MT_LEFTSHIFT);
                    flags |= (rightOffsetSize << MT_RIGHTSHIFT);
                }

                if (labelPos == labelLen - 1) {
                    if (arc->Node->IsFinal())
                        flags |= MT_FINAL;
                    if (!arc->Node->IsLast())
                        flags |= MT_NEXT;
                } else {
                    flags |= MT_NEXT;
                }

                os.Write(&flags, 1);
                os.Write(&arc->Label.AsCharPtr()[labelPos], 1);
                written += 2;

                TSuffixLink suffixlink;
                os.Write(&suffixlink, sizeof(TSuffixLink));
                written += sizeof(TSuffixLink);

                if (labelPos == 0) {
                    written += ArcSaveOffset(arc->LeftOffset, os);
                    written += ArcSaveOffset(arc->RightOffset, os);
                }
            }

            written += this->NodeSaveLeafValue(arc->Node, os);
            return written;
        }
    };
};


template <class T>
struct TPatternMatch {
    ui64 End;
    T Data;

    TPatternMatch(ui64 end, const T& data)
        : End(end)
        , Data(data)
    {
    }
};


template <class T = char, class D = ui64, class S = TCompactTriePacker<D>>
class TCompactPatternSearcher {
public:
    typedef T TSymbol;
    typedef D TData;
    typedef S TPacker;

    typedef typename TCompactTrieKeySelector<TSymbol>::TKey TKey;
    typedef typename TCompactTrieKeySelector<TSymbol>::TKeyBuf TKeyBuf;

    typedef TCompactTrie<TSymbol, TData, TPacker> TTrie;
public:
    TCompactPatternSearcher()
    {
    }

    explicit TCompactPatternSearcher(const TBlob& data)
        : Trie(data)
    {
    }

    TCompactPatternSearcher(const char* data, size_t size)
        : Trie(data, size)
    {
    }

    TVector<TPatternMatch<TData>> SearchMatches(const TSymbol* text, size_t textSize) const;
    TVector<TPatternMatch<TData>> SearchMatches(const TKeyBuf& text) const {
        return SearchMatches(text.data(), text.size());
    }
private:
    TTrie Trie;
};

////////////////////
// Implementation //
////////////////////

namespace {

template <class TData, class TPacker>
char ReadNode(
    char* nodeStart,
    char*& leftSibling,
    char*& rightSibling,
    char*& directChild,
    NCompactTrie::TSuffixLink*& suffixLink,
    TPacker packer = TPacker()
) {
    char* dataPos = nodeStart;
    char flags = *(dataPos++);

    Y_ASSERT(!NCompactTrie::IsEpsilonLink(flags)); // Epsilon links are not allowed

    char label = *(dataPos++);

    suffixLink = (NCompactTrie::TSuffixLink*)dataPos;
    dataPos += sizeof(NCompactTrie::TSuffixLink);

    { // Left branch
        size_t offsetLength = NCompactTrie::LeftOffsetLen(flags);
        size_t leftOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength);
        leftSibling = leftOffset ? (nodeStart + leftOffset) : nullptr;

        dataPos += offsetLength;
    }


    { // Right branch
        size_t offsetLength = NCompactTrie::RightOffsetLen(flags);
        size_t rightOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength);
        rightSibling = rightOffset ? (nodeStart + rightOffset) : nullptr;

        dataPos += offsetLength;
    }

    directChild = nullptr;
    if (flags & NCompactTrie::MT_NEXT) {
        directChild = dataPos;
        if (flags & NCompactTrie::MT_FINAL) {
            directChild += packer.SkipLeaf(directChild);
        }
    }

    return label;
}

template <class TData, class TPacker>
char ReadNodeConst(
    const char* nodeStart,
    const char*& leftSibling,
    const char*& rightSibling,
    const char*& directChild,
    const char*& data,
    NCompactTrie::TSuffixLink& suffixLink,
    TPacker packer = TPacker()
) {
    const char* dataPos = nodeStart;
    char flags = *(dataPos++);

    Y_ASSERT(!NCompactTrie::IsEpsilonLink(flags)); // Epsilon links are not allowed

    char label = *(dataPos++);

    suffixLink = *((NCompactTrie::TSuffixLink*)dataPos);
    dataPos += sizeof(NCompactTrie::TSuffixLink);

    { // Left branch
        size_t offsetLength = NCompactTrie::LeftOffsetLen(flags);
        size_t leftOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength);
        leftSibling = leftOffset ? (nodeStart + leftOffset) : nullptr;

        dataPos += offsetLength;
    }


    { // Right branch
        size_t offsetLength = NCompactTrie::RightOffsetLen(flags);
        size_t rightOffset = NCompactTrie::UnpackOffset(dataPos, offsetLength);
        rightSibling = rightOffset ? (nodeStart + rightOffset) : nullptr;

        dataPos += offsetLength;
    }

    data = nullptr;
    if (flags & NCompactTrie::MT_FINAL) {
        data = dataPos;
    }
    directChild = nullptr;
    if (flags & NCompactTrie::MT_NEXT) {
        directChild = dataPos;
        if (flags & NCompactTrie::MT_FINAL) {
            directChild += packer.SkipLeaf(directChild);
        }
    }

    return label;
}

Y_FORCE_INLINE bool Advance(
    const char*& dataPos,
    const char* const dataEnd,
    char label
) {
    if (dataPos == nullptr) {
        return false;
    }

    while (dataPos < dataEnd) {
        size_t offsetLength, offset;
        const char* startPos = dataPos;
        char flags = *(dataPos++);
        char symbol = *(dataPos++);
        dataPos += sizeof(NCompactTrie::TSuffixLink);

        // Left branch
        offsetLength = NCompactTrie::LeftOffsetLen(flags);
        if ((unsigned char)label < (unsigned char)symbol) {
            offset = NCompactTrie::UnpackOffset(dataPos, offsetLength);
            if (!offset)
                break;

            dataPos = startPos + offset;
            continue;
        }

        dataPos += offsetLength;

        // Right branch
        offsetLength = NCompactTrie::RightOffsetLen(flags);
        if ((unsigned char)label > (unsigned char)symbol) {
            offset = NCompactTrie::UnpackOffset(dataPos, offsetLength);
            if (!offset)
                break;

            dataPos = startPos + offset;
            continue;
        }

        dataPos = startPos;
        return true;
    }

    // if we got here, we're past the dataend - bail out ASAP
    dataPos = nullptr;
    return false;
}

} // anonymous

template <class T, class D, class S>
void TCompactPatternSearcherBuilder<T, D, S>::CalculateSuffixLinks(
    char* trieStart,
    const char* trieEnd
) const {
    struct TBfsElement {
        char* Node;
        const char* Parent;

        TBfsElement(char* node, const char* parent)
            : Node(node)
            , Parent(parent)
        {
        }
    };

    TDeque<TBfsElement> bfsQueue;
    if (trieStart && trieStart != trieEnd) {
        bfsQueue.emplace_back(trieStart, nullptr);
    }

    while (!bfsQueue.empty()) {
        auto front = bfsQueue.front();
        char* node = front.Node;
        const char* parent = front.Parent;
        bfsQueue.pop_front();

        char* leftSibling;
        char* rightSibling;
        char* directChild;
        NCompactTrie::TSuffixLink* suffixLink;

        char label = ReadNode<TData, TPacker>(
            node,
            leftSibling,
            rightSibling,
            directChild,
            suffixLink
        );

        const char* suffix;

        if (parent == nullptr) {
            suffix = node;
        } else {
            const char* parentOfSuffix = parent;
            const char* temp;
            do {
                NCompactTrie::TSuffixLink parentOfSuffixSuffixLink;

                ReadNodeConst<TData, TPacker>(
                    parentOfSuffix,
                    /*left*/temp,
                    /*right*/temp,
                    /*direct*/temp,
                    /*data*/temp,
                    parentOfSuffixSuffixLink
                );
                if (parentOfSuffixSuffixLink.NextSuffixOffset == 0) {
                    suffix = trieStart;
                    if (!Advance(suffix, trieEnd, label)) {
                        suffix = node;
                    }
                    break;
                }
                parentOfSuffix += parentOfSuffixSuffixLink.NextSuffixOffset;

                NCompactTrie::TSuffixLink tempSuffixLink;
                ReadNodeConst<TData, TPacker>(
                    parentOfSuffix,
                    /*left*/temp,
                    /*right*/temp,
                    /*direct*/suffix,
                    /*data*/temp,
                    tempSuffixLink
                );

                if (suffix == nullptr) {
                    continue;
                }
            } while (!Advance(suffix, trieEnd, label));
        }

        suffixLink->NextSuffixOffset = suffix - node;

        NCompactTrie::TSuffixLink suffixSuffixLink;
        const char* suffixData;
        const char* temp;
        ReadNodeConst<TData, TPacker>(
            suffix,
            /*left*/temp,
            /*right*/temp,
            /*direct*/temp,
            suffixData,
            suffixSuffixLink
        );
        suffixLink->NextSuffixWithDataOffset = suffix - node;
        if (suffixData == nullptr) {
            suffixLink->NextSuffixWithDataOffset += suffixSuffixLink.NextSuffixWithDataOffset;
        }

        if (directChild) {
            bfsQueue.emplace_back(directChild, node);
        }

        if (leftSibling) {
            bfsQueue.emplace_front(leftSibling, parent);
        }

        if (rightSibling) {
            bfsQueue.emplace_front(rightSibling, parent);
        }
    }
}


template<class T, class D, class S>
TVector<TPatternMatch<D>> TCompactPatternSearcher<T, D, S>::SearchMatches(
    const TSymbol* text,
    size_t textSize
) const {
    const char* temp;
    NCompactTrie::TSuffixLink tempSuffixLink;

    const auto& trieData = Trie.Data();
    const char* trieStart = trieData.AsCharPtr();
    size_t dataSize = trieData.Length();
    const char* trieEnd = trieStart + dataSize;

    const char* lastNode = nullptr;
    const char* currentSubtree = trieStart;

    TVector<TPatternMatch<TData>> matches;

    for (const TSymbol* position = text; position < text + textSize; ++position) {
        TSymbol symbol = *position;
        for (i64 i = (i64)NCompactTrie::ExtraBits<TSymbol>(); i >= 0; i -= 8) {
            char label = (char)(symbol >> i);

            // Find first suffix extendable by label
            while (true) {
                const char* nextLastNode = currentSubtree;
                if (Advance(nextLastNode, trieEnd, label)) {
                    lastNode = nextLastNode;
                    ReadNodeConst<TData, TPacker>(
                        lastNode,
                        /*left*/temp,
                        /*right*/temp,
                        currentSubtree,
                        /*data*/temp,
                        tempSuffixLink
                    );
                    break;
                } else {
                    if (lastNode == nullptr) {
                        break;
                    }
                }

                NCompactTrie::TSuffixLink suffixLink;
                ReadNodeConst<TData, TPacker>(
                    lastNode,
                    /*left*/temp,
                    /*right*/temp,
                    /*direct*/temp,
                    /*data*/temp,
                    suffixLink
                );
                if (suffixLink.NextSuffixOffset == 0) {
                    lastNode = nullptr;
                    currentSubtree = trieStart;
                    continue;
                }
                lastNode += suffixLink.NextSuffixOffset;
                ReadNodeConst<TData, TPacker>(
                    lastNode,
                    /*left*/temp,
                    /*right*/temp,
                    currentSubtree,
                    /*data*/temp,
                    tempSuffixLink
                );
            }

            // Iterate through all suffixes
            const char* suffix = lastNode;
            while (suffix != nullptr) {
                const char* nodeData;
                NCompactTrie::TSuffixLink suffixLink;
                ReadNodeConst<TData, TPacker>(
                    suffix,
                    /*left*/temp,
                    /*right*/temp,
                    /*direct*/temp,
                    nodeData,
                    suffixLink
                );
                if (nodeData != nullptr) {
                    TData data;
                    Trie.GetPacker().UnpackLeaf(nodeData, data);
                    matches.emplace_back(
                        position - text,
                        data
                    );
                }
                if (suffixLink.NextSuffixOffset == 0) {
                    break;
                }
                suffix += suffixLink.NextSuffixWithDataOffset;
            }
        }
    }

    return matches;
}