diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/linear_regression/welford.cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/linear_regression/welford.cpp')
-rw-r--r-- | library/cpp/linear_regression/welford.cpp | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/library/cpp/linear_regression/welford.cpp b/library/cpp/linear_regression/welford.cpp new file mode 100644 index 0000000000..e27b1994f6 --- /dev/null +++ b/library/cpp/linear_regression/welford.cpp @@ -0,0 +1,107 @@ +#include "welford.h" + +#include <util/generic/ymath.h> + +void TMeanCalculator::Multiply(const double value) { + SumWeights *= value; +} + +void TMeanCalculator::Add(const double value, const double weight /*= 1.*/) { + SumWeights += weight; + if (SumWeights.Get()) { + Mean += weight * (value - Mean) / SumWeights.Get(); + } +} + +void TMeanCalculator::Remove(const double value, const double weight /*= 1.*/) { + SumWeights -= weight; + if (SumWeights.Get()) { + Mean -= weight * (value - Mean) / SumWeights.Get(); + } +} + +double TMeanCalculator::GetMean() const { + return Mean; +} + +double TMeanCalculator::GetSumWeights() const { + return SumWeights.Get(); +} + +void TMeanCalculator::Reset() { + *this = TMeanCalculator(); +} + +void TCovariationCalculator::Add(const double firstValue, const double secondValue, const double weight /*= 1.*/) { + SumWeights += weight; + if (SumWeights.Get()) { + FirstValueMean += weight * (firstValue - FirstValueMean) / SumWeights.Get(); + Covariation += weight * (firstValue - FirstValueMean) * (secondValue - SecondValueMean); + SecondValueMean += weight * (secondValue - SecondValueMean) / SumWeights.Get(); + } +} + +void TCovariationCalculator::Remove(const double firstValue, const double secondValue, const double weight /*= 1.*/) { + SumWeights -= weight; + if (SumWeights.Get()) { + FirstValueMean -= weight * (firstValue - FirstValueMean) / SumWeights.Get(); + Covariation -= weight * (firstValue - FirstValueMean) * (secondValue - SecondValueMean); + SecondValueMean -= weight * (secondValue - SecondValueMean) / SumWeights.Get(); + } +} + +double TCovariationCalculator::GetFirstValueMean() const { + return FirstValueMean; +} + +double TCovariationCalculator::GetSecondValueMean() const { + return SecondValueMean; +} + +double TCovariationCalculator::GetCovariation() const { + return Covariation; +} + +double TCovariationCalculator::GetSumWeights() const { + return SumWeights.Get(); +} + +void TCovariationCalculator::Reset() { + *this = TCovariationCalculator(); +} + +void TDeviationCalculator::Add(const double value, const double weight /*= 1.*/) { + const double lastMean = MeanCalculator.GetMean(); + MeanCalculator.Add(value, weight); + Deviation += weight * (value - lastMean) * (value - MeanCalculator.GetMean()); +} + +void TDeviationCalculator::Remove(const double value, const double weight /*= 1.*/) { + const double lastMean = MeanCalculator.GetMean(); + MeanCalculator.Remove(value, weight); + Deviation -= weight * (value - lastMean) * (value - MeanCalculator.GetMean()); +} + +double TDeviationCalculator::GetMean() const { + return MeanCalculator.GetMean(); +} + +double TDeviationCalculator::GetDeviation() const { + return Deviation; +} + +double TDeviationCalculator::GetStdDev() const { + const double sumWeights = GetSumWeights(); + if (!sumWeights) { + return 0.; + } + return sqrt(GetDeviation() / sumWeights); +} + +double TDeviationCalculator::GetSumWeights() const { + return MeanCalculator.GetSumWeights(); +} + +void TDeviationCalculator::Reset() { + *this = TDeviationCalculator(); +} |