aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionUniqCombined.cpp
blob: 8c2cb6ea0de960770ccc8c7813adbb4cff094520 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>

#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/Helpers.h>

#include <Common/FieldVisitorConvertToNumber.h>

#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeIPv4andIPv6.h>

#include <functional>


namespace DB
{

struct Settings;

namespace ErrorCodes
{
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
    extern const int ARGUMENT_OUT_OF_BOUND;
}

namespace
{
    template <UInt8 K, typename HashValueType>
    struct WithK
    {
        template <typename T>
        using AggregateFunction = AggregateFunctionUniqCombined<T, K, HashValueType>;

        template <bool is_exact, bool argument_is_tuple>
        using AggregateFunctionVariadic = AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>;
    };

    template <UInt8 K, typename HashValueType>
    AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params)
    {
        /// We use exact hash function if the arguments are not contiguous in memory, because only exact hash function has support for this case.
        bool use_exact_hash_function = !isAllArgumentsContiguousInMemory(argument_types);

        if (argument_types.size() == 1)
        {
            const IDataType & argument_type = *argument_types[0];

            AggregateFunctionPtr res(createWithNumericType<WithK<K, HashValueType>::template AggregateFunction>(*argument_types[0], argument_types, params));

            WhichDataType which(argument_type);
            if (res)
                return res;
            else if (which.isDate())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDate::FieldType>>(argument_types, params);
            else if (which.isDate32())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDate32::FieldType>>(argument_types, params);
            else if (which.isDateTime())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDateTime::FieldType>>(argument_types, params);
            else if (which.isStringOrFixedString())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<String>>(argument_types, params);
            else if (which.isUUID())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeUUID::FieldType>>(argument_types, params);
            else if (which.isIPv4())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv4::FieldType>>(argument_types, params);
            else if (which.isIPv6())
                return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv6::FieldType>>(argument_types, params);
            else if (which.isTuple())
            {
                if (use_exact_hash_function)
                    return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<true, true>>(argument_types, params);
                else
                    return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<false, true>>(argument_types, params);
            }
        }

        /// "Variadic" method also works as a fallback generic case for a single argument.
        if (use_exact_hash_function)
            return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<true, false>>(argument_types, params);
        else
            return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<false, false>>(argument_types, params);
    }

    template <UInt8 K>
    AggregateFunctionPtr createAggregateFunctionWithHashType(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params)
    {
        if (use_64_bit_hash)
            return createAggregateFunctionWithK<K, UInt64>(argument_types, params);
        else
            return createAggregateFunctionWithK<K, UInt32>(argument_types, params);
    }

    AggregateFunctionPtr createAggregateFunctionUniqCombined(bool use_64_bit_hash,
        const std::string & name, const DataTypes & argument_types, const Array & params)
    {
        /// log2 of the number of cells in HyperLogLog.
        /// Reasonable default value, selected to be comparable in quality with "uniq" aggregate function.
        UInt8 precision = 17;

        if (!params.empty())
        {
            if (params.size() != 1)
                throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one parameter or less.",
                    name);

            UInt64 precision_param = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
            // This range is hardcoded below
            if (precision_param > 20 || precision_param < 12)
                throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Parameter for aggregate function {} is out of range: [12, 20].",
                    name);
            precision = precision_param;
        }

        if (argument_types.empty())
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Incorrect number of arguments for aggregate function {}", name);

        switch (precision)
        {
            case 12:
                return createAggregateFunctionWithHashType<12>(use_64_bit_hash, argument_types, params);
            case 13:
                return createAggregateFunctionWithHashType<13>(use_64_bit_hash, argument_types, params);
            case 14:
                return createAggregateFunctionWithHashType<14>(use_64_bit_hash, argument_types, params);
            case 15:
                return createAggregateFunctionWithHashType<15>(use_64_bit_hash, argument_types, params);
            case 16:
                return createAggregateFunctionWithHashType<16>(use_64_bit_hash, argument_types, params);
            case 17:
                return createAggregateFunctionWithHashType<17>(use_64_bit_hash, argument_types, params);
            case 18:
                return createAggregateFunctionWithHashType<18>(use_64_bit_hash, argument_types, params);
            case 19:
                return createAggregateFunctionWithHashType<19>(use_64_bit_hash, argument_types, params);
            case 20:
                return createAggregateFunctionWithHashType<20>(use_64_bit_hash, argument_types, params);
        }

        UNREACHABLE();
    }

}

void registerAggregateFunctionUniqCombined(AggregateFunctionFactory & factory)
{
    using namespace std::placeholders;
    factory.registerFunction("uniqCombined",
        [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
        {
            return createAggregateFunctionUniqCombined(false, name, argument_types, parameters);
        });
    factory.registerFunction("uniqCombined64",
        [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
        {
            return createAggregateFunctionUniqCombined(true, name, argument_types, parameters);
        });
}

}