diff options
author | chertus <azuikov@ydb.tech> | 2023-01-11 21:16:04 +0300 |
---|---|---|
committer | chertus <azuikov@ydb.tech> | 2023-01-11 21:16:04 +0300 |
commit | d060e0812a3a1e493d574621c56883b5916e4814 (patch) | |
tree | da989824769f90e46a08bf59f1e26380bb9a90f7 | |
parent | 8d1eca571344506fd5ab42bb2cd221e58149adbc (diff) | |
download | ydb-d060e0812a3a1e493d574621c56883b5916e4814.tar.gz |
fix GROUP BY with empty batch
3 files changed, 45 insertions, 16 deletions
diff --git a/ydb/core/formats/ut_program_step.cpp b/ydb/core/formats/ut_program_step.cpp index fe9a1ea79f..a043779ef3 100644 --- a/ydb/core/formats/ut_program_step.cpp +++ b/ydb/core/formats/ut_program_step.cpp @@ -46,13 +46,17 @@ size_t FilterTestUnary(std::vector<std::shared_ptr<arrow::Array>> args, EOperati step->Assignes = {TAssign("res1", op1, {"x"}), TAssign("res2", op2, {"res1", "z"})}; step->Filters = {"res2"}; step->Projection = {"res1", "res2"}; - UNIT_ASSERT(ApplyProgram(batch, TProgram({step}), GetCustomExecContext()).ok()); + auto status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext()); + if (!status.ok()) { + Cerr << status.ToString() << "\n"; + } + UNIT_ASSERT(status.ok()); UNIT_ASSERT(batch->ValidateFull().ok()); UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); return batch->num_rows(); } -void SumGroupBy(bool nullable, ui32 numKeys = 1) { +void SumGroupBy(bool nullable, ui32 numKeys = 1, bool emptySrc = false) { std::optional<double> null; if (nullable) { null = 0; @@ -61,8 +65,14 @@ void SumGroupBy(bool nullable, ui32 numKeys = 1) { auto schema = std::make_shared<arrow::Schema>(std::vector{ std::make_shared<arrow::Field>("x", arrow::int16()), std::make_shared<arrow::Field>("y", arrow::uint32())}); - auto batch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null), - NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)}); + std::shared_ptr<arrow::RecordBatch> batch; + if (emptySrc) { + batch = arrow::RecordBatch::Make(schema, 0, std::vector{NumVecToArray(arrow::int16(), {}), + NumVecToArray(arrow::uint32(), {})}); + } else { + batch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int16(), {-1, 0, 0, -1}, null), + NumVecToArray(arrow::uint32(), {1, 0, 0, 1}, null)}); + } UNIT_ASSERT(batch->ValidateFull().ok()); auto step = std::make_shared<TProgramStep>(); @@ -75,10 +85,14 @@ void SumGroupBy(bool nullable, ui32 numKeys = 1) { step->GroupByKeys.push_back("y"); } - UNIT_ASSERT(ApplyProgram(batch, TProgram({step}), GetCustomExecContext()).ok()); + auto status = ApplyProgram(batch, TProgram({step}), GetCustomExecContext()); + if (!status.ok()) { + Cerr << status.ToString() << "\n"; + } + UNIT_ASSERT(status.ok()); UNIT_ASSERT(batch->ValidateFull().ok()); UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), numKeys + 2); - UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 2); + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), (emptySrc ? 0 : 2)); UNIT_ASSERT_EQUAL(batch->column(0)->type_id(), arrow::Type::INT64); UNIT_ASSERT_EQUAL(batch->column(1)->type_id(), arrow::Type::UINT64); UNIT_ASSERT_EQUAL(batch->column(2)->type_id(), arrow::Type::INT16); @@ -86,6 +100,10 @@ void SumGroupBy(bool nullable, ui32 numKeys = 1) { UNIT_ASSERT_EQUAL(batch->column(3)->type_id(), arrow::Type::UINT32); } + if (emptySrc) { + return; + } + auto& sumX = static_cast<arrow::Int64Array&>(*batch->column(0)); auto& sumY = static_cast<arrow::UInt64Array&>(*batch->column(1)); auto& colX = static_cast<arrow::Int16Array&>(*batch->column(2)); @@ -310,10 +328,16 @@ Y_UNIT_TEST_SUITE(ProgramStep) { Y_UNIT_TEST(SumGroupBy) { SumGroupBy(true); SumGroupBy(true, 2); + + SumGroupBy(true, 1, true); + SumGroupBy(true, 2, true); } Y_UNIT_TEST(SumGroupByNotNull) { SumGroupBy(false); SumGroupBy(false, 2); + + SumGroupBy(false, 1, true); + SumGroupBy(false, 2, true); } } diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h index 2269737ea3..ff05557525 100644 --- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h +++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h @@ -68,6 +68,8 @@ public: } auto batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), num_rows, columns); + if (!batch) + return arrow::Status::Invalid("Wrong aggregation arguments: cannot make batch"); AggregateDescription description { .function = getHouseFunction(types), @@ -141,8 +143,9 @@ public: columns.reserve(needed_columns.size()); fields.reserve(needed_columns.size()); - int64_t num_rows = 0; - for (int i = 0; i < opts->schema->num_fields(); ++i) { + std::optional<int64_t> num_rows; + for (int i = 0; i < opts->schema->num_fields(); ++i) + { auto& datum = args[i]; auto& field = opts->schema->field(i); @@ -151,7 +154,7 @@ public: if (datum.is_array()) { - if (num_rows && num_rows != datum.mutable_array()->length) + if (num_rows && *num_rows != datum.mutable_array()->length) return arrow::Status::Invalid("Arrays have different length"); num_rows = datum.mutable_array()->length; } @@ -161,7 +164,8 @@ public: if (!num_rows) // All datums are scalars num_rows = 1; - for (int i = 0; i < opts->schema->num_fields(); ++i) { + for (int i = 0; i < opts->schema->num_fields(); ++i) + { auto& datum = args[i]; auto& field = opts->schema->field(i); @@ -171,9 +175,9 @@ public: if (datum.is_scalar()) { // TODO: better GROUP BY over scalars - auto res = arrow::MakeArrayFromScalar(*datum.scalar(), num_rows); + auto res = arrow::MakeArrayFromScalar(*datum.scalar(), *num_rows); if (!res.ok()) - return arrow::Status::Invalid("Bad scalar: '" + field->name() + "'"); + return arrow::Status::Invalid("Bad scalar for '" + field->name() + "', " + res.status().ToString()); columns.push_back(*res); } else @@ -182,7 +186,7 @@ public: fields.push_back(field); } - batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), num_rows, columns); + batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), *num_rows, columns); if (!batch) return arrow::Status::Invalid("Wrong GROUP BY arguments: cannot make batch"); } @@ -240,7 +244,7 @@ public: AggregatingBlockInputStream agg_stream(input_stream, agg_params, true); auto result_batch = agg_stream.read(); - if (!result_batch || result_batch->num_rows() == 0) + if (!result_batch || (batch->num_rows() && !result_batch->num_rows())) return arrow::Status::Invalid("unexpected arrgerate result"); if (agg_stream.read()) return arrow::Status::Invalid("unexpected second batch in aggregate result"); diff --git a/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h b/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h index 6735022c46..1664d79db7 100644 --- a/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h +++ b/ydb/library/arrow_clickhouse/DataStreams/OneBlockInputStream.h @@ -18,8 +18,9 @@ public: explicit OneBlockInputStream(Block block_) : block(std::move(block_)) { - if (!block->Validate().ok()) - throw Exception("Bad batch in OneBlockInputStream"); + auto status = block->Validate(); + if (!status.ok()) + throw Exception(std::string("Bad batch in OneBlockInputStream: ") + status.ToString()); } String getName() const override { return "One"; } |