blob: 43288319c8bf10dcd9a6d4e9dcfe17c641c92f87 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  | 
#pragma once
#include <util/generic/vector.h>
#include <util/generic/string.h>
#include <util/random/mersenne.h>
#include <util/random/shuffle.h>
struct TInstance {
    TVector<double> Features;
    double Goal;
    double Weight;
    static TInstance FromFeaturesString(const TString& featuresString);
};
struct TPool: public TVector<TInstance> {
    enum EIteratorType {
        LearnIterator,
        TestIterator,
    };
    class TCVIterator {
    private:
        const TPool& ParentPool;
        size_t FoldsCount;
        EIteratorType IteratorType;
        size_t TestFoldNumber;
        TVector<size_t> InstanceFoldNumbers;
        const size_t* Current;
        TMersenne<ui64> RandomGenerator;
    public:
        TCVIterator(const TPool& parentPool,
                    const size_t foldsCount,
                    const EIteratorType iteratorType);
        void ResetShuffle();
        void SetTestFold(const size_t testFoldNumber);
        bool IsValid() const;
        const TInstance& operator*() const;
        const TInstance* operator->() const;
        TPool::TCVIterator& operator++();
    private:
        void Advance();
        bool TakeCurrent() const;
    };
    void ReadFromFeatures(const TString& featuresPath);
    TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const;
    TPool InjurePool(const double injureFactir, const double injureOffset) const;
};
  |