diff options
author | vvvv <vvvv@ydb.tech> | 2023-05-22 16:51:52 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-05-22 16:51:52 +0300 |
commit | 7b06ae00f14aff82006eeeb5375af1a79ef2cf00 (patch) | |
tree | e7601642d8ab52b8ad9a00da28dbc8750acc2330 | |
parent | f86980eca696fad8264c575b94b8e0e2aa26be29 (diff) | |
download | ydb-7b06ae00f14aff82006eeeb5375af1a79ef2cf00.tar.gz |
Upper part of block PG aggregations
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_blocks.cpp | 4 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 117 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.cpp | 71 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.cpp | 48 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.h | 3 |
5 files changed, 163 insertions, 80 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index 3167d31745e..1f7343585f8 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -530,7 +530,7 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType return false; } - if (agg->ChildrenSize() != agg->Head().ChildrenSize()) { + if (agg->ChildrenSize() + (overState ? 1 : 0) != agg->Head().ChildrenSize()) { ctx.AddError(TIssue(ctx.GetPosition(pos), "Different amount of input arguments")); return false; } @@ -548,7 +548,7 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType return false; } - auto applyArgType = agg->Head().Child(i)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + auto applyArgType = agg->Head().Child(i + (overState ? 1 : 0))->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); auto expectedType = many ? ctx.MakeType<TOptionalExprType>(applyArgType) : applyArgType; if (!IsSameAnnotation(*inputItems[argColumnIndex], *expectedType)) { ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index 0de23081a7d..0e259b8e31c 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -23,42 +23,6 @@ namespace { return x->GetTypeAnn() && x->GetTypeAnn()->GetKind() == ETypeAnnotationKind::EmptyList; }; - const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { - if (!EnsureStructType(pos, *originalExtractorType, ctx)) { - return nullptr; - } - - auto structType = originalExtractorType->Cast<TStructExprType>(); - if (structType->GetSize() != 1) { - ctx.AddError(TIssue(ctx.GetPosition(pos), - TStringBuilder() << "Expected struct with one member")); - return nullptr; - } - - auto type = structType->GetItems()[0]->GetItemType(); - if (isMany) { - if (type->GetKind() != ETypeAnnotationKind::Optional) { - ctx.AddError(TIssue(ctx.GetPosition(pos), - TStringBuilder() << "Expected optional state")); - return nullptr; - } - - type = type->Cast<TOptionalExprType>()->GetItemType(); - } - - return type; - } - - bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { - auto type = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx); - if (!type) { - return false; - } - - input->SetTypeAnn(type); - return true; - } - TExprNode::TPtr RewriteMultiAggregate(const TExprNode& node, TExprContext& ctx) { auto exprLambda = node.Child(1); const TStructExprType* structType = nullptr; @@ -5439,7 +5403,7 @@ namespace { IGraphTransformer::TStatus AggBlockApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); const bool overState = input->Content().EndsWith("State"); - if (!EnsureMinArgsCount(*input, 1, ctx.Expr)) { + if (!EnsureMinArgsCount(*input, overState ? 2 : 1, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -5447,23 +5411,38 @@ namespace { return IGraphTransformer::TStatus::Error; } - auto name = input->Child(0)->Content(); - ui32 expectedArgs; - if (name == "count_all") { - expectedArgs = overState ? 2 : 1; - } else if (name == "count" || name == "sum" || name == "avg" || name == "min" || name == "max" || name == "some") { - expectedArgs = 2; - } else { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), - TStringBuilder() << "Unsupported agg name: " << name)); - return IGraphTransformer::TStatus::Error; + const TTypeAnnotationNode* originalType = nullptr; + if (overState && !input->Child(1)->IsCallable("Void")) { + if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) { + return status; + } + + originalType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); } - if (!EnsureArgsCount(*input, expectedArgs, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; + auto name = input->Child(0)->Content(); + if (!name.StartsWith("pg_")) { + ui32 expectedArgs; + if (name == "count_all") { + expectedArgs = overState ? 2 : 1; + } else if (name == "count" || name == "sum" || name == "avg" || name == "min" || name == "max" || name == "some") { + expectedArgs = 2; + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Unsupported agg name: " << name)); + return IGraphTransformer::TStatus::Error; + } + + if (overState) { + ++expectedArgs; + } + + if (!EnsureArgsCount(*input, expectedArgs, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } } - for (ui32 i = 1; i < expectedArgs; ++i) { + for (ui32 i = overState ? 2 : 1; i < input->ChildrenSize(); ++i) { if (auto status = EnsureTypeRewrite(input->ChildRef(i), ctx.Expr); status != IGraphTransformer::TStatus::Ok) { return status; } @@ -5472,7 +5451,7 @@ namespace { if (name == "count_all" || name == "count") { const TTypeAnnotationNode* retType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); if (overState) { - auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + auto itemType = input->Child(2)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); if (!IsSameAnnotation(*itemType, *retType)) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "Mismatch count type, expected: " << *itemType << ", but got: " << *retType)); @@ -5482,7 +5461,7 @@ namespace { input->SetTypeAnn(retType); } else if (name == "sum") { - auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType; if (!GetSumResultType(input->Pos(), *itemType, retType, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -5498,7 +5477,7 @@ namespace { input->SetTypeAnn(retType); } else if (name == "avg") { - auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType; if (!overState) { if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) { @@ -5512,7 +5491,7 @@ namespace { input->SetTypeAnn(retType); } else if (name == "min" || name == "max") { - auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType; if (!GetMinMaxResultType(input->Pos(), *itemType, retType, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -5528,9 +5507,37 @@ namespace { input->SetTypeAnn(retType); } else if (name == "some") { - auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType = itemType; input->SetTypeAnn(retType); + } else if (name.StartsWith("pg_")) { + auto func = name.SubStr(3); + TVector<ui32> argTypes; + for (ui32 i = 1 + (overState ? 1 : 0); i < input->ChildrenSize(); ++i) { + argTypes.push_back(input->Child(i)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TPgExprType>()->GetId()); + } + + const NPg::TAggregateDesc* aggDescPtr; + if (overState) { + YQL_ENSURE(argTypes.size() == 1); + if (!originalType) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Partial aggreation is not supported for: " << name)); + return IGraphTransformer::TStatus::Error; + } + + auto resultType = originalType->Cast<TPgExprType>()->GetId(); + aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes[0], resultType); + } else { + aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes); + } + + if (overState) { + input->SetTypeAnn(originalType); + } else { + auto stateType = NPg::LookupProc(aggDescPtr->SerializeFuncId ? aggDescPtr->SerializeFuncId : aggDescPtr->TransFuncId).ResultType; + input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(stateType)); + } } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "Unsupported agg name: " << name)); diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index 8ac57ad21d4..a5c1cacd623 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -570,20 +570,22 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea auto trait = AggregatedColumns->Child(index)->ChildPtr(1); TVector<const TTypeAnnotationNode*> allTypes; + const TTypeAnnotationNode* originalType = nullptr; + if (overState && !trait->Child(3)->IsCallable("Void")) { + auto originalExtractorType = trait->Child(3)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + originalType = GetOriginalResultType(trait->Pos(), many, originalExtractorType, Ctx); + YQL_ENSURE(originalType); + } + + ui32 argsCount = trait->Child(2)->ChildrenSize() - 1; if (!overState && trait->Child(0)->Content() == "count_all") { - // 0 columns - aggs.push_back(Ctx.Builder(Node->Pos()) - .List() - .Callable(0, "AggBlockApply") - .Atom(0, trait->Child(0)->Content()) - .Seal() - .Seal() - .Build()); + argsCount = 0; } - else { - // 1 column - auto root = trait->Child(2)->TailPtr(); - auto rowArg = &trait->Child(2)->Head().Head(); + + auto rowArg = &trait->Child(2)->Head().Head(); + TVector<TExprNode::TPtr> roots; + for (ui32 i = 1; i < argsCount + 1; ++i) { + auto root = trait->Child(2)->ChildPtr(i); allTypes.push_back(root->GetTypeAnn()); auto status = OptimizeExpr(root, root, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { @@ -598,17 +600,46 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea }, Ctx, TOptimizeExprSettings(&TypesCtx)); YQL_ENSURE(status.Level != IGraphTransformer::TStatus::Error); + roots.push_back(root); + } - aggs.push_back(Ctx.Builder(Node->Pos()) - .List() - .Callable(0, TString("AggBlockApply") + (overState ? "State" : "")) - .Atom(0, trait->Child(0)->Content()) - .Add(1, ExpandType(Node->Pos(), *trait->Child(2)->GetTypeAnn(), Ctx)) - .Seal() - .Atom(1, ToString(extractorRoots.size())) + aggs.push_back(Ctx.Builder(Node->Pos()) + .List() + .Callable(0, TString("AggBlockApply") + (overState ? "State" : "")) + .Atom(0, trait->Child(0)->Content()) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (overState) { + if (originalType) { + parent.Add(1, ExpandType(Node->Pos(), *originalType, Ctx)); + } else { + parent + .Callable(1, "Null") + .Seal(); + } + } + + return parent; + }) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 1; i < argsCount + 1; ++i) { + parent.Add(i + (overState ? 1 : 0), ExpandType(Node->Pos(), *trait->Child(2)->Child(i)->GetTypeAnn(), Ctx)); + return parent; + } + + return parent; + }) .Seal() - .Build()); + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 1; i < argsCount + 1; ++i) { + parent.Atom(i, ToString(extractorRoots.size() + i - 1)); + } + + return parent; + }) + .Seal() + .Build()); + for (auto root : roots) { if (many) { if (root->IsCallable("Unwrap")) { root = root->HeadPtr(); diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 12135cbbf70..79ca51ece68 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -5569,9 +5569,15 @@ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& in } else if (name.StartsWith("pg_")) { auto func = name.SubStr(3); TVector<ui32> argTypes; - bool needRetype = false; - auto status = ExtractPgTypesFromMultiLambda(input->ChildRef(2), argTypes, needRetype, ctx); - YQL_ENSURE(status == IGraphTransformer::TStatus::Ok); + if (input->Content().StartsWith("AggBlock")) { + for (ui32 i = 1; i < input->ChildrenSize(); ++i) { + argTypes.push_back(input->Child(i)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TPgExprType>()->GetId()); + } + } else { + bool needRetype = false; + auto status = ExtractPgTypesFromMultiLambda(input->ChildRef(2), argTypes, needRetype, ctx); + YQL_ENSURE(status == IGraphTransformer::TStatus::Ok); + } const NPg::TAggregateDesc& aggDesc = NPg::LookupAggregation(TString(func), argTypes); const auto& procDesc = NPg::LookupProc(aggDesc.SerializeFuncId ? aggDesc.SerializeFuncId : aggDesc.TransFuncId); @@ -6036,5 +6042,41 @@ TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggre } } +const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { + if (!EnsureStructType(pos, *originalExtractorType, ctx)) { + return nullptr; + } + + auto structType = originalExtractorType->Cast<TStructExprType>(); + if (structType->GetSize() != 1) { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Expected struct with one member")); + return nullptr; + } + + auto type = structType->GetItems()[0]->GetItemType(); + if (isMany) { + if (type->GetKind() != ETypeAnnotationKind::Optional) { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Expected optional state")); + return nullptr; + } + + type = type->Cast<TOptionalExprType>()->GetItemType(); + } + + return type; +} + +bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { + auto type = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx); + if (!type) { + return false; + } + + input->SetTypeAnn(type); + return true; +} + } // NYql diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h index 6daf85dcca4..a2b665e562b 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.h +++ b/ydb/library/yql/core/yql_expr_type_annotation.h @@ -324,4 +324,7 @@ IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow, const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx); +const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx); +bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx); + } |