aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/codecs/codecs.h
blob: cc5e72b2850900b3dd34a59458fe13083bef7998 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12










                                      
                             



                               
                                                                
 
                 
 
                                                      
 




                                       
 



                                             
 

                                                                      
 


                                                         
 


                                                  
 


                                     
 
                                                                                  
 
                                                                                      

                                                                                         
                                                            

                                                                                  
 
                                    
 




                                                      
 



                                             

                                        
 


                                                    
 




                                                                                      
 
                                                      
 
                                                
 
                                                
 

                                                           
 
                                                   
 



                                                       


                                                                                    
 



                                                 
 




                                                                                 
 





                                                                                               
 

                                                                
                                             
                     
 

                                                                 
 


                                                 
 



                                                    
 

                                    
 
                                          
                                      
         
 




                                                             
 

                                       
 
                                          
                                      
         
 





                                                
 

                                          
 


                                                             
 

                                                              
 


                                                 
 
                                             
 
                           
 






                                                                                                                                 
 

                                          
 
                                                                
 
           

                                                                   
                                                  
 
                                             
 


                                                    
 


                                                                           
 
#pragma once

#include "sample.h"

#include <util/generic/bt_exception.h>
#include <util/generic/hash.h>
#include <util/generic/ptr.h>
#include <util/generic/singleton.h>

#include <util/stream/input.h>
#include <util/stream/output.h>

#include <util/string/cast.h>
#include <util/string/vector.h>
#include <util/system/tls.h>
#include <util/ysaveload.h>

namespace NCodecs {
    class TCodecException: public TWithBackTrace<yexception> {};

    class ICodec;

    using TCodecPtr = TIntrusivePtr<ICodec>;
    using TCodecConstPtr = TIntrusiveConstPtr<ICodec>;

    struct TCodecTraits {
        ui32 RecommendedSampleSize = 0;
        ui16 SizeOfInputElement = 1;
        ui8 SizeOnEncodeMultiplier = 1;
        ui8 SizeOnEncodeAddition = 0;
        ui8 SizeOnDecodeMultiplier = 1;

        bool NeedsTraining = false;
        bool PreservesPrefixGrouping = false;
        bool Irreversible = false;
        bool PaddingBit = 0;
        bool AssumesStructuredInput = false;

        size_t ApproximateSizeOnEncode(size_t sz) const {
            return sz * SizeOnEncodeMultiplier + SizeOnEncodeAddition;
        }

        size_t ApproximateSizeOnDecode(size_t sz) const {
            return sz * SizeOnDecodeMultiplier;
        }
    };

    class ICodec: public TAtomicRefCount<ICodec> {
    protected:
        bool Trained = false;
        TCodecTraits MyTraits;

    public:
        TCodecTraits Traits() const {
            return MyTraits;
        }

        // the name of the codec (or its variant) to be used in the codec registry
        virtual TString GetName() const = 0;

        virtual ui8 /*free bits in last byte*/ Encode(TStringBuf, TBuffer&) const = 0;
        virtual ui8 Encode(const TBuffer& input, TBuffer& output) const {
            return Encode(TStringBuf(input.Data(), input.Data() + input.Size()), output);
        }
        virtual void Decode(TStringBuf, TBuffer&) const = 0;
        virtual void Decode(const TBuffer& input, TBuffer& output) const {
            Decode(TStringBuf(input.Data(), input.Data() + input.Size()), output);
        }

        virtual ~ICodec() = default;

        virtual bool AlreadyTrained() const {
            return !Traits().NeedsTraining || Trained;
        }
        virtual void SetTrained(bool t) {
            Trained = t;
        }

        bool TryToLearn(ISequenceReader& r) {
            Trained = DoTryToLearn(r);
            return Trained;
        }

        void Learn(ISequenceReader& r) {
            LearnX(r, 1);
        }

        template <class TIter>
        void Learn(TIter beg, TIter end) {
            Learn(beg, end, IterToStringBuf<TIter>);
        }

        template <class TIter, class TGetter>
        void Learn(TIter beg, TIter end, TGetter getter) {
            auto sample = GetSample(beg, end, Traits().RecommendedSampleSize, getter);
            TSimpleSequenceReader<TBuffer> reader{sample};
            Learn(reader);
        }

        static TCodecPtr GetInstance(TStringBuf name);

        static TVector<TString> GetCodecsList();

        static TString GetNameSafe(TCodecPtr p);

        static void Store(IOutputStream* out, TCodecPtr p);
        static TCodecPtr Restore(IInputStream* in);
        static TCodecPtr RestoreFromString(TStringBuf);

    protected:
        virtual void DoLearn(ISequenceReader&) = 0;

        virtual bool DoTryToLearn(ISequenceReader& r) {
            DoLearn(r);
            return true;
        }

        // so the pipeline codec will know to adjust the sample for the subcodecs
        virtual void DoLearnX(ISequenceReader& r, double /*sampleSizeMultiplier*/) {
            DoLearn(r);
        }

        virtual void Save(IOutputStream*) const {
        }
        virtual void Load(IInputStream*) {
        }
        friend class TPipelineCodec;

    public:
        // so the pipeline codec will know to adjust the sample for the subcodecs
        void LearnX(ISequenceReader& r, double sampleSizeMult) {
            DoLearnX(r, sampleSizeMult);
            Trained = true;
        }

        template <class TIter>
        void LearnX(TIter beg, TIter end, double sampleSizeMult) {
            auto sample = GetSample(beg, end, Traits().RecommendedSampleSize * sampleSizeMult);
            TSimpleSequenceReader<TBuffer> reader{sample};
            LearnX(reader, sampleSizeMult);
        }
    };

    class TBasicTrivialCodec: public ICodec {
    public:
        ui8 Encode(TStringBuf in, TBuffer& out) const override {
            out.Assign(in.data(), in.size());
            return 0;
        }

        void Decode(TStringBuf in, TBuffer& out) const override {
            Encode(in, out);
        }

    protected:
        void DoLearn(ISequenceReader&) override {
        }
    };

    class TTrivialCodec: public TBasicTrivialCodec {
    public:
        TTrivialCodec() {
            MyTraits.PreservesPrefixGrouping = true;
        }

        static TStringBuf MyName() {
            return "trivial";
        }

        TString GetName() const override {
            return ToString(MyName());
        }
    };

    class TTrivialTrainableCodec: public TBasicTrivialCodec {
    public:
        TTrivialTrainableCodec() {
            MyTraits.PreservesPrefixGrouping = true;
            MyTraits.NeedsTraining = true;
        }

        static TStringBuf MyName() {
            return "trivial-trainable";
        }

        TString GetName() const override {
            return ToString(MyName());
        }
    };

    class TNullCodec: public ICodec {
    public:
        TNullCodec() {
            MyTraits.Irreversible = true;
            MyTraits.SizeOnDecodeMultiplier = 0;
            MyTraits.SizeOnEncodeMultiplier = 0;
        }

        TString GetName() const override {
            return "null";
        }

        ui8 Encode(TStringBuf, TBuffer& out) const override {
            out.Clear();
            return 0;
        }

        void Decode(TStringBuf, TBuffer& out) const override {
            out.Clear();
        }

    protected:
        void DoLearn(ISequenceReader&) override {
        }
    };

    class TPipelineCodec: public ICodec {
        typedef TVector<TCodecPtr> TPipeline;

        TPipeline Pipeline;
        TString MyName;

    public:
        explicit TPipelineCodec(TCodecPtr c0 = nullptr, TCodecPtr c1 = nullptr, TCodecPtr c2 = nullptr, TCodecPtr c3 = nullptr) {
            MyTraits.PreservesPrefixGrouping = true;
            AddCodec(c0);
            AddCodec(c1);
            AddCodec(c2);
            AddCodec(c3);
        }

        TString GetName() const override {
            return MyName;
        }

        ui8 Encode(TStringBuf in, TBuffer& out) const override;
        void Decode(TStringBuf in, TBuffer& out) const override;

    public:
        /*
     * Add codecs in the following order:
     * uncompressed -> codec0 | codec1 | ... | codecN -> compressed
     */
        TPipelineCodec& AddCodec(TCodecPtr codec);

        bool AlreadyTrained() const override;
        void SetTrained(bool t) override;

    protected:
        void DoLearn(ISequenceReader& in) override {
            DoLearnX(in, 1);
        }

        void DoLearnX(ISequenceReader& in, double sampleSizeMult) override;
        void Save(IOutputStream* out) const override;
        void Load(IInputStream* in) override;
    };

}