diff options
author | Artem Zuikov <chertus@gmail.com> | 2022-06-01 11:37:03 +0300 |
---|---|---|
committer | Artem Zuikov <chertus@gmail.com> | 2022-06-01 11:37:03 +0300 |
commit | c051815bfe2221ba1b1ae457868e7ca59fd58fca (patch) | |
tree | a1d835d8d90b9f4f73a2c2b358cf3a44922d635e | |
parent | 45799c2715cdd8b97f99eba8864d365389056b3d (diff) | |
download | ydb-c051815bfe2221ba1b1ae457868e7ca59fd58fca.tar.gz |
KIKIMR-14822: call basic aggregates from arrow library in ColumnShard
ref:94ad8071bd9d77eadbc386cca245b6fcf0eb4854
-rw-r--r-- | ydb/core/formats/program.cpp | 219 | ||||
-rw-r--r-- | ydb/core/formats/program.h | 46 | ||||
-rw-r--r-- | ydb/core/protos/ssa.proto | 8 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/columnshard_common.cpp | 163 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/ut_columnshard_read_write.cpp | 23 |
5 files changed, 346 insertions, 113 deletions
diff --git a/ydb/core/formats/program.cpp b/ydb/core/formats/program.cpp index f8d8287ab5..10534a6435 100644 --- a/ydb/core/formats/program.cpp +++ b/ydb/core/formats/program.cpp @@ -151,9 +151,114 @@ const char * GetFunctionName(EOperation op) { return ""; } -void AddColumn(std::shared_ptr<TProgramStep::TDatumBatch>& batch, - std::string field_name, - const arrow::Datum& column) { +EOperation ValidateOperation(EOperation op, ui32 argsSize) { + switch (op) { + case EOperation::Equal: + case EOperation::NotEqual: + case EOperation::Less: + case EOperation::LessEqual: + case EOperation::Greater: + case EOperation::GreaterEqual: + case EOperation::MatchSubstring: + case EOperation::And: + case EOperation::Or: + case EOperation::Xor: + case EOperation::Add: + case EOperation::Subtract: + case EOperation::Multiply: + case EOperation::Divide: + case EOperation::Modulo: + case EOperation::AddNotNull: + case EOperation::SubtractNotNull: + case EOperation::MultiplyNotNull: + case EOperation::DivideNotNull: + case EOperation::ModuloOrZero: + case EOperation::Gcd: + case EOperation::Lcm: + if (argsSize == 2) { + return op; + } + break; + + case EOperation::CastBoolean: + case EOperation::CastInt8: + case EOperation::CastInt16: + case EOperation::CastInt32: + case EOperation::CastInt64: + case EOperation::CastUInt8: + case EOperation::CastUInt16: + case EOperation::CastUInt32: + case EOperation::CastUInt64: + case EOperation::CastFloat: + case EOperation::CastDouble: + case EOperation::CastBinary: + case EOperation::CastFixedSizeBinary: + case EOperation::CastString: + case EOperation::CastTimestamp: + case EOperation::IsValid: + case EOperation::IsNull: + case EOperation::BinaryLength: + case EOperation::Invert: + case EOperation::Abs: + case EOperation::Negate: + if (argsSize == 1) { + return op; + } + break; + + case EOperation::Acosh: + case EOperation::Atanh: + case EOperation::Cbrt: + case EOperation::Cosh: + case EOperation::E: + case EOperation::Erf: + case EOperation::Erfc: + case EOperation::Exp: + case EOperation::Exp2: + case EOperation::Exp10: + case EOperation::Hypot: + case EOperation::Lgamma: + case EOperation::Pi: + case EOperation::Sinh: + case EOperation::Sqrt: + case EOperation::Tgamma: + case EOperation::Floor: + case EOperation::Ceil: + case EOperation::Trunc: + case EOperation::Round: + case EOperation::RoundBankers: + case EOperation::RoundToExp2: + return op; // TODO: check + + default: + break; + } + return EOperation::Unspecified; +} + +const char * GetFunctionName(EAggregate op) { + switch (op) { + case EAggregate::Any: + return "any"; + case EAggregate::Count: + return "count"; + case EAggregate::Min: + return "min_max"; + case EAggregate::Max: + return "min_max"; + case EAggregate::Sum: + return "sum"; + case EAggregate::Avg: + return "mean"; + + default: + break; + } + return ""; +} + + +void AddColumn(std::shared_ptr<TProgramStep::TDatumBatch>& batch, std::string field_name, const arrow::Datum& column) { auto field = ::arrow::field(std::move(field_name), column.type()); Y_VERIFY(field != nullptr); Y_VERIFY(field->type()->Equals(column.type())); @@ -204,55 +309,12 @@ std::shared_ptr<arrow::Array> MakeConstantColumn(const arrow::Scalar& value, int return *res; } -//firstly try to call function from custom registry, if fails call from default -arrow::Result<arrow::Datum> CallFromCustomOrDefaultRegistry(EOperation funcId, const std::vector<arrow::Datum>& arguments, arrow::compute::ExecContext* ctx) { - std::string funcName = GetFunctionName(funcId); - if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) { - return arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx); - } else { - return arrow::compute::CallFunction(GetFunctionName(funcId), arguments); - } -} - -std::shared_ptr<arrow::Array> CallArrayFunction(EOperation funcId, const std::vector<std::string>& args, - std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) { - std::vector<arrow::Datum> arguments; - arguments.reserve(args.size()); - - for (auto& colName : args) { - auto column = batch->GetColumnByName(colName); - Y_VERIFY(column); - arguments.push_back(arrow::Datum(*column)); - } - std::string funcName = GetFunctionName(funcId); - arrow::Result<arrow::Datum> result; - result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx); - Y_VERIFY(result.ok()); - Y_VERIFY(result->is_array()); - return result->make_array(); -} - - -std::shared_ptr<arrow::Scalar> CallScalarFunction(EOperation funcId, const std::vector<std::string>& args, - std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) { - std::vector<arrow::Datum> arguments; - arguments.reserve(args.size()); - - for (auto& colName : args) { - auto column = batch->GetColumnByName(colName); - Y_VERIFY(column); - arguments.push_back(arrow::Datum{column}); - } - std::string funcName = GetFunctionName(funcId); - arrow::Result<arrow::Datum> result; - result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx); - Y_VERIFY(result.ok()); - Y_VERIFY(result->is_scalar()); - return result->scalar(); -} - -arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, const arrow::compute::FunctionOptions* funcOpts, - std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) { +template <typename TOpId, typename TOptions> +arrow::Datum CallFunctionById(TOpId funcId, const std::vector<std::string>& args, + const TOptions* funcOpts, + std::shared_ptr<TProgramStep::TDatumBatch> batch, + arrow::compute::ExecContext* ctx) +{ std::vector<arrow::Datum> arguments; arguments.reserve(args.size()); @@ -269,15 +331,27 @@ arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& } else { result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, funcOpts); } - Y_VERIFY(result.ok()); + Y_VERIFY_S(result.ok(), result.status().message()); return result.ValueOrDie(); } -arrow::Datum CallFunctionByAssign(const TAssign& assign, std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) { +arrow::Datum CallFunctionByAssign(const TAssign& assign, + std::shared_ptr<TProgramStep::TDatumBatch> batch, + arrow::compute::ExecContext* ctx) +{ return CallFunctionById(assign.GetOperation(), assign.GetArguments(), assign.GetFunctionOptions(), batch, ctx); } -void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const { +arrow::Datum CallFunctionByAssign(const TAggregateAssign& assign, + std::shared_ptr<TProgramStep::TDatumBatch> batch, + arrow::compute::ExecContext* ctx) +{ + return CallFunctionById(assign.GetOperation(), assign.GetArguments(), &assign.GetAggregateOptions(), batch, ctx); +} + +void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, + arrow::compute::ExecContext* ctx) const +{ if (Assignes.empty()) { return; } @@ -296,6 +370,40 @@ void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& bat //Y_VERIFY(batch->Validate().ok()); } +void TProgramStep::ApplyAggregates(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const { + if (GroupBy.empty()) { + return; + } + + auto res = std::make_shared<TDatumBatch>(); + res->rows = 1; // TODO + res->datums.reserve(GroupBy.size()); + + arrow::FieldVector fields; + fields.reserve(GroupBy.size()); + + for (auto& assign : GroupBy) { + res->datums.push_back(CallFunctionByAssign(assign, batch, ctx)); + auto& column = res->datums.back(); + Y_VERIFY_S(column.is_scalar(), TStringBuilder() << "Aggregate result is not a scalar."); + + auto op = assign.GetOperation(); + if (op == EAggregate::Min) { + const auto& minMax = column.scalar_as<arrow::StructScalar>(); + column = minMax.value[0]; + } else if (op == EAggregate::Max) { + const auto& minMax = column.scalar_as<arrow::StructScalar>(); + column = minMax.value[1]; + } + + Y_VERIFY_S(column.type(), TStringBuilder() << "Aggregate result has no type."); + fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), column.type())); + } + + res->fields = std::make_shared<arrow::Schema>(fields); + batch = res; +} + void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const { if (Filters.empty()) { return; @@ -384,6 +492,7 @@ void TProgramStep::Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::comp auto rb = ToTDatumBatch(batch); ApplyAssignes(rb, ctx); ApplyFilters(rb); + ApplyAggregates(rb, ctx); ApplyProjection(rb); batch = ToRecordBatch(rb); } diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h index f4b7d466ab..163f82a5c4 100644 --- a/ydb/core/formats/program.h +++ b/ydb/core/formats/program.h @@ -1,5 +1,6 @@ #pragma once #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> #include <util/system/types.h> @@ -83,20 +84,32 @@ enum class EOperation { RoundToExp2 }; +enum class EAggregate { + Unspecified = 0, + Any = 1, + Count = 2, + Min = 3, + Max = 4, + Sum = 5, + Avg = 6, +}; + const char * GetFunctionName(EOperation op); +const char * GetFunctionName(EAggregate op); +EOperation ValidateOperation(EOperation op, ui32 argsSize); class TAssign { public: TAssign(const std::string& name, EOperation op, std::vector<std::string>&& args) : Name(name) - , Operation(op) + , Operation(ValidateOperation(op, args.size())) , Arguments(std::move(args)) , FuncOpts(nullptr) {} TAssign(const std::string& name, EOperation op, std::vector<std::string>&& args, std::shared_ptr<arrow::compute::FunctionOptions> funcOpts) : Name(name) - , Operation(op) + , Operation(ValidateOperation(op, args.size())) , Arguments(std::move(args)) , FuncOpts(funcOpts) {} @@ -165,6 +178,7 @@ public: {} bool IsConstant() const { return Operation == EOperation::Constant; } + bool IsOk() const { return Operation != EOperation::Unspecified; } EOperation GetOperation() const { return Operation; } const std::vector<std::string>& GetArguments() const { return Arguments; } std::shared_ptr<arrow::Scalar> GetConstant() const { return Constant; } @@ -179,6 +193,31 @@ private: std::shared_ptr<arrow::compute::FunctionOptions> FuncOpts; }; +class TAggregateAssign { +public: + TAggregateAssign(const std::string& name, EAggregate op, std::string&& arg) + : Name(name) + , Operation(op) + , Arguments({std::move(arg)}) + { + if (arg.empty()) { + op = EAggregate::Unspecified; + } + } + + bool IsOk() const { return Operation != EAggregate::Unspecified; } + EAggregate GetOperation() const { return Operation; } + const std::vector<std::string>& GetArguments() const { return Arguments; } + const std::string& GetName() const { return Name; } + const arrow::compute::ScalarAggregateOptions& GetAggregateOptions() const { return ScalarOpts; } + +private: + std::string Name; + EAggregate Operation{EAggregate::Unspecified}; + std::vector<std::string> Arguments; + arrow::compute::ScalarAggregateOptions ScalarOpts; // TODO: make correct options +}; + /// Group of commands that finishes with projection. Steps add locality for columns definition. /// /// In step we have non-decreasing count of columns (line to line) till projection. So columns are either source @@ -191,8 +230,8 @@ private: struct TProgramStep { std::vector<TAssign> Assignes; std::vector<std::string> Filters; // List of filter columns. Implicit "Filter by (f1 AND f2 AND .. AND fn)" + std::vector<TAggregateAssign> GroupBy; std::vector<std::string> Projection; // Step's result columns (remove others) - // TODO: group by struct TDatumBatch { std::shared_ptr<arrow::Schema> fields; @@ -207,6 +246,7 @@ struct TProgramStep { void Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::compute::ExecContext* ctx) const; void ApplyAssignes(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const; + void ApplyAggregates(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const; void ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const; void ApplyProjection(std::shared_ptr<arrow::RecordBatch>& batch) const; void ApplyProjection(std::shared_ptr<TDatumBatch>& batch) const; diff --git a/ydb/core/protos/ssa.proto b/ydb/core/protos/ssa.proto index 0b6774cf17..88eaeee8b4 100644 --- a/ydb/core/protos/ssa.proto +++ b/ydb/core/protos/ssa.proto @@ -100,10 +100,10 @@ message TProgram { AGG_UNSPECIFIED = 0; AGG_ANY = 1; AGG_COUNT = 2; - //AGG_MIN = 3; - //AGG_MAX = 4; - //AGG_SUM = 5; - //AGG_AVG = 6; + AGG_MIN = 3; + AGG_MAX = 4; + AGG_SUM = 5; + AGG_AVG = 6; //AGG_VAR = 7; //AGG_COVAR = 8; //AGG_STDDEV = 9; diff --git a/ydb/core/tx/columnshard/columnshard_common.cpp b/ydb/core/tx/columnshard/columnshard_common.cpp index 0662e5115f..ecacebdc80 100644 --- a/ydb/core/tx/columnshard/columnshard_common.cpp +++ b/ydb/core/tx/columnshard/columnshard_common.cpp @@ -39,13 +39,13 @@ TString FromCells(const TConstArrayRef<TCell>& cells, const TVector<std::pair<TS struct TContext { const IColumnResolver& ColumnResolver; - THashMap<ui32, TString> Sources; + mutable THashMap<ui32, TString> Sources; explicit TContext(const IColumnResolver& columnResolver) : ColumnResolver(columnResolver) {} - std::string GetName(ui32 columnId) { + std::string GetName(ui32 columnId) const { TString name = ColumnResolver.GetColumnName(columnId, false); if (name.Empty()) { name = ToString(columnId); @@ -56,14 +56,13 @@ struct TContext { } }; -NArrow::TAssign MakeFunction(TContext& info, const std::string& name, +NArrow::TAssign MakeFunction(const TContext& info, const std::string& name, const NKikimrSSA::TProgram::TAssignment::TFunction& func) { using TId = NKikimrSSA::TProgram::TAssignment; using EOperation = NArrow::EOperation; using TAssign = NArrow::TAssign; - auto& args = func.GetArguments(); - TVector<std::string> arguments; + std::vector<std::string> arguments; for (auto& col : func.GetArguments()) { ui32 columnId = col.GetId(); arguments.push_back(info.GetName(columnId)); @@ -71,66 +70,48 @@ NArrow::TAssign MakeFunction(TContext& info, const std::string& name, switch (func.GetId()) { case TId::FUNC_CMP_EQUAL: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Equal, std::move(arguments)); case TId::FUNC_CMP_NOT_EQUAL: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::NotEqual, std::move(arguments)); case TId::FUNC_CMP_LESS: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Less, std::move(arguments)); case TId::FUNC_CMP_LESS_EQUAL: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::LessEqual, std::move(arguments)); case TId::FUNC_CMP_GREATER: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Greater, std::move(arguments)); case TId::FUNC_CMP_GREATER_EQUAL: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::GreaterEqual, std::move(arguments)); case TId::FUNC_IS_NULL: - Y_VERIFY(args.size() == 1); return TAssign(name, EOperation::IsNull, std::move(arguments)); case TId::FUNC_STR_LENGTH: - Y_VERIFY(args.size() == 1); return TAssign(name, EOperation::BinaryLength, std::move(arguments)); case TId::FUNC_STR_MATCH: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::MatchSubstring, std::move(arguments)); case TId::FUNC_BINARY_NOT: - Y_VERIFY(args.size() == 1); return TAssign(name, EOperation::Invert, std::move(arguments)); case TId::FUNC_BINARY_AND: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::And, std::move(arguments)); case TId::FUNC_BINARY_OR: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Or, std::move(arguments)); case TId::FUNC_BINARY_XOR: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Xor, std::move(arguments)); case TId::FUNC_MATH_ADD: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Add, std::move(arguments)); case TId::FUNC_MATH_SUBTRACT: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Subtract, std::move(arguments)); case TId::FUNC_MATH_MULTIPLY: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Multiply, std::move(arguments)); case TId::FUNC_MATH_DIVIDE: - Y_VERIFY(args.size() == 2); return TAssign(name, EOperation::Divide, std::move(arguments)); case TId::FUNC_CAST_TO_INT32: { - Y_VERIFY(args.size() == 1); // TODO: support CAST with OrDefault/OrNull logic (second argument is default value) + // TODO: support CAST with OrDefault/OrNull logic (second argument is default value) auto castOpts = std::make_shared<arrow::compute::CastOptions>(false); castOpts->to_type = std::make_shared<arrow::Int32Type>(); return TAssign(name, EOperation::CastInt32, std::move(arguments), castOpts); } case TId::FUNC_CAST_TO_TIMESTAMP: { - Y_VERIFY(args.size() == 1); auto castOpts = std::make_shared<arrow::compute::CastOptions>(false); castOpts->to_type = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::MICRO); return TAssign(name, EOperation::CastTimestamp, std::move(arguments), castOpts); @@ -149,11 +130,12 @@ NArrow::TAssign MakeFunction(TContext& info, const std::string& name, case TId::FUNC_UNSPECIFIED: break; } - Y_VERIFY(false); // unexpected + return TAssign(name, EOperation::Unspecified, std::move(arguments)); } NArrow::TAssign MakeConstant(const std::string& name, const NKikimrSSA::TProgram::TConstant& constant) { using TId = NKikimrSSA::TProgram::TConstant; + using EOperation = NArrow::EOperation; using TAssign = NArrow::TAssign; switch (constant.GetValueCase()) { @@ -182,9 +164,40 @@ NArrow::TAssign MakeConstant(const std::string& name, const NKikimrSSA::TProgram return TAssign(name, std::string(str.data(), str.size())); } case TId::VALUE_NOT_SET: - Y_VERIFY(false); // unexpected break; } + return TAssign(name, EOperation::Unspecified, {}); +} + +NArrow::TAggregateAssign MakeAggregate(const TContext& info, const std::string& name, + const NKikimrSSA::TProgram::TAggregateAssignment::TAggregateFunction& func) +{ + using TId = NKikimrSSA::TProgram::TAggregateAssignment; + using EAggregate = NArrow::EAggregate; + using TAggregateAssign = NArrow::TAggregateAssign; + + if (func.ArgumentsSize() == 1) { + std::string argument = info.GetName(func.GetArguments()[0].GetId()); + + switch (func.GetId()) { + case TId::AGG_ANY: + return TAggregateAssign(name, EAggregate::Any, std::move(argument)); + case TId::AGG_COUNT: + return TAggregateAssign(name, EAggregate::Count, std::move(argument)); + case TId::AGG_MIN: + return TAggregateAssign(name, EAggregate::Min, std::move(argument)); + case TId::AGG_MAX: + return TAggregateAssign(name, EAggregate::Max, std::move(argument)); + case TId::AGG_SUM: + return TAggregateAssign(name, EAggregate::Sum, std::move(argument)); + case TId::AGG_AVG: + return TAggregateAssign(name, EAggregate::Avg, std::move(argument)); + + case TId::AGG_UNSPECIFIED: + break; + } + } + return TAggregateAssign(name, EAggregate::Unspecified, {}); } NArrow::TAssign MaterializeParameter(const std::string& name, const NKikimrSSA::TProgram::TParameter& parameter, @@ -194,7 +207,7 @@ NArrow::TAssign MaterializeParameter(const std::string& name, const NKikimrSSA:: auto parameterName = parameter.GetName(); auto column = parameterValues->GetColumnByName(parameterName); - +#if 0 Y_VERIFY( column, "No parameter %s in serialized parameters.", parameterName.c_str() @@ -203,11 +216,15 @@ NArrow::TAssign MaterializeParameter(const std::string& name, const NKikimrSSA:: column->length() == 1, "Incorrect values count in parameter array" ); - +#else + if (!column || column->length() != 1) { + return TAssign(name, NArrow::EOperation::Unspecified, {}); + } +#endif return TAssign(name, *column->GetScalar(0)); } -void ExtractAssign(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TAssignment& assign, +bool ExtractAssign(const TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TAssignment& assign, const std::shared_ptr<arrow::RecordBatch>& parameterValues) { using TId = NKikimrSSA::TProgram::TAssignment; @@ -218,36 +235,87 @@ void ExtractAssign(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA: switch (assign.GetExpressionCase()) { case TId::kFunction: { - step.Assignes.emplace_back(MakeFunction(info, columnName, assign.GetFunction())); + auto func = MakeFunction(info, columnName, assign.GetFunction()); + if (!func.IsOk()) { + return false; + } + step.Assignes.emplace_back(std::move(func)); break; } case TId::kConstant: { - step.Assignes.emplace_back(MakeConstant(columnName, assign.GetConstant())); + auto cnst = MakeConstant(columnName, assign.GetConstant()); + if (!cnst.IsConstant()) { + return false; + } + step.Assignes.emplace_back(std::move(cnst)); break; } case TId::kParameter: { - step.Assignes.emplace_back(MaterializeParameter(columnName, assign.GetParameter(), parameterValues)); + auto param = MaterializeParameter(columnName, assign.GetParameter(), parameterValues); + if (!param.IsConstant()) { + return false; + } + step.Assignes.emplace_back(std::move(param)); break; } case TId::kExternalFunction: case TId::kNull: case TId::EXPRESSION_NOT_SET: - Y_VERIFY(false); // not implemented - break; + return false; } + return true; } -void ExtractFilter(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TFilter& filter) { +bool ExtractFilter(const TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TFilter& filter) { ui32 columnId = filter.GetPredicate().GetId(); + if (!columnId) { + return false; + } step.Filters.push_back(info.GetName(columnId)); + return true; } -void ExtractProjection(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TProjection& projection) { +bool ExtractProjection(const TContext& info, NArrow::TProgramStep& step, + const NKikimrSSA::TProgram::TProjection& projection) { + step.Projection.reserve(projection.ColumnsSize()); for (auto& col : projection.GetColumns()) { step.Projection.push_back(info.GetName(col.GetId())); } + return true; +} + +bool ExtractGroupBy(const TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TGroupBy& groupBy) { + if (!groupBy.AggregatesSize()) { + return false; + } +#if 1 // TODO + if (groupBy.KeyColumnsSize()) { + return false; + } +#endif + + // It adds implicit projection with aggregates and keys. Remove non aggregated columns. + step.Projection.reserve(groupBy.KeyColumnsSize() + groupBy.AggregatesSize()); + for (auto& col : groupBy.GetKeyColumns()) { + step.Projection.push_back(info.GetName(col.GetId())); + } + + step.GroupBy.reserve(groupBy.AggregatesSize()); + for (auto& agg : groupBy.GetAggregates()) { + auto& resColumn = agg.GetColumn(); + TString columnName = ToString(resColumn.GetId()); + + auto func = MakeAggregate(info, columnName, agg.GetFunction()); + if (!func.IsOk()) { + return false; + } + step.GroupBy.push_back(std::move(func)); + step.Projection.push_back(columnName); + } + + return true; } } @@ -314,22 +382,31 @@ bool TReadDescription::AddProgram(const IColumnResolver& columnResolver, const N for (auto& cmd : program.GetCommand()) { switch (cmd.GetLineCase()) { case TId::kAssign: - ExtractAssign(info, *step, cmd.GetAssign(), ProgramParameters); + if (!ExtractAssign(info, *step, cmd.GetAssign(), ProgramParameters)) { + return false; + } break; case TId::kFilter: - ExtractFilter(info, *step, cmd.GetFilter()); + if (!ExtractFilter(info, *step, cmd.GetFilter())) { + return false; + } break; case TId::kProjection: - ExtractProjection(info, *step, cmd.GetProjection()); + if (!ExtractProjection(info, *step, cmd.GetProjection())) { + return false; + } Program.push_back(step); step = std::make_shared<NArrow::TProgramStep>(); break; case TId::kGroupBy: - // TODO - return false; // not implemented - case TId::LINE_NOT_SET: - Y_VERIFY(false); + if (!ExtractGroupBy(info, *step, cmd.GetGroupBy())) { + return false; + } + Program.push_back(step); + step = std::make_shared<NArrow::TProgramStep>(); break; + case TId::LINE_NOT_SET: + return false; } } diff --git a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp index da93bf78de..db3e42f55f 100644 --- a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp +++ b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp @@ -883,10 +883,19 @@ static NKikimrSSA::TProgram MakeSelect(TAssignment::EFunction compareId = TAssig } // SELECT some(timestamp), some(saved_at) FROM t -NKikimrSSA::TProgram MakeSelectAggregates(TAggAssignment::EAggregateFunction aggId = TAggAssignment::AGG_ANY) { +// +// FIXME: +// NotImplemented: Function any has no kernel matching input types (array[timestamp[us]]) +// NotImplemented: Function any has no kernel matching input types (array[string]) +// NotImplemented: Function min_max has no kernel matching input types (array[timestamp[us]]) +// NotImplemented: Function min_max has no kernel matching input types (array[string]) +// +NKikimrSSA::TProgram MakeSelectAggregates(TAggAssignment::EAggregateFunction aggId = TAggAssignment::AGG_ANY, + std::vector<ui32> columnIds = {1, 9}) +{ NKikimrSSA::TProgram ssa; - std::vector<ui32> columnIds = {1, 9}; + ui32 tmpColumnId = 100; auto* line1 = ssa.AddCommand(); @@ -1070,7 +1079,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T std::vector<TString> programs; { - NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_ANY); + NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_COUNT); TString serialized; UNIT_ASSERT(ssa.SerializeToString(&serialized)); NKikimrSSA::TOlapProgram program; @@ -1081,7 +1090,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T } { - NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_COUNT); + NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_MIN, {5, 5}); TString serialized; UNIT_ASSERT(ssa.SerializeToString(&serialized)); NKikimrSSA::TOlapProgram program; @@ -1108,7 +1117,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T auto& resRead = Proto(result); UNIT_ASSERT_EQUAL(resRead.GetOrigin(), TTestTxConfig::TxTablet0); UNIT_ASSERT_EQUAL(resRead.GetTxInitiator(), metaShard); -#if 0 // TODO + { UNIT_ASSERT_EQUAL(resRead.GetStatus(), NKikimrTxColumnShard::EResultStatus::SUCCESS); UNIT_ASSERT_EQUAL(resRead.GetBatch(), 0); @@ -1123,9 +1132,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T UNIT_ASSERT(CheckColumns(readData[0], meta, {"100", "101"}, 1)); } -#else - UNIT_ASSERT_EQUAL(resRead.GetStatus(), NKikimrTxColumnShard::EResultStatus::ERROR); -#endif + ++i; } } |