aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/benchmark/pool.cpp
blob: 5e014e575c4d32d0789fafed371c518096a63a48 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "pool.h"

#include <util/string/cast.h>
#include <util/stream/file.h>

TInstance TInstance::FromFeaturesString(const TString& featuresString) { 
    TInstance instance;

    TStringBuf featuresStringBuf(featuresString);

    featuresStringBuf.NextTok('\t'); // query id
    instance.Goal = FromString(featuresStringBuf.NextTok('\t'));
    featuresStringBuf.NextTok('\t'); // url
    instance.Weight = FromString(featuresStringBuf.NextTok('\t'));

    while (featuresStringBuf) {
        instance.Features.push_back(FromString(featuresStringBuf.NextTok('\t')));
    }

    return instance;
}

TPool::TCVIterator::TCVIterator(const TPool& parentPool, const size_t foldsCount, const EIteratorType iteratorType)
    : ParentPool(parentPool)
    , FoldsCount(foldsCount)
    , IteratorType(iteratorType)
    , InstanceFoldNumbers(ParentPool.size())
{
}

void TPool::TCVIterator::ResetShuffle() {
    TVector<size_t> instanceNumbers(ParentPool.size()); 
    for (size_t instanceNumber = 0; instanceNumber < ParentPool.size(); ++instanceNumber) {
        instanceNumbers[instanceNumber] = instanceNumber;
    }
    Shuffle(instanceNumbers.begin(), instanceNumbers.end(), RandomGenerator);

    for (size_t instancePosition = 0; instancePosition < ParentPool.size(); ++instancePosition) {
        InstanceFoldNumbers[instanceNumbers[instancePosition]] = instancePosition % FoldsCount;
    }
    Current = InstanceFoldNumbers.begin();
}

void TPool::TCVIterator::SetTestFold(const size_t testFoldNumber) {
    TestFoldNumber = testFoldNumber;
    Current = InstanceFoldNumbers.begin();
    Advance();
}

bool TPool::TCVIterator::IsValid() const {
    return Current != InstanceFoldNumbers.end();
}

const TInstance& TPool::TCVIterator::operator*() const {
    return ParentPool[Current - InstanceFoldNumbers.begin()];
}

const TInstance* TPool::TCVIterator::operator->() const {
    return &ParentPool[Current - InstanceFoldNumbers.begin()];
}

TPool::TCVIterator& TPool::TCVIterator::operator++() {
    Advance();
    return *this;
}

void TPool::TCVIterator::Advance() {
    while (IsValid()) {
        ++Current;
        if (IsValid() && TakeCurrent()) {
            break;
        }
    }
}

bool TPool::TCVIterator::TakeCurrent() const {
    switch (IteratorType) {
        case LearnIterator:
            return *Current != TestFoldNumber;
        case TestIterator:
            return *Current == TestFoldNumber;
    }
    return false;
}

void TPool::ReadFromFeatures(const TString& featuresPath) { 
    TFileInput featuresIn(featuresPath); 
    TString featuresString; 
    while (featuresIn.ReadLine(featuresString)) {
        this->push_back(TInstance::FromFeaturesString(featuresString));
    }
}

TPool::TCVIterator TPool::CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const {
    return TPool::TCVIterator(*this, foldsCount, iteratorType);
}

TPool TPool::InjurePool(const double injureFactor, const double injureOffset) const {
    TPool injuredPool(*this);

    for (TInstance& instance : injuredPool) {
        for (double& feature : instance.Features) {
            feature = feature * injureFactor + injureOffset;
        }
        instance.Goal = instance.Goal * injureFactor + injureOffset;
    }

    return injuredPool;
}