diff options
author | chertus <azuikov@ydb.tech> | 2023-01-23 15:29:29 +0300 |
---|---|---|
committer | chertus <azuikov@ydb.tech> | 2023-01-23 15:29:29 +0300 |
commit | 08d7aec1386d15c53ded0a9239d592e74fce66df (patch) | |
tree | 1be1e267cb268540bc7ad00289a7ffd7bb627e63 | |
parent | 0e91c7a392b2e6d46a4dde191a29073dfc8d434b (diff) | |
download | ydb-08d7aec1386d15c53ded0a9239d592e74fce66df.tar.gz |
fix nulls in SSA aggregates
8 files changed, 283 insertions, 125 deletions
diff --git a/ydb/core/formats/ut_program_step.cpp b/ydb/core/formats/ut_program_step.cpp index a043779ef3..f89b9f0f7a 100644 --- a/ydb/core/formats/ut_program_step.cpp +++ b/ydb/core/formats/ut_program_step.cpp @@ -56,68 +56,189 @@ size_t FilterTestUnary(std::vector<std::shared_ptr<arrow::Array>> args, EOperati return batch->num_rows(); } -void SumGroupBy(bool nullable, ui32 numKeys = 1, bool emptySrc = false) { - std::optional<double> null; - if (nullable) { - null = 0; +enum class ETest { + DEFAULT, + EMPTY, + ONE_VALUE +}; + +struct TSumData { + static std::shared_ptr<arrow::RecordBatch> Data(ETest test, + std::shared_ptr<arrow::Schema>& schema, + bool nullable) + { + std::optional<double> null; + if (nullable) { + null = 0; + } + + if (test == ETest::DEFAULT) { + return arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null), + NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)}); + } else if (test == ETest::EMPTY) { + return arrow::RecordBatch::Make(schema, 0, std::vector{NumVecToArray(arrow::int16(), {}), + NumVecToArray(arrow::uint32(), {})}); + } else if (test == ETest::ONE_VALUE) { + return arrow::RecordBatch::Make(schema, 1, std::vector{NumVecToArray(arrow::int16(), {1}), + NumVecToArray(arrow::uint32(), {0}, null)}); + } + return {}; } + static void CheckResult(ETest test, const std::shared_ptr<arrow::RecordBatch>& batch, ui32 numKeys, bool nullable) { + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2); + UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT64); + UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT64); + UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16); + if (numKeys == 2) { + UNIT_ASSERT_EQUAL(batch->column(3)->type_id(), arrow::Type::UINT32); + } + + if (test == ETest::EMPTY) { + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 0); + return; + } + + auto& aggX = static_cast<arrow::Int64Array&>(*batch->column(0)); + auto& aggY = static_cast<arrow::UInt64Array&>(*batch->column(1)); + auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2)); + + if (test == ETest::ONE_VALUE) { + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + + UNIT_ASSERT_VALUES_EQUAL(aggX.Value(0), 1); + if (nullable) { + UNIT_ASSERT(aggY.IsNull(0)); + } else { + UNIT_ASSERT(!aggY.IsNull(0)); + UNIT_ASSERT_VALUES_EQUAL(aggY.Value(0), 0); + } + return; + } + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 2); + + for (ui32 row = 0; row < 2; ++row) { + if (colX.IsNull(row)) { + UNIT_ASSERT(aggX.IsNull(row)); + UNIT_ASSERT(aggY.IsNull(row)); + } else { + UNIT_ASSERT(!aggX.IsNull(row)); + UNIT_ASSERT(!aggY.IsNull(row)); + if (colX.Value(row) == 0) { + UNIT_ASSERT_VALUES_EQUAL(aggX.Value(row), 0); + UNIT_ASSERT_VALUES_EQUAL(aggY.Value(row), 0); + } else if (colX.Value(row) == -1) { + UNIT_ASSERT_VALUES_EQUAL(aggX.Value(row), -2); + UNIT_ASSERT_VALUES_EQUAL(aggY.Value(row), 2); + } else { + UNIT_ASSERT(false); + } + } + } + } +}; + +struct TMinMaxSomeData { + static std::shared_ptr<arrow::RecordBatch> Data(ETest /*test*/, + std::shared_ptr<arrow::Schema>& schema, + bool nullable) + { + std::optional<double> null; + if (nullable) { + null = 0; + } + + return arrow::RecordBatch::Make(schema, 1, std::vector{NumVecToArray(arrow::int16(), {1}), + NumVecToArray(arrow::uint32(), {0}, null)}); + } + + static void CheckResult(ETest /*test*/, const std::shared_ptr<arrow::RecordBatch>& batch, ui32 numKeys, + bool nullable) { + UNIT_ASSERT_VALUES_EQUAL(numKeys, 1); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2); + UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT16); + UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT32); + UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16); + + auto& aggX = static_cast<arrow::Int16Array&>(*batch->column(0)); + auto& aggY = static_cast<arrow::UInt32Array&>(*batch->column(1)); + auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2)); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + + UNIT_ASSERT_VALUES_EQUAL(colX.Value(0), 1); + UNIT_ASSERT_VALUES_EQUAL(aggX.Value(0), 1); + if (nullable) { + UNIT_ASSERT(aggY.IsNull(0)); + } else { + UNIT_ASSERT(!aggY.IsNull(0)); + UNIT_ASSERT_VALUES_EQUAL(aggY.Value(0), 0); + } + return; + } +}; + +void GroupByXY(bool nullable, ui32 numKeys, ETest test = ETest::DEFAULT, + EAggregate aggFunc = EAggregate::Sum) { auto schema = std::make_shared<arrow::Schema>(std::vector{ std::make_shared<arrow::Field>("x", arrow::int16()), std::make_shared<arrow::Field>("y", arrow::uint32())}); + std::shared_ptr<arrow::RecordBatch> batch; - if (emptySrc) { - batch = arrow::RecordBatch::Make(schema, 0, std::vector{NumVecToArray(arrow::int16(), {}), - NumVecToArray(arrow::uint32(), {})}); - } else { - batch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null), - NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)}); + switch (aggFunc) { + case EAggregate::Sum: + batch = TSumData::Data(test, schema, nullable); + break; + case EAggregate::Min: + case EAggregate::Max: + case EAggregate::Some: + batch = TMinMaxSomeData::Data(test, schema, nullable); + break; + default: + break; } - UNIT_ASSERT(batch->ValidateFull().ok()); + UNIT_ASSERT(batch); + auto status = batch->ValidateFull(); + if (!status.ok()) { + Cerr << status.ToString() << "\n"; + } + UNIT_ASSERT(status.ok()); auto step = std::make_shared<TProgramStep>(); step->GroupBy = { - TAggregateAssign("sum_x", EAggregate::Sum, {"x"}), - TAggregateAssign("sum_y", EAggregate::Sum, {"y"}) + TAggregateAssign("agg_x", aggFunc, {"x"}), + TAggregateAssign("agg_y", aggFunc, {"y"}) }; step->GroupByKeys.push_back("x"); if (numKeys == 2) { step->GroupByKeys.push_back("y"); } - auto status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext()); + status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext()); if (!status.ok()) { Cerr << status.ToString() << "\n"; } UNIT_ASSERT(status.ok()); - UNIT_ASSERT(batch->ValidateFull().ok()); - UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2); - UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), (emptySrc ? 0 : 2)); - UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT64); - UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT64); - UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16); - if (numKeys == 2) { - UNIT_ASSERT_EQUAL(batch->column(3)->type_id(), arrow::Type::UINT32); - } - if (emptySrc) { - return; + status = batch->ValidateFull(); + if (!status.ok()) { + Cerr << status.ToString() << "\n"; } + UNIT_ASSERT(status.ok()); - auto& sumX = static_cast<arrow::Int64Array&>(*batch->column(0)); - auto& sumY = static_cast<arrow::UInt64Array&>(*batch->column(1)); - auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2)); - - for (ui32 row = 0; row < 2; ++row) { - if (colX.IsNull(row) || colX.Value(row) == 0) { - UNIT_ASSERT_VALUES_EQUAL(sumX.Value(row), 0); - UNIT_ASSERT_VALUES_EQUAL(sumY.Value(row), 0); - } else if (colX.Value(row) == -1) { - UNIT_ASSERT_VALUES_EQUAL(sumX.Value(row), -2); - UNIT_ASSERT_VALUES_EQUAL(sumY.Value(row), 2); - } else { - UNIT_ASSERT(false); - } + switch (aggFunc) { + case EAggregate::Sum: + TSumData::CheckResult(test, batch, numKeys, nullable); + break; + case EAggregate::Min: + case EAggregate::Max: + case EAggregate::Some: + TMinMaxSomeData::CheckResult(test, batch, numKeys, nullable); + break; + default: + break; } } @@ -326,18 +447,34 @@ Y_UNIT_TEST_SUITE(ProgramStep) { } Y_UNIT_TEST(SumGroupBy) { - SumGroupBy(true); - SumGroupBy(true, 2); + GroupByXY(true, 1); + GroupByXY(true, 2); - SumGroupBy(true, 1, true); - SumGroupBy(true, 2, true); + GroupByXY(true, 1, ETest::EMPTY); + GroupByXY(true, 2, ETest::EMPTY); + + GroupByXY(true, 1, ETest::ONE_VALUE); } Y_UNIT_TEST(SumGroupByNotNull) { - SumGroupBy(false); - SumGroupBy(false, 2); + GroupByXY(false, 1); + GroupByXY(false, 2); + + GroupByXY(false, 1, ETest::EMPTY); + GroupByXY(false, 2, ETest::EMPTY); + + GroupByXY(false, 1, ETest::ONE_VALUE); + } + + Y_UNIT_TEST(MinMaxSomeGroupBy) { + GroupByXY(true, 1, ETest::ONE_VALUE, EAggregate::Min); + GroupByXY(true, 1, ETest::ONE_VALUE, EAggregate::Max); + GroupByXY(true, 1, ETest::ONE_VALUE, EAggregate::Some); + } - SumGroupBy(false, 1, true); - SumGroupBy(false, 2, true); + Y_UNIT_TEST(MinMaxSomeGroupByNotNull) { + GroupByXY(false, 1, ETest::ONE_VALUE, EAggregate::Min); + GroupByXY(false, 1, ETest::ONE_VALUE, EAggregate::Max); + GroupByXY(false, 1, ETest::ONE_VALUE, EAggregate::Some); } } diff --git a/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp b/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp index 1b1e01eb56..2f01118c1d 100644 --- a/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp +++ b/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp @@ -2137,8 +2137,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Avg_NullGroupBy) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2159,8 +2157,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Avg_NullMixGroupBy) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2181,8 +2177,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Avg_GroupByNull) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2203,8 +2197,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Avg_GroupByNullMix) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2297,8 +2289,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Sum_NullGroupBy) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2319,8 +2309,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Sum_NullMixGroupBy) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2341,8 +2329,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Sum_GroupByNull) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2363,8 +2349,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Sum_GroupByNullMix) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2510,8 +2494,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Some_Null) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT SOME(level) FROM `/Root/tableWithNulls` WHERE id > 5 @@ -2546,8 +2528,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Some_NullGroupBy) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2568,8 +2548,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Some_NullMixGroupBy) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2590,8 +2568,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Some_GroupByNullMix) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT @@ -2612,8 +2588,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } Y_UNIT_TEST(Aggregation_Some_GroupByNull) { - // Wait for KIKIMR-16831 fix - return; TAggregationTestCase testCase; testCase.SetQuery(R"( SELECT diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h index 598afb54a3..508abe059f 100644 --- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h +++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h @@ -119,16 +119,17 @@ public: size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, - Arena *, - ssize_t if_argument_pos) const final + Arena *) const final { AggregateFunctionSumData<Numerator> sum_data; const auto & column = assert_cast<const ColumnType &>(*columns[0]); - if (if_argument_pos >= 0) + if (auto * flags = column.null_bitmap_data()) { - const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values(); - sum_data.addManyConditional(column.raw_values(), flags, row_begin, row_end); - this->data(place).denominator += countBytesInFilter(flags, row_begin, row_end); + auto * condition_map = flags + column.offset(); + auto length = row_end - row_begin; + + sum_data.addManyConditional(column.raw_values(), condition_map, row_begin, row_end); + this->data(place).denominator += arrow::internal::CountSetBits(condition_map, row_begin, length); } else { diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h index 0994cadc1f..8432027620 100644 --- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h +++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h @@ -18,6 +18,24 @@ namespace CH struct AggregateFunctionCountData { UInt64 count = 0; + bool has_value = false; + + bool has() const + { + return has_value; + } + + void inc() + { + has_value = true; + ++count; + } + + void add(UInt64 value) + { + has_value = true; + count += value; + } }; @@ -38,7 +56,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override { - ++data(place).count; + data(place).inc(); } void addBatchSinglePlace( @@ -46,35 +64,32 @@ public: size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, - Arena *, - ssize_t if_argument_pos) const override + Arena *) const override { - if (if_argument_pos >= 0) + const auto & column = *columns[0]; + if (auto * flags = column.null_bitmap_data()) { - const auto & filter_column = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]); - const auto & flags = filter_column.raw_values(); - data(place).count += countBytesInFilter(flags, row_begin, row_end); + auto * condition_map = flags + column.offset(); + auto length = row_end - row_begin; + data(place).add(arrow::internal::CountSetBits(condition_map, row_begin, length)); } else { - data(place).count += row_end - row_begin; + data(place).add(row_end - row_begin); } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { - data(place).count += data(rhs).count; + data(place).add(data(rhs).count); } void insertResultInto(AggregateDataPtr __restrict place, MutableColumn & to, Arena *) const override { - assert_cast<MutableColumnUInt64 &>(to).Append(data(place).count).ok(); - } - - /// Reset the state to specified value. This function is not the part of common interface. - void set(AggregateDataPtr __restrict place, UInt64 new_count) const - { - data(place).count = new_count; + if (data(place).has()) + assert_cast<MutableColumnUInt64 &>(to).Append(data(place).count).ok(); + else + assert_cast<MutableColumnUInt64 &>(to).AppendNull().ok(); } }; diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h index c8ce2884c7..f323ea50d2 100644 --- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -38,7 +38,7 @@ public: if (has()) assert_cast<MutableColumnType &>(to).Append(value).ok(); else - assert_cast<MutableColumnType &>(to).AppendEmptyValue().ok(); + assert_cast<MutableColumnType &>(to).AppendNull().ok(); } void change(const IColumn & column, size_t row_num, Arena *) @@ -201,7 +201,7 @@ public: if (has()) assert_cast<MutableColumnType &>(to).Append(getData(), size).ok(); else - assert_cast<MutableColumnType &>(to).AppendEmptyValue().ok(); + assert_cast<MutableColumnType &>(to).AppendNull().ok(); } /// Assuming to.has() @@ -572,20 +572,20 @@ public: size_t row_end, AggregateDataPtr place, const IColumn ** columns, - Arena * arena, - ssize_t if_argument_pos) const override + Arena * arena) const override { if constexpr (is_any) if (this->data(place).has()) return; - if (if_argument_pos >= 0) + + const auto & column = *columns[0]; + if (column.null_bitmap_data()) { - const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values(); for (size_t i = row_begin; i < row_end; ++i) { - if (flags[i]) + if (column.IsValid(i)) { - this->data(place).changeIfBetter(*columns[0], i, arena); + this->data(place).changeIfBetter(column, i, arena); if constexpr (is_any) break; } @@ -595,7 +595,7 @@ public: { for (size_t i = row_begin; i < row_end; ++i) { - this->data(place).changeIfBetter(*columns[0], i, arena); + this->data(place).changeIfBetter(column, i, arena); if constexpr (is_any) break; } diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h index f539323d6a..575c627a19 100644 --- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h +++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h @@ -33,9 +33,16 @@ struct AggregateFunctionSumData { using Impl = AggregateFunctionSumAddOverflowImpl<T>; T sum{}; + bool has_value = false; + + bool has() const + { + return has_value; + } void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(T value) { + has_value = true; Impl::add(sum, value); } @@ -84,7 +91,7 @@ struct AggregateFunctionSumData { addManyImpl(ptr, start, end); } - +#if 0 template <typename Value, bool add_if_zero> void NO_SANITIZE_UNDEFINED NO_INLINE addManyConditionalInternalImpl( const Value * __restrict ptr, @@ -172,7 +179,34 @@ struct AggregateFunctionSumData { return addManyConditionalInternal<Value, false>(ptr, cond_map, start, end); } +#else + template <typename Value> + void NO_SANITIZE_UNDEFINED NO_INLINE addManyConditionalImpl( + const Value * __restrict ptr, + const uint8_t * __restrict condition_map, + size_t start, + size_t end) /// NOLINT + { + // TODO: optimize + + const auto * end_ptr = ptr + end; + ptr += start; + while (ptr < end_ptr) + { + if (arrow::BitUtil::GetBit(condition_map, start)) + Impl::add(sum, *ptr); + ++ptr; + ++start; + } + } + + template <typename Value> + void ALWAYS_INLINE addManyConditional(const Value * __restrict ptr, const uint8_t * __restrict cond_map, size_t start, size_t end) + { + return addManyConditionalImpl<Value>(ptr, cond_map, start, end); + } +#endif void NO_SANITIZE_UNDEFINED merge(const AggregateFunctionSumData & rhs) { Impl::add(sum, rhs.sum); @@ -219,14 +253,13 @@ public: size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, - Arena *, - ssize_t if_argument_pos) const override + Arena *) const override { const auto & column = assert_cast<const ColumnType &>(*columns[0]); - if (if_argument_pos >= 0) + if (auto * flags = column.null_bitmap_data()) { - const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values(); - this->data(place).addManyConditional(column.raw_values(), flags, row_begin, row_end); + auto * condition_map = flags + column.offset(); + this->data(place).addManyConditional(column.raw_values(), condition_map, row_begin, row_end); } else { @@ -241,7 +274,10 @@ public: void insertResultInto(AggregateDataPtr __restrict place, MutableColumn & to, Arena *) const override { - assert_cast<MutableColumnType &>(to).Append(this->data(place).get()).ok(); + if (this->data(place).has()) + assert_cast<MutableColumnType &>(to).Append(this->data(place).get()).ok(); + else + assert_cast<MutableColumnType &>(to).AppendNull().ok(); } }; diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h index e6b753d95d..f4f21463a1 100644 --- a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h +++ b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h @@ -117,8 +117,7 @@ public: AggregateDataPtr * places, size_t place_offset, const IColumn ** columns, - Arena * arena, - ssize_t if_argument_pos = -1) const = 0; + Arena * arena) const = 0; virtual void mergeBatch( size_t row_begin, @@ -135,8 +134,7 @@ public: size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, - Arena * arena, - ssize_t if_argument_pos = -1) const = 0; + Arena * arena) const = 0; /** The case when the aggregation key is UInt8 * and pointers to aggregation states are stored in AggregateDataPtr[256] lookup table. @@ -204,15 +202,13 @@ public: AggregateDataPtr * places, size_t place_offset, const IColumn ** columns, - Arena * arena, - ssize_t if_argument_pos = -1) const override + Arena * arena) const override { - if (if_argument_pos >= 0) + if (columns && columns[0]->null_bitmap_data()) { - const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values(); for (size_t i = row_begin; i < row_end; ++i) { - if (flags[i] && places[i]) + if (columns[0]->IsValid(i) && places[i]) static_cast<const Derived *>(this)->add(places[i] + place_offset, columns, i, arena); } } @@ -242,15 +238,13 @@ public: size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, - Arena * arena, - ssize_t if_argument_pos = -1) const override + Arena * arena) const override { - if (if_argument_pos >= 0) + if (columns && columns[0]->null_bitmap_data()) { - const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values(); for (size_t i = row_begin; i < row_end; ++i) { - if (flags[i]) + if (columns[0]->IsValid(i)) static_cast<const Derived *>(this)->add(place, columns, i, arena); } } diff --git a/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h b/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h index 9142899fb1..493280e862 100644 --- a/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h +++ b/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h @@ -10,6 +10,7 @@ #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/util/bitmap.h> #include <common/StringRef.h> #include <common/extended_types.h> |