aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionSimpleState.h
blob: 3af7d71395a4dbf3091caa3d116468b5f3206952 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once

#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
#include <DataTypes/DataTypeFactory.h>

namespace DB
{
struct Settings;

/** Not an aggregate function, but an adapter of aggregate functions.
  * Aggregate functions with the `SimpleState` suffix is almost identical to the corresponding ones,
  * except the return type becomes DataTypeCustomSimpleAggregateFunction.
  */
class AggregateFunctionSimpleState final : public IAggregateFunctionHelper<AggregateFunctionSimpleState>
{
private:
    AggregateFunctionPtr nested_func;

public:
    AggregateFunctionSimpleState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
        : IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_, createResultType(nested_, params_))
        , nested_func(nested_)
    {
    }

    String getName() const override { return nested_func->getName() + "SimpleState"; }

    static DataTypePtr createResultType(const AggregateFunctionPtr & nested_, const Array & params_)
    {
        DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_);

        // Need to make a clone to avoid recursive reference.
        auto storage_type_out = DataTypeFactory::instance().get(nested_->getResultType()->getName());
        // Need to make a new function with promoted argument types because SimpleAggregates requires arg_type = return_type.
        AggregateFunctionProperties properties;
        auto function
            = AggregateFunctionFactory::instance().get(nested_->getName(), {storage_type_out}, nested_->getParameters(), properties);

        // Need to make a clone because it'll be customized.
        auto storage_type_arg = DataTypeFactory::instance().get(nested_->getResultType()->getName());
        DataTypeCustomNamePtr custom_name
            = std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_->getResultType()}, params_);
        storage_type_arg->setCustomization(std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
        return storage_type_arg;
    }

    bool isVersioned() const override
    {
        return nested_func->isVersioned();
    }

    size_t getDefaultVersion() const override
    {
        return nested_func->getDefaultVersion();
    }

    bool isState() const override
    {
        return nested_func->isState();
    }

    size_t getVersionFromRevision(size_t revision) const override
    {
        return nested_func->getVersionFromRevision(revision);
    }

    void create(AggregateDataPtr __restrict place) const override { nested_func->create(place); }

    void destroy(AggregateDataPtr __restrict place) const noexcept override { nested_func->destroy(place); }

    void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override { nested_func->destroyUpToState(place); }

    bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); }

    size_t sizeOfData() const override { return nested_func->sizeOfData(); }

    size_t alignOfData() const override { return nested_func->alignOfData(); }

    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
    {
        nested_func->add(place, columns, row_num, arena);
    }

    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
    {
        nested_func->merge(place, rhs, arena);
    }

    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
    {
        nested_func->serialize(place, buf, version);
    }

    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
    {
        nested_func->deserialize(place, buf, version, arena);
    }

    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
    {
        nested_func->insertResultInto(place, to, arena);
    }

    bool allocatesMemoryInArena() const override { return nested_func->allocatesMemoryInArena(); }

    AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
};

}