diff options
author | vvvv <vvvv@ydb.tech> | 2022-10-10 23:20:20 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-10-10 23:20:20 +0300 |
commit | f970a39bdd247dd49a1204c59ba07966d2886905 (patch) | |
tree | cd8af96e552a2c05626e67d510fe1c2aadc8db06 | |
parent | a6b8cecadad0531eaf559b57c797d7d71bae12d7 (diff) | |
download | ydb-f970a39bdd247dd49a1204c59ba07966d2886905.tar.gz |
implementation of Aggregate phases in peephole
-rw-r--r-- | ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp | 6 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 2 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 34 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.cpp | 259 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.h | 6 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_constraint.cpp | 6 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/aggregation.cpp | 148 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/insert.cpp | 5 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/node.cpp | 51 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/node.h | 12 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/select.cpp | 12 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/sql.cpp | 2 |
12 files changed, 380 insertions, 163 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 5e2482d699c..40c096b7f32 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -6121,6 +6121,12 @@ struct TPeepHoleRules { static constexpr std::initializer_list<TExtPeepHoleOptimizerMap::value_type> CommonStageExtRulesInit = { {"Aggregate", &ExpandAggregatePeephole}, + {"AggregateCombine", &ExpandAggregatePeephole}, + {"AggregateCombineState", &ExpandAggregatePeephole}, + {"AggregateMergeState", &ExpandAggregatePeephole}, + {"AggregateMergeFinalize", &ExpandAggregatePeephole}, + {"AggregateMergeManyFinalize", &ExpandAggregatePeephole}, + {"AggregateFinalize", &ExpandAggregatePeephole}, }; static constexpr std::initializer_list<TPeepHoleOptimizerMap::value_type> SimplifyStageRulesInit = { diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 3c458a0ff49..f6383d58a78 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11444,6 +11444,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["AggregateMergeState"] = &AggregateWrapper; Functions["AggregateFinalize"] = &AggregateWrapper; Functions["AggregateMergeFinalize"] = &AggregateWrapper; + Functions["AggregateMergeManyFinalize"] = &AggregateWrapper; Functions["AggOverState"] = &AggOverStateWrapper; Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper; Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper; @@ -11507,6 +11508,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["PgOr"] = &PgBoolOpWrapper; Functions["PgNot"] = &PgBoolOpWrapper; Functions["PgAggregationTraits"] = &PgAggregationTraitsWrapper; + Functions["PgAggregationTraitsOverState"] = &PgAggregationTraitsWrapper; Functions["PgWindowTraits"] = &PgAggregationTraitsWrapper; Functions["PgAggregationTraitsTuple"] = &PgAggregationTraitsWrapper; Functions["PgWindowTraitsTuple"] = &PgAggregationTraitsWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index e5a4b3ba4bb..d12f33f801c 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -4398,21 +4398,21 @@ namespace { return IGraphTransformer::TStatus::Error; } - if (!overState) { - if (lambdaUpdate->Head().ChildrenSize() == 2U) { - if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - } else { - if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } + if (lambdaUpdate->Head().ChildrenSize() == 2U) { + if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; } - - if (!lambdaUpdate->GetTypeAnn()) { - return IGraphTransformer::TStatus::Repeat; + } else { + if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; } + } + + if (!lambdaUpdate->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + if (!overState) { if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *combineStateType)) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch update lambda result type, expected: " << *combineStateType << ", but got: " << *lambdaUpdate->GetTypeAnn())); @@ -4774,6 +4774,11 @@ namespace { const TTypeAnnotationNode* distinctColumnType = nullptr; if (child->ChildrenSize() == 3) { + if (suffix != "" && suffix != "Finalize") { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder() << "DISTINCT aggregation is not supported for mode: " << suffix)); + return IGraphTransformer::TStatus::Error; + } + if (!EnsureAtom(*child->Child(2), ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -5116,6 +5121,11 @@ namespace { TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot)); return IGraphTransformer::TStatus::Error; } + + if (isOptional) { + sumResultType = ctx.Expr.MakeType<TOptionalExprType>(sumResultType); + } + input->SetTypeAnn(sumResultType); } else if (IsNull(*lambda->GetTypeAnn())) { input->SetTypeAnn(ctx.Expr.MakeType<TNullExprType>()); diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index 5fc67082331..afd95c720e2 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -1,4 +1,5 @@ #include "yql_aggregate_expander.h" +#include "yql_aggregate_expander.h" #include <ydb/library/yql/core/yql_expr_optimize.h> #include <ydb/library/yql/core/yql_expr_type_annotation.h> @@ -8,6 +9,8 @@ namespace NYql { TExprNode::TPtr TAggregateExpander::ExpandAggregate() { + Suffix = Node->Content(); + YQL_ENSURE(Suffix.SkipPrefix("Aggregate")); AggList = Node->HeadPtr(); KeyColumns = Node->ChildPtr(1); AggregatedColumns = Node->Child(2); @@ -29,6 +32,13 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate() } } + if (Suffix == "Finalize") { + EffectiveCompact = true; + Suffix = ""; + } else if (Suffix != "") { + EffectiveCompact = false; + } + OriginalRowType = GetSeqItemType(Node->Head().GetTypeAnn())->Cast<TStructExprType>(); RowItems = OriginalRowType->GetItems(); @@ -39,11 +49,20 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate() bool needPickle = IsNeedPickle(keyItemTypes); auto keyExtractor = GetKeyExtractor(needPickle); CollectColumnsSpecs(); - + + if (Suffix == "MergeState" || Suffix == "MergeFinalize" || Suffix == "MergeManyFinalize") { + return GeneratePostAggregate(AggList, keyExtractor); + } + TExprNode::TPtr preAgg = GeneratePartialAggregate(keyExtractor, keyItemTypes, needPickle); if (EffectiveCompact || !preAgg) { preAgg = std::move(AggList); } + + if (Suffix == "Combine" || Suffix == "CombineState") { + return preAgg; + } + return GeneratePostAggregate(preAgg, keyExtractor); } @@ -76,7 +95,7 @@ bool TAggregateExpander::CollectTraits() { bool allTraitsCollected = true; for (ui32 index = 0; index < AggregatedColumns->ChildrenSize(); ++index) { auto trait = AggregatedColumns->Child(index)->ChildPtr(1); - if (trait->IsCallable("AggApply")) { + if (trait->IsCallable({ "AggApply", "AggApplyState" })) { trait = ExpandAggApply(trait); allTraitsCollected = false; } @@ -341,8 +360,97 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregate(const TExprNode::TP return partialAgg; } +std::function<TExprNodeBuilder& (TExprNodeBuilder&)> TAggregateExpander::GetPartialAggArgExtractor(ui32 i, bool deserialize) { + return [&, i, deserialize](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + auto trait = Traits[i]; + auto extractorLambda = trait->Child(1); + auto loadLambda = trait->Child(4); + if (Suffix == "CombineState") { + if (deserialize) { + parent.Apply(*loadLambda) + .With(0) + .Apply(*extractorLambda) + .With(0) + .Callable("CastStruct") + .Arg(0, "item") + .Add(1, ExpandType(Node->Pos(), *extractorLambda->Head().Head().GetTypeAnn(), Ctx)) + .Seal() + .Done() + .Seal() + .Done() + .Seal(); + } else { + parent.Apply(*extractorLambda) + .With(0) + .Callable("CastStruct") + .Arg(0, "item") + .Add(1, ExpandType(Node->Pos(), *extractorLambda->Head().Head().GetTypeAnn(), Ctx)) + .Seal() + .Done() + .Seal(); + } + } else { + parent.Callable("CastStruct") + .Arg(0, "item") + .Add(1, ExpandType(Node->Pos(), *extractorLambda->Head().Head().GetTypeAnn(), Ctx)) + .Seal(); + } + + return parent; + }; +} + +TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) { + auto trait = Traits[i]; + if (Suffix.StartsWith("Merge")) { + auto lambda = trait->ChildPtr(1); + if (!Suffix.StartsWith("MergeMany")) { + return lambda; + } + + if (lambda->Tail().IsCallable("Unwrap")) { + return Ctx.Builder(Node->Pos()) + .Lambda() + .Param("item") + .ApplyPartial(lambda->HeadPtr(), lambda->Tail().HeadPtr()) + .With(0, "item") + .Seal() + .Seal() + .Build(); + } else { + return Ctx.Builder(Node->Pos()) + .Lambda() + .Param("item") + .Callable("Just") + .Apply(0, *lambda) + .With(0, "item") + .Seal() + .Seal() + .Seal() + .Build(); + } + } + + bool aggregateOnly = (Suffix != ""); + const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames; + return Ctx.Builder(Node->Pos()) + .Lambda() + .Param("item") + .Callable("Member")
+ .Arg(0, "item")
+ .Add(1, columnNames[i])
+ .Seal() + .Seal() + .Build(); +} + TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const TExprNode::TPtr& keyExtractor, const TExprNode::TPtr& pickleTypeNode) { + bool combineOnly = Suffix == "Combine" || Suffix == "CombineState"; + const auto& columnNames = combineOnly ? FinalColumnNames : InitialColumnNames; + auto initLambdaIndex = (Suffix == "CombineState") ? 4 : 1; + auto updateLambdaIndex = (Suffix == "CombineState") ? 5 : 2; + auto combineInit = Ctx.Builder(Node->Pos()) .Lambda() .Param("key") @@ -352,28 +460,22 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const ui32 ndx = 0; for (ui32 i: NonDistinctColumns) { auto trait = Traits[i]; - auto initLambda = trait->Child(1); + auto initLambda = trait->Child(initLambdaIndex); if (initLambda->Head().ChildrenSize() == 1) { parent.List(ndx++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *initLambda) .With(0) - .Callable("CastStruct") - .Arg(0, "item") - .Add(1, ExpandType(Node->Pos(), *initLambda->Head().Head().GetTypeAnn(), Ctx)) - .Seal() + .Do(GetPartialAggArgExtractor(i, false)) .Done() .Seal() .Seal(); } else { parent.List(ndx++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *initLambda) .With(0) - .Callable("CastStruct") - .Arg(0, "item") - .Add(1, ExpandType(Node->Pos(), *initLambda->Head().Head().GetTypeAnn(), Ctx)) - .Seal() + .Do(GetPartialAggArgExtractor(i, false)) .Done() .With(1) .Callable("Uint32") @@ -400,39 +502,33 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const ui32 ndx = 0; for (ui32 i: NonDistinctColumns) { auto trait = Traits[i]; - auto updateLambda = trait->Child(2); + auto updateLambda = trait->Child(updateLambdaIndex); if (updateLambda->Head().ChildrenSize() == 2) { parent.List(ndx++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *updateLambda) .With(0) - .Callable("CastStruct") - .Arg(0, "item") - .Add(1, ExpandType(Node->Pos(), *updateLambda->Head().Head().GetTypeAnn(), Ctx)) - .Seal() + .Do(GetPartialAggArgExtractor(i, true)) .Done() .With(1) - .Callable("Member") - .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Callable("Member")
+ .Arg(0, "state")
+ .Add(1, columnNames[i])
.Seal() .Done() .Seal() .Seal(); } else { parent.List(ndx++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *updateLambda) .With(0) - .Callable("CastStruct") - .Arg(0, "item") - .Add(1, ExpandType(Node->Pos(), *updateLambda->Head().Head().GetTypeAnn(), Ctx)) - .Seal() + .Do(GetPartialAggArgExtractor(i, true)) .Done() .With(1) - .Callable("Member") - .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Callable("Member")
+ .Arg(0, "state")
+ .Add(1, columnNames[i])
.Seal() .Done() .With(2) @@ -457,10 +553,10 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const .Callable("Just") .Callable(0, "AsStruct") .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (ui32 i = 0; i < InitialColumnNames.size(); ++i) { + for (ui32 i = 0; i < columnNames.size(); ++i) { if (NonDistinctColumns.find(i) == NonDistinctColumns.end()) { parent.List(i) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Add(1, NothingStates[i]) .Seal(); } else { @@ -468,13 +564,13 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const auto saveLambda = trait->Child(3); if (!DistinctFields.empty()) { parent.List(i) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Callable(1, "Just") .Apply(0, *saveLambda) .With(0) .Callable("Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Done() .Seal() @@ -482,12 +578,12 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const .Seal(); } else { parent.List(i) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *saveLambda) .With(0) .Callable("Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Done() .Seal() @@ -500,7 +596,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { ui32 pos = 0; for (ui32 i = 0; i < KeyColumns->ChildrenSize(); ++i) { - auto listBuilder = parent.List(InitialColumnNames.size() + i); + auto listBuilder = parent.List(columnNames.size() + i); listBuilder.Add(0, KeyColumns->ChildPtr(i)); if (KeyColumns->ChildrenSize() > 1) { if (pickleTypeNode) { @@ -1006,7 +1102,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregate(const TExprNode::TPtr& .Seal() .Seal().Build(); - if (KeyColumns->ChildrenSize() == 0 && !HaveSessionSetting) { + if (KeyColumns->ChildrenSize() == 0 && !HaveSessionSetting && (Suffix == "" || Suffix.EndsWith("Finalize"))) { return MakeSingleGroupRow(*Node, postAgg, Ctx); } @@ -1086,6 +1182,9 @@ TExprNode::TPtr TAggregateExpander::GenerateCondenseSwitch(const TExprNode::TPtr TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase() { + bool aggregateOnly = (Suffix != ""); + const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames; + ui32 index = 0U; return Ctx.Builder(Node->Pos()) .Lambda() @@ -1115,31 +1214,30 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase() return parent; }) .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (ui32 i = 0; i < InitialColumnNames.size(); ++i) { + for (ui32 i = 0; i < columnNames.size(); ++i) { auto child = AggregatedColumns->Child(i); auto trait = Traits[i]; if (!EffectiveCompact) { auto loadLambda = trait->Child(4); + auto extractorLambda = GetFinalAggStateExtractor(i); - if (!DistinctFields.empty()) { + if (!DistinctFields.empty() || Suffix == "MergeManyFinalize") { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Callable(1, "Map") - .Callable(0, "Member") - .Arg(0, "item") - .Add(1, InitialColumnNames[i]) + .Apply(0, *extractorLambda) + .With(0, "item") .Seal() .Add(1, loadLambda) .Seal() .Seal(); } else { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *loadLambda) .With(0) - .Callable("Member") - .Arg(0, "item") - .Add(1, InitialColumnNames[i]) + .Apply(*extractorLambda) + .With(0, "item") .Seal() .Done() .Seal(); @@ -1188,7 +1286,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase() const bool isFirst = *Distinct2Columns[distinctField->Content()].begin() == i; if (isFirst) { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .List(1) .Callable(0, "NamedApply") .Add(0, UdfSetCreate[distinctField->Content()]) @@ -1226,13 +1324,13 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase() .Seal(); } else { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Do(initApply) .Seal(); } } else { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Do(initApply) .Seal(); } @@ -1247,6 +1345,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase() TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() { + bool aggregateOnly = (Suffix != ""); + const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames; + ui32 index = 0U; return Ctx.Builder(Node->Pos()) .Lambda() @@ -1280,12 +1381,12 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() return parent; }) .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (ui32 i = 0; i < InitialColumnNames.size(); ++i) { + for (ui32 i = 0; i < columnNames.size(); ++i) { auto child = AggregatedColumns->Child(i); auto trait = Traits[i]; - auto finishLambda = trait->Child(6); + auto finishLambda = (Suffix == "MergeState") ? trait->Child(3) : trait->Child(6); - if (!EffectiveCompact && !DistinctFields.empty()) { + if (!EffectiveCompact && (!DistinctFields.empty() || Suffix == "MergeManyFinalize")) { if (child->Head().IsAtom()) { parent.List(index++) .Add(0, FinalColumnNames[i]) @@ -1293,7 +1394,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() .Callable(0, "Map") .Callable(0, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Add(1, finishLambda) .Seal() @@ -1309,7 +1410,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() .Callable(0, "Map") .Callable(0, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Add(1, finishLambda) .Seal() @@ -1327,14 +1428,14 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() parent.Callable("Nth") .Callable(0, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Atom(1, "1", TNodeFlags::Default) .Seal(); } else { parent.Callable("Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal(); } @@ -1377,6 +1478,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() { + bool aggregateOnly = (Suffix != ""); + const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames; + ui32 index = 0U; return Ctx.Builder(Node->Pos()) .Lambda() @@ -1407,41 +1511,40 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() return parent; }) .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (ui32 i = 0; i < InitialColumnNames.size(); ++i) { + for (ui32 i = 0; i < columnNames.size(); ++i) { auto child = AggregatedColumns->Child(i); auto trait = Traits[i]; if (!EffectiveCompact) { auto loadLambda = trait->Child(4); auto mergeLambda = trait->Child(5); + auto extractorLambda = GetFinalAggStateExtractor(i); - if (!DistinctFields.empty()) { + if (!DistinctFields.empty() || Suffix == "MergeManyFinalize") { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Callable(1, "OptionalReduce") .Callable(0, "Map") - .Callable(0, "Member") - .Arg(0, "item") - .Add(1, InitialColumnNames[i]) + .Apply(0, extractorLambda) + .With(0, "item") .Seal() .Add(1, loadLambda) .Seal() .Callable(1, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Add(2, mergeLambda) .Seal() .Seal(); } else { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Apply(1, *mergeLambda) .With(0) .Apply(*loadLambda) .With(0) - .Callable("Member") - .Arg(0, "item") - .Add(1, InitialColumnNames[i]) + .Apply(extractorLambda) + .With(0, "item") .Seal() .Done() .Seal() @@ -1449,7 +1552,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() .With(1) .Callable("Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Done() .Seal() @@ -1486,14 +1589,14 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() parent.Callable("Nth") .Callable(0, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Atom(1, "1", TNodeFlags::Default) .Seal(); } else { parent.Callable("Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal(); } @@ -1527,7 +1630,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() .Callable(0, "Nth") .Callable(0, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[distinctIndex]) + .Add(1, columnNames[distinctIndex]) .Seal() .Atom(1, "0", TNodeFlags::Default) .Seal() @@ -1556,7 +1659,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() }; parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Callable(1, "If") .Callable(0, "NamedApply") .Add(0, UdfWasChanged[distinctField->Content()]) @@ -1567,7 +1670,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() .Callable(0, "Nth") .Callable(0, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[distinctIndex]) + .Add(1, columnNames[distinctIndex]) .Seal() .Atom(1, "0", TNodeFlags::Default) .Seal() @@ -1608,13 +1711,13 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() }) .Callable(2, "Member") .Arg(0, "state") - .Add(1, InitialColumnNames[i]) + .Add(1, columnNames[i]) .Seal() .Seal() .Seal(); } else { parent.List(index++) - .Add(0, InitialColumnNames[i]) + .Add(0, columnNames[i]) .Do(updateApply) .Seal(); } @@ -1627,4 +1730,4 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() .Build(); } -} // namespace NYql
\ No newline at end of file +} // namespace NYql diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h index 552bdee2229..c39481012dc 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.h +++ b/ydb/library/yql/core/yql_aggregate_expander.h @@ -64,6 +64,9 @@ private: TExprNode::TPtr GeneratePostAggregateSavePhase(); TExprNode::TPtr GeneratePostAggregateMergePhase(); + std::function<TExprNodeBuilder& (TExprNodeBuilder&)> GetPartialAggArgExtractor(ui32 i, bool deserialize); + TExprNode::TPtr GetFinalAggStateExtractor(ui32 i); + private: static constexpr TStringBuf SessionStartMemberName = "_yql_group_session_start"; @@ -73,6 +76,7 @@ private: bool AllowPickle; bool ForceCompact; bool CompactForDistinct; + TStringBuf Suffix; TSessionWindowParams SessionWindowParams; TExprNode::TPtr AggList; @@ -110,4 +114,4 @@ inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExp return aggExpander.ExpandAggregate(); } -}
\ No newline at end of file +} diff --git a/ydb/library/yql/core/yql_expr_constraint.cpp b/ydb/library/yql/core/yql_expr_constraint.cpp index 48e8920af51..d995acfd799 100644 --- a/ydb/library/yql/core/yql_expr_constraint.cpp +++ b/ydb/library/yql/core/yql_expr_constraint.cpp @@ -211,6 +211,12 @@ public: Functions["WideCombiner"] = &TCallableConstraintTransformer::InheriteEmptyFromInput; Functions["WideCondense1"] = &TCallableConstraintTransformer::WideCondense1Wrap; Functions["Aggregate"] = &TCallableConstraintTransformer::AggregateWrap; + Functions["AggregateCombine"] = &TCallableConstraintTransformer::AggregateWrap; + Functions["AggregateCombineState"] = &TCallableConstraintTransformer::AggregateWrap; + Functions["AggregateMergeState"] = &TCallableConstraintTransformer::AggregateWrap; + Functions["AggregateMergeFinalize"] = &TCallableConstraintTransformer::AggregateWrap; + Functions["AggregateMergeManyFinalize"] = &TCallableConstraintTransformer::AggregateWrap; + Functions["AggregateFinalize"] = &TCallableConstraintTransformer::AggregateWrap; Functions["Fold"] = &TCallableConstraintTransformer::FoldWrap; Functions["Fold1"] = &TCallableConstraintTransformer::FoldWrap; Functions["WithContext"] = &TCallableConstraintTransformer::CopyAllFrom<0>; diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp index b9157664b01..c141c859c17 100644 --- a/ydb/library/yql/sql/v1/aggregation.cpp +++ b/ydb/library/yql/sql/v1/aggregation.cpp @@ -114,23 +114,29 @@ protected: return Factory; } - TNodePtr GetExtractor() const override { - return BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)); + TNodePtr GetExtractor(bool many, TContext& ctx) const override { + Y_UNUSED(ctx); + return BuildLambda(Pos, Y("row"), Y("PersistableRepr", many ? Y("Unwrap", Expr) : Expr)); } - TNodePtr GetApply(const TNodePtr& type) const override { + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const override { + auto extractor = GetExtractor(many, ctx); + if (!extractor) { + return nullptr; + } + if (!Multi) { if (!DynamicFactory && !AggApplyName.empty()) { - return Y("AggApply", Q(AggApplyName), Y("ListItemType", type), BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr))); + return Y("AggApply", Q(AggApplyName), Y("ListItemType", type), extractor); } return Y("Apply", Factory, (DynamicFactory ? Y("ListItemType", type) : type), - BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr))); + extractor); } return Y("MultiAggregate", Y("ListItemType", type), - GetExtractor(), + extractor, Factory); } @@ -288,12 +294,16 @@ private: return new TKeyPayloadAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetExtractor() const final { - return BuildLambda(Pos, Y("row"), Payload); + TNodePtr GetExtractor(bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + return BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload); } - TNodePtr GetApply(const TNodePtr& type) const final { - auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Key), BuildLambda(Pos, Y("row"), Payload)); + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + auto apply = Y("Apply", Factory, type, + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Key) : Key), + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload)); AddFactoryArguments(apply); return apply; } @@ -384,12 +394,16 @@ private: return new TPayloadPredicateAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetExtractor() const final { - return BuildLambda(Pos, Y("row"), Payload); + TNodePtr GetExtractor(bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + return BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload); } - TNodePtr GetApply(const TNodePtr& type) const final { - return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Payload), BuildLambda(Pos, Y("row"), Predicate)); + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + return Y("Apply", Factory, type, + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload), + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Predicate) : Predicate)); } std::vector<ui32> GetFactoryColumnIndices() const final { @@ -466,13 +480,15 @@ private: return new TTwoArgsAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetExtractor() const final { - return BuildLambda(Pos, Y("row"), One); + TNodePtr GetExtractor(bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + return BuildLambda(Pos, Y("row"), many ? Y("Unwrap", One) : One); } - TNodePtr GetApply(const TNodePtr& type) const final { + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); auto tuple = Q(Y(One, Two)); - return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), tuple)); + return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", tuple) : tuple)); } bool DoInit(TContext& ctx, ISource* src) final { @@ -563,8 +579,11 @@ private: return new THistogramAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { - auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr), BuildLambda(Pos, Y("row"), Weight)); + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + auto apply = Y("Apply", Factory, type, + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr), + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Weight) : Weight)); AddFactoryArguments(apply); return apply; } @@ -639,9 +658,10 @@ private: return new TLinearHistogramAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); return Y("Apply", Factory, type, - BuildLambda(Pos, Y("row"), Expr), + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr), BinSize, Minimum, Maximum); } @@ -734,7 +754,8 @@ private: return new TPercentileFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); TNodePtr percentiles(Percentiles.cbegin()->second); if (Percentiles.size() > 1U) { @@ -745,16 +766,16 @@ private: percentiles = Q(percentiles); } - return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr), percentiles); + return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr), percentiles); } void AddFactoryArguments(TNodePtr& apply) const final { apply = L(apply, FactoryPercentile); } - TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final { + std::pair<TNodePtr, bool> AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const final { if (Percentiles.empty()) - return TNodePtr(); + return { TNodePtr(), true }; TNodePtr names(Q(Percentiles.cbegin()->first)); @@ -767,9 +788,19 @@ private: const bool distinct = AggMode == EAggregateMode::Distinct; const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type; - return distinct ? - Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) : - Q(Y(names, WrapIfOverState(GetApply(listType), overState))); + auto apply = GetApply(listType, many, ctx); + if (!apply) { + return { TNodePtr(), false }; + } + + auto wrapped = WrapIfOverState(apply, overState, many, ctx); + if (!wrapped) { + return { TNodePtr(), false }; + } + + return { distinct ? + Q(Y(names, wrapped, BuildQuotedAtom(Pos, DistinctKey))) : + Q(Y(names, wrapped)), true }; } bool DoInit(TContext& ctx, ISource* src) final { @@ -857,7 +888,8 @@ private: return new TTopFreqFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); TPair topFreqs(TopFreqs.cbegin()->second); if (TopFreqs.size() > 1U) { @@ -868,7 +900,7 @@ private: topFreqs = { Q(topFreqs.first), Q(topFreqs.second) }; } - auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr), topFreqs.first, topFreqs.second); + auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr), topFreqs.first, topFreqs.second); return apply; } @@ -876,9 +908,9 @@ private: apply = L(apply, TopFreqFactoryParams.first, TopFreqFactoryParams.second); } - TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final { + std::pair<TNodePtr, bool> AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const final { if (TopFreqs.empty()) - return TNodePtr(); + return { TNodePtr(), true }; TNodePtr names(Q(TopFreqs.cbegin()->first)); @@ -891,9 +923,19 @@ private: const bool distinct = AggMode == EAggregateMode::Distinct; const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type; - return distinct ? - Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) : - Q(Y(names, WrapIfOverState(GetApply(listType), overState))); + auto apply = GetApply(listType, many, ctx); + if (!apply) { + return { nullptr, false }; + } + + auto wrapped = WrapIfOverState(apply, overState, many, ctx); + if (!wrapped) { + return { nullptr, false }; + } + + return { distinct ? + Q(Y(names, wrapped, BuildQuotedAtom(Pos, DistinctKey))) : + Q(Y(names, wrapped)), true }; } bool DoInit(TContext& ctx, ISource* src) final { @@ -971,12 +1013,15 @@ private: return new TTopAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); TNodePtr apply; if (HasKey) { - apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Key), BuildLambda(Pos, Y("row"), Payload)); + apply = Y("Apply", Factory, type, + BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Key) : Key), + BuildLambda(Pos, Y("row"), many ? Y("Payload", Payload) : Payload)); } else { - apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Payload)); + apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload)); } AddFactoryArguments(apply); return apply; @@ -1083,8 +1128,9 @@ private: return new TCountDistinctEstimateAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { - auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr)); + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr)); AddFactoryArguments(apply); return apply; } @@ -1156,8 +1202,9 @@ private: return new TListAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { - auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr)); + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr)); AddFactoryArguments(apply); return apply; } @@ -1213,8 +1260,9 @@ private: return new TUserDefinedAggregationFactory(Pos, Name, Func, AggMode); } - TNodePtr GetApply(const TNodePtr& type) const final { - auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr)); + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(ctx); + auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr)); AddFactoryArguments(apply); return apply; } @@ -1297,7 +1345,15 @@ public: return ret; } - TNodePtr GetApply(const TNodePtr& type) const final { + TNodePtr GetExtractor(bool many, TContext& ctx) const override { + Y_UNUSED(many); + ctx.Error() << "Partial aggregation by PostgreSQL function isn't supported"; + return nullptr; + } + + TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final { + Y_UNUSED(many); + Y_UNUSED(ctx); return Y(AggMode == EAggregateMode::OverWindow ? "PgWindowTraits" : "PgAggregationTraits", Q(PgFunc), Y("ListItemType", type), Lambda); } diff --git a/ydb/library/yql/sql/v1/insert.cpp b/ydb/library/yql/sql/v1/insert.cpp index 4f4c8630181..0620eb4eab4 100644 --- a/ydb/library/yql/sql/v1/insert.cpp +++ b/ydb/library/yql/sql/v1/insert.cpp @@ -51,9 +51,10 @@ public: return nullptr; } - TNodePtr BuildAggregation(const TString& label) override { + std::pair<TNodePtr, bool> BuildAggregation(const TString& label, TContext& ctx) override { Y_UNUSED(label); - return nullptr; + Y_UNUSED(ctx); + return { nullptr, true }; } protected: diff --git a/ydb/library/yql/sql/v1/node.cpp b/ydb/library/yql/sql/v1/node.cpp index 9ab50224c20..8a1e56b4c94 100644 --- a/ydb/library/yql/sql/v1/node.cpp +++ b/ydb/library/yql/sql/v1/node.cpp @@ -1246,20 +1246,35 @@ TAstNode* IAggregation::Translate(TContext& ctx) const { return nullptr; } -TNodePtr IAggregation::AggregationTraits(const TNodePtr& type, bool overState) const { +std::pair<TNodePtr, bool> IAggregation::AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const { const bool distinct = AggMode == EAggregateMode::Distinct; const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type; - return distinct ? - Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) : - Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState))); + auto apply = GetApply(listType, many, ctx); + if (!apply) { + return { nullptr, false }; + } + + auto wrapped = WrapIfOverState(apply, overState, many, ctx); + if (!wrapped) { + return { nullptr, false }; + } + + return { distinct ? + Q(Y(Q(Name), wrapped, BuildQuotedAtom(Pos, DistinctKey))) : + Q(Y(Q(Name), wrapped)), true }; } -TNodePtr IAggregation::WrapIfOverState(const TNodePtr& input, bool overState) const { +TNodePtr IAggregation::WrapIfOverState(const TNodePtr& input, bool overState, bool many, TContext& ctx) const { if (!overState) { return input; } - return Y("AggOverState", GetExtractor(), BuildLambda(Pos, Y(), input)); + auto extractor = GetExtractor(many, ctx); + if (!extractor) { + return nullptr; + } + + return Y(ToString("AggOverState"), extractor, BuildLambda(Pos, Y(), input)); } void IAggregation::AddFactoryArguments(TNodePtr& apply) const { @@ -1270,9 +1285,9 @@ std::vector<ui32> IAggregation::GetFactoryColumnIndices() const { return {0u}; } -TNodePtr IAggregation::WindowTraits(const TNodePtr& type) const { +TNodePtr IAggregation::WindowTraits(const TNodePtr& type, TContext& ctx) const { YQL_ENSURE(AggMode == EAggregateMode::OverWindow, "Windows traits is unavailable"); - return Q(Y(Q(Name), GetApply(type))); + return Q(Y(Q(Name), GetApply(type, false, ctx))); } ISource::ISource(TPosition pos) @@ -1759,9 +1774,9 @@ bool ISource::SetSamplingRate(TContext& ctx, TNodePtr samplingRate) { return true; } -TNodePtr ISource::BuildAggregation(const TString& label) { +std::pair<TNodePtr, bool> ISource::BuildAggregation(const TString& label, TContext& ctx) { if (GroupKeys.empty() && Aggregations.empty() && !IsCompositeSource() && !HoppingWindowSpec) { - return nullptr; + return { nullptr, true }; } auto keysTuple = Y(); @@ -1786,10 +1801,16 @@ TNodePtr ISource::BuildAggregation(const TString& label) { const auto listType = Y("TypeOf", label); auto aggrArgs = Y(); - const bool overState = GroupBySuffix == "CombineState" || GroupBySuffix == "MergeState" || GroupBySuffix == "MergeFinalize"; + const bool overState = GroupBySuffix == "CombineState" || GroupBySuffix == "MergeState" || + GroupBySuffix == "MergeFinalize" || GroupBySuffix == "MergeManyFinalize"; for (const auto& aggr: Aggregations) { - if (const auto traits = aggr->AggregationTraits(listType, overState)) { - aggrArgs = L(aggrArgs, traits); + auto res = aggr->AggregationTraits(listType, overState, GroupBySuffix == "MergeManyFinalize", ctx); + if (!res.second) { + return { nullptr, false }; + } + + if (res.first) { + aggrArgs = L(aggrArgs, res.first); } } @@ -1819,7 +1840,7 @@ TNodePtr ISource::BuildAggregation(const TString& label) { Q(Y(BuildQuotedAtom(Pos, SessionWindow->GetLabel()), sessionWindow->BuildTraits(label)))))); } - return Y("AssumeColumnOrderPartial", Y("Aggregate" + GroupBySuffix, label, Q(keysTuple), Q(aggrArgs), Q(options)), Q(keysTuple)); + return { Y("AssumeColumnOrderPartial", Y("Aggregate" + GroupBySuffix, label, Q(keysTuple), Q(aggrArgs), Q(options)), Q(keysTuple)), true }; } TMaybe<TString> ISource::FindColumnMistype(const TString& name) const { @@ -1990,7 +2011,7 @@ TNodePtr ISource::BuildCalcOverWindow(TContext& ctx, const TString& label) { YQL_ENSURE(frameType); auto callOnFrame = Y(frameType, BuildWindowFrame(*spec->Frame, spec->IsCompact)); for (auto& agg : aggs) { - auto winTraits = agg->WindowTraits(listType); + auto winTraits = agg->WindowTraits(listType, ctx); callOnFrame = L(callOnFrame, winTraits); } for (auto& func : funcs) { diff --git a/ydb/library/yql/sql/v1/node.h b/ydb/library/yql/sql/v1/node.h index 2d779dfc90c..6cdc0098ecc 100644 --- a/ydb/library/yql/sql/v1/node.h +++ b/ydb/library/yql/sql/v1/node.h @@ -754,7 +754,7 @@ namespace NSQLTranslationV1 { virtual bool InitAggr(TContext& ctx, bool isFactory, ISource* src, TAstListNode& node, const TVector<TNodePtr>& exprs) = 0; - virtual TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const; + virtual std::pair<TNodePtr, bool> AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const; virtual TNodePtr AggregationTraitsFactory() const = 0; @@ -762,7 +762,7 @@ namespace NSQLTranslationV1 { virtual void AddFactoryArguments(TNodePtr& apply) const; - virtual TNodePtr WindowTraits(const TNodePtr& type) const; + virtual TNodePtr WindowTraits(const TNodePtr& type, TContext& ctx) const; const TString& GetName() const; @@ -772,13 +772,13 @@ namespace NSQLTranslationV1 { virtual void Join(IAggregation* aggr); private: - virtual TNodePtr GetApply(const TNodePtr& type) const = 0; + virtual TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const = 0; protected: IAggregation(TPosition pos, const TString& name, const TString& func, EAggregateMode mode); TAstNode* Translate(TContext& ctx) const override; - TNodePtr WrapIfOverState(const TNodePtr& input, bool overState) const; - virtual TNodePtr GetExtractor() const = 0; + TNodePtr WrapIfOverState(const TNodePtr& input, bool overState, bool many, TContext& ctx) const; + virtual TNodePtr GetExtractor(bool many, TContext& ctx) const = 0; TString Name; TString Func; @@ -860,7 +860,7 @@ namespace NSQLTranslationV1 { virtual TNodePtr BuildPreaggregatedMap(TContext& ctx); virtual TNodePtr BuildPreFlattenMap(TContext& ctx); virtual TNodePtr BuildPrewindowMap(TContext& ctx); - virtual TNodePtr BuildAggregation(const TString& label); + virtual std::pair<TNodePtr, bool> BuildAggregation(const TString& label, TContext& ctx); virtual TNodePtr BuildCalcOverWindow(TContext& ctx, const TString& label); virtual TNodePtr BuildSort(TContext& ctx, const TString& label); virtual TNodePtr BuildCleanupColumns(TContext& ctx, const TString& label); diff --git a/ydb/library/yql/sql/v1/select.cpp b/ydb/library/yql/sql/v1/select.cpp index b9c3d448a3a..ac37571185c 100644 --- a/ydb/library/yql/sql/v1/select.cpp +++ b/ydb/library/yql/sql/v1/select.cpp @@ -242,9 +242,10 @@ public: return nullptr; } - TNodePtr BuildAggregation(const TString& label) override { + std::pair<TNodePtr, bool> BuildAggregation(const TString& label, TContext& ctx) override { Y_UNUSED(label); - return nullptr; + Y_UNUSED(ctx); + return { nullptr, true }; } TPtr DoClone() const final { @@ -1552,7 +1553,12 @@ public: } src->FinishColumns(); - Aggregate = src->BuildAggregation("core"); + auto aggRes = src->BuildAggregation("core", ctx); + if (!aggRes.second) { + return false; + } + + Aggregate = aggRes.first; if (src->IsFlattenByColumns() || src->IsFlattenColumns()) { Flatten = src->IsFlattenByColumns() ? src->BuildFlattenByColumns("row") : diff --git a/ydb/library/yql/sql/v1/sql.cpp b/ydb/library/yql/sql/v1/sql.cpp index 53f345695ca..1594a3095b9 100644 --- a/ydb/library/yql/sql/v1/sql.cpp +++ b/ydb/library/yql/sql/v1/sql.cpp @@ -7224,6 +7224,8 @@ bool TGroupByClause::Build(const TRule_group_by_clause& node, bool stream) { Suffix = "Finalize"; } else if (mode == "mergefinalize") { Suffix = "MergeFinalize"; + } else if (mode == "mergemanyfinalize") { + Suffix = "MergeManyFinalize"; } else { Ctx.Error() << "Unsupported group by mode: " << mode; Ctx.IncrementMonCounter("sql_errors", "GroupByModeUnknown"); |