aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp
blob: e840005facff6f966105de7d617e097bc3f77ac8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <memory>
#include <type_traits>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>

namespace DB
{
struct Settings;

namespace ErrorCodes
{
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

namespace
{
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
{
    const WhichDataType l_dt(left), r_dt(right);

    constexpr auto allow = [](WhichDataType t)
    {
        return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
    };

    return allow(l_dt) && allow(r_dt);
}

#define AT_SWITCH(LINE) \
    switch (which.idx) \
    { \
        LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
        LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
        LINE(Decimal32); LINE(Decimal64); LINE(Decimal128); LINE(Decimal256); \
        LINE(Float32); LINE(Float64); \
        default: return nullptr; \
    }

template <class First, class ... TArgs>
IAggregateFunction * create(const IDataType & second_type, TArgs && ... args)
{
    const WhichDataType which(second_type);

#define LINE(Type) \
    case TypeIndex::Type:       return new AggregateFunctionAvgWeighted<First, Type>(std::forward<TArgs>(args)...)
    AT_SWITCH(LINE)
#undef LINE
}

// Not using helper functions because there are no templates for binary decimal/numeric function.
template <class... TArgs>
IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
    const WhichDataType which(first_type);

#define LINE(Type) \
    case TypeIndex::Type:       return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
    AT_SWITCH(LINE)
#undef LINE
}

AggregateFunctionPtr
createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
    assertNoParameters(name, parameters);
    assertBinary(name, argument_types);

    const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
    const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);

    if (!allowTypes(data_type, data_type_weight))
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                        "Types {} and {} are non-conforming as arguments for aggregate function {}",
                        data_type->getName(), data_type_weight->getName(), name);

    AggregateFunctionPtr ptr;

    const bool left_decimal = isDecimal(data_type);
    const bool right_decimal = isDecimal(data_type_weight);

    /// We multiply value by weight, so actual scale of numerator is <scale of value> + <scale of weight>
    if (left_decimal && right_decimal)
        ptr.reset(create(*data_type, *data_type_weight,
            argument_types,
            getDecimalScale(*data_type) + getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
    else if (left_decimal)
        ptr.reset(create(*data_type, *data_type_weight, argument_types,
            getDecimalScale(*data_type)));
    else if (right_decimal)
        ptr.reset(create(*data_type, *data_type_weight, argument_types,
            getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
    else
        ptr.reset(create(*data_type, *data_type_weight, argument_types));

    return ptr;
}
}

void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
{
    factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
}
}