aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorchertus <azuikov@ydb.tech>2023-01-11 21:16:04 +0300
committerchertus <azuikov@ydb.tech>2023-01-11 21:16:04 +0300
commitd060e0812a3a1e493d574621c56883b5916e4814 (patch)
treeda989824769f90e46a08bf59f1e26380bb9a90f7
parent8d1eca571344506fd5ab42bb2cd221e58149adbc (diff)
downloadydb-d060e0812a3a1e493d574621c56883b5916e4814.tar.gz
fix GROUP BY with empty batch
-rw-r--r--ydb/core/formats/ut_program_step.cpp36
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h20
-rw-r--r--ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h5
3 files changed, 45 insertions, 16 deletions
diff --git a/ydb/core/formats/ut_program_step.cpp b/ydb/core/formats/ut_program_step.cpp
index fe9a1ea79f..a043779ef3 100644
--- a/ydb/core/formats/ut_program_step.cpp
+++ b/ydb/core/formats/ut_program_step.cpp
@@ -46,13 +46,17 @@ size_t FilterTestUnary(std::vector<std::shared_ptr<arrow::Array>> args, EOperati
step->Assignes = {TAssign("res1", op1, {"x"}), TAssign("res2", op2, {"res1", "z"})};
step->Filters = {"res2"};
step->Projection = {"res1", "res2"};
- UNIT_ASSERT(ApplyProgram(batch, TProgram({step}), GetCustomExecContext()).ok());
+ auto status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext());
+ if (!status.ok()) {
+ Cerr << status.ToString() << "\n";
+ }
+ UNIT_ASSERT(status.ok());
UNIT_ASSERT(batch->ValidateFull().ok());
UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2);
return batch->num_rows();
}
-void SumGroupBy(bool nullable, ui32 numKeys = 1) {
+void SumGroupBy(bool nullable, ui32 numKeys = 1, bool emptySrc = false) {
std::optional<double> null;
if (nullable) {
null = 0;
@@ -61,8 +65,14 @@ void SumGroupBy(bool nullable, ui32 numKeys = 1) {
auto schema = std::make_shared<arrow::Schema>(std::vector{
std::make_shared<arrow::Field>("x", arrow::int16()),
std::make_shared<arrow::Field>("y", arrow::uint32())});
- auto batch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null),
- NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)});
+ std::shared_ptr<arrow::RecordBatch> batch;
+ if (emptySrc) {
+ batch = arrow::RecordBatch::Make(schema, 0, std::vector{NumVecToArray(arrow::int16(), {}),
+ NumVecToArray(arrow::uint32(), {})});
+ } else {
+ batch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null),
+ NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)});
+ }
UNIT_ASSERT(batch->ValidateFull().ok());
auto step = std::make_shared<TProgramStep>();
@@ -75,10 +85,14 @@ void SumGroupBy(bool nullable, ui32 numKeys = 1) {
step->GroupByKeys.push_back("y");
}
- UNIT_ASSERT(ApplyProgram(batch, TProgram({step}), GetCustomExecContext()).ok());
+ auto status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext());
+ if (!status.ok()) {
+ Cerr << status.ToString() << "\n";
+ }
+ UNIT_ASSERT(status.ok());
UNIT_ASSERT(batch->ValidateFull().ok());
UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2);
- UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 2);
+ UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), (emptySrc ? 0 : 2));
UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT64);
UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT64);
UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16);
@@ -86,6 +100,10 @@ void SumGroupBy(bool nullable, ui32 numKeys = 1) {
UNIT_ASSERT_EQUAL(batch->column(3)->type_id(), arrow::Type::UINT32);
}
+ if (emptySrc) {
+ return;
+ }
+
auto& sumX = static_cast<arrow::Int64Array&>(*batch->column(0));
auto& sumY = static_cast<arrow::UInt64Array&>(*batch->column(1));
auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2));
@@ -310,10 +328,16 @@ Y_UNIT_TEST_SUITE(ProgramStep) {
Y_UNIT_TEST(SumGroupBy) {
SumGroupBy(true);
SumGroupBy(true, 2);
+
+ SumGroupBy(true, 1, true);
+ SumGroupBy(true, 2, true);
}
Y_UNIT_TEST(SumGroupByNotNull) {
SumGroupBy(false);
SumGroupBy(false, 2);
+
+ SumGroupBy(false, 1, true);
+ SumGroupBy(false, 2, true);
}
}
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h
index 2269737ea3..ff05557525 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h
@@ -68,6 +68,8 @@ public:
}
auto batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), num_rows, columns);
+ if (!batch)
+ return arrow::Status::Invalid("Wrong aggregation arguments: cannot make batch");
AggregateDescription description {
.function = getHouseFunction(types),
@@ -141,8 +143,9 @@ public:
columns.reserve(needed_columns.size());
fields.reserve(needed_columns.size());
- int64_t num_rows = 0;
- for (int i = 0; i < opts->schema->num_fields(); ++i) {
+ std::optional<int64_t> num_rows;
+ for (int i = 0; i < opts->schema->num_fields(); ++i)
+ {
auto& datum = args[i];
auto& field = opts->schema->field(i);
@@ -151,7 +154,7 @@ public:
if (datum.is_array())
{
- if (num_rows && num_rows != datum.mutable_array()->length)
+ if (num_rows && *num_rows != datum.mutable_array()->length)
return arrow::Status::Invalid("Arrays have different length");
num_rows = datum.mutable_array()->length;
}
@@ -161,7 +164,8 @@ public:
if (!num_rows) // All datums are scalars
num_rows = 1;
- for (int i = 0; i < opts->schema->num_fields(); ++i) {
+ for (int i = 0; i < opts->schema->num_fields(); ++i)
+ {
auto& datum = args[i];
auto& field = opts->schema->field(i);
@@ -171,9 +175,9 @@ public:
if (datum.is_scalar())
{
// TODO: better GROUP BY over scalars
- auto res = arrow::MakeArrayFromScalar(*datum.scalar(), num_rows);
+ auto res = arrow::MakeArrayFromScalar(*datum.scalar(), *num_rows);
if (!res.ok())
- return arrow::Status::Invalid("Bad scalar: '" + field->name() + "'");
+ return arrow::Status::Invalid("Bad scalar for '" + field->name() + "', " + res.status().ToString());
columns.push_back(*res);
}
else
@@ -182,7 +186,7 @@ public:
fields.push_back(field);
}
- batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), num_rows, columns);
+ batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), *num_rows, columns);
if (!batch)
return arrow::Status::Invalid("Wrong GROUP BY arguments: cannot make batch");
}
@@ -240,7 +244,7 @@ public:
AggregatingBlockInputStream agg_stream(input_stream, agg_params, true);
auto result_batch = agg_stream.read();
- if (!result_batch || result_batch->num_rows() == 0)
+ if (!result_batch || (batch->num_rows() && !result_batch->num_rows()))
return arrow::Status::Invalid("unexpected arrgerate result");
if (agg_stream.read())
return arrow::Status::Invalid("unexpected second batch in aggregate result");
diff --git a/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h b/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h
index 6735022c46..1664d79db7 100644
--- a/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h
+++ b/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h
@@ -18,8 +18,9 @@ public:
explicit OneBlockInputStream(Block block_)
: block(std::move(block_))
{
- if (!block->Validate().ok())
- throw Exception("Bad batch in OneBlockInputStream");
+ auto status = block->Validate();
+ if (!status.ok())
+ throw Exception(std::string("Bad batch in OneBlockInputStream: ") + status.ToString());
}
String getName() const override { return "One"; }