aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionAvgWeighted.h
blob: 5a3869032cadcbcb78b2a5917029950d505c787f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#pragma once

#include <type_traits>
#include <AggregateFunctions/AggregateFunctionAvg.h>

namespace DB
{
struct Settings;

template <typename T>
using AvgWeightedFieldType = std::conditional_t<is_decimal<T>,
    std::conditional_t<std::is_same_v<T, Decimal256>, Decimal256, Decimal128>,
    std::conditional_t<DecimalOrExtendedInt<T>,
        Float64, // no way to do UInt128 * UInt128, better cast to Float64
        NearestFieldType<T>>>;

template <typename T, typename U>
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
    AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;

template <typename Value, typename Weight>
class AggregateFunctionAvgWeighted final :
    public AggregateFunctionAvgBase<
        MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
{
public:
    using Base = AggregateFunctionAvgBase<
        MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
    using Base::Base;

    using Numerator = typename Base::Numerator;
    using Denominator = typename Base::Denominator;
    using Fraction = typename Base::Fraction;

    void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
    {
        const auto& weights = static_cast<const ColumnVectorOrDecimal<Weight> &>(*columns[1]);

        this->data(place).numerator += static_cast<Numerator>(
            static_cast<const ColumnVectorOrDecimal<Value> &>(*columns[0]).getData()[row_num]) *
            static_cast<Numerator>(weights.getData()[row_num]);

        this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
    }

    String getName() const override { return "avgWeighted"; }

#if USE_EMBEDDED_COMPILER

    bool isCompilable() const override
    {
        bool can_be_compiled = Base::isCompilable();
        can_be_compiled &= canBeNativeType<Weight>();

        return can_be_compiled;
    }

    void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
    {
        llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);

        auto * numerator_type = toNativeType<Numerator>(b);
        auto * numerator_ptr = aggregate_data_ptr;
        auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);

        auto numerator_data_type = toNativeDataType<Numerator>();
        auto * argument = nativeCast(b, arguments[0], numerator_data_type);
        auto * weight = nativeCast(b, arguments[1], numerator_data_type);

        llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
        auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
        b.CreateStore(numerator_result_value, numerator_ptr);

        auto * denominator_type = toNativeType<Denominator>(b);

        static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
        auto * denominator_ptr = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, denominator_offset);

        auto * weight_cast_to_denominator = nativeCast(b, arguments[1], toNativeDataType<Denominator>());

        auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
        auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);

        b.CreateStore(denominator_value_updated, denominator_ptr);
    }

#endif

};
}