blob: 88140b7dd1b14bcace56d8dfb9cb6ba5ba8d3b10 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
|
#pragma once
#include <util/generic/vector.h>
#include <util/generic/string.h>
#include <util/random/mersenne.h>
#include <util/random/shuffle.h>
struct TInstance {
TVector<double> Features;
double Goal;
double Weight;
static TInstance FromFeaturesString(const TString& featuresString);
};
struct TPool: public TVector<TInstance> {
enum EIteratorType {
LearnIterator,
TestIterator,
};
class TCVIterator {
private:
const TPool& ParentPool;
size_t FoldsCount;
EIteratorType IteratorType;
size_t TestFoldNumber;
TVector<size_t> InstanceFoldNumbers;
const size_t* Current;
TMersenne<ui64> RandomGenerator;
public:
TCVIterator(const TPool& parentPool,
const size_t foldsCount,
const EIteratorType iteratorType);
void ResetShuffle();
void SetTestFold(const size_t testFoldNumber);
bool IsValid() const;
const TInstance& operator*() const;
const TInstance* operator->() const;
TPool::TCVIterator& operator++();
private:
void Advance();
bool TakeCurrent() const;
};
void ReadFromFeatures(const TString& featuresPath);
TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const;
TPool InjurePool(const double injureFactir, const double injureOffset) const;
};
|