aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/sample.h
blob: bce37e6a2c2ebc884dc0ab47c78297b874423cda (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#pragma once 
 
#include <library/cpp/deprecated/accessors/accessors.h>
 
#include <util/generic/buffer.h> 
#include <util/generic/vector.h> 
#include <util/random/fast.h> 
#include <util/random/shuffle.h> 
 
#include <functional> 
#include <type_traits> 
 
namespace NCodecs { 
    class ISequenceReader { 
    public: 
        virtual bool NextRegion(TStringBuf& s) = 0; 
 
        virtual ~ISequenceReader() = default; 
    }; 
 
    template <class TValue> 
    TStringBuf ValueToStringBuf(TValue&& t) { 
        return TStringBuf{NAccessors::Begin(t), NAccessors::End(t)}; 
    } 
 
    template <class TIter> 
    TStringBuf IterToStringBuf(TIter iter) {
        return ValueToStringBuf(*iter); 
    } 
 
    template <class TItem> 
    class TSimpleSequenceReader: public ISequenceReader {
        const TVector<TItem>& Items;
        size_t Idx = 0; 
 
    public: 
        TSimpleSequenceReader(const TVector<TItem>& items)
            : Items(items) 
        {
        }
 
        bool NextRegion(TStringBuf& s) override { 
            if (Idx >= Items.size()) { 
                return false; 
            } 
 
            s = ValueToStringBuf(Items[Idx++]); 
            return true; 
        } 
    }; 
 
    template <class TIter, class TGetter> 
    size_t GetInputSize(TIter begin, TIter end, TGetter getter) { 
        size_t totalBytes = 0; 
        for (TIter iter = begin; iter != end; ++iter) { 
            totalBytes += getter(iter).size(); 
        } 
        return totalBytes; 
    } 
 
    template <class TIter> 
    size_t GetInputSize(TIter begin, TIter end) { 
        return GetInputSize(begin, end, IterToStringBuf<TIter>); 
    } 
 
    template <class TIter, class TGetter> 
    TVector<TBuffer> GetSample(TIter begin, TIter end, size_t sampleSizeBytes, TGetter getter) {
        TFastRng64 rng{0x1ce1f2e507541a05, 0x07d45659, 0x7b8771030dd9917e, 0x2d6636ce}; 
 
        size_t totalBytes = GetInputSize(begin, end, getter); 
        double sampleProb = (double)sampleSizeBytes / Max<size_t>(1, totalBytes); 
 
        TVector<TBuffer> result;
        for (TIter iter = begin; iter != end; ++iter) { 
            if (sampleProb >= 1 || rng.GenRandReal1() < sampleProb) { 
                TStringBuf reg = getter(iter); 
                result.emplace_back(reg.data(), reg.size());
            } 
        } 
        Shuffle(result.begin(), result.end(), rng); 
        return result; 
    } 
 
    template <class TIter> 
    TVector<TBuffer> GetSample(TIter begin, TIter end, size_t sampleSizeBytes) {
        return GetSample(begin, end, sampleSizeBytes, IterToStringBuf<TIter>); 
    } 
 
}