aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionNull.cpp
blob: 3d3d7af30260085fa14b4fc79c34e0bf27b9fafc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <DataTypes/DataTypeNullable.h>
#include <AggregateFunctions/AggregateFunctionNull.h>
#include <AggregateFunctions/AggregateFunctionNothing.h>
#include <AggregateFunctions/AggregateFunctionCount.h>
#include <AggregateFunctions/AggregateFunctionState.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>


namespace DB
{

namespace ErrorCodes
{
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

namespace
{

class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinator
{
public:
    String getName() const override { return "Null"; }

    bool isForInternalUsageOnly() const override { return true; }

    DataTypes transformArguments(const DataTypes & arguments) const override
    {
        size_t size = arguments.size();
        DataTypes res(size);
        for (size_t i = 0; i < size; ++i)
        {
            /// Nullable(Nothing) is processed separately, don't convert it to Nothing.
            if (arguments[i]->onlyNull())
                res[i] = arguments[i];
            else
                res[i] = removeNullable(arguments[i]);
        }
        return res;
    }

    AggregateFunctionPtr transformAggregateFunction(
        const AggregateFunctionPtr & nested_function,
        const AggregateFunctionProperties & properties,
        const DataTypes & arguments,
        const Array & params) const override
    {
        bool has_nullable_types = false;
        bool has_null_types = false;
        std::unordered_set<size_t> arguments_that_can_be_only_null;
        if (nested_function)
            arguments_that_can_be_only_null = nested_function->getArgumentsThatCanBeOnlyNull();

        for (size_t i = 0; i < arguments.size(); ++i)
        {
            if (arguments[i]->isNullable())
            {
                has_nullable_types = true;
                if (arguments[i]->onlyNull() && !arguments_that_can_be_only_null.contains(i))
                {
                    has_null_types = true;
                    break;
                }
            }
        }

        if (!has_nullable_types)
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function combinator 'Null' "
                            "requires at least one argument to be Nullable");

        if (has_null_types)
        {
            /// Currently the only functions that returns not-NULL on all NULL arguments are count and uniq, and they returns UInt64.
            if (properties.returns_default_when_only_null)
                return std::make_shared<AggregateFunctionNothing>(arguments, params, std::make_shared<DataTypeUInt64>());
            else
                return std::make_shared<AggregateFunctionNothing>(arguments, params, std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()));
        }

        assert(nested_function);

        if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params, properties))
            return adapter;

        /// If applied to aggregate function with -State combinator, we apply -Null combinator to it's nested_function instead of itself.
        /// Because Nullable AggregateFunctionState does not make sense and ruins the logic of managing aggregate function states.

        if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(nested_function.get()))
        {
            auto transformed_nested_function = transformAggregateFunction(function_state->getNestedFunction(), properties, arguments, params);

            return std::make_shared<AggregateFunctionState>(
                transformed_nested_function,
                transformed_nested_function->getArgumentTypes(),
                transformed_nested_function->getParameters());
        }

        bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getResultType()->canBeInsideNullable();
        bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;

        if (arguments.size() == 1)
        {
            if (return_type_is_nullable)
            {
                return std::make_shared<AggregateFunctionNullUnary<true, true>>(nested_function, arguments, params);
            }
            else
            {
                if (serialize_flag)
                    return std::make_shared<AggregateFunctionNullUnary<false, true>>(nested_function, arguments, params);
                else
                    return std::make_shared<AggregateFunctionNullUnary<false, false>>(nested_function, arguments, params);
            }
        }
        else
        {
            if (return_type_is_nullable)
            {
                return std::make_shared<AggregateFunctionNullVariadic<true, true>>(nested_function, arguments, params);
            }
            else
            {
                if (serialize_flag)
                    return std::make_shared<AggregateFunctionNullVariadic<false, true>>(nested_function, arguments, params);
                else
                    return std::make_shared<AggregateFunctionNullVariadic<false, true>>(nested_function, arguments, params);
            }
        }
    }
};

}

void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory & factory)
{
    factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
}

}