aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/welford.h
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/welford.h
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/welford.h')
-rw-r--r--library/cpp/linear_regression/welford.h77
1 files changed, 77 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/welford.h b/library/cpp/linear_regression/welford.h
new file mode 100644
index 0000000000..ee865d6693
--- /dev/null
+++ b/library/cpp/linear_regression/welford.h
@@ -0,0 +1,77 @@
+#pragma once
+
+#include <library/cpp/accurate_accumulate/accurate_accumulate.h>
+
+#include <util/ysaveload.h>
+
+// accurately computes (w_1 * x_1 + w_2 * x_2 + ... + w_n * x_n) / (w_1 + w_2 + ... + w_n)
+class TMeanCalculator {
+private:
+ double Mean = 0.;
+ TKahanAccumulator<double> SumWeights;
+
+public:
+ Y_SAVELOAD_DEFINE(Mean, SumWeights);
+
+ void Multiply(const double value);
+ void Add(const double value, const double weight = 1.);
+ void Remove(const double value, const double weight = 1.);
+ double GetMean() const;
+ double GetSumWeights() const;
+ void Reset();
+
+ bool operator<(const TMeanCalculator& other) const {
+ return Mean < other.Mean;
+ }
+
+ bool operator>(const TMeanCalculator& other) const {
+ return Mean > other.Mean;
+ }
+};
+
+// accurately computes (w_1 * x_1 * y_1 + w_2 * x_2 * y_2 + ... + w_n * x_n * y_n) / (w_1 + w_2 + ... + w_n)
+class TCovariationCalculator {
+private:
+ double Covariation = 0.;
+
+ double FirstValueMean = 0.;
+ double SecondValueMean = 0.;
+
+ TKahanAccumulator<double> SumWeights;
+
+public:
+ Y_SAVELOAD_DEFINE(Covariation, FirstValueMean, SecondValueMean, SumWeights);
+
+ void Add(const double firstValue, const double secondValue, const double weight = 1.);
+ void Remove(const double firstValue, const double secondValue, const double weight = 1.);
+
+ double GetFirstValueMean() const;
+ double GetSecondValueMean() const;
+
+ double GetCovariation() const;
+
+ double GetSumWeights() const;
+
+ void Reset();
+};
+
+// accurately computes (w_1 * x_1 * x_1 + w_2 * x_2 * x_2 + ... + w_n * x_n * x_n) / (w_1 + w_2 + ... + w_n)
+class TDeviationCalculator {
+private:
+ double Deviation = 0.;
+ TMeanCalculator MeanCalculator;
+
+public:
+ Y_SAVELOAD_DEFINE(Deviation, MeanCalculator);
+
+ void Add(const double value, const double weight = 1.);
+ void Remove(const double value, const double weight = 1.);
+
+ double GetMean() const;
+ double GetDeviation() const;
+ double GetStdDev() const;
+
+ double GetSumWeights() const;
+
+ void Reset();
+};