aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/benchmark/pool.cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/benchmark/pool.cpp
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/benchmark/pool.cpp')
-rw-r--r--library/cpp/linear_regression/benchmark/pool.cpp109
1 files changed, 109 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/benchmark/pool.cpp b/library/cpp/linear_regression/benchmark/pool.cpp
new file mode 100644
index 0000000000..7f2c6a7004
--- /dev/null
+++ b/library/cpp/linear_regression/benchmark/pool.cpp
@@ -0,0 +1,109 @@
+#include "pool.h"
+
+#include <util/string/cast.h>
+#include <util/stream/file.h>
+
+TInstance TInstance::FromFeaturesString(const TString& featuresString) {
+ TInstance instance;
+
+ TStringBuf featuresStringBuf(featuresString);
+
+ featuresStringBuf.NextTok('\t'); // query id
+ instance.Goal = FromString(featuresStringBuf.NextTok('\t'));
+ featuresStringBuf.NextTok('\t'); // url
+ instance.Weight = FromString(featuresStringBuf.NextTok('\t'));
+
+ while (featuresStringBuf) {
+ instance.Features.push_back(FromString(featuresStringBuf.NextTok('\t')));
+ }
+
+ return instance;
+}
+
+TPool::TCVIterator::TCVIterator(const TPool& parentPool, const size_t foldsCount, const EIteratorType iteratorType)
+ : ParentPool(parentPool)
+ , FoldsCount(foldsCount)
+ , IteratorType(iteratorType)
+ , InstanceFoldNumbers(ParentPool.size())
+{
+}
+
+void TPool::TCVIterator::ResetShuffle() {
+ TVector<size_t> instanceNumbers(ParentPool.size());
+ for (size_t instanceNumber = 0; instanceNumber < ParentPool.size(); ++instanceNumber) {
+ instanceNumbers[instanceNumber] = instanceNumber;
+ }
+ Shuffle(instanceNumbers.begin(), instanceNumbers.end(), RandomGenerator);
+
+ for (size_t instancePosition = 0; instancePosition < ParentPool.size(); ++instancePosition) {
+ InstanceFoldNumbers[instanceNumbers[instancePosition]] = instancePosition % FoldsCount;
+ }
+ Current = InstanceFoldNumbers.begin();
+}
+
+void TPool::TCVIterator::SetTestFold(const size_t testFoldNumber) {
+ TestFoldNumber = testFoldNumber;
+ Current = InstanceFoldNumbers.begin();
+ Advance();
+}
+
+bool TPool::TCVIterator::IsValid() const {
+ return Current != InstanceFoldNumbers.end();
+}
+
+const TInstance& TPool::TCVIterator::operator*() const {
+ return ParentPool[Current - InstanceFoldNumbers.begin()];
+}
+
+const TInstance* TPool::TCVIterator::operator->() const {
+ return &ParentPool[Current - InstanceFoldNumbers.begin()];
+}
+
+TPool::TCVIterator& TPool::TCVIterator::operator++() {
+ Advance();
+ return *this;
+}
+
+void TPool::TCVIterator::Advance() {
+ while (IsValid()) {
+ ++Current;
+ if (IsValid() && TakeCurrent()) {
+ break;
+ }
+ }
+}
+
+bool TPool::TCVIterator::TakeCurrent() const {
+ switch (IteratorType) {
+ case LearnIterator:
+ return *Current != TestFoldNumber;
+ case TestIterator:
+ return *Current == TestFoldNumber;
+ }
+ return false;
+}
+
+void TPool::ReadFromFeatures(const TString& featuresPath) {
+ TFileInput featuresIn(featuresPath);
+ TString featuresString;
+ while (featuresIn.ReadLine(featuresString)) {
+ this->push_back(TInstance::FromFeaturesString(featuresString));
+ }
+}
+
+TPool::TCVIterator TPool::CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const {
+ return TPool::TCVIterator(*this, foldsCount, iteratorType);
+}
+
+TPool TPool::InjurePool(const double injureFactor, const double injureOffset) const {
+ TPool injuredPool(*this);
+
+ for (TInstance& instance : injuredPool) {
+ for (double& feature : instance.Features) {
+ feature = feature * injureFactor + injureOffset;
+ }
+ instance.Goal = instance.Goal * injureFactor + injureOffset;
+ }
+
+ return injuredPool;
+}