diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/codecs/zstd_dict_codec.cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/codecs/zstd_dict_codec.cpp')
-rw-r--r-- | library/cpp/codecs/zstd_dict_codec.cpp | 281 |
1 files changed, 281 insertions, 0 deletions
diff --git a/library/cpp/codecs/zstd_dict_codec.cpp b/library/cpp/codecs/zstd_dict_codec.cpp new file mode 100644 index 0000000000..c42a2879e6 --- /dev/null +++ b/library/cpp/codecs/zstd_dict_codec.cpp @@ -0,0 +1,281 @@ +#include "zstd_dict_codec.h" + +#include <library/cpp/packers/packers.h> + +#include <util/generic/ptr.h> +#include <util/generic/refcount.h> +#include <util/generic/noncopyable.h> +#include <util/string/builder.h> +#include <util/system/src_location.h> +#include <util/ysaveload.h> + +#define ZDICT_STATIC_LINKING_ONLY + +#include <contrib/libs/zstd/include/zdict.h> +#include <contrib/libs/zstd/include/zstd.h> +#include <contrib/libs/zstd/include/zstd_errors.h> + +// See IGNIETFERRO-320 for possible bugs + +namespace NCodecs { + class TZStdDictCodec::TImpl: public TAtomicRefCount<TZStdDictCodec::TImpl> { + template <class T, size_t Deleter(T*)> + class TPtrHolder : TMoveOnly { + T* Ptr = nullptr; + + public: + TPtrHolder() = default; + + TPtrHolder(T* dict) + : Ptr(dict) + { + } + + T* Get() { + return Ptr; + } + + const T* Get() const { + return Ptr; + } + + void Reset(T* dict) { + Dispose(); + Ptr = dict; + } + + void Dispose() { + if (Ptr) { + Deleter(Ptr); + Ptr = nullptr; + } + } + + ~TPtrHolder() { + Dispose(); + } + }; + + using TCDict = TPtrHolder<ZSTD_CDict, ZSTD_freeCDict>; + using TDDict = TPtrHolder<ZSTD_DDict, ZSTD_freeDDict>; + using TCCtx = TPtrHolder<ZSTD_CCtx, ZSTD_freeCCtx>; + using TDCtx = TPtrHolder<ZSTD_DCtx, ZSTD_freeDCtx>; + + using TSizePacker = NPackers::TPacker<ui64>; + + public: + static const ui32 SampleSize = (1 << 22) * 5; + + explicit TImpl(ui32 comprLevel) + : CompressionLevel(comprLevel) + { + const size_t zeroSz = TSizePacker().MeasureLeaf(0); + Zero.Resize(zeroSz); + TSizePacker().PackLeaf(Zero.data(), 0, zeroSz); + } + + ui32 GetCompressionLevel() const { + return CompressionLevel; + } + + ui8 Encode(TStringBuf in, TBuffer& outbuf) const { + outbuf.Clear(); + + if (in.empty()) { + return 0; + } + + TSizePacker packer; + + const char* rawBeg = in.data(); + const size_t rawSz = in.size(); + + const size_t szSz = packer.MeasureLeaf(rawSz); + const size_t maxDatSz = ZSTD_compressBound(rawSz); + + outbuf.Resize(szSz + maxDatSz); + packer.PackLeaf(outbuf.data(), rawSz, szSz); + + TCCtx ctx{CheckPtr(ZSTD_createCCtx(), __LOCATION__)}; + const size_t resSz = CheckSize(ZSTD_compress_usingCDict( + ctx.Get(), outbuf.data() + szSz, maxDatSz, rawBeg, rawSz, CDict.Get()), + __LOCATION__); + + if (resSz < rawSz) { + outbuf.Resize(resSz + szSz); + } else { + outbuf.Resize(Zero.size() + rawSz); + memcpy(outbuf.data(), Zero.data(), Zero.size()); + memcpy(outbuf.data() + Zero.size(), rawBeg, rawSz); + } + return 0; + } + + void Decode(TStringBuf in, TBuffer& outbuf) const { + outbuf.Clear(); + + if (in.empty()) { + return; + } + + TSizePacker packer; + + const char* rawBeg = in.data(); + size_t rawSz = in.size(); + + const size_t szSz = packer.SkipLeaf(rawBeg); + ui64 datSz = 0; + packer.UnpackLeaf(rawBeg, datSz); + + rawBeg += szSz; + rawSz -= szSz; + + if (!datSz) { + outbuf.Resize(rawSz); + memcpy(outbuf.data(), rawBeg, rawSz); + } else { + // size_t zSz = ZSTD_getDecompressedSize(rawBeg, rawSz); + // Y_ENSURE_EX(datSz == zSz, TCodecException() << datSz << " != " << zSz); + outbuf.Resize(datSz); + TDCtx ctx{CheckPtr(ZSTD_createDCtx(), __LOCATION__)}; + CheckSize(ZSTD_decompress_usingDDict( + ctx.Get(), outbuf.data(), outbuf.size(), rawBeg, rawSz, DDict.Get()), + __LOCATION__); + outbuf.Resize(datSz); + } + } + + bool Learn(ISequenceReader& in, bool throwOnError) { + TBuffer data; + TVector<size_t> lens; + + data.Reserve(2 * SampleSize); + TStringBuf r; + while (in.NextRegion(r)) { + if (!r) { + continue; + } + data.Append(r.data(), r.size()); + lens.push_back(r.size()); + } + + ZDICT_legacy_params_t params; + memset(¶ms, 0, sizeof(params)); + params.zParams.compressionLevel = 1; + params.zParams.notificationLevel = 1; + Dict.Resize(Max<size_t>(1 << 20, data.Size() + 16 * lens.size())); + + if (!lens) { + Dict.Reset(); + } else { + size_t trainResult = ZDICT_trainFromBuffer_legacy( + Dict.data(), Dict.size(), data.Data(), const_cast<const size_t*>(&lens[0]), lens.size(), params); + if (ZSTD_isError(trainResult)) { + if (!throwOnError) { + return false; + } + CheckSize(trainResult, __LOCATION__); + } + Dict.Resize(trainResult); + Dict.ShrinkToFit(); + } + InitContexts(); + return true; + } + + void Save(IOutputStream* out) const { + ::Save(out, Dict); + } + + void Load(IInputStream* in) { + ::Load(in, Dict); + InitContexts(); + } + + void InitContexts() { + CDict.Reset(CheckPtr(ZSTD_createCDict(Dict.data(), Dict.size(), CompressionLevel), __LOCATION__)); + DDict.Reset(CheckPtr(ZSTD_createDDict(Dict.data(), Dict.size()), __LOCATION__)); + } + + static size_t CheckSize(size_t sz, TSourceLocation loc) { + if (ZSTD_isError(sz)) { + ythrow TCodecException() << loc << " " << ZSTD_getErrorName(sz) << " (code " << (int)ZSTD_getErrorCode(sz) << ")"; + } + return sz; + } + + template <class T> + static T* CheckPtr(T* t, TSourceLocation loc) { + Y_ENSURE_EX(t, TCodecException() << loc << " " + << "unexpected nullptr"); + return t; + } + + private: + ui32 CompressionLevel = 1; + + TBuffer Zero; + TBuffer Dict; + + TCDict CDict; + TDDict DDict; + }; + + TZStdDictCodec::TZStdDictCodec(ui32 comprLevel) + : Impl(new TImpl(comprLevel)) + { + MyTraits.NeedsTraining = true; + MyTraits.SizeOnEncodeMultiplier = 2; + MyTraits.SizeOnDecodeMultiplier = 10; + MyTraits.RecommendedSampleSize = TImpl::SampleSize; // same as for solar + } + + TZStdDictCodec::~TZStdDictCodec() { + } + + TString TZStdDictCodec::GetName() const { + return TStringBuilder() << MyName() << "-" << Impl->GetCompressionLevel(); + } + + ui8 TZStdDictCodec::Encode(TStringBuf in, TBuffer& out) const { + return Impl->Encode(in, out); + } + + void TZStdDictCodec::Decode(TStringBuf in, TBuffer& out) const { + Impl->Decode(in, out); + } + + void TZStdDictCodec::DoLearn(ISequenceReader& in) { + Impl = new TImpl(Impl->GetCompressionLevel()); + Impl->Learn(in, true/*throwOnError*/); + } + + bool TZStdDictCodec::DoTryToLearn(ISequenceReader& in) { + Impl = new TImpl(Impl->GetCompressionLevel()); + return Impl->Learn(in, false/*throwOnError*/); + } + + void TZStdDictCodec::Save(IOutputStream* out) const { + Impl->Save(out); + } + + void TZStdDictCodec::Load(IInputStream* in) { + Impl->Load(in); + } + + TVector<TString> TZStdDictCodec::ListCompressionNames() { + TVector<TString> res; + for (int i = 1; i <= ZSTD_maxCLevel(); ++i) { + res.emplace_back(TStringBuilder() << MyName() << "-" << i); + } + return res; + } + + int TZStdDictCodec::ParseCompressionName(TStringBuf name) { + int c = 0; + TryFromString(name.After('-'), c); + Y_ENSURE_EX(name.Before('-') == MyName() && c > 0 && c <= ZSTD_maxCLevel(), TCodecException() << "invald codec name" << name); + return c; + } + +} |