diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/benchmark/pool.h | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/benchmark/pool.h')
-rw-r--r-- | library/cpp/linear_regression/benchmark/pool.h | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/benchmark/pool.h b/library/cpp/linear_regression/benchmark/pool.h new file mode 100644 index 0000000000..43288319c8 --- /dev/null +++ b/library/cpp/linear_regression/benchmark/pool.h @@ -0,0 +1,61 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/generic/string.h> + +#include <util/random/mersenne.h> +#include <util/random/shuffle.h> + +struct TInstance { + TVector<double> Features; + double Goal; + double Weight; + + static TInstance FromFeaturesString(const TString& featuresString); +}; + +struct TPool: public TVector<TInstance> { + enum EIteratorType { + LearnIterator, + TestIterator, + }; + + class TCVIterator { + private: + const TPool& ParentPool; + + size_t FoldsCount; + + EIteratorType IteratorType; + size_t TestFoldNumber; + + TVector<size_t> InstanceFoldNumbers; + const size_t* Current; + + TMersenne<ui64> RandomGenerator; + + public: + TCVIterator(const TPool& parentPool, + const size_t foldsCount, + const EIteratorType iteratorType); + + void ResetShuffle(); + + void SetTestFold(const size_t testFoldNumber); + + bool IsValid() const; + + const TInstance& operator*() const; + const TInstance* operator->() const; + TPool::TCVIterator& operator++(); + + private: + void Advance(); + bool TakeCurrent() const; + }; + + void ReadFromFeatures(const TString& featuresPath); + TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const; + + TPool InjurePool(const double injureFactir, const double injureOffset) const; +}; |