aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/streams/lzma/lzma.cpp
blob: f1942fa546e3e41d12afa4c21476c37ee32573d8 (plain) (tree)































                                                                                                           
                                                    
















                                                                   
                                   





                                     
                                           




                                                 
                                
             
                                                           





                                                   
                                             




                                                  
                                 
             
                                                                






                                             
                                                              



                                         
                             

         
                                   










































                                                        
                                                                     















                                                            
                                    








                                  
                                            
                                       






















                                                  
                              













                                               
                                                  











                                                                                                           
                                 


                                                 
                                                    











                                                                                                      
                                  




















                                                                               
                                           





































                                                                   
                                     

                                                  
                                                                      



                                                     
                                                                                  









                                                                                                
                                                    












                                                        
                             
















                                                                                                            
                                         










                                                                
 


                   

                                                  
       
                                           

                                  
 
                                                            
 







                                                                
 
                    
 
        
                         

                   
                                                                     
                                            
                    

                                                            
 




















                                                                        

                     
                         
 
                                                                              
 
                          
        
                          
                                              
 
                   
                                     
                        


                     
 
                           
  
                                                                


                                    
                                 









                                                                       
                                         
               
                       

     
                                                     
                                   

 
                                                       


                                     
                                     



                                                       
#include "lzma.h"

#include <util/stream/mem.h>
#include <util/system/context.h>
#include <util/generic/cast.h>
#include <util/memory/addstorage.h>
#include <util/generic/ptr.h>
#include <util/generic/intrlist.h>
#include <util/generic/scope.h>

extern "C" {
#include <contrib/libs/lzmasdk/LzmaEnc.h>
#include <contrib/libs/lzmasdk/LzmaDec.h>
}

namespace {
    class TMemoryGc {
        class TAllocation: public TIntrusiveListItem<TAllocation>, public TAdditionalStorage<TAllocation> {
        };

    public:
        inline void* Allocate(size_t len) {
            if (len > 1024 * 1024 * 1024) {
                return nullptr;
            }

            TAllocation* ret = new (len) TAllocation;

            Allocs_.PushBack(ret);

            return ret->AdditionalData();
        }

        inline void Deallocate(void* ptr) noexcept {
            if (ptr) {
                delete TAllocation::ObjectFromData(ptr);
            }
        }

    private:
        TIntrusiveListWithAutoDelete<TAllocation, TDelete> Allocs_;
    };

    template <class T>
    class TInverseFilter {
        class TTrampoLine: public ITrampoLine {
        public:
            inline TTrampoLine(TInverseFilter* parent)
                : Parent_(parent)
            {
            }

            void DoRun() override {
                Parent_->RunFilter();
            }

        private:
            TInverseFilter* Parent_;
        };

        class TInput: public IInputStream {
        public:
            inline TInput(TInverseFilter* parent)
                : Parent_(parent)
            {
            }

            ~TInput() override {
            }

            size_t DoRead(void* ptr, size_t len) override {
                return Parent_->ReadImpl(ptr, len);
            }

        private:
            TInverseFilter* Parent_;
        };

        class TOutput: public IOutputStream {
        public:
            inline TOutput(TInverseFilter* parent)
                : Parent_(parent)
            {
            }

            ~TOutput() override {
            }

            void DoWrite(const void* ptr, size_t len) override {
                Parent_->WriteImpl(ptr, len);
            }

        private:
            TInverseFilter* Parent_;
        };

    public:
        inline TInverseFilter(IOutputStream* slave, T* filter)
            : Slave_(slave)
            , Filter_(filter)
            , TrampoLine_(this)
            , FilterCtx_(FilterClosure())
            , Finished_(false)
            , In_(nullptr, 0)
        {
        }

        virtual ~TInverseFilter() {
            if (!UncaughtException()) {
                try {
                    Finish();
                } catch (...) {
                }
            } else {
                //rely on gc
            }
        }

        inline void Write(const void* ptr, size_t len) {
            In_.Reset(ptr, len);

            Y_DEFER {
                In_.Reset(0, 0);
            };

            while (In_.Avail()) {
                SwitchTo();
            }
        }

        inline void Finish() {
            if (!Finished_) {
                Finished_ = true;
                SwitchTo();
            }
        }

    private:
        inline void RunFilter() {
            try {
                TInput in(this);
                TOutput out(this);

                (*Filter_)(&in, &out);
            } catch (...) {
                Err_ = std::current_exception();
            }

            SwitchFrom();
        }

        inline TContClosure FilterClosure() {
            return {&TrampoLine_, TArrayRef(Stack_, sizeof(Stack_))};
        }

        inline size_t ReadImpl(void* ptr, size_t len) {
            while (!Finished_) {
                const size_t ret = In_.Read(ptr, len);

                if (ret) {
                    return ret;
                }

                SwitchFrom();
            }

            return 0;
        }

        inline void WriteImpl(const void* ptr, size_t len) {
            Y_ASSERT(!Out_.Avail());

            Out_.Reset(ptr, len);

            while (Out_.Avail()) {
                SwitchFrom();
            }
        }

        inline bool FlushImpl() {
            if (Out_.Avail()) {
                TransferData(&Out_, Slave_);
                Out_.Reset(nullptr, 0);

                return true;
            }

            return false;
        }

        inline void SwitchTo() {
            do {
                CurrentCtx_.SwitchTo(&FilterCtx_);

                if (Err_) {
                    Finished_ = true;

                    std::rethrow_exception(Err_);
                }
            } while (FlushImpl());
        }

        inline void SwitchFrom() {
            FilterCtx_.SwitchTo(&CurrentCtx_);
        }

    private:
        IOutputStream* Slave_;
        T* Filter_;
        TTrampoLine TrampoLine_;
        char Stack_[16 * 1024];
        TContMachineContext FilterCtx_;
        TContMachineContext CurrentCtx_;
        bool Finished_;
        TMemoryInput In_;
        TMemoryInput Out_;
        std::exception_ptr Err_;
    };

    class TLzma {
    public:
        class TLzmaInput: public ISeqInStream {
        public:
            inline TLzmaInput(IInputStream* slave)
                : Slave_(slave)
            {
                Read = ReadFunc;
            }

        private:
            static inline SRes ReadFunc(const ISeqInStream* p, void* ptr, size_t* len) {
                *len = const_cast<TLzmaInput*>(static_cast<const TLzmaInput*>(p))->Slave_->Read(ptr, *len);

                return SZ_OK;
            }

        private:
            IInputStream* Slave_;
        };

        class TLzmaOutput: public ISeqOutStream {
        public:
            inline TLzmaOutput(IOutputStream* slave)
                : Slave_(slave)
            {
                Write = WriteFunc;
            }

        private:
            static inline size_t WriteFunc(const ISeqOutStream* p, const void* ptr, size_t len) {
                const_cast<TLzmaOutput*>(static_cast<const TLzmaOutput*>(p))->Slave_->Write(ptr, len);

                return len;
            }

        private:
            IOutputStream* Slave_;
        };

        class TAlloc: public ISzAlloc {
        public:
            inline TAlloc() {
                Alloc = AllocFunc;
                Free = FreeFunc;
            }

        private:
            static void* AllocFunc(const ISzAlloc* t, size_t len) {
                return static_cast<TAlloc*>(((ISzAlloc*)t))->Gc_.Allocate(len);
            }

            static void FreeFunc(const ISzAlloc* t, void* p) {
                static_cast<TAlloc*>(((ISzAlloc*)t))->Gc_.Deallocate(p);
            }

        private:
            TMemoryGc Gc_;
        };

        inline ISzAlloc* Alloc() noexcept {
            return &Alloc_;
        }

        static inline void Check(SRes r) {
            if (r != SZ_OK) {
                ythrow yexception() << "lzma error(" << r << ")";
            }
        }

    private:
        TAlloc Alloc_;
    };

    class TLzmaCompressBase: public TLzma {
    public:
        inline TLzmaCompressBase(size_t level)
            : H_(LzmaEnc_Create(Alloc()))
        {
            if (!H_) {
                ythrow yexception() << "can not init lzma engine";
            }

            LzmaEncProps_Init(&Props_);

            Props_.level = level;
            Props_.dictSize = 0;
            Props_.lc = -1;
            Props_.lp = -1;
            Props_.pb = -1;
            Props_.fb = -1;
            Props_.numThreads = -1;
            Props_.writeEndMark = 1;

            Check(LzmaEnc_SetProps(H_, &Props_));
            size_t bufLen = sizeof(PropsBuf_);
            Zero(PropsBuf_);
            Check(LzmaEnc_WriteProperties(H_, PropsBuf_, &bufLen));
        }

        inline ~TLzmaCompressBase() {
            LzmaEnc_Destroy(H_, Alloc(), Alloc());
        }

        inline void operator()(IInputStream* in, IOutputStream* out) {
            TLzmaInput input(in);
            TLzmaOutput output(out);

            out->Write(PropsBuf_, sizeof(PropsBuf_));

            Check(LzmaEnc_Encode(H_, &output, &input, nullptr, Alloc(), Alloc()));
        }

    private:
        CLzmaEncHandle H_;
        CLzmaEncProps Props_;
        Byte PropsBuf_[LZMA_PROPS_SIZE];
    };
}

class TLzmaCompress::TImpl: public TLzmaCompressBase, public TInverseFilter<TLzmaCompressBase> {
public:
    inline TImpl(IOutputStream* slave, size_t level)
        : TLzmaCompressBase(level)
        , TInverseFilter<TLzmaCompressBase>(slave, this)
    {
    }
};

class TLzmaDecompress::TImpl: public TLzma {
public:
    inline TImpl()
        : InBegin_(nullptr)
        , InEnd_(nullptr)
    {
        LzmaDec_Construct(&H_);
    }
    inline virtual ~TImpl() {
        LzmaDec_Free(&H_, Alloc());
    }

    inline size_t Read(void* ptr, size_t len) {
        Byte* pos = (Byte*)ptr;
        Byte* end = pos + len;

    retry:
        size_t availLen = InEnd_ - InBegin_;
        size_t bufLen = end - pos;
        ELzmaStatus status;

        Check(LzmaDec_DecodeToBuf(&H_, pos, &bufLen, (Byte*)InBegin_, &availLen, LZMA_FINISH_ANY, &status));

        InBegin_ += availLen;
        pos += bufLen;

        if (status == LZMA_STATUS_NEEDS_MORE_INPUT) {
            Y_ASSERT(InEnd_ == InBegin_);
            if (!Fill()) {
                ythrow yexception() << "incomplete lzma stream";
            }

            goto retry;
        }

        return pos - (Byte*)ptr;
    }

private:
    virtual bool Fill() = 0;

protected:
    CLzmaDec H_;
    char* InBegin_;
    char* InEnd_;
};

class TLzmaDecompress::TImplStream: public TImpl {
public:
    inline TImplStream(IInputStream* slave)
        : Slave_(slave)
    {
        Byte buf[LZMA_PROPS_SIZE];

        if (Slave_->Load(buf, sizeof(buf)) != sizeof(buf))
            ythrow yexception() << "can't read lzma header";

        Check(LzmaDec_Allocate(&H_, buf, sizeof(buf), Alloc()));
        LzmaDec_Init(&H_);
    }

private:
    bool Fill() override {
        size_t size = Slave_->Read(In_, sizeof(In_));
        InBegin_ = In_;
        InEnd_ = In_ + size;

        return size;
    }

private:
    IInputStream* Slave_;
    char In_[4096];
};

class TLzmaDecompress::TImplZeroCopy: public TLzmaDecompress::TImpl {
public:
    inline TImplZeroCopy(IZeroCopyInput* in)
        : Input_(in)
    {
        if (!Fill())
            ythrow yexception() << "can't read lzma header";

        char buf[LZMA_PROPS_SIZE];
        char* header;
        if (InEnd_ - InBegin_ >= LZMA_PROPS_SIZE) {
            header = InBegin_;
            InBegin_ += LZMA_PROPS_SIZE;
        } else {
            //bad luck, first part is less than header
            //try to copy header part by part to the local buffer
            const char* end = buf + sizeof(buf);
            char* pos = buf;
            while (1) {
                size_t left = end - pos;
                size_t avail = InEnd_ - InBegin_;
                if (left < avail) {
                    memcpy(pos, InBegin_, left);
                    InBegin_ += left;
                    break;
                } else {
                    memcpy(pos, InBegin_, avail);
                    pos += avail;
                    if (!Fill()) {
                        ythrow yexception() << "can't read lzma header";
                    }
                }
            }
            header = buf;
        }

        Check(LzmaDec_Allocate(&H_, (Byte*)header, LZMA_PROPS_SIZE, Alloc()));

        LzmaDec_Init(&H_);
    }

private:
    bool Fill() override {
        size_t size = Input_->Next(&InBegin_);

        if (size) {
            InEnd_ = InBegin_ + size;

            return true;
        }

        return false;
    }

    IZeroCopyInput* Input_;
};

TLzmaCompress::TLzmaCompress(IOutputStream* slave, size_t level)
    : Impl_(new TImpl(slave, level))
{
}

TLzmaCompress::~TLzmaCompress() {
}

void TLzmaCompress::DoWrite(const void* buf, size_t len) {
    if (!Impl_) {
        ythrow yexception() << "can not write to finished lzma stream";
    }

    Impl_->Write(buf, len);
}

void TLzmaCompress::DoFinish() {
    THolder<TImpl> impl(Impl_.Release());

    if (impl) {
        impl->Finish();
    }
}

TLzmaDecompress::TLzmaDecompress(IInputStream* slave)
    : Impl_(new TImplStream(slave))
{
}

TLzmaDecompress::TLzmaDecompress(IZeroCopyInput* input)
    : Impl_(new TImplZeroCopy(input))
{
}

TLzmaDecompress::~TLzmaDecompress() {
}

size_t TLzmaDecompress::DoRead(void* buf, size_t len) {
    return Impl_->Read(buf, len);
}