aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArtem Zuikov <chertus@gmail.com>2022-06-01 11:37:03 +0300
committerArtem Zuikov <chertus@gmail.com>2022-06-01 11:37:03 +0300
commitc051815bfe2221ba1b1ae457868e7ca59fd58fca (patch)
treea1d835d8d90b9f4f73a2c2b358cf3a44922d635e
parent45799c2715cdd8b97f99eba8864d365389056b3d (diff)
downloadydb-c051815bfe2221ba1b1ae457868e7ca59fd58fca.tar.gz
KIKIMR-14822: call basic aggregates from arrow library in ColumnShard
ref:94ad8071bd9d77eadbc386cca245b6fcf0eb4854
-rw-r--r--ydb/core/formats/program.cpp219
-rw-r--r--ydb/core/formats/program.h46
-rw-r--r--ydb/core/protos/ssa.proto8
-rw-r--r--ydb/core/tx/columnshard/columnshard_common.cpp163
-rw-r--r--ydb/core/tx/columnshard/ut_columnshard_read_write.cpp23
5 files changed, 346 insertions, 113 deletions
diff --git a/ydb/core/formats/program.cpp b/ydb/core/formats/program.cpp
index f8d8287ab5..10534a6435 100644
--- a/ydb/core/formats/program.cpp
+++ b/ydb/core/formats/program.cpp
@@ -151,9 +151,114 @@ const char * GetFunctionName(EOperation op) {
return "";
}
-void AddColumn(std::shared_ptr<TProgramStep::TDatumBatch>& batch,
- std::string field_name,
- const arrow::Datum& column) {
+EOperation ValidateOperation(EOperation op, ui32 argsSize) {
+ switch (op) {
+ case EOperation::Equal:
+ case EOperation::NotEqual:
+ case EOperation::Less:
+ case EOperation::LessEqual:
+ case EOperation::Greater:
+ case EOperation::GreaterEqual:
+ case EOperation::MatchSubstring:
+ case EOperation::And:
+ case EOperation::Or:
+ case EOperation::Xor:
+ case EOperation::Add:
+ case EOperation::Subtract:
+ case EOperation::Multiply:
+ case EOperation::Divide:
+ case EOperation::Modulo:
+ case EOperation::AddNotNull:
+ case EOperation::SubtractNotNull:
+ case EOperation::MultiplyNotNull:
+ case EOperation::DivideNotNull:
+ case EOperation::ModuloOrZero:
+ case EOperation::Gcd:
+ case EOperation::Lcm:
+ if (argsSize == 2) {
+ return op;
+ }
+ break;
+
+ case EOperation::CastBoolean:
+ case EOperation::CastInt8:
+ case EOperation::CastInt16:
+ case EOperation::CastInt32:
+ case EOperation::CastInt64:
+ case EOperation::CastUInt8:
+ case EOperation::CastUInt16:
+ case EOperation::CastUInt32:
+ case EOperation::CastUInt64:
+ case EOperation::CastFloat:
+ case EOperation::CastDouble:
+ case EOperation::CastBinary:
+ case EOperation::CastFixedSizeBinary:
+ case EOperation::CastString:
+ case EOperation::CastTimestamp:
+ case EOperation::IsValid:
+ case EOperation::IsNull:
+ case EOperation::BinaryLength:
+ case EOperation::Invert:
+ case EOperation::Abs:
+ case EOperation::Negate:
+ if (argsSize == 1) {
+ return op;
+ }
+ break;
+
+ case EOperation::Acosh:
+ case EOperation::Atanh:
+ case EOperation::Cbrt:
+ case EOperation::Cosh:
+ case EOperation::E:
+ case EOperation::Erf:
+ case EOperation::Erfc:
+ case EOperation::Exp:
+ case EOperation::Exp2:
+ case EOperation::Exp10:
+ case EOperation::Hypot:
+ case EOperation::Lgamma:
+ case EOperation::Pi:
+ case EOperation::Sinh:
+ case EOperation::Sqrt:
+ case EOperation::Tgamma:
+ case EOperation::Floor:
+ case EOperation::Ceil:
+ case EOperation::Trunc:
+ case EOperation::Round:
+ case EOperation::RoundBankers:
+ case EOperation::RoundToExp2:
+ return op; // TODO: check
+
+ default:
+ break;
+ }
+ return EOperation::Unspecified;
+}
+
+const char * GetFunctionName(EAggregate op) {
+ switch (op) {
+ case EAggregate::Any:
+ return "any";
+ case EAggregate::Count:
+ return "count";
+ case EAggregate::Min:
+ return "min_max";
+ case EAggregate::Max:
+ return "min_max";
+ case EAggregate::Sum:
+ return "sum";
+ case EAggregate::Avg:
+ return "mean";
+
+ default:
+ break;
+ }
+ return "";
+}
+
+
+void AddColumn(std::shared_ptr<TProgramStep::TDatumBatch>& batch, std::string field_name, const arrow::Datum& column) {
auto field = ::arrow::field(std::move(field_name), column.type());
Y_VERIFY(field != nullptr);
Y_VERIFY(field->type()->Equals(column.type()));
@@ -204,55 +309,12 @@ std::shared_ptr<arrow::Array> MakeConstantColumn(const arrow::Scalar& value, int
return *res;
}
-//firstly try to call function from custom registry, if fails call from default
-arrow::Result<arrow::Datum> CallFromCustomOrDefaultRegistry(EOperation funcId, const std::vector<arrow::Datum>& arguments, arrow::compute::ExecContext* ctx) {
- std::string funcName = GetFunctionName(funcId);
- if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) {
- return arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx);
- } else {
- return arrow::compute::CallFunction(GetFunctionName(funcId), arguments);
- }
-}
-
-std::shared_ptr<arrow::Array> CallArrayFunction(EOperation funcId, const std::vector<std::string>& args,
- std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) {
- std::vector<arrow::Datum> arguments;
- arguments.reserve(args.size());
-
- for (auto& colName : args) {
- auto column = batch->GetColumnByName(colName);
- Y_VERIFY(column);
- arguments.push_back(arrow::Datum(*column));
- }
- std::string funcName = GetFunctionName(funcId);
- arrow::Result<arrow::Datum> result;
- result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx);
- Y_VERIFY(result.ok());
- Y_VERIFY(result->is_array());
- return result->make_array();
-}
-
-
-std::shared_ptr<arrow::Scalar> CallScalarFunction(EOperation funcId, const std::vector<std::string>& args,
- std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) {
- std::vector<arrow::Datum> arguments;
- arguments.reserve(args.size());
-
- for (auto& colName : args) {
- auto column = batch->GetColumnByName(colName);
- Y_VERIFY(column);
- arguments.push_back(arrow::Datum{column});
- }
- std::string funcName = GetFunctionName(funcId);
- arrow::Result<arrow::Datum> result;
- result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx);
- Y_VERIFY(result.ok());
- Y_VERIFY(result->is_scalar());
- return result->scalar();
-}
-
-arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, const arrow::compute::FunctionOptions* funcOpts,
- std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) {
+template <typename TOpId, typename TOptions>
+arrow::Datum CallFunctionById(TOpId funcId, const std::vector<std::string>& args,
+ const TOptions* funcOpts,
+ std::shared_ptr<TProgramStep::TDatumBatch> batch,
+ arrow::compute::ExecContext* ctx)
+{
std::vector<arrow::Datum> arguments;
arguments.reserve(args.size());
@@ -269,15 +331,27 @@ arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>&
} else {
result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, funcOpts);
}
- Y_VERIFY(result.ok());
+ Y_VERIFY_S(result.ok(), result.status().message());
return result.ValueOrDie();
}
-arrow::Datum CallFunctionByAssign(const TAssign& assign, std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) {
+arrow::Datum CallFunctionByAssign(const TAssign& assign,
+ std::shared_ptr<TProgramStep::TDatumBatch> batch,
+ arrow::compute::ExecContext* ctx)
+{
return CallFunctionById(assign.GetOperation(), assign.GetArguments(), assign.GetFunctionOptions(), batch, ctx);
}
-void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const {
+arrow::Datum CallFunctionByAssign(const TAggregateAssign& assign,
+ std::shared_ptr<TProgramStep::TDatumBatch> batch,
+ arrow::compute::ExecContext* ctx)
+{
+ return CallFunctionById(assign.GetOperation(), assign.GetArguments(), &assign.GetAggregateOptions(), batch, ctx);
+}
+
+void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch,
+ arrow::compute::ExecContext* ctx) const
+{
if (Assignes.empty()) {
return;
}
@@ -296,6 +370,40 @@ void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& bat
//Y_VERIFY(batch->Validate().ok());
}
+void TProgramStep::ApplyAggregates(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const {
+ if (GroupBy.empty()) {
+ return;
+ }
+
+ auto res = std::make_shared<TDatumBatch>();
+ res->rows = 1; // TODO
+ res->datums.reserve(GroupBy.size());
+
+ arrow::FieldVector fields;
+ fields.reserve(GroupBy.size());
+
+ for (auto& assign : GroupBy) {
+ res->datums.push_back(CallFunctionByAssign(assign, batch, ctx));
+ auto& column = res->datums.back();
+ Y_VERIFY_S(column.is_scalar(), TStringBuilder() << "Aggregate result is not a scalar.");
+
+ auto op = assign.GetOperation();
+ if (op == EAggregate::Min) {
+ const auto& minMax = column.scalar_as<arrow::StructScalar>();
+ column = minMax.value[0];
+ } else if (op == EAggregate::Max) {
+ const auto& minMax = column.scalar_as<arrow::StructScalar>();
+ column = minMax.value[1];
+ }
+
+ Y_VERIFY_S(column.type(), TStringBuilder() << "Aggregate result has no type.");
+ fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), column.type()));
+ }
+
+ res->fields = std::make_shared<arrow::Schema>(fields);
+ batch = res;
+}
+
void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const {
if (Filters.empty()) {
return;
@@ -384,6 +492,7 @@ void TProgramStep::Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::comp
auto rb = ToTDatumBatch(batch);
ApplyAssignes(rb, ctx);
ApplyFilters(rb);
+ ApplyAggregates(rb, ctx);
ApplyProjection(rb);
batch = ToRecordBatch(rb);
}
diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h
index f4b7d466ab..163f82a5c4 100644
--- a/ydb/core/formats/program.h
+++ b/ydb/core/formats/program.h
@@ -1,5 +1,6 @@
#pragma once
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h>
#include <util/system/types.h>
@@ -83,20 +84,32 @@ enum class EOperation {
RoundToExp2
};
+enum class EAggregate {
+ Unspecified = 0,
+ Any = 1,
+ Count = 2,
+ Min = 3,
+ Max = 4,
+ Sum = 5,
+ Avg = 6,
+};
+
const char * GetFunctionName(EOperation op);
+const char * GetFunctionName(EAggregate op);
+EOperation ValidateOperation(EOperation op, ui32 argsSize);
class TAssign {
public:
TAssign(const std::string& name, EOperation op, std::vector<std::string>&& args)
: Name(name)
- , Operation(op)
+ , Operation(ValidateOperation(op, args.size()))
, Arguments(std::move(args))
, FuncOpts(nullptr)
{}
TAssign(const std::string& name, EOperation op, std::vector<std::string>&& args, std::shared_ptr<arrow::compute::FunctionOptions> funcOpts)
: Name(name)
- , Operation(op)
+ , Operation(ValidateOperation(op, args.size()))
, Arguments(std::move(args))
, FuncOpts(funcOpts)
{}
@@ -165,6 +178,7 @@ public:
{}
bool IsConstant() const { return Operation == EOperation::Constant; }
+ bool IsOk() const { return Operation != EOperation::Unspecified; }
EOperation GetOperation() const { return Operation; }
const std::vector<std::string>& GetArguments() const { return Arguments; }
std::shared_ptr<arrow::Scalar> GetConstant() const { return Constant; }
@@ -179,6 +193,31 @@ private:
std::shared_ptr<arrow::compute::FunctionOptions> FuncOpts;
};
+class TAggregateAssign {
+public:
+ TAggregateAssign(const std::string& name, EAggregate op, std::string&& arg)
+ : Name(name)
+ , Operation(op)
+ , Arguments({std::move(arg)})
+ {
+ if (arg.empty()) {
+ op = EAggregate::Unspecified;
+ }
+ }
+
+ bool IsOk() const { return Operation != EAggregate::Unspecified; }
+ EAggregate GetOperation() const { return Operation; }
+ const std::vector<std::string>& GetArguments() const { return Arguments; }
+ const std::string& GetName() const { return Name; }
+ const arrow::compute::ScalarAggregateOptions& GetAggregateOptions() const { return ScalarOpts; }
+
+private:
+ std::string Name;
+ EAggregate Operation{EAggregate::Unspecified};
+ std::vector<std::string> Arguments;
+ arrow::compute::ScalarAggregateOptions ScalarOpts; // TODO: make correct options
+};
+
/// Group of commands that finishes with projection. Steps add locality for columns definition.
///
/// In step we have non-decreasing count of columns (line to line) till projection. So columns are either source
@@ -191,8 +230,8 @@ private:
struct TProgramStep {
std::vector<TAssign> Assignes;
std::vector<std::string> Filters; // List of filter columns. Implicit "Filter by (f1 AND f2 AND .. AND fn)"
+ std::vector<TAggregateAssign> GroupBy;
std::vector<std::string> Projection; // Step's result columns (remove others)
- // TODO: group by
struct TDatumBatch {
std::shared_ptr<arrow::Schema> fields;
@@ -207,6 +246,7 @@ struct TProgramStep {
void Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::compute::ExecContext* ctx) const;
void ApplyAssignes(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const;
+ void ApplyAggregates(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const;
void ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const;
void ApplyProjection(std::shared_ptr<arrow::RecordBatch>& batch) const;
void ApplyProjection(std::shared_ptr<TDatumBatch>& batch) const;
diff --git a/ydb/core/protos/ssa.proto b/ydb/core/protos/ssa.proto
index 0b6774cf17..88eaeee8b4 100644
--- a/ydb/core/protos/ssa.proto
+++ b/ydb/core/protos/ssa.proto
@@ -100,10 +100,10 @@ message TProgram {
AGG_UNSPECIFIED = 0;
AGG_ANY = 1;
AGG_COUNT = 2;
- //AGG_MIN = 3;
- //AGG_MAX = 4;
- //AGG_SUM = 5;
- //AGG_AVG = 6;
+ AGG_MIN = 3;
+ AGG_MAX = 4;
+ AGG_SUM = 5;
+ AGG_AVG = 6;
//AGG_VAR = 7;
//AGG_COVAR = 8;
//AGG_STDDEV = 9;
diff --git a/ydb/core/tx/columnshard/columnshard_common.cpp b/ydb/core/tx/columnshard/columnshard_common.cpp
index 0662e5115f..ecacebdc80 100644
--- a/ydb/core/tx/columnshard/columnshard_common.cpp
+++ b/ydb/core/tx/columnshard/columnshard_common.cpp
@@ -39,13 +39,13 @@ TString FromCells(const TConstArrayRef<TCell>& cells, const TVector<std::pair<TS
struct TContext {
const IColumnResolver& ColumnResolver;
- THashMap<ui32, TString> Sources;
+ mutable THashMap<ui32, TString> Sources;
explicit TContext(const IColumnResolver& columnResolver)
: ColumnResolver(columnResolver)
{}
- std::string GetName(ui32 columnId) {
+ std::string GetName(ui32 columnId) const {
TString name = ColumnResolver.GetColumnName(columnId, false);
if (name.Empty()) {
name = ToString(columnId);
@@ -56,14 +56,13 @@ struct TContext {
}
};
-NArrow::TAssign MakeFunction(TContext& info, const std::string& name,
+NArrow::TAssign MakeFunction(const TContext& info, const std::string& name,
const NKikimrSSA::TProgram::TAssignment::TFunction& func) {
using TId = NKikimrSSA::TProgram::TAssignment;
using EOperation = NArrow::EOperation;
using TAssign = NArrow::TAssign;
- auto& args = func.GetArguments();
- TVector<std::string> arguments;
+ std::vector<std::string> arguments;
for (auto& col : func.GetArguments()) {
ui32 columnId = col.GetId();
arguments.push_back(info.GetName(columnId));
@@ -71,66 +70,48 @@ NArrow::TAssign MakeFunction(TContext& info, const std::string& name,
switch (func.GetId()) {
case TId::FUNC_CMP_EQUAL:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Equal, std::move(arguments));
case TId::FUNC_CMP_NOT_EQUAL:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::NotEqual, std::move(arguments));
case TId::FUNC_CMP_LESS:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Less, std::move(arguments));
case TId::FUNC_CMP_LESS_EQUAL:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::LessEqual, std::move(arguments));
case TId::FUNC_CMP_GREATER:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Greater, std::move(arguments));
case TId::FUNC_CMP_GREATER_EQUAL:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::GreaterEqual, std::move(arguments));
case TId::FUNC_IS_NULL:
- Y_VERIFY(args.size() == 1);
return TAssign(name, EOperation::IsNull, std::move(arguments));
case TId::FUNC_STR_LENGTH:
- Y_VERIFY(args.size() == 1);
return TAssign(name, EOperation::BinaryLength, std::move(arguments));
case TId::FUNC_STR_MATCH:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::MatchSubstring, std::move(arguments));
case TId::FUNC_BINARY_NOT:
- Y_VERIFY(args.size() == 1);
return TAssign(name, EOperation::Invert, std::move(arguments));
case TId::FUNC_BINARY_AND:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::And, std::move(arguments));
case TId::FUNC_BINARY_OR:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Or, std::move(arguments));
case TId::FUNC_BINARY_XOR:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Xor, std::move(arguments));
case TId::FUNC_MATH_ADD:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Add, std::move(arguments));
case TId::FUNC_MATH_SUBTRACT:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Subtract, std::move(arguments));
case TId::FUNC_MATH_MULTIPLY:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Multiply, std::move(arguments));
case TId::FUNC_MATH_DIVIDE:
- Y_VERIFY(args.size() == 2);
return TAssign(name, EOperation::Divide, std::move(arguments));
case TId::FUNC_CAST_TO_INT32:
{
- Y_VERIFY(args.size() == 1); // TODO: support CAST with OrDefault/OrNull logic (second argument is default value)
+ // TODO: support CAST with OrDefault/OrNull logic (second argument is default value)
auto castOpts = std::make_shared<arrow::compute::CastOptions>(false);
castOpts->to_type = std::make_shared<arrow::Int32Type>();
return TAssign(name, EOperation::CastInt32, std::move(arguments), castOpts);
}
case TId::FUNC_CAST_TO_TIMESTAMP:
{
- Y_VERIFY(args.size() == 1);
auto castOpts = std::make_shared<arrow::compute::CastOptions>(false);
castOpts->to_type = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::MICRO);
return TAssign(name, EOperation::CastTimestamp, std::move(arguments), castOpts);
@@ -149,11 +130,12 @@ NArrow::TAssign MakeFunction(TContext& info, const std::string& name,
case TId::FUNC_UNSPECIFIED:
break;
}
- Y_VERIFY(false); // unexpected
+ return TAssign(name, EOperation::Unspecified, std::move(arguments));
}
NArrow::TAssign MakeConstant(const std::string& name, const NKikimrSSA::TProgram::TConstant& constant) {
using TId = NKikimrSSA::TProgram::TConstant;
+ using EOperation = NArrow::EOperation;
using TAssign = NArrow::TAssign;
switch (constant.GetValueCase()) {
@@ -182,9 +164,40 @@ NArrow::TAssign MakeConstant(const std::string& name, const NKikimrSSA::TProgram
return TAssign(name, std::string(str.data(), str.size()));
}
case TId::VALUE_NOT_SET:
- Y_VERIFY(false); // unexpected
break;
}
+ return TAssign(name, EOperation::Unspecified, {});
+}
+
+NArrow::TAggregateAssign MakeAggregate(const TContext& info, const std::string& name,
+ const NKikimrSSA::TProgram::TAggregateAssignment::TAggregateFunction& func)
+{
+ using TId = NKikimrSSA::TProgram::TAggregateAssignment;
+ using EAggregate = NArrow::EAggregate;
+ using TAggregateAssign = NArrow::TAggregateAssign;
+
+ if (func.ArgumentsSize() == 1) {
+ std::string argument = info.GetName(func.GetArguments()[0].GetId());
+
+ switch (func.GetId()) {
+ case TId::AGG_ANY:
+ return TAggregateAssign(name, EAggregate::Any, std::move(argument));
+ case TId::AGG_COUNT:
+ return TAggregateAssign(name, EAggregate::Count, std::move(argument));
+ case TId::AGG_MIN:
+ return TAggregateAssign(name, EAggregate::Min, std::move(argument));
+ case TId::AGG_MAX:
+ return TAggregateAssign(name, EAggregate::Max, std::move(argument));
+ case TId::AGG_SUM:
+ return TAggregateAssign(name, EAggregate::Sum, std::move(argument));
+ case TId::AGG_AVG:
+ return TAggregateAssign(name, EAggregate::Avg, std::move(argument));
+
+ case TId::AGG_UNSPECIFIED:
+ break;
+ }
+ }
+ return TAggregateAssign(name, EAggregate::Unspecified, {});
}
NArrow::TAssign MaterializeParameter(const std::string& name, const NKikimrSSA::TProgram::TParameter& parameter,
@@ -194,7 +207,7 @@ NArrow::TAssign MaterializeParameter(const std::string& name, const NKikimrSSA::
auto parameterName = parameter.GetName();
auto column = parameterValues->GetColumnByName(parameterName);
-
+#if 0
Y_VERIFY(
column,
"No parameter %s in serialized parameters.", parameterName.c_str()
@@ -203,11 +216,15 @@ NArrow::TAssign MaterializeParameter(const std::string& name, const NKikimrSSA::
column->length() == 1,
"Incorrect values count in parameter array"
);
-
+#else
+ if (!column || column->length() != 1) {
+ return TAssign(name, NArrow::EOperation::Unspecified, {});
+ }
+#endif
return TAssign(name, *column->GetScalar(0));
}
-void ExtractAssign(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TAssignment& assign,
+bool ExtractAssign(const TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TAssignment& assign,
const std::shared_ptr<arrow::RecordBatch>& parameterValues)
{
using TId = NKikimrSSA::TProgram::TAssignment;
@@ -218,36 +235,87 @@ void ExtractAssign(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA:
switch (assign.GetExpressionCase()) {
case TId::kFunction:
{
- step.Assignes.emplace_back(MakeFunction(info, columnName, assign.GetFunction()));
+ auto func = MakeFunction(info, columnName, assign.GetFunction());
+ if (!func.IsOk()) {
+ return false;
+ }
+ step.Assignes.emplace_back(std::move(func));
break;
}
case TId::kConstant:
{
- step.Assignes.emplace_back(MakeConstant(columnName, assign.GetConstant()));
+ auto cnst = MakeConstant(columnName, assign.GetConstant());
+ if (!cnst.IsConstant()) {
+ return false;
+ }
+ step.Assignes.emplace_back(std::move(cnst));
break;
}
case TId::kParameter:
{
- step.Assignes.emplace_back(MaterializeParameter(columnName, assign.GetParameter(), parameterValues));
+ auto param = MaterializeParameter(columnName, assign.GetParameter(), parameterValues);
+ if (!param.IsConstant()) {
+ return false;
+ }
+ step.Assignes.emplace_back(std::move(param));
break;
}
case TId::kExternalFunction:
case TId::kNull:
case TId::EXPRESSION_NOT_SET:
- Y_VERIFY(false); // not implemented
- break;
+ return false;
}
+ return true;
}
-void ExtractFilter(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TFilter& filter) {
+bool ExtractFilter(const TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TFilter& filter) {
ui32 columnId = filter.GetPredicate().GetId();
+ if (!columnId) {
+ return false;
+ }
step.Filters.push_back(info.GetName(columnId));
+ return true;
}
-void ExtractProjection(TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TProjection& projection) {
+bool ExtractProjection(const TContext& info, NArrow::TProgramStep& step,
+ const NKikimrSSA::TProgram::TProjection& projection) {
+ step.Projection.reserve(projection.ColumnsSize());
for (auto& col : projection.GetColumns()) {
step.Projection.push_back(info.GetName(col.GetId()));
}
+ return true;
+}
+
+bool ExtractGroupBy(const TContext& info, NArrow::TProgramStep& step, const NKikimrSSA::TProgram::TGroupBy& groupBy) {
+ if (!groupBy.AggregatesSize()) {
+ return false;
+ }
+#if 1 // TODO
+ if (groupBy.KeyColumnsSize()) {
+ return false;
+ }
+#endif
+
+ // It adds implicit projection with aggregates and keys. Remove non aggregated columns.
+ step.Projection.reserve(groupBy.KeyColumnsSize() + groupBy.AggregatesSize());
+ for (auto& col : groupBy.GetKeyColumns()) {
+ step.Projection.push_back(info.GetName(col.GetId()));
+ }
+
+ step.GroupBy.reserve(groupBy.AggregatesSize());
+ for (auto& agg : groupBy.GetAggregates()) {
+ auto& resColumn = agg.GetColumn();
+ TString columnName = ToString(resColumn.GetId());
+
+ auto func = MakeAggregate(info, columnName, agg.GetFunction());
+ if (!func.IsOk()) {
+ return false;
+ }
+ step.GroupBy.push_back(std::move(func));
+ step.Projection.push_back(columnName);
+ }
+
+ return true;
}
}
@@ -314,22 +382,31 @@ bool TReadDescription::AddProgram(const IColumnResolver& columnResolver, const N
for (auto& cmd : program.GetCommand()) {
switch (cmd.GetLineCase()) {
case TId::kAssign:
- ExtractAssign(info, *step, cmd.GetAssign(), ProgramParameters);
+ if (!ExtractAssign(info, *step, cmd.GetAssign(), ProgramParameters)) {
+ return false;
+ }
break;
case TId::kFilter:
- ExtractFilter(info, *step, cmd.GetFilter());
+ if (!ExtractFilter(info, *step, cmd.GetFilter())) {
+ return false;
+ }
break;
case TId::kProjection:
- ExtractProjection(info, *step, cmd.GetProjection());
+ if (!ExtractProjection(info, *step, cmd.GetProjection())) {
+ return false;
+ }
Program.push_back(step);
step = std::make_shared<NArrow::TProgramStep>();
break;
case TId::kGroupBy:
- // TODO
- return false; // not implemented
- case TId::LINE_NOT_SET:
- Y_VERIFY(false);
+ if (!ExtractGroupBy(info, *step, cmd.GetGroupBy())) {
+ return false;
+ }
+ Program.push_back(step);
+ step = std::make_shared<NArrow::TProgramStep>();
break;
+ case TId::LINE_NOT_SET:
+ return false;
}
}
diff --git a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp
index da93bf78de..db3e42f55f 100644
--- a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp
+++ b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp
@@ -883,10 +883,19 @@ static NKikimrSSA::TProgram MakeSelect(TAssignment::EFunction compareId = TAssig
}
// SELECT some(timestamp), some(saved_at) FROM t
-NKikimrSSA::TProgram MakeSelectAggregates(TAggAssignment::EAggregateFunction aggId = TAggAssignment::AGG_ANY) {
+//
+// FIXME:
+// NotImplemented: Function any has no kernel matching input types (array[timestamp[us]])
+// NotImplemented: Function any has no kernel matching input types (array[string])
+// NotImplemented: Function min_max has no kernel matching input types (array[timestamp[us]])
+// NotImplemented: Function min_max has no kernel matching input types (array[string])
+//
+NKikimrSSA::TProgram MakeSelectAggregates(TAggAssignment::EAggregateFunction aggId = TAggAssignment::AGG_ANY,
+ std::vector<ui32> columnIds = {1, 9})
+{
NKikimrSSA::TProgram ssa;
- std::vector<ui32> columnIds = {1, 9};
+
ui32 tmpColumnId = 100;
auto* line1 = ssa.AddCommand();
@@ -1070,7 +1079,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T
std::vector<TString> programs;
{
- NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_ANY);
+ NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_COUNT);
TString serialized;
UNIT_ASSERT(ssa.SerializeToString(&serialized));
NKikimrSSA::TOlapProgram program;
@@ -1081,7 +1090,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T
}
{
- NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_COUNT);
+ NKikimrSSA::TProgram ssa = MakeSelectAggregates(TAggAssignment::AGG_MIN, {5, 5});
TString serialized;
UNIT_ASSERT(ssa.SerializeToString(&serialized));
NKikimrSSA::TOlapProgram program;
@@ -1108,7 +1117,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T
auto& resRead = Proto(result);
UNIT_ASSERT_EQUAL(resRead.GetOrigin(), TTestTxConfig::TxTablet0);
UNIT_ASSERT_EQUAL(resRead.GetTxInitiator(), metaShard);
-#if 0 // TODO
+
{
UNIT_ASSERT_EQUAL(resRead.GetStatus(), NKikimrTxColumnShard::EResultStatus::SUCCESS);
UNIT_ASSERT_EQUAL(resRead.GetBatch(), 0);
@@ -1123,9 +1132,7 @@ void TestReadAggregate(const TVector<std::pair<TString, TTypeId>>& ydbSchema = T
UNIT_ASSERT(CheckColumns(readData[0], meta, {"100", "101"}, 1));
}
-#else
- UNIT_ASSERT_EQUAL(resRead.GetStatus(), NKikimrTxColumnShard::EResultStatus::ERROR);
-#endif
+
++i;
}
}