aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/benchmark/pool.h
blob: 4dcf7d7e9e9d5a558f967b3db84eeec4fe1c0fb3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#pragma once

#include <util/generic/vector.h>
#include <util/generic/string.h>

#include <util/random/mersenne.h>
#include <util/random/shuffle.h>

struct TInstance {
    TVector<double> Features;
    double Goal;
    double Weight;

    static TInstance FromFeaturesString(const TString& featuresString);
};

struct TPool: public TVector<TInstance> { 
    enum EIteratorType {
        LearnIterator,
        TestIterator,
    };

    class TCVIterator {
    private:
        const TPool& ParentPool;

        size_t FoldsCount;

        EIteratorType IteratorType;
        size_t TestFoldNumber;

        TVector<size_t> InstanceFoldNumbers;
        const size_t* Current;

        TMersenne<ui64> RandomGenerator;
 
    public:
        TCVIterator(const TPool& parentPool,
                    const size_t foldsCount,
                    const EIteratorType iteratorType);

        void ResetShuffle();

        void SetTestFold(const size_t testFoldNumber);

        bool IsValid() const;

        const TInstance& operator*() const; 
        const TInstance* operator->() const; 
        TPool::TCVIterator& operator++();
 
    private:
        void Advance();
        bool TakeCurrent() const;
    };

    void ReadFromFeatures(const TString& featuresPath);
    TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const;

    TPool InjurePool(const double injureFactir, const double injureOffset) const;
};