aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-10-06 20:01:44 +0300
committervvvv <vvvv@ydb.tech>2022-10-06 20:01:44 +0300
commit8a0c721548ffb3053cb6a3e08c0a9027712e2894 (patch)
tree10c89ae4e33b6eccabecb0d580263dbc6ab881cc
parent68fd753127853234cd0812750f9bd10c81ab5fa9 (diff)
downloadydb-8a0c721548ffb3053cb6a3e08c0a9027712e2894.tar.gz
typecheck for explicit phases of GROUP BY
Пример запроса на котором проходит typecheck. Для агрегационных функций с несколькими параметрами в режиме over state, он используется из первого аргумента, при этом параметры - литералы (например limit для max_by) используются в том смысле, что их наличие влияет на тип state. Percentile особая функция, для нее в MergeFinalize можно построить из одного state много срезов по разным значениям percentile. %%(sql) --pragma EmitAggApply; $p = SELECT key,count(value) as a,avg(value) as b,percentile(value,0.1) as c, max_by(value,value) as e, sum_if(value,value>0) as f FROM AS_TABLE([<|key: 1, value: 2|>]) GROUP BY key with combine ; $p = PROCESS $p; select FormatType(TypeOf($p)); $p = SELECT key,count(a) as a,avg(b) as b,percentile(c,0.1) as c,max_by(e,e) as e, sum_if(f,f) as f FROM $p GROUP BY key with combinestate ; $p = PROCESS $p; select FormatType(TypeOf($p)); $p = SELECT key,count(a) as a,avg(b) as b,percentile(c,0.1) as c,max_by(e,e) as e, sum_if(f,f) as f FROM $p GROUP BY key with mergestate ; $p = PROCESS $p; select FormatType(TypeOf($p)); $p = SELECT key,count(a) as a,avg(b) as b,percentile(c,0.1) as c,percentile(c,0.2) as d,max_by(e,e) as e, sum_if(f,f) as f FROM $p GROUP BY key with mergefinalize ; $p = PROCESS $p; select FormatType(TypeOf($p)); $p = SELECT key,count(value) as a,avg(value) as b,percentile(value,0.1) as c,percentile(value,0.2) as d, max_by(value,value) as e, sum_if(value,value>0) as f FROM AS_TABLE([<|key: 1, value: 2|>]) GROUP BY key with finalize ; $p = PROCESS $p; select FormatType(TypeOf($p)); %%
-rw-r--r--ydb/library/yql/core/expr_nodes/yql_expr_nodes.json30
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp7
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp314
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.h2
-rw-r--r--ydb/library/yql/sql/v1/aggregation.cpp30
-rw-r--r--ydb/library/yql/sql/v1/node.cpp20
-rw-r--r--ydb/library/yql/sql/v1/node.h4
7 files changed, 303 insertions, 104 deletions
diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
index 5c5625782d9..716d0b80682 100644
--- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
+++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
@@ -2156,6 +2156,36 @@
{"Index": 0, "Name": "Input", "Type": "TExprBase"},
{"Index": 1, "Name": "Name", "Type": "TCoAtom"}
]
+ },
+ {
+ "Name": "TCoAggApplyBase",
+ "Base": "TCallable",
+ "Match": {"Type": "CallableBase"},
+ "Builder": {"Generate": "None"},
+ "Children": [
+ {"Index": 0, "Name": "Name", "Type": "TCoAtom"},
+ {"Index": 1, "Name": "InputType", "Type": "TExprBase"},
+ {"Index": 2, "Name": "Extractor", "Type": "TCoLambda"}
+ ]
+ },
+ {
+ "Name": "TCoAggApply",
+ "Base": "TCoAggApplyBase",
+ "Match": {"Type": "Callable", "Name": "AggApply"}
+ },
+ {
+ "Name": "TCoAggApplyState",
+ "Base": "TCoAggApplyBase",
+ "Match": {"Type": "Callable", "Name": "AggApplyState"}
+ },
+ {
+ "Name": "TCoAggOverState",
+ "Base": "TCallable",
+ "Match": {"Type": "Callable", "Name": "AggOverState"},
+ "Children": [
+ {"Index": 0, "Name": "Extractor", "Type": "TCoLambda"},
+ {"Index": 1, "Name": "Trait", "Type": "TCoLambda"}
+ ]
}
]
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 2957ae2fe55..3c458a0ff49 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11439,9 +11439,16 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["AggregationTraits"] = &AggregationTraitsWrapper;
Functions["MultiAggregate"] = &MultiAggregateWrapper;
Functions["Aggregate"] = &AggregateWrapper;
+ Functions["AggregateCombine"] = &AggregateWrapper;
+ Functions["AggregateCombineState"] = &AggregateWrapper;
+ Functions["AggregateMergeState"] = &AggregateWrapper;
+ Functions["AggregateFinalize"] = &AggregateWrapper;
+ Functions["AggregateMergeFinalize"] = &AggregateWrapper;
+ Functions["AggOverState"] = &AggOverStateWrapper;
Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper;
Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper;
Functions["AggApply"] = &AggApplyWrapper;
+ Functions["AggApplyState"] = &AggApplyWrapper;
Functions["WinOnRows"] = &WinOnWrapper;
Functions["WinOnGroups"] = &WinOnWrapper;
Functions["WinOnRange"] = &WinOnWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index a27aead2027..e5a4b3ba4bb 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -4344,7 +4344,6 @@ namespace {
}
IGraphTransformer::TStatus AggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
- Y_UNUSED(output);
if (!EnsureArgsCount(*input, 8, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -4359,12 +4358,13 @@ namespace {
return IGraphTransformer::TStatus::Repeat;
}
- auto status = ConvertToLambda(input->ChildRef(1), ctx.Expr, 1, 2);
- status = status.Combine(ConvertToLambda(input->ChildRef(2), ctx.Expr, 2, 3));
- status = status.Combine(ConvertToLambda(input->ChildRef(3), ctx.Expr, 1));
- status = status.Combine(ConvertToLambda(input->ChildRef(4), ctx.Expr, 1));
- status = status.Combine(ConvertToLambda(input->ChildRef(5), ctx.Expr, 2));
- status = status.Combine(ConvertToLambda(input->ChildRef(6), ctx.Expr, 1));
+ IGraphTransformer::TStatus status = IGraphTransformer::TStatus::Ok;
+ status = status.Combine(ConvertToLambda(input->ChildRef(1), ctx.Expr, 1, 2)); // init
+ status = status.Combine(ConvertToLambda(input->ChildRef(2), ctx.Expr, 2, 3)); // update
+ status = status.Combine(ConvertToLambda(input->ChildRef(3), ctx.Expr, 1)); // save
+ status = status.Combine(ConvertToLambda(input->ChildRef(4), ctx.Expr, 1)); // load
+ status = status.Combine(ConvertToLambda(input->ChildRef(5), ctx.Expr, 2)); // merge
+ status = status.Combine(ConvertToLambda(input->ChildRef(6), ctx.Expr, 1)); // finish
if (status.Level != IGraphTransformer::TStatus::Ok) {
return status;
}
@@ -4373,6 +4373,9 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
+ auto& lambdaUpdate = input->ChildRef(2);
+ const bool overState = lambdaUpdate->Tail().IsCallable("Void");
+
auto& lambdaInit = input->ChildRef(1);
auto ui32Type = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint32);
@@ -4380,8 +4383,7 @@ namespace {
if (!UpdateLambdaAllArgumentsTypes(lambdaInit, { itemType }, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- }
- else {
+ } else {
if (!UpdateLambdaAllArgumentsTypes(lambdaInit, { itemType, ui32Type }, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -4396,61 +4398,96 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
- auto& lambdaUpdate = input->ChildRef(2);
- if (lambdaUpdate->Head().ChildrenSize() == 2U) {
- if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
- }
- else {
- if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
+ if (!overState) {
+ if (lambdaUpdate->Head().ChildrenSize() == 2U) {
+ if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+ } else {
+ if (!UpdateLambdaAllArgumentsTypes(lambdaUpdate, { itemType, combineStateType, ui32Type }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
}
- }
- if (!lambdaUpdate->GetTypeAnn()) {
- return IGraphTransformer::TStatus::Repeat;
- }
+ if (!lambdaUpdate->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
- if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *combineStateType)) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch update lambda result type, expected: "
- << *combineStateType << ", but got: " << *lambdaUpdate->GetTypeAnn()));
- return IGraphTransformer::TStatus::Error;
+ if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *combineStateType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch update lambda result type, expected: "
+ << *combineStateType << ", but got: " << *lambdaUpdate->GetTypeAnn()));
+ return IGraphTransformer::TStatus::Error;
+ }
}
auto& lambdaMerge = input->ChildRef(5);
const bool noSaveLoad = lambdaMerge->Tail().IsCallable("Void");
-
- auto& lambdaSave = input->ChildRef(3);
- if (!UpdateLambdaAllArgumentsTypes(lambdaSave, { combineStateType }, ctx.Expr)) {
+ if (overState && noSaveLoad) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaMerge->Pos()), "Merge handler should be specified because of aggregation over states"));
return IGraphTransformer::TStatus::Error;
}
- if (!lambdaSave->GetTypeAnn()) {
- return IGraphTransformer::TStatus::Repeat;
- }
+ const TTypeAnnotationNode* reduceStateType = nullptr;
+ if (!overState) {
+ auto& lambdaSave = input->ChildRef(3);
+ if (!UpdateLambdaAllArgumentsTypes(lambdaSave, { combineStateType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
- auto savedType = lambdaSave->GetTypeAnn();
- if (!noSaveLoad && !EnsurePersistableType(lambdaSave->Pos(), *savedType, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
+ if (!lambdaSave->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
- auto& lambdaLoad = input->ChildRef(4);
- if (!UpdateLambdaAllArgumentsTypes(lambdaLoad, { savedType }, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
+ auto savedType = lambdaSave->GetTypeAnn();
+ if (!noSaveLoad && !EnsurePersistableType(lambdaSave->Pos(), *savedType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
- if (!lambdaLoad->GetTypeAnn()) {
- return IGraphTransformer::TStatus::Repeat;
- }
+ auto& lambdaLoad = input->ChildRef(4);
+ if (!UpdateLambdaAllArgumentsTypes(lambdaLoad, { savedType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
- if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *lambdaLoad->GetTypeAnn())) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch state type after load, expected: "
- << *lambdaUpdate->GetTypeAnn() << ", but got: " << *lambdaLoad->GetTypeAnn()));
- return IGraphTransformer::TStatus::Error;
+ if (!lambdaLoad->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ auto& lambdaUpdate = input->ChildRef(2);
+ if (!IsSameAnnotation(*lambdaUpdate->GetTypeAnn(), *lambdaLoad->GetTypeAnn())) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch state type after load, expected: "
+ << *lambdaUpdate->GetTypeAnn() << ", but got: " << *lambdaLoad->GetTypeAnn()));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ reduceStateType = lambdaLoad->GetTypeAnn();
+ } else {
+ auto& lambdaLoad = input->ChildRef(4);
+ if (!UpdateLambdaAllArgumentsTypes(lambdaLoad, { combineStateType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!lambdaLoad->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ reduceStateType = lambdaLoad->GetTypeAnn();
+
+ auto& lambdaSave = input->ChildRef(3);
+ if (!UpdateLambdaAllArgumentsTypes(lambdaSave, { reduceStateType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!lambdaSave->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ if (!IsSameAnnotation(*lambdaSave->GetTypeAnn(), *combineStateType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaSave->Pos()), TStringBuilder() << "Mismatch serialized state type after save, expected: "
+ << *combineStateType << ", but got: " << *lambdaSave->GetTypeAnn()));
+ return IGraphTransformer::TStatus::Error;
+ }
}
- auto reduceStateType = lambdaLoad->GetTypeAnn();
if (!UpdateLambdaAllArgumentsTypes(lambdaMerge, { reduceStateType, reduceStateType }, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -4460,7 +4497,7 @@ namespace {
}
if (!noSaveLoad && !IsSameAnnotation(*lambdaMerge->GetTypeAnn(), *reduceStateType)) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaUpdate->Pos()), TStringBuilder() << "Mismatch merge lambda result type, expected: "
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambdaMerge->Pos()), TStringBuilder() << "Mismatch merge lambda result type, expected: "
<< *reduceStateType << ", but got: " << *lambdaMerge->GetTypeAnn()));
return IGraphTransformer::TStatus::Error;
}
@@ -4531,6 +4568,8 @@ namespace {
}
IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ TStringBuf suffix = input->Content();
+ YQL_ENSURE(suffix.SkipPrefix("Aggregate"));
if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -4726,7 +4765,7 @@ namespace {
}
}
- const bool isAggApply = child->Child(1)->IsCallable("AggApply");
+ const bool isAggApply = child->Child(1)->IsCallable({ "AggApply", "AggApplyState" });
const bool isTraits = child->Child(1)->IsCallable("AggregationTraits");
if (!isAggApply && !isTraits) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Pos()), "Expected aggregation traits"));
@@ -4754,58 +4793,71 @@ namespace {
}
}
- auto finishType = isAggApply ? child->Child(1)->GetTypeAnn() : child->Child(1)->Child(6)->GetTypeAnn();
- bool isOptional = finishType->GetKind() == ETypeAnnotationKind::Optional;
- if (child->Head().IsList()) {
- if (isOptional) {
- finishType = finishType->Cast<TOptionalExprType>()->GetItemType();
- }
+ if (suffix == "" || suffix.EndsWith("Finalize")) {
+ auto finishType = isAggApply ? child->Child(1)->GetTypeAnn() : child->Child(1)->Child(6)->GetTypeAnn();
+ bool isOptional = finishType->GetKind() == ETypeAnnotationKind::Optional;
+ if (child->Head().IsList()) {
+ if (isOptional) {
+ finishType = finishType->Cast<TOptionalExprType>()->GetItemType();
+ }
- const auto tupleType = finishType->Cast<TTupleExprType>();
- if (tupleType->GetSize() != child->Head().ChildrenSize()) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Child(6)->Pos()),
- TStringBuilder() << "Expected tuple type of size: " << child->Head().ChildrenSize() << ", but got: " << tupleType->GetSize()));
- return IGraphTransformer::TStatus::Error;
- }
+ const auto tupleType = finishType->Cast<TTupleExprType>();
+ if (tupleType->GetSize() != child->Head().ChildrenSize()) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Child(6)->Pos()),
+ TStringBuilder() << "Expected tuple type of size: " << child->Head().ChildrenSize() << ", but got: " << tupleType->GetSize()));
+ return IGraphTransformer::TStatus::Error;
+ }
- for (ui32 index = 0; index < tupleType->GetSize(); ++index) {
- const auto item = tupleType->GetItems()[index];
- rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(
- child->Head().Child(index)->Content(),
- (isOptional || (input->Child(1)->ChildrenSize() == 0 && !isHopping)) &&
- item->GetKind() != ETypeAnnotationKind::Optional ? ctx.Expr.MakeType<TOptionalExprType>(item) : item));
- }
- } else {
- const TTypeAnnotationNode* defValType;
- bool isDefNull;
- if (isTraits) {
- auto defVal = child->Child(1)->Child(7);
- isDefNull = defVal->IsCallable("Null");
- defValType = defVal->GetTypeAnn();
+ for (ui32 index = 0; index < tupleType->GetSize(); ++index) {
+ const auto item = tupleType->GetItems()[index];
+ rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(
+ child->Head().Child(index)->Content(),
+ (isOptional || (input->Child(1)->ChildrenSize() == 0 && !isHopping)) &&
+ item->GetKind() != ETypeAnnotationKind::Optional ? ctx.Expr.MakeType<TOptionalExprType>(item) : item));
+ }
} else {
- auto name = child->Child(1)->Child(0)->Content();
- if (name == "count" || name == "count_all") {
- isDefNull = false;
- defValType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
+ const TTypeAnnotationNode* defValType;
+ bool isDefNull;
+ if (isTraits) {
+ auto defVal = child->Child(1)->Child(7);
+ isDefNull = defVal->IsCallable("Null");
+ defValType = defVal->GetTypeAnn();
} else {
- isDefNull = true;
- defValType = ctx.Expr.MakeType<TNullExprType>();
+ auto name = child->Child(1)->Child(0)->Content();
+ if (name == "count" || name == "count_all") {
+ isDefNull = false;
+ defValType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
+ } else {
+ isDefNull = true;
+ defValType = ctx.Expr.MakeType<TNullExprType>();
+ }
}
- }
- if (isDefNull && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) {
- if (finishType->GetKind() != ETypeAnnotationKind::Null &&
- finishType->GetKind() != ETypeAnnotationKind::Pg) {
- finishType = ctx.Expr.MakeType<TOptionalExprType>(finishType);
+ if (isDefNull && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) {
+ if (finishType->GetKind() != ETypeAnnotationKind::Null &&
+ finishType->GetKind() != ETypeAnnotationKind::Pg) {
+ finishType = ctx.Expr.MakeType<TOptionalExprType>(finishType);
+ }
+ } else if (!isDefNull && defValType->GetKind() != ETypeAnnotationKind::Optional
+ && finishType->GetKind() == ETypeAnnotationKind::Optional) {
+ finishType = finishType->Cast<TOptionalExprType>()->GetItemType();
+ } else if (!isDefNull && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) {
+ finishType = defValType;
}
- } else if (!isDefNull && defValType->GetKind() != ETypeAnnotationKind::Optional
- && finishType->GetKind() == ETypeAnnotationKind::Optional) {
- finishType = finishType->Cast<TOptionalExprType>()->GetItemType();
- } else if (!isDefNull && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) {
- finishType = defValType;
- }
- rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), finishType));
+ rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), finishType));
+ }
+ } else if (suffix == "Combine" || suffix == "CombineState" || suffix == "MergeState") {
+ auto stateType = isAggApply ? AggApplySerializedStateType(child->ChildPtr(1), ctx.Expr) : child->Child(1)->Child(3)->GetTypeAnn();
+ if (child->Head().IsList()) {
+ for (const auto& x : child->Head().Children()) {
+ rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(x->Content(), stateType));
+ }
+ } else {
+ rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), stateType));
+ }
+ } else {
+ YQL_ENSURE(false, "Unknown aggregation mode: " << suffix);
}
}
@@ -4820,6 +4872,70 @@ namespace {
return IGraphTransformer::TStatus::Ok;
}
+ IGraphTransformer::TStatus AggOverStateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto& lambda1 = input->ChildRef(0);
+ auto status = ConvertToLambda(lambda1, ctx.Expr, 1);
+ if (status.Level != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ auto& lambda2 = input->ChildRef(1);
+ status = ConvertToLambda(lambda2, ctx.Expr, 0);
+ if (status.Level != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ auto root = lambda2->TailPtr();
+ if (root->IsCallable("AggApply")) {
+ if (!EnsureArgsCount(*root, 3, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ // extractor for state, not initial value itself
+ output = ctx.Expr.Builder(input->Pos())
+ .Callable("AggApplyState")
+ .Add(0, root->ChildPtr(0))
+ .Add(1, root->ChildPtr(1))
+ .Add(2, input->ChildPtr(0))
+ .Seal()
+ .Build();
+
+ return IGraphTransformer::TStatus::Repeat;
+ } else if (root->IsCallable("AggregationTraits")) {
+ // make Void update handler
+ if (!EnsureArgsCount(*root, 8, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ output = ctx.Expr.Builder(input->Pos())
+ .Callable("AggregationTraits")
+ .Add(0, root->ChildPtr(0))
+ .Add(1, input->ChildPtr(0)) // extractor for state, not initial value itself
+ .Lambda(2)
+ .Param("item")
+ .Param("state")
+ .Callable("Void")
+ .Seal()
+ .Seal()
+ .Add(3, root->ChildPtr(3))
+ .Add(4, root->ChildPtr(4))
+ .Add(5, root->ChildPtr(5))
+ .Add(6, root->ChildPtr(6))
+ .Add(7, root->ChildPtr(7))
+ .Seal()
+ .Build();
+
+ return IGraphTransformer::TStatus::Repeat;
+ } else {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(root->Pos()), "Expected aggregation traits"));
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
if (!EnsureArgsCount(*input, 1, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -4938,6 +5054,16 @@ namespace {
return IGraphTransformer::TStatus::Ok;
}
+ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx) {
+ Y_UNUSED(ctx);
+ auto name = input->Child(0)->Content();
+ if (name == "count" || name == "count_all" || name == "sum") {
+ return input->GetTypeAnn();
+ } else {
+ YQL_ENSURE(false, "Unknown AggApply: " << name);
+ }
+ }
+
IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 3, ctx.Expr)) {
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h
index c6f1ccf8cc8..09238b78a35 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.h
+++ b/ydb/library/yql/core/type_ann/type_ann_list.h
@@ -10,6 +10,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus InferPositionalUnionType(TPositionHandle pos, const TExprNode::TListType& children,
TColumnOrder& resultColumnOrder, const TStructExprType*& resultStructType, TExtContext& ctx);
TExprNode::TPtr ExpandToWindowTraits(const TExprNode& input, TExprContext& ctx);
+ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx);
IGraphTransformer::TStatus FilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
template <bool InverseCondition>
@@ -93,6 +94,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus AggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+ IGraphTransformer::TStatus AggOverStateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus CountedAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp
index 1bf9ff4457d..b9157664b01 100644
--- a/ydb/library/yql/sql/v1/aggregation.cpp
+++ b/ydb/library/yql/sql/v1/aggregation.cpp
@@ -114,6 +114,10 @@ protected:
return Factory;
}
+ TNodePtr GetExtractor() const override {
+ return BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr));
+ }
+
TNodePtr GetApply(const TNodePtr& type) const override {
if (!Multi) {
if (!DynamicFactory && !AggApplyName.empty()) {
@@ -126,7 +130,7 @@ protected:
return Y("MultiAggregate",
Y("ListItemType", type),
- BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)),
+ GetExtractor(),
Factory);
}
@@ -284,6 +288,10 @@ private:
return new TKeyPayloadAggregationFactory(Pos, Name, Func, AggMode);
}
+ TNodePtr GetExtractor() const final {
+ return BuildLambda(Pos, Y("row"), Payload);
+ }
+
TNodePtr GetApply(const TNodePtr& type) const final {
auto apply = Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Key), BuildLambda(Pos, Y("row"), Payload));
AddFactoryArguments(apply);
@@ -376,6 +384,10 @@ private:
return new TPayloadPredicateAggregationFactory(Pos, Name, Func, AggMode);
}
+ TNodePtr GetExtractor() const final {
+ return BuildLambda(Pos, Y("row"), Payload);
+ }
+
TNodePtr GetApply(const TNodePtr& type) const final {
return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), Payload), BuildLambda(Pos, Y("row"), Predicate));
}
@@ -454,6 +466,10 @@ private:
return new TTwoArgsAggregationFactory(Pos, Name, Func, AggMode);
}
+ TNodePtr GetExtractor() const final {
+ return BuildLambda(Pos, Y("row"), One);
+ }
+
TNodePtr GetApply(const TNodePtr& type) const final {
auto tuple = Q(Y(One, Two));
return Y("Apply", Factory, type, BuildLambda(Pos, Y("row"), tuple));
@@ -736,7 +752,7 @@ private:
apply = L(apply, FactoryPercentile);
}
- TNodePtr AggregationTraits(const TNodePtr& type) const final {
+ TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final {
if (Percentiles.empty())
return TNodePtr();
@@ -751,7 +767,9 @@ private:
const bool distinct = AggMode == EAggregateMode::Distinct;
const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type;
- return distinct ? Q(Y(names, GetApply(listType), BuildQuotedAtom(Pos, DistinctKey))) : Q(Y(names, GetApply(listType)));
+ return distinct ?
+ Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) :
+ Q(Y(names, WrapIfOverState(GetApply(listType), overState)));
}
bool DoInit(TContext& ctx, ISource* src) final {
@@ -858,7 +876,7 @@ private:
apply = L(apply, TopFreqFactoryParams.first, TopFreqFactoryParams.second);
}
- TNodePtr AggregationTraits(const TNodePtr& type) const final {
+ TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const final {
if (TopFreqs.empty())
return TNodePtr();
@@ -873,7 +891,9 @@ private:
const bool distinct = AggMode == EAggregateMode::Distinct;
const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type;
- return distinct ? Q(Y(names, GetApply(listType), BuildQuotedAtom(Pos, DistinctKey))) : Q(Y(names, GetApply(listType)));
+ return distinct ?
+ Q(Y(names, WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) :
+ Q(Y(names, WrapIfOverState(GetApply(listType), overState)));
}
bool DoInit(TContext& ctx, ISource* src) final {
diff --git a/ydb/library/yql/sql/v1/node.cpp b/ydb/library/yql/sql/v1/node.cpp
index 8a407d0768e..9ab50224c20 100644
--- a/ydb/library/yql/sql/v1/node.cpp
+++ b/ydb/library/yql/sql/v1/node.cpp
@@ -1246,10 +1246,20 @@ TAstNode* IAggregation::Translate(TContext& ctx) const {
return nullptr;
}
-TNodePtr IAggregation::AggregationTraits(const TNodePtr& type) const {
+TNodePtr IAggregation::AggregationTraits(const TNodePtr& type, bool overState) const {
const bool distinct = AggMode == EAggregateMode::Distinct;
const auto listType = distinct ? Y("ListType", Y("StructMemberType", Y("ListItemType", type), BuildQuotedAtom(Pos, DistinctKey))) : type;
- return distinct ? Q(Y(Q(Name), GetApply(listType), BuildQuotedAtom(Pos, DistinctKey))): Q(Y(Q(Name), GetApply(listType)));
+ return distinct ?
+ Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState), BuildQuotedAtom(Pos, DistinctKey))) :
+ Q(Y(Q(Name), WrapIfOverState(GetApply(listType), overState)));
+}
+
+TNodePtr IAggregation::WrapIfOverState(const TNodePtr& input, bool overState) const {
+ if (!overState) {
+ return input;
+ }
+
+ return Y("AggOverState", GetExtractor(), BuildLambda(Pos, Y(), input));
}
void IAggregation::AddFactoryArguments(TNodePtr& apply) const {
@@ -1776,13 +1786,15 @@ TNodePtr ISource::BuildAggregation(const TString& label) {
const auto listType = Y("TypeOf", label);
auto aggrArgs = Y();
+ const bool overState = GroupBySuffix == "CombineState" || GroupBySuffix == "MergeState" || GroupBySuffix == "MergeFinalize";
for (const auto& aggr: Aggregations) {
- if (const auto traits = aggr->AggregationTraits(listType))
+ if (const auto traits = aggr->AggregationTraits(listType, overState)) {
aggrArgs = L(aggrArgs, traits);
+ }
}
auto options = Y();
- if (CompactGroupBy) {
+ if (CompactGroupBy || GroupBySuffix == "Finalize") {
options = L(options, Q(Y(Q("compact"))));
}
diff --git a/ydb/library/yql/sql/v1/node.h b/ydb/library/yql/sql/v1/node.h
index 1a23499c79a..2d779dfc90c 100644
--- a/ydb/library/yql/sql/v1/node.h
+++ b/ydb/library/yql/sql/v1/node.h
@@ -754,7 +754,7 @@ namespace NSQLTranslationV1 {
virtual bool InitAggr(TContext& ctx, bool isFactory, ISource* src, TAstListNode& node, const TVector<TNodePtr>& exprs) = 0;
- virtual TNodePtr AggregationTraits(const TNodePtr& type) const;
+ virtual TNodePtr AggregationTraits(const TNodePtr& type, bool overState) const;
virtual TNodePtr AggregationTraitsFactory() const = 0;
@@ -777,6 +777,8 @@ namespace NSQLTranslationV1 {
protected:
IAggregation(TPosition pos, const TString& name, const TString& func, EAggregateMode mode);
TAstNode* Translate(TContext& ctx) const override;
+ TNodePtr WrapIfOverState(const TNodePtr& input, bool overState) const;
+ virtual TNodePtr GetExtractor() const = 0;
TString Name;
TString Func;