aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionMap.cpp
blob: b957b541fe10086d5109760d051a2a9baa9182c4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include "AggregateFunctionMap.h"
#include "AggregateFunctions/AggregateFunctionCombinatorFactory.h"
#include "Functions/FunctionHelpers.h"

namespace DB
{
namespace ErrorCodes
{
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

class AggregateFunctionCombinatorMap final : public IAggregateFunctionCombinator
{
public:
    String getName() const override { return "Map"; }

    DataTypes transformArguments(const DataTypes & arguments) const override
    {
        if (arguments.empty())
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
                "Incorrect number of arguments for aggregate function with {} suffix", getName());

        const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());
        if (map_type)
        {
            if (arguments.size() > 1)
                throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "{} combinator takes only one map argument", getName());

            return DataTypes({map_type->getValueType()});
        }

        // we need this part just to pass to redirection for mapped arrays
        auto check_func = [](DataTypePtr t) { return t->getTypeId() == TypeIndex::Array; };

        const auto * tup_type = checkAndGetDataType<DataTypeTuple>(arguments[0].get());
        if (tup_type)
        {
            const auto & types = tup_type->getElements();
            bool arrays_match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func);
            if (arrays_match)
            {
                const auto * val_array_type = assert_cast<const DataTypeArray *>(types[1].get());
                return DataTypes({val_array_type->getNestedType()});
            }
        }
        else
        {
            bool arrays_match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func);
            if (arrays_match)
            {
                const auto * val_array_type = assert_cast<const DataTypeArray *>(arguments[1].get());
                return DataTypes({val_array_type->getNestedType()});
            }
        }

        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} requires map as argument", getName());
    }

    AggregateFunctionPtr transformAggregateFunction(
        const AggregateFunctionPtr & nested_function,
        const AggregateFunctionProperties &,
        const DataTypes & arguments,
        const Array & params) const override
    {
        const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());
        if (map_type)
        {
            const auto & key_type = map_type->getKeyType();

            switch (key_type->getTypeId())
            {
                case TypeIndex::Enum8:
                case TypeIndex::Int8:
                    return std::make_shared<AggregateFunctionMap<Int8>>(nested_function, arguments);
                case TypeIndex::Enum16:
                case TypeIndex::Int16:
                    return std::make_shared<AggregateFunctionMap<Int16>>(nested_function, arguments);
                case TypeIndex::Int32:
                    return std::make_shared<AggregateFunctionMap<Int32>>(nested_function, arguments);
                case TypeIndex::Int64:
                    return std::make_shared<AggregateFunctionMap<Int64>>(nested_function, arguments);
                case TypeIndex::Int128:
                    return std::make_shared<AggregateFunctionMap<Int128>>(nested_function, arguments);
                case TypeIndex::Int256:
                    return std::make_shared<AggregateFunctionMap<Int256>>(nested_function, arguments);
                case TypeIndex::UInt8:
                    return std::make_shared<AggregateFunctionMap<UInt8>>(nested_function, arguments);
                case TypeIndex::Date:
                case TypeIndex::UInt16:
                    return std::make_shared<AggregateFunctionMap<UInt16>>(nested_function, arguments);
                case TypeIndex::DateTime:
                case TypeIndex::UInt32:
                    return std::make_shared<AggregateFunctionMap<UInt32>>(nested_function, arguments);
                case TypeIndex::UInt64:
                    return std::make_shared<AggregateFunctionMap<UInt64>>(nested_function, arguments);
                case TypeIndex::UInt128:
                    return std::make_shared<AggregateFunctionMap<UInt128>>(nested_function, arguments);
                case TypeIndex::UInt256:
                    return std::make_shared<AggregateFunctionMap<UInt256>>(nested_function, arguments);
                case TypeIndex::UUID:
                    return std::make_shared<AggregateFunctionMap<UUID>>(nested_function, arguments);
                case TypeIndex::IPv4:
                    return std::make_shared<AggregateFunctionMap<IPv4>>(nested_function, arguments);
                case TypeIndex::IPv6:
                    return std::make_shared<AggregateFunctionMap<IPv6>>(nested_function, arguments);
                case TypeIndex::FixedString:
                case TypeIndex::String:
                    return std::make_shared<AggregateFunctionMap<String>>(nested_function, arguments);
                default:
                    throw Exception(
                        ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                        "Map key type {} is not is not supported by combinator {}", key_type->getName(), getName());
            }
        }
        else
        {
            // in case of tuple of arrays or just arrays (checked in transformArguments), try to redirect to sum/min/max-MappedArrays to implement old behavior
            auto nested_func_name = nested_function->getName();
            if (nested_func_name == "sum" || nested_func_name == "min" || nested_func_name == "max")
            {
                AggregateFunctionProperties out_properties;
                auto & aggr_func_factory = AggregateFunctionFactory::instance();
                return aggr_func_factory.get(nested_func_name + "MappedArrays", arguments, params, out_properties);
            }
            else
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregation '{}Map' is not implemented for mapped arrays",
                                 nested_func_name);
        }
    }
};

void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory)
{
    factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorMap>());
}

}