aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorchertus <azuikov@ydb.tech>2023-01-23 15:29:29 +0300
committerchertus <azuikov@ydb.tech>2023-01-23 15:29:29 +0300
commit08d7aec1386d15c53ded0a9239d592e74fce66df (patch)
tree1be1e267cb268540bc7ad00289a7ffd7bb627e63
parent0e91c7a392b2e6d46a4dde191a29073dfc8d434b (diff)
downloadydb-08d7aec1386d15c53ded0a9239d592e74fce66df.tar.gz
fix nulls in SSA aggregates
-rw-r--r--ydb/core/formats/ut_program_step.cpp231
-rw-r--r--ydb/core/kqp/ut/olap/kqp_olap_ut.cpp26
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h13
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h47
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h18
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h50
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h22
-rw-r--r--ydb/library/arrow_clickhouse/arrow_clickhouse_types.h1
8 files changed, 283 insertions, 125 deletions
diff --git a/ydb/core/formats/ut_program_step.cpp b/ydb/core/formats/ut_program_step.cpp
index a043779ef3..f89b9f0f7a 100644
--- a/ydb/core/formats/ut_program_step.cpp
+++ b/ydb/core/formats/ut_program_step.cpp
@@ -56,68 +56,189 @@ size_t FilterTestUnary(std::vector<std::shared_ptr<arrow::Array>> args, EOperati
return batch->num_rows();
}
-void SumGroupBy(bool nullable, ui32 numKeys = 1, bool emptySrc = false) {
- std::optional<double> null;
- if (nullable) {
- null = 0;
+enum class ETest {
+ DEFAULT,
+ EMPTY,
+ ONE_VALUE
+};
+
+struct TSumData {
+ static std::shared_ptr<arrow::RecordBatch> Data(ETest test,
+ std::shared_ptr<arrow::Schema>& schema,
+ bool nullable)
+ {
+ std::optional<double> null;
+ if (nullable) {
+ null = 0;
+ }
+
+ if (test == ETest::DEFAULT) {
+ return arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null),
+ NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)});
+ } else if (test == ETest::EMPTY) {
+ return arrow::RecordBatch::Make(schema, 0, std::vector{NumVecToArray(arrow::int16(), {}),
+ NumVecToArray(arrow::uint32(), {})});
+ } else if (test == ETest::ONE_VALUE) {
+ return arrow::RecordBatch::Make(schema, 1, std::vector{NumVecToArray(arrow::int16(), {1}),
+ NumVecToArray(arrow::uint32(), {0}, null)});
+ }
+ return {};
}
+ static void CheckResult(ETest test, const std::shared_ptr<arrow::RecordBatch>& batch, ui32 numKeys, bool nullable) {
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2);
+ UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT64);
+ UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT64);
+ UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16);
+ if (numKeys == 2) {
+ UNIT_ASSERT_EQUAL(batch->column(3)->type_id(), arrow::Type::UINT32);
+ }
+
+ if (test == ETest::EMPTY) {
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 0);
+ return;
+ }
+
+ auto& aggX = static_cast<arrow::Int64Array&>(*batch->column(0));
+ auto& aggY = static_cast<arrow::UInt64Array&>(*batch->column(1));
+ auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2));
+
+ if (test == ETest::ONE_VALUE) {
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1);
+
+ UNIT_ASSERT_VALUES_EQUAL(aggX.Value(0), 1);
+ if (nullable) {
+ UNIT_ASSERT(aggY.IsNull(0));
+ } else {
+ UNIT_ASSERT(!aggY.IsNull(0));
+ UNIT_ASSERT_VALUES_EQUAL(aggY.Value(0), 0);
+ }
+ return;
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 2);
+
+ for (ui32 row = 0; row < 2; ++row) {
+ if (colX.IsNull(row)) {
+ UNIT_ASSERT(aggX.IsNull(row));
+ UNIT_ASSERT(aggY.IsNull(row));
+ } else {
+ UNIT_ASSERT(!aggX.IsNull(row));
+ UNIT_ASSERT(!aggY.IsNull(row));
+ if (colX.Value(row) == 0) {
+ UNIT_ASSERT_VALUES_EQUAL(aggX.Value(row), 0);
+ UNIT_ASSERT_VALUES_EQUAL(aggY.Value(row), 0);
+ } else if (colX.Value(row) == -1) {
+ UNIT_ASSERT_VALUES_EQUAL(aggX.Value(row), -2);
+ UNIT_ASSERT_VALUES_EQUAL(aggY.Value(row), 2);
+ } else {
+ UNIT_ASSERT(false);
+ }
+ }
+ }
+ }
+};
+
+struct TMinMaxSomeData {
+ static std::shared_ptr<arrow::RecordBatch> Data(ETest /*test*/,
+ std::shared_ptr<arrow::Schema>& schema,
+ bool nullable)
+ {
+ std::optional<double> null;
+ if (nullable) {
+ null = 0;
+ }
+
+ return arrow::RecordBatch::Make(schema, 1, std::vector{NumVecToArray(arrow::int16(), {1}),
+ NumVecToArray(arrow::uint32(), {0}, null)});
+ }
+
+ static void CheckResult(ETest /*test*/, const std::shared_ptr<arrow::RecordBatch>& batch, ui32 numKeys,
+ bool nullable) {
+ UNIT_ASSERT_VALUES_EQUAL(numKeys, 1);
+
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2);
+ UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT16);
+ UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT32);
+ UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16);
+
+ auto& aggX = static_cast<arrow::Int16Array&>(*batch->column(0));
+ auto& aggY = static_cast<arrow::UInt32Array&>(*batch->column(1));
+ auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2));
+
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1);
+
+ UNIT_ASSERT_VALUES_EQUAL(colX.Value(0), 1);
+ UNIT_ASSERT_VALUES_EQUAL(aggX.Value(0), 1);
+ if (nullable) {
+ UNIT_ASSERT(aggY.IsNull(0));
+ } else {
+ UNIT_ASSERT(!aggY.IsNull(0));
+ UNIT_ASSERT_VALUES_EQUAL(aggY.Value(0), 0);
+ }
+ return;
+ }
+};
+
+void GroupByXY(bool nullable, ui32 numKeys, ETest test = ETest::DEFAULT,
+ EAggregate aggFunc = EAggregate::Sum) {
auto schema = std::make_shared<arrow::Schema>(std::vector{
std::make_shared<arrow::Field>("x", arrow::int16()),
std::make_shared<arrow::Field>("y", arrow::uint32())});
+
std::shared_ptr<arrow::RecordBatch> batch;
- if (emptySrc) {
- batch = arrow::RecordBatch::Make(schema, 0, std::vector{NumVecToArray(arrow::int16(), {}),
- NumVecToArray(arrow::uint32(), {})});
- } else {
- batch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null),
- NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)});
+ switch (aggFunc) {
+ case EAggregate::Sum:
+ batch = TSumData::Data(test, schema, nullable);
+ break;
+ case EAggregate::Min:
+ case EAggregate::Max:
+ case EAggregate::Some:
+ batch = TMinMaxSomeData::Data(test, schema, nullable);
+ break;
+ default:
+ break;
}
- UNIT_ASSERT(batch->ValidateFull().ok());
+ UNIT_ASSERT(batch);
+ auto status = batch->ValidateFull();
+ if (!status.ok()) {
+ Cerr << status.ToString() << "\n";
+ }
+ UNIT_ASSERT(status.ok());
auto step = std::make_shared<TProgramStep>();
step->GroupBy = {
- TAggregateAssign("sum_x", EAggregate::Sum, {"x"}),
- TAggregateAssign("sum_y", EAggregate::Sum, {"y"})
+ TAggregateAssign("agg_x", aggFunc, {"x"}),
+ TAggregateAssign("agg_y", aggFunc, {"y"})
};
step->GroupByKeys.push_back("x");
if (numKeys == 2) {
step->GroupByKeys.push_back("y");
}
- auto status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext());
+ status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext());
if (!status.ok()) {
Cerr << status.ToString() << "\n";
}
UNIT_ASSERT(status.ok());
- UNIT_ASSERT(batch->ValidateFull().ok());
- UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2);
- UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), (emptySrc ? 0 : 2));
- UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT64);
- UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT64);
- UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16);
- if (numKeys == 2) {
- UNIT_ASSERT_EQUAL(batch->column(3)->type_id(), arrow::Type::UINT32);
- }
- if (emptySrc) {
- return;
+ status = batch->ValidateFull();
+ if (!status.ok()) {
+ Cerr << status.ToString() << "\n";
}
+ UNIT_ASSERT(status.ok());
- auto& sumX = static_cast<arrow::Int64Array&>(*batch->column(0));
- auto& sumY = static_cast<arrow::UInt64Array&>(*batch->column(1));
- auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2));
-
- for (ui32 row = 0; row < 2; ++row) {
- if (colX.IsNull(row) || colX.Value(row) == 0) {
- UNIT_ASSERT_VALUES_EQUAL(sumX.Value(row), 0);
- UNIT_ASSERT_VALUES_EQUAL(sumY.Value(row), 0);
- } else if (colX.Value(row) == -1) {
- UNIT_ASSERT_VALUES_EQUAL(sumX.Value(row), -2);
- UNIT_ASSERT_VALUES_EQUAL(sumY.Value(row), 2);
- } else {
- UNIT_ASSERT(false);
- }
+ switch (aggFunc) {
+ case EAggregate::Sum:
+ TSumData::CheckResult(test, batch, numKeys, nullable);
+ break;
+ case EAggregate::Min:
+ case EAggregate::Max:
+ case EAggregate::Some:
+ TMinMaxSomeData::CheckResult(test, batch, numKeys, nullable);
+ break;
+ default:
+ break;
}
}
@@ -326,18 +447,34 @@ Y_UNIT_TEST_SUITE(ProgramStep) {
}
Y_UNIT_TEST(SumGroupBy) {
- SumGroupBy(true);
- SumGroupBy(true, 2);
+ GroupByXY(true, 1);
+ GroupByXY(true, 2);
- SumGroupBy(true, 1, true);
- SumGroupBy(true, 2, true);
+ GroupByXY(true, 1, ETest::EMPTY);
+ GroupByXY(true, 2, ETest::EMPTY);
+
+ GroupByXY(true, 1, ETest::ONE_VALUE);
}
Y_UNIT_TEST(SumGroupByNotNull) {
- SumGroupBy(false);
- SumGroupBy(false, 2);
+ GroupByXY(false, 1);
+ GroupByXY(false, 2);
+
+ GroupByXY(false, 1, ETest::EMPTY);
+ GroupByXY(false, 2, ETest::EMPTY);
+
+ GroupByXY(false, 1, ETest::ONE_VALUE);
+ }
+
+ Y_UNIT_TEST(MinMaxSomeGroupBy) {
+ GroupByXY(true, 1, ETest::ONE_VALUE, EAggregate::Min);
+ GroupByXY(true, 1, ETest::ONE_VALUE, EAggregate::Max);
+ GroupByXY(true, 1, ETest::ONE_VALUE, EAggregate::Some);
+ }
- SumGroupBy(false, 1, true);
- SumGroupBy(false, 2, true);
+ Y_UNIT_TEST(MinMaxSomeGroupByNotNull) {
+ GroupByXY(false, 1, ETest::ONE_VALUE, EAggregate::Min);
+ GroupByXY(false, 1, ETest::ONE_VALUE, EAggregate::Max);
+ GroupByXY(false, 1, ETest::ONE_VALUE, EAggregate::Some);
}
}
diff --git a/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp b/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp
index 1b1e01eb56..2f01118c1d 100644
--- a/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp
+++ b/ydb/core/kqp/ut/olap/kqp_olap_ut.cpp
@@ -2137,8 +2137,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Avg_NullGroupBy) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2159,8 +2157,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Avg_NullMixGroupBy) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2181,8 +2177,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Avg_GroupByNull) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2203,8 +2197,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Avg_GroupByNullMix) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2297,8 +2289,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Sum_NullGroupBy) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2319,8 +2309,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Sum_NullMixGroupBy) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2341,8 +2329,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Sum_GroupByNull) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2363,8 +2349,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Sum_GroupByNullMix) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2510,8 +2494,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Some_Null) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT SOME(level) FROM `/Root/tableWithNulls` WHERE id > 5
@@ -2546,8 +2528,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Some_NullGroupBy) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2568,8 +2548,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Some_NullMixGroupBy) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2590,8 +2568,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Some_GroupByNullMix) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
@@ -2612,8 +2588,6 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
}
Y_UNIT_TEST(Aggregation_Some_GroupByNull) {
- // Wait for KIKIMR-16831 fix
- return;
TAggregationTestCase testCase;
testCase.SetQuery(R"(
SELECT
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h
index 598afb54a3..508abe059f 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionAvg.h
@@ -119,16 +119,17 @@ public:
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** columns,
- Arena *,
- ssize_t if_argument_pos) const final
+ Arena *) const final
{
AggregateFunctionSumData<Numerator> sum_data;
const auto & column = assert_cast<const ColumnType &>(*columns[0]);
- if (if_argument_pos >= 0)
+ if (auto * flags = column.null_bitmap_data())
{
- const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values();
- sum_data.addManyConditional(column.raw_values(), flags, row_begin, row_end);
- this->data(place).denominator += countBytesInFilter(flags, row_begin, row_end);
+ auto * condition_map = flags + column.offset();
+ auto length = row_end - row_begin;
+
+ sum_data.addManyConditional(column.raw_values(), condition_map, row_begin, row_end);
+ this->data(place).denominator += arrow::internal::CountSetBits(condition_map, row_begin, length);
}
else
{
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h
index 0994cadc1f..8432027620 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionCount.h
@@ -18,6 +18,24 @@ namespace CH
struct AggregateFunctionCountData
{
UInt64 count = 0;
+ bool has_value = false;
+
+ bool has() const
+ {
+ return has_value;
+ }
+
+ void inc()
+ {
+ has_value = true;
+ ++count;
+ }
+
+ void add(UInt64 value)
+ {
+ has_value = true;
+ count += value;
+ }
};
@@ -38,7 +56,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override
{
- ++data(place).count;
+ data(place).inc();
}
void addBatchSinglePlace(
@@ -46,35 +64,32 @@ public:
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** columns,
- Arena *,
- ssize_t if_argument_pos) const override
+ Arena *) const override
{
- if (if_argument_pos >= 0)
+ const auto & column = *columns[0];
+ if (auto * flags = column.null_bitmap_data())
{
- const auto & filter_column = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]);
- const auto & flags = filter_column.raw_values();
- data(place).count += countBytesInFilter(flags, row_begin, row_end);
+ auto * condition_map = flags + column.offset();
+ auto length = row_end - row_begin;
+ data(place).add(arrow::internal::CountSetBits(condition_map, row_begin, length));
}
else
{
- data(place).count += row_end - row_begin;
+ data(place).add(row_end - row_begin);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
- data(place).count += data(rhs).count;
+ data(place).add(data(rhs).count);
}
void insertResultInto(AggregateDataPtr __restrict place, MutableColumn & to, Arena *) const override
{
- assert_cast<MutableColumnUInt64 &>(to).Append(data(place).count).ok();
- }
-
- /// Reset the state to specified value. This function is not the part of common interface.
- void set(AggregateDataPtr __restrict place, UInt64 new_count) const
- {
- data(place).count = new_count;
+ if (data(place).has())
+ assert_cast<MutableColumnUInt64 &>(to).Append(data(place).count).ok();
+ else
+ assert_cast<MutableColumnUInt64 &>(to).AppendNull().ok();
}
};
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h
index c8ce2884c7..f323ea50d2 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionMinMaxAny.h
@@ -38,7 +38,7 @@ public:
if (has())
assert_cast<MutableColumnType &>(to).Append(value).ok();
else
- assert_cast<MutableColumnType &>(to).AppendEmptyValue().ok();
+ assert_cast<MutableColumnType &>(to).AppendNull().ok();
}
void change(const IColumn & column, size_t row_num, Arena *)
@@ -201,7 +201,7 @@ public:
if (has())
assert_cast<MutableColumnType &>(to).Append(getData(), size).ok();
else
- assert_cast<MutableColumnType &>(to).AppendEmptyValue().ok();
+ assert_cast<MutableColumnType &>(to).AppendNull().ok();
}
/// Assuming to.has()
@@ -572,20 +572,20 @@ public:
size_t row_end,
AggregateDataPtr place,
const IColumn ** columns,
- Arena * arena,
- ssize_t if_argument_pos) const override
+ Arena * arena) const override
{
if constexpr (is_any)
if (this->data(place).has())
return;
- if (if_argument_pos >= 0)
+
+ const auto & column = *columns[0];
+ if (column.null_bitmap_data())
{
- const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values();
for (size_t i = row_begin; i < row_end; ++i)
{
- if (flags[i])
+ if (column.IsValid(i))
{
- this->data(place).changeIfBetter(*columns[0], i, arena);
+ this->data(place).changeIfBetter(column, i, arena);
if constexpr (is_any)
break;
}
@@ -595,7 +595,7 @@ public:
{
for (size_t i = row_begin; i < row_end; ++i)
{
- this->data(place).changeIfBetter(*columns[0], i, arena);
+ this->data(place).changeIfBetter(column, i, arena);
if constexpr (is_any)
break;
}
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h
index f539323d6a..575c627a19 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionSum.h
@@ -33,9 +33,16 @@ struct AggregateFunctionSumData
{
using Impl = AggregateFunctionSumAddOverflowImpl<T>;
T sum{};
+ bool has_value = false;
+
+ bool has() const
+ {
+ return has_value;
+ }
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(T value)
{
+ has_value = true;
Impl::add(sum, value);
}
@@ -84,7 +91,7 @@ struct AggregateFunctionSumData
{
addManyImpl(ptr, start, end);
}
-
+#if 0
template <typename Value, bool add_if_zero>
void NO_SANITIZE_UNDEFINED NO_INLINE addManyConditionalInternalImpl(
const Value * __restrict ptr,
@@ -172,7 +179,34 @@ struct AggregateFunctionSumData
{
return addManyConditionalInternal<Value, false>(ptr, cond_map, start, end);
}
+#else
+ template <typename Value>
+ void NO_SANITIZE_UNDEFINED NO_INLINE addManyConditionalImpl(
+ const Value * __restrict ptr,
+ const uint8_t * __restrict condition_map,
+ size_t start,
+ size_t end) /// NOLINT
+ {
+ // TODO: optimize
+
+ const auto * end_ptr = ptr + end;
+ ptr += start;
+ while (ptr < end_ptr)
+ {
+ if (arrow::BitUtil::GetBit(condition_map, start))
+ Impl::add(sum, *ptr);
+ ++ptr;
+ ++start;
+ }
+ }
+
+ template <typename Value>
+ void ALWAYS_INLINE addManyConditional(const Value * __restrict ptr, const uint8_t * __restrict cond_map, size_t start, size_t end)
+ {
+ return addManyConditionalImpl<Value>(ptr, cond_map, start, end);
+ }
+#endif
void NO_SANITIZE_UNDEFINED merge(const AggregateFunctionSumData & rhs)
{
Impl::add(sum, rhs.sum);
@@ -219,14 +253,13 @@ public:
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** columns,
- Arena *,
- ssize_t if_argument_pos) const override
+ Arena *) const override
{
const auto & column = assert_cast<const ColumnType &>(*columns[0]);
- if (if_argument_pos >= 0)
+ if (auto * flags = column.null_bitmap_data())
{
- const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values();
- this->data(place).addManyConditional(column.raw_values(), flags, row_begin, row_end);
+ auto * condition_map = flags + column.offset();
+ this->data(place).addManyConditional(column.raw_values(), condition_map, row_begin, row_end);
}
else
{
@@ -241,7 +274,10 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, MutableColumn & to, Arena *) const override
{
- assert_cast<MutableColumnType &>(to).Append(this->data(place).get()).ok();
+ if (this->data(place).has())
+ assert_cast<MutableColumnType &>(to).Append(this->data(place).get()).ok();
+ else
+ assert_cast<MutableColumnType &>(to).AppendNull().ok();
}
};
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h
index e6b753d95d..f4f21463a1 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h
@@ -117,8 +117,7 @@ public:
AggregateDataPtr * places,
size_t place_offset,
const IColumn ** columns,
- Arena * arena,
- ssize_t if_argument_pos = -1) const = 0;
+ Arena * arena) const = 0;
virtual void mergeBatch(
size_t row_begin,
@@ -135,8 +134,7 @@ public:
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** columns,
- Arena * arena,
- ssize_t if_argument_pos = -1) const = 0;
+ Arena * arena) const = 0;
/** The case when the aggregation key is UInt8
* and pointers to aggregation states are stored in AggregateDataPtr[256] lookup table.
@@ -204,15 +202,13 @@ public:
AggregateDataPtr * places,
size_t place_offset,
const IColumn ** columns,
- Arena * arena,
- ssize_t if_argument_pos = -1) const override
+ Arena * arena) const override
{
- if (if_argument_pos >= 0)
+ if (columns && columns[0]->null_bitmap_data())
{
- const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values();
for (size_t i = row_begin; i < row_end; ++i)
{
- if (flags[i] && places[i])
+ if (columns[0]->IsValid(i) && places[i])
static_cast<const Derived *>(this)->add(places[i] + place_offset, columns, i, arena);
}
}
@@ -242,15 +238,13 @@ public:
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** columns,
- Arena * arena,
- ssize_t if_argument_pos = -1) const override
+ Arena * arena) const override
{
- if (if_argument_pos >= 0)
+ if (columns && columns[0]->null_bitmap_data())
{
- const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).raw_values();
for (size_t i = row_begin; i < row_end; ++i)
{
- if (flags[i])
+ if (columns[0]->IsValid(i))
static_cast<const Derived *>(this)->add(place, columns, i, arena);
}
}
diff --git a/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h b/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h
index 9142899fb1..493280e862 100644
--- a/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h
+++ b/ydb/library/arrow_clickhouse/arrow_clickhouse_types.h
@@ -10,6 +10,7 @@
#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/util/bitmap.h>
#include <common/StringRef.h>
#include <common/extended_types.h>