#pragma once

#include "buffered_io.h"
#include "class_factory.h"

#include <library/cpp/containers/2d_array/2d_array.h>

#include <util/generic/hash_set.h>
#include <util/generic/buffer.h>
#include <util/generic/list.h>
#include <util/generic/maybe.h>
#include <util/generic/bitmap.h>
#include <util/generic/variant.h>
#include <util/generic/ylimits.h>
#include <util/memory/blob.h>
#include <util/digest/murmur.h>
#include <util/system/compiler.h>

#include <array>
#include <bitset>
#include <list>
#include <string>

#ifdef _MSC_VER
#pragma warning(disable : 4127)
#endif

enum ESaverMode {
    SAVER_MODE_READ = 1,
    SAVER_MODE_WRITE = 2,
    SAVER_MODE_WRITE_COMPRESSED = 3,
};

namespace NBinSaverInternals {
    // This lets explicitly control the overload resolution priority
    // The higher P means higher priority in overload resolution order
    template <int P>
    struct TOverloadPriority : TOverloadPriority <P-1> {
    };

    template <>
    struct TOverloadPriority<0> {
    };
}

//////////////////////////////////////////////////////////////////////////
struct IBinSaver {
public:
    typedef unsigned char chunk_id;
    typedef ui32 TStoredSize; // changing this will break compatibility

private:
    // This overload is required to avoid infinite recursion when overriding serialization in derived classes:
    // struct B {
    //  virtual int operator &(IBinSaver& f) {
    //      return 0;
    //  }
    // };
    //
    // struct D : B {
    //  int operator &(IBinSaver& f) override {
    //      f.Add(0, static_cast<B*>(this));
    //      return 0;
    //  }
    // };
    template <class T, typename = decltype(std::declval<T*>()->T::operator&(std::declval<IBinSaver&>()))>
    void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<2>) { // highest priority -  will be resolved first if enabled
                                           // Note: p->operator &(*this) would lead to infinite recursion
        p->T::operator&(*this);
    }

    template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())>
    void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<1>) { // lower priority - will be resolved second if enabled
        (*p) & (*this);
    }

    template <class T>
    void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<0>) { // lower priority - will be resolved last
#if (!defined(_MSC_VER))
        // broken in clang16 for some types
        // In MSVC __has_trivial_copy returns false to enums, primitive types and arrays.
        // static_assert(__is_trivially_copyable(T), "Class is nontrivial copyable, you must define operator&, see");
#endif
        DataChunk(p, sizeof(T));
    }

    // vector
    template <class T, class TA>
    void DoVector(TVector<T, TA>& data) {
        TStoredSize nSize;
        if (IsReading()) {
            data.clear();
            Add(2, &nSize);
            data.resize(nSize);
        } else {
            nSize = data.size();
            CheckOverflow(nSize, data.size());
            Add(2, &nSize);
        }
        for (TStoredSize i = 0; i < nSize; i++)
            Add(1, &data[i]);
    }

    template <class T, int N>
    void DoArray(T (&data)[N]) {
        for (size_t i = 0; i < N; i++) {
            Add(1, &(data[i]));
        }
    }

    template <typename TLarge>
    void CheckOverflow(TStoredSize nSize, TLarge origSize) {
        if (nSize != origSize) {
            fprintf(stderr, "IBinSaver: object size is too large to be serialized (%" PRIu32 " != %" PRIu64 ")\n", nSize, (ui64)origSize);
            abort();
        }
    }

    template <class T, class TA>
    void DoDataVector(TVector<T, TA>& data) {
        TStoredSize nSize = data.size();
        CheckOverflow(nSize, data.size());
        Add(1, &nSize);
        if (IsReading()) {
            data.clear();
            data.resize(nSize);
        }
        if (nSize > 0)
            DataChunk(&data[0], sizeof(T) * nSize);
    }

    template <class AM>
    void DoAnyMap(AM& data) {
        if (IsReading()) {
            data.clear();
            TStoredSize nSize;
            Add(3, &nSize);
            TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices;
            indices.resize(nSize);
            for (TStoredSize i = 0; i < nSize; ++i)
                Add(1, &indices[i]);
            for (TStoredSize i = 0; i < nSize; ++i)
                Add(2, &data[indices[i]]);
        } else {
            TStoredSize nSize = data.size();
            CheckOverflow(nSize, data.size());
            Add(3, &nSize);

            TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices;
            indices.resize(nSize);
            TStoredSize i = 1;
            for (auto pos = data.begin(); pos != data.end(); ++pos, ++i)
                indices[nSize - i] = pos->first;
            for (TStoredSize j = 0; j < nSize; ++j)
                Add(1, &indices[j]);
            for (TStoredSize j = 0; j < nSize; ++j)
                Add(2, &data[indices[j]]);
        }
    }

    // hash_multimap
    template <class AMM>
    void DoAnyMultiMap(AMM& data) {
        if (IsReading()) {
            data.clear();
            TStoredSize nSize;
            Add(3, &nSize);
            TVector<typename AMM::key_type, typename std::allocator_traits<typename AMM::allocator_type>::template rebind_alloc<typename AMM::key_type>> indices;
            indices.resize(nSize);
            for (TStoredSize i = 0; i < nSize; ++i)
                Add(1, &indices[i]);
            for (TStoredSize i = 0; i < nSize; ++i) {
                std::pair<typename AMM::key_type, typename AMM::mapped_type> valToInsert;
                valToInsert.first = indices[i];
                Add(2, &valToInsert.second);
                data.insert(valToInsert);
            }
        } else {
            TStoredSize nSize = data.size();
            CheckOverflow(nSize, data.size());
            Add(3, &nSize);
            for (auto pos = data.begin(); pos != data.end(); ++pos)
                Add(1, (typename AMM::key_type*)(&pos->first));
            for (auto pos = data.begin(); pos != data.end(); ++pos)
                Add(2, &pos->second);
        }
    }

    template <class T>
    void DoAnySet(T& data) {
        if (IsReading()) {
            data.clear();
            TStoredSize nSize;
            Add(2, &nSize);
            for (TStoredSize i = 0; i < nSize; ++i) {
                typename T::value_type member;
                Add(1, &member);
                data.insert(member);
            }
        } else {
            TStoredSize nSize = data.size();
            CheckOverflow(nSize, data.size());
            Add(2, &nSize);
            for (const auto& elem : data) {
                auto member = elem;
                Add(1, &member);
            }
        }
    }

    // 2D array
    template <class T>
    void Do2DArray(TArray2D<T>& a) {
        int nXSize = a.GetXSize(), nYSize = a.GetYSize();
        Add(1, &nXSize);
        Add(2, &nYSize);
        if (IsReading())
            a.SetSizes(nXSize, nYSize);
        for (int i = 0; i < nXSize * nYSize; i++)
            Add(3, &a[i / nXSize][i % nXSize]);
    }
    template <class T>
    void Do2DArrayData(TArray2D<T>& a) {
        int nXSize = a.GetXSize(), nYSize = a.GetYSize();
        Add(1, &nXSize);
        Add(2, &nYSize);
        if (IsReading())
            a.SetSizes(nXSize, nYSize);
        if (nXSize * nYSize > 0)
            DataChunk(&a[0][0], sizeof(T) * nXSize * nYSize);
    }
    // strings
    template <class TStringType>
    void DataChunkStr(TStringType& data, i64 elemSize) {
        if (bRead) {
            TStoredSize nCount = 0;
            File.Read(&nCount, sizeof(TStoredSize));
            data.resize(nCount);
            if (nCount)
                File.Read(&*data.begin(), nCount * elemSize);
        } else {
            TStoredSize nCount = data.size();
            CheckOverflow(nCount, data.size());
            File.Write(&nCount, sizeof(TStoredSize));
            File.Write(data.c_str(), nCount * elemSize);
        }
    }
    void DataChunkString(std::string& data) {
        DataChunkStr(data, sizeof(char));
    }
    void DataChunkStroka(TString& data) {
        DataChunkStr(data, sizeof(TString::char_type));
    }
    void DataChunkWtroka(TUtf16String& data) {
        DataChunkStr(data, sizeof(wchar16));
    }

    void DataChunk(void* pData, i64 nSize) {
        i64 chunkSize = 1 << 30;
        for (i64 offset = 0; offset < nSize; offset += chunkSize) {
            void* ptr = (char*)pData + offset;
            i64 size = offset + chunkSize < nSize ? chunkSize : (nSize - offset);
            if (bRead)
                File.Read(ptr, size);
            else
                File.Write(ptr, size);
        }
    }

    // storing/loading pointers to objects
    void StoreObject(IObjectBase* pObject);
    IObjectBase* LoadObject();

    bool bRead;
    TBufferedStream<> File;
    // maps objects addresses during save(first) to addresses during load(second) - during loading
    // or serves as a sign that some object has been already stored - during storing
    bool StableOutput;

    typedef THashMap<void*, ui32> PtrIdHash;
    TAutoPtr<PtrIdHash> PtrIds;

    typedef THashMap<ui64, TPtr<IObjectBase>> CObjectsHash;
    TAutoPtr<CObjectsHash> Objects;

    TVector<IObjectBase*> ObjectQueue;

public:
    bool IsReading() {
        return bRead;
    }
    void AddRawData(const chunk_id, void* pData, i64 nSize) {
        DataChunk(pData, nSize);
    }

    // return type of Add() is used to detect specialized serializer (see HasNonTrivialSerializer below)
    template <class T>
    char Add(const chunk_id, T* p) {
        CallObjectSerialize(p, NBinSaverInternals::TOverloadPriority<2>());
        return 0;
    }
    int Add(const chunk_id, std::string* pStr) {
        DataChunkString(*pStr);
        return 0;
    }
    int Add(const chunk_id, TString* pStr) {
        DataChunkStroka(*pStr);
        return 0;
    }
    int Add(const chunk_id, TUtf16String* pStr) {
        DataChunkWtroka(*pStr);
        return 0;
    }
    int Add(const chunk_id, TBlob* blob) {
        if (bRead) {
            ui64 size = 0;
            File.Read(&size, sizeof(size));
            TBuffer buffer;
            buffer.Advance(size);
            if (size > 0)
                File.Read(buffer.Data(), buffer.Size());
            (*blob) = TBlob::FromBuffer(buffer);
        } else {
            const ui64 size = blob->Size();
            File.Write(&size, sizeof(size));
            File.Write(blob->Data(), blob->Size());
        }
        return 0;
    }
    template <class T1, class TA>
    int Add(const chunk_id, TVector<T1, TA>* pVec) {
        if (HasNonTrivialSerializer<T1>(0u))
            DoVector(*pVec);
        else
            DoDataVector(*pVec);
        return 0;
    }

    template <class T, int N>
    int Add(const chunk_id, T (*pVec)[N]) {
        if (HasNonTrivialSerializer<T>(0u))
            DoArray(*pVec);
        else
            DataChunk(pVec, sizeof(*pVec));
        return 0;
    }

    template <class T1, class T2, class T3, class T4>
    int Add(const chunk_id, TMap<T1, T2, T3, T4>* pMap) {
        DoAnyMap(*pMap);
        return 0;
    }
    template <class T1, class T2, class T3, class T4, class T5>
    int Add(const chunk_id, THashMap<T1, T2, T3, T4, T5>* pHash) {
        DoAnyMap(*pHash);
        return 0;
    }
    template <class T1, class T2, class T3, class T4, class T5>
    int Add(const chunk_id, THashMultiMap<T1, T2, T3, T4, T5>* pHash) {
        DoAnyMultiMap(*pHash);
        return 0;
    }
    template <class K, class L, class A>
    int Add(const chunk_id, TSet<K, L, A>* pSet) {
        DoAnySet(*pSet);
        return 0;
    }
    template <class T1, class T2, class T3, class T4>
    int Add(const chunk_id, THashSet<T1, T2, T3, T4>* pHash) {
        DoAnySet(*pHash);
        return 0;
    }

    template <class T1>
    int Add(const chunk_id, TArray2D<T1>* pArr) {
        if (HasNonTrivialSerializer<T1>(0u))
            Do2DArray(*pArr);
        else
            Do2DArrayData(*pArr);
        return 0;
    }
    template <class T1>
    int Add(const chunk_id, TList<T1>* pList) {
        TList<T1>& data = *pList;
        if (IsReading()) {
            int nSize;
            Add(2, &nSize);
            data.clear();
            data.insert(data.begin(), nSize, T1());
        } else {
            int nSize = data.size();
            Add(2, &nSize);
        }
        int i = 1;
        for (typename TList<T1>::iterator k = data.begin(); k != data.end(); ++k, ++i)
            Add(i + 2, &(*k));
        return 0;
    }
    template <class T1, class T2>
    int Add(const chunk_id, std::pair<T1, T2>* pData) {
        Add(1, &(pData->first));
        Add(2, &(pData->second));
        return 0;
    }

    template <class T1, size_t N>
    int Add(const chunk_id, std::array<T1, N>* pData) {
        if (HasNonTrivialSerializer<T1>(0u)) {
            for (size_t i = 0; i < N; ++i)
                Add(1, &(*pData)[i]);
        } else {
            DataChunk((void*)pData->data(), pData->size() * sizeof(T1));
        }
        return 0;
    }

    template <size_t N>
    int Add(const chunk_id, std::bitset<N>* pData) {
        if (IsReading()) {
            std::string s;
            Add(1, &s);
            *pData = std::bitset<N>(s);
        } else {
            std::string s = pData->template to_string<char, std::char_traits<char>, std::allocator<char>>();
            Add(1, &s);
        }
        return 0;
    }

    int Add(const chunk_id, TDynBitMap* pData) {
        if (IsReading()) {
            ui64 count = 0;
            Add(1, &count);
            pData->Clear();
            pData->Reserve(count * sizeof(TDynBitMap::TChunk) * 8);
            for (ui64 i = 0; i < count; ++i) {
                TDynBitMap::TChunk chunk = 0;
                Add(i + 1, &chunk);
                if (i > 0) {
                    pData->LShift(8 * sizeof(TDynBitMap::TChunk));
                }
                pData->Or(chunk);
            }
        } else {
            ui64 count = pData->GetChunkCount();
            Add(1, &count);
            for (ui64 i = 0; i < count; ++i) {
                // Write in reverse order
                TDynBitMap::TChunk chunk = pData->GetChunks()[count - i - 1];
                Add(i + 1, &chunk);
            }
        }
        return 0;
    }

    template <class TVariantClass>
    struct TLoadFromTypeFromListHelper {
        template <class T0, class... TTail>
        static void Do(IBinSaver& binSaver, ui32 typeIndex, TVariantClass* pData) {
            if constexpr (sizeof...(TTail) == 0) {
                Y_ASSERT(typeIndex == 0);
                T0 chunk;
                binSaver.Add(2, &chunk);
                *pData = std::move(chunk);
            } else {
                if (typeIndex == 0) {
                    Do<T0>(binSaver, 0, pData);
                } else {
                    Do<TTail...>(binSaver, typeIndex - 1, pData);
                }
            }
        }
    };

    template <class... TVariantTypes>
    int Add(const chunk_id, std::variant<TVariantTypes...>* pData) {
        static_assert(std::variant_size_v<std::variant<TVariantTypes...>> < Max<ui32>());

        ui32 index;
        if (IsReading()) {
            Add(1, &index);
            TLoadFromTypeFromListHelper<std::variant<TVariantTypes...>>::template Do<TVariantTypes...>(
                *this,
                index,
                pData
            );
        } else {
            index = pData->index(); // type cast is safe because of static_assert check above
            Add(1, &index);
            std::visit([&](auto& dst) -> void { Add(2, &dst); }, *pData);
        }
        return 0;
    }


    void AddPolymorphicBase(chunk_id, IObjectBase* pObject) {
        (*pObject) & (*this);
    }

    template <class T1, class T2>
    void DoPtr(TPtrBase<T1, T2>* pData) {
        if (pData && pData->Get()) {
        }
        if (IsReading())
            pData->Set(CastToUserObject(LoadObject(), (T1*)nullptr));
        else
            StoreObject(pData->GetBarePtr());
    }
    template <class T, class TPolicy>
    int Add(const chunk_id, TMaybe<T, TPolicy>* pData) {
        TMaybe<T, TPolicy>& data = *pData;
        if (IsReading()) {
            bool defined = false;
            Add(1, &defined);
            if (defined) {
                data = T();
                Add(2, data.Get());
            }
        } else {
            bool defined = data.Defined();
            Add(1, &defined);
            if (defined) {
                Add(2, data.Get());
            }
        }
        return 0;
    }

    template <typename TOne>
    void AddMulti(TOne& one) {
        Add(0, &one);
    }

    template <typename THead, typename... TTail>
    void AddMulti(THead& head, TTail&... tail) {
        Add(0, &head);
        AddMulti(tail...);
    }

    template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())>
    static bool HasNonTrivialSerializer(ui32) {
        return true;
    }

    template <class T>
    static bool HasNonTrivialSerializer(...) {
        return sizeof(std::declval<IBinSaver*>()->Add(0, std::declval<T*>())) != 1;
    }

public:
    IBinSaver(IBinaryStream& stream, bool _bRead, bool stableOutput = false)
        : bRead(_bRead)
        , File(_bRead, stream)
        , StableOutput(stableOutput)
    {
    }
    virtual ~IBinSaver();
    bool IsValid() const {
        return File.IsValid();
    }
};

// realisation of forward declared serialisation operator
template <class TUserObj, class TRef>
int TPtrBase<TUserObj, TRef>::operator&(IBinSaver& f) {
    f.DoPtr(this);
    return 0;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

extern TClassFactory<IObjectBase>* pSaverClasses;
void StartRegisterSaveload();

template <class TReg>
struct TRegisterSaveLoadType {
    TRegisterSaveLoadType(int num) {
        StartRegisterSaveload();
        pSaverClasses->RegisterType(num, TReg::NewSaveLoadNullItem, (TReg*)nullptr);
    }
};

#define Y_BINSAVER_REGISTER(name) \
    BASIC_REGISTER_CLASS(name)    \
    static TRegisterSaveLoadType<name> init##name(MurmurHash<int>(#name, sizeof(#name)));

#define REGISTER_SAVELOAD_CLASS(N, name) \
    BASIC_REGISTER_CLASS(name)           \
    static TRegisterSaveLoadType<name> init##name##N(N);

// using TObj/TRef on forward declared templ class will not work
// but multiple registration with same id is allowed
#define REGISTER_SAVELOAD_TEMPL1_CLASS(N, className, T) \
    static TRegisterSaveLoadType<className<T>> init##className##T##N(N);

#define REGISTER_SAVELOAD_TEMPL2_CLASS(N, className, T1, T2)    \
    typedef className<T1, T2> temp##className##T1##_##T2##temp; \
    static TRegisterSaveLoadType<className<T1, T2>> init##className##T1##_##T2##N(N);

#define REGISTER_SAVELOAD_TEMPL3_CLASS(N, className, T1, T2, T3)           \
    typedef className<T1, T2, T3> temp##className##T1##_##T2##_##T3##temp; \
    static TRegisterSaveLoadType<className<T1, T2, T3>> init##className##T1##_##T2##_##T3##N(N);

#define REGISTER_SAVELOAD_NM_CLASS(N, nmspace, className) \
    BASIC_REGISTER_CLASS(nmspace::className)              \
    static TRegisterSaveLoadType<nmspace::className> init_##nmspace##_##name##N(N);

#define REGISTER_SAVELOAD_NM2_CLASS(N, nmspace1, nmspace2, className) \
    BASIC_REGISTER_CLASS(nmspace1::nmspace2::className)              \
    static TRegisterSaveLoadType<nmspace1::nmspace2::className> init_##nmspace1##_##nmspace2##_##name##N(N);

#define REGISTER_SAVELOAD_TEMPL1_NM_CLASS(N, nmspace, className, T)       \
    typedef nmspace::className<T> temp_init##nmspace##className##T##temp; \
    BASIC_REGISTER_CLASS(nmspace::className<T>)                           \
    static TRegisterSaveLoadType<nmspace::className<T>> temp_init##nmspace##_##name##T##N(N);

#define REGISTER_SAVELOAD_CLASS_NAME(N, cls, name) \
    BASIC_REGISTER_CLASS(cls)                      \
    static TRegisterSaveLoadType<cls> init##name##N(N);

#define REGISTER_SAVELOAD_CLASS_NS_PREF(N, cls, ns, pref) \
    REGISTER_SAVELOAD_CLASS_NAME(N, ns ::cls, _##pref##_##cls)

#define SAVELOAD(...)             \
    int operator&(IBinSaver& f) { \
        f.AddMulti(__VA_ARGS__);  \
        return 0;                 \
    } Y_SEMICOLON_GUARD

#define SAVELOAD_OVERRIDE_WITHOUT_BASE(...) \
    int operator&(IBinSaver& f) override {  \
        f.AddMulti(__VA_ARGS__);            \
        return 0;                           \
    } Y_SEMICOLON_GUARD

#define SAVELOAD_OVERRIDE(base, ...)       \
    int operator&(IBinSaver& f) override { \
        base::operator&(f);                \
        f.AddMulti(__VA_ARGS__);           \
        return 0;                          \
    } Y_SEMICOLON_GUARD

#define SAVELOAD_BASE(...)        \
    int operator&(IBinSaver& f) { \
        TBase::operator&(f);      \
        f.AddMulti(__VA_ARGS__);  \
        return 0;                 \
    } Y_SEMICOLON_GUARD