aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-10-10 23:20:20 +0300
committervvvv <vvvv@ydb.tech>2022-10-10 23:20:20 +0300
commitf970a39bdd247dd49a1204c59ba07966d2886905 (patch)
treecd8af96e552a2c05626e67d510fe1c2aadc8db06
parenta6b8cecadad0531eaf559b57c797d7d71bae12d7 (diff)
downloadydb-f970a39bdd247dd49a1204c59ba07966d2886905.tar.gz
implementation of Aggregate phases in peephole
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp6
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp2
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp34
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.cpp259
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.h6
-rw-r--r--ydb/library/yql/core/yql_expr_constraint.cpp6
-rw-r--r--ydb/library/yql/sql/v1/aggregation.cpp148
-rw-r--r--ydb/library/yql/sql/v1/insert.cpp5
-rw-r--r--ydb/library/yql/sql/v1/node.cpp51
-rw-r--r--ydb/library/yql/sql/v1/node.h12
-rw-r--r--ydb/library/yql/sql/v1/select.cpp12
-rw-r--r--ydb/library/yql/sql/v1/sql.cpp2
12 files changed, 380 insertions, 163 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 5e2482d699c..40c096b7f32 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -6121,6 +6121,12 @@ struct TPeepHoleRules {
static constexpr std::initializer_list<TExtPeepHoleOptimizerMap::value_type> CommonStageExtRulesInit = {
{"Aggregate", &ExpandAggregatePeephole},
+ {"AggregateCombine", &ExpandAggregatePeephole},
+ {"AggregateCombineState", &ExpandAggregatePeephole},
+ {"AggregateMergeState", &ExpandAggregatePeephole},
+ {"AggregateMergeFinalize", &ExpandAggregatePeephole},
+ {"AggregateMergeManyFinalize", &ExpandAggregatePeephole},
+ {"AggregateFinalize", &ExpandAggregatePeephole},
};
static constexpr std::initializer_list<TPeepHoleOptimizerMap::value_type> SimplifyStageRulesInit = {
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 3c458a0ff49..f6383d58a78 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11444,6 +11444,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["AggregateMergeState"] = &AggregateWrapper;
Functions["AggregateFinalize"] = &AggregateWrapper;
Functions["AggregateMergeFinalize"] = &AggregateWrapper;
+ Functions["AggregateMergeManyFinalize"] = &AggregateWrapper;
Functions["AggOverState"] = &AggOverStateWrapper;
Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper;
Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper;
@@ -11507,6 +11508,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["PgOr"] = &PgBoolOpWrapper;
Functions["PgNot"] = &PgBoolOpWrapper;
Functions["PgAggregationTraits"] = &PgAggregationTraitsWrapper;
+ Functions["PgAggregationTraitsOverState"] = &PgAggregationTraitsWrapper;
Functions["PgWindowTraits"] = &PgAggregationTraitsWrapper;
Functions["PgAggregationTraitsTuple"] = &PgAggregationTraitsWrapper;
Functions["PgWindowTraitsTuple"] = &PgAggregationTraitsWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index e5a4b3ba4bb..d12f33f801c 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -4398,21 +4398,21 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
- if (!overState) {
- if (lambdaUpdate->Head().ChildrenSize() == 2U) {
- if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
- } else {
- if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
+ if (lambdaUpdate->Head().ChildrenSize() == 2U) {
+ if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
}
-
- if (!lambdaUpdate->GetTypeAnn()) {
- return IGraphTransformer::TStatus::Repeat;
+ } else {
+ if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
}
+ }
+
+ if (!lambdaUpdate->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+ if (!overState) {
if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *combineStateType)) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch update lambda result type, expected: "
<< *combineStateType << ", but got: " << *lambdaUpdate->GetTypeAnn()));
@@ -4774,6 +4774,11 @@ namespace {
const TTypeAnnotationNode* distinctColumnType = nullptr;
if (child->ChildrenSize() == 3) {
+ if (suffix != "" && suffix != "Finalize") {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder() << "DISTINCT aggregation is not supported for mode: " << suffix));
+ return IGraphTransformer::TStatus::Error;
+ }
+
if (!EnsureAtom(*child->Child(2), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -5116,6 +5121,11 @@ namespace {
TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot));
return IGraphTransformer::TStatus::Error;
}
+
+ if (isOptional) {
+ sumResultType = ctx.Expr.MakeType<TOptionalExprType>(sumResultType);
+ }
+
input->SetTypeAnn(sumResultType);
} else if (IsNull(*lambda->GetTypeAnn())) {
input->SetTypeAnn(ctx.Expr.MakeType<TNullExprType>());
diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp
index 5fc67082331..afd95c720e2 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.cpp
+++ b/ydb/library/yql/core/yql_aggregate_expander.cpp
@@ -1,4 +1,5 @@
#include "yql_aggregate_expander.h"
+#include "yql_aggregate_expander.h"
#include <ydb/library/yql/core/yql_expr_optimize.h>
#include <ydb/library/yql/core/yql_expr_type_annotation.h>
@@ -8,6 +9,8 @@ namespace NYql {
TExprNode::TPtr TAggregateExpander::ExpandAggregate()
{
+ Suffix = Node->Content();
+ YQL_ENSURE(Suffix.SkipPrefix("Aggregate"));
AggList = Node->HeadPtr();
KeyColumns = Node->ChildPtr(1);
AggregatedColumns = Node->Child(2);
@@ -29,6 +32,13 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate()
}
}
+ if (Suffix == "Finalize") {
+ EffectiveCompact = true;
+ Suffix = "";
+ } else if (Suffix != "") {
+ EffectiveCompact = false;
+ }
+
OriginalRowType = GetSeqItemType(Node->Head().GetTypeAnn())->Cast<TStructExprType>();
RowItems = OriginalRowType->GetItems();
@@ -39,11 +49,20 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate()
bool needPickle = IsNeedPickle(keyItemTypes);
auto keyExtractor = GetKeyExtractor(needPickle);
CollectColumnsSpecs();
-
+
+ if (Suffix == "MergeState" || Suffix == "MergeFinalize" || Suffix == "MergeManyFinalize") {
+ return GeneratePostAggregate(AggList, keyExtractor);
+ }
+
TExprNode::TPtr preAgg = GeneratePartialAggregate(keyExtractor, keyItemTypes, needPickle);
if (EffectiveCompact || !preAgg) {
preAgg = std::move(AggList);
}
+
+ if (Suffix == "Combine" || Suffix == "CombineState") {
+ return preAgg;
+ }
+
return GeneratePostAggregate(preAgg, keyExtractor);
}
@@ -76,7 +95,7 @@ bool TAggregateExpander::CollectTraits() {
bool allTraitsCollected = true;
for (ui32 index = 0; index < AggregatedColumns->ChildrenSize(); ++index) {
auto trait = AggregatedColumns->Child(index)->ChildPtr(1);
- if (trait->IsCallable("AggApply")) {
+ if (trait->IsCallable({ "AggApply", "AggApplyState" })) {
trait = ExpandAggApply(trait);
allTraitsCollected = false;
}
@@ -341,8 +360,97 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregate(const TExprNode::TP
return partialAgg;
}
+std::function<TExprNodeBuilder& (TExprNodeBuilder&)> TAggregateExpander::GetPartialAggArgExtractor(ui32 i, bool deserialize) {
+ return [&, i, deserialize](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
+ auto trait = Traits[i];
+ auto extractorLambda = trait->Child(1);
+ auto loadLambda = trait->Child(4);
+ if (Suffix == "CombineState") {
+ if (deserialize) {
+ parent.Apply(*loadLambda)
+ .With(0)
+ .Apply(*extractorLambda)
+ .With(0)
+ .Callable("CastStruct")
+ .Arg(0, "item")
+ .Add(1, ExpandType(Node->Pos(), *extractorLambda->Head().Head().GetTypeAnn(), Ctx))
+ .Seal()
+ .Done()
+ .Seal()
+ .Done()
+ .Seal();
+ } else {
+ parent.Apply(*extractorLambda)
+ .With(0)
+ .Callable("CastStruct")
+ .Arg(0, "item")
+ .Add(1, ExpandType(Node->Pos(), *extractorLambda->Head().Head().GetTypeAnn(), Ctx))
+ .Seal()
+ .Done()
+ .Seal();
+ }
+ } else {
+ parent.Callable("CastStruct")
+ .Arg(0, "item")
+ .Add(1, ExpandType(Node->Pos(), *extractorLambda->Head().Head().GetTypeAnn(), Ctx))
+ .Seal();
+ }
+
+ return parent;
+ };
+}
+
+TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) {
+ auto trait = Traits[i];
+ if (Suffix.StartsWith("Merge")) {
+ auto lambda = trait->ChildPtr(1);
+ if (!Suffix.StartsWith("MergeMany")) {
+ return lambda;
+ }
+
+ if (lambda->Tail().IsCallable("Unwrap")) {
+ return Ctx.Builder(Node->Pos())
+ .Lambda()
+ .Param("item")
+ .ApplyPartial(lambda->HeadPtr(), lambda->Tail().HeadPtr())
+ .With(0, "item")
+ .Seal()
+ .Seal()
+ .Build();
+ } else {
+ return Ctx.Builder(Node->Pos())
+ .Lambda()
+ .Param("item")
+ .Callable("Just")
+ .Apply(0, *lambda)
+ .With(0, "item")
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+ }
+ }
+
+ bool aggregateOnly = (Suffix != "");
+ const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames;
+ return Ctx.Builder(Node->Pos())
+ .Lambda()
+ .Param("item")
+ .Callable("Member")
+ .Arg(0, "item")
+ .Add(1, columnNames[i])
+ .Seal()
+ .Seal()
+ .Build();
+}
+
TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const TExprNode::TPtr& keyExtractor, const TExprNode::TPtr& pickleTypeNode)
{
+ bool combineOnly = Suffix == "Combine" || Suffix == "CombineState";
+ const auto& columnNames = combineOnly ? FinalColumnNames : InitialColumnNames;
+ auto initLambdaIndex = (Suffix == "CombineState") ? 4 : 1;
+ auto updateLambdaIndex = (Suffix == "CombineState") ? 5 : 2;
+
auto combineInit = Ctx.Builder(Node->Pos())
.Lambda()
.Param("key")
@@ -352,28 +460,22 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
ui32 ndx = 0;
for (ui32 i: NonDistinctColumns) {
auto trait = Traits[i];
- auto initLambda = trait->Child(1);
+ auto initLambda = trait->Child(initLambdaIndex);
if (initLambda->Head().ChildrenSize() == 1) {
parent.List(ndx++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *initLambda)
.With(0)
- .Callable("CastStruct")
- .Arg(0, "item")
- .Add(1, ExpandType(Node->Pos(), *initLambda->Head().Head().GetTypeAnn(), Ctx))
- .Seal()
+ .Do(GetPartialAggArgExtractor(i, false))
.Done()
.Seal()
.Seal();
} else {
parent.List(ndx++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *initLambda)
.With(0)
- .Callable("CastStruct")
- .Arg(0, "item")
- .Add(1, ExpandType(Node->Pos(), *initLambda->Head().Head().GetTypeAnn(), Ctx))
- .Seal()
+ .Do(GetPartialAggArgExtractor(i, false))
.Done()
.With(1)
.Callable("Uint32")
@@ -400,39 +502,33 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
ui32 ndx = 0;
for (ui32 i: NonDistinctColumns) {
auto trait = Traits[i];
- auto updateLambda = trait->Child(2);
+ auto updateLambda = trait->Child(updateLambdaIndex);
if (updateLambda->Head().ChildrenSize() == 2) {
parent.List(ndx++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *updateLambda)
.With(0)
- .Callable("CastStruct")
- .Arg(0, "item")
- .Add(1, ExpandType(Node->Pos(), *updateLambda->Head().Head().GetTypeAnn(), Ctx))
- .Seal()
+ .Do(GetPartialAggArgExtractor(i, true))
.Done()
.With(1)
- .Callable("Member")
- .Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Callable("Member")
+ .Arg(0, "state")
+ .Add(1, columnNames[i])
.Seal()
.Done()
.Seal()
.Seal();
} else {
parent.List(ndx++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *updateLambda)
.With(0)
- .Callable("CastStruct")
- .Arg(0, "item")
- .Add(1, ExpandType(Node->Pos(), *updateLambda->Head().Head().GetTypeAnn(), Ctx))
- .Seal()
+ .Do(GetPartialAggArgExtractor(i, true))
.Done()
.With(1)
- .Callable("Member")
- .Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Callable("Member")
+ .Arg(0, "state")
+ .Add(1, columnNames[i])
.Seal()
.Done()
.With(2)
@@ -457,10 +553,10 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
.Callable("Just")
.Callable(0, "AsStruct")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (ui32 i = 0; i < InitialColumnNames.size(); ++i) {
+ for (ui32 i = 0; i < columnNames.size(); ++i) {
if (NonDistinctColumns.find(i) == NonDistinctColumns.end()) {
parent.List(i)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Add(1, NothingStates[i])
.Seal();
} else {
@@ -468,13 +564,13 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
auto saveLambda = trait->Child(3);
if (!DistinctFields.empty()) {
parent.List(i)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Callable(1, "Just")
.Apply(0, *saveLambda)
.With(0)
.Callable("Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Done()
.Seal()
@@ -482,12 +578,12 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
.Seal();
} else {
parent.List(i)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *saveLambda)
.With(0)
.Callable("Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Done()
.Seal()
@@ -500,7 +596,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
ui32 pos = 0;
for (ui32 i = 0; i < KeyColumns->ChildrenSize(); ++i) {
- auto listBuilder = parent.List(InitialColumnNames.size() + i);
+ auto listBuilder = parent.List(columnNames.size() + i);
listBuilder.Add(0, KeyColumns->ChildPtr(i));
if (KeyColumns->ChildrenSize() > 1) {
if (pickleTypeNode) {
@@ -1006,7 +1102,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregate(const TExprNode::TPtr&
.Seal()
.Seal().Build();
- if (KeyColumns->ChildrenSize() == 0 && !HaveSessionSetting) {
+ if (KeyColumns->ChildrenSize() == 0 && !HaveSessionSetting && (Suffix == "" || Suffix.EndsWith("Finalize"))) {
return MakeSingleGroupRow(*Node, postAgg, Ctx);
}
@@ -1086,6 +1182,9 @@ TExprNode::TPtr TAggregateExpander::GenerateCondenseSwitch(const TExprNode::TPtr
TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase()
{
+ bool aggregateOnly = (Suffix != "");
+ const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames;
+
ui32 index = 0U;
return Ctx.Builder(Node->Pos())
.Lambda()
@@ -1115,31 +1214,30 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase()
return parent;
})
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (ui32 i = 0; i < InitialColumnNames.size(); ++i) {
+ for (ui32 i = 0; i < columnNames.size(); ++i) {
auto child = AggregatedColumns->Child(i);
auto trait = Traits[i];
if (!EffectiveCompact) {
auto loadLambda = trait->Child(4);
+ auto extractorLambda = GetFinalAggStateExtractor(i);
- if (!DistinctFields.empty()) {
+ if (!DistinctFields.empty() || Suffix == "MergeManyFinalize") {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Callable(1, "Map")
- .Callable(0, "Member")
- .Arg(0, "item")
- .Add(1, InitialColumnNames[i])
+ .Apply(0, *extractorLambda)
+ .With(0, "item")
.Seal()
.Add(1, loadLambda)
.Seal()
.Seal();
} else {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *loadLambda)
.With(0)
- .Callable("Member")
- .Arg(0, "item")
- .Add(1, InitialColumnNames[i])
+ .Apply(*extractorLambda)
+ .With(0, "item")
.Seal()
.Done()
.Seal();
@@ -1188,7 +1286,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase()
const bool isFirst = *Distinct2Columns[distinctField->Content()].begin() == i;
if (isFirst) {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.List(1)
.Callable(0, "NamedApply")
.Add(0, UdfSetCreate[distinctField->Content()])
@@ -1226,13 +1324,13 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase()
.Seal();
} else {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Do(initApply)
.Seal();
}
} else {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Do(initApply)
.Seal();
}
@@ -1247,6 +1345,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase()
TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase()
{
+ bool aggregateOnly = (Suffix != "");
+ const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames;
+
ui32 index = 0U;
return Ctx.Builder(Node->Pos())
.Lambda()
@@ -1280,12 +1381,12 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase()
return parent;
})
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (ui32 i = 0; i < InitialColumnNames.size(); ++i) {
+ for (ui32 i = 0; i < columnNames.size(); ++i) {
auto child = AggregatedColumns->Child(i);
auto trait = Traits[i];
- auto finishLambda = trait->Child(6);
+ auto finishLambda = (Suffix == "MergeState") ? trait->Child(3) : trait->Child(6);
- if (!EffectiveCompact && !DistinctFields.empty()) {
+ if (!EffectiveCompact && (!DistinctFields.empty() || Suffix == "MergeManyFinalize")) {
if (child->Head().IsAtom()) {
parent.List(index++)
.Add(0, FinalColumnNames[i])
@@ -1293,7 +1394,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase()
.Callable(0, "Map")
.Callable(0, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Add(1, finishLambda)
.Seal()
@@ -1309,7 +1410,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase()
.Callable(0, "Map")
.Callable(0, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Add(1, finishLambda)
.Seal()
@@ -1327,14 +1428,14 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase()
parent.Callable("Nth")
.Callable(0, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Atom(1, "1", TNodeFlags::Default)
.Seal();
} else {
parent.Callable("Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal();
}
@@ -1377,6 +1478,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase()
TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
{
+ bool aggregateOnly = (Suffix != "");
+ const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames;
+
ui32 index = 0U;
return Ctx.Builder(Node->Pos())
.Lambda()
@@ -1407,41 +1511,40 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
return parent;
})
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (ui32 i = 0; i < InitialColumnNames.size(); ++i) {
+ for (ui32 i = 0; i < columnNames.size(); ++i) {
auto child = AggregatedColumns->Child(i);
auto trait = Traits[i];
if (!EffectiveCompact) {
auto loadLambda = trait->Child(4);
auto mergeLambda = trait->Child(5);
+ auto extractorLambda = GetFinalAggStateExtractor(i);
- if (!DistinctFields.empty()) {
+ if (!DistinctFields.empty() || Suffix == "MergeManyFinalize") {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Callable(1, "OptionalReduce")
.Callable(0, "Map")
- .Callable(0, "Member")
- .Arg(0, "item")
- .Add(1, InitialColumnNames[i])
+ .Apply(0, extractorLambda)
+ .With(0, "item")
.Seal()
.Add(1, loadLambda)
.Seal()
.Callable(1, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Add(2, mergeLambda)
.Seal()
.Seal();
} else {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Apply(1, *mergeLambda)
.With(0)
.Apply(*loadLambda)
.With(0)
- .Callable("Member")
- .Arg(0, "item")
- .Add(1, InitialColumnNames[i])
+ .Apply(extractorLambda)
+ .With(0, "item")
.Seal()
.Done()
.Seal()
@@ -1449,7 +1552,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
.With(1)
.Callable("Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Done()
.Seal()
@@ -1486,14 +1589,14 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
parent.Callable("Nth")
.Callable(0, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Atom(1, "1", TNodeFlags::Default)
.Seal();
} else {
parent.Callable("Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal();
}
@@ -1527,7 +1630,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
.Callable(0, "Nth")
.Callable(0, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[distinctIndex])
+ .Add(1, columnNames[distinctIndex])
.Seal()
.Atom(1, "0", TNodeFlags::Default)
.Seal()
@@ -1556,7 +1659,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
};
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Callable(1, "If")
.Callable(0, "NamedApply")
.Add(0, UdfWasChanged[distinctField->Content()])
@@ -1567,7 +1670,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
.Callable(0, "Nth")
.Callable(0, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[distinctIndex])
+ .Add(1, columnNames[distinctIndex])
.Seal()
.Atom(1, "0", TNodeFlags::Default)
.Seal()
@@ -1608,13 +1711,13 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
})
.Callable(2, "Member")
.Arg(0, "state")
- .Add(1, InitialColumnNames[i])
+ .Add(1, columnNames[i])
.Seal()
.Seal()
.Seal();
} else {
parent.List(index++)
- .Add(0, InitialColumnNames[i])
+ .Add(0, columnNames[i])
.Do(updateApply)
.Seal();
}
@@ -1627,4 +1730,4 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase()
.Build();
}
-} // namespace NYql \ No newline at end of file
+} // namespace NYql
diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h
index 552bdee2229..c39481012dc 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.h
+++ b/ydb/library/yql/core/yql_aggregate_expander.h
@@ -64,6 +64,9 @@ private:
TExprNode::TPtr GeneratePostAggregateSavePhase();
TExprNode::TPtr GeneratePostAggregateMergePhase();
+ std::function<TExprNodeBuilder& (TExprNodeBuilder&)> GetPartialAggArgExtractor(ui32 i, bool deserialize);
+ TExprNode::TPtr GetFinalAggStateExtractor(ui32 i);
+
private:
static constexpr TStringBuf SessionStartMemberName = "_yql_group_session_start";
@@ -73,6 +76,7 @@ private:
bool AllowPickle;
bool ForceCompact;
bool CompactForDistinct;
+ TStringBuf Suffix;
TSessionWindowParams SessionWindowParams;
TExprNode::TPtr AggList;
@@ -110,4 +114,4 @@ inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExp
return aggExpander.ExpandAggregate();
}
-} \ No newline at end of file
+}
diff --git a/ydb/library/yql/core/yql_expr_constraint.cpp b/ydb/library/yql/core/yql_expr_constraint.cpp
index 48e8920af51..d995acfd799 100644
--- a/ydb/library/yql/core/yql_expr_constraint.cpp
+++ b/ydb/library/yql/core/yql_expr_constraint.cpp
@@ -211,6 +211,12 @@ public:
Functions["WideCombiner"] = &TCallableConstraintTransformer::InheriteEmptyFromInput;
Functions["WideCondense1"] = &TCallableConstraintTransformer::WideCondense1Wrap;
Functions["Aggregate"] = &TCallableConstraintTransformer::AggregateWrap;
+ Functions["AggregateCombine"] = &TCallableConstraintTransformer::AggregateWrap;
+ Functions["AggregateCombineState"] = &TCallableConstraintTransformer::AggregateWrap;
+ Functions["AggregateMergeState"] = &TCallableConstraintTransformer::AggregateWrap;
+ Functions["AggregateMergeFinalize"] = &TCallableConstraintTransformer::AggregateWrap;
+ Functions["AggregateMergeManyFinalize"] = &TCallableConstraintTransformer::AggregateWrap;
+ Functions["AggregateFinalize"] = &TCallableConstraintTransformer::AggregateWrap;
Functions["Fold"] = &TCallableConstraintTransformer::FoldWrap;
Functions["Fold1"] = &TCallableConstraintTransformer::FoldWrap;
Functions["WithContext"] = &TCallableConstraintTransformer::CopyAllFrom<0>;
diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp
index b9157664b01..c141c859c17 100644
--- a/ydb/library/yql/sql/v1/aggregation.cpp
+++ b/ydb/library/yql/sql/v1/aggregation.cpp
@@ -114,23 +114,29 @@ protected:
return Factory;
}
- TNodePtr GetExtractor() const override {
- return BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr));
+ TNodePtr GetExtractor(bool many, TContext& ctx) const override {
+ Y_UNUSED(ctx);
+ return BuildLambda(Pos, Y("row"), Y("PersistableRepr", many ? Y("Unwrap", Expr) : Expr));
}
- TNodePtr GetApply(const TNodePtr& type) const override {
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const override {
+ auto extractor = GetExtractor(many, ctx);
+ if (!extractor) {
+ return nullptr;
+ }
+
if (!Multi) {
if (!DynamicFactory && !AggApplyName.empty()) {
- return Y("AggApply", Q(AggApplyName), Y("ListItemType", type), BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)));
+ return Y("AggApply", Q(AggApplyName), Y("ListItemType", type), extractor);
}
return Y("Apply", Factory, (DynamicFactory ? Y("ListItemType", type) : type),
- BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)));
+ extractor);
}
return Y("MultiAggregate",
Y("ListItemType", type),
- GetExtractor(),
+ extractor,
Factory);
}
@@ -288,12 +294,16 @@ private:
return new TKeyPayloadAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetExtractor() const final {
- return BuildLambda(Pos, Y("row"), Payload);
+ TNodePtr GetExtractor(bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ return BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
- auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Key), BuildLambda(Pos, Y("row"), Payload));
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ auto apply = Y("Apply", Factory, type,
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Key) : Key),
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload));
AddFactoryArguments(apply);
return apply;
}
@@ -384,12 +394,16 @@ private:
return new TPayloadPredicateAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetExtractor() const final {
- return BuildLambda(Pos, Y("row"), Payload);
+ TNodePtr GetExtractor(bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ return BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
- return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Payload), BuildLambda(Pos, Y("row"), Predicate));
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ return Y("Apply", Factory, type,
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload),
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Predicate) : Predicate));
}
std::vector<ui32> GetFactoryColumnIndices() const final {
@@ -466,13 +480,15 @@ private:
return new TTwoArgsAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetExtractor() const final {
- return BuildLambda(Pos, Y("row"), One);
+ TNodePtr GetExtractor(bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ return BuildLambda(Pos, Y("row"), many ? Y("Unwrap", One) : One);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
auto tuple = Q(Y(One, Two));
- return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), tuple));
+ return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", tuple) : tuple));
}
bool DoInit(TContext& ctx, ISource* src) final {
@@ -563,8 +579,11 @@ private:
return new THistogramAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
- auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr), BuildLambda(Pos, Y("row"), Weight));
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ auto apply = Y("Apply", Factory, type,
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr),
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Weight) : Weight));
AddFactoryArguments(apply);
return apply;
}
@@ -639,9 +658,10 @@ private:
return new TLinearHistogramAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
return Y("Apply", Factory, type,
- BuildLambda(Pos, Y("row"), Expr),
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr),
BinSize, Minimum, Maximum);
}
@@ -734,7 +754,8 @@ private:
return new TPercentileFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
TNodePtr percentiles(Percentiles.cbegin()->second);
if (Percentiles.size() > 1U) {
@@ -745,16 +766,16 @@ private:
percentiles = Q(percentiles);
}
- return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr), percentiles);
+ return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr), percentiles);
}
void AddFactoryArguments(TNodePtr& apply) const final {
apply = L(apply, FactoryPercentile);
}
- TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final {
+ std::pair<TNodePtr, bool> AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const final {
if (Percentiles.empty())
- return TNodePtr();
+ return { TNodePtr(), true };
TNodePtr names(Q(Percentiles.cbegin()->first));
@@ -767,9 +788,19 @@ private:
const bool distinct = AggMode == EAggregateMode::Distinct;
const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type;
- return distinct ?
- Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) :
- Q(Y(names, WrapIfOverState(GetApply(listType), overState)));
+ auto apply = GetApply(listType, many, ctx);
+ if (!apply) {
+ return { TNodePtr(), false };
+ }
+
+ auto wrapped = WrapIfOverState(apply, overState, many, ctx);
+ if (!wrapped) {
+ return { TNodePtr(), false };
+ }
+
+ return { distinct ?
+ Q(Y(names, wrapped, BuildQuotedAtom(Pos, DistinctKey))) :
+ Q(Y(names, wrapped)), true };
}
bool DoInit(TContext& ctx, ISource* src) final {
@@ -857,7 +888,8 @@ private:
return new TTopFreqFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
TPair topFreqs(TopFreqs.cbegin()->second);
if (TopFreqs.size() > 1U) {
@@ -868,7 +900,7 @@ private:
topFreqs = { Q(topFreqs.first), Q(topFreqs.second) };
}
- auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr), topFreqs.first, topFreqs.second);
+ auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr), topFreqs.first, topFreqs.second);
return apply;
}
@@ -876,9 +908,9 @@ private:
apply = L(apply, TopFreqFactoryParams.first, TopFreqFactoryParams.second);
}
- TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final {
+ std::pair<TNodePtr, bool> AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const final {
if (TopFreqs.empty())
- return TNodePtr();
+ return { TNodePtr(), true };
TNodePtr names(Q(TopFreqs.cbegin()->first));
@@ -891,9 +923,19 @@ private:
const bool distinct = AggMode == EAggregateMode::Distinct;
const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type;
- return distinct ?
- Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) :
- Q(Y(names, WrapIfOverState(GetApply(listType), overState)));
+ auto apply = GetApply(listType, many, ctx);
+ if (!apply) {
+ return { nullptr, false };
+ }
+
+ auto wrapped = WrapIfOverState(apply, overState, many, ctx);
+ if (!wrapped) {
+ return { nullptr, false };
+ }
+
+ return { distinct ?
+ Q(Y(names, wrapped, BuildQuotedAtom(Pos, DistinctKey))) :
+ Q(Y(names, wrapped)), true };
}
bool DoInit(TContext& ctx, ISource* src) final {
@@ -971,12 +1013,15 @@ private:
return new TTopAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
TNodePtr apply;
if (HasKey) {
- apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Key), BuildLambda(Pos, Y("row"), Payload));
+ apply = Y("Apply", Factory, type,
+ BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Key) : Key),
+ BuildLambda(Pos, Y("row"), many ? Y("Payload", Payload) : Payload));
} else {
- apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Payload));
+ apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Payload) : Payload));
}
AddFactoryArguments(apply);
return apply;
@@ -1083,8 +1128,9 @@ private:
return new TCountDistinctEstimateAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
- auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr));
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr));
AddFactoryArguments(apply);
return apply;
}
@@ -1156,8 +1202,9 @@ private:
return new TListAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
- auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr));
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr));
AddFactoryArguments(apply);
return apply;
}
@@ -1213,8 +1260,9 @@ private:
return new TUserDefinedAggregationFactory(Pos, Name, Func, AggMode);
}
- TNodePtr GetApply(const TNodePtr& type) const final {
- auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Expr));
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(ctx);
+ auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), many ? Y("Unwrap", Expr) : Expr));
AddFactoryArguments(apply);
return apply;
}
@@ -1297,7 +1345,15 @@ public:
return ret;
}
- TNodePtr GetApply(const TNodePtr& type) const final {
+ TNodePtr GetExtractor(bool many, TContext& ctx) const override {
+ Y_UNUSED(many);
+ ctx.Error() << "Partial aggregation by PostgreSQL function isn't supported";
+ return nullptr;
+ }
+
+ TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const final {
+ Y_UNUSED(many);
+ Y_UNUSED(ctx);
return Y(AggMode == EAggregateMode::OverWindow ? "PgWindowTraits" : "PgAggregationTraits",
Q(PgFunc), Y("ListItemType", type), Lambda);
}
diff --git a/ydb/library/yql/sql/v1/insert.cpp b/ydb/library/yql/sql/v1/insert.cpp
index 4f4c8630181..0620eb4eab4 100644
--- a/ydb/library/yql/sql/v1/insert.cpp
+++ b/ydb/library/yql/sql/v1/insert.cpp
@@ -51,9 +51,10 @@ public:
return nullptr;
}
- TNodePtr BuildAggregation(const TString& label) override {
+ std::pair<TNodePtr, bool> BuildAggregation(const TString& label, TContext& ctx) override {
Y_UNUSED(label);
- return nullptr;
+ Y_UNUSED(ctx);
+ return { nullptr, true };
}
protected:
diff --git a/ydb/library/yql/sql/v1/node.cpp b/ydb/library/yql/sql/v1/node.cpp
index 9ab50224c20..8a1e56b4c94 100644
--- a/ydb/library/yql/sql/v1/node.cpp
+++ b/ydb/library/yql/sql/v1/node.cpp
@@ -1246,20 +1246,35 @@ TAstNode* IAggregation::Translate(TContext& ctx) const {
return nullptr;
}
-TNodePtr IAggregation::AggregationTraits(const TNodePtr& type, bool overState) const {
+std::pair<TNodePtr, bool> IAggregation::AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const {
const bool distinct = AggMode == EAggregateMode::Distinct;
const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type;
- return distinct ?
- Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) :
- Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState)));
+ auto apply = GetApply(listType, many, ctx);
+ if (!apply) {
+ return { nullptr, false };
+ }
+
+ auto wrapped = WrapIfOverState(apply, overState, many, ctx);
+ if (!wrapped) {
+ return { nullptr, false };
+ }
+
+ return { distinct ?
+ Q(Y(Q(Name), wrapped, BuildQuotedAtom(Pos, DistinctKey))) :
+ Q(Y(Q(Name), wrapped)), true };
}
-TNodePtr IAggregation::WrapIfOverState(const TNodePtr& input, bool overState) const {
+TNodePtr IAggregation::WrapIfOverState(const TNodePtr& input, bool overState, bool many, TContext& ctx) const {
if (!overState) {
return input;
}
- return Y("AggOverState", GetExtractor(), BuildLambda(Pos, Y(), input));
+ auto extractor = GetExtractor(many, ctx);
+ if (!extractor) {
+ return nullptr;
+ }
+
+ return Y(ToString("AggOverState"), extractor, BuildLambda(Pos, Y(), input));
}
void IAggregation::AddFactoryArguments(TNodePtr& apply) const {
@@ -1270,9 +1285,9 @@ std::vector<ui32> IAggregation::GetFactoryColumnIndices() const {
return {0u};
}
-TNodePtr IAggregation::WindowTraits(const TNodePtr& type) const {
+TNodePtr IAggregation::WindowTraits(const TNodePtr& type, TContext& ctx) const {
YQL_ENSURE(AggMode == EAggregateMode::OverWindow, "Windows traits is unavailable");
- return Q(Y(Q(Name), GetApply(type)));
+ return Q(Y(Q(Name), GetApply(type, false, ctx)));
}
ISource::ISource(TPosition pos)
@@ -1759,9 +1774,9 @@ bool ISource::SetSamplingRate(TContext& ctx, TNodePtr samplingRate) {
return true;
}
-TNodePtr ISource::BuildAggregation(const TString& label) {
+std::pair<TNodePtr, bool> ISource::BuildAggregation(const TString& label, TContext& ctx) {
if (GroupKeys.empty() && Aggregations.empty() && !IsCompositeSource() && !HoppingWindowSpec) {
- return nullptr;
+ return { nullptr, true };
}
auto keysTuple = Y();
@@ -1786,10 +1801,16 @@ TNodePtr ISource::BuildAggregation(const TString& label) {
const auto listType = Y("TypeOf", label);
auto aggrArgs = Y();
- const bool overState = GroupBySuffix == "CombineState" || GroupBySuffix == "MergeState" || GroupBySuffix == "MergeFinalize";
+ const bool overState = GroupBySuffix == "CombineState" || GroupBySuffix == "MergeState" ||
+ GroupBySuffix == "MergeFinalize" || GroupBySuffix == "MergeManyFinalize";
for (const auto& aggr: Aggregations) {
- if (const auto traits = aggr->AggregationTraits(listType, overState)) {
- aggrArgs = L(aggrArgs, traits);
+ auto res = aggr->AggregationTraits(listType, overState, GroupBySuffix == "MergeManyFinalize", ctx);
+ if (!res.second) {
+ return { nullptr, false };
+ }
+
+ if (res.first) {
+ aggrArgs = L(aggrArgs, res.first);
}
}
@@ -1819,7 +1840,7 @@ TNodePtr ISource::BuildAggregation(const TString& label) {
Q(Y(BuildQuotedAtom(Pos, SessionWindow->GetLabel()), sessionWindow->BuildTraits(label))))));
}
- return Y("AssumeColumnOrderPartial", Y("Aggregate" + GroupBySuffix, label, Q(keysTuple), Q(aggrArgs), Q(options)), Q(keysTuple));
+ return { Y("AssumeColumnOrderPartial", Y("Aggregate" + GroupBySuffix, label, Q(keysTuple), Q(aggrArgs), Q(options)), Q(keysTuple)), true };
}
TMaybe<TString> ISource::FindColumnMistype(const TString& name) const {
@@ -1990,7 +2011,7 @@ TNodePtr ISource::BuildCalcOverWindow(TContext& ctx, const TString& label) {
YQL_ENSURE(frameType);
auto callOnFrame = Y(frameType, BuildWindowFrame(*spec->Frame, spec->IsCompact));
for (auto& agg : aggs) {
- auto winTraits = agg->WindowTraits(listType);
+ auto winTraits = agg->WindowTraits(listType, ctx);
callOnFrame = L(callOnFrame, winTraits);
}
for (auto& func : funcs) {
diff --git a/ydb/library/yql/sql/v1/node.h b/ydb/library/yql/sql/v1/node.h
index 2d779dfc90c..6cdc0098ecc 100644
--- a/ydb/library/yql/sql/v1/node.h
+++ b/ydb/library/yql/sql/v1/node.h
@@ -754,7 +754,7 @@ namespace NSQLTranslationV1 {
virtual bool InitAggr(TContext& ctx, bool isFactory, ISource* src, TAstListNode& node, const TVector<TNodePtr>& exprs) = 0;
- virtual TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const;
+ virtual std::pair<TNodePtr, bool> AggregationTraits(const TNodePtr& type, bool overState, bool many, TContext& ctx) const;
virtual TNodePtr AggregationTraitsFactory() const = 0;
@@ -762,7 +762,7 @@ namespace NSQLTranslationV1 {
virtual void AddFactoryArguments(TNodePtr& apply) const;
- virtual TNodePtr WindowTraits(const TNodePtr& type) const;
+ virtual TNodePtr WindowTraits(const TNodePtr& type, TContext& ctx) const;
const TString& GetName() const;
@@ -772,13 +772,13 @@ namespace NSQLTranslationV1 {
virtual void Join(IAggregation* aggr);
private:
- virtual TNodePtr GetApply(const TNodePtr& type) const = 0;
+ virtual TNodePtr GetApply(const TNodePtr& type, bool many, TContext& ctx) const = 0;
protected:
IAggregation(TPosition pos, const TString& name, const TString& func, EAggregateMode mode);
TAstNode* Translate(TContext& ctx) const override;
- TNodePtr WrapIfOverState(const TNodePtr& input, bool overState) const;
- virtual TNodePtr GetExtractor() const = 0;
+ TNodePtr WrapIfOverState(const TNodePtr& input, bool overState, bool many, TContext& ctx) const;
+ virtual TNodePtr GetExtractor(bool many, TContext& ctx) const = 0;
TString Name;
TString Func;
@@ -860,7 +860,7 @@ namespace NSQLTranslationV1 {
virtual TNodePtr BuildPreaggregatedMap(TContext& ctx);
virtual TNodePtr BuildPreFlattenMap(TContext& ctx);
virtual TNodePtr BuildPrewindowMap(TContext& ctx);
- virtual TNodePtr BuildAggregation(const TString& label);
+ virtual std::pair<TNodePtr, bool> BuildAggregation(const TString& label, TContext& ctx);
virtual TNodePtr BuildCalcOverWindow(TContext& ctx, const TString& label);
virtual TNodePtr BuildSort(TContext& ctx, const TString& label);
virtual TNodePtr BuildCleanupColumns(TContext& ctx, const TString& label);
diff --git a/ydb/library/yql/sql/v1/select.cpp b/ydb/library/yql/sql/v1/select.cpp
index b9c3d448a3a..ac37571185c 100644
--- a/ydb/library/yql/sql/v1/select.cpp
+++ b/ydb/library/yql/sql/v1/select.cpp
@@ -242,9 +242,10 @@ public:
return nullptr;
}
- TNodePtr BuildAggregation(const TString& label) override {
+ std::pair<TNodePtr, bool> BuildAggregation(const TString& label, TContext& ctx) override {
Y_UNUSED(label);
- return nullptr;
+ Y_UNUSED(ctx);
+ return { nullptr, true };
}
TPtr DoClone() const final {
@@ -1552,7 +1553,12 @@ public:
}
src->FinishColumns();
- Aggregate = src->BuildAggregation("core");
+ auto aggRes = src->BuildAggregation("core", ctx);
+ if (!aggRes.second) {
+ return false;
+ }
+
+ Aggregate = aggRes.first;
if (src->IsFlattenByColumns() || src->IsFlattenColumns()) {
Flatten = src->IsFlattenByColumns() ?
src->BuildFlattenByColumns("row") :
diff --git a/ydb/library/yql/sql/v1/sql.cpp b/ydb/library/yql/sql/v1/sql.cpp
index 53f345695ca..1594a3095b9 100644
--- a/ydb/library/yql/sql/v1/sql.cpp
+++ b/ydb/library/yql/sql/v1/sql.cpp
@@ -7224,6 +7224,8 @@ bool TGroupByClause::Build(const TRule_group_by_clause& node, bool stream) {
Suffix = "Finalize";
} else if (mode == "mergefinalize") {
Suffix = "MergeFinalize";
+ } else if (mode == "mergemanyfinalize") {
+ Suffix = "MergeManyFinalize";
} else {
Ctx.Error() << "Unsupported group by mode: " << mode;
Ctx.IncrementMonCounter("sql_errors", "GroupByModeUnknown");