diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/benchmark/pool.cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/benchmark/pool.cpp')
-rw-r--r-- | library/cpp/linear_regression/benchmark/pool.cpp | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/benchmark/pool.cpp b/library/cpp/linear_regression/benchmark/pool.cpp new file mode 100644 index 0000000000..7f2c6a7004 --- /dev/null +++ b/library/cpp/linear_regression/benchmark/pool.cpp @@ -0,0 +1,109 @@ +#include "pool.h" + +#include <util/string/cast.h> +#include <util/stream/file.h> + +TInstance TInstance::FromFeaturesString(const TString& featuresString) { + TInstance instance; + + TStringBuf featuresStringBuf(featuresString); + + featuresStringBuf.NextTok('\t'); // query id + instance.Goal = FromString(featuresStringBuf.NextTok('\t')); + featuresStringBuf.NextTok('\t'); // url + instance.Weight = FromString(featuresStringBuf.NextTok('\t')); + + while (featuresStringBuf) { + instance.Features.push_back(FromString(featuresStringBuf.NextTok('\t'))); + } + + return instance; +} + +TPool::TCVIterator::TCVIterator(const TPool& parentPool, const size_t foldsCount, const EIteratorType iteratorType) + : ParentPool(parentPool) + , FoldsCount(foldsCount) + , IteratorType(iteratorType) + , InstanceFoldNumbers(ParentPool.size()) +{ +} + +void TPool::TCVIterator::ResetShuffle() { + TVector<size_t> instanceNumbers(ParentPool.size()); + for (size_t instanceNumber = 0; instanceNumber < ParentPool.size(); ++instanceNumber) { + instanceNumbers[instanceNumber] = instanceNumber; + } + Shuffle(instanceNumbers.begin(), instanceNumbers.end(), RandomGenerator); + + for (size_t instancePosition = 0; instancePosition < ParentPool.size(); ++instancePosition) { + InstanceFoldNumbers[instanceNumbers[instancePosition]] = instancePosition % FoldsCount; + } + Current = InstanceFoldNumbers.begin(); +} + +void TPool::TCVIterator::SetTestFold(const size_t testFoldNumber) { + TestFoldNumber = testFoldNumber; + Current = InstanceFoldNumbers.begin(); + Advance(); +} + +bool TPool::TCVIterator::IsValid() const { + return Current != InstanceFoldNumbers.end(); +} + +const TInstance& TPool::TCVIterator::operator*() const { + return ParentPool[Current - InstanceFoldNumbers.begin()]; +} + +const TInstance* TPool::TCVIterator::operator->() const { + return &ParentPool[Current - InstanceFoldNumbers.begin()]; +} + +TPool::TCVIterator& TPool::TCVIterator::operator++() { + Advance(); + return *this; +} + +void TPool::TCVIterator::Advance() { + while (IsValid()) { + ++Current; + if (IsValid() && TakeCurrent()) { + break; + } + } +} + +bool TPool::TCVIterator::TakeCurrent() const { + switch (IteratorType) { + case LearnIterator: + return *Current != TestFoldNumber; + case TestIterator: + return *Current == TestFoldNumber; + } + return false; +} + +void TPool::ReadFromFeatures(const TString& featuresPath) { + TFileInput featuresIn(featuresPath); + TString featuresString; + while (featuresIn.ReadLine(featuresString)) { + this->push_back(TInstance::FromFeaturesString(featuresString)); + } +} + +TPool::TCVIterator TPool::CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const { + return TPool::TCVIterator(*this, foldsCount, iteratorType); +} + +TPool TPool::InjurePool(const double injureFactor, const double injureOffset) const { + TPool injuredPool(*this); + + for (TInstance& instance : injuredPool) { + for (double& feature : instance.Features) { + feature = feature * injureFactor + injureOffset; + } + instance.Goal = instance.Goal * injureFactor + injureOffset; + } + + return injuredPool; +} |