diff options
author | aidarsamer <aidarsamer@ydb.tech> | 2022-07-21 19:26:18 +0300 |
---|---|---|
committer | aidarsamer <aidarsamer@ydb.tech> | 2022-07-21 19:26:18 +0300 |
commit | fc5f9cf5dc8228e32562168d4608e200b21371b2 (patch) | |
tree | e5e5328a9607b4eafbac8676ed5afa8cd5e0e3dc | |
parent | b839916a3502d942ce4844416a56f1e0c8856119 (diff) | |
download | ydb-fc5f9cf5dc8228e32562168d4608e200b21371b2.tar.gz |
Aggregate pushdown implementation in KQP
agg funcitons pushdown for OLAP
Implementation in KQP
34 files changed, 875 insertions, 131 deletions
diff --git a/ydb/core/formats/CMakeLists.txt b/ydb/core/formats/CMakeLists.txt index 993180c9e4..72e015389a 100644 --- a/ydb/core/formats/CMakeLists.txt +++ b/ydb/core/formats/CMakeLists.txt @@ -27,4 +27,5 @@ target_sources(ydb-core-formats PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/formats/func_cast.cpp ${CMAKE_SOURCE_DIR}/ydb/core/formats/merging_sorted_input_stream.cpp ${CMAKE_SOURCE_DIR}/ydb/core/formats/program.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/formats/ssa_program_optimizer.cpp ) diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h index 163f82a5c4..800ec9bb42 100644 --- a/ydb/core/formats/program.h +++ b/ydb/core/formats/program.h @@ -4,6 +4,8 @@ #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> #include <util/system/types.h> +#include <ydb/core/scheme_types/scheme_types_defs.h> + namespace NKikimr::NArrow { enum class EOperation { @@ -200,7 +202,8 @@ public: , Operation(op) , Arguments({std::move(arg)}) { - if (arg.empty()) { + if (arg.empty() && op != EAggregate::Count) { + // COUNT(*) doesn't have arguments op = EAggregate::Unspecified; } } @@ -260,4 +263,9 @@ inline void ApplyProgram(std::shared_ptr<arrow::RecordBatch>& batch, } } +struct TSsaProgramSteps { + std::vector<std::shared_ptr<TProgramStep>> Program; + THashMap<ui32, TString> ProgramSourceColumns; +}; + } diff --git a/ydb/core/formats/ssa_program_optimizer.cpp b/ydb/core/formats/ssa_program_optimizer.cpp new file mode 100644 index 0000000000..400fe09f18 --- /dev/null +++ b/ydb/core/formats/ssa_program_optimizer.cpp @@ -0,0 +1,36 @@ +#include "ssa_program_optimizer.h" + +namespace NKikimr::NSsaOptimizer { + +namespace { + +std::vector<std::shared_ptr<NArrow::TProgramStep>> ReplaceCountAll(const std::vector<std::shared_ptr<NArrow::TProgramStep>>& program, const NTable::TScheme::TTableSchema& tableSchema) { + std::vector<std::shared_ptr<NArrow::TProgramStep>> newProgram; + newProgram.reserve(program.size()); + + for (auto& step : program) { + newProgram.push_back(std::make_shared<NArrow::TProgramStep>(*step)); + for (size_t i = 0; i < step->GroupBy.size(); ++i) { + auto groupBy = step->GroupBy[i]; + auto& newGroupBy = newProgram.back()->GroupBy; + if (groupBy.GetOperation() == NArrow::EAggregate::Count + && groupBy.GetArguments().size() == 1 + && groupBy.GetArguments()[0] == "") { + // COUNT(*) + auto replaceIt = std::next(newGroupBy.begin(), i); + replaceIt = newGroupBy.erase(replaceIt); + auto pkColName = tableSchema.Columns.find(tableSchema.KeyColumns[0])->second.Name; + newGroupBy.emplace(replaceIt, groupBy.GetName(), groupBy.GetOperation(), std::string(pkColName)); + } + } + } + return newProgram; +} + +} // anonymous namespace + +std::vector<std::shared_ptr<NArrow::TProgramStep>> OptimizeProgram(const std::vector<std::shared_ptr<NArrow::TProgramStep>>& program, const NTable::TScheme::TTableSchema& tableSchema) { + return ReplaceCountAll(program, tableSchema); +} + +} // namespace NKikimr::NSsaOptimizer diff --git a/ydb/core/formats/ssa_program_optimizer.h b/ydb/core/formats/ssa_program_optimizer.h new file mode 100644 index 0000000000..12c48dd1db --- /dev/null +++ b/ydb/core/formats/ssa_program_optimizer.h @@ -0,0 +1,12 @@ +#pragma once + +#include "program.h" + +#include <ydb/core/protos/ssa.pb.h> +#include <ydb/core/tablet_flat/flat_dbase_scheme.h> + +namespace NKikimr::NSsaOptimizer { + +std::vector<std::shared_ptr<NArrow::TProgramStep>> OptimizeProgram(const std::vector<std::shared_ptr<NArrow::TProgramStep>>& program, const NTable::TScheme::TTableSchema& tableSchema); + +} diff --git a/ydb/core/kqp/compile/kqp_olap_compiler.cpp b/ydb/core/kqp/compile/kqp_olap_compiler.cpp index f880273706..7f5cd49122 100644 --- a/ydb/core/kqp/compile/kqp_olap_compiler.cpp +++ b/ydb/core/kqp/compile/kqp_olap_compiler.cpp @@ -9,6 +9,8 @@ using namespace NYql; using namespace NYql::NNodes; using namespace NKikimrSSA; +using EAggFunctionType = TProgram::TAggregateAssignment::EAggregateFunction; + constexpr ui32 OLAP_PROGRAM_VERSION = 1; namespace { @@ -56,6 +58,14 @@ public: return Program.AddCommand()->MutableFilter(); } + TProgram::TGroupBy* CreateGroupBy() { + return Program.AddCommand()->MutableGroupBy(); + } + + TProgram::TProjection* CreateProjection() { + return Program.AddCommand()->MutableProjection(); + } + void AddParameterName(const TString& name) { ReadProto.AddOlapProgramParameterNames(name); } @@ -67,7 +77,14 @@ public: ReadProto.SetOlapProgram(programBytes); } + EAggFunctionType GetAggFuncType(const std::string& funcName) const { + YQL_ENSURE(AggFuncTypesMap.find(funcName) != AggFuncTypesMap.end()); + return AggFuncTypesMap.at(funcName); + } + private: + static std::unordered_map<std::string, EAggFunctionType> AggFuncTypesMap; + TCoArgument Row; TMap<TString, ui32> ReadColumns; ui32 MaxColumnId; @@ -75,6 +92,10 @@ private: NKqpProto::TKqpPhyOpReadOlapRanges& ReadProto; }; +std::unordered_map<std::string, EAggFunctionType> TKqpOlapCompileContext::AggFuncTypesMap = { + { "count", TProgram::TAggregateAssignment::AGG_COUNT }, + { "some", TProgram::TAggregateAssignment::AGG_ANY }, +}; TProgram::TAssignment* CompileCondition(const TExprBase& condition, TKqpOlapCompileContext& ctx); ui64 GetOrCreateColumnId(const TExprBase& node, TKqpOlapCompileContext& ctx); @@ -348,15 +369,58 @@ void CompileFilter(const TKqpOlapFilter& filterNode, TKqpOlapCompileContext& ctx filter->MutablePredicate()->SetId(condition->GetColumn().GetId()); } +void CompileAggregates(const TKqpOlapAgg& aggNode, TKqpOlapCompileContext& ctx) { + auto* groupBy = ctx.CreateGroupBy(); + auto* projection = ctx.CreateProjection(); + + for (auto keyCol : aggNode.KeyColumns()) { + auto aggKeyCol = groupBy->AddKeyColumns(); + auto keyColName = keyCol.StringValue(); + auto aggKeyColId = GetOrCreateColumnId(keyCol, ctx); + aggKeyCol->SetId(aggKeyColId); + aggKeyCol->SetName(keyColName); + + auto* projCol = projection->AddColumns(); + projCol->SetId(aggKeyColId); + projCol->SetName(keyColName); + } + + for (auto aggIt : aggNode.Aggregates()) { + auto aggKqp = aggIt.Cast<TKqpOlapAggOperation>(); + std::string aggColName = aggKqp.Name().StringValue().c_str(); + + auto* agg = groupBy->AddAggregates(); + auto aggColId = ctx.NewColumnId(); + auto* aggCol = agg->MutableColumn(); + aggCol->SetId(aggColId); + aggCol->SetName(aggColName.c_str()); + auto* projCol = projection->AddColumns(); + projCol->SetId(aggColId); + projCol->SetName(aggColName.c_str()); + + auto* aggFunc = agg->MutableFunction(); + aggFunc->SetId(ctx.GetAggFuncType(aggKqp.Type().StringValue().c_str())); + + if (aggKqp.Column() != "*") { + aggFunc->AddArguments()->SetId(GetOrCreateColumnId(aggKqp.Column(), ctx)); + } + } +} + void CompileOlapProgramImpl(TExprBase operation, TKqpOlapCompileContext& ctx) { if (operation.Raw() == ctx.GetRowExpr()) { return; } - if (auto maybeFilter = operation.Maybe<TKqpOlapFilter>()) { - CompileOlapProgramImpl(maybeFilter.Cast().Input(), ctx); - CompileFilter(maybeFilter.Cast(), ctx); - return; + if (auto maybeOlapOperation = operation.Maybe<TKqpOlapOperationBase>()) { + CompileOlapProgramImpl(maybeOlapOperation.Cast().Input(), ctx); + if (auto maybeFilter = operation.Maybe<TKqpOlapFilter>()) { + CompileFilter(maybeFilter.Cast(), ctx); + return; + } else if (auto maybeAgg = operation.Maybe<TKqpOlapAgg>()) { + CompileAggregates(maybeAgg.Cast(), ctx); + return; + } } YQL_ENSURE(operation.Maybe<TCallable>(), "Unexpected OLAP operation node type: " << operation.Ref().Type()); diff --git a/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json b/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json index 66d86ab9f1..6c3ebee989 100644 --- a/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json +++ b/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json @@ -442,6 +442,29 @@ ] }, { + "Name": "TKqpOlapAggOperation", + "Base": "TExprBase", + "Match": {"Type": "Tuple"}, + "Children": [ + {"Index": 0, "Name": "Name", "Type": "TCoAtom"}, + {"Index": 1, "Name": "Type", "Type": "TCoAtom"}, + {"Index": 2, "Name": "Column", "Type": "TCoAtom"} + ] + }, + { + "Name": "TKqpOlapAggOperationList", + "ListBase": "TKqpOlapAggOperation" + }, + { + "Name": "TKqpOlapAgg", + "Base": "TKqpOlapOperationBase", + "Match": {"Type": "Callable", "Name": "TKqpOlapAgg"}, + "Children": [ + {"Index": 1, "Name": "Aggregates", "Type": "TKqpOlapAggOperationList"}, + {"Index": 2, "Name": "KeyColumns", "Type": "TCoAtomList"} + ] + }, + { "Name": "TKqpEnsure", "Base": "TCallable", "Match": {"Type": "Callable", "Name": "KqpEnsure"}, diff --git a/ydb/core/kqp/opt/peephole/kqp_opt_peephole.cpp b/ydb/core/kqp/opt/peephole/kqp_opt_peephole.cpp index 55aca64df8..f44dbefc29 100644 --- a/ydb/core/kqp/opt/peephole/kqp_opt_peephole.cpp +++ b/ydb/core/kqp/opt/peephole/kqp_opt_peephole.cpp @@ -37,6 +37,7 @@ public: AddHandler(0, &TDqPhyJoinDict::Match, HNDL(RewriteDictJoin)); AddHandler(0, &TDqJoin::Match, HNDL(RewritePureJoin)); AddHandler(0, TOptimizeTransformerBase::Any(), HNDL(BuildWideReadTable)); + AddHandler(0, &TDqPhyLength::Match, HNDL(RewriteLength)); #undef HNDL } @@ -76,6 +77,12 @@ protected: DumpAppliedRule("BuildWideReadTable", node.Ptr(), output.Ptr(), ctx); return output; } + + TMaybeNode<TExprBase> RewriteLength(TExprBase node, TExprContext& ctx) { + TExprBase output = DqPeepholeRewriteLength(node, ctx); + DumpAppliedRule("RewriteLength", node.Ptr(), output.Ptr(), ctx); + return output; + } }; struct TKqpPeepholePipelineConfigurator : IPipelineConfigurator { diff --git a/ydb/core/kqp/opt/physical/CMakeLists.txt b/ydb/core/kqp/opt/physical/CMakeLists.txt index 6e5a345794..c0c3c5285f 100644 --- a/ydb/core/kqp/opt/physical/CMakeLists.txt +++ b/ydb/core/kqp/opt/physical/CMakeLists.txt @@ -22,6 +22,7 @@ target_link_libraries(kqp-opt-physical PUBLIC target_sources(kqp-opt-physical PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/physical/kqp_opt_phy_build_stage.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_agg.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_filter.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/physical/kqp_opt_phy_precompute.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/physical/kqp_opt_phy_sort.cpp diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp index c57f9cc6f2..6699dc3298 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp @@ -34,6 +34,8 @@ public: HNDL(RemoveRedundantSortByPk)); AddHandler(0, &TCoTake::Match, HNDL(ApplyLimitToReadTable)); AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapFilter)); + AddHandler(0, &TCoCombineByKey::Match, HNDL(PushOlapAggregate)); + AddHandler(0, &TDqPhyLength::Match, HNDL(PushOlapLength)); AddHandler(0, &TCoSkipNullMembers::Match, HNDL(PushSkipNullMembersToStage<false>)); AddHandler(0, &TCoExtractMembers::Match, HNDL(PushExtractMembersToStage<false>)); AddHandler(0, &TCoFlatMapBase::Match, HNDL(BuildFlatmapStage<false>)); @@ -134,6 +136,18 @@ protected: return output; } + TMaybeNode<TExprBase> PushOlapAggregate(TExprBase node, TExprContext& ctx) { + TExprBase output = KqpPushOlapAggregate(node, ctx, KqpCtx); + DumpAppliedRule("PushOlapAggregate", node.Ptr(), output.Ptr(), ctx); + return output; + } + + TMaybeNode<TExprBase> PushOlapLength(TExprBase node, TExprContext& ctx) { + TExprBase output = KqpPushOlapLength(node, ctx, KqpCtx); + DumpAppliedRule("PushOlapLength", node.Ptr(), output.Ptr(), ctx); + return output; + } + template <bool IsGlobal> TMaybeNode<TExprBase> PushSkipNullMembersToStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TGetParents& getParents) diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_agg.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_agg.cpp new file mode 100644 index 0000000000..6443ad1427 --- /dev/null +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_olap_agg.cpp @@ -0,0 +1,276 @@ +#include "kqp_opt_phy_rules.h" + +#include <ydb/core/kqp/common/kqp_yql.h> + +#include <ydb/library/yql/core/yql_opt_utils.h> + +#include <vector> + +namespace NKikimr::NKqp::NOpt { + +using namespace NYql; +using namespace NYql::NNodes; + +namespace { + +enum class EAggType { + Count, + Some +}; + +struct TAggInfo { + std::string AggName; + std::string ColName; + EAggType Type; +}; + +bool ContainsConstOnly(const TExprBase& node) { + return node.Maybe<TCoDataCtor>().IsValid(); +} + +bool ContainsSimpleColumnOnly(const TExprBase& node, const TExprBase& parent) { + if (!parent.Maybe<TCoInputBase>()) { + return false; + } + auto input = parent.Cast<TCoInputBase>().Input(); + if (auto maybeExprList = node.Maybe<TExprList>()) { + for (auto expr : maybeExprList.Cast()) { + if (!expr.Maybe<TCoMember>() || expr.Cast<TCoMember>().Struct().Raw() != input.Raw()) { + return false; + } + } + return true; + } + return node.Maybe<TCoMember>().IsValid() && node.Cast<TCoMember>().Struct().Raw() == input.Raw(); +} + +std::vector<std::string> GetGroupByCols(const TExprBase& keySelectorBody, const TExprBase& parent) { + std::vector<std::string> res; + if (!ContainsSimpleColumnOnly(keySelectorBody, parent)) { + YQL_CLOG(DEBUG, ProviderKqp) << "For aggregate push down optimization in GROUP BY column list should be Member callables only."; + return res; + } + if (auto maybeMember = keySelectorBody.Maybe<TCoMember>()) { + res.push_back(keySelectorBody.Cast<TCoMember>().Name().StringValue()); + } else if (auto maybeExprList = keySelectorBody.Maybe<TExprList>()) { + for (auto expr : maybeExprList.Cast()) { + res.push_back(expr.Cast<TCoMember>().Name().StringValue()); + } + } + return res; +} + +std::vector<TAggInfo> GetAggregationsFromInit(const TExprBase& node) { + std::vector<TAggInfo> res; + if (!node.Maybe<TCoAsStruct>()) { + return res; + } + for (auto item : node.Cast<TCoAsStruct>()) { + auto tuple = item.Cast<TCoNameValueTuple>(); + auto tupleValue = tuple.Value(); + if (tupleValue.Maybe<TCoAggrCountInit>()) { + auto aggrCntInit = tupleValue.Cast<TCoAggrCountInit>(); + if (aggrCntInit.Value().Maybe<TCoMember>()) { + TAggInfo aggInfo; + aggInfo.AggName = tuple.Name(); + aggInfo.Type = EAggType::Count; + aggInfo.ColName = aggrCntInit.Value().Cast<TCoMember>().Name(); + res.push_back(aggInfo); + } + } else { + YQL_CLOG(DEBUG, ProviderKqp) << "Unsupported aggregation type in init handler."; + res.clear(); + return res; + } + } + return res; +} + +std::vector<TAggInfo> GetAggregationsFromUpdate(const TExprBase& node) { + std::vector<TAggInfo> res; + if (!node.Maybe<TCoAsStruct>()) { + return res; + } + for (auto item : node.Cast<TCoAsStruct>()) { + auto tuple = item.Cast<TCoNameValueTuple>(); + auto tupleValue = tuple.Value(); + if (auto maybeAggrCntUpd = tupleValue.Maybe<TCoAggrCountUpdate>()) { + if (maybeAggrCntUpd.Cast().Value().Maybe<TCoMember>()) { + TAggInfo aggInfo; + aggInfo.Type = EAggType::Count; + aggInfo.AggName = tuple.Name(); + res.push_back(aggInfo); + } + } else { + YQL_CLOG(DEBUG, ProviderKqp) << "Unsupported aggregation type in update handler."; + res.clear(); + return res; + } + } + return res; +} + +} // anonymous namespace end + +TExprBase KqpPushOlapAggregate(TExprBase node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx) +{ + if (!kqpCtx.Config->PushOlapProcess()) { + return node; + } + + if (!node.Maybe<TCoCombineByKey>().Input().Maybe<TKqpReadOlapTableRanges>()) { + return node; + } + + auto combineKey = node.Cast<TCoCombineByKey>(); + auto read = combineKey.Input().Cast<TKqpReadOlapTableRanges>(); + + if (read.Process().Body().Raw() != read.Process().Args().Arg(0).Raw()) { + return node; + } + + auto keySelectorBody = combineKey.KeySelectorLambda().Cast<TCoLambda>().Body(); + if (!ContainsSimpleColumnOnly(keySelectorBody, combineKey) && !ContainsConstOnly(keySelectorBody)) { + return node; + } + auto aggKeyCols = Build<TCoAtomList>(ctx, node.Pos()); + auto groupByCols = GetGroupByCols(keySelectorBody, combineKey); + for (auto groupByCol : groupByCols) { + aggKeyCols.Add<TCoAtom>() + .Build(groupByCol) + .Done(); + } + + auto initHandlerBody = combineKey.InitHandlerLambda().Cast<TCoLambda>().Body(); + auto aggInits = GetAggregationsFromInit(initHandlerBody); + + auto updateHandlerBody = combineKey.UpdateHandlerLambda().Cast<TCoLambda>().Body(); + auto aggUpdates = GetAggregationsFromUpdate(updateHandlerBody); + + auto finishHandlerBody = combineKey.FinishHandlerLambda().Cast<TCoLambda>().Body(); + if (aggInits.empty() || aggInits.size() != aggUpdates.size()) { + return node; + } + + for (size_t i = 0; i != aggInits.size(); ++i) { + if (aggInits[i].Type != aggUpdates[i].Type) { + YQL_CLOG(DEBUG, ProviderKqp) << "Different aggregation type in init and update handlers in aggregate push-down optimization!"; + return node; + } + } + + auto aggs = Build<TKqpOlapAggOperationList>(ctx, node.Pos()); + // TODO: TMaybeNode<TKqpOlapAggOperation>; + for (size_t i = 0; i != aggInits.size(); ++i) { + std::string aggType; + switch (aggInits[i].Type) { + case EAggType::Count: + { + aggType = "count"; + break; + } + case EAggType::Some: + { + aggType = "some"; + break; + } + default: + { + YQL_ENSURE(false, "Unsupported type of aggregation!"); // add aggInits[i].Type + return node; + } + } + aggs.Add<TKqpOlapAggOperation>() + .Name().Build(aggInits[i].AggName) + .Type().Build(aggType) + .Column().Build(aggInits[i].ColName) + .Build() + .Done(); + } + + auto olapAgg = Build<TKqpOlapAgg>(ctx, node.Pos()) + .Input(read.Process().Body()) + .Aggregates(std::move(aggs.Done())) + .KeyColumns(std::move(aggKeyCols.Done())) + .Done(); + + auto newProcessLambda = Build<TCoLambda>(ctx, node.Pos()) + .Args({"row"}) + .Body<TExprApplier>() + .Apply(olapAgg) + .With(read.Process().Args().Arg(0), "row") + .Build() + .Done(); + + YQL_CLOG(INFO, ProviderKqp) << "Pushed OLAP lambda: " << KqpExprToPrettyString(newProcessLambda, ctx); + + auto newRead = Build<TKqpReadOlapTableRanges>(ctx, node.Pos()) + .Table(read.Table()) + .Ranges(read.Ranges()) + .Columns(read.Columns()) + .Settings(read.Settings()) + .ExplainPrompt(read.ExplainPrompt()) + .Process(newProcessLambda) + .Done(); + + return newRead; +} + +TExprBase KqpPushOlapLength(TExprBase node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx) +{ + if (!kqpCtx.Config->PushOlapProcess()) { + return node; + } + + if (!node.Maybe<TDqPhyLength>().Input().Maybe<TKqpReadOlapTableRanges>()) { + return node; + } + + auto dqPhyLength = node.Cast<TDqPhyLength>(); + auto read = dqPhyLength.Input().Cast<TKqpReadOlapTableRanges>(); + + if (read.Process().Body().Raw() != read.Process().Args().Arg(0).Raw()) { + return node; + } + + auto aggs = Build<TKqpOlapAggOperationList>(ctx, node.Pos()); + aggs.Add<TKqpOlapAggOperation>() + .Name(dqPhyLength.Name()) + .Type().Build("count") + .Column().Build("*") + .Build() + .Done(); + + auto olapAgg = Build<TKqpOlapAgg>(ctx, node.Pos()) + .Input(read.Process().Body()) + .Aggregates(std::move(aggs.Done())) + .KeyColumns(std::move( + Build<TCoAtomList>(ctx, node.Pos()) + .Done() + ) + ) + .Done(); + + auto newProcessLambda = Build<TCoLambda>(ctx, node.Pos()) + .Args({"row"}) + .Body<TExprApplier>() + .Apply(olapAgg) + .With(read.Process().Args().Arg(0), "row") + .Build() + .Done(); + + YQL_CLOG(INFO, ProviderKqp) << "Pushed OLAP lambda: " << KqpExprToPrettyString(newProcessLambda, ctx); + + auto newRead = Build<TKqpReadOlapTableRanges>(ctx, node.Pos()) + .Table(read.Table()) + .Ranges(read.Ranges()) + .Columns(read.Columns()) + .Settings(read.Settings()) + .ExplainPrompt(read.ExplainPrompt()) + .Process(newProcessLambda) + .Done(); + + return newRead; +} + +} // namespace NKikimr::NKqp::NOpt
\ No newline at end of file diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h b/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h index bde2856792..54eb24f227 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h @@ -30,6 +30,12 @@ NYql::NNodes::TExprBase KqpApplyLimitToReadTable(NYql::NNodes::TExprBase node, N NYql::NNodes::TExprBase KqpPushOlapFilter(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, NYql::TTypeAnnotationContext& typesCtx); +NYql::NNodes::TExprBase KqpPushOlapAggregate(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, + const TKqpOptimizeContext& kqpCtx); + +NYql::NNodes::TExprBase KqpPushOlapLength(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, + const TKqpOptimizeContext& kqpCtx); + NYql::NNodes::TExprBase KqpFloatUpStage(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx); NYql::NNodes::TExprBase KqpPropagatePrecomuteScalarRowset(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, diff --git a/ydb/core/kqp/prepare/kqp_query_plan.cpp b/ydb/core/kqp/prepare/kqp_query_plan.cpp index a363e5fe77..4ec82d87bf 100644 --- a/ydb/core/kqp/prepare/kqp_query_plan.cpp +++ b/ydb/core/kqp/prepare/kqp_query_plan.cpp @@ -790,7 +790,9 @@ private: })).Cast<TCoJoinDict>(); operatorId = Visit(flatMap, join, planNode); node = join.Ptr(); - } else if (auto maybeCondense = TMaybeNode<TCoCondense1>(node)) { + } else if (auto maybeCondense1 = TMaybeNode<TCoCondense1>(node)) { + operatorId = Visit(maybeCondense1.Cast(), planNode); + } else if (auto maybeCondense = TMaybeNode<TCoCondense>(node)) { operatorId = Visit(maybeCondense.Cast(), planNode); } else if (auto maybeCombiner = TMaybeNode<TCoCombineCore>(node)) { operatorId = Visit(maybeCombiner.Cast(), planNode); @@ -840,6 +842,13 @@ private: return AddOperator(planNode, "Aggregate", std::move(op)); } + ui32 Visit(const TCoCondense& /*condense*/, TQueryPlanNode& planNode) { + TOperator op; + op.Properties["Name"] = "Aggregate"; + + return AddOperator(planNode, "Aggregate", std::move(op)); + } + ui32 Visit(const TCoCombineCore& combiner, TQueryPlanNode& planNode) { TOperator op; op.Properties["Name"] = "Aggregate"; diff --git a/ydb/core/kqp/prepare/kqp_type_ann.cpp b/ydb/core/kqp/prepare/kqp_type_ann.cpp index 965efb09c0..bfa6d038d4 100644 --- a/ydb/core/kqp/prepare/kqp_type_ann.cpp +++ b/ydb/core/kqp/prepare/kqp_type_ann.cpp @@ -841,6 +841,85 @@ TStatus AnnotateOlapFilterExists(const TExprNode::TPtr& node, TExprContext& ctx) return TStatus::Ok; } +TStatus AnnotateOlapAgg(const TExprNode::TPtr& node, TExprContext& ctx) { + if (!EnsureArgsCount(*node, 3, ctx)) { + return TStatus::Error; + } + + auto* input = node->Child(TKqpOlapAgg::idx_Input); + + const TTypeAnnotationNode* itemType; + if (!EnsureNewSeqType<false, false, true>(*input, ctx, &itemType)) { + return TStatus::Error; + } + + if (!EnsureStructType(input->Pos(), *itemType, ctx)) { + return TStatus::Error; + } + + auto structType = itemType->Cast<TStructExprType>(); + + if (!EnsureTuple(*node->Child(TKqpOlapAgg::idx_Aggregates), ctx)) { + return TStatus::Error; + } + + TVector<const TItemExprType*> aggTypes; + for (auto agg : node->Child(TKqpOlapAgg::idx_Aggregates)->ChildrenList()) { + auto aggName = agg->Child(TKqpOlapAggOperation::idx_Name); + auto opType = agg->Child(TKqpOlapAggOperation::idx_Type); + auto colName = agg->Child(TKqpOlapAggOperation::idx_Column); + if (!EnsureAtom(*opType, ctx)) { + ctx.AddError(TIssue( + ctx.GetPosition(node->Pos()), + TStringBuilder() << "Expected operation type in OLAP aggregation, got: " << opType->Content() + )); + return TStatus::Error; + } + if (!EnsureAtom(*colName, ctx)) { + ctx.AddError(TIssue( + ctx.GetPosition(node->Pos()), + TStringBuilder() << "Expected column name in OLAP aggregation, got: " << colName->Content() + )); + return TStatus::Error; + } + if (!EnsureAtom(*aggName, ctx)) { + ctx.AddError(TIssue( + ctx.GetPosition(node->Pos()), + TStringBuilder() << "Expected aggregate column generated name in OLAP aggregation, got: " << aggName->Content() + )); + return TStatus::Error; + } + if (opType->Content() == "count") { + aggTypes.push_back(ctx.MakeType<TItemExprType>(aggName->Content(), ctx.MakeType<TDataExprType>(EDataSlot::Uint64))); + } else if (opType->Content() == "some") { + aggTypes.push_back(ctx.MakeType<TItemExprType>(aggName->Content(), structType->FindItemType(colName->Content()))); + } else { + ctx.AddError(TIssue( + ctx.GetPosition(node->Pos()), + TStringBuilder() << "Unsupported operation type in OLAP aggregation, got: " << opType->Content() + )); + return TStatus::Error; + } + } + + if (!EnsureTuple(*node->Child(TKqpOlapAgg::idx_KeyColumns), ctx)) { + return TStatus::Error; + } + for (auto keyCol : node->Child(TKqpOlapAgg::idx_KeyColumns)->ChildrenList()) { + if (!EnsureAtom(*keyCol, ctx)) { + ctx.AddError(TIssue( + ctx.GetPosition(node->Pos()), + TStringBuilder() << "Expected column name in OLAP key columns, got: " << keyCol->Content() + )); + return TStatus::Error; + } + aggTypes.push_back(ctx.MakeType<TItemExprType>(keyCol->Content(), structType->FindItemType(keyCol->Content()))); + } + + node->SetTypeAnn(MakeSequenceType(input->GetTypeAnn()->GetKind(), *ctx.MakeType<TStructExprType>(aggTypes), ctx)); + return TStatus::Ok; +} + TStatus AnnotateKqpTxInternalBinding(const TExprNode::TPtr& node, TExprContext& ctx) { if (!EnsureArgsCount(*node, 2, ctx)) { return TStatus::Error; @@ -1154,6 +1233,10 @@ TAutoPtr<IGraphTransformer> CreateKqpTypeAnnotationTransformer(const TString& cl return AnnotateOlapFilterExists(input, ctx); } + if (TKqpOlapAgg::Match(input.Get())) { + return AnnotateOlapAgg(input, ctx); + } + if (TKqpCnMapShard::Match(input.Get()) || TKqpCnShuffleShard::Match(input.Get())) { return AnnotateDqConnection(input, ctx); } diff --git a/ydb/core/kqp/runtime/kqp_read_table.cpp b/ydb/core/kqp/runtime/kqp_read_table.cpp index 70b33e12fa..8c0a2e82f7 100644 --- a/ydb/core/kqp/runtime/kqp_read_table.cpp +++ b/ydb/core/kqp/runtime/kqp_read_table.cpp @@ -70,69 +70,44 @@ void BuildKeyTupleCells(const TTupleType* tupleType, const TUnboxedValue& tupleV } void ParseReadColumns(const TType* readType, const TRuntimeNode& tagsNode, - TSmallVec<TKqpComputeContextBase::TColumn>& columns, TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns) + TSmallVec<NTable::TTag>& columns, TSmallVec<NTable::TTag>& systemColumns) { - MKQL_ENSURE_S(readType); + MKQL_ENSURE_S(readType); MKQL_ENSURE_S(readType->GetKind() == TType::EKind::Flow); auto tags = AS_VALUE(TStructLiteral, tagsNode); MKQL_ENSURE_S(tags); - auto itemType = AS_TYPE(TFlowType, readType)->GetItemType(); - MKQL_ENSURE_S(itemType->GetKind() == TType::EKind::Struct); - auto structType = AS_TYPE(TStructType, itemType); - MKQL_ENSURE_S(tags->GetValuesCount() == structType->GetMembersCount()); + columns.reserve(tags->GetValuesCount()); - columns.reserve(structType->GetMembersCount()); - - for (ui32 i = 0; i < structType->GetMembersCount(); ++i) { - auto memberType = structType->GetMemberType(i); - if (memberType->GetKind() == TType::EKind::Optional) { - memberType = AS_TYPE(TOptionalType, memberType)->GetItemType(); - } - MKQL_ENSURE_S(memberType->GetKind() == TType::EKind::Data); + for (ui32 i = 0; i < tags->GetValuesCount(); ++i) { NTable::TTag columnId = AS_VALUE(TDataLiteral, tags->GetValue(i))->AsValue().Get<ui32>(); if (IsSystemColumn(columnId)) { - systemColumns.push_back({columnId, AS_TYPE(TDataType, memberType)->GetSchemeType()}); + systemColumns.push_back(columnId); } else { - columns.push_back({columnId, AS_TYPE(TDataType, memberType)->GetSchemeType()}); + columns.push_back(columnId); } } } -void ParseWideReadColumns(const TCallable& callable, const TRuntimeNode& tagsNode, - TSmallVec<TKqpComputeContextBase::TColumn>& columns, TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns) +void ParseWideReadColumns(const TRuntimeNode& tagsNode, + TSmallVec<NTable::TTag>& columns, TSmallVec<NTable::TTag>& systemColumns) { auto tags = AS_VALUE(TStructLiteral, tagsNode); MKQL_ENSURE_S(tags); - TType* returnType = callable.GetType()->GetReturnType(); - MKQL_ENSURE_S(returnType->GetKind() == TType::EKind::Flow); - - auto itemType = AS_TYPE(TFlowType, returnType)->GetItemType(); - MKQL_ENSURE_S(itemType->GetKind() == TType::EKind::Tuple); - - auto tupleType = AS_TYPE(TTupleType, itemType); - MKQL_ENSURE_S(tags->GetValuesCount() == tupleType->GetElementsCount()); - - columns.reserve(tupleType->GetElementsCount()); + columns.reserve(tags->GetValuesCount()); - for (ui32 i = 0; i < tupleType->GetElementsCount(); ++i) { - auto memberType = tupleType->GetElementType(i); + for (ui32 i = 0; i < tags->GetValuesCount(); ++i) { - if (memberType->GetKind() == TType::EKind::Optional) { - memberType = AS_TYPE(TOptionalType, memberType)->GetItemType(); - } - MKQL_ENSURE_S(memberType->GetKind() == TType::EKind::Data); - - NTable::TTag columnId = AS_VALUE(TDataLiteral, tags->GetValue(i))->AsValue().Get<ui32>(); + NTable::TTag columnId = AS_VALUE(TDataLiteral, tags->GetValue(i))->AsValue().Get<ui32>();; if (IsSystemColumn(columnId)) { - systemColumns.push_back({columnId, AS_TYPE(TDataType, memberType)->GetSchemeType()}); - } else { - columns.push_back({columnId, AS_TYPE(TDataType, memberType)->GetSchemeType()}); + systemColumns.push_back(columnId); + } else if (columnId != TKeyDesc::EColumnIdInvalid) { + columns.push_back(columnId); } } } @@ -161,7 +136,7 @@ TParseReadTableResult ParseWideReadTable(TCallable& callable) { MKQL_ENSURE_S(result.ToTuple); result.ToInclusive = AS_VALUE(TDataLiteral, range->GetValue(3))->AsValue().Get<bool>(); - ParseWideReadColumns(callable, tagsNode, result.Columns, result.SystemColumns); + ParseWideReadColumns(tagsNode, result.Columns, result.SystemColumns); auto skipNullKeys = AS_VALUE(TListLiteral, callable.GetInput(3)); result.SkipNullKeys.reserve(skipNullKeys->GetItemsCount()); @@ -210,7 +185,7 @@ TParseReadTableRangesResult ParseWideReadTableRanges(TCallable& callable) { MKQL_ENSURE_S(result.Ranges); MKQL_ENSURE_S(result.Ranges->GetValuesCount() == 1); - ParseWideReadColumns(callable, tagsNode, result.Columns, result.SystemColumns); + ParseWideReadColumns(tagsNode, result.Columns, result.SystemColumns); auto limitNode = limit.GetNode(); diff --git a/ydb/core/kqp/runtime/kqp_read_table.h b/ydb/core/kqp/runtime/kqp_read_table.h index 99c9cc3f05..010751df3d 100644 --- a/ydb/core/kqp/runtime/kqp_read_table.h +++ b/ydb/core/kqp/runtime/kqp_read_table.h @@ -17,8 +17,8 @@ struct TParseReadTableResultBase { ui32 CallableId = 0; TTableId TableId; - TSmallVec<TKqpComputeContextBase::TColumn> Columns; - TSmallVec<TKqpComputeContextBase::TColumn> SystemColumns; + TSmallVec<NTable::TTag> Columns; + TSmallVec<NTable::TTag> SystemColumns; TSmallVec<bool> SkipNullKeys; TNode* ItemsLimit = nullptr; bool Reverse = false; @@ -36,7 +36,7 @@ struct TParseReadTableRangesResult : TParseReadTableResultBase { }; void ParseReadColumns(const TType* readType, const TRuntimeNode& tagsNode, - TSmallVec<TKqpComputeContextBase::TColumn>& columns, TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns); + TSmallVec<NTable::TTag>& columns, TSmallVec<NTable::TTag>& systemColumns); TParseReadTableResult ParseWideReadTable(TCallable& callable); TParseReadTableRangesResult ParseWideReadTableRanges(TCallable& callable); diff --git a/ydb/core/kqp/ut/kqp_olap_ut.cpp b/ydb/core/kqp/ut/kqp_olap_ut.cpp index 0f4018de23..e85c27b93f 100644 --- a/ydb/core/kqp/ut/kqp_olap_ut.cpp +++ b/ydb/core/kqp/ut/kqp_olap_ut.cpp @@ -510,6 +510,19 @@ Y_UNIT_TEST_SUITE(KqpOlap) { }; } + void CheckPlanForAggregatePushdown(const TString& query, NYdb::NTable::TTableClient& tableClient) { + TStreamExecScanQuerySettings scanSettings; + scanSettings.Explain(true); + auto res = tableClient.StreamExecuteScanQuery(query, scanSettings).GetValueSync(); + UNIT_ASSERT_C(res.IsSuccess(), res.GetIssues().ToString()); + + auto planRes = CollectStreamResult(res); + auto ast = planRes.QueryStats->Getquery_ast(); + + UNIT_ASSERT_C(ast.find("TKqpOlapAgg") != std::string::npos, + TStringBuilder() << "Aggregate was not pushed down. Query: " << query); + } + Y_UNIT_TEST_TWIN(SimpleQueryOlap, UseSessionActor) { auto settings = TKikimrSettings() .SetWithSampleTables(false) @@ -1268,6 +1281,170 @@ Y_UNIT_TEST_SUITE(KqpOlap) { } } + Y_UNIT_TEST(AggregationPushdown) { + auto settings = TKikimrSettings() + .SetWithSampleTables(false) + .SetEnableOlapSchemaOperations(true); + TKikimrRunner kikimr(settings); + + // EnableDebugLogging(kikimr); + CreateTestOlapTable(kikimr); + auto tableClient = kikimr.GetTableClient(); + + { + WriteTestData(kikimr, "/Root/olapStore/olapTable", 10000, 3000000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 11000, 3001000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 12000, 3002000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 13000, 3003000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 14000, 3004000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 20000, 2000000, 7000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 30000, 1000000, 11000); + } + + { + TString query = R"( + --!syntax_v1 + PRAGMA Kikimr.KqpPushOlapProcess = "true"; + SELECT + COUNT(level) + FROM `/Root/olapStore/olapTable` + )"; + auto opStartTime = Now(); + auto it = tableClient.StreamExecuteScanQuery(query).GetValueSync(); + + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + TString result = StreamResultToYson(it); + Cerr << "!!!\nPushdown query execution time: " << (Now() - opStartTime).MilliSeconds() << "\n!!!\n"; + Cout << result << Endl; + CompareYson(result, R"([[23000u;]])"); + + // Check plan + CheckPlanForAggregatePushdown(query, tableClient); + } + } + + Y_UNIT_TEST(AggregationGroupByPushdown) { + // remove this return when GROUP BY will be implemented on columnshard + return; + + auto settings = TKikimrSettings() + .SetWithSampleTables(false) + .SetEnableOlapSchemaOperations(true); + TKikimrRunner kikimr(settings); + + // EnableDebugLogging(kikimr); + CreateTestOlapTable(kikimr); + auto tableClient = kikimr.GetTableClient(); + + { + WriteTestData(kikimr, "/Root/olapStore/olapTable", 10000, 3000000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 11000, 3001000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 12000, 3002000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 13000, 3003000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 14000, 3004000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 20000, 2000000, 7000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 30000, 1000000, 11000); + } + + { + TString query = R"( + --!syntax_v1 + PRAGMA Kikimr.KqpPushOlapProcess = "true"; + SELECT + level, COUNT(level) + FROM `/Root/olapStore/olapTable` + GROUP BY level + )"; + auto it = tableClient.StreamExecuteScanQuery(query).GetValueSync(); + + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + TString result = StreamResultToYson(it); + Cout << result << Endl; + CompareYson(result, R"([[23000u;]])"); + + // Check plan + CheckPlanForAggregatePushdown(query, tableClient); + } + } + + Y_UNIT_TEST(CountAllPushdown) { + // remove this return when COUNT(*) will be implemented on columnshard + return; + + auto settings = TKikimrSettings() + .SetWithSampleTables(false) + .SetEnableOlapSchemaOperations(true); + TKikimrRunner kikimr(settings); + + // EnableDebugLogging(kikimr); + CreateTestOlapTable(kikimr); + auto tableClient = kikimr.GetTableClient(); + + { + WriteTestData(kikimr, "/Root/olapStore/olapTable", 10000, 3000000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 11000, 3001000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 12000, 3002000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 13000, 3003000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 14000, 3004000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 20000, 2000000, 7000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 30000, 1000000, 11000); + } + + { + TString query = R"( + --!syntax_v1 + PRAGMA Kikimr.KqpPushOlapProcess = "true"; + SELECT + COUNT(*) + FROM `/Root/olapStore/olapTable` + )"; + auto it = tableClient.StreamExecuteScanQuery(query).GetValueSync(); + + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + TString result = StreamResultToYson(it); + Cout << result << Endl; + CompareYson(result, R"([[23000u;]])"); + + // Check plan + CheckPlanForAggregatePushdown(query, tableClient); + } + } + + Y_UNIT_TEST(CountAllNoPushdown) { + auto settings = TKikimrSettings() + .SetWithSampleTables(false) + .SetEnableOlapSchemaOperations(true); + TKikimrRunner kikimr(settings); + + // EnableDebugLogging(kikimr); + CreateTestOlapTable(kikimr); + auto tableClient = kikimr.GetTableClient(); + + { + WriteTestData(kikimr, "/Root/olapStore/olapTable", 10000, 3000000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 11000, 3001000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 12000, 3002000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 13000, 3003000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 14000, 3004000, 1000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 20000, 2000000, 7000); + WriteTestData(kikimr, "/Root/olapStore/olapTable", 30000, 1000000, 11000); + } + + { + auto it = tableClient.StreamExecuteScanQuery(R"( + --!syntax_v1 + SELECT + COUNT(*) + FROM `/Root/olapStore/olapTable` + )").GetValueSync(); + + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + TString result = StreamResultToYson(it); + Cout << result << Endl; + CompareYson(result, R"([[23000u;]])"); + } + } + Y_UNIT_TEST_TWIN(StatsSysView, UseSessionActor) { auto settings = TKikimrSettings() .SetWithSampleTables(false) diff --git a/ydb/core/kqp/ut/kqp_stats_ut.cpp b/ydb/core/kqp/ut/kqp_stats_ut.cpp index c330a42bf3..2228f3c660 100644 --- a/ydb/core/kqp/ut/kqp_stats_ut.cpp +++ b/ydb/core/kqp/ut/kqp_stats_ut.cpp @@ -353,10 +353,10 @@ Y_UNIT_TEST_TWIN(StatsProfile, UseSessionActor) { NJson::TJsonValue plan; NJson::ReadJsonTree(result.GetQueryPlan(), &plan, true); - auto node1 = FindPlanNodeByKv(plan, "Node Type", "TableFullScan"); + auto node1 = FindPlanNodeByKv(plan, "Node Type", "Aggregate-TableFullScan"); UNIT_ASSERT_EQUAL(node1.GetMap().at("Stats").GetMapSafe().at("ComputeNodes").GetArraySafe().size(), 2); - auto node2 = FindPlanNodeByKv(plan, "Node Type", "Limit"); + auto node2 = FindPlanNodeByKv(plan, "Node Type", "Aggregate-Limit"); UNIT_ASSERT_EQUAL(node2.GetMap().at("Stats").GetMapSafe().at("ComputeNodes").GetArraySafe().size(), 1); } diff --git a/ydb/core/protos/ssa.proto b/ydb/core/protos/ssa.proto index 88eaeee8b4..c1c2f3e56b 100644 --- a/ydb/core/protos/ssa.proto +++ b/ydb/core/protos/ssa.proto @@ -19,6 +19,7 @@ option java_package = "ru.yandex.kikimr.proto"; message TProgram { message TColumn { optional uint64 Id = 1; + optional string Name = 2; } message TConstant { diff --git a/ydb/core/tx/columnshard/columnshard__read.cpp b/ydb/core/tx/columnshard/columnshard__read.cpp index 8b9c731373..dcf041b7d0 100644 --- a/ydb/core/tx/columnshard/columnshard__read.cpp +++ b/ydb/core/tx/columnshard/columnshard__read.cpp @@ -4,6 +4,7 @@ #include "columnshard__index_scan.h" #include <ydb/core/tx/columnshard/engines/column_engine.h> #include <ydb/core/tx/columnshard/engines/indexed_read_data.h> +#include <ydb/core/formats/ssa_program_optimizer.h> namespace NKikimr::NColumnShard { @@ -204,10 +205,18 @@ bool TTxReadBase::ParseProgram(const TActorContext& ctx, NKikimrSchemeOp::EOlapP read.ProgramParameters = NArrow::DeserializeBatch(olapProgram.GetParameters(), schema); } - if (!read.AddProgram(columnResolver, program)) { + + auto ssaProgramSteps = read.AddProgram(columnResolver, program); + if (!ssaProgramSteps) { ErrorDescription = TStringBuilder() << "Wrong olap program"; return false; } + if (!ssaProgramSteps->Program.empty() && Self->PrimaryIndex) { + ssaProgramSteps->Program = NKikimr::NSsaOptimizer::OptimizeProgram(ssaProgramSteps->Program, Self->PrimaryIndex->GetIndexInfo()); + } + + read.Program = ssaProgramSteps->Program; + read.ProgramSourceColumns = ssaProgramSteps->ProgramSourceColumns; return true; } diff --git a/ydb/core/tx/columnshard/columnshard_common.cpp b/ydb/core/tx/columnshard/columnshard_common.cpp index ecacebdc80..14f826774b 100644 --- a/ydb/core/tx/columnshard/columnshard_common.cpp +++ b/ydb/core/tx/columnshard/columnshard_common.cpp @@ -196,6 +196,9 @@ NArrow::TAggregateAssign MakeAggregate(const TContext& info, const std::string& case TId::AGG_UNSPECIFIED: break; } + } else if (func.ArgumentsSize() == 0 && func.GetId() == TId::AGG_COUNT) { + // COUNT(*) case + return TAggregateAssign(name, EAggregate::Count, {}); } return TAggregateAssign(name, EAggregate::Unspecified, {}); } @@ -373,50 +376,51 @@ std::pair<TPredicate, TPredicate> RangePredicates(const TSerializedTableRange& r TPredicate(EOperation::Less, rightBorder, NArrow::MakeArrowSchema(rightColumns), toInclusive)); } -bool TReadDescription::AddProgram(const IColumnResolver& columnResolver, const NKikimrSSA::TProgram& program) +std::shared_ptr<NArrow::TSsaProgramSteps> TReadDescription::AddProgram(const IColumnResolver& columnResolver, const NKikimrSSA::TProgram& program) { using TId = NKikimrSSA::TProgram::TCommand; + auto programSteps = std::make_shared<NArrow::TSsaProgramSteps>(); TContext info(columnResolver); auto step = std::make_shared<NArrow::TProgramStep>(); for (auto& cmd : program.GetCommand()) { switch (cmd.GetLineCase()) { case TId::kAssign: if (!ExtractAssign(info, *step, cmd.GetAssign(), ProgramParameters)) { - return false; + return nullptr; } break; case TId::kFilter: if (!ExtractFilter(info, *step, cmd.GetFilter())) { - return false; + return nullptr; } break; case TId::kProjection: if (!ExtractProjection(info, *step, cmd.GetProjection())) { - return false; + return nullptr; } - Program.push_back(step); + programSteps->Program.push_back(step); step = std::make_shared<NArrow::TProgramStep>(); break; case TId::kGroupBy: if (!ExtractGroupBy(info, *step, cmd.GetGroupBy())) { - return false; + return nullptr; } - Program.push_back(step); + programSteps->Program.push_back(step); step = std::make_shared<NArrow::TProgramStep>(); break; case TId::LINE_NOT_SET: - return false; + return nullptr; } } // final step without final projection if (!step->Empty()) { - Program.push_back(step); + programSteps->Program.push_back(step); } - ProgramSourceColumns = std::move(info.Sources); - return true; + programSteps->ProgramSourceColumns = std::move(info.Sources); + return programSteps; } } diff --git a/ydb/core/tx/columnshard/columnshard_common.h b/ydb/core/tx/columnshard/columnshard_common.h index bbec34cbbd..80a9e4599c 100644 --- a/ydb/core/tx/columnshard/columnshard_common.h +++ b/ydb/core/tx/columnshard/columnshard_common.h @@ -48,7 +48,7 @@ struct TReadDescription { ui64 PlanStep = 0; ui64 TxId = 0; - bool AddProgram(const IColumnResolver& columnResolver, const NKikimrSSA::TProgram& program); + std::shared_ptr<NArrow::TSsaProgramSteps> AddProgram(const IColumnResolver& columnResolver, const NKikimrSSA::TProgram& program); }; } diff --git a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp index db3e42f55f..0ec7aefce0 100644 --- a/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp +++ b/ydb/core/tx/columnshard/ut_columnshard_read_write.cpp @@ -887,6 +887,7 @@ static NKikimrSSA::TProgram MakeSelect(TAssignment::EFunction compareId = TAssig // FIXME: // NotImplemented: Function any has no kernel matching input types (array[timestamp[us]]) // NotImplemented: Function any has no kernel matching input types (array[string]) +// NotImplemented: Function any has no kernel matching input types (array[int32]) // NotImplemented: Function min_max has no kernel matching input types (array[timestamp[us]]) // NotImplemented: Function min_max has no kernel matching input types (array[string]) // diff --git a/ydb/core/tx/datashard/datashard_kqp_compute.cpp b/ydb/core/tx/datashard/datashard_kqp_compute.cpp index 0e10561ab6..b83770b18e 100644 --- a/ydb/core/tx/datashard/datashard_kqp_compute.cpp +++ b/ydb/core/tx/datashard/datashard_kqp_compute.cpp @@ -14,14 +14,6 @@ namespace NMiniKQL { using namespace NTable; using namespace NUdf; -TSmallVec<TTag> ExtractTags(const TSmallVec<TKqpComputeContextBase::TColumn>& columns) { - TSmallVec<TTag> tags; - for (const auto& column : columns) { - tags.push_back(column.Tag); - } - return tags; -} - typedef IComputationNode* (*TCallableDatashardBuilderFunc)(TCallable& callable, const TComputationNodeFactoryContext& ctx, TKqpDatashardComputeContext& computeCtx); diff --git a/ydb/core/tx/datashard/datashard_kqp_compute.h b/ydb/core/tx/datashard/datashard_kqp_compute.h index a0e2072be5..d547030db8 100644 --- a/ydb/core/tx/datashard/datashard_kqp_compute.h +++ b/ydb/core/tx/datashard/datashard_kqp_compute.h @@ -120,8 +120,6 @@ public: IEngineFlatHost* Host = nullptr; }; -TSmallVec<NTable::TTag> ExtractTags(const TSmallVec<TKqpComputeContextBase::TColumn>& columns); - IComputationNode* WrapKqpWideReadTableRanges(TCallable& callable, const TComputationNodeFactoryContext& ctx, TKqpDatashardComputeContext& computeCtx); IComputationNode* WrapKqpLookupTable(TCallable& callable, const TComputationNodeFactoryContext& ctx, diff --git a/ydb/core/tx/datashard/datashard_kqp_lookup_table.cpp b/ydb/core/tx/datashard/datashard_kqp_lookup_table.cpp index ee179cd91c..e1938769fd 100644 --- a/ydb/core/tx/datashard/datashard_kqp_lookup_table.cpp +++ b/ydb/core/tx/datashard/datashard_kqp_lookup_table.cpp @@ -19,8 +19,8 @@ struct TParseLookupTableResult { TVector<ui32> KeyIndices; TVector<NUdf::TDataTypeId> KeyTypes; - TSmallVec<TKqpComputeContextBase::TColumn> Columns; - TSmallVec<TKqpComputeContextBase::TColumn> SystemColumns; + TSmallVec<NTable::TTag> Columns; + TSmallVec<NTable::TTag> SystemColumns; TSmallVec<bool> SkipNullKeys; }; @@ -88,8 +88,8 @@ public: , TypeEnv(typeEnv) , ParseResult(parseResult) , LookupKeysNode(lookupKeysNode) - , ColumnTags(ExtractTags(ParseResult.Columns)) - , SystemColumnTags(ExtractTags(ParseResult.SystemColumns)) + , ColumnTags(ParseResult.Columns) + , SystemColumnTags(ParseResult.SystemColumns) , ShardTableStats(ComputeCtx.GetDatashardCounters()) , TaskTableStats(ComputeCtx.GetTaskCounters(ComputeCtx.GetCurrentTaskId())) { @@ -173,8 +173,8 @@ public: , TypeEnv(typeEnv) , ParseResult(parseResult) , LookupKeysNode(lookupKeysNode) - , ColumnTags(ExtractTags(ParseResult.Columns)) - , SystemColumnTags(ExtractTags(ParseResult.SystemColumns)) + , ColumnTags(ParseResult.Columns) + , SystemColumnTags(ParseResult.SystemColumns) , ShardTableStats(ComputeCtx.GetDatashardCounters()) , TaskTableStats(ComputeCtx.GetTaskCounters(computeCtx.GetCurrentTaskId())) {} diff --git a/ydb/core/tx/datashard/datashard_kqp_read_table.cpp b/ydb/core/tx/datashard/datashard_kqp_read_table.cpp index 7cfa13420e..c57a8a3840 100644 --- a/ydb/core/tx/datashard/datashard_kqp_read_table.cpp +++ b/ydb/core/tx/datashard/datashard_kqp_read_table.cpp @@ -252,12 +252,12 @@ public: const TParseReadTableResult& parseResult, IComputationNode* fromNode, IComputationNode* toNode, IComputationNode* itemsLimit) : TKqpWideReadTableWrapperBase<IsReverse>(parseResult.TableId, computeCtx, typeEnv, - ExtractTags(parseResult.SystemColumns), parseResult.SkipNullKeys) + parseResult.SystemColumns, parseResult.SkipNullKeys) , ParseResult(parseResult) , FromNode(fromNode) , ToNode(toNode) , ItemsLimit(itemsLimit) - , ColumnTags(ExtractTags(parseResult.Columns)) + , ColumnTags(parseResult.Columns) { this->ShardTableStats.NSelectRange++; this->TaskTableStats.NSelectRange++; @@ -315,11 +315,11 @@ public: TKqpWideReadTableRangesWrapper(TKqpDatashardComputeContext& computeCtx, const TTypeEnvironment& typeEnv, const TParseReadTableRangesResult& parseResult, IComputationNode* rangesNode, IComputationNode* itemsLimit) : TKqpWideReadTableWrapperBase<IsReverse>(parseResult.TableId, computeCtx, typeEnv, - ExtractTags(parseResult.SystemColumns), parseResult.SkipNullKeys) + parseResult.SystemColumns, parseResult.SkipNullKeys) , ParseResult(parseResult) , RangesNode(rangesNode) , ItemsLimit(itemsLimit) - , ColumnTags(ExtractTags(parseResult.Columns)) {} + , ColumnTags(parseResult.Columns) {} private: EFetchResult ReadValue(TComputationContext& ctx, NUdf::TUnboxedValue* const* output) const final { diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json index 65e5acba4f..c27cd98840 100644 --- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json +++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json @@ -419,6 +419,23 @@ ] }, { + "Name": "TCoAggrCountInit", + "Base": "TCallable", + "Match": {"Type": "Callable", "Name": "AggrCountInit"}, + "Children": [ + {"Index": 0, "Name": "Value", "Type": "TExprBase"} + ] + }, + { + "Name": "TCoAggrCountUpdate", + "Base": "TCallable", + "Match": {"Type": "Callable", "Name": "AggrCountUpdate"}, + "Children": [ + {"Index": 0, "Name": "Value", "Type": "TExprBase"}, + {"Index": 1, "Name": "State", "Type": "TExprBase"} + ] + }, + { "Name": "TCoMin", "VarArgBase": "TExprBase", "Match": {"Type": "Callable", "Name": "Min"} diff --git a/ydb/library/yql/dq/expr_nodes/dq_expr_nodes.json b/ydb/library/yql/dq/expr_nodes/dq_expr_nodes.json index 46d880b3a5..11e2d47da4 100644 --- a/ydb/library/yql/dq/expr_nodes/dq_expr_nodes.json +++ b/ydb/library/yql/dq/expr_nodes/dq_expr_nodes.json @@ -245,6 +245,15 @@ {"Index": 1, "Name": "TransformName", "Type": "TExprBase"}, {"Index": 2, "Name": "Settings", "Type": "TCoNameValueTupleList"} ] + }, + { + "Name": "TDqPhyLength", + "Base": "TCallable", + "Match": {"Type": "Callable", "Name": "DqPhyLength"}, + "Children": [ + {"Index": 0, "Name": "Input", "Type": "TExprBase"}, + {"Index": 1, "Name": "Name", "Type": "TCoAtom"} + ] } ] } diff --git a/ydb/library/yql/dq/opt/dq_opt_peephole.cpp b/ydb/library/yql/dq/opt/dq_opt_peephole.cpp index 223cc2c9be..ea8864e0d7 100644 --- a/ydb/library/yql/dq/opt/dq_opt_peephole.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_peephole.cpp @@ -666,4 +666,36 @@ NNodes::TExprBase DqPeepholeDropUnusedInputs(const NNodes::TExprBase& node, TExp return NNodes::TExprBase(ctx.ChangeChildren(node.Ref(), std::move(children))); } +NNodes::TExprBase DqPeepholeRewriteLength(const NNodes::TExprBase& node, TExprContext& ctx) { + if (!node.Maybe<TDqPhyLength>()) { + return node; + } + + auto dqPhyLength = node.Cast<TDqPhyLength>(); + + auto zero = Build<TCoUint64>(ctx, node.Pos()) + .Literal().Build("0") + .Done(); + + return Build<TCoCondense>(ctx, node.Pos()) + .Input(dqPhyLength.Input()) + .State<TCoUint64>() + .Literal().Build("0") + .Build() + .SwitchHandler() + .Args({"item", "state"}) + .Body(MakeBool<false>(node.Pos(), ctx)) + .Build() + .UpdateHandler() + .Args({"item", "state"}) + .Body<TCoAggrAdd>() + .Left("state") + .Right<TCoUint64>() + .Literal().Build("1") + .Build() + .Build() + .Build() + .Done(); +} + } // namespace NYql::NDq diff --git a/ydb/library/yql/dq/opt/dq_opt_peephole.h b/ydb/library/yql/dq/opt/dq_opt_peephole.h index f403c8d39d..e2bf494c0e 100644 --- a/ydb/library/yql/dq/opt/dq_opt_peephole.h +++ b/ydb/library/yql/dq/opt/dq_opt_peephole.h @@ -13,5 +13,6 @@ NNodes::TExprBase DqPeepholeRewriteMapJoin(const NNodes::TExprBase& node, TExprC NNodes::TExprBase DqPeepholeRewriteReplicate(const NNodes::TExprBase& node, TExprContext& ctx); NNodes::TExprBase DqPeepholeRewritePureJoin(const NNodes::TExprBase& node, TExprContext& ctx); NNodes::TExprBase DqPeepholeDropUnusedInputs(const NNodes::TExprBase& node, TExprContext& ctx); +NNodes::TExprBase DqPeepholeRewriteLength(const NNodes::TExprBase& node, TExprContext& ctx); } // namespace NYql::NDq diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp index f0b37b2276..358cf9b2e9 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp @@ -1208,47 +1208,16 @@ TExprBase DqRewriteLengthOfStageOutput(TExprBase node, TExprContext& ctx, IOptim auto field = BuildAtom("_dq_agg_cnt", node.Pos(), ctx); - auto combineLambda = Build<TCoLambda>(ctx, node.Pos()) + auto dqLengthLambda = Build<TCoLambda>(ctx, node.Pos()) .Args({"stream"}) - .Body<TCoCombineByKey>() + .Body<TDqPhyLength>() .Input("stream") - .PreMapLambda() - .Args({"item"}) - .Body<TCoJust>() - .Input("item") - .Build() - .Build() - .KeySelectorLambda() - .Args({"item"}) - .Body(zero) - .Build() - .InitHandlerLambda() - .Args({"key", "item"}) - .Body<TCoUint64>() - .Literal().Build("1") - .Build() - .Build() - .UpdateHandlerLambda() - .Args({"key", "item", "state"}) - .Body<TCoInc>() - .Value("state") - .Build() - .Build() - .FinishHandlerLambda() - .Args({"key", "state"}) - .Body<TCoJust>() - .Input<TCoAsStruct>() - .Add<TCoNameValueTuple>() - .Name(field) - .Value("state") - .Build() - .Build() - .Build() - .Build() + .Name(field) .Build() .Done(); - auto result = DqPushLambdaToStageUnionAll(dqUnion, combineLambda, {}, ctx, optCtx); + + auto result = DqPushLambdaToStageUnionAll(dqUnion, dqLengthLambda, {}, ctx, optCtx); if (!result) { return node; } @@ -1270,10 +1239,7 @@ TExprBase DqRewriteLengthOfStageOutput(TExprBase node, TExprContext& ctx, IOptim .Args({"item", "state"}) .Body<TCoAggrAdd>() .Left("state") - .Right<TCoMember>() - .Struct("item") - .Name(field) - .Build() + .Right("item") .Build() .Build() .Build() diff --git a/ydb/library/yql/dq/type_ann/dq_type_ann.cpp b/ydb/library/yql/dq/type_ann/dq_type_ann.cpp index 5f46d1bb53..f4835abdad 100644 --- a/ydb/library/yql/dq/type_ann/dq_type_ann.cpp +++ b/ydb/library/yql/dq/type_ann/dq_type_ann.cpp @@ -842,6 +842,21 @@ TStatus AnnotateDqPhyPrecompute(const TExprNode::TPtr& node, TExprContext& ctx) return TStatus::Ok; } +TStatus AnnotateDqPhyLength(const TExprNode::TPtr& node, TExprContext& ctx) { + if (!EnsureArgsCount(*node, 2, ctx)) { + return TStatus::Error; + } + auto* input = node->Child(TDqPhyLength::idx_Input); + auto* aggName = node->Child(TDqPhyLength::idx_Name); + + TVector<const TItemExprType*> aggTypes; + if (!EnsureAtom(*aggName, ctx)) { + return TStatus::Error; + } + node->SetTypeAnn(MakeSequenceType(input->GetTypeAnn()->GetKind(), *ctx.MakeType<TDataExprType>(EDataSlot::Uint64), ctx)); + return TStatus::Ok; +} + THolder<IGraphTransformer> CreateDqTypeAnnotationTransformer(TTypeAnnotationContext& typesCtx) { auto coreTransformer = CreateExtCallableTypeAnnotationTransformer(typesCtx); @@ -937,6 +952,10 @@ THolder<IGraphTransformer> CreateDqTypeAnnotationTransformer(TTypeAnnotationCont return AnnotateDqPhyPrecompute(input, ctx); } + if (TDqPhyLength::Match(input.Get())) { + return AnnotateDqPhyLength(input, ctx); + } + return coreTransformer->Transform(input, output, ctx); }); } diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp index c171924714..7ce46e261e 100644 --- a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp +++ b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp @@ -58,6 +58,7 @@ namespace NYql::NDqs { PERFORM_RULE(DqPeepholeRewritePureJoin, node, ctx); PERFORM_RULE(DqPeepholeRewriteReplicate, node, ctx); PERFORM_RULE(DqPeepholeDropUnusedInputs, node, ctx); + PERFORM_RULE(DqPeepholeRewriteLength, node, ctx); return inputExpr; }, ctx, optSettings); }); diff --git a/ydb/tests/functional/canonical/canondata/test_sql.TestCanonicalFolder1.test_case_explain.script-script_/explain.script.plan b/ydb/tests/functional/canonical/canondata/test_sql.TestCanonicalFolder1.test_case_explain.script-script_/explain.script.plan index a0685b27cd..335217c5ae 100644 --- a/ydb/tests/functional/canonical/canondata/test_sql.TestCanonicalFolder1.test_case_explain.script-script_/explain.script.plan +++ b/ydb/tests/functional/canonical/canondata/test_sql.TestCanonicalFolder1.test_case_explain.script-script_/explain.script.plan @@ -73,8 +73,6 @@ "PlanNodeId": 1, "Operators": [ { - "GroupBy": "0", - "Aggregation": "Inc(state)", "Name": "Aggregate" }, { @@ -95,11 +93,14 @@ ], "Operators": [ { + "Name": "Aggregate" + }, + { "Name": "Limit", "Limit": "1001" } ], - "Node Type": "Limit" + "Node Type": "Aggregate-Limit" } ], "Node Type": "Precompute_0_0", @@ -189,8 +190,6 @@ "PlanNodeId": 1, "Operators": [ { - "GroupBy": "0", - "Aggregation": "Inc(state)", "Name": "Aggregate" }, { @@ -211,11 +210,14 @@ ], "Operators": [ { + "Name": "Aggregate" + }, + { "Name": "Limit", "Limit": "1001" } ], - "Node Type": "Limit" + "Node Type": "Aggregate-Limit" } ], "Node Type": "Precompute_0_0", |