aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/benchmark/pool.cpp
blob: 2460b177ca2621c4cf83918057b8b8fc8b010dd8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "pool.h" 
 
#include <util/string/cast.h> 
#include <util/stream/file.h> 
 
TInstance TInstance::FromFeaturesString(const TString& featuresString) {
    TInstance instance; 
 
    TStringBuf featuresStringBuf(featuresString); 
 
    featuresStringBuf.NextTok('\t'); // query id 
    instance.Goal = FromString(featuresStringBuf.NextTok('\t')); 
    featuresStringBuf.NextTok('\t'); // url 
    instance.Weight = FromString(featuresStringBuf.NextTok('\t')); 
 
    while (featuresStringBuf) { 
        instance.Features.push_back(FromString(featuresStringBuf.NextTok('\t'))); 
    } 
 
    return instance; 
} 
 
TPool::TCVIterator::TCVIterator(const TPool& parentPool, const size_t foldsCount, const EIteratorType iteratorType) 
    : ParentPool(parentPool) 
    , FoldsCount(foldsCount) 
    , IteratorType(iteratorType) 
    , InstanceFoldNumbers(ParentPool.size()) 
{ 
} 
 
void TPool::TCVIterator::ResetShuffle() { 
    TVector<size_t> instanceNumbers(ParentPool.size());
    for (size_t instanceNumber = 0; instanceNumber < ParentPool.size(); ++instanceNumber) { 
        instanceNumbers[instanceNumber] = instanceNumber; 
    } 
    Shuffle(instanceNumbers.begin(), instanceNumbers.end(), RandomGenerator); 
 
    for (size_t instancePosition = 0; instancePosition < ParentPool.size(); ++instancePosition) { 
        InstanceFoldNumbers[instanceNumbers[instancePosition]] = instancePosition % FoldsCount; 
    } 
    Current = InstanceFoldNumbers.begin(); 
} 
 
void TPool::TCVIterator::SetTestFold(const size_t testFoldNumber) { 
    TestFoldNumber = testFoldNumber; 
    Current = InstanceFoldNumbers.begin(); 
    Advance(); 
} 
 
bool TPool::TCVIterator::IsValid() const { 
    return Current != InstanceFoldNumbers.end(); 
} 
 
const TInstance& TPool::TCVIterator::operator*() const {
    return ParentPool[Current - InstanceFoldNumbers.begin()]; 
} 
 
const TInstance* TPool::TCVIterator::operator->() const { 
    return &ParentPool[Current - InstanceFoldNumbers.begin()]; 
} 
 
TPool::TCVIterator& TPool::TCVIterator::operator++() { 
    Advance(); 
    return *this; 
} 
 
void TPool::TCVIterator::Advance() { 
    while (IsValid()) { 
        ++Current; 
        if (IsValid() && TakeCurrent()) { 
            break; 
        } 
    } 
} 
 
bool TPool::TCVIterator::TakeCurrent() const { 
    switch (IteratorType) { 
        case LearnIterator:
            return *Current != TestFoldNumber;
        case TestIterator:
            return *Current == TestFoldNumber;
    } 
    return false; 
} 
 
void TPool::ReadFromFeatures(const TString& featuresPath) {
    TFileInput featuresIn(featuresPath);
    TString featuresString;
    while (featuresIn.ReadLine(featuresString)) { 
        this->push_back(TInstance::FromFeaturesString(featuresString)); 
    } 
} 
 
TPool::TCVIterator TPool::CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const { 
    return TPool::TCVIterator(*this, foldsCount, iteratorType); 
} 
 
TPool TPool::InjurePool(const double injureFactor, const double injureOffset) const { 
    TPool injuredPool(*this); 
 
    for (TInstance& instance : injuredPool) { 
        for (double& feature : instance.Features) { 
            feature = feature * injureFactor + injureOffset; 
        } 
        instance.Goal = instance.Goal * injureFactor + injureOffset; 
    } 
 
    return injuredPool; 
}