aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/QuantileTDigest.h
blob: 915f6763e524ed4e1ed315318a4f6452b65ef2ea (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
#pragma once

#include <cmath>
#include <Common/Exception.h>
#include <Common/RadixSort.h>
#include <Common/PODArray.h>
#include <Core/AccurateComparison.h>
#include <IO/WriteBuffer.h>
#include <IO/ReadBuffer.h>
#include <IO/VarInt.h>


namespace DB
{
struct Settings;

namespace ErrorCodes
{
    extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED;
    extern const int DECIMAL_OVERFLOW;
    extern const int TOO_LARGE_ARRAY_SIZE;
}


/** The algorithm was implemented by Alexei Borzenkov https://github.com/snaury
  * He owns the authorship of the code and half the comments in this namespace,
  * except for merging, serialization, and sorting, as well as selecting types and other changes.
  * We thank Alexei Borzenkov for writing the original code.
  */

/** Implementation of t-digest algorithm (https://github.com/tdunning/t-digest).
  * This option is very similar to MergingDigest on java, however the decision about
  * the union is accepted based on the original condition from the article
  * (via a size constraint, using the approximation of the quantile of each
  * centroid, not the distance on the curve of the position of their boundaries). MergingDigest
  * on java gives significantly fewer centroids than this variant, that
  * negatively affects accuracy with the same compression factor, but gives
  * size guarantees. The author himself on the proposal for this variant said that
  * the size of the digest grows like O(log(n)), while the version on java
  * does not depend on the expected number of points. Also an variant on java
  * uses asin, which slows down the algorithm a bit.
  */
template <typename T>
class QuantileTDigest
{
    using Value = Float32;
    using Count = Float32;
    using BetterFloat = Float64; // For intermediate results and sum(Count). Must have better precision, than Count

    /** The centroid stores the weight of points around their mean value
      */
    struct Centroid
    {
        Value mean;
        Count count;

        Centroid() = default;

        explicit Centroid(Value mean_, Count count_)
            : mean(mean_)
            , count(count_)
        {}

        bool operator<(const Centroid & other) const
        {
            return mean < other.mean;
        }
    };


    /** :param epsilon: value \delta from the article - error in the range
      *                    quantile 0.5 (default is 0.01, i.e. 1%)
      *                    if you change epsilon, you must also change max_centroids
      * :param max_centroids: depends on epsilon, the better accuracy, the more centroids you need
      *                       to describe data with this accuracy. Read article before changing.
      * :param max_unmerged: when accumulating count of new points beyond this
      *                      value centroid compression is triggered
      *                      (default is 2048, the higher the value - the
      *                      more memory is required, but amortization of execution time increases)
      *                      Change freely anytime.
      */
    struct Params
    {
        Value epsilon = 0.01f;
        size_t max_centroids = 2048;
        size_t max_unmerged = 2048;
    };
    /** max_centroids_deserialize should be >= all max_centroids ever used in production.
     *  This is security parameter, preventing allocation of too much centroids in deserialize, so can be relatively large.
     */
    static constexpr size_t max_centroids_deserialize = 65536;

    static constexpr Params params{};

    static constexpr size_t bytes_in_arena = 128 - sizeof(PODArray<Centroid>) - sizeof(BetterFloat) - sizeof(size_t); // If alignment is imperfect, sizeof(TDigest) will be more than naively expected
    using Centroids = PODArrayWithStackMemory<Centroid, bytes_in_arena>;

    Centroids centroids;
    BetterFloat count = 0;
    size_t unmerged = 0;

    /// Linear interpolation at the point x on the line (x1, y1)..(x2, y2)
    static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
    {
        /// Symmetric interpolation for better results with infinities.
        double k = (x - x1) / (x2 - x1);
        return static_cast<Value>((1 - k) * y1 + k * y2);
    }

    struct RadixSortTraits
    {
        using Element = Centroid;
        using Result = Element;
        using Key = Value;
        using CountType = UInt32;
        using KeyBits = UInt32;

        static constexpr size_t PART_SIZE_BITS = 8;

        using Transform = RadixSortFloatTransform<KeyBits>;
        using Allocator = RadixSortAllocator;

        /// The function to get the key from an array element.
        static Key & extractKey(Element & elem) { return elem.mean; }
        static Result & extractResult(Element & elem) { return elem; }
    };

    /** Adds a centroid `c` to the digest
     * centroid must be valid, validity is checked in add(), deserialize() and is maintained by compress()
      */
    void addCentroid(const Centroid & c)
    {
        centroids.push_back(c);
        count += c.count;
        ++unmerged;
        if (unmerged > params.max_unmerged)
            compress();
    }

    inline bool canBeMerged(const BetterFloat & l_mean, const Value & r_mean)
    {
        return l_mean == r_mean || (!std::isinf(l_mean) && !std::isinf(r_mean));
    }

    void compressBrute()
    {
        if (centroids.size() <= params.max_centroids)
            return;
        const size_t batch_size = (centroids.size() + params.max_centroids - 1) / params.max_centroids; // at least 2

        auto l = centroids.begin();
        auto r = std::next(l);
        BetterFloat sum = 0;
        BetterFloat l_mean = l->mean; // We have high-precision temporaries for numeric stability
        BetterFloat l_count = l->count;
        size_t batch_pos = 0;

        for (; r != centroids.end(); ++r)
        {
            if (batch_pos < batch_size - 1)
            {
                /// The left column "eats" the right. Middle of the batch
                l_count += r->count;
                if (r->mean != l_mean) /// Handling infinities of the same sign well.
                {
                    l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
                }
                l->mean = static_cast<Value>(l_mean);
                l->count = static_cast<Value>(l_count);
                batch_pos += 1;
            }
            else
            {
                // End of the batch, start the next one
                if (!std::isnan(l->mean)) /// Skip writing batch result if we compressed something to nan.
                {
                    sum += l->count; // Not l_count, otherwise actual sum of elements will be different
                    ++l;
                }

                /// We skip all the values "eaten" earlier.
                *l = *r;
                l_mean = l->mean;
                l_count = l->count;
                batch_pos = 0;
            }
        }

        if (!std::isnan(l->mean))
        {
            count = sum + l_count; // Update count, it might be different due to += inaccuracy
            centroids.resize(l - centroids.begin() + 1);
        }
        else /// Skip writing last batch if (super unlikely) it's nan.
        {
            count = sum;
            centroids.resize(l - centroids.begin());
        }
        // Here centroids.size() <= params.max_centroids
    }

public:
    /** Performs compression of accumulated centroids
      * When merging, the invariant is retained to the maximum size of each
      * centroid that does not exceed `4 q (1 - q) \ delta N`.
      */
    void compress()
    {
        if (unmerged > 0 || centroids.size() > params.max_centroids)
        {
            // unmerged > 0 implies centroids.size() > 0, hence *l is valid below
            RadixSort<RadixSortTraits>::executeLSD(centroids.data(), centroids.size());

            /// A pair of consecutive bars of the histogram.
            auto l = centroids.begin();
            auto r = std::next(l);

            const BetterFloat count_epsilon_4 = count * params.epsilon * 4; // Compiler is unable to do this optimization
            BetterFloat sum = 0;
            BetterFloat l_mean = l->mean; // We have high-precision temporaries for numeric stability
            BetterFloat l_count = l->count;
            while (r != centroids.end())
            {
                /// N.B. We cannot merge all the same values into single centroids because this will lead to
                /// unbalanced compression and wrong results.
                /// For more information see: https://arxiv.org/abs/1902.04023

                /// The ratio of the part of the histogram to l, including the half l to the entire histogram. That is, what level quantile in position l.
                BetterFloat ql = (sum + l_count * 0.5) / count;
                BetterFloat err = ql * (1 - ql);

                /// The ratio of the portion of the histogram to l, including l and half r to the entire histogram. That is, what level is the quantile in position r.
                BetterFloat qr = (sum + l_count + r->count * 0.5) / count;
                BetterFloat err2 = qr * (1 - qr);

                if (err > err2)
                    err = err2;

                BetterFloat k = count_epsilon_4 * err;

                /** The ratio of the weight of the glued column pair to all values is not greater,
                  *  than epsilon multiply by a certain quadratic coefficient, which in the median is 1 (4 * 1/2 * 1/2),
                  *  and at the edges decreases and is approximately equal to the distance to the edge * 4.
                  */

                if (l_count + r->count <= k && canBeMerged(l_mean, r->mean))
                {
                    // it is possible to merge left and right
                    /// The left column "eats" the right.
                    l_count += r->count;
                    if (r->mean != l_mean) /// Handling infinities of the same sign well.
                    {
                        l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
                    }
                    l->mean = static_cast<Value>(l_mean);
                    l->count = static_cast<Value>(l_count);
                }
                else
                {
                    // not enough capacity, check the next pair
                    sum += l->count; // Not l_count, otherwise actual sum of elements will be different
                    ++l;

                    /// We skip all the values "eaten" earlier.
                    if (l != r)
                        *l = *r;
                    l_mean = l->mean;
                    l_count = l->count;
                }
                ++r;
            }
            count = sum + l_count; // Update count, it might be different due to += inaccuracy

            /// At the end of the loop, all values to the right of l were "eaten".
            centroids.resize(l - centroids.begin() + 1);
            unmerged = 0;
        }

        // Ensures centroids.size() < max_centroids, independent of unprovable floating point blackbox above
        compressBrute();
    }

    /** Adds to the digest a change in `x` with a weight of `cnt` (default 1)
      */
    void add(T x, UInt64 cnt = 1)
    {
        auto vx = static_cast<Value>(x);
        if (cnt == 0 || std::isnan(vx))
            return; // Count 0 breaks compress() assumptions, Nan breaks sort(). We treat them as no sample.
        addCentroid(Centroid{vx, static_cast<Count>(cnt)});
    }

    void merge(const QuantileTDigest & other)
    {
        for (const auto & c : other.centroids)
            addCentroid(c);
    }

    void serialize(WriteBuffer & buf)
    {
        compress();
        writeVarUInt(centroids.size(), buf);
        buf.write(reinterpret_cast<const char *>(centroids.data()), centroids.size() * sizeof(centroids[0]));
    }

    void deserialize(ReadBuffer & buf)
    {
        size_t size = 0;
        readVarUInt(size, buf);

        if (size > max_centroids_deserialize)
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large t-digest centroids size");

        count = 0;
        unmerged = 0;

        centroids.resize(size);
        // From now, TDigest will be in invalid state if exception is thrown.
        buf.readStrict(reinterpret_cast<char *>(centroids.data()), size * sizeof(centroids[0]));

        for (const auto & c : centroids)
        {
            if (c.count <= 0 || std::isnan(c.count)) // invalid count breaks compress()
                throw Exception(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "Invalid centroid {}:{}", c.count, std::to_string(c.mean));
            if (!std::isnan(c.mean))
            {
                count += c.count;
            }
        }

        auto it = std::remove_if(centroids.begin(), centroids.end(), [](Centroid & c) { return std::isnan(c.mean); });
        centroids.erase(it, centroids.end());

        compress(); // Allows reading/writing TDigests with different epsilon/max_centroids params
    }

    /** Calculates the quantile q [0, 1] based on the digest.
      * For an empty digest returns NaN.
      */
    template <typename ResultType>
    ResultType getImpl(Float64 level)
    {
        if (centroids.empty())
            return std::is_floating_point_v<ResultType> ? std::numeric_limits<ResultType>::quiet_NaN() : 0;

        compress();

        if (centroids.size() == 1)
            return checkOverflow<ResultType>(centroids.front().mean);

        Float64 x = level * count;
        Float64 prev_x = 0;
        Count sum = 0;
        Value prev_mean = centroids.front().mean;
        Count prev_count = centroids.front().count;

        for (const auto & c : centroids)
        {
            Float64 current_x = sum + c.count * 0.5;

            if (current_x >= x)
            {
                /// Special handling of singletons.
                Float64 left = prev_x + 0.5 * (prev_count == 1);
                Float64 right = current_x - 0.5 * (c.count == 1);

                if (x <= left)
                    return checkOverflow<ResultType>(prev_mean);
                else if (x >= right)
                    return checkOverflow<ResultType>(c.mean);
                else
                    return checkOverflow<ResultType>(interpolate(
                        static_cast<Value>(x),
                        static_cast<Value>(left),
                        prev_mean,
                        static_cast<Value>(right),
                        c.mean));
            }

            sum += c.count;
            prev_mean = c.mean;
            prev_count = c.count;
            prev_x = current_x;
        }

        return checkOverflow<ResultType>(centroids.back().mean);
    }

    /** Get multiple quantiles (`size` parts).
      * levels - an array of levels of the desired quantiles. They are in a random order.
      * levels_permutation - array-permutation levels. The i-th position will be the index of the i-th ascending level in the `levels` array.
      * result - the array where the results are added, in order of `levels`,
      */
    template <typename ResultType>
    void getManyImpl(const Float64 * levels, const size_t * levels_permutation, size_t size, ResultType * result)
    {
        if (centroids.empty())
        {
            for (size_t result_num = 0; result_num < size; ++result_num)
                result[result_num] = std::is_floating_point_v<ResultType> ? NAN : 0;
            return;
        }

        compress();

        if (centroids.size() == 1)
        {
            for (size_t result_num = 0; result_num < size; ++result_num)
                result[result_num] = static_cast<ResultType>(centroids.front().mean);
            return;
        }

        Float64 x = levels[levels_permutation[0]] * count;
        Float64 prev_x = 0;
        Count sum = 0;
        Value prev_mean = centroids.front().mean;
        Count prev_count = centroids.front().count;

        size_t result_num = 0;
        for (const auto & c : centroids)
        {
            Float64 current_x = sum + c.count * 0.5;

            if (current_x >= x)
            {
                /// Special handling of singletons.
                Float64 left = prev_x + 0.5 * (prev_count == 1);
                Float64 right = current_x - 0.5 * (c.count == 1);

                while (current_x >= x)
                {
                    if (x <= left)
                        result[levels_permutation[result_num]] = static_cast<ResultType>(prev_mean);
                    else if (x >= right)
                        result[levels_permutation[result_num]] = static_cast<ResultType>(c.mean);
                    else
                        result[levels_permutation[result_num]] = static_cast<ResultType>(interpolate(
                            static_cast<Value>(x), static_cast<Value>(left), prev_mean, static_cast<Value>(right), c.mean));

                    ++result_num;
                    if (result_num >= size)
                        return;

                    x = levels[levels_permutation[result_num]] * count;
                }
            }

            sum += c.count;
            prev_mean = c.mean;
            prev_count = c.count;
            prev_x = current_x;
        }

        auto rest_of_results = centroids.back().mean;
        for (; result_num < size; ++result_num)
            result[levels_permutation[result_num]] = static_cast<ResultType>(rest_of_results);
    }

    T get(Float64 level)
    {
        return getImpl<T>(level);
    }

    Float32 getFloat(Float64 level)
    {
        return getImpl<Float32>(level);
    }

    void getMany(const Float64 * levels, const size_t * indices, size_t size, T * result)
    {
        getManyImpl(levels, indices, size, result);
    }

    void getManyFloat(const Float64 * levels, const size_t * indices, size_t size, Float32 * result)
    {
        getManyImpl(levels, indices, size, result);
    }

private:
    template <typename ResultType>
    static ResultType checkOverflow(Value val)
    {
        ResultType result;
        if (accurate::convertNumeric(val, result))
            return result;
        throw DB::Exception(ErrorCodes::DECIMAL_OVERFLOW, "Numeric overflow");
    }
};

}