aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/streams/zstd/zstd.cpp
blob: 0e2cf159dbfd6c7c7479eed4305f019435a81e9a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#include "zstd.h"

#include <util/generic/buffer.h>
#include <util/generic/yexception.h>

#define ZSTD_STATIC_LINKING_ONLY
#include <contrib/libs/zstd/include/zstd.h> 

namespace {
    inline void CheckError(const char* op, size_t code) {
        if (::ZSTD_isError(code)) {
            ythrow yexception() << op << TStringBuf(" zstd error: ") << ::ZSTD_getErrorName(code);
        }
    }

    struct DestroyZCStream {
        static void Destroy(::ZSTD_CStream* p) noexcept {
            ::ZSTD_freeCStream(p);
        }
    };

    struct DestroyZDStream {
        static void Destroy(::ZSTD_DStream* p) noexcept {
            ::ZSTD_freeDStream(p);
        }
    };
}

class TZstdCompress::TImpl {
public:
    TImpl(IOutputStream* slave, int quality)
        : Slave_(slave)
        , ZCtx_(::ZSTD_createCStream())
        , Buffer_(::ZSTD_CStreamOutSize())  // do reserve
    {
        Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_CStream");
        Y_ENSURE(0 != Buffer_.Capacity(), "ZSTD_CStreamOutSize was too small");
        CheckError("init", ZSTD_initCStream(ZCtx_.Get(), quality));
    }

    void Write(const void* buffer, size_t size) {
        ::ZSTD_inBuffer zIn{buffer, size, 0};
        auto zOut = OutBuffer();

        while (0 != zIn.size) {
            CheckError("compress", ::ZSTD_compressStream(ZCtx_.Get(), &zOut, &zIn));
            DoWrite(zOut);
            // forget about the data we already compressed
            zIn.src = static_cast<const unsigned char*>(zIn.src) + zIn.pos;
            zIn.size -= zIn.pos;
            zIn.pos = 0;
        }
    }

    void Flush() {
        auto zOut = OutBuffer();
        CheckError("flush", ::ZSTD_flushStream(ZCtx_.Get(), &zOut));
        DoWrite(zOut);
    }

    void Finish() {
        auto zOut = OutBuffer();
        size_t returnCode;
        do {
            returnCode = ::ZSTD_endStream(ZCtx_.Get(), &zOut);
            CheckError("finish", returnCode);
            DoWrite(zOut);
        } while (0 != returnCode);  // zero means there is no more bytes to flush
    }

private:
    ::ZSTD_outBuffer OutBuffer() {
        return {Buffer_.Data(), Buffer_.Capacity(), 0};
    }

    void DoWrite(::ZSTD_outBuffer& buffer) {
        Slave_->Write(buffer.dst, buffer.pos);
        buffer.pos = 0;
    }
private:
    IOutputStream* Slave_;
    THolder<::ZSTD_CStream, DestroyZCStream> ZCtx_;
    TBuffer Buffer_;
};

TZstdCompress::TZstdCompress(IOutputStream* slave, int quality)
    : Impl_(new TImpl(slave, quality)) {
}

TZstdCompress::~TZstdCompress() {
    try {
        Finish();
    } catch (...) {
    }
}

void TZstdCompress::DoWrite(const void* buffer, size_t size) {
    Y_ENSURE(Impl_, "Cannot use stream after finish.");
    Impl_->Write(buffer, size);
}

void TZstdCompress::DoFlush() {
    Y_ENSURE(Impl_, "Cannot use stream after finish.");
    Impl_->Flush();
}

void TZstdCompress::DoFinish() {
    // Finish should be idempotent
    if (Impl_) {
        auto impl = std::move(Impl_);
        impl->Finish();
    }
}

////////////////////////////////////////////////////////////////////////////////

class TZstdDecompress::TImpl {
public:
    TImpl(IInputStream* slave, size_t bufferSize)
        : Slave_(slave)
        , ZCtx_(::ZSTD_createDStream())
        , Buffer_(bufferSize)  // do reserve
        , Offset_(0)
    {
        Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_DStream");
        Y_ENSURE(0 != Buffer_.Capacity(), "Buffer size was too small");
    }

    size_t Read(void* buffer, size_t size) {
        Y_ASSERT(size > 0);

        ::ZSTD_outBuffer zOut{buffer, size, 0};
        ::ZSTD_inBuffer zIn{Buffer_.Data(), Buffer_.Size(), Offset_};

        size_t returnCode = 0;
        while (zOut.pos != zOut.size) {
            if (zIn.pos == zIn.size) {
                zIn.size = Slave_->Read(Buffer_.Data(), Buffer_.Capacity());
                Buffer_.Resize(zIn.size);
                zIn.pos = Offset_ = 0;
                if (0 == zIn.size) {
                    // end of stream, need to check that there is no uncompleted blocks
                    Y_ENSURE(0 == returnCode, "Incomplete block");
                    break;
                }
            }
            returnCode = ::ZSTD_decompressStream(ZCtx_.Get(), &zOut, &zIn);
            CheckError("decompress", returnCode);
            if (0 == returnCode) {
                // The frame is over, prepare to (maybe) start a new frame
                ZSTD_initDStream(ZCtx_.Get());
            }
        }
        Offset_ = zIn.pos;
        return zOut.pos;
    }

private:
    IInputStream* Slave_;
    THolder<::ZSTD_DStream, DestroyZDStream> ZCtx_;
    TBuffer Buffer_;
    size_t  Offset_;
};

TZstdDecompress::TZstdDecompress(IInputStream* slave, size_t bufferSize)
    : Impl_(new TImpl(slave, bufferSize)) {
}

TZstdDecompress::~TZstdDecompress() = default;

size_t TZstdDecompress::DoRead(void* buffer, size_t size) {
    return Impl_->Read(buffer, size);
}