aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorchertus <azuikov@ydb.tech>2022-09-07 18:06:43 +0300
committerchertus <azuikov@ydb.tech>2022-09-07 18:06:43 +0300
commit8b69c23276bd26357a31974a33df6c025100ed97 (patch)
tree3451ec60471f4bc75b5ddd67317f49d7ac7e7e89
parentd982847c9208f52a3f05cfcb1646c7945ae57087 (diff)
downloadydb-8b69c23276bd26357a31974a33df6c025100ed97.tar.gz
GroupBy over CH primitives
-rw-r--r--ydb/core/formats/custom_registry.cpp2
-rw-r--r--ydb/core/formats/program.cpp181
-rw-r--r--ydb/core/formats/program.h4
-rw-r--r--ydb/core/kqp/ut/kqp_olap_ut.cpp3
-rw-r--r--ydb/core/tx/columnshard/columnshard_common.cpp9
-rw-r--r--ydb/core/tx/columnshard/ut_columnshard_read_write.cpp187
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h154
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.cpp31
-rw-r--r--ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h36
-rw-r--r--ydb/library/arrow_clickhouse/CMakeLists.txt1
10 files changed, 486 insertions, 122 deletions
diff --git a/ydb/core/formats/custom_registry.cpp b/ydb/core/formats/custom_registry.cpp
index 590e3c3e07f..40be5634f2d 100644
--- a/ydb/core/formats/custom_registry.cpp
+++ b/ydb/core/formats/custom_registry.cpp
@@ -68,6 +68,8 @@ static void RegisterHouseAggregates(cp::FunctionRegistry* registry) {
Y_VERIFY(registry->AddFunction(std::make_shared<CH::WrappedMax>(GetHouseFunctionName(EAggregate::Max))).ok());
Y_VERIFY(registry->AddFunction(std::make_shared<CH::WrappedSum>(GetHouseFunctionName(EAggregate::Sum))).ok());
//Y_VERIFY(registry->AddFunction(std::make_shared<CH::WrappedAvg>(GetHouseFunctionName(EAggregate::Avg))).ok());
+
+ Y_VERIFY(registry->AddFunction(std::make_shared<CH::ArrowGroupBy>(GetHouseGroupByName())).ok());
} catch (const std::exception& /*ex*/) {
Y_VERIFY(false);
}
diff --git a/ydb/core/formats/program.cpp b/ydb/core/formats/program.cpp
index 4dd463fffc0..fb2d8f0d592 100644
--- a/ydb/core/formats/program.cpp
+++ b/ydb/core/formats/program.cpp
@@ -6,6 +6,32 @@
#include "program.h"
#include "arrow_helpers.h"
+
+#ifndef WIN32
+#include <AggregateFunctions/IAggregateFunction.h>
+#else
+namespace CH {
+enum class AggFunctionId {
+ AGG_UNSPECIFIED = 0,
+ AGG_ANY = 1,
+ AGG_COUNT = 2,
+ AGG_MIN = 3,
+ AGG_MAX = 4,
+ AGG_SUM = 5,
+};
+struct GroupByOptions : public arrow::compute::ScalarAggregateOptions {
+ struct Assign {
+ AggFunctionId function = AggFunctionId::AGG_UNSPECIFIED;
+ std::string result_column;
+ std::vector<std::string> arguments;
+ };
+
+ std::shared_ptr<arrow::Schema> schema;
+ std::vector<Assign> assigns;
+};
+}
+#endif
+
#include <util/system/yassert.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
@@ -280,12 +306,34 @@ const char * GetHouseFunctionName(EAggregate op) {
namespace {
+CH::AggFunctionId GetHouseFunction(EAggregate op) {
+ switch (op) {
+ case EAggregate::Some:
+ return CH::AggFunctionId::AGG_ANY;
+ case EAggregate::Count:
+ return CH::AggFunctionId::AGG_COUNT;
+ case EAggregate::Min:
+ return CH::AggFunctionId::AGG_MIN;
+ case EAggregate::Max:
+ return CH::AggFunctionId::AGG_MAX;
+ case EAggregate::Sum:
+ return CH::AggFunctionId::AGG_SUM;
+#if 0 // TODO
+ case EAggregate::Avg:
+ return CH::AggFunctionId::AGG_AVG;
+#endif
+ default:
+ break;
+ }
+ return CH::AggFunctionId::AGG_UNSPECIFIED;
+}
+
arrow::Status AddColumn(
TProgramStep::TDatumBatch& batch,
const std::string& name,
arrow::Datum&& column)
{
- if (batch.fields->GetFieldIndex(name) != -1) {
+ if (batch.schema->GetFieldIndex(name) != -1) {
return arrow::Status::Invalid("Trying to add duplicate column '" + name + "'");
}
@@ -297,13 +345,13 @@ arrow::Status AddColumn(
return arrow::Status::Invalid("Wrong column length.");
}
- batch.fields = *batch.fields->AddField(batch.fields->num_fields(), field);
+ batch.schema = *batch.schema->AddField(batch.schema->num_fields(), field);
batch.datums.emplace_back(column);
return arrow::Status::OK();
}
arrow::Result<arrow::Datum> GetColumnByName(const TProgramStep::TDatumBatch& batch, const std::string& name) {
- int i = batch.fields->GetFieldIndex(name);
+ int i = batch.schema->GetFieldIndex(name);
if (i == -1) {
return arrow::Status::Invalid("Not found column '" + name + "' or duplicate");
} else {
@@ -334,7 +382,7 @@ std::shared_ptr<arrow::RecordBatch> ToRecordBatch(TProgramStep::TDatumBatch& bat
columns.push_back(col.make_array());
}
}
- return arrow::RecordBatch::Make(batch.fields, batch.rows, columns);
+ return arrow::RecordBatch::Make(batch.schema, batch.rows, columns);
}
template <bool houseFunction, typename TOpId, typename TOptions>
@@ -398,6 +446,18 @@ arrow::Result<arrow::Datum> CallHouseFunctionByAssign(
}
}
+CH::GroupByOptions::Assign GetGroupByAssign(const TAggregateAssign& assign) {
+ CH::GroupByOptions::Assign descr;
+ descr.function = GetHouseFunction(assign.GetOperation());
+ descr.result_column = assign.GetName();
+ descr.arguments.reserve(assign.GetArguments().size());
+
+ for (auto& colName : assign.GetArguments()) {
+ descr.arguments.push_back(colName);
+ }
+ return descr;
+}
+
}
@@ -441,50 +501,85 @@ arrow::Status TProgramStep::ApplyAggregates(
return arrow::Status::OK();
}
+ ui32 numResultColumns = GroupBy.size() + GroupByKeys.size();
TDatumBatch res;
- res.rows = 1; // TODO
- res.datums.reserve(GroupBy.size());
+ res.datums.reserve(numResultColumns);
arrow::FieldVector fields;
- fields.reserve(GroupBy.size());
+ fields.reserve(numResultColumns);
- for (auto& assign : GroupBy) {
- auto funcResult = CallFunctionByAssign(assign, batch, ctx);
- if (!funcResult.ok()) {
- auto houseResult = CallHouseFunctionByAssign(assign, batch, ctx);
- if (!houseResult.ok()) {
- return funcResult.status();
+ if (GroupByKeys.empty()) {
+ for (auto& assign : GroupBy) {
+ auto funcResult = CallFunctionByAssign(assign, batch, ctx);
+ if (!funcResult.ok()) {
+ auto houseResult = CallHouseFunctionByAssign(assign, batch, ctx);
+ if (!houseResult.ok()) {
+ return funcResult.status();
+ }
+ funcResult = houseResult;
+ }
+
+ res.datums.push_back(*funcResult);
+ auto& column = res.datums.back();
+ if (!column.is_scalar()) {
+ return arrow::Status::Invalid("Aggregate result is not a scalar.");
+ }
+
+ if (column.scalar()->type->id() == arrow::Type::STRUCT) {
+ auto op = assign.GetOperation();
+ if (op == EAggregate::Min) {
+ const auto& minMax = column.scalar_as<arrow::StructScalar>();
+ column = minMax.value[0];
+ } else if (op == EAggregate::Max) {
+ const auto& minMax = column.scalar_as<arrow::StructScalar>();
+ column = minMax.value[1];
+ } else {
+ return arrow::Status::Invalid("Unexpected struct result for aggregate function.");
+ }
}
- funcResult = houseResult;
+
+ if (!column.type()) {
+ return arrow::Status::Invalid("Aggregate result has no type.");
+ }
+ fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), column.type()));
}
- res.datums.push_back(*funcResult);
- auto& column = res.datums.back();
- if (!column.is_scalar()) {
- return arrow::Status::Invalid("Aggregate result is not a scalar.");
+ res.rows = 1;
+ } else {
+ CH::GroupByOptions funcOpts;
+ funcOpts.schema = batch.schema;
+ funcOpts.assigns.reserve(numResultColumns);
+
+ for (auto& assign : GroupBy) {
+ funcOpts.assigns.emplace_back(GetGroupByAssign(assign));
}
- if (column.scalar()->type->id() == arrow::Type::STRUCT) {
- auto op = assign.GetOperation();
- if (op == EAggregate::Min) {
- const auto& minMax = column.scalar_as<arrow::StructScalar>();
- column = minMax.value[0];
- } else if (op == EAggregate::Max) {
- const auto& minMax = column.scalar_as<arrow::StructScalar>();
- column = minMax.value[1];
- } else {
- return arrow::Status::Invalid("Unexpected struct result for aggregate function");
- }
+ for (auto& key : GroupByKeys) {
+ funcOpts.assigns.emplace_back(CH::GroupByOptions::Assign{
+ .result_column = key
+ });
}
- if (!column.type()) {
- return arrow::Status::Invalid("Aggregate result has no type.");
+ auto gbRes = arrow::compute::CallFunction(GetHouseGroupByName(), batch.datums, &funcOpts, ctx);
+ if (!gbRes.ok()) {
+ return gbRes.status();
}
- fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), column.type()));
+ auto gbBatch = (*gbRes).record_batch();
+
+ for (auto& assign : funcOpts.assigns) {
+ auto column = gbBatch->GetColumnByName(assign.result_column);
+ if (!column) {
+ return arrow::Status::Invalid("No expected column in GROUP BY result.");
+ }
+ fields.emplace_back(std::make_shared<arrow::Field>(assign.result_column, column->type()));
+ res.datums.push_back(column);
+ }
+
+ res.rows = gbBatch->num_rows();
}
- res.fields = std::make_shared<arrow::Schema>(fields);
- batch = res;
+ res.schema = std::make_shared<arrow::Schema>(fields);
+ batch = std::move(res);
return arrow::Status::OK();
}
@@ -521,20 +616,22 @@ arrow::Status TProgramStep::ApplyFilters(TDatumBatch& batch) const {
std::unordered_set<std::string_view> neededColumns;
bool allColumns = Projection.empty() && GroupBy.empty();
- if (!GroupBy.empty()) {
+ if (!allColumns) {
for (auto& aggregate : GroupBy) {
for (auto& arg : aggregate.GetArguments()) {
neededColumns.insert(arg);
}
}
- } else if (!Projection.empty()) {
+ for (auto& key : GroupByKeys) {
+ neededColumns.insert(key);
+ }
for (auto& str : Projection) {
neededColumns.insert(str);
}
}
- for (int64_t i = 0; i < batch.fields->num_fields(); ++i) {
- bool needed = (allColumns || neededColumns.contains(batch.fields->field(i)->name()));
+ for (int64_t i = 0; i < batch.schema->num_fields(); ++i) {
+ bool needed = (allColumns || neededColumns.contains(batch.schema->field(i)->name()));
if (batch.datums[i].is_array() && needed) {
auto res = arrow::compute::Filter(batch.datums[i].make_array(), filter);
if (!res.ok()) {
@@ -567,17 +664,17 @@ arrow::Status TProgramStep::ApplyProjection(TDatumBatch& batch) const {
}
std::vector<std::shared_ptr<arrow::Field>> newFields;
std::vector<arrow::Datum> newDatums;
- for (int64_t i = 0; i < batch.fields->num_fields(); ++i) {
- auto& cur_field_name = batch.fields->field(i)->name();
+ for (int64_t i = 0; i < batch.schema->num_fields(); ++i) {
+ auto& cur_field_name = batch.schema->field(i)->name();
if (projSet.contains(cur_field_name)) {
- newFields.push_back(batch.fields->field(i));
+ newFields.push_back(batch.schema->field(i));
if (!newFields.back()) {
return arrow::Status::Invalid("Wrong projection.");
}
newDatums.push_back(batch.datums[i]);
}
}
- batch.fields = std::make_shared<arrow::Schema>(newFields);
+ batch.schema = std::make_shared<arrow::Schema>(newFields);
batch.datums = std::move(newDatums);
return arrow::Status::OK();
}
diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h
index f52a0b6de16..166e4e121d0 100644
--- a/ydb/core/formats/program.h
+++ b/ydb/core/formats/program.h
@@ -99,6 +99,7 @@ enum class EAggregate {
const char * GetFunctionName(EOperation op);
const char * GetFunctionName(EAggregate op);
const char * GetHouseFunctionName(EAggregate op);
+inline const char * GetHouseGroupByName() { return "ch.group_by"; }
EOperation ValidateOperation(EOperation op, ui32 argsSize);
class TAssign {
@@ -235,10 +236,11 @@ struct TProgramStep {
std::vector<TAssign> Assignes;
std::vector<std::string> Filters; // List of filter columns. Implicit "Filter by (f1 AND f2 AND .. AND fn)"
std::vector<TAggregateAssign> GroupBy;
+ std::vector<std::string> GroupByKeys; // TODO: it's possible to use them without GROUP BY for DISTINCT
std::vector<std::string> Projection; // Step's result columns (remove others)
struct TDatumBatch {
- std::shared_ptr<arrow::Schema> fields;
+ std::shared_ptr<arrow::Schema> schema;
int64_t rows;
std::vector<arrow::Datum> datums;
};
diff --git a/ydb/core/kqp/ut/kqp_olap_ut.cpp b/ydb/core/kqp/ut/kqp_olap_ut.cpp
index a8a86168fc4..0a9dccdd87a 100644
--- a/ydb/core/kqp/ut/kqp_olap_ut.cpp
+++ b/ydb/core/kqp/ut/kqp_olap_ut.cpp
@@ -1398,13 +1398,14 @@ Y_UNIT_TEST_SUITE(KqpOlap) {
level, COUNT(level)
FROM `/Root/olapStore/olapTable`
GROUP BY level
+ ORDER BY level
)";
auto it = tableClient.StreamExecuteScanQuery(query).GetValueSync();
UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString());
TString result = StreamResultToYson(it);
Cout << result << Endl;
- CompareYson(result, R"([[23000u;]])");
+ CompareYson(result, R"([[[0];4600u];[[1];4600u];[[2];4600u];[[3];4600u];[[4];4600u]])");
// Check plan
CheckPlanForAggregatePushdown(query, tableClient, { "TKqpOlapAgg" });
diff --git a/ydb/core/tx/columnshard/columnshard_common.cpp b/ydb/core/tx/columnshard/columnshard_common.cpp
index 18d578e1778..aed13a5a278 100644
--- a/ydb/core/tx/columnshard/columnshard_common.cpp
+++ b/ydb/core/tx/columnshard/columnshard_common.cpp
@@ -305,11 +305,6 @@ bool ExtractGroupBy(const TContext& info, NArrow::TProgramStep& step, const NKik
if (!groupBy.AggregatesSize()) {
return false;
}
-#if 1 // TODO
- if (groupBy.KeyColumnsSize()) {
- return false;
- }
-#endif
// It adds implicit projection with aggregates and keys. Remove non aggregated columns.
step.Projection.reserve(groupBy.KeyColumnsSize() + groupBy.AggregatesSize());
@@ -318,6 +313,7 @@ bool ExtractGroupBy(const TContext& info, NArrow::TProgramStep& step, const NKik
}
step.GroupBy.reserve(groupBy.AggregatesSize());
+ step.GroupByKeys.reserve(groupBy.KeyColumnsSize());
for (auto& agg : groupBy.GetAggregates()) {
auto& resColumn = agg.GetColumn();
TString columnName = info.GenerateName(resColumn);
@@ -329,6 +325,9 @@ bool ExtractGroupBy(const TContext& info, NArrow::TProgramStep& step, const NKik
step.GroupBy.push_back(std::move(func));
step.Projection.push_back(columnName);
}
+ for (auto& key : groupBy.GetKeyColumns()) {
+ step.GroupByKeys.push_back(info.GetName(key));
+ }
return true;
}
diff --git a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp
index 06edad85ce8..7d5f9b2e51e 100644
--- a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp
+++ b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp
@@ -114,6 +114,12 @@ bool CheckTypedStrValues(const std::shared_ptr<arrow::Array>& array, const std::
bool CheckIntValues(const std::shared_ptr<arrow::Array>& array, const std::vector<int64_t>& expected) {
UNIT_ASSERT(array);
+ std::vector<std::string> expectedStr;
+ expectedStr.reserve(expected.size());
+ for (auto& val : expected) {
+ expectedStr.push_back(ToString(val));
+ }
+
switch (array->type()->id()) {
case arrow::Type::UINT8:
return CheckTypedIntValues<arrow::UInt8Type>(array, expected);
@@ -137,25 +143,20 @@ bool CheckIntValues(const std::shared_ptr<arrow::Array>& array, const std::vecto
case arrow::Type::DURATION:
return CheckTypedIntValues<arrow::DurationType>(array, expected);
- default:
- UNIT_ASSERT(false);
- break;
- }
- return true;
-}
-
-bool CheckStringValues(const std::shared_ptr<arrow::Array>& array, const std::vector<std::string>& expected) {
- UNIT_ASSERT(array);
+ case arrow::Type::FLOAT:
+ return CheckTypedIntValues<arrow::FloatType>(array, expected);
+ case arrow::Type::DOUBLE:
+ return CheckTypedIntValues<arrow::DoubleType>(array, expected);
- switch (array->type()->id()) {
case arrow::Type::STRING:
- return CheckTypedStrValues<arrow::StringArray>(array, expected);
+ return CheckTypedStrValues<arrow::StringArray>(array, expectedStr);
case arrow::Type::BINARY:
- return CheckTypedStrValues<arrow::BinaryArray>(array, expected);
+ return CheckTypedStrValues<arrow::BinaryArray>(array, expectedStr);
case arrow::Type::FIXED_SIZE_BINARY:
- return CheckTypedStrValues<arrow::FixedSizeBinaryArray>(array, expected);
+ return CheckTypedStrValues<arrow::FixedSizeBinaryArray>(array, expectedStr);
default:
+ Cerr << "type : " << array->type()->ToString() << "\n";
UNIT_ASSERT(false);
break;
}
@@ -185,7 +186,7 @@ bool CheckOrdered(const TString& blob, const TString& srtSchema) {
return true;
}
-bool CheckColumns(const std::shared_ptr<arrow::RecordBatch>& batch, const TVector<TString>& colNames, size_t rowsCount) {
+bool CheckColumns(const std::shared_ptr<arrow::RecordBatch>& batch, const std::vector<TString>& colNames, size_t rowsCount) {
UNIT_ASSERT(batch);
for (auto& col : colNames) {
if (!batch->GetColumnByName(col)) {
@@ -200,7 +201,7 @@ bool CheckColumns(const std::shared_ptr<arrow::RecordBatch>& batch, const TVecto
return true;
}
-bool CheckColumns(const TString& blob, const NKikimrTxColumnShard::TMetadata& meta, const TVector<TString>& colNames,
+bool CheckColumns(const TString& blob, const NKikimrTxColumnShard::TMetadata& meta, const std::vector<TString>& colNames,
size_t rowsCount = 100) {
auto schema = NArrow::DeserializeSchema(meta.GetSchema());
auto batch = NArrow::DeserializeBatch(blob, schema);
@@ -975,7 +976,7 @@ NKikimrSSA::TProgram MakeSelectAggregates(ui32 columnId, const std::vector<ui32>
auto* line1 = ssa.AddCommand();
auto* groupBy = line1->MutableGroupBy();
for (ui32 key : keys) {
- groupBy->AddKeyColumns()->SetId(key);
+ groupBy->AddKeyColumns()->SetId(key + 1);
}
//
auto* l1_agg1 = groupBy->AddAggregates();
@@ -1040,7 +1041,7 @@ NKikimrSSA::TProgram MakeSelectAggregatesWithFilter(ui32 columnId, ui32 filterCo
auto* line4 = ssa.AddCommand();
auto* groupBy = line4->MutableGroupBy();
for (ui32 key : keys) {
- groupBy->AddKeyColumns()->SetId(key);
+ groupBy->AddKeyColumns()->SetId(key + 1);
}
//
auto* l4_agg1 = groupBy->AddAggregates();
@@ -1210,7 +1211,18 @@ void TestReadWithProgram(const TVector<std::pair<TString, TTypeId>>& ydbSchema =
}
}
-void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, const std::vector<ui32>& aggKeys = {}) {
+struct TReadAggregateResult {
+ ui32 NumRows = 1;
+
+ std::vector<int64_t> MinValues = {0};
+ std::vector<int64_t> MaxValues = {99};
+ std::vector<int64_t> Counts = {100};
+};
+
+void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, const TString& testDataBlob,
+ bool addProjection, const std::vector<ui32>& aggKeys = {},
+ const TReadAggregateResult& expectedResult = {},
+ const TReadAggregateResult& expectedFiltered = {1, {1}, {1}, {1}}) {
TTestBasicRuntime runtime;
TTester::Setup(runtime);
@@ -1230,7 +1242,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, co
SetupSchema(runtime, sender, tableId, ydbSchema);
{ // write some data
- bool ok = WriteData(runtime, sender, metaShard, writeId, tableId, MakeTestBlob({0, 100}, ydbSchema));
+ bool ok = WriteData(runtime, sender, metaShard, writeId, tableId, testDataBlob);
UNIT_ASSERT(ok);
ProposeCommit(runtime, sender, metaShard, txId, {writeId});
@@ -1240,29 +1252,26 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, co
// TODO: write some into index
std::vector<TString> programs;
- THashSet<ui32> intResult;
- THashSet<ui32> strResult;
THashSet<ui32> isFiltered;
+ THashSet<ui32> checkResult;
THashSet<NScheme::TTypeId> intTypes = {
NTypeIds::Int8, NTypeIds::Int16, NTypeIds::Int32, NTypeIds::Int64,
NTypeIds::Uint8, NTypeIds::Uint16, NTypeIds::Uint32, NTypeIds::Uint64,
NTypeIds::Timestamp
};
THashSet<NScheme::TTypeId> strTypes = {
- NTypeIds::Utf8, NTypeIds::String, NTypeIds::Bytes
+ NTypeIds::Utf8, NTypeIds::String
//NTypeIds::Yson, NTypeIds::Json, NTypeIds::JsonDocument
};
ui32 prog = 0;
for (ui32 i = 0; i < ydbSchema.size(); ++i, ++prog) {
- if (intTypes.count(ydbSchema[i].second)) {
- intResult.insert(prog);
- }
- if (strTypes.count(ydbSchema[i].second)) {
- strResult.insert(prog);
+ if (intTypes.count(ydbSchema[i].second) ||
+ strTypes.count(ydbSchema[i].second)) {
+ checkResult.insert(prog);
}
- NKikimrSSA::TProgram ssa = MakeSelectAggregates(i + 1, aggKeys, i % 2);
+ NKikimrSSA::TProgram ssa = MakeSelectAggregates(i + 1, aggKeys, addProjection);
TString serialized;
UNIT_ASSERT(ssa.SerializeToString(&serialized));
NKikimrSSA::TOlapProgram program;
@@ -1274,14 +1283,12 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, co
for (ui32 i = 0; i < ydbSchema.size(); ++i, ++prog) {
isFiltered.insert(prog);
- if (intTypes.count(ydbSchema[i].second)) {
- intResult.insert(prog);
- }
- if (strTypes.count(ydbSchema[i].second)) {
- strResult.insert(prog);
+ if (intTypes.count(ydbSchema[i].second) ||
+ strTypes.count(ydbSchema[i].second)) {
+ checkResult.insert(prog);
}
- NKikimrSSA::TProgram ssa = MakeSelectAggregatesWithFilter(i + 1, 4, aggKeys, i % 2);
+ NKikimrSSA::TProgram ssa = MakeSelectAggregatesWithFilter(i + 1, 4, aggKeys, addProjection);
TString serialized;
UNIT_ASSERT(ssa.SerializeToString(&serialized));
NKikimrSSA::TOlapProgram program;
@@ -1291,8 +1298,19 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, co
UNIT_ASSERT(program.SerializeToString(&programs.back()));
}
+ std::vector<TString> namedColumns = {"res_min", "res_max", "res_some", "res_count"};
+ std::vector<TString> unnamedColumns = {"100", "101", "102", "103"};
+ if (!addProjection) {
+ for (auto& key : aggKeys) {
+ namedColumns.push_back(ydbSchema[key].first);
+ unnamedColumns.push_back(ydbSchema[key].first);
+ }
+ }
+
prog = 0;
for (auto& programText : programs) {
+ Cerr << "-- select program: " << prog << " is filtered: " << (int)isFiltered.count(prog) << "\n";
+
auto* readEvent = new TEvColumnShard::TEvRead(sender, metaShard, planStep, txId, tableId);
auto& readProto = Proto(readEvent);
@@ -1325,38 +1343,24 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema, co
UNIT_ASSERT(batch->ValidateFull().ok());
}
- if (aggKeys.empty()) {
- if (intResult.count(prog)) {
- if (isFiltered.count(prog)) {
- UNIT_ASSERT(CheckColumns(batch, {"res_min", "res_max", "res_some", "res_count"}, 1));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_min"), {1}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_max"), {1}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_some"), {1}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_count"), {1}));
- } else {
- UNIT_ASSERT(CheckColumns(batch, {"100", "101", "102", "103"}, 1));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("100"), {0}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("101"), {99}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("102"), {0}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("103"), {100}));
+ if (checkResult.count(prog)) {
+ if (isFiltered.count(prog)) {
+ UNIT_ASSERT(CheckColumns(batch, namedColumns, expectedFiltered.NumRows));
+ if (aggKeys.empty()) { // TODO: ORDER BY for compare
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_min"), expectedFiltered.MinValues));
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_max"), expectedFiltered.MaxValues));
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_some"), expectedFiltered.MinValues));
}
- } else if (strResult.count(prog)) {
- if (isFiltered.count(prog)) {
- UNIT_ASSERT(CheckColumns(batch, {"res_min", "res_max", "res_some", "res_count"}, 1));
- UNIT_ASSERT(CheckStringValues(batch->GetColumnByName("res_min"), {"1"}));
- UNIT_ASSERT(CheckStringValues(batch->GetColumnByName("res_max"), {"1"}));
- UNIT_ASSERT(CheckStringValues(batch->GetColumnByName("res_some"), {"1"}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_count"), {1}));
- } else {
- UNIT_ASSERT(CheckColumns(batch, {"100", "101", "102", "103"}, 1));
- UNIT_ASSERT(CheckStringValues(batch->GetColumnByName("100"), {"0"}));
- UNIT_ASSERT(CheckStringValues(batch->GetColumnByName("101"), {"99"}));
- UNIT_ASSERT(CheckStringValues(batch->GetColumnByName("102"), {"0"}));
- UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("103"), {100}));
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("res_count"), expectedFiltered.Counts));
+ } else {
+ UNIT_ASSERT(CheckColumns(batch, unnamedColumns, expectedResult.NumRows));
+ if (aggKeys.empty()) { // TODO: ORDER BY for compare
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("100"), expectedResult.MinValues));
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("101"), expectedResult.MaxValues));
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("102"), expectedResult.MinValues));
}
+ UNIT_ASSERT(CheckIntValues(batch->GetColumnByName("103"), expectedResult.Counts));
}
- } else {
- // TODO
}
++prog;
@@ -1411,25 +1415,62 @@ Y_UNIT_TEST_SUITE(TColumnShardTestReadWrite) {
}
Y_UNIT_TEST(ReadAggregate) {
- TestReadAggregate(TTestSchema::YdbAllTypesSchema());
+ auto schema = TTestSchema::YdbAllTypesSchema();
+ auto testBlob = MakeTestBlob({0, 100}, schema);
+
+ TestReadAggregate(schema, testBlob, false);
+ TestReadAggregate(schema, testBlob, true);
}
-#if 0
Y_UNIT_TEST(ReadGroupBy) {
auto schema = TTestSchema::YdbAllTypesSchema();
- for (ui32 keyPos = 0; keyPos < schema.size(); ++keyPos) {
- TestReadAggregate(schema, {keyPos});
- }
+ auto testBlob = MakeTestBlob({0, 100}, schema);
- for (ui32 keyPos = 0; keyPos < schema.size() - 1; ++keyPos) {
- TestReadAggregate(schema, {keyPos, keyPos + 1});
+ std::vector<int64_t> counts;
+ counts.reserve(100);
+ for (int i = 0; i < 100; ++i) {
+ counts.push_back(1);
}
- for (ui32 keyPos = 0; keyPos < schema.size() - 2; ++keyPos) {
- TestReadAggregate(schema, {keyPos, keyPos + 1, keyPos + 2});
+ THashSet<NScheme::TTypeId> sameValTypes = {
+ NTypeIds::Yson, NTypeIds::Json, NTypeIds::JsonDocument
+ };
+
+ // TODO: query needs normalization to compare with expected
+ TReadAggregateResult resDefault = {100, {}, {}, counts};
+ TReadAggregateResult resFiltered = {1, {}, {}, {1}};
+ TReadAggregateResult resGrouped = {1, {}, {}, {100}};
+
+ for (ui32 key = 0; key < schema.size(); ++key) {
+ Cerr << "-- group by key: " << key << "\n";
+
+ // the type has the same values in test batch so result would be grouped in one row
+ if (sameValTypes.count(schema[key].second)) {
+ TestReadAggregate(schema, testBlob, (key % 2), {key}, resGrouped, resFiltered);
+ } else {
+ TestReadAggregate(schema, testBlob, (key % 2), {key}, resDefault, resFiltered);
+ }
+ }
+ for (ui32 key = 0; key < schema.size() - 1; ++key) {
+ Cerr << "-- group by key: " << key << ", " << key + 1 << "\n";
+ if (sameValTypes.count(schema[key].second) &&
+ sameValTypes.count(schema[key + 1].second)) {
+ TestReadAggregate(schema, testBlob, (key % 2), {key, key + 1}, resGrouped, resFiltered);
+ } else {
+ TestReadAggregate(schema, testBlob, (key % 2), {key, key + 1}, resDefault, resFiltered);
+ }
+ }
+ for (ui32 key = 0; key < schema.size() - 2; ++key) {
+ Cerr << "-- group by key: " << key << ", " << key + 1 << ", " << key + 2 << "\n";
+ if (sameValTypes.count(schema[key].second) &&
+ sameValTypes.count(schema[key + 1].second) &&
+ sameValTypes.count(schema[key + 1].second)) {
+ TestReadAggregate(schema, testBlob, (key % 2), {key, key + 1, key + 2}, resGrouped, resFiltered);
+ } else {
+ TestReadAggregate(schema, testBlob, (key % 2), {key, key + 1, key + 2}, resDefault, resFiltered);
+ }
}
}
-#endif
Y_UNIT_TEST(CompactionSplitGranule) {
TTestBasicRuntime runtime;
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h
index 732b07f8e3e..bf25ab585ef 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/AggregateFunctionWrapper.h
@@ -5,6 +5,8 @@
#include <DataStreams/OneBlockInputStream.h>
#include <DataStreams/AggregatingBlockInputStream.h>
+#include <unordered_set>
+
namespace CH
{
@@ -86,4 +88,156 @@ public:
}
};
+class ArrowGroupBy : public arrow::compute::ScalarAggregateFunction
+{
+public:
+ ArrowGroupBy(std::string name)
+ : arrow::compute::ScalarAggregateFunction(std::move(name), arrow::compute::Arity::VarArgs(1), nullptr)
+ {}
+
+ arrow::Result<arrow::Datum> Execute(
+ const std::vector<arrow::Datum>& args,
+ const arrow::compute::FunctionOptions* options,
+ arrow::compute::ExecContext* /*ctx*/) const override
+ {
+ if (args.empty())
+ return arrow::Status::Invalid("GROUP BY without arguments");
+ if (!options)
+ return arrow::Status::Invalid("GROUP BY without options");
+
+ auto* opts = dynamic_cast<const GroupByOptions*>(options);
+ if (!opts || !opts->schema)
+ return arrow::Status::Invalid("Wrong GROUP BY options");
+ if ((int)args.size() != opts->schema->num_fields())
+ return arrow::Status::Invalid("Wrong GROUP BY arguments count");
+
+ // Find needed columns
+ std::unordered_set<std::string> needed_columns;
+ needed_columns.reserve(opts->assigns.size());
+ for (auto& assign : opts->assigns)
+ {
+ if (assign.function != AggFunctionId::AGG_UNSPECIFIED)
+ {
+ for (auto& agg_arg : assign.arguments)
+ needed_columns.insert(agg_arg);
+ }
+ else
+ needed_columns.insert(assign.result_column);
+ }
+
+ // Make batch with needed columns
+ std::shared_ptr<arrow::RecordBatch> batch;
+ {
+ std::vector<std::shared_ptr<arrow::Array>> columns;
+ std::vector<std::shared_ptr<arrow::Field>> fields;
+ columns.reserve(needed_columns.size());
+ fields.reserve(needed_columns.size());
+
+ int64_t num_rows = 0;
+ for (int i = 0; i < opts->schema->num_fields(); ++i) {
+ auto& datum = args[i];
+ auto& field = opts->schema->field(i);
+
+ if (!needed_columns.count(field->name()))
+ continue;
+
+ if (datum.is_array())
+ {
+ if (num_rows && num_rows != datum.mutable_array()->length)
+ return arrow::Status::Invalid("Arrays have different length");
+ num_rows = datum.mutable_array()->length;
+ }
+ else if (!datum.is_scalar())
+ return arrow::Status::Invalid("Bad scalar: '" + field->name() + "'");
+ }
+ if (!num_rows) // All datums are scalars
+ num_rows = 1;
+
+ for (int i = 0; i < opts->schema->num_fields(); ++i) {
+ auto& datum = args[i];
+ auto& field = opts->schema->field(i);
+
+ if (!needed_columns.count(field->name()))
+ continue;
+
+ if (datum.is_scalar())
+ {
+ // TODO: better GROUP BY over scalars
+ auto res = arrow::MakeArrayFromScalar(*datum.scalar(), num_rows);
+ if (!res.ok())
+ return arrow::Status::Invalid("Bad scalar: '" + field->name() + "'");
+ columns.push_back(*res);
+ }
+ else
+ columns.push_back(datum.make_array());
+
+ fields.push_back(field);
+ }
+
+ batch = arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), num_rows, columns);
+ if (!batch)
+ return arrow::Status::Invalid("Wrong GROUP BY arguments: cannot make batch");
+ }
+
+ // Make aggregats descriptions
+ std::vector<AggregateDescription> descriptions;
+ ColumnNumbers keys;
+ {
+ descriptions.reserve(opts->assigns.size());
+ keys.reserve(opts->assigns.size());
+
+ auto& schema = batch->schema();
+
+ for (auto& assign : opts->assigns)
+ {
+ if (assign.function != AggFunctionId::AGG_UNSPECIFIED)
+ {
+ ColumnNumbers arg_positions;
+ arg_positions.reserve(assign.arguments.size());
+ DataTypes types;
+ types.reserve(assign.arguments.size());
+
+ for (auto& agg_arg : assign.arguments) {
+ int pos = schema->GetFieldIndex(agg_arg);
+ if (pos < 0)
+ return arrow::Status::Invalid("Unexpected aggregate function argument in GROUP BY");
+ arg_positions.push_back(pos);
+ types.push_back(schema->field(pos)->type());
+ }
+
+ AggregateFunctionPtr func = GetAggregateFunction(assign.function, types);
+ if (!func)
+ return arrow::Status::Invalid("Unexpected agregate function in GROUP BY");
+
+ descriptions.emplace_back(AggregateDescription{
+ .function = func,
+ .arguments = arg_positions,
+ .column_name = assign.result_column
+ });
+ } else {
+ int pos = schema->GetFieldIndex(assign.result_column);
+ if (pos < 0)
+ return arrow::Status::Invalid("Unexpected key in GROUP BY: '" + assign.result_column + "'");
+ keys.push_back(pos);
+ }
+ }
+ }
+
+ // GROUP BY
+
+ auto input_stream = std::make_shared<OneBlockInputStream>(batch);
+
+ Aggregator::Params agg_params(false, input_stream->getHeader(), keys, descriptions, false);
+ AggregatingBlockInputStream agg_stream(input_stream, agg_params, true);
+
+ auto result_batch = agg_stream.read();
+ if (!result_batch || result_batch->num_rows() == 0)
+ return arrow::Status::Invalid("unexpected arrgerate result");
+ if (agg_stream.read())
+ return arrow::Status::Invalid("unexpected second batch in aggregate result");
+
+ return arrow::Result<arrow::Datum>(result_batch);
+ }
+};
+
}
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.cpp b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.cpp
new file mode 100644
index 00000000000..87eccca5e41
--- /dev/null
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.cpp
@@ -0,0 +1,31 @@
+#include <AggregateFunctions/IAggregateFunction.h>
+#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
+#include <AggregateFunctions/AggregateFunctionCount.h>
+#include <AggregateFunctions/AggregateFunctionSum.h>
+#include <AggregateFunctions/AggregateFunctionAvg.h>
+
+namespace CH
+{
+
+AggregateFunctionPtr GetAggregateFunction(AggFunctionId id, const DataTypes & argument_types)
+{
+ switch (id) {
+ case AggFunctionId::AGG_ANY:
+ return WrappedAny("").getHouseFunction(argument_types);
+ case AggFunctionId::AGG_COUNT:
+ return WrappedCount("").getHouseFunction(argument_types);
+ case AggFunctionId::AGG_MIN:
+ return WrappedMin("").getHouseFunction(argument_types);
+ case AggFunctionId::AGG_MAX:
+ return WrappedMax("").getHouseFunction(argument_types);
+ case AggFunctionId::AGG_SUM:
+ return WrappedSum("").getHouseFunction(argument_types);
+ case AggFunctionId::AGG_AVG:
+ return WrappedAvg("").getHouseFunction(argument_types);
+ default:
+ break;
+ }
+ return {};
+}
+
+}
diff --git a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h
index dc23c3a2a8f..264c44a4fc5 100644
--- a/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h
+++ b/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.h
@@ -481,4 +481,40 @@ struct AggregateFunctionProperties
bool is_order_dependent = false;
};
+enum class AggFunctionId {
+ AGG_UNSPECIFIED = 0,
+ AGG_ANY = 1,
+ AGG_COUNT = 2,
+ AGG_MIN = 3,
+ AGG_MAX = 4,
+ AGG_SUM = 5,
+ AGG_AVG = 6,
+ //AGG_VAR = 7,
+ //AGG_COVAR = 8,
+ //AGG_STDDEV = 9,
+ //AGG_CORR = 10,
+ //AGG_ARG_MIN = 11,
+ //AGG_ARG_MAX = 12,
+ //AGG_COUNT_DISTINCT = 13,
+ //AGG_QUANTILES = 14,
+ //AGG_TOP_COUNT = 15,
+ //AGG_TOP_SUM = 16,
+};
+
+struct GroupByOptions : public arrow::compute::ScalarAggregateOptions {
+ // We have to return aggregates + aggregate keys in result.
+ // We use pair {AGG_UNSPECIFIED, result_column} to specify a key.
+ // Then we could place aggregates and keys in one vector to set their order in result.
+ struct Assign {
+ AggFunctionId function = AggFunctionId::AGG_UNSPECIFIED;
+ std::string result_column;
+ std::vector<std::string> arguments;
+ };
+
+ std::shared_ptr<arrow::Schema> schema; // types and names of input arguments
+ std::vector<Assign> assigns; // aggregates and keys in needed result order
+};
+
+AggregateFunctionPtr GetAggregateFunction(AggFunctionId, const DataTypes & argument_types);
+
}
diff --git a/ydb/library/arrow_clickhouse/CMakeLists.txt b/ydb/library/arrow_clickhouse/CMakeLists.txt
index 012192db9e2..afe895bcb27 100644
--- a/ydb/library/arrow_clickhouse/CMakeLists.txt
+++ b/ydb/library/arrow_clickhouse/CMakeLists.txt
@@ -27,6 +27,7 @@ target_link_libraries(ydb-library-arrow_clickhouse PUBLIC
library-arrow_clickhouse-DataStreams
)
target_sources(ydb-library-arrow_clickhouse PRIVATE
+ ${CMAKE_SOURCE_DIR}/ydb/library/arrow_clickhouse/AggregateFunctions/IAggregateFunction.cpp
${CMAKE_SOURCE_DIR}/ydb/library/arrow_clickhouse/Aggregator.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/udfs/common/clickhouse/client/base/common/mremap.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/udfs/common/clickhouse/client/base/common/getPageSize.cpp