diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/welford.h | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/welford.h')
-rw-r--r-- | library/cpp/linear_regression/welford.h | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/welford.h b/library/cpp/linear_regression/welford.h new file mode 100644 index 0000000000..ee865d6693 --- /dev/null +++ b/library/cpp/linear_regression/welford.h @@ -0,0 +1,77 @@ +#pragma once + +#include <library/cpp/accurate_accumulate/accurate_accumulate.h> + +#include <util/ysaveload.h> + +// accurately computes (w_1 * x_1 + w_2 * x_2 + ... + w_n * x_n) / (w_1 + w_2 + ... + w_n) +class TMeanCalculator { +private: + double Mean = 0.; + TKahanAccumulator<double> SumWeights; + +public: + Y_SAVELOAD_DEFINE(Mean, SumWeights); + + void Multiply(const double value); + void Add(const double value, const double weight = 1.); + void Remove(const double value, const double weight = 1.); + double GetMean() const; + double GetSumWeights() const; + void Reset(); + + bool operator<(const TMeanCalculator& other) const { + return Mean < other.Mean; + } + + bool operator>(const TMeanCalculator& other) const { + return Mean > other.Mean; + } +}; + +// accurately computes (w_1 * x_1 * y_1 + w_2 * x_2 * y_2 + ... + w_n * x_n * y_n) / (w_1 + w_2 + ... + w_n) +class TCovariationCalculator { +private: + double Covariation = 0.; + + double FirstValueMean = 0.; + double SecondValueMean = 0.; + + TKahanAccumulator<double> SumWeights; + +public: + Y_SAVELOAD_DEFINE(Covariation, FirstValueMean, SecondValueMean, SumWeights); + + void Add(const double firstValue, const double secondValue, const double weight = 1.); + void Remove(const double firstValue, const double secondValue, const double weight = 1.); + + double GetFirstValueMean() const; + double GetSecondValueMean() const; + + double GetCovariation() const; + + double GetSumWeights() const; + + void Reset(); +}; + +// accurately computes (w_1 * x_1 * x_1 + w_2 * x_2 * x_2 + ... + w_n * x_n * x_n) / (w_1 + w_2 + ... + w_n) +class TDeviationCalculator { +private: + double Deviation = 0.; + TMeanCalculator MeanCalculator; + +public: + Y_SAVELOAD_DEFINE(Deviation, MeanCalculator); + + void Add(const double value, const double weight = 1.); + void Remove(const double value, const double weight = 1.); + + double GetMean() const; + double GetDeviation() const; + double GetStdDev() const; + + double GetSumWeights() const; + + void Reset(); +}; |