diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/binsaver | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/binsaver')
-rw-r--r-- | library/cpp/binsaver/bin_saver.cpp | 81 | ||||
-rw-r--r-- | library/cpp/binsaver/bin_saver.h | 646 | ||||
-rw-r--r-- | library/cpp/binsaver/blob_io.cpp | 1 | ||||
-rw-r--r-- | library/cpp/binsaver/blob_io.h | 47 | ||||
-rw-r--r-- | library/cpp/binsaver/buffered_io.cpp | 39 | ||||
-rw-r--r-- | library/cpp/binsaver/buffered_io.h | 134 | ||||
-rw-r--r-- | library/cpp/binsaver/class_factory.h | 105 | ||||
-rw-r--r-- | library/cpp/binsaver/mem_io.cpp | 1 | ||||
-rw-r--r-- | library/cpp/binsaver/mem_io.h | 212 | ||||
-rw-r--r-- | library/cpp/binsaver/ut/binsaver_ut.cpp | 198 | ||||
-rw-r--r-- | library/cpp/binsaver/ut/ya.make | 11 | ||||
-rw-r--r-- | library/cpp/binsaver/ut_util/README.md | 1 | ||||
-rw-r--r-- | library/cpp/binsaver/ut_util/ut_util.cpp | 1 | ||||
-rw-r--r-- | library/cpp/binsaver/ut_util/ut_util.h | 71 | ||||
-rw-r--r-- | library/cpp/binsaver/ut_util/ya.make | 14 | ||||
-rw-r--r-- | library/cpp/binsaver/util_stream_io.cpp | 1 | ||||
-rw-r--r-- | library/cpp/binsaver/util_stream_io.h | 86 | ||||
-rw-r--r-- | library/cpp/binsaver/ya.make | 18 |
18 files changed, 1667 insertions, 0 deletions
diff --git a/library/cpp/binsaver/bin_saver.cpp b/library/cpp/binsaver/bin_saver.cpp new file mode 100644 index 0000000000..fe0775af9f --- /dev/null +++ b/library/cpp/binsaver/bin_saver.cpp @@ -0,0 +1,81 @@ +#include "bin_saver.h" + +TClassFactory<IObjectBase>* pSaverClasses; +void StartRegisterSaveload() { + if (!pSaverClasses) + pSaverClasses = new TClassFactory<IObjectBase>; +} +struct SBasicChunkInit { + ~SBasicChunkInit() { + if (pSaverClasses) + delete pSaverClasses; + } +} initSaver; + +////////////////////////////////////////////////////////////////////////// +void IBinSaver::StoreObject(IObjectBase* pObject) { + if (pObject) { + Y_ASSERT(pSaverClasses->GetObjectTypeID(pObject) != -1 && "trying to save unregistered object"); + } + + ui64 ptrId = ((char*)pObject) - ((char*)nullptr); + if (StableOutput) { + ui32 id = 0; + if (pObject) { + if (!PtrIds.Get()) + PtrIds.Reset(new PtrIdHash); + PtrIdHash::iterator pFound = PtrIds->find(pObject); + if (pFound != PtrIds->end()) + id = pFound->second; + else { + id = PtrIds->ysize() + 1; + PtrIds->insert(std::make_pair(pObject, id)); + } + } + ptrId = id; + } + + DataChunk(&ptrId, sizeof(ptrId)); + if (!Objects.Get()) + Objects.Reset(new CObjectsHash); + if (ptrId != 0 && Objects->find(ptrId) == Objects->end()) { + ObjectQueue.push_back(pObject); + (*Objects)[ptrId]; + int typeId = pSaverClasses->GetObjectTypeID(pObject); + if (typeId == -1) { + fprintf(stderr, "IBinSaver: trying to save unregistered object\n"); + abort(); + } + DataChunk(&typeId, sizeof(typeId)); + } +} + +IObjectBase* IBinSaver::LoadObject() { + ui64 ptrId = 0; + DataChunk(&ptrId, sizeof(ptrId)); + if (ptrId != 0) { + if (!Objects.Get()) + Objects.Reset(new CObjectsHash); + CObjectsHash::iterator pFound = Objects->find(ptrId); + if (pFound != Objects->end()) + return pFound->second; + int typeId; + DataChunk(&typeId, sizeof(typeId)); + IObjectBase* pObj = pSaverClasses->CreateObject(typeId); + Y_ASSERT(pObj != nullptr); + if (pObj == nullptr) { + fprintf(stderr, "IBinSaver: trying to load unregistered object\n"); + abort(); + } + (*Objects)[ptrId] = pObj; + ObjectQueue.push_back(pObj); + return pObj; + } + return nullptr; +} + +IBinSaver::~IBinSaver() { + for (size_t i = 0; i < ObjectQueue.size(); ++i) { + AddPolymorphicBase(1, ObjectQueue[i]); + } +} diff --git a/library/cpp/binsaver/bin_saver.h b/library/cpp/binsaver/bin_saver.h new file mode 100644 index 0000000000..412424889f --- /dev/null +++ b/library/cpp/binsaver/bin_saver.h @@ -0,0 +1,646 @@ +#pragma once + +#include "buffered_io.h" +#include "class_factory.h" + +#include <library/cpp/containers/2d_array/2d_array.h> + +#include <util/generic/hash_set.h> +#include <util/generic/buffer.h> +#include <util/generic/list.h> +#include <util/generic/maybe.h> +#include <util/generic/bitmap.h> +#include <util/generic/variant.h> +#include <util/generic/ylimits.h> +#include <util/memory/blob.h> +#include <util/digest/murmur.h> + +#include <array> +#include <bitset> +#include <list> +#include <string> + +#ifdef _MSC_VER +#pragma warning(disable : 4127) +#endif + +enum ESaverMode { + SAVER_MODE_READ = 1, + SAVER_MODE_WRITE = 2, + SAVER_MODE_WRITE_COMPRESSED = 3, +}; + +namespace NBinSaverInternals { + // This lets explicitly control the overload resolution priority + // The higher P means higher priority in overload resolution order + template <int P> + struct TOverloadPriority : TOverloadPriority <P-1> { + }; + + template <> + struct TOverloadPriority<0> { + }; +} + +////////////////////////////////////////////////////////////////////////// +struct IBinSaver { +public: + typedef unsigned char chunk_id; + typedef ui32 TStoredSize; // changing this will break compatibility + +private: + // This overload is required to avoid infinite recursion when overriding serialization in derived classes: + // struct B { + // virtual int operator &(IBinSaver& f) { + // return 0; + // } + // }; + // + // struct D : B { + // int operator &(IBinSaver& f) override { + // f.Add(0, static_cast<B*>(this)); + // return 0; + // } + // }; + template <class T, typename = decltype(std::declval<T*>()->T::operator&(std::declval<IBinSaver&>()))> + void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<2>) { // highest priority - will be resolved first if enabled + // Note: p->operator &(*this) would lead to infinite recursion + p->T::operator&(*this); + } + + template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())> + void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<1>) { // lower priority - will be resolved second if enabled + (*p) & (*this); + } + + template <class T> + void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<0>) { // lower priority - will be resolved last +#if (!defined(_MSC_VER)) + // In MSVC __has_trivial_copy returns false to enums, primitive types and arrays. + static_assert(__has_trivial_copy(T), "Class is nontrivial copyable, you must define operator&, see"); +#endif + DataChunk(p, sizeof(T)); + } + + // vector + template <class T, class TA> + void DoVector(TVector<T, TA>& data) { + TStoredSize nSize; + if (IsReading()) { + data.clear(); + Add(2, &nSize); + data.resize(nSize); + } else { + nSize = data.size(); + CheckOverflow(nSize, data.size()); + Add(2, &nSize); + } + for (TStoredSize i = 0; i < nSize; i++) + Add(1, &data[i]); + } + + template <class T, int N> + void DoArray(T (&data)[N]) { + for (size_t i = 0; i < N; i++) { + Add(1, &(data[i])); + } + } + + template <typename TLarge> + void CheckOverflow(TStoredSize nSize, TLarge origSize) { + if (nSize != origSize) { + fprintf(stderr, "IBinSaver: object size is too large to be serialized (%" PRIu32 " != %" PRIu64 ")\n", nSize, (ui64)origSize); + abort(); + } + } + + template <class T, class TA> + void DoDataVector(TVector<T, TA>& data) { + TStoredSize nSize = data.size(); + CheckOverflow(nSize, data.size()); + Add(1, &nSize); + if (IsReading()) { + data.clear(); + data.resize(nSize); + } + if (nSize > 0) + DataChunk(&data[0], sizeof(T) * nSize); + } + + template <class AM> + void DoAnyMap(AM& data) { + if (IsReading()) { + data.clear(); + TStoredSize nSize; + Add(3, &nSize); + TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices; + indices.resize(nSize); + for (TStoredSize i = 0; i < nSize; ++i) + Add(1, &indices[i]); + for (TStoredSize i = 0; i < nSize; ++i) + Add(2, &data[indices[i]]); + } else { + TStoredSize nSize = data.size(); + CheckOverflow(nSize, data.size()); + Add(3, &nSize); + + TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices; + indices.resize(nSize); + TStoredSize i = 1; + for (auto pos = data.begin(); pos != data.end(); ++pos, ++i) + indices[nSize - i] = pos->first; + for (TStoredSize j = 0; j < nSize; ++j) + Add(1, &indices[j]); + for (TStoredSize j = 0; j < nSize; ++j) + Add(2, &data[indices[j]]); + } + } + + // hash_multimap + template <class AMM> + void DoAnyMultiMap(AMM& data) { + if (IsReading()) { + data.clear(); + TStoredSize nSize; + Add(3, &nSize); + TVector<typename AMM::key_type, typename std::allocator_traits<typename AMM::allocator_type>::template rebind_alloc<typename AMM::key_type>> indices; + indices.resize(nSize); + for (TStoredSize i = 0; i < nSize; ++i) + Add(1, &indices[i]); + for (TStoredSize i = 0; i < nSize; ++i) { + std::pair<typename AMM::key_type, typename AMM::mapped_type> valToInsert; + valToInsert.first = indices[i]; + Add(2, &valToInsert.second); + data.insert(valToInsert); + } + } else { + TStoredSize nSize = data.size(); + CheckOverflow(nSize, data.size()); + Add(3, &nSize); + for (auto pos = data.begin(); pos != data.end(); ++pos) + Add(1, (typename AMM::key_type*)(&pos->first)); + for (auto pos = data.begin(); pos != data.end(); ++pos) + Add(2, &pos->second); + } + } + + template <class T> + void DoAnySet(T& data) { + if (IsReading()) { + data.clear(); + TStoredSize nSize; + Add(2, &nSize); + for (TStoredSize i = 0; i < nSize; ++i) { + typename T::value_type member; + Add(1, &member); + data.insert(member); + } + } else { + TStoredSize nSize = data.size(); + CheckOverflow(nSize, data.size()); + Add(2, &nSize); + for (const auto& elem : data) { + auto member = elem; + Add(1, &member); + } + } + } + + // 2D array + template <class T> + void Do2DArray(TArray2D<T>& a) { + int nXSize = a.GetXSize(), nYSize = a.GetYSize(); + Add(1, &nXSize); + Add(2, &nYSize); + if (IsReading()) + a.SetSizes(nXSize, nYSize); + for (int i = 0; i < nXSize * nYSize; i++) + Add(3, &a[i / nXSize][i % nXSize]); + } + template <class T> + void Do2DArrayData(TArray2D<T>& a) { + int nXSize = a.GetXSize(), nYSize = a.GetYSize(); + Add(1, &nXSize); + Add(2, &nYSize); + if (IsReading()) + a.SetSizes(nXSize, nYSize); + if (nXSize * nYSize > 0) + DataChunk(&a[0][0], sizeof(T) * nXSize * nYSize); + } + // strings + template <class TStringType> + void DataChunkStr(TStringType& data, i64 elemSize) { + if (bRead) { + TStoredSize nCount = 0; + File.Read(&nCount, sizeof(TStoredSize)); + data.resize(nCount); + if (nCount) + File.Read(&*data.begin(), nCount * elemSize); + } else { + TStoredSize nCount = data.size(); + CheckOverflow(nCount, data.size()); + File.Write(&nCount, sizeof(TStoredSize)); + File.Write(data.c_str(), nCount * elemSize); + } + } + void DataChunkString(std::string& data) { + DataChunkStr(data, sizeof(char)); + } + void DataChunkStroka(TString& data) { + DataChunkStr(data, sizeof(TString::char_type)); + } + void DataChunkWtroka(TUtf16String& data) { + DataChunkStr(data, sizeof(wchar16)); + } + + void DataChunk(void* pData, i64 nSize) { + i64 chunkSize = 1 << 30; + for (i64 offset = 0; offset < nSize; offset += chunkSize) { + void* ptr = (char*)pData + offset; + i64 size = offset + chunkSize < nSize ? chunkSize : (nSize - offset); + if (bRead) + File.Read(ptr, size); + else + File.Write(ptr, size); + } + } + + // storing/loading pointers to objects + void StoreObject(IObjectBase* pObject); + IObjectBase* LoadObject(); + + bool bRead; + TBufferedStream<> File; + // maps objects addresses during save(first) to addresses during load(second) - during loading + // or serves as a sign that some object has been already stored - during storing + bool StableOutput; + + typedef THashMap<void*, ui32> PtrIdHash; + TAutoPtr<PtrIdHash> PtrIds; + + typedef THashMap<ui64, TPtr<IObjectBase>> CObjectsHash; + TAutoPtr<CObjectsHash> Objects; + + TVector<IObjectBase*> ObjectQueue; + +public: + bool IsReading() { + return bRead; + } + void AddRawData(const chunk_id, void* pData, i64 nSize) { + DataChunk(pData, nSize); + } + + // return type of Add() is used to detect specialized serializer (see HasNonTrivialSerializer below) + template <class T> + char Add(const chunk_id, T* p) { + CallObjectSerialize(p, NBinSaverInternals::TOverloadPriority<2>()); + return 0; + } + int Add(const chunk_id, std::string* pStr) { + DataChunkString(*pStr); + return 0; + } + int Add(const chunk_id, TString* pStr) { + DataChunkStroka(*pStr); + return 0; + } + int Add(const chunk_id, TUtf16String* pStr) { + DataChunkWtroka(*pStr); + return 0; + } + int Add(const chunk_id, TBlob* blob) { + if (bRead) { + ui64 size = 0; + File.Read(&size, sizeof(size)); + TBuffer buffer; + buffer.Advance(size); + if (size > 0) + File.Read(buffer.Data(), buffer.Size()); + (*blob) = TBlob::FromBuffer(buffer); + } else { + const ui64 size = blob->Size(); + File.Write(&size, sizeof(size)); + File.Write(blob->Data(), blob->Size()); + } + return 0; + } + template <class T1, class TA> + int Add(const chunk_id, TVector<T1, TA>* pVec) { + if (HasNonTrivialSerializer<T1>(0u)) + DoVector(*pVec); + else + DoDataVector(*pVec); + return 0; + } + + template <class T, int N> + int Add(const chunk_id, T (*pVec)[N]) { + if (HasNonTrivialSerializer<T>(0u)) + DoArray(*pVec); + else + DataChunk(pVec, sizeof(*pVec)); + return 0; + } + + template <class T1, class T2, class T3, class T4> + int Add(const chunk_id, TMap<T1, T2, T3, T4>* pMap) { + DoAnyMap(*pMap); + return 0; + } + template <class T1, class T2, class T3, class T4, class T5> + int Add(const chunk_id, THashMap<T1, T2, T3, T4, T5>* pHash) { + DoAnyMap(*pHash); + return 0; + } + template <class T1, class T2, class T3, class T4, class T5> + int Add(const chunk_id, THashMultiMap<T1, T2, T3, T4, T5>* pHash) { + DoAnyMultiMap(*pHash); + return 0; + } + template <class K, class L, class A> + int Add(const chunk_id, TSet<K, L, A>* pSet) { + DoAnySet(*pSet); + return 0; + } + template <class T1, class T2, class T3, class T4> + int Add(const chunk_id, THashSet<T1, T2, T3, T4>* pHash) { + DoAnySet(*pHash); + return 0; + } + + template <class T1> + int Add(const chunk_id, TArray2D<T1>* pArr) { + if (HasNonTrivialSerializer<T1>(0u)) + Do2DArray(*pArr); + else + Do2DArrayData(*pArr); + return 0; + } + template <class T1> + int Add(const chunk_id, TList<T1>* pList) { + TList<T1>& data = *pList; + if (IsReading()) { + int nSize; + Add(2, &nSize); + data.clear(); + data.insert(data.begin(), nSize, T1()); + } else { + int nSize = data.size(); + Add(2, &nSize); + } + int i = 1; + for (typename TList<T1>::iterator k = data.begin(); k != data.end(); ++k, ++i) + Add(i + 2, &(*k)); + return 0; + } + template <class T1, class T2> + int Add(const chunk_id, std::pair<T1, T2>* pData) { + Add(1, &(pData->first)); + Add(2, &(pData->second)); + return 0; + } + + template <class T1, size_t N> + int Add(const chunk_id, std::array<T1, N>* pData) { + if (HasNonTrivialSerializer<T1>(0u)) { + for (size_t i = 0; i < N; ++i) + Add(1, &(*pData)[i]); + } else { + DataChunk((void*)pData->data(), pData->size() * sizeof(T1)); + } + return 0; + } + + template <size_t N> + int Add(const chunk_id, std::bitset<N>* pData) { + if (IsReading()) { + std::string s; + Add(1, &s); + *pData = std::bitset<N>(s); + } else { + std::string s = pData->template to_string<char, std::char_traits<char>, std::allocator<char>>(); + Add(1, &s); + } + return 0; + } + + int Add(const chunk_id, TDynBitMap* pData) { + if (IsReading()) { + ui64 count = 0; + Add(1, &count); + pData->Clear(); + pData->Reserve(count * sizeof(TDynBitMap::TChunk) * 8); + for (ui64 i = 0; i < count; ++i) { + TDynBitMap::TChunk chunk = 0; + Add(i + 1, &chunk); + if (i > 0) { + pData->LShift(8 * sizeof(TDynBitMap::TChunk)); + } + pData->Or(chunk); + } + } else { + ui64 count = pData->GetChunkCount(); + Add(1, &count); + for (ui64 i = 0; i < count; ++i) { + // Write in reverse order + TDynBitMap::TChunk chunk = pData->GetChunks()[count - i - 1]; + Add(i + 1, &chunk); + } + } + return 0; + } + + template <class TVariantClass> + struct TLoadFromTypeFromListHelper { + template <class T0, class... TTail> + static void Do(IBinSaver& binSaver, ui32 typeIndex, TVariantClass* pData) { + if constexpr (sizeof...(TTail) == 0) { + Y_ASSERT(typeIndex == 0); + T0 chunk; + binSaver.Add(2, &chunk); + *pData = std::move(chunk); + } else { + if (typeIndex == 0) { + Do<T0>(binSaver, 0, pData); + } else { + Do<TTail...>(binSaver, typeIndex - 1, pData); + } + } + } + }; + + template <class... TVariantTypes> + int Add(const chunk_id, std::variant<TVariantTypes...>* pData) { + static_assert(std::variant_size_v<std::variant<TVariantTypes...>> < Max<ui32>()); + + ui32 index; + if (IsReading()) { + Add(1, &index); + TLoadFromTypeFromListHelper<std::variant<TVariantTypes...>>::template Do<TVariantTypes...>( + *this, + index, + pData + ); + } else { + index = pData->index(); // type cast is safe because of static_assert check above + Add(1, &index); + std::visit([&](auto& dst) -> void { Add(2, &dst); }, *pData); + } + return 0; + } + + + void AddPolymorphicBase(chunk_id, IObjectBase* pObject) { + (*pObject) & (*this); + } + + template <class T1, class T2> + void DoPtr(TPtrBase<T1, T2>* pData) { + if (pData && pData->Get()) { + } + if (IsReading()) + pData->Set(CastToUserObject(LoadObject(), (T1*)nullptr)); + else + StoreObject(pData->GetBarePtr()); + } + template <class T, class TPolicy> + int Add(const chunk_id, TMaybe<T, TPolicy>* pData) { + TMaybe<T, TPolicy>& data = *pData; + if (IsReading()) { + bool defined = false; + Add(1, &defined); + if (defined) { + data = T(); + Add(2, data.Get()); + } + } else { + bool defined = data.Defined(); + Add(1, &defined); + if (defined) { + Add(2, data.Get()); + } + } + return 0; + } + + template <typename TOne> + void AddMulti(TOne& one) { + Add(0, &one); + } + + template <typename THead, typename... TTail> + void AddMulti(THead& head, TTail&... tail) { + Add(0, &head); + AddMulti(tail...); + } + + template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())> + static bool HasNonTrivialSerializer(ui32) { + return true; + } + + template <class T> + static bool HasNonTrivialSerializer(...) { + return sizeof(std::declval<IBinSaver*>()->Add(0, std::declval<T*>())) != 1; + } + +public: + IBinSaver(IBinaryStream& stream, bool _bRead, bool stableOutput = false) + : bRead(_bRead) + , File(_bRead, stream) + , StableOutput(stableOutput) + { + } + virtual ~IBinSaver(); + bool IsValid() const { + return File.IsValid(); + } +}; + +// realisation of forward declared serialisation operator +template <class TUserObj, class TRef> +int TPtrBase<TUserObj, TRef>::operator&(IBinSaver& f) { + f.DoPtr(this); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +extern TClassFactory<IObjectBase>* pSaverClasses; +void StartRegisterSaveload(); + +template <class TReg> +struct TRegisterSaveLoadType { + TRegisterSaveLoadType(int num) { + StartRegisterSaveload(); + pSaverClasses->RegisterType(num, TReg::NewSaveLoadNullItem, (TReg*)nullptr); + } +}; + +#define Y_BINSAVER_REGISTER(name) \ + BASIC_REGISTER_CLASS(name) \ + static TRegisterSaveLoadType<name> init##name(MurmurHash<int>(#name, sizeof(#name))); + +#define REGISTER_SAVELOAD_CLASS(N, name) \ + BASIC_REGISTER_CLASS(name) \ + static TRegisterSaveLoadType<name> init##name##N(N); + +// using TObj/TRef on forward declared templ class will not work +// but multiple registration with same id is allowed +#define REGISTER_SAVELOAD_TEMPL1_CLASS(N, className, T) \ + static TRegisterSaveLoadType<className<T>> init##className##T##N(N); + +#define REGISTER_SAVELOAD_TEMPL2_CLASS(N, className, T1, T2) \ + typedef className<T1, T2> temp##className##T1##_##T2##temp; \ + static TRegisterSaveLoadType<className<T1, T2>> init##className##T1##_##T2##N(N); + +#define REGISTER_SAVELOAD_TEMPL3_CLASS(N, className, T1, T2, T3) \ + typedef className<T1, T2, T3> temp##className##T1##_##T2##_##T3##temp; \ + static TRegisterSaveLoadType<className<T1, T2, T3>> init##className##T1##_##T2##_##T3##N(N); + +#define REGISTER_SAVELOAD_NM_CLASS(N, nmspace, className) \ + BASIC_REGISTER_CLASS(nmspace::className) \ + static TRegisterSaveLoadType<nmspace::className> init_##nmspace##_##name##N(N); + +#define REGISTER_SAVELOAD_NM2_CLASS(N, nmspace1, nmspace2, className) \ + BASIC_REGISTER_CLASS(nmspace1::nmspace2::className) \ + static TRegisterSaveLoadType<nmspace1::nmspace2::className> init_##nmspace1##_##nmspace2##_##name##N(N); + +#define REGISTER_SAVELOAD_TEMPL1_NM_CLASS(N, nmspace, className, T) \ + typedef nmspace::className<T> temp_init##nmspace##className##T##temp; \ + BASIC_REGISTER_CLASS(nmspace::className<T>) \ + static TRegisterSaveLoadType<nmspace::className<T>> temp_init##nmspace##_##name##T##N(N); + +#define REGISTER_SAVELOAD_CLASS_NAME(N, cls, name) \ + BASIC_REGISTER_CLASS(cls) \ + static TRegisterSaveLoadType<cls> init##name##N(N); + +#define REGISTER_SAVELOAD_CLASS_NS_PREF(N, cls, ns, pref) \ + REGISTER_SAVELOAD_CLASS_NAME(N, ns ::cls, _##pref##_##cls) + +#define SAVELOAD(...) \ + int operator&(IBinSaver& f) { \ + f.AddMulti(__VA_ARGS__); \ + return 0; \ + } + +#define SAVELOAD_OVERRIDE_WITHOUT_BASE(...) \ + int operator&(IBinSaver& f) override { \ + f.AddMulti(__VA_ARGS__); \ + return 0; \ + } + +#define SAVELOAD_OVERRIDE(base, ...) \ + int operator&(IBinSaver& f) override { \ + base::operator&(f); \ + f.AddMulti(__VA_ARGS__); \ + return 0; \ + } + +#define SAVELOAD_BASE(...) \ + int operator&(IBinSaver& f) { \ + TBase::operator&(f); \ + f.AddMulti(__VA_ARGS__); \ + return 0; \ + } diff --git a/library/cpp/binsaver/blob_io.cpp b/library/cpp/binsaver/blob_io.cpp new file mode 100644 index 0000000000..ff10349e6f --- /dev/null +++ b/library/cpp/binsaver/blob_io.cpp @@ -0,0 +1 @@ +#include "blob_io.h" diff --git a/library/cpp/binsaver/blob_io.h b/library/cpp/binsaver/blob_io.h new file mode 100644 index 0000000000..abe518ef30 --- /dev/null +++ b/library/cpp/binsaver/blob_io.h @@ -0,0 +1,47 @@ +#pragma once + +#include "bin_saver.h" +#include "buffered_io.h" + +#include <util/memory/blob.h> + +class TYaBlobStream: public IBinaryStream { + TBlob Blob; + i64 Pos; + + int WriteImpl(const void*, int) override { + Y_ASSERT(0); + return 0; + } + int ReadImpl(void* userBuffer, int size) override { + if (size == 0) + return 0; + i64 res = Min<i64>(Blob.Length() - Pos, size); + if (res) + memcpy(userBuffer, ((const char*)Blob.Data()) + Pos, res); + Pos += res; + return res; + } + bool IsValid() const override { + return true; + } + bool IsFailed() const override { + return false; + } + +public: + TYaBlobStream(const TBlob& blob) + : Blob(blob) + , Pos(0) + { + } +}; + +template <class T> +inline void SerializeBlob(const TBlob& data, T& c) { + TYaBlobStream f(data); + { + IBinSaver bs(f, true); + bs.Add(1, &c); + } +} diff --git a/library/cpp/binsaver/buffered_io.cpp b/library/cpp/binsaver/buffered_io.cpp new file mode 100644 index 0000000000..dd88b04bc5 --- /dev/null +++ b/library/cpp/binsaver/buffered_io.cpp @@ -0,0 +1,39 @@ +#include "buffered_io.h" + +i64 IBinaryStream::LongWrite(const void* userBuffer, i64 size) { + Y_VERIFY(size >= 0, "IBinaryStream::Write() called with a negative buffer size."); + + i64 leftToWrite = size; + while (leftToWrite != 0) { + int writeSz = static_cast<int>(Min<i64>(leftToWrite, std::numeric_limits<int>::max())); + int written = WriteImpl(userBuffer, writeSz); + Y_ASSERT(written <= writeSz); + leftToWrite -= written; + // Assumption: if WriteImpl(buf, writeSz) returns < writeSz, the stream is + // full and there's no sense in continuing. + if (written < writeSz) + break; + } + Y_ASSERT(size >= leftToWrite); + return size - leftToWrite; +} + +i64 IBinaryStream::LongRead(void* userBuffer, i64 size) { + Y_VERIFY(size >= 0, "IBinaryStream::Read() called with a negative buffer size."); + + i64 leftToRead = size; + while (leftToRead != 0) { + int readSz = static_cast<int>(Min<i64>(leftToRead, std::numeric_limits<int>::max())); + int read = ReadImpl(userBuffer, readSz); + Y_ASSERT(read <= readSz); + leftToRead -= read; + // Assumption: if ReadImpl(buf, readSz) returns < readSz, the stream is + // full and there's no sense in continuing. + if (read < readSz) { + memset(static_cast<char*>(userBuffer) + (size - leftToRead), 0, leftToRead); + break; + } + } + Y_ASSERT(size >= leftToRead); + return size - leftToRead; +} diff --git a/library/cpp/binsaver/buffered_io.h b/library/cpp/binsaver/buffered_io.h new file mode 100644 index 0000000000..75465c9c5c --- /dev/null +++ b/library/cpp/binsaver/buffered_io.h @@ -0,0 +1,134 @@ +#pragma once + +#include <util/system/yassert.h> +#include <util/generic/utility.h> +#include <util/generic/ylimits.h> +#include <string.h> + +struct IBinaryStream { + virtual ~IBinaryStream() = default; + ; + + inline i64 Write(const void* userBuffer, i64 size) { + if (size <= Max<int>()) { + return WriteImpl(userBuffer, static_cast<int>(size)); + } else { + return LongWrite(userBuffer, size); + } + } + + inline i64 Read(void* userBuffer, i64 size) { + if (size <= Max<int>()) { + return ReadImpl(userBuffer, static_cast<int>(size)); + } else { + return LongRead(userBuffer, size); + } + } + + virtual bool IsValid() const = 0; + virtual bool IsFailed() const = 0; + +private: + virtual int WriteImpl(const void* userBuffer, int size) = 0; + virtual int ReadImpl(void* userBuffer, int size) = 0; + + i64 LongRead(void* userBuffer, i64 size); + i64 LongWrite(const void* userBuffer, i64 size); +}; + +template <int N_SIZE = 16384> +class TBufferedStream { + char Buf[N_SIZE]; + i64 Pos, BufSize; + IBinaryStream& Stream; + bool bIsReading, bIsEof, bFailed; + + void ReadComplex(void* userBuffer, i64 size) { + if (bIsEof) { + memset(userBuffer, 0, size); + return; + } + char* dst = (char*)userBuffer; + i64 leftBytes = BufSize - Pos; + memcpy(dst, Buf + Pos, leftBytes); + dst += leftBytes; + size -= leftBytes; + Pos = BufSize = 0; + if (size > N_SIZE) { + i64 n = Stream.Read(dst, size); + bFailed = Stream.IsFailed(); + if (n != size) { + bIsEof = true; + memset(dst + n, 0, size - n); + } + } else { + BufSize = Stream.Read(Buf, N_SIZE); + bFailed = Stream.IsFailed(); + if (BufSize == 0) + bIsEof = true; + Read(dst, size); + } + } + + void WriteComplex(const void* userBuffer, i64 size) { + Flush(); + if (size >= N_SIZE) { + Stream.Write(userBuffer, size); + bFailed = Stream.IsFailed(); + } else + Write(userBuffer, size); + } + + void operator=(const TBufferedStream&) { + } + +public: + TBufferedStream(bool bRead, IBinaryStream& stream) + : Pos(0) + , BufSize(0) + , Stream(stream) + , bIsReading(bRead) + , bIsEof(false) + , bFailed(false) + { + } + ~TBufferedStream() { + if (!bIsReading) + Flush(); + } + void Flush() { + Y_ASSERT(!bIsReading); + if (bIsReading) + return; + Stream.Write(Buf, Pos); + bFailed = Stream.IsFailed(); + Pos = 0; + } + bool IsEof() const { + return bIsEof; + } + inline void Read(void* userBuffer, i64 size) { + Y_ASSERT(bIsReading); + if (!bIsEof && size + Pos <= BufSize) { + memcpy(userBuffer, Buf + Pos, size); + Pos += size; + return; + } + ReadComplex(userBuffer, size); + } + inline void Write(const void* userBuffer, i64 size) { + Y_ASSERT(!bIsReading); + if (Pos + size < N_SIZE) { + memcpy(Buf + Pos, userBuffer, size); + Pos += size; + return; + } + WriteComplex(userBuffer, size); + } + bool IsValid() const { + return Stream.IsValid(); + } + bool IsFailed() const { + return bFailed; + } +}; diff --git a/library/cpp/binsaver/class_factory.h b/library/cpp/binsaver/class_factory.h new file mode 100644 index 0000000000..e83512331b --- /dev/null +++ b/library/cpp/binsaver/class_factory.h @@ -0,0 +1,105 @@ +#pragma once + +#include <typeinfo> +#include <util/generic/hash.h> +#include <util/generic/vector.h> +#include <util/ysafeptr.h> + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// factory is using RTTI +// objects should inherit T and T must have at least 1 virtual function +template <class T> +class TClassFactory { +public: + typedef const std::type_info* VFT; + +private: + typedef T* (*newFunc)(); + typedef THashMap<int, newFunc> CTypeNewHash; // typeID->newFunc() + typedef THashMap<VFT, int> CTypeIndexHash; // vftable->typeID + + CTypeIndexHash typeIndex; + CTypeNewHash typeInfo; + + void RegisterTypeBase(int nTypeID, newFunc func, VFT vft); + static VFT GetObjectType(T* pObject) { + return &typeid(*pObject); + } + int VFT2TypeID(VFT t) { + CTypeIndexHash::iterator i = typeIndex.find(t); + if (i != typeIndex.end()) + return i->second; + for (i = typeIndex.begin(); i != typeIndex.end(); ++i) { + if (*i->first == *t) { + typeIndex[t] = i->second; + return i->second; + } + } + return -1; + } + +public: + template <class TT> + void RegisterType(int nTypeID, newFunc func, TT*) { + RegisterTypeBase(nTypeID, func, &typeid(TT)); + } + void RegisterTypeSafe(int nTypeID, newFunc func) { + TPtr<T> pObj = func(); + VFT vft = GetObjectType(pObj); + RegisterTypeBase(nTypeID, func, vft); + } + T* CreateObject(int nTypeID) { + newFunc f = typeInfo[nTypeID]; + if (f) + return f(); + return nullptr; + } + int GetObjectTypeID(T* pObject) { + return VFT2TypeID(GetObjectType(pObject)); + } + template <class TT> + int GetTypeID(TT* p = 0) { + (void)p; + return VFT2TypeID(&typeid(TT)); + } + + void GetAllTypeIDs(TVector<int>& typeIds) const { + typeIds.clear(); + for (typename CTypeNewHash::const_iterator iter = typeInfo.begin(); + iter != typeInfo.end(); + ++iter) { + typeIds.push_back(iter->first); + } + } +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <class T> +void TClassFactory<T>::RegisterTypeBase(int nTypeID, newFunc func, VFT vft) { + if (typeInfo.find(nTypeID) != typeInfo.end()) { + TObj<IObjectBase> o1 = typeInfo[nTypeID](); + TObj<IObjectBase> o2 = func(); + + // stupid clang warning + auto& o1v = *o1; + auto& o2v = *o2; + + if (typeid(o1v) != typeid(o2v)) { + fprintf(stderr, "IBinSaver: Type ID 0x%08X has been already used\n", nTypeID); + abort(); + } + } + + CTypeIndexHash::iterator typeIndexIt = typeIndex.find(vft); + if (typeIndexIt != typeIndex.end() && nTypeID != typeIndexIt->second) { + fprintf(stderr, "IBinSaver: class (Type ID 0x%08X) has been already registered (Type ID 0x%08X)\n", nTypeID, typeIndexIt->second); + abort(); + } + typeIndex[vft] = nTypeID; + typeInfo[nTypeID] = func; +} +//////////////////////////////////////////////////////////////////////////////////////////////////// +// macro for registering CFundament derivatives +#define REGISTER_CLASS(factory, N, name) factory.RegisterType(N, name::New##name, (name*)0); +#define REGISTER_TEMPL_CLASS(factory, N, name, className) factory.RegisterType(N, name::New##className, (name*)0); +#define REGISTER_CLASS_NM(factory, N, name, nmspace) factory.RegisterType(N, nmspace::name::New##name, (nmspace::name*)0); diff --git a/library/cpp/binsaver/mem_io.cpp b/library/cpp/binsaver/mem_io.cpp new file mode 100644 index 0000000000..82316606b6 --- /dev/null +++ b/library/cpp/binsaver/mem_io.cpp @@ -0,0 +1 @@ +#include "mem_io.h" diff --git a/library/cpp/binsaver/mem_io.h b/library/cpp/binsaver/mem_io.h new file mode 100644 index 0000000000..2a9e36fe68 --- /dev/null +++ b/library/cpp/binsaver/mem_io.h @@ -0,0 +1,212 @@ +#pragma once + +#include "bin_saver.h" + +namespace NMemIoInternals { + class TMemoryStream: public IBinaryStream { + TVector<char>& Data; + ui64 Pos; + + public: + TMemoryStream(TVector<char>* data, ui64 pos = 0) + : Data(*data) + , Pos(pos) + { + } + ~TMemoryStream() override { + } // keep gcc happy + + bool IsValid() const override { + return true; + } + bool IsFailed() const override { + return false; + } + + private: + int WriteImpl(const void* userBuffer, int size) override { + if (size == 0) + return 0; + Y_ASSERT(size > 0); + if (Pos + size > Data.size()) + Data.yresize(Pos + size); + memcpy(&Data[Pos], userBuffer, size); + Pos += size; + return size; + } + int ReadImpl(void* userBuffer, int size) override { + if (size == 0) + return 0; + Y_ASSERT(size > 0); + int res = Min(Data.size() - Pos, (ui64)size); + if (res) + memcpy(userBuffer, &Data[Pos], res); + Pos += res; + return res; + } + }; + + template <class T> + inline void SerializeMem(bool bRead, TVector<char>* data, T& c, bool stableOutput = false) { + if (IBinSaver::HasNonTrivialSerializer<T>(0u)) { + TMemoryStream f(data); + { + IBinSaver bs(f, bRead, stableOutput); + bs.Add(1, &c); + } + } else { + if (bRead) { + Y_ASSERT(data->size() == sizeof(T)); + c = *reinterpret_cast<T*>(&(*data)[0]); + } else { + data->yresize(sizeof(T)); + *reinterpret_cast<T*>(&(*data)[0]) = c; + } + } + } + + //////////////////////////////////////////////////////////////////////////// + class THugeMemoryStream: public IBinaryStream { + TVector<TVector<char>>& Data; + i64 Block, Pos; + bool ShrinkOnRead; + + enum { + MAX_BLOCK_SIZE = 1024 * 1024 // Aligned with cache size + }; + + public: + THugeMemoryStream(TVector<TVector<char>>* data, bool shrinkOnRead = false) + : Data(*data) + , Block(0) + , Pos(0) + , ShrinkOnRead(shrinkOnRead) + { + Y_ASSERT(!data->empty()); + } + + ~THugeMemoryStream() override { + } // keep gcc happy + + bool IsValid() const override { + return true; + } + bool IsFailed() const override { + return false; + } + + private: + int WriteImpl(const void* userDataArg, int sizeArg) override { + if (sizeArg == 0) + return 0; + const char* userData = (const char*)userDataArg; + i64 size = sizeArg; + i64 newSize = Pos + size; + if (newSize > Data[Block].ysize()) { + while (newSize > MAX_BLOCK_SIZE) { + int maxWrite = MAX_BLOCK_SIZE - Pos; + Data[Block].yresize(MAX_BLOCK_SIZE); + if (maxWrite) { + memcpy(&Data[Block][Pos], userData, maxWrite); + userData += maxWrite; + size -= maxWrite; + } + ++Block; + Pos = 0; + Data.resize(Block + 1); + newSize = Pos + size; + } + Data[Block].yresize(newSize); + } + if (size) { + memcpy(&Data[Block][Pos], userData, size); + } + Pos += size; + return sizeArg; + } + int ReadImpl(void* userDataArg, int sizeArg) override { + if (sizeArg == 0) + return 0; + + char* userData = (char*)userDataArg; + i64 size = sizeArg; + i64 rv = 0; + while (size > 0) { + int curBlockSize = Data[Block].ysize(); + int maxRead = 0; + if (Pos + size > curBlockSize) { + maxRead = curBlockSize - Pos; + if (maxRead) { + memcpy(userData, &Data[Block][Pos], maxRead); + userData += maxRead; + size -= maxRead; + rv += maxRead; + } + if (Block + 1 == Data.ysize()) { + memset(userData, 0, size); + return rv; + } + if (ShrinkOnRead) { + TVector<char>().swap(Data[Block]); + } + ++Block; + Pos = 0; + } else { + memcpy(userData, &Data[Block][Pos], size); + Pos += size; + rv += size; + return rv; + } + } + return rv; + } + }; + + template <class T> + inline void SerializeMem(bool bRead, TVector<TVector<char>>* data, T& c, bool stableOutput = false) { + if (data->empty()) { + data->resize(1); + } + THugeMemoryStream f(data); + { + IBinSaver bs(f, bRead, stableOutput); + bs.Add(1, &c); + } + } +} + +template <class T> +inline void SerializeMem(const TVector<char>& data, T& c) { + if (IBinSaver::HasNonTrivialSerializer<T>(0u)) { + TVector<char> tmp(data); + SerializeFromMem(&tmp, c); + } else { + Y_ASSERT(data.size() == sizeof(T)); + c = *reinterpret_cast<const T*>(&data[0]); + } +} + +template <class T, class D> +inline void SerializeToMem(D* data, T& c, bool stableOutput = false) { + NMemIoInternals::SerializeMem(false, data, c, stableOutput); +} + +template <class T, class D> +inline void SerializeFromMem(D* data, T& c, bool stableOutput = false) { + NMemIoInternals::SerializeMem(true, data, c, stableOutput); +} + +// Frees memory in (*data)[i] immediately upon it's deserialization, thus keeps low overall memory consumption for data + object. +template <class T> +inline void SerializeFromMemShrinkInput(TVector<TVector<char>>* data, T& c) { + if (data->empty()) { + data->resize(1); + } + NMemIoInternals::THugeMemoryStream f(data, true); + { + IBinSaver bs(f, true, false); + bs.Add(1, &c); + } + data->resize(0); + data->shrink_to_fit(); +} diff --git a/library/cpp/binsaver/ut/binsaver_ut.cpp b/library/cpp/binsaver/ut/binsaver_ut.cpp new file mode 100644 index 0000000000..37eba5406f --- /dev/null +++ b/library/cpp/binsaver/ut/binsaver_ut.cpp @@ -0,0 +1,198 @@ +#include <library/cpp/binsaver/util_stream_io.h> +#include <library/cpp/binsaver/mem_io.h> +#include <library/cpp/binsaver/bin_saver.h> +#include <library/cpp/binsaver/ut_util/ut_util.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/stream/buffer.h> +#include <util/generic/map.h> + +struct TBinarySerializable { + ui32 Data = 0; +}; + +struct TNonBinarySerializable { + ui32 Data = 0; + TString StrData; +}; + +struct TCustomSerializer { + ui32 Data = 0; + TString StrData; + SAVELOAD(StrData, Data); +}; + +struct TCustomOuterSerializer { + ui32 Data = 0; + TString StrData; +}; + +void operator&(TCustomOuterSerializer& s, IBinSaver& f); + +struct TCustomOuterSerializerTmpl { + ui32 Data = 0; + TString StrData; +}; + +struct TCustomOuterSerializerTmplDerived: public TCustomOuterSerializerTmpl { + ui32 Data = 0; + TString StrData; +}; + +struct TMoveOnlyType { + ui32 Data = 0; + + TMoveOnlyType() = default; + TMoveOnlyType(TMoveOnlyType&&) = default; + + bool operator==(const TMoveOnlyType& obj) const { + return Data == obj.Data; + } +}; + +struct TTypeWithArray { + ui32 Data = 1; + TString Array[2][2]{{"test", "data"}, {"and", "more"}}; + + SAVELOAD(Data, Array); + bool operator==(const TTypeWithArray& obj) const { + return Data == obj.Data && std::equal(std::begin(Array[0]), std::end(Array[0]), obj.Array[0]) && std::equal(std::begin(Array[1]), std::end(Array[1]), obj.Array[1]); + } +}; + +template <typename T, typename = std::enable_if_t<std::is_base_of<TCustomOuterSerializerTmpl, T>::value>> +int operator&(T& s, IBinSaver& f); + +static bool operator==(const TBlob& l, const TBlob& r) { + return TStringBuf(l.AsCharPtr(), l.Size()) == TStringBuf(r.AsCharPtr(), r.Size()); +} + +Y_UNIT_TEST_SUITE(BinSaver){ + Y_UNIT_TEST(HasTrivialSerializer){ + UNIT_ASSERT(!IBinSaver::HasNonTrivialSerializer<TBinarySerializable>(0u)); +UNIT_ASSERT(!IBinSaver::HasNonTrivialSerializer<TNonBinarySerializable>(0u)); +UNIT_ASSERT(IBinSaver::HasNonTrivialSerializer<TCustomSerializer>(0u)); +UNIT_ASSERT(IBinSaver::HasNonTrivialSerializer<TCustomOuterSerializer>(0u)); +UNIT_ASSERT(IBinSaver::HasNonTrivialSerializer<TCustomOuterSerializerTmpl>(0u)); +UNIT_ASSERT(IBinSaver::HasNonTrivialSerializer<TCustomOuterSerializerTmplDerived>(0u)); +UNIT_ASSERT(IBinSaver::HasNonTrivialSerializer<TVector<TCustomSerializer>>(0u)); +} + + +Y_UNIT_TEST(TestStroka) { + TestBinSaverSerialization(TString("QWERTY")); +} + +Y_UNIT_TEST(TestMoveOnlyType) { + TestBinSaverSerializationToBuffer(TMoveOnlyType()); +} + +Y_UNIT_TEST(TestVectorStrok) { + TestBinSaverSerialization(TVector<TString>{"A", "B", "C"}); +} + +Y_UNIT_TEST(TestCArray) { + TestBinSaverSerialization(TTypeWithArray()); +} + +Y_UNIT_TEST(TestSets) { + TestBinSaverSerialization(THashSet<TString>{"A", "B", "C"}); + TestBinSaverSerialization(TSet<TString>{"A", "B", "C"}); +} + +Y_UNIT_TEST(TestMaps) { + TestBinSaverSerialization(THashMap<TString, ui32>{{"A", 1}, {"B", 2}, {"C", 3}}); + TestBinSaverSerialization(TMap<TString, ui32>{{"A", 1}, {"B", 2}, {"C", 3}}); +} + +Y_UNIT_TEST(TestBlob) { + TestBinSaverSerialization(TBlob::FromStringSingleThreaded("qwerty")); +} + +Y_UNIT_TEST(TestVariant) { + { + using T = std::variant<TString, int>; + + TestBinSaverSerialization(T(TString(""))); + TestBinSaverSerialization(T(0)); + } + { + using T = std::variant<TString, int, float>; + + TestBinSaverSerialization(T(TString("ask"))); + TestBinSaverSerialization(T(12)); + TestBinSaverSerialization(T(0.64f)); + } +} + +Y_UNIT_TEST(TestPod) { + struct TPod { + ui32 A = 5; + ui64 B = 7; + bool operator==(const TPod& other) const { + return A == other.A && B == other.B; + } + }; + TestBinSaverSerialization(TPod()); + TPod custom; + custom.A = 25; + custom.B = 37; + TestBinSaverSerialization(custom); + TestBinSaverSerialization(TVector<TPod>{custom}); +} + +Y_UNIT_TEST(TestSubPod) { + struct TPod { + struct TSub { + ui32 X = 10; + bool operator==(const TSub& other) const { + return X == other.X; + } + }; + TVector<TSub> B; + int operator&(IBinSaver& f) { + f.Add(0, &B); + return 0; + } + bool operator==(const TPod& other) const { + return B == other.B; + } + }; + TestBinSaverSerialization(TPod()); + TPod::TSub sub; + sub.X = 1; + TPod custom; + custom.B = {sub}; + TestBinSaverSerialization(TVector<TPod>{custom}); +} + +Y_UNIT_TEST(TestMemberAndOpIsMain) { + struct TBase { + TString S; + virtual int operator&(IBinSaver& f) { + f.Add(0, &S); + return 0; + } + virtual ~TBase() = default; + }; + + struct TDerived: public TBase { + int A = 0; + int operator&(IBinSaver& f)override { + f.Add(0, static_cast<TBase*>(this)); + f.Add(0, &A); + return 0; + } + bool operator==(const TDerived& other) const { + return A == other.A && S == other.S; + } + }; + + TDerived obj; + obj.S = "TString"; + obj.A = 42; + + TestBinSaverSerialization(obj); +} +} +; diff --git a/library/cpp/binsaver/ut/ya.make b/library/cpp/binsaver/ut/ya.make new file mode 100644 index 0000000000..43dc20bff7 --- /dev/null +++ b/library/cpp/binsaver/ut/ya.make @@ -0,0 +1,11 @@ +UNITTEST_FOR(library/cpp/binsaver) + +OWNER(gulin) + +SRCS( + binsaver_ut.cpp +) + +PEERDIR(library/cpp/binsaver/ut_util) + +END() diff --git a/library/cpp/binsaver/ut_util/README.md b/library/cpp/binsaver/ut_util/README.md new file mode 100644 index 0000000000..41641cd1e1 --- /dev/null +++ b/library/cpp/binsaver/ut_util/README.md @@ -0,0 +1 @@ +Library for testing BinSaver serialization.
\ No newline at end of file diff --git a/library/cpp/binsaver/ut_util/ut_util.cpp b/library/cpp/binsaver/ut_util/ut_util.cpp new file mode 100644 index 0000000000..4cd8daa931 --- /dev/null +++ b/library/cpp/binsaver/ut_util/ut_util.cpp @@ -0,0 +1 @@ +#include "ut_util.h" diff --git a/library/cpp/binsaver/ut_util/ut_util.h b/library/cpp/binsaver/ut_util/ut_util.h new file mode 100644 index 0000000000..52e7bcf8e1 --- /dev/null +++ b/library/cpp/binsaver/ut_util/ut_util.h @@ -0,0 +1,71 @@ +#pragma once + +#include <library/cpp/binsaver/bin_saver.h> +#include <library/cpp/binsaver/mem_io.h> +#include <library/cpp/binsaver/util_stream_io.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/vector.h> +#include <util/stream/buffer.h> + +#include <functional> + + +/* comparerChecksInside == true means comparer uses UNIT_ASSERT... inside itself + * comparerChecksInside == false means comparer returns if its arguments are equal + */ + +template <class T, class TComparer = std::equal_to<T>, bool comparerChecksInside = false> +void UnitTestCheckWithComparer(const T& lhs, const T& rhs, const TComparer& comparer) { + if constexpr (comparerChecksInside) { + comparer(lhs, rhs); + } else { + UNIT_ASSERT(comparer(lhs, rhs)); + } +} + + +/* comparerChecksInside == true means comparer uses UNIT_ASSERT... inside itself + * comparerChecksInside == false means comparer returns true if its arguments are equal + */ + +template <typename T, typename TComparer = std::equal_to<T>, bool comparerChecksInside = false> +void TestBinSaverSerializationToBuffer(const T& original, const TComparer& comparer = TComparer()) { + TBufferOutput out; + { + TYaStreamOutput yaOut(out); + + IBinSaver f(yaOut, false, false); + f.Add(0, const_cast<T*>(&original)); + } + TBufferInput in(out.Buffer()); + T restored; + { + TYaStreamInput yaIn(in); + IBinSaver f(yaIn, true, false); + f.Add(0, &restored); + } + UnitTestCheckWithComparer<T, TComparer, comparerChecksInside>(original, restored, comparer); +} + +template <typename T, typename TComparer = std::equal_to<T>, bool comparerChecksInside = false> +void TestBinSaverSerializationToVector(const T& original, const TComparer& comparer = TComparer()) { + TVector<char> out; + SerializeToMem(&out, *const_cast<T*>(&original)); + T restored; + SerializeFromMem(&out, restored); + UnitTestCheckWithComparer<T, TComparer, comparerChecksInside>(original, restored, comparer); + + TVector<TVector<char>> out2D; + SerializeToMem(&out2D, *const_cast<T*>(&original)); + T restored2D; + SerializeFromMem(&out2D, restored2D); + UnitTestCheckWithComparer<T, TComparer, comparerChecksInside>(original, restored2D, comparer); +} + +template <typename T, typename TComparer = std::equal_to<T>, bool comparerChecksInside = false> +void TestBinSaverSerialization(const T& original, const TComparer& comparer = TComparer()) { + TestBinSaverSerializationToBuffer<T, TComparer, comparerChecksInside>(original, comparer); + TestBinSaverSerializationToVector<T, TComparer, comparerChecksInside>(original, comparer); +} diff --git a/library/cpp/binsaver/ut_util/ya.make b/library/cpp/binsaver/ut_util/ya.make new file mode 100644 index 0000000000..7e60f13ef3 --- /dev/null +++ b/library/cpp/binsaver/ut_util/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +OWNER(gulin) + +SRCS( + ut_util.cpp +) + +PEERDIR( + library/cpp/binsaver + library/cpp/testing/unittest +) + +END() diff --git a/library/cpp/binsaver/util_stream_io.cpp b/library/cpp/binsaver/util_stream_io.cpp new file mode 100644 index 0000000000..a2a79a2fe7 --- /dev/null +++ b/library/cpp/binsaver/util_stream_io.cpp @@ -0,0 +1 @@ +#include "util_stream_io.h" diff --git a/library/cpp/binsaver/util_stream_io.h b/library/cpp/binsaver/util_stream_io.h new file mode 100644 index 0000000000..d65d630b93 --- /dev/null +++ b/library/cpp/binsaver/util_stream_io.h @@ -0,0 +1,86 @@ +#pragma once + +#include "bin_saver.h" + +#include <util/stream/input.h> +#include <util/stream/output.h> +#include <util/stream/file.h> + +class TYaStreamInput: public IBinaryStream { + IInputStream& Stream; + + int WriteImpl(const void*, int) override { + Y_ASSERT(0); + return 0; + } + int ReadImpl(void* userBuffer, int size) override { + return (int)Stream.Read(userBuffer, (size_t)size); + } + bool IsValid() const override { + return true; + } + bool IsFailed() const override { + return false; + } + +public: + TYaStreamInput(IInputStream& stream) + : Stream(stream) + { + } +}; + +template <class T> +inline void SerializeFromStream(IInputStream& stream, T& c) { + TYaStreamInput f(stream); + { + IBinSaver bs(f, true); + bs.Add(1, &c); + } +} + +template <class T> +inline void SerializeFromFile(const TString& fileName, T& c) { + TIFStream in(fileName); + SerializeFromStream(in, c); +} + +class TYaStreamOutput: public IBinaryStream { + IOutputStream& Stream; + + int WriteImpl(const void* what, int size) override { + Stream.Write(what, (size_t)size); + return size; + } + int ReadImpl(void*, int) override { + Y_ASSERT(0); + return 0; + } + bool IsValid() const override { + return true; + } + bool IsFailed() const override { + return false; + } + +public: + TYaStreamOutput(IOutputStream& stream) + : Stream(stream) + { + } +}; + +template <class T> +inline void SerializeToArcadiaStream(IOutputStream& stream, T& c) { + TYaStreamOutput f(stream); + { + IBinSaver bs(f, false); + bs.Add(1, &c); + } +} + +template <class T> +inline void SerializeToFile(const TString& fileName, T& c) { + TOFStream out(fileName); + SerializeToArcadiaStream(out, c); +} diff --git a/library/cpp/binsaver/ya.make b/library/cpp/binsaver/ya.make new file mode 100644 index 0000000000..9693c54639 --- /dev/null +++ b/library/cpp/binsaver/ya.make @@ -0,0 +1,18 @@ +LIBRARY() + +OWNER(gulin) + +SRCS( + class_factory.h + bin_saver.cpp + blob_io.cpp + buffered_io.cpp + mem_io.cpp + util_stream_io.cpp +) + +PEERDIR( + library/cpp/containers/2d_array +) + +END() |