#pragma once
#include "buffered_io.h"
#include "class_factory.h"
#include <library/cpp/containers/2d_array/2d_array.h>
#include <util/generic/hash_set.h>
#include <util/generic/buffer.h>
#include <util/generic/list.h>
#include <util/generic/maybe.h>
#include <util/generic/bitmap.h>
#include <util/generic/variant.h>
#include <util/generic/ylimits.h>
#include <util/memory/blob.h>
#include <util/digest/murmur.h>
#include <array>
#include <bitset>
#include <list>
#include <string>
#ifdef _MSC_VER
#pragma warning(disable : 4127)
#endif
enum ESaverMode {
SAVER_MODE_READ = 1,
SAVER_MODE_WRITE = 2,
SAVER_MODE_WRITE_COMPRESSED = 3,
};
namespace NBinSaverInternals {
// This lets explicitly control the overload resolution priority
// The higher P means higher priority in overload resolution order
template <int P>
struct TOverloadPriority : TOverloadPriority <P-1> {
};
template <>
struct TOverloadPriority<0> {
};
}
//////////////////////////////////////////////////////////////////////////
struct IBinSaver {
public:
typedef unsigned char chunk_id;
typedef ui32 TStoredSize; // changing this will break compatibility
private:
// This overload is required to avoid infinite recursion when overriding serialization in derived classes:
// struct B {
// virtual int operator &(IBinSaver& f) {
// return 0;
// }
// };
//
// struct D : B {
// int operator &(IBinSaver& f) override {
// f.Add(0, static_cast<B*>(this));
// return 0;
// }
// };
template <class T, typename = decltype(std::declval<T*>()->T::operator&(std::declval<IBinSaver&>()))>
void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<2>) { // highest priority - will be resolved first if enabled
// Note: p->operator &(*this) would lead to infinite recursion
p->T::operator&(*this);
}
template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())>
void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<1>) { // lower priority - will be resolved second if enabled
(*p) & (*this);
}
template <class T>
void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<0>) { // lower priority - will be resolved last
#if (!defined(_MSC_VER))
// In MSVC __has_trivial_copy returns false to enums, primitive types and arrays.
static_assert(__has_trivial_copy(T), "Class is nontrivial copyable, you must define operator&, see");
#endif
DataChunk(p, sizeof(T));
}
// vector
template <class T, class TA>
void DoVector(TVector<T, TA>& data) {
TStoredSize nSize;
if (IsReading()) {
data.clear();
Add(2, &nSize);
data.resize(nSize);
} else {
nSize = data.size();
CheckOverflow(nSize, data.size());
Add(2, &nSize);
}
for (TStoredSize i = 0; i < nSize; i++)
Add(1, &data[i]);
}
template <class T, int N>
void DoArray(T (&data)[N]) {
for (size_t i = 0; i < N; i++) {
Add(1, &(data[i]));
}
}
template <typename TLarge>
void CheckOverflow(TStoredSize nSize, TLarge origSize) {
if (nSize != origSize) {
fprintf(stderr, "IBinSaver: object size is too large to be serialized (%" PRIu32 " != %" PRIu64 ")\n", nSize, (ui64)origSize);
abort();
}
}
template <class T, class TA>
void DoDataVector(TVector<T, TA>& data) {
TStoredSize nSize = data.size();
CheckOverflow(nSize, data.size());
Add(1, &nSize);
if (IsReading()) {
data.clear();
data.resize(nSize);
}
if (nSize > 0)
DataChunk(&data[0], sizeof(T) * nSize);
}
template <class AM>
void DoAnyMap(AM& data) {
if (IsReading()) {
data.clear();
TStoredSize nSize;
Add(3, &nSize);
TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices;
indices.resize(nSize);
for (TStoredSize i = 0; i < nSize; ++i)
Add(1, &indices[i]);
for (TStoredSize i = 0; i < nSize; ++i)
Add(2, &data[indices[i]]);
} else {
TStoredSize nSize = data.size();
CheckOverflow(nSize, data.size());
Add(3, &nSize);
TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices;
indices.resize(nSize);
TStoredSize i = 1;
for (auto pos = data.begin(); pos != data.end(); ++pos, ++i)
indices[nSize - i] = pos->first;
for (TStoredSize j = 0; j < nSize; ++j)
Add(1, &indices[j]);
for (TStoredSize j = 0; j < nSize; ++j)
Add(2, &data[indices[j]]);
}
}
// hash_multimap
template <class AMM>
void DoAnyMultiMap(AMM& data) {
if (IsReading()) {
data.clear();
TStoredSize nSize;
Add(3, &nSize);
TVector<typename AMM::key_type, typename std::allocator_traits<typename AMM::allocator_type>::template rebind_alloc<typename AMM::key_type>> indices;
indices.resize(nSize);
for (TStoredSize i = 0; i < nSize; ++i)
Add(1, &indices[i]);
for (TStoredSize i = 0; i < nSize; ++i) {
std::pair<typename AMM::key_type, typename AMM::mapped_type> valToInsert;
valToInsert.first = indices[i];
Add(2, &valToInsert.second);
data.insert(valToInsert);
}
} else {
TStoredSize nSize = data.size();
CheckOverflow(nSize, data.size());
Add(3, &nSize);
for (auto pos = data.begin(); pos != data.end(); ++pos)
Add(1, (typename AMM::key_type*)(&pos->first));
for (auto pos = data.begin(); pos != data.end(); ++pos)
Add(2, &pos->second);
}
}
template <class T>
void DoAnySet(T& data) {
if (IsReading()) {
data.clear();
TStoredSize nSize;
Add(2, &nSize);
for (TStoredSize i = 0; i < nSize; ++i) {
typename T::value_type member;
Add(1, &member);
data.insert(member);
}
} else {
TStoredSize nSize = data.size();
CheckOverflow(nSize, data.size());
Add(2, &nSize);
for (const auto& elem : data) {
auto member = elem;
Add(1, &member);
}
}
}
// 2D array
template <class T>
void Do2DArray(TArray2D<T>& a) {
int nXSize = a.GetXSize(), nYSize = a.GetYSize();
Add(1, &nXSize);
Add(2, &nYSize);
if (IsReading())
a.SetSizes(nXSize, nYSize);
for (int i = 0; i < nXSize * nYSize; i++)
Add(3, &a[i / nXSize][i % nXSize]);
}
template <class T>
void Do2DArrayData(TArray2D<T>& a) {
int nXSize = a.GetXSize(), nYSize = a.GetYSize();
Add(1, &nXSize);
Add(2, &nYSize);
if (IsReading())
a.SetSizes(nXSize, nYSize);
if (nXSize * nYSize > 0)
DataChunk(&a[0][0], sizeof(T) * nXSize * nYSize);
}
// strings
template <class TStringType>
void DataChunkStr(TStringType& data, i64 elemSize) {
if (bRead) {
TStoredSize nCount = 0;
File.Read(&nCount, sizeof(TStoredSize));
data.resize(nCount);
if (nCount)
File.Read(&*data.begin(), nCount * elemSize);
} else {
TStoredSize nCount = data.size();
CheckOverflow(nCount, data.size());
File.Write(&nCount, sizeof(TStoredSize));
File.Write(data.c_str(), nCount * elemSize);
}
}
void DataChunkString(std::string& data) {
DataChunkStr(data, sizeof(char));
}
void DataChunkStroka(TString& data) {
DataChunkStr(data, sizeof(TString::char_type));
}
void DataChunkWtroka(TUtf16String& data) {
DataChunkStr(data, sizeof(wchar16));
}
void DataChunk(void* pData, i64 nSize) {
i64 chunkSize = 1 << 30;
for (i64 offset = 0; offset < nSize; offset += chunkSize) {
void* ptr = (char*)pData + offset;
i64 size = offset + chunkSize < nSize ? chunkSize : (nSize - offset);
if (bRead)
File.Read(ptr, size);
else
File.Write(ptr, size);
}
}
// storing/loading pointers to objects
void StoreObject(IObjectBase* pObject);
IObjectBase* LoadObject();
bool bRead;
TBufferedStream<> File;
// maps objects addresses during save(first) to addresses during load(second) - during loading
// or serves as a sign that some object has been already stored - during storing
bool StableOutput;
typedef THashMap<void*, ui32> PtrIdHash;
TAutoPtr<PtrIdHash> PtrIds;
typedef THashMap<ui64, TPtr<IObjectBase>> CObjectsHash;
TAutoPtr<CObjectsHash> Objects;
TVector<IObjectBase*> ObjectQueue;
public:
bool IsReading() {
return bRead;
}
void AddRawData(const chunk_id, void* pData, i64 nSize) {
DataChunk(pData, nSize);
}
// return type of Add() is used to detect specialized serializer (see HasNonTrivialSerializer below)
template <class T>
char Add(const chunk_id, T* p) {
CallObjectSerialize(p, NBinSaverInternals::TOverloadPriority<2>());
return 0;
}
int Add(const chunk_id, std::string* pStr) {
DataChunkString(*pStr);
return 0;
}
int Add(const chunk_id, TString* pStr) {
DataChunkStroka(*pStr);
return 0;
}
int Add(const chunk_id, TUtf16String* pStr) {
DataChunkWtroka(*pStr);
return 0;
}
int Add(const chunk_id, TBlob* blob) {
if (bRead) {
ui64 size = 0;
File.Read(&size, sizeof(size));
TBuffer buffer;
buffer.Advance(size);
if (size > 0)
File.Read(buffer.Data(), buffer.Size());
(*blob) = TBlob::FromBuffer(buffer);
} else {
const ui64 size = blob->Size();
File.Write(&size, sizeof(size));
File.Write(blob->Data(), blob->Size());
}
return 0;
}
template <class T1, class TA>
int Add(const chunk_id, TVector<T1, TA>* pVec) {
if (HasNonTrivialSerializer<T1>(0u))
DoVector(*pVec);
else
DoDataVector(*pVec);
return 0;
}
template <class T, int N>
int Add(const chunk_id, T (*pVec)[N]) {
if (HasNonTrivialSerializer<T>(0u))
DoArray(*pVec);
else
DataChunk(pVec, sizeof(*pVec));
return 0;
}
template <class T1, class T2, class T3, class T4>
int Add(const chunk_id, TMap<T1, T2, T3, T4>* pMap) {
DoAnyMap(*pMap);
return 0;
}
template <class T1, class T2, class T3, class T4, class T5>
int Add(const chunk_id, THashMap<T1, T2, T3, T4, T5>* pHash) {
DoAnyMap(*pHash);
return 0;
}
template <class T1, class T2, class T3, class T4, class T5>
int Add(const chunk_id, THashMultiMap<T1, T2, T3, T4, T5>* pHash) {
DoAnyMultiMap(*pHash);
return 0;
}
template <class K, class L, class A>
int Add(const chunk_id, TSet<K, L, A>* pSet) {
DoAnySet(*pSet);
return 0;
}
template <class T1, class T2, class T3, class T4>
int Add(const chunk_id, THashSet<T1, T2, T3, T4>* pHash) {
DoAnySet(*pHash);
return 0;
}
template <class T1>
int Add(const chunk_id, TArray2D<T1>* pArr) {
if (HasNonTrivialSerializer<T1>(0u))
Do2DArray(*pArr);
else
Do2DArrayData(*pArr);
return 0;
}
template <class T1>
int Add(const chunk_id, TList<T1>* pList) {
TList<T1>& data = *pList;
if (IsReading()) {
int nSize;
Add(2, &nSize);
data.clear();
data.insert(data.begin(), nSize, T1());
} else {
int nSize = data.size();
Add(2, &nSize);
}
int i = 1;
for (typename TList<T1>::iterator k = data.begin(); k != data.end(); ++k, ++i)
Add(i + 2, &(*k));
return 0;
}
template <class T1, class T2>
int Add(const chunk_id, std::pair<T1, T2>* pData) {
Add(1, &(pData->first));
Add(2, &(pData->second));
return 0;
}
template <class T1, size_t N>
int Add(const chunk_id, std::array<T1, N>* pData) {
if (HasNonTrivialSerializer<T1>(0u)) {
for (size_t i = 0; i < N; ++i)
Add(1, &(*pData)[i]);
} else {
DataChunk((void*)pData->data(), pData->size() * sizeof(T1));
}
return 0;
}
template <size_t N>
int Add(const chunk_id, std::bitset<N>* pData) {
if (IsReading()) {
std::string s;
Add(1, &s);
*pData = std::bitset<N>(s);
} else {
std::string s = pData->template to_string<char, std::char_traits<char>, std::allocator<char>>();
Add(1, &s);
}
return 0;
}
int Add(const chunk_id, TDynBitMap* pData) {
if (IsReading()) {
ui64 count = 0;
Add(1, &count);
pData->Clear();
pData->Reserve(count * sizeof(TDynBitMap::TChunk) * 8);
for (ui64 i = 0; i < count; ++i) {
TDynBitMap::TChunk chunk = 0;
Add(i + 1, &chunk);
if (i > 0) {
pData->LShift(8 * sizeof(TDynBitMap::TChunk));
}
pData->Or(chunk);
}
} else {
ui64 count = pData->GetChunkCount();
Add(1, &count);
for (ui64 i = 0; i < count; ++i) {
// Write in reverse order
TDynBitMap::TChunk chunk = pData->GetChunks()[count - i - 1];
Add(i + 1, &chunk);
}
}
return 0;
}
template <class TVariantClass>
struct TLoadFromTypeFromListHelper {
template <class T0, class... TTail>
static void Do(IBinSaver& binSaver, ui32 typeIndex, TVariantClass* pData) {
if constexpr (sizeof...(TTail) == 0) {
Y_ASSERT(typeIndex == 0);
T0 chunk;
binSaver.Add(2, &chunk);
*pData = std::move(chunk);
} else {
if (typeIndex == 0) {
Do<T0>(binSaver, 0, pData);
} else {
Do<TTail...>(binSaver, typeIndex - 1, pData);
}
}
}
};
template <class... TVariantTypes>
int Add(const chunk_id, std::variant<TVariantTypes...>* pData) {
static_assert(std::variant_size_v<std::variant<TVariantTypes...>> < Max<ui32>());
ui32 index;
if (IsReading()) {
Add(1, &index);
TLoadFromTypeFromListHelper<std::variant<TVariantTypes...>>::template Do<TVariantTypes...>(
*this,
index,
pData
);
} else {
index = pData->index(); // type cast is safe because of static_assert check above
Add(1, &index);
std::visit([&](auto& dst) -> void { Add(2, &dst); }, *pData);
}
return 0;
}
void AddPolymorphicBase(chunk_id, IObjectBase* pObject) {
(*pObject) & (*this);
}
template <class T1, class T2>
void DoPtr(TPtrBase<T1, T2>* pData) {
if (pData && pData->Get()) {
}
if (IsReading())
pData->Set(CastToUserObject(LoadObject(), (T1*)nullptr));
else
StoreObject(pData->GetBarePtr());
}
template <class T, class TPolicy>
int Add(const chunk_id, TMaybe<T, TPolicy>* pData) {
TMaybe<T, TPolicy>& data = *pData;
if (IsReading()) {
bool defined = false;
Add(1, &defined);
if (defined) {
data = T();
Add(2, data.Get());
}
} else {
bool defined = data.Defined();
Add(1, &defined);
if (defined) {
Add(2, data.Get());
}
}
return 0;
}
template <typename TOne>
void AddMulti(TOne& one) {
Add(0, &one);
}
template <typename THead, typename... TTail>
void AddMulti(THead& head, TTail&... tail) {
Add(0, &head);
AddMulti(tail...);
}
template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())>
static bool HasNonTrivialSerializer(ui32) {
return true;
}
template <class T>
static bool HasNonTrivialSerializer(...) {
return sizeof(std::declval<IBinSaver*>()->Add(0, std::declval<T*>())) != 1;
}
public:
IBinSaver(IBinaryStream& stream, bool _bRead, bool stableOutput = false)
: bRead(_bRead)
, File(_bRead, stream)
, StableOutput(stableOutput)
{
}
virtual ~IBinSaver();
bool IsValid() const {
return File.IsValid();
}
};
// realisation of forward declared serialisation operator
template <class TUserObj, class TRef>
int TPtrBase<TUserObj, TRef>::operator&(IBinSaver& f) {
f.DoPtr(this);
return 0;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
extern TClassFactory<IObjectBase>* pSaverClasses;
void StartRegisterSaveload();
template <class TReg>
struct TRegisterSaveLoadType {
TRegisterSaveLoadType(int num) {
StartRegisterSaveload();
pSaverClasses->RegisterType(num, TReg::NewSaveLoadNullItem, (TReg*)nullptr);
}
};
#define Y_BINSAVER_REGISTER(name) \
BASIC_REGISTER_CLASS(name) \
static TRegisterSaveLoadType<name> init##name(MurmurHash<int>(#name, sizeof(#name)));
#define REGISTER_SAVELOAD_CLASS(N, name) \
BASIC_REGISTER_CLASS(name) \
static TRegisterSaveLoadType<name> init##name##N(N);
// using TObj/TRef on forward declared templ class will not work
// but multiple registration with same id is allowed
#define REGISTER_SAVELOAD_TEMPL1_CLASS(N, className, T) \
static TRegisterSaveLoadType<className<T>> init##className##T##N(N);
#define REGISTER_SAVELOAD_TEMPL2_CLASS(N, className, T1, T2) \
typedef className<T1, T2> temp##className##T1##_##T2##temp; \
static TRegisterSaveLoadType<className<T1, T2>> init##className##T1##_##T2##N(N);
#define REGISTER_SAVELOAD_TEMPL3_CLASS(N, className, T1, T2, T3) \
typedef className<T1, T2, T3> temp##className##T1##_##T2##_##T3##temp; \
static TRegisterSaveLoadType<className<T1, T2, T3>> init##className##T1##_##T2##_##T3##N(N);
#define REGISTER_SAVELOAD_NM_CLASS(N, nmspace, className) \
BASIC_REGISTER_CLASS(nmspace::className) \
static TRegisterSaveLoadType<nmspace::className> init_##nmspace##_##name##N(N);
#define REGISTER_SAVELOAD_NM2_CLASS(N, nmspace1, nmspace2, className) \
BASIC_REGISTER_CLASS(nmspace1::nmspace2::className) \
static TRegisterSaveLoadType<nmspace1::nmspace2::className> init_##nmspace1##_##nmspace2##_##name##N(N);
#define REGISTER_SAVELOAD_TEMPL1_NM_CLASS(N, nmspace, className, T) \
typedef nmspace::className<T> temp_init##nmspace##className##T##temp; \
BASIC_REGISTER_CLASS(nmspace::className<T>) \
static TRegisterSaveLoadType<nmspace::className<T>> temp_init##nmspace##_##name##T##N(N);
#define REGISTER_SAVELOAD_CLASS_NAME(N, cls, name) \
BASIC_REGISTER_CLASS(cls) \
static TRegisterSaveLoadType<cls> init##name##N(N);
#define REGISTER_SAVELOAD_CLASS_NS_PREF(N, cls, ns, pref) \
REGISTER_SAVELOAD_CLASS_NAME(N, ns ::cls, _##pref##_##cls)
#define SAVELOAD(...) \
int operator&(IBinSaver& f) { \
f.AddMulti(__VA_ARGS__); \
return 0; \
}
#define SAVELOAD_OVERRIDE_WITHOUT_BASE(...) \
int operator&(IBinSaver& f) override { \
f.AddMulti(__VA_ARGS__); \
return 0; \
}
#define SAVELOAD_OVERRIDE(base, ...) \
int operator&(IBinSaver& f) override { \
base::operator&(f); \
f.AddMulti(__VA_ARGS__); \
return 0; \
}
#define SAVELOAD_BASE(...) \
int operator&(IBinSaver& f) { \
TBase::operator&(f); \
f.AddMulti(__VA_ARGS__); \
return 0; \
}