blob: 2460b177ca2621c4cf83918057b8b8fc8b010dd8 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
#include "pool.h"
#include <util/string/cast.h>
#include <util/stream/file.h>
TInstance TInstance::FromFeaturesString(const TString& featuresString) {
TInstance instance;
TStringBuf featuresStringBuf(featuresString);
featuresStringBuf.NextTok('\t'); // query id
instance.Goal = FromString(featuresStringBuf.NextTok('\t'));
featuresStringBuf.NextTok('\t'); // url
instance.Weight = FromString(featuresStringBuf.NextTok('\t'));
while (featuresStringBuf) {
instance.Features.push_back(FromString(featuresStringBuf.NextTok('\t')));
}
return instance;
}
TPool::TCVIterator::TCVIterator(const TPool& parentPool, const size_t foldsCount, const EIteratorType iteratorType)
: ParentPool(parentPool)
, FoldsCount(foldsCount)
, IteratorType(iteratorType)
, InstanceFoldNumbers(ParentPool.size())
{
}
void TPool::TCVIterator::ResetShuffle() {
TVector<size_t> instanceNumbers(ParentPool.size());
for (size_t instanceNumber = 0; instanceNumber < ParentPool.size(); ++instanceNumber) {
instanceNumbers[instanceNumber] = instanceNumber;
}
Shuffle(instanceNumbers.begin(), instanceNumbers.end(), RandomGenerator);
for (size_t instancePosition = 0; instancePosition < ParentPool.size(); ++instancePosition) {
InstanceFoldNumbers[instanceNumbers[instancePosition]] = instancePosition % FoldsCount;
}
Current = InstanceFoldNumbers.begin();
}
void TPool::TCVIterator::SetTestFold(const size_t testFoldNumber) {
TestFoldNumber = testFoldNumber;
Current = InstanceFoldNumbers.begin();
Advance();
}
bool TPool::TCVIterator::IsValid() const {
return Current != InstanceFoldNumbers.end();
}
const TInstance& TPool::TCVIterator::operator*() const {
return ParentPool[Current - InstanceFoldNumbers.begin()];
}
const TInstance* TPool::TCVIterator::operator->() const {
return &ParentPool[Current - InstanceFoldNumbers.begin()];
}
TPool::TCVIterator& TPool::TCVIterator::operator++() {
Advance();
return *this;
}
void TPool::TCVIterator::Advance() {
while (IsValid()) {
++Current;
if (IsValid() && TakeCurrent()) {
break;
}
}
}
bool TPool::TCVIterator::TakeCurrent() const {
switch (IteratorType) {
case LearnIterator:
return *Current != TestFoldNumber;
case TestIterator:
return *Current == TestFoldNumber;
}
return false;
}
void TPool::ReadFromFeatures(const TString& featuresPath) {
TFileInput featuresIn(featuresPath);
TString featuresString;
while (featuresIn.ReadLine(featuresString)) {
this->push_back(TInstance::FromFeaturesString(featuresString));
}
}
TPool::TCVIterator TPool::CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const {
return TPool::TCVIterator(*this, foldsCount, iteratorType);
}
TPool TPool::InjurePool(const double injureFactor, const double injureOffset) const {
TPool injuredPool(*this);
for (TInstance& instance : injuredPool) {
for (double& feature : instance.Features) {
feature = feature * injureFactor + injureOffset;
}
instance.Goal = instance.Goal * injureFactor + injureOffset;
}
return injuredPool;
}
|