1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
#include <memory>
#include <type_traits>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
namespace
{
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
{
const WhichDataType l_dt(left), r_dt(right);
constexpr auto allow = [](WhichDataType t)
{
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
};
return allow(l_dt) && allow(r_dt);
}
#define AT_SWITCH(LINE) \
switch (which.idx) \
{ \
LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
LINE(Decimal32); LINE(Decimal64); LINE(Decimal128); LINE(Decimal256); \
LINE(Float32); LINE(Float64); \
default: return nullptr; \
}
template <class First, class ... TArgs>
IAggregateFunction * create(const IDataType & second_type, TArgs && ... args)
{
const WhichDataType which(second_type);
#define LINE(Type) \
case TypeIndex::Type: return new AggregateFunctionAvgWeighted<First, Type>(std::forward<TArgs>(args)...)
AT_SWITCH(LINE)
#undef LINE
}
// Not using helper functions because there are no templates for binary decimal/numeric function.
template <class... TArgs>
IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
const WhichDataType which(first_type);
#define LINE(Type) \
case TypeIndex::Type: return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
AT_SWITCH(LINE)
#undef LINE
}
AggregateFunctionPtr
createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
if (!allowTypes(data_type, data_type_weight))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Types {} and {} are non-conforming as arguments for aggregate function {}",
data_type->getName(), data_type_weight->getName(), name);
AggregateFunctionPtr ptr;
const bool left_decimal = isDecimal(data_type);
const bool right_decimal = isDecimal(data_type_weight);
/// We multiply value by weight, so actual scale of numerator is <scale of value> + <scale of weight>
if (left_decimal && right_decimal)
ptr.reset(create(*data_type, *data_type_weight,
argument_types,
getDecimalScale(*data_type) + getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
else if (left_decimal)
ptr.reset(create(*data_type, *data_type_weight, argument_types,
getDecimalScale(*data_type)));
else if (right_decimal)
ptr.reset(create(*data_type, *data_type_weight, argument_types,
getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
else
ptr.reset(create(*data_type, *data_type_weight, argument_types));
return ptr;
}
}
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
{
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
}
}
|