aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/streams/zstd/zstd.cpp
blob: 29816f6d4c953b31b712999ac4ce1d118db4f59e (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11




                                    
                                           


                                                         
                                                                                                  































































































































































                                                                                       
#include "zstd.h"

#include <util/generic/buffer.h>
#include <util/generic/yexception.h>

#define ZSTD_STATIC_LINKING_ONLY
#include <contrib/libs/zstd/include/zstd.h>

namespace {
    inline void CheckError(const char* op, size_t code) {
        if (::ZSTD_isError(code)) {
            ythrow yexception() << op << TStringBuf(" zstd error: ") << ::ZSTD_getErrorName(code);
        }
    }

    struct DestroyZCStream {
        static void Destroy(::ZSTD_CStream* p) noexcept {
            ::ZSTD_freeCStream(p);
        }
    };

    struct DestroyZDStream {
        static void Destroy(::ZSTD_DStream* p) noexcept {
            ::ZSTD_freeDStream(p);
        }
    };
}

class TZstdCompress::TImpl {
public:
    TImpl(IOutputStream* slave, int quality)
        : Slave_(slave)
        , ZCtx_(::ZSTD_createCStream())
        , Buffer_(::ZSTD_CStreamOutSize())  // do reserve
    {
        Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_CStream");
        Y_ENSURE(0 != Buffer_.Capacity(), "ZSTD_CStreamOutSize was too small");
        CheckError("init", ZSTD_initCStream(ZCtx_.Get(), quality));
    }

    void Write(const void* buffer, size_t size) {
        ::ZSTD_inBuffer zIn{buffer, size, 0};
        auto zOut = OutBuffer();

        while (0 != zIn.size) {
            CheckError("compress", ::ZSTD_compressStream(ZCtx_.Get(), &zOut, &zIn));
            DoWrite(zOut);
            // forget about the data we already compressed
            zIn.src = static_cast<const unsigned char*>(zIn.src) + zIn.pos;
            zIn.size -= zIn.pos;
            zIn.pos = 0;
        }
    }

    void Flush() {
        auto zOut = OutBuffer();
        CheckError("flush", ::ZSTD_flushStream(ZCtx_.Get(), &zOut));
        DoWrite(zOut);
    }

    void Finish() {
        auto zOut = OutBuffer();
        size_t returnCode;
        do {
            returnCode = ::ZSTD_endStream(ZCtx_.Get(), &zOut);
            CheckError("finish", returnCode);
            DoWrite(zOut);
        } while (0 != returnCode);  // zero means there is no more bytes to flush
    }

private:
    ::ZSTD_outBuffer OutBuffer() {
        return {Buffer_.Data(), Buffer_.Capacity(), 0};
    }

    void DoWrite(::ZSTD_outBuffer& buffer) {
        Slave_->Write(buffer.dst, buffer.pos);
        buffer.pos = 0;
    }
private:
    IOutputStream* Slave_;
    THolder<::ZSTD_CStream, DestroyZCStream> ZCtx_;
    TBuffer Buffer_;
};

TZstdCompress::TZstdCompress(IOutputStream* slave, int quality)
    : Impl_(new TImpl(slave, quality)) {
}

TZstdCompress::~TZstdCompress() {
    try {
        Finish();
    } catch (...) {
    }
}

void TZstdCompress::DoWrite(const void* buffer, size_t size) {
    Y_ENSURE(Impl_, "Cannot use stream after finish.");
    Impl_->Write(buffer, size);
}

void TZstdCompress::DoFlush() {
    Y_ENSURE(Impl_, "Cannot use stream after finish.");
    Impl_->Flush();
}

void TZstdCompress::DoFinish() {
    // Finish should be idempotent
    if (Impl_) {
        auto impl = std::move(Impl_);
        impl->Finish();
    }
}

////////////////////////////////////////////////////////////////////////////////

class TZstdDecompress::TImpl {
public:
    TImpl(IInputStream* slave, size_t bufferSize)
        : Slave_(slave)
        , ZCtx_(::ZSTD_createDStream())
        , Buffer_(bufferSize)  // do reserve
        , Offset_(0)
    {
        Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_DStream");
        Y_ENSURE(0 != Buffer_.Capacity(), "Buffer size was too small");
    }

    size_t Read(void* buffer, size_t size) {
        Y_ASSERT(size > 0);

        ::ZSTD_outBuffer zOut{buffer, size, 0};
        ::ZSTD_inBuffer zIn{Buffer_.Data(), Buffer_.Size(), Offset_};

        size_t returnCode = 0;
        while (zOut.pos != zOut.size) {
            if (zIn.pos == zIn.size) {
                zIn.size = Slave_->Read(Buffer_.Data(), Buffer_.Capacity());
                Buffer_.Resize(zIn.size);
                zIn.pos = Offset_ = 0;
                if (0 == zIn.size) {
                    // end of stream, need to check that there is no uncompleted blocks
                    Y_ENSURE(0 == returnCode, "Incomplete block");
                    break;
                }
            }
            returnCode = ::ZSTD_decompressStream(ZCtx_.Get(), &zOut, &zIn);
            CheckError("decompress", returnCode);
            if (0 == returnCode) {
                // The frame is over, prepare to (maybe) start a new frame
                ZSTD_initDStream(ZCtx_.Get());
            }
        }
        Offset_ = zIn.pos;
        return zOut.pos;
    }

private:
    IInputStream* Slave_;
    THolder<::ZSTD_DStream, DestroyZDStream> ZCtx_;
    TBuffer Buffer_;
    size_t  Offset_;
};

TZstdDecompress::TZstdDecompress(IInputStream* slave, size_t bufferSize)
    : Impl_(new TImpl(slave, bufferSize)) {
}

TZstdDecompress::~TZstdDecompress() = default;

size_t TZstdDecompress::DoRead(void* buffer, size_t size) {
    return Impl_->Read(buffer, size);
}