aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/benchmark/pool.h
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/benchmark/pool.h
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/benchmark/pool.h')
-rw-r--r--library/cpp/linear_regression/benchmark/pool.h61
1 files changed, 61 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/benchmark/pool.h b/library/cpp/linear_regression/benchmark/pool.h
new file mode 100644
index 0000000000..43288319c8
--- /dev/null
+++ b/library/cpp/linear_regression/benchmark/pool.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include <util/generic/vector.h>
+#include <util/generic/string.h>
+
+#include <util/random/mersenne.h>
+#include <util/random/shuffle.h>
+
+struct TInstance {
+ TVector<double> Features;
+ double Goal;
+ double Weight;
+
+ static TInstance FromFeaturesString(const TString& featuresString);
+};
+
+struct TPool: public TVector<TInstance> {
+ enum EIteratorType {
+ LearnIterator,
+ TestIterator,
+ };
+
+ class TCVIterator {
+ private:
+ const TPool& ParentPool;
+
+ size_t FoldsCount;
+
+ EIteratorType IteratorType;
+ size_t TestFoldNumber;
+
+ TVector<size_t> InstanceFoldNumbers;
+ const size_t* Current;
+
+ TMersenne<ui64> RandomGenerator;
+
+ public:
+ TCVIterator(const TPool& parentPool,
+ const size_t foldsCount,
+ const EIteratorType iteratorType);
+
+ void ResetShuffle();
+
+ void SetTestFold(const size_t testFoldNumber);
+
+ bool IsValid() const;
+
+ const TInstance& operator*() const;
+ const TInstance* operator->() const;
+ TPool::TCVIterator& operator++();
+
+ private:
+ void Advance();
+ bool TakeCurrent() const;
+ };
+
+ void ReadFromFeatures(const TString& featuresPath);
+ TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const;
+
+ TPool InjurePool(const double injureFactir, const double injureOffset) const;
+};