diff options
author | vvvv <vvvv@ydb.tech> | 2022-10-06 20:01:44 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-10-06 20:01:44 +0300 |
commit | 8a0c721548ffb3053cb6a3e08c0a9027712e2894 (patch) | |
tree | 10c89ae4e33b6eccabecb0d580263dbc6ab881cc | |
parent | 68fd753127853234cd0812750f9bd10c81ab5fa9 (diff) | |
download | ydb-8a0c721548ffb3053cb6a3e08c0a9027712e2894.tar.gz |
typecheck for explicit phases of GROUP BY
Пример запроса на котором проходит typecheck. Для агрегационных функций с несколькими параметрами в режиме over state, он используется из первого аргумента, при этом параметры - литералы (например limit для max_by) используются в том смысле, что их наличие влияет на тип state.
Percentile особая функция, для нее в MergeFinalize можно построить из одного state много срезов по разным значениям percentile.
%%(sql)
--pragma EmitAggApply;
$p =
SELECT
key,count(value) as a,avg(value) as b,percentile(value,0.1) as c,
max_by(value,value) as e, sum_if(value,value>0) as f
FROM AS_TABLE([<|key: 1, value: 2|>])
GROUP BY
key
with combine
;
$p = PROCESS $p;
select FormatType(TypeOf($p));
$p = SELECT
key,count(a) as a,avg(b) as b,percentile(c,0.1) as c,max_by(e,e) as e,
sum_if(f,f) as f
FROM $p
GROUP BY
key
with combinestate
;
$p = PROCESS $p;
select FormatType(TypeOf($p));
$p = SELECT
key,count(a) as a,avg(b) as b,percentile(c,0.1) as c,max_by(e,e) as e,
sum_if(f,f) as f
FROM $p
GROUP BY
key
with mergestate
;
$p = PROCESS $p;
select FormatType(TypeOf($p));
$p = SELECT
key,count(a) as a,avg(b) as b,percentile(c,0.1) as c,percentile(c,0.2) as d,max_by(e,e) as e,
sum_if(f,f) as f
FROM $p
GROUP BY
key
with mergefinalize
;
$p = PROCESS $p;
select FormatType(TypeOf($p));
$p =
SELECT
key,count(value) as a,avg(value) as b,percentile(value,0.1) as c,percentile(value,0.2) as d, max_by(value,value) as e,
sum_if(value,value>0) as f
FROM AS_TABLE([<|key: 1, value: 2|>])
GROUP BY
key
with finalize
;
$p = PROCESS $p;
select FormatType(TypeOf($p));
%%
-rw-r--r-- | ydb/library/yql/core/expr_nodes/yql_expr_nodes.json | 30 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 7 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 314 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/aggregation.cpp | 30 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/node.cpp | 20 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/node.h | 4 |
7 files changed, 303 insertions, 104 deletions
diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json index 5c5625782d9..716d0b80682 100644 --- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json +++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json @@ -2156,6 +2156,36 @@ {"Index": 0, "Name": "Input", "Type": "TExprBase"}, {"Index": 1, "Name": "Name", "Type": "TCoAtom"} ] + }, + { + "Name": "TCoAggApplyBase", + "Base": "TCallable", + "Match": {"Type": "CallableBase"}, + "Builder": {"Generate": "None"}, + "Children": [ + {"Index": 0, "Name": "Name", "Type": "TCoAtom"}, + {"Index": 1, "Name": "InputType", "Type": "TExprBase"}, + {"Index": 2, "Name": "Extractor", "Type": "TCoLambda"} + ] + }, + { + "Name": "TCoAggApply", + "Base": "TCoAggApplyBase", + "Match": {"Type": "Callable", "Name": "AggApply"} + }, + { + "Name": "TCoAggApplyState", + "Base": "TCoAggApplyBase", + "Match": {"Type": "Callable", "Name": "AggApplyState"} + }, + { + "Name": "TCoAggOverState", + "Base": "TCallable", + "Match": {"Type": "Callable", "Name": "AggOverState"}, + "Children": [ + {"Index": 0, "Name": "Extractor", "Type": "TCoLambda"}, + {"Index": 1, "Name": "Trait", "Type": "TCoLambda"} + ] } ] } diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 2957ae2fe55..3c458a0ff49 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11439,9 +11439,16 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["AggregationTraits"] = &AggregationTraitsWrapper; Functions["MultiAggregate"] = &MultiAggregateWrapper; Functions["Aggregate"] = &AggregateWrapper; + Functions["AggregateCombine"] = &AggregateWrapper; + Functions["AggregateCombineState"] = &AggregateWrapper; + Functions["AggregateMergeState"] = &AggregateWrapper; + Functions["AggregateFinalize"] = &AggregateWrapper; + Functions["AggregateMergeFinalize"] = &AggregateWrapper; + Functions["AggOverState"] = &AggOverStateWrapper; Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper; Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper; Functions["AggApply"] = &AggApplyWrapper; + Functions["AggApplyState"] = &AggApplyWrapper; Functions["WinOnRows"] = &WinOnWrapper; Functions["WinOnGroups"] = &WinOnWrapper; Functions["WinOnRange"] = &WinOnWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index a27aead2027..e5a4b3ba4bb 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -4344,7 +4344,6 @@ namespace { } IGraphTransformer::TStatus AggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { - Y_UNUSED(output); if (!EnsureArgsCount(*input, 8, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -4359,12 +4358,13 @@ namespace { return IGraphTransformer::TStatus::Repeat; } - auto status = ConvertToLambda(input->ChildRef(1), ctx.Expr, 1, 2); - status = status.Combine(ConvertToLambda(input->ChildRef(2), ctx.Expr, 2, 3)); - status = status.Combine(ConvertToLambda(input->ChildRef(3), ctx.Expr, 1)); - status = status.Combine(ConvertToLambda(input->ChildRef(4), ctx.Expr, 1)); - status = status.Combine(ConvertToLambda(input->ChildRef(5), ctx.Expr, 2)); - status = status.Combine(ConvertToLambda(input->ChildRef(6), ctx.Expr, 1)); + IGraphTransformer::TStatus status = IGraphTransformer::TStatus::Ok; + status = status.Combine(ConvertToLambda(input->ChildRef(1), ctx.Expr, 1, 2)); // init + status = status.Combine(ConvertToLambda(input->ChildRef(2), ctx.Expr, 2, 3)); // update + status = status.Combine(ConvertToLambda(input->ChildRef(3), ctx.Expr, 1)); // save + status = status.Combine(ConvertToLambda(input->ChildRef(4), ctx.Expr, 1)); // load + status = status.Combine(ConvertToLambda(input->ChildRef(5), ctx.Expr, 2)); // merge + status = status.Combine(ConvertToLambda(input->ChildRef(6), ctx.Expr, 1)); // finish if (status.Level != IGraphTransformer::TStatus::Ok) { return status; } @@ -4373,6 +4373,9 @@ namespace { return IGraphTransformer::TStatus::Error; } + auto& lambdaUpdate = input->ChildRef(2); + const bool overState = lambdaUpdate->Tail().IsCallable("Void"); + auto& lambdaInit = input->ChildRef(1); auto ui32Type = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint32); @@ -4380,8 +4383,7 @@ namespace { if (!UpdateLambdaAllArgumentsTypes(lambdaInit, { itemType }, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - } - else { + } else { if (!UpdateLambdaAllArgumentsTypes(lambdaInit, { itemType, ui32Type }, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -4396,61 +4398,96 @@ namespace { return IGraphTransformer::TStatus::Error; } - auto& lambdaUpdate = input->ChildRef(2); - if (lambdaUpdate->Head().ChildrenSize() == 2U) { - if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - } - else { - if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; + if (!overState) { + if (lambdaUpdate->Head().ChildrenSize() == 2U) { + if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } else { + if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } } - } - if (!lambdaUpdate->GetTypeAnn()) { - return IGraphTransformer::TStatus::Repeat; - } + if (!lambdaUpdate->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } - if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *combineStateType)) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch update lambda result type, expected: " - << *combineStateType << ", but got: " << *lambdaUpdate->GetTypeAnn())); - return IGraphTransformer::TStatus::Error; + if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *combineStateType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch update lambda result type, expected: " + << *combineStateType << ", but got: " << *lambdaUpdate->GetTypeAnn())); + return IGraphTransformer::TStatus::Error; + } } auto& lambdaMerge = input->ChildRef(5); const bool noSaveLoad = lambdaMerge->Tail().IsCallable("Void"); - - auto& lambdaSave = input->ChildRef(3); - if (!UpdateLambdaAllArgumentsTypes(lambdaSave, { combineStateType }, ctx.Expr)) { + if (overState && noSaveLoad) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaMerge->Pos()), "Merge handler should be specified because of aggregation over states")); return IGraphTransformer::TStatus::Error; } - if (!lambdaSave->GetTypeAnn()) { - return IGraphTransformer::TStatus::Repeat; - } + const TTypeAnnotationNode* reduceStateType = nullptr; + if (!overState) { + auto& lambdaSave = input->ChildRef(3); + if (!UpdateLambdaAllArgumentsTypes(lambdaSave, { combineStateType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } - auto savedType = lambdaSave->GetTypeAnn(); - if (!noSaveLoad && !EnsurePersistableType(lambdaSave->Pos(), *savedType, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } + if (!lambdaSave->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } - auto& lambdaLoad = input->ChildRef(4); - if (!UpdateLambdaAllArgumentsTypes(lambdaLoad, { savedType }, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } + auto savedType = lambdaSave->GetTypeAnn(); + if (!noSaveLoad && !EnsurePersistableType(lambdaSave->Pos(), *savedType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } - if (!lambdaLoad->GetTypeAnn()) { - return IGraphTransformer::TStatus::Repeat; - } + auto& lambdaLoad = input->ChildRef(4); + if (!UpdateLambdaAllArgumentsTypes(lambdaLoad, { savedType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } - if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *lambdaLoad->GetTypeAnn())) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch state type after load, expected: " - << *lambdaUpdate->GetTypeAnn() << ", but got: " << *lambdaLoad->GetTypeAnn())); - return IGraphTransformer::TStatus::Error; + if (!lambdaLoad->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + auto& lambdaUpdate = input->ChildRef(2); + if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *lambdaLoad->GetTypeAnn())) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch state type after load, expected: " + << *lambdaUpdate->GetTypeAnn() << ", but got: " << *lambdaLoad->GetTypeAnn())); + return IGraphTransformer::TStatus::Error; + } + + reduceStateType = lambdaLoad->GetTypeAnn(); + } else { + auto& lambdaLoad = input->ChildRef(4); + if (!UpdateLambdaAllArgumentsTypes(lambdaLoad, { combineStateType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!lambdaLoad->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + reduceStateType = lambdaLoad->GetTypeAnn(); + + auto& lambdaSave = input->ChildRef(3); + if (!UpdateLambdaAllArgumentsTypes(lambdaSave, { reduceStateType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!lambdaSave->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + if (!IsSameAnnotation(*lambdaSave->GetTypeAnn(), *combineStateType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaSave->Pos()), TStringBuilder() << "Mismatch serialized state type after save, expected: " + << *combineStateType << ", but got: " << *lambdaSave->GetTypeAnn())); + return IGraphTransformer::TStatus::Error; + } } - auto reduceStateType = lambdaLoad->GetTypeAnn(); if (!UpdateLambdaAllArgumentsTypes(lambdaMerge, { reduceStateType, reduceStateType }, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -4460,7 +4497,7 @@ namespace { } if (!noSaveLoad && !IsSameAnnotation(*lambdaMerge->GetTypeAnn(), *reduceStateType)) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch merge lambda result type, expected: " + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaMerge->Pos()), TStringBuilder() << "Mismatch merge lambda result type, expected: " << *reduceStateType << ", but got: " << *lambdaMerge->GetTypeAnn())); return IGraphTransformer::TStatus::Error; } @@ -4531,6 +4568,8 @@ namespace { } IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + TStringBuf suffix = input->Content(); + YQL_ENSURE(suffix.SkipPrefix("Aggregate")); if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -4726,7 +4765,7 @@ namespace { } } - const bool isAggApply = child->Child(1)->IsCallable("AggApply"); + const bool isAggApply = child->Child(1)->IsCallable({ "AggApply", "AggApplyState" }); const bool isTraits = child->Child(1)->IsCallable("AggregationTraits"); if (!isAggApply && !isTraits) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Pos()), "Expected aggregation traits")); @@ -4754,58 +4793,71 @@ namespace { } } - auto finishType = isAggApply ? child->Child(1)->GetTypeAnn() : child->Child(1)->Child(6)->GetTypeAnn(); - bool isOptional = finishType->GetKind() == ETypeAnnotationKind::Optional; - if (child->Head().IsList()) { - if (isOptional) { - finishType = finishType->Cast<TOptionalExprType>()->GetItemType(); - } + if (suffix == "" || suffix.EndsWith("Finalize")) { + auto finishType = isAggApply ? child->Child(1)->GetTypeAnn() : child->Child(1)->Child(6)->GetTypeAnn(); + bool isOptional = finishType->GetKind() == ETypeAnnotationKind::Optional; + if (child->Head().IsList()) { + if (isOptional) { + finishType = finishType->Cast<TOptionalExprType>()->GetItemType(); + } - const auto tupleType = finishType->Cast<TTupleExprType>(); - if (tupleType->GetSize() != child->Head().ChildrenSize()) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Child(6)->Pos()), - TStringBuilder() << "Expected tuple type of size: " << child->Head().ChildrenSize() << ", but got: " << tupleType->GetSize())); - return IGraphTransformer::TStatus::Error; - } + const auto tupleType = finishType->Cast<TTupleExprType>(); + if (tupleType->GetSize() != child->Head().ChildrenSize()) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Child(6)->Pos()), + TStringBuilder() << "Expected tuple type of size: " << child->Head().ChildrenSize() << ", but got: " << tupleType->GetSize())); + return IGraphTransformer::TStatus::Error; + } - for (ui32 index = 0; index < tupleType->GetSize(); ++index) { - const auto item = tupleType->GetItems()[index]; - rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>( - child->Head().Child(index)->Content(), - (isOptional || (input->Child(1)->ChildrenSize() == 0 && !isHopping)) && - item->GetKind() != ETypeAnnotationKind::Optional ? ctx.Expr.MakeType<TOptionalExprType>(item) : item)); - } - } else { - const TTypeAnnotationNode* defValType; - bool isDefNull; - if (isTraits) { - auto defVal = child->Child(1)->Child(7); - isDefNull = defVal->IsCallable("Null"); - defValType = defVal->GetTypeAnn(); + for (ui32 index = 0; index < tupleType->GetSize(); ++index) { + const auto item = tupleType->GetItems()[index]; + rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>( + child->Head().Child(index)->Content(), + (isOptional || (input->Child(1)->ChildrenSize() == 0 && !isHopping)) && + item->GetKind() != ETypeAnnotationKind::Optional ? ctx.Expr.MakeType<TOptionalExprType>(item) : item)); + } } else { - auto name = child->Child(1)->Child(0)->Content(); - if (name == "count" || name == "count_all") { - isDefNull = false; - defValType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); + const TTypeAnnotationNode* defValType; + bool isDefNull; + if (isTraits) { + auto defVal = child->Child(1)->Child(7); + isDefNull = defVal->IsCallable("Null"); + defValType = defVal->GetTypeAnn(); } else { - isDefNull = true; - defValType = ctx.Expr.MakeType<TNullExprType>(); + auto name = child->Child(1)->Child(0)->Content(); + if (name == "count" || name == "count_all") { + isDefNull = false; + defValType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); + } else { + isDefNull = true; + defValType = ctx.Expr.MakeType<TNullExprType>(); + } } - } - if (isDefNull && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) { - if (finishType->GetKind() != ETypeAnnotationKind::Null && - finishType->GetKind() != ETypeAnnotationKind::Pg) { - finishType = ctx.Expr.MakeType<TOptionalExprType>(finishType); + if (isDefNull && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) { + if (finishType->GetKind() != ETypeAnnotationKind::Null && + finishType->GetKind() != ETypeAnnotationKind::Pg) { + finishType = ctx.Expr.MakeType<TOptionalExprType>(finishType); + } + } else if (!isDefNull && defValType->GetKind() != ETypeAnnotationKind::Optional + && finishType->GetKind() == ETypeAnnotationKind::Optional) { + finishType = finishType->Cast<TOptionalExprType>()->GetItemType(); + } else if (!isDefNull && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) { + finishType = defValType; } - } else if (!isDefNull && defValType->GetKind() != ETypeAnnotationKind::Optional - && finishType->GetKind() == ETypeAnnotationKind::Optional) { - finishType = finishType->Cast<TOptionalExprType>()->GetItemType(); - } else if (!isDefNull && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) { - finishType = defValType; - } - rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), finishType)); + rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), finishType)); + } + } else if (suffix == "Combine" || suffix == "CombineState" || suffix == "MergeState") { + auto stateType = isAggApply ? AggApplySerializedStateType(child->ChildPtr(1), ctx.Expr) : child->Child(1)->Child(3)->GetTypeAnn(); + if (child->Head().IsList()) { + for (const auto& x : child->Head().Children()) { + rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(x->Content(), stateType)); + } + } else { + rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), stateType)); + } + } else { + YQL_ENSURE(false, "Unknown aggregation mode: " << suffix); } } @@ -4820,6 +4872,70 @@ namespace { return IGraphTransformer::TStatus::Ok; } + IGraphTransformer::TStatus AggOverStateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + if (!EnsureArgsCount(*input, 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto& lambda1 = input->ChildRef(0); + auto status = ConvertToLambda(lambda1, ctx.Expr, 1); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + auto& lambda2 = input->ChildRef(1); + status = ConvertToLambda(lambda2, ctx.Expr, 0); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + auto root = lambda2->TailPtr(); + if (root->IsCallable("AggApply")) { + if (!EnsureArgsCount(*root, 3, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + // extractor for state, not initial value itself + output = ctx.Expr.Builder(input->Pos()) + .Callable("AggApplyState") + .Add(0, root->ChildPtr(0)) + .Add(1, root->ChildPtr(1)) + .Add(2, input->ChildPtr(0)) + .Seal() + .Build(); + + return IGraphTransformer::TStatus::Repeat; + } else if (root->IsCallable("AggregationTraits")) { + // make Void update handler + if (!EnsureArgsCount(*root, 8, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + output = ctx.Expr.Builder(input->Pos()) + .Callable("AggregationTraits") + .Add(0, root->ChildPtr(0)) + .Add(1, input->ChildPtr(0)) // extractor for state, not initial value itself + .Lambda(2) + .Param("item") + .Param("state") + .Callable("Void") + .Seal() + .Seal() + .Add(3, root->ChildPtr(3)) + .Add(4, root->ChildPtr(4)) + .Add(5, root->ChildPtr(5)) + .Add(6, root->ChildPtr(6)) + .Add(7, root->ChildPtr(7)) + .Seal() + .Build(); + + return IGraphTransformer::TStatus::Repeat; + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(root->Pos()), "Expected aggregation traits")); + return IGraphTransformer::TStatus::Error; + } + } + IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { if (!EnsureArgsCount(*input, 1, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -4938,6 +5054,16 @@ namespace { return IGraphTransformer::TStatus::Ok; } + const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx) { + Y_UNUSED(ctx); + auto name = input->Child(0)->Content(); + if (name == "count" || name == "count_all" || name == "sum") { + return input->GetTypeAnn(); + } else { + YQL_ENSURE(false, "Unknown AggApply: " << name); + } + } + IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); if (!EnsureArgsCount(*input, 3, ctx.Expr)) { diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h index c6f1ccf8cc8..09238b78a35 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.h +++ b/ydb/library/yql/core/type_ann/type_ann_list.h @@ -10,6 +10,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus InferPositionalUnionType(TPositionHandle pos, const TExprNode::TListType& children, TColumnOrder& resultColumnOrder, const TStructExprType*& resultStructType, TExtContext& ctx); TExprNode::TPtr ExpandToWindowTraits(const TExprNode& input, TExprContext& ctx); + const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx); IGraphTransformer::TStatus FilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); template <bool InverseCondition> @@ -93,6 +94,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus AggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus AggOverStateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus CountedAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp index 1bf9ff4457d..b9157664b01 100644 --- a/ydb/library/yql/sql/v1/aggregation.cpp +++ b/ydb/library/yql/sql/v1/aggregation.cpp @@ -114,6 +114,10 @@ protected: return Factory; } + TNodePtr GetExtractor() const override { + return BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)); + } + TNodePtr GetApply(const TNodePtr& type) const override { if (!Multi) { if (!DynamicFactory && !AggApplyName.empty()) { @@ -126,7 +130,7 @@ protected: return Y("MultiAggregate", Y("ListItemType", type), - BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)), + GetExtractor(), Factory); } @@ -284,6 +288,10 @@ private: return new TKeyPayloadAggregationFactory(Pos, Name, Func, AggMode); } + TNodePtr GetExtractor() const final { + return BuildLambda(Pos, Y("row"), Payload); + } + TNodePtr GetApply(const TNodePtr& type) const final { auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Key), BuildLambda(Pos, Y("row"), Payload)); AddFactoryArguments(apply); @@ -376,6 +384,10 @@ private: return new TPayloadPredicateAggregationFactory(Pos, Name, Func, AggMode); } + TNodePtr GetExtractor() const final { + return BuildLambda(Pos, Y("row"), Payload); + } + TNodePtr GetApply(const TNodePtr& type) const final { return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Payload), BuildLambda(Pos, Y("row"), Predicate)); } @@ -454,6 +466,10 @@ private: return new TTwoArgsAggregationFactory(Pos, Name, Func, AggMode); } + TNodePtr GetExtractor() const final { + return BuildLambda(Pos, Y("row"), One); + } + TNodePtr GetApply(const TNodePtr& type) const final { auto tuple = Q(Y(One, Two)); return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), tuple)); @@ -736,7 +752,7 @@ private: apply = L(apply, FactoryPercentile); } - TNodePtr AggregationTraits(const TNodePtr& type) const final { + TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final { if (Percentiles.empty()) return TNodePtr(); @@ -751,7 +767,9 @@ private: const bool distinct = AggMode == EAggregateMode::Distinct; const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type; - return distinct ? Q(Y(names, GetApply(listType), BuildQuotedAtom(Pos, DistinctKey))) : Q(Y(names, GetApply(listType))); + return distinct ? + Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) : + Q(Y(names, WrapIfOverState(GetApply(listType), overState))); } bool DoInit(TContext& ctx, ISource* src) final { @@ -858,7 +876,7 @@ private: apply = L(apply, TopFreqFactoryParams.first, TopFreqFactoryParams.second); } - TNodePtr AggregationTraits(const TNodePtr& type) const final { + TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final { if (TopFreqs.empty()) return TNodePtr(); @@ -873,7 +891,9 @@ private: const bool distinct = AggMode == EAggregateMode::Distinct; const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type; - return distinct ? Q(Y(names, GetApply(listType), BuildQuotedAtom(Pos, DistinctKey))) : Q(Y(names, GetApply(listType))); + return distinct ? + Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) : + Q(Y(names, WrapIfOverState(GetApply(listType), overState))); } bool DoInit(TContext& ctx, ISource* src) final { diff --git a/ydb/library/yql/sql/v1/node.cpp b/ydb/library/yql/sql/v1/node.cpp index 8a407d0768e..9ab50224c20 100644 --- a/ydb/library/yql/sql/v1/node.cpp +++ b/ydb/library/yql/sql/v1/node.cpp @@ -1246,10 +1246,20 @@ TAstNode* IAggregation::Translate(TContext& ctx) const { return nullptr; } -TNodePtr IAggregation::AggregationTraits(const TNodePtr& type) const { +TNodePtr IAggregation::AggregationTraits(const TNodePtr& type, bool overState) const { const bool distinct = AggMode == EAggregateMode::Distinct; const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type; - return distinct ? Q(Y(Q(Name), GetApply(listType), BuildQuotedAtom(Pos, DistinctKey))): Q(Y(Q(Name), GetApply(listType))); + return distinct ? + Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) : + Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState))); +} + +TNodePtr IAggregation::WrapIfOverState(const TNodePtr& input, bool overState) const { + if (!overState) { + return input; + } + + return Y("AggOverState", GetExtractor(), BuildLambda(Pos, Y(), input)); } void IAggregation::AddFactoryArguments(TNodePtr& apply) const { @@ -1776,13 +1786,15 @@ TNodePtr ISource::BuildAggregation(const TString& label) { const auto listType = Y("TypeOf", label); auto aggrArgs = Y(); + const bool overState = GroupBySuffix == "CombineState" || GroupBySuffix == "MergeState" || GroupBySuffix == "MergeFinalize"; for (const auto& aggr: Aggregations) { - if (const auto traits = aggr->AggregationTraits(listType)) + if (const auto traits = aggr->AggregationTraits(listType, overState)) { aggrArgs = L(aggrArgs, traits); + } } auto options = Y(); - if (CompactGroupBy) { + if (CompactGroupBy || GroupBySuffix == "Finalize") { options = L(options, Q(Y(Q("compact")))); } diff --git a/ydb/library/yql/sql/v1/node.h b/ydb/library/yql/sql/v1/node.h index 1a23499c79a..2d779dfc90c 100644 --- a/ydb/library/yql/sql/v1/node.h +++ b/ydb/library/yql/sql/v1/node.h @@ -754,7 +754,7 @@ namespace NSQLTranslationV1 { virtual bool InitAggr(TContext& ctx, bool isFactory, ISource* src, TAstListNode& node, const TVector<TNodePtr>& exprs) = 0; - virtual TNodePtr AggregationTraits(const TNodePtr& type) const; + virtual TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const; virtual TNodePtr AggregationTraitsFactory() const = 0; @@ -777,6 +777,8 @@ namespace NSQLTranslationV1 { protected: IAggregation(TPosition pos, const TString& name, const TString& func, EAggregateMode mode); TAstNode* Translate(TContext& ctx) const override; + TNodePtr WrapIfOverState(const TNodePtr& input, bool overState) const; + virtual TNodePtr GetExtractor() const = 0; TString Name; TString Func; |