aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/linear_regression/welford.cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/welford.cpp
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/welford.cpp')
-rw-r--r--library/cpp/linear_regression/welford.cpp107
1 files changed, 107 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/welford.cpp b/library/cpp/linear_regression/welford.cpp
new file mode 100644
index 0000000000..e27b1994f6
--- /dev/null
+++ b/library/cpp/linear_regression/welford.cpp
@@ -0,0 +1,107 @@
+#include "welford.h"
+
+#include <util/generic/ymath.h>
+
+void TMeanCalculator::Multiply(const double value) {
+ SumWeights *= value;
+}
+
+void TMeanCalculator::Add(const double value, const double weight /*= 1.*/) {
+ SumWeights += weight;
+ if (SumWeights.Get()) {
+ Mean += weight * (value - Mean) / SumWeights.Get();
+ }
+}
+
+void TMeanCalculator::Remove(const double value, const double weight /*= 1.*/) {
+ SumWeights -= weight;
+ if (SumWeights.Get()) {
+ Mean -= weight * (value - Mean) / SumWeights.Get();
+ }
+}
+
+double TMeanCalculator::GetMean() const {
+ return Mean;
+}
+
+double TMeanCalculator::GetSumWeights() const {
+ return SumWeights.Get();
+}
+
+void TMeanCalculator::Reset() {
+ *this = TMeanCalculator();
+}
+
+void TCovariationCalculator::Add(const double firstValue, const double secondValue, const double weight /*= 1.*/) {
+ SumWeights += weight;
+ if (SumWeights.Get()) {
+ FirstValueMean += weight * (firstValue - FirstValueMean) / SumWeights.Get();
+ Covariation += weight * (firstValue - FirstValueMean) * (secondValue - SecondValueMean);
+ SecondValueMean += weight * (secondValue - SecondValueMean) / SumWeights.Get();
+ }
+}
+
+void TCovariationCalculator::Remove(const double firstValue, const double secondValue, const double weight /*= 1.*/) {
+ SumWeights -= weight;
+ if (SumWeights.Get()) {
+ FirstValueMean -= weight * (firstValue - FirstValueMean) / SumWeights.Get();
+ Covariation -= weight * (firstValue - FirstValueMean) * (secondValue - SecondValueMean);
+ SecondValueMean -= weight * (secondValue - SecondValueMean) / SumWeights.Get();
+ }
+}
+
+double TCovariationCalculator::GetFirstValueMean() const {
+ return FirstValueMean;
+}
+
+double TCovariationCalculator::GetSecondValueMean() const {
+ return SecondValueMean;
+}
+
+double TCovariationCalculator::GetCovariation() const {
+ return Covariation;
+}
+
+double TCovariationCalculator::GetSumWeights() const {
+ return SumWeights.Get();
+}
+
+void TCovariationCalculator::Reset() {
+ *this = TCovariationCalculator();
+}
+
+void TDeviationCalculator::Add(const double value, const double weight /*= 1.*/) {
+ const double lastMean = MeanCalculator.GetMean();
+ MeanCalculator.Add(value, weight);
+ Deviation += weight * (value - lastMean) * (value - MeanCalculator.GetMean());
+}
+
+void TDeviationCalculator::Remove(const double value, const double weight /*= 1.*/) {
+ const double lastMean = MeanCalculator.GetMean();
+ MeanCalculator.Remove(value, weight);
+ Deviation -= weight * (value - lastMean) * (value - MeanCalculator.GetMean());
+}
+
+double TDeviationCalculator::GetMean() const {
+ return MeanCalculator.GetMean();
+}
+
+double TDeviationCalculator::GetDeviation() const {
+ return Deviation;
+}
+
+double TDeviationCalculator::GetStdDev() const {
+ const double sumWeights = GetSumWeights();
+ if (!sumWeights) {
+ return 0.;
+ }
+ return sqrt(GetDeviation() / sumWeights);
+}
+
+double TDeviationCalculator::GetSumWeights() const {
+ return MeanCalculator.GetSumWeights();
+}
+
+void TDeviationCalculator::Reset() {
+ *this = TDeviationCalculator();
+}