aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/sample.h
blob: 8aa62c7abd9c1f6c5fdb1634ca4c17f7352c6c35 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#pragma once

#include <library/cpp/deprecated/accessors/accessors.h>

#include <util/generic/buffer.h>
#include <util/generic/vector.h>
#include <util/random/fast.h>
#include <util/random/shuffle.h>

#include <functional>
#include <type_traits>

namespace NCodecs {
    class ISequenceReader {
    public:
        virtual bool NextRegion(TStringBuf& s) = 0;

        virtual ~ISequenceReader() = default;
    };

    template <class TValue>
    TStringBuf ValueToStringBuf(TValue&& t) {
        return TStringBuf{NAccessors::Begin(t), NAccessors::End(t)};
    }

    template <class TIter>
    TStringBuf IterToStringBuf(TIter iter) {
        return ValueToStringBuf(*iter);
    }

    template <class TItem>
    class TSimpleSequenceReader: public ISequenceReader {
        const TVector<TItem>& Items; 
        size_t Idx = 0;

    public:
        TSimpleSequenceReader(const TVector<TItem>& items) 
            : Items(items)
        {
        }

        bool NextRegion(TStringBuf& s) override {
            if (Idx >= Items.size()) {
                return false;
            }

            s = ValueToStringBuf(Items[Idx++]);
            return true;
        }
    };

    template <class TIter, class TGetter>
    size_t GetInputSize(TIter begin, TIter end, TGetter getter) {
        size_t totalBytes = 0;
        for (TIter iter = begin; iter != end; ++iter) {
            totalBytes += getter(iter).size();
        }
        return totalBytes;
    }

    template <class TIter>
    size_t GetInputSize(TIter begin, TIter end) {
        return GetInputSize(begin, end, IterToStringBuf<TIter>);
    }

    template <class TIter, class TGetter>
    TVector<TBuffer> GetSample(TIter begin, TIter end, size_t sampleSizeBytes, TGetter getter) { 
        TFastRng64 rng{0x1ce1f2e507541a05, 0x07d45659, 0x7b8771030dd9917e, 0x2d6636ce};

        size_t totalBytes = GetInputSize(begin, end, getter);
        double sampleProb = (double)sampleSizeBytes / Max<size_t>(1, totalBytes);

        TVector<TBuffer> result; 
        for (TIter iter = begin; iter != end; ++iter) {
            if (sampleProb >= 1 || rng.GenRandReal1() < sampleProb) {
                TStringBuf reg = getter(iter);
                result.emplace_back(reg.data(), reg.size());
            }
        }
        Shuffle(result.begin(), result.end(), rng);
        return result;
    }

    template <class TIter>
    TVector<TBuffer> GetSample(TIter begin, TIter end, size_t sampleSizeBytes) { 
        return GetSample(begin, end, sampleSizeBytes, IterToStringBuf<TIter>);
    }

}