aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/AggregateFunctions/AggregateFunctionGroupBitmap.h
blob: a32bb330884b279fdbf97383ea434737f6d6a7ed (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#pragma once

#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Common/assert_cast.h>

// TODO include this last because of a broken roaring header. See the comment inside.
#include <AggregateFunctions/AggregateFunctionGroupBitmapData.h>


namespace DB
{

/// Counts bitmap operation on numbers.
template <typename T, typename Data>
class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>
{
public:
    explicit AggregateFunctionBitmap(const DataTypePtr & type)
        : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
    {
    }

    String getName() const override { return Data::name(); }

    static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }

    bool allocatesMemoryInArena() const override { return false; }

    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
    {
        this->data(place).roaring_bitmap_with_small_set.add(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
    }

    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
    {
        this->data(place).roaring_bitmap_with_small_set.merge(this->data(rhs).roaring_bitmap_with_small_set);
    }

    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
    {
        this->data(place).roaring_bitmap_with_small_set.write(buf);
    }

    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
    {
        this->data(place).roaring_bitmap_with_small_set.read(buf);
    }

    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
    {
        assert_cast<ColumnVector<T> &>(to).getData().push_back(
            static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
    }
};


/// This aggregate function takes the states of AggregateFunctionBitmap as its argument.
template <typename T, typename Data, typename Policy>
class AggregateFunctionBitmapL2 final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>
{
private:
    static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
public:
    explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
        : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
    {
    }

    String getName() const override { return Policy::name; }

    static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }

    bool allocatesMemoryInArena() const override { return false; }

    DataTypePtr getStateType() const override
    {
        return this->argument_types.at(0);
    }

    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
    {
        Data & data_lhs = this->data(place);
        const Data & data_rhs = this->data(assert_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]);
        if (!data_lhs.init)
        {
            data_lhs.init = true;
            data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
        }
        else
        {
            Policy::apply(data_lhs, data_rhs);
        }
    }

    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
    {
        Data & data_lhs = this->data(place);
        const Data & data_rhs = this->data(rhs);

        if (!data_rhs.init)
            return;

        if (!data_lhs.init)
        {
            data_lhs.init = true;
            data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
        }
        else
        {
            Policy::apply(data_lhs, data_rhs);
        }
    }

    bool isVersioned() const override { return true; }

    size_t getDefaultVersion() const override { return 1; }

    size_t getVersionFromRevision(size_t revision) const override
    {
        if (revision >= STATE_VERSION_1_MIN_REVISION)
            return 1;
        else
            return 0;
    }

    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
    {
        if (!version)
            version = getDefaultVersion();

        if (*version >= 1)
            DB::writeBoolText(this->data(place).init, buf);

        this->data(place).roaring_bitmap_with_small_set.write(buf);
    }

    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
    {
        if (!version)
            version = getDefaultVersion();

        if (*version >= 1)
            DB::readBoolText(this->data(place).init, buf);
        this->data(place).roaring_bitmap_with_small_set.read(buf);
    }

    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
    {
        assert_cast<ColumnVector<T> &>(to).getData().push_back(
            static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
    }
};


template <typename Data>
class BitmapAndPolicy
{
public:
    static constexpr auto name = "groupBitmapAnd";
    static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_and(rhs.roaring_bitmap_with_small_set); }
};

template <typename Data>
class BitmapOrPolicy
{
public:
    static constexpr auto name = "groupBitmapOr";
    static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_or(rhs.roaring_bitmap_with_small_set); }
};

template <typename Data>
class BitmapXorPolicy
{
public:
    static constexpr auto name = "groupBitmapXor";
    static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_xor(rhs.roaring_bitmap_with_small_set); }
};

template <typename T, typename Data>
using AggregateFunctionBitmapL2And = AggregateFunctionBitmapL2<T, Data, BitmapAndPolicy<Data>>;

template <typename T, typename Data>
using AggregateFunctionBitmapL2Or = AggregateFunctionBitmapL2<T, Data, BitmapOrPolicy<Data>>;

template <typename T, typename Data>
using AggregateFunctionBitmapL2Xor = AggregateFunctionBitmapL2<T, Data, BitmapXorPolicy<Data>>;

}