aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/benchmark/pool.h
blob: 88140b7dd1b14bcace56d8dfb9cb6ba5ba8d3b10 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#pragma once

#include <util/generic/vector.h>
#include <util/generic/string.h> 

#include <util/random/mersenne.h>
#include <util/random/shuffle.h>

struct TInstance {
    TVector<double> Features; 
    double Goal;
    double Weight;

    static TInstance FromFeaturesString(const TString& featuresString); 
};

struct TPool: public TVector<TInstance> {
    enum EIteratorType {
        LearnIterator,
        TestIterator,
    };

    class TCVIterator {
    private:
        const TPool& ParentPool;

        size_t FoldsCount;

        EIteratorType IteratorType;
        size_t TestFoldNumber;

        TVector<size_t> InstanceFoldNumbers; 
        const size_t* Current;

        TMersenne<ui64> RandomGenerator;

    public:
        TCVIterator(const TPool& parentPool,
                    const size_t foldsCount,
                    const EIteratorType iteratorType);

        void ResetShuffle();

        void SetTestFold(const size_t testFoldNumber);

        bool IsValid() const;

        const TInstance& operator*() const;
        const TInstance* operator->() const;
        TPool::TCVIterator& operator++();

    private:
        void Advance();
        bool TakeCurrent() const;
    };

    void ReadFromFeatures(const TString& featuresPath); 
    TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const;

    TPool InjurePool(const double injureFactir, const double injureOffset) const;
};