diff options
author | vvvv <vvvv@ydb.tech> | 2022-12-19 20:34:44 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-12-19 20:34:44 +0300 |
commit | 0b60d14b135627201f472fceca644d6a7e01019f (patch) | |
tree | af0c5be2fcaf38dd2a17e3fcd134dcd233ae2e70 | |
parent | 81a58118787d9507ae39aeecaf003285bc3da5f6 (diff) | |
download | ydb-0b60d14b135627201f472fceca644d6a7e01019f.tar.gz |
support of final aggregation by keys (DQ/YT)
17 files changed, 513 insertions, 98 deletions
diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json index 0c31c72546..8698b9d604 100644 --- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json +++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json @@ -2260,6 +2260,15 @@ {"Index": 2, "Name": "Keys", "Type": "TCoAtomList"}, {"Index": 3, "Name": "Aggregations", "Type": "TExprList"} ] + }, + { + "Name": "TCoShuffleByKeys", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "ShuffleByKeys"}, + "Children": [ + {"Index": 1, "Name": "KeySelectorLambda", "Type": "TCoLambda"}, + {"Index": 2, "Name": "ListHandlerLambda", "Type": "TCoLambda"} + ] } ] } diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 20d6fae8d5..1dd2d05f8e 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -68,13 +68,23 @@ TExprNode::TPtr Now0Arg(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnn } bool IsArgumentsOnlyLambda(const TExprNode& lambda, TVector<ui32>& argIndices) { + TNodeMap<ui32> args; + for (ui32 i = 0; i < lambda.Head().ChildrenSize(); ++i) { + args.insert(std::make_pair(lambda.Head().Child(i), i)); + } + for (ui32 i = 1; i < lambda.ChildrenSize(); ++i) { auto root = lambda.Child(i); - if (!root->IsArgument() || root->GetLambdaLevel() > 0) { + if (!root->IsArgument()) { + return false; + } + + auto it = args.find(root); + if (it == args.end()) { return false; } - argIndices.push_back(root->GetArgIndex()); + argIndices.push_back(it->second); } return true; @@ -2456,7 +2466,7 @@ TExprNode::TPtr ExpandMux(const TExprNode::TPtr& node, TExprContext& ctx) { return node; } -TExprNode::TPtr ExpandLMap(const TExprNode::TPtr& node, TExprContext& ctx) { +TExprNode::TPtr ExpandLMapOrShuffleByKeys(const TExprNode::TPtr& node, TExprContext& ctx) { YQL_CLOG(DEBUG, CorePeepHole) << "Expand " << node->Content(); return ctx.Builder(node->Pos()) .Callable("Collect") @@ -6678,8 +6688,9 @@ struct TPeepHoleRules { {"OrderedFilter", &ExpandFilter}, {"TakeWhile", &ExpandFilter<false>}, {"SkipWhile", &ExpandFilter<true>}, - {"LMap", &ExpandLMap}, - {"OrderedLMap", &ExpandLMap}, + {"LMap", &ExpandLMapOrShuffleByKeys}, + {"OrderedLMap", &ExpandLMapOrShuffleByKeys}, + {"ShuffleByKeys", &ExpandLMapOrShuffleByKeys}, {"ExpandMap", &OptimizeExpandMap}, {"MultiMap", &OptimizeMultiMap<false>}, {"OrderedMultiMap", &OptimizeMultiMap<true>}, diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index bcf423eac7..e95cd3cd33 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -251,8 +251,31 @@ IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TEx return IGraphTransformer::TStatus::Ok; } -bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType inputItems, const TExprNode& aggs, - TTypeAnnotationNode::TListType& retMultiType, TExprContext& ctx) { +bool ValidateBlockKeys(TPositionHandle pos, const TTypeAnnotationNode::TListType& inputItems, + const TExprNode& keys, TTypeAnnotationNode::TListType& retMultiType, TExprContext& ctx) { + if (!EnsureTupleMinSize(keys, 1, ctx)) { + return IGraphTransformer::TStatus::Error; + } + + for (auto child : keys.Children()) { + if (!EnsureAtom(*child, ctx)) { + return false; + } + + ui32 keyColumnIndex; + if (!TryFromString(child->Content(), keyColumnIndex) || keyColumnIndex >= inputItems.size()) { + ctx.AddError(TIssue(ctx.GetPosition(pos), "Bad key column index")); + return false; + } + + retMultiType.push_back(inputItems[keyColumnIndex]); + } + + return true; +} + +bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType& inputItems, const TExprNode& aggs, + TTypeAnnotationNode::TListType& retMultiType, TExprContext& ctx, bool overState) { if (!EnsureTuple(aggs, ctx)) { return false; } @@ -262,8 +285,9 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType return false; } - if (!agg->Head().IsCallable("AggBlockApply")) { - ctx.AddError(TIssue(ctx.GetPosition(pos), "Expected AggBlockApply")); + auto expectedCallable = overState ? "AggBlockApplyState" : "AggBlockApply"; + if (!agg->Head().IsCallable(expectedCallable)) { + ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected: " << expectedCallable)); return false; } @@ -287,7 +311,8 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType } } - retMultiType.push_back(AggApplySerializedStateType(agg->HeadPtr(), ctx)); + auto retAggType = overState ? agg->HeadPtr()->GetTypeAnn() : AggApplySerializedStateType(agg->HeadPtr(), ctx); + retMultiType.push_back(retAggType); } return true; @@ -321,7 +346,7 @@ IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, } TTypeAnnotationNode::TListType retMultiType; - if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr)) { + if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr, false)) { return IGraphTransformer::TStatus::Error; } @@ -362,21 +387,41 @@ IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& inpu } TTypeAnnotationNode::TListType retMultiType; - for (auto child : input->Child(2)->Children()) { - if (!EnsureAtom(*child, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } + if (!ValidateBlockKeys(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } - ui32 keyColumnIndex; - if (!TryFromString(child->Content(), keyColumnIndex) || keyColumnIndex >= blockItemTypes.size()) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Bad key column index")); - return IGraphTransformer::TStatus::Error; - } + if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(3), retMultiType, ctx.Expr, false)) { + return IGraphTransformer::TStatus::Error; + } + + for (auto& t : retMultiType) { + t = ctx.Expr.MakeType<TBlockExprType>(t); + } + + retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64))); + auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); + input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); + return IGraphTransformer::TStatus::Ok; +} - retMultiType.push_back(blockItemTypes[keyColumnIndex]); +IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 3U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + TTypeAnnotationNode::TListType blockItemTypes; + if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + TTypeAnnotationNode::TListType retMultiType; + if (!ValidateBlockKeys(input->Pos(), blockItemTypes, *input->Child(1), retMultiType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; } - if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(3), retMultiType, ctx.Expr)) { + if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr, true)) { return IGraphTransformer::TStatus::Error; } diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h index b8fd6fb9b8..e461142643 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.h +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h @@ -16,6 +16,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); + IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); } // namespace NTypeAnnImpl } // namespace NYql diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 54931a0cfa..361f97dec5 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11364,6 +11364,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["Chain1Map"] = &Chain1MapWrapper; Functions["LMap"] = &LMapWrapper; Functions["OrderedLMap"] = &LMapWrapper; + Functions["ShuffleByKeys"] = &ShuffleByKeysWrapper; Functions["Struct"] = &StructWrapper; Functions["AddMember"] = &AddMemberWrapper; Functions["RemoveMember"] = &RemoveMemberWrapper<false>; @@ -11584,6 +11585,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["AggApply"] = &AggApplyWrapper; Functions["AggApplyState"] = &AggApplyWrapper; Functions["AggBlockApply"] = &AggBlockApplyWrapper; + Functions["AggBlockApplyState"] = &AggBlockApplyWrapper; Functions["WinOnRows"] = &WinOnWrapper; Functions["WinOnGroups"] = &WinOnWrapper; Functions["WinOnRange"] = &WinOnWrapper; @@ -11792,6 +11794,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper; ExtFunctions["BlockCombineAll"] = &BlockCombineAllWrapper; ExtFunctions["BlockCombineHashed"] = &BlockCombineHashedWrapper; + ExtFunctions["BlockMergeFinalizeHashed"] = &BlockMergeFinalizeHashedWrapper; Functions["AsRange"] = &AsRangeWrapper; Functions["RangeCreate"] = &RangeCreateWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index 040dbb9e04..210d5c9533 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -871,6 +871,74 @@ namespace { return IGraphTransformer::TStatus::Ok; } + IGraphTransformer::TStatus ShuffleByKeysWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 3, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureListType(input->Head(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto itemType = input->Head().GetTypeAnn()->Cast<TListExprType>()->GetItemType(); + auto& lambdaKeySelector = input->ChildRef(1); + auto status = ConvertToLambda(lambdaKeySelector, ctx.Expr, 1); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(lambdaKeySelector, {itemType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!lambdaKeySelector->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + auto keyType = lambdaKeySelector->GetTypeAnn(); + if (!EnsureHashableKey(lambdaKeySelector->Pos(), keyType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureEquatableKey(lambdaKeySelector->Pos(), keyType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto& lambdaHandler = input->ChildRef(2); + status = ConvertToLambda(lambdaHandler, ctx.Expr, 1); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + auto handlerStreamType = ctx.Expr.MakeType<TStreamExprType>(itemType); + + if (!UpdateLambdaAllArgumentsTypes(lambdaHandler, { handlerStreamType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!lambdaHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + if (!EnsureSeqOrOptionalType(lambdaHandler->Pos(), *lambdaHandler->GetTypeAnn(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto retKind = lambdaHandler->GetTypeAnn()->GetKind(); + const TTypeAnnotationNode* retItemType; + if (retKind == ETypeAnnotationKind::List) { + retItemType = lambdaHandler->GetTypeAnn()->Cast<TListExprType>()->GetItemType(); + } else if (retKind == ETypeAnnotationKind::Optional) { + retItemType = lambdaHandler->GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType(); + } else { + retItemType = lambdaHandler->GetTypeAnn()->Cast<TStreamExprType>()->GetItemType(); + } + + input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(retItemType)); + return IGraphTransformer::TStatus::Ok; + } + IGraphTransformer::TStatus FoldMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { if (!EnsureArgsCount(*input, 3, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -5155,7 +5223,16 @@ namespace { } if (name == "count" || name == "count_all") { - input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)); + const TTypeAnnotationNode* retType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); + if (overState) { + if (!IsSameAnnotation(*lambda->GetTypeAnn(), *retType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch count type, expected: " << *lambda->GetTypeAnn() << ", but got: " << *retType)); + return IGraphTransformer::TStatus::Error; + } + } + + input->SetTypeAnn(retType); } else if (name == "sum") { const TTypeAnnotationNode* retType; if (!GetSumResultType(input->Pos(), *lambda->GetTypeAnn(), retType, ctx.Expr)) { @@ -5178,42 +5255,8 @@ namespace { return IGraphTransformer::TStatus::Error; } } else { - auto itemType = lambda->GetTypeAnn(); - if (IsNull(*itemType)) { - retType = itemType; - } else { - bool isOptional = false; - if (itemType->GetKind() == ETypeAnnotationKind::Optional) { - isOptional = true; - itemType = itemType->Cast<TOptionalExprType>()->GetItemType(); - } - - if (!EnsureTupleTypeSize(lambda->Pos(), itemType, 2, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - - auto tupleType = itemType->Cast<TTupleExprType>(); - auto sumType = tupleType->GetItems()[0]; - const TTypeAnnotationNode* sumTypeOut; - if (!GetSumResultType(input->Pos(), *sumType, sumTypeOut, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - - if (!IsSameAnnotation(*sumType, *sumTypeOut)) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), - TStringBuilder() << "Mismatch sum type, expected: " << *sumType << ", but got: " << *sumTypeOut)); - return IGraphTransformer::TStatus::Error; - } - - auto countType = tupleType->GetItems()[1]; - if (!EnsureSpecificDataType(lambda->Pos(), *countType, EDataSlot::Uint64, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - - retType = sumType; - if (isOptional) { - retType = ctx.Expr.MakeType<TOptionalExprType>(retType); - } + if (!GetAvgResultTypeOverState(input->Pos(), *lambda->GetTypeAnn(), retType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; } } @@ -5240,6 +5283,14 @@ namespace { return IGraphTransformer::TStatus::Error; } + if (overState) { + if (!IsSameAnnotation(*lambda->GetTypeAnn(), *retType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch min/max type, expected: " << *lambda->GetTypeAnn() << ", but got: " << *retType)); + return IGraphTransformer::TStatus::Error; + } + } + input->SetTypeAnn(retType); } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), @@ -5252,6 +5303,7 @@ namespace { IGraphTransformer::TStatus AggBlockApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); + const bool overState = input->Content().EndsWith("State"); if (!EnsureMinArgsCount(*input, 1, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -5263,7 +5315,7 @@ namespace { auto name = input->Child(0)->Content(); ui32 expectedArgs; if (name == "count_all") { - expectedArgs = 1; + expectedArgs = overState ? 2 : 1; } else if (name == "count" || name == "sum" || name == "avg" || name == "min" || name == "max") { expectedArgs = 2; } else { @@ -5283,7 +5335,17 @@ namespace { } if (name == "count_all" || name == "count") { - input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)); + const TTypeAnnotationNode* retType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); + if (overState) { + auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + if (!IsSameAnnotation(*itemType, *retType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch count type, expected: " << *itemType << ", but got: " << *retType)); + return IGraphTransformer::TStatus::Error; + } + } + + input->SetTypeAnn(retType); } else if (name == "sum") { auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType; @@ -5291,12 +5353,26 @@ namespace { return IGraphTransformer::TStatus::Error; } + if (overState) { + if (!IsSameAnnotation(*itemType, *retType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch sum type, expected: " << *itemType << ", but got: " << *retType)); + return IGraphTransformer::TStatus::Error; + } + } + input->SetTypeAnn(retType); } else if (name == "avg") { auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType; - if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; + if (!overState) { + if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } else { + if (!GetAvgResultTypeOverState(input->Pos(), *itemType, retType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } } input->SetTypeAnn(retType); @@ -5307,6 +5383,14 @@ namespace { return IGraphTransformer::TStatus::Error; } + if (overState) { + if (!IsSameAnnotation(*itemType, *retType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch min/max type, expected: " << *itemType << ", but got: " << *retType)); + return IGraphTransformer::TStatus::Error; + } + } + input->SetTypeAnn(retType); } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h index afeaa54f7b..6e04740df3 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.h +++ b/ydb/library/yql/core/type_ann/type_ann_list.h @@ -17,6 +17,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus MapNextWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus LMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus ShuffleByKeysWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); template <bool Warn> IGraphTransformer::TStatus FlatMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); template <bool Ordered> diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index 6ffce1fd3d..79f1dc8cc2 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -63,6 +63,13 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate() return ret; } } + + if (Suffix == "MergeFinalize") { + auto ret = TryGenerateBlockMergeFinalize(); + if (ret) { + return ret; + } + } } if (!allTraitsCollected) { @@ -492,14 +499,8 @@ TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) { .Build(); } -TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { - if (!TypesCtx.ArrowResolver) { - return nullptr; - } - - const bool hashed = (KeyColumns->ChildrenSize() > 0); - - auto streamArg = Ctx.NewArgument(Node->Pos(), "stream"); +TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs, + TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState) { auto flow = Ctx.NewCallable(Node->Pos(), "ToFlow", { streamArg }); TVector<TString> inputColumns; for (ui32 i = 0; i < RowType->GetSize(); ++i) { @@ -514,9 +515,6 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { } TExprNode::TListType extractorRoots; - TExprNode::TListType aggs; - TVector<TString> outputColumns; - TExprNode::TListType keyIdxs; TVector<const TTypeAnnotationNode*> allKeyTypes; for (ui32 index = 0; index < KeyColumns->ChildrenSize(); ++index) { auto keyName = KeyColumns->Child(index)->Content(); @@ -538,7 +536,7 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { for (ui32 index = 0; index < AggregatedColumns->ChildrenSize(); ++index) { auto trait = AggregatedColumns->Child(index)->ChildPtr(1); - if (trait->Child(0)->Content() == "count_all") { + if (!overState && trait->Child(0)->Content() == "count_all") { // 0 columns aggs.push_back(Ctx.Builder(Node->Pos()) .List() @@ -547,7 +545,8 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { .Seal() .Seal() .Build()); - } else { + } + else { // 1 column auto root = trait->Child(2)->TailPtr(); auto rowArg = &trait->Child(2)->Head().Head(); @@ -575,7 +574,7 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { aggs.push_back(Ctx.Builder(Node->Pos()) .List() - .Callable(0, "AggBlockApply") + .Callable(0, TString("AggBlockApply") + (overState ? "State" : "")) .Atom(0, trait->Child(0)->Content()) .Add(1, ExpandType(Node->Pos(), *trait->Child(2)->GetTypeAnn(), Ctx)) .Seal() @@ -592,6 +591,25 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { auto extractorLambda = Ctx.NewLambda(Node->Pos(), Ctx.NewArguments(Node->Pos(), std::move(extractorArgs)), std::move(extractorRoots)); auto mappedWideFlow = Ctx.NewCallable(Node->Pos(), "WideMap", { wideFlow, extractorLambda }); auto blocks = Ctx.NewCallable(Node->Pos(), "WideToBlocks", { mappedWideFlow }); + return blocks; +} + +TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { + if (!TypesCtx.ArrowResolver) { + return nullptr; + } + + const bool hashed = (KeyColumns->ChildrenSize() > 0); + + auto streamArg = Ctx.NewArgument(Node->Pos(), "stream"); + TExprNode::TListType keyIdxs; + TVector<TString> outputColumns; + TExprNode::TListType aggs; + auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, false); + if (!blocks) { + return nullptr; + } + TExprNode::TPtr aggWideFlow; if (hashed) { aggWideFlow = Ctx.Builder(Node->Pos()) @@ -2234,4 +2252,65 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombine() { return TryGenerateBlockCombineAllOrHashed(); } +TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalize() { + if (UsePartitionsByKeys) { + return nullptr; + } + + if (HaveSessionSetting || HaveDistinct) { + return nullptr; + } + + for (const auto& x : AggregatedColumns->Children()) { + auto trait = x->ChildPtr(1); + if (!trait->IsCallable("AggApplyState")) { + return nullptr; + } + } + + return TryGenerateBlockMergeFinalizeHashed(); +} + +TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalizeHashed() { + if (!TypesCtx.ArrowResolver) { + return nullptr; + } + + if (KeyColumns->ChildrenSize() == 0) { + return nullptr; + } + + auto streamArg = Ctx.NewArgument(Node->Pos(), "stream"); + TExprNode::TListType keyIdxs; + TVector<TString> outputColumns; + TExprNode::TListType aggs; + auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, true); + if (!blocks) { + return nullptr; + } + + auto aggWideFlow = Ctx.Builder(Node->Pos()) + .Callable("WideFromBlocks") + .Callable(0, "BlockMergeFinalizeHashed") + .Add(0, blocks) + .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) + .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) + .Seal() + .Seal() + .Build(); + + auto finalFlow = MakeNarrowMap(Node->Pos(), outputColumns, aggWideFlow, Ctx); + auto root = Ctx.NewCallable(Node->Pos(), "FromFlow", { finalFlow }); + auto lambdaStream = Ctx.NewLambda(Node->Pos(), Ctx.NewArguments(Node->Pos(), { streamArg }), std::move(root)); + + auto keySelector = BuildKeySelector(Node->Pos(), *OriginalRowType, KeyColumns, Ctx); + return Ctx.Builder(Node->Pos()) + .Callable("ShuffleByKeys") + .Add(0, AggList) + .Add(1, keySelector) + .Add(2, lambdaStream) + .Seal() + .Build(); +} + } // namespace NYql diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h index 63f1cc9dfc..7695c7cf6b 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.h +++ b/ydb/library/yql/core/yql_aggregate_expander.h @@ -8,12 +8,13 @@ namespace NYql { class TAggregateExpander { public: - TAggregateExpander(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, + TAggregateExpander(bool allowPickle, bool usePartitionsByKeys, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool forceCompact = false, bool compactForDistinct = false, bool usePhases = false) : Node(node) , Ctx(ctx) , TypesCtx(typesCtx) , AllowPickle(allowPickle) + , UsePartitionsByKeys(usePartitionsByKeys) , ForceCompact(forceCompact) , CompactForDistinct(compactForDistinct) , UsePhases(usePhases) @@ -76,7 +77,11 @@ private: void GenerateInitForDistinct(TExprNodeBuilder& parent, ui32& ndx, const TIdxSet& indicies, const TExprNode::TPtr& distinctField); TExprNode::TPtr GenerateJustOverStates(const TExprNode::TPtr& input, const TIdxSet& indicies); TExprNode::TPtr TryGenerateBlockCombineAllOrHashed(); + TExprNode::TPtr TryGenerateBlockMergeFinalizeHashed(); TExprNode::TPtr TryGenerateBlockCombine(); + TExprNode::TPtr TryGenerateBlockMergeFinalize(); + TExprNode::TPtr MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs, + TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState); private: static constexpr TStringBuf SessionStartMemberName = "_yql_group_session_start"; @@ -85,6 +90,7 @@ private: TExprContext& Ctx; TTypeAnnotationContext& TypesCtx; bool AllowPickle; + bool UsePartitionsByKeys; bool ForceCompact; bool CompactForDistinct; bool UsePhases; @@ -121,7 +127,7 @@ private: }; inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) { - TAggregateExpander aggExpander(false, node, ctx, typesCtx, true); + TAggregateExpander aggExpander(false, true, node, ctx, typesCtx, true); return aggExpander.ExpandAggregate(); } diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 040ff5dd49..008dcc700e 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -5405,10 +5405,10 @@ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& in } } -bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { +bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { bool isOptional; const TDataExprType* lambdaType; - if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) { + if(IsDataOrOptionalOfData(&inputType, isOptional, lambdaType)) { auto lambdaTypeSlot = lambdaType->GetSlot(); const TTypeAnnotationNode *sumResultType = nullptr; if (IsDataTypeSigned(lambdaTypeSlot)) { @@ -5432,28 +5432,28 @@ bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& ite retType = sumResultType; return true; - } else if (IsNull(itemType)) { + } else if (IsNull(inputType)) { retType = ctx.MakeType<TNullExprType>(); return true; } else { ctx.AddError(TIssue(ctx.GetPosition(pos), - TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data.")); + TStringBuilder() << "Unsupported type: " << FormatType(&inputType) << ". Expected Data or Optional of Data.")); return false; } } -bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { +bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { bool isOptional; const TDataExprType* lambdaType; - if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) { + if(IsDataOrOptionalOfData(&inputType, isOptional, lambdaType)) { auto lambdaTypeSlot = lambdaType->GetSlot(); const TTypeAnnotationNode *avgResultType = nullptr; if (IsDataTypeNumeric(lambdaTypeSlot)) { avgResultType = ctx.MakeType<TDataExprType>(EDataSlot::Double); } else if (IsDataTypeDecimal(lambdaTypeSlot)) { - avgResultType = &itemType; + avgResultType = &inputType; } else if (IsDataTypeInterval(lambdaTypeSlot)) { - avgResultType = &itemType; + avgResultType = &inputType; } else { ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot)); @@ -5466,23 +5466,65 @@ bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& ite retType = avgResultType; return true; - } else if (IsNull(itemType)) { + } else if (IsNull(inputType)) { retType = ctx.MakeType<TNullExprType>(); return true; } else { ctx.AddError(TIssue(ctx.GetPosition(pos), - TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data.")); + TStringBuilder() << "Unsupported type: " << FormatType(&inputType) << ". Expected Data or Optional of Data.")); return false; } } -bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { - if (!itemType.IsComparable()) { - ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected comparable type, but got: " << itemType)); +bool GetAvgResultTypeOverState(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { + if (IsNull(inputType)) { + retType = &inputType; + } else { + auto itemType = &inputType; + bool isOptional = false; + if (itemType->GetKind() == ETypeAnnotationKind::Optional) { + isOptional = true; + itemType = itemType->Cast<TOptionalExprType>()->GetItemType(); + } + + if (!EnsureTupleTypeSize(pos, itemType, 2, ctx)) { + return false; + } + + auto tupleType = itemType->Cast<TTupleExprType>(); + auto sumType = tupleType->GetItems()[0]; + const TTypeAnnotationNode* sumTypeOut; + if (!GetSumResultType(pos, *sumType, sumTypeOut, ctx)) { + return false; + } + + if (!IsSameAnnotation(*sumType, *sumTypeOut)) { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Mismatch sum type, expected: " << *sumType << ", but got: " << *sumTypeOut)); + return false; + } + + auto countType = tupleType->GetItems()[1]; + if (!EnsureSpecificDataType(pos, *countType, EDataSlot::Uint64, ctx)) { + return false; + } + + retType = sumType; + if (isOptional) { + retType = ctx.MakeType<TOptionalExprType>(retType); + } + } + + return true; +} + +bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { + if (!inputType.IsComparable()) { + ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected comparable type, but got: " << inputType)); return false; } - retType = &itemType; + retType = &inputType; return true; } diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h index 59a2ecaa35..e6a19bb723 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.h +++ b/ydb/library/yql/core/yql_expr_type_annotation.h @@ -302,8 +302,9 @@ bool EnsureBlockOrScalarType(TPositionHandle position, const TTypeAnnotationNode const TTypeAnnotationNode* GetBlockItemType(const TTypeAnnotationNode& type, bool& isScalar); const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx); -bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx); -bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx); -bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx); +bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx); +bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx); +bool GetAvgResultTypeOverState(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx); +bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx); } diff --git a/ydb/library/yql/dq/opt/dq_opt_log.cpp b/ydb/library/yql/dq/opt/dq_opt_log.cpp index b8c97743e0..27803e6e54 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_log.cpp @@ -18,7 +18,7 @@ TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationC return node; } - TAggregateExpander aggExpander(true, node.Ptr(), ctx, typesCtx, false, compactForDistinct, usePhases); + TAggregateExpander aggExpander(true, false, node.Ptr(), ctx, typesCtx, false, compactForDistinct, usePhases); auto result = aggExpander.ExpandAggregate(); YQL_ENSURE(result); diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp index 276ca6b748..72d8dab524 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp @@ -926,6 +926,100 @@ TExprBase DqBuildPartitionStage(TExprBase node, TExprContext& ctx, const TParent return DqBuildPartitionsStageStub<TCoPartitionByKey>(std::move(node), ctx, parentsMap); } +TExprBase DqBuildShuffleStage(TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap) { + auto shuffleInput = node.Maybe<TCoShuffleByKeys>().Input(); + if (!shuffleInput.Maybe<TDqCnUnionAll>()) { + return node; + } + + auto shuffle = node.Cast<TCoShuffleByKeys>(); + if (!IsDqPureExpr(shuffle.KeySelectorLambda()) || + !IsDqPureExpr(shuffle.ListHandlerLambda())) + { + return node; + } + + auto dqUnion = shuffle.Input().Cast<TDqCnUnionAll>(); + + if (!IsSingleConsumerConnection(dqUnion, parentsMap)) { + return node; + } + + auto keyLambda = shuffle.KeySelectorLambda(); + TVector<TExprBase> keyElements; + if (auto maybeTuple = keyLambda.Body().Maybe<TExprList>()) { + auto tuple = maybeTuple.Cast(); + for (const auto& element : tuple) { + keyElements.push_back(element); + } + } else { + keyElements.push_back(keyLambda.Body()); + } + + TVector<TCoAtom> keyColumns; + keyColumns.reserve(keyElements.size()); + for (auto& element : keyElements) { + if (!element.Maybe<TCoMember>()) { + return node; + } + + auto member = element.Cast<TCoMember>(); + if (member.Struct().Raw() != keyLambda.Args().Arg(0).Raw()) { + return node; + } + + keyColumns.push_back(member.Name()); + } + + if (keyColumns.empty()) { + return node; + } + + auto connection = Build<TDqCnHashShuffle>(ctx, node.Pos()) + .Output() + .Stage(dqUnion.Output().Stage()) + .Index(dqUnion.Output().Index()) + .Build() + .KeyColumns() + .Add(keyColumns) + .Build() + .Done(); + + TCoArgument programArg = Build<TCoArgument>(ctx, node.Pos()) + .Name("arg") + .Done(); + + TVector<TCoArgument> inputArgs; + TVector<TExprBase> inputConns; + + inputConns.push_back(connection); + inputArgs.push_back(programArg); + + auto handler = shuffle.ListHandlerLambda(); + + auto shuffleStage = Build<TDqStage>(ctx, node.Pos()) + .Inputs() + .Add(inputConns) + .Build() + .Program() + .Args(inputArgs) + .Body<TCoToStream>() + .Input<TExprApplier>() + .Apply(handler) + .With(handler.Args().Arg(0), programArg) + .Build() + .Build() + .Build() + .Settings(TDqStageSettings().BuildNode(ctx, node.Pos())) + .Done(); + + return Build<TDqCnUnionAll>(ctx, node.Pos()) + .Output() + .Stage(shuffleStage) + .Index().Build("0") + .Build() + .Done(); +} /* * Optimizer rule which handles a switch to scalar expression context for aggregation results. diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h index e0b26e9429..482277bbe8 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.h +++ b/ydb/library/yql/dq/opt/dq_opt_phy.h @@ -46,6 +46,8 @@ NNodes::TExprBase DqBuildPartitionsStage(NNodes::TExprBase node, TExprContext& c NNodes::TExprBase DqBuildPartitionStage(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap); +NNodes::TExprBase DqBuildShuffleStage(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap); + NNodes::TExprBase DqBuildAggregationResultStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx); diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h index 8f3744bdc3..a37b84c33b 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h @@ -4,6 +4,8 @@ #include <util/generic/yexception.h> #include <vector> +#include <util/digest/city.h> + namespace NKikimr { namespace NMiniKQL { @@ -17,6 +19,7 @@ protected: explicit TRobinHoodHashBase(ui64 initialCapacity = 1u << 8) : Capacity(initialCapacity) + , SelfHash(GetSelfHash(this)) { Y_ENSURE((Capacity & (Capacity - 1)) == 0); } @@ -91,7 +94,7 @@ public: private: Y_FORCE_INLINE char* InsertImpl(TKey key, bool& isNew, ui64 capacity, TVec& data) { isNew = false; - ui64 bucket = THash()(key) & (capacity - 1); + ui64 bucket = (SelfHash ^ THash()(key)) & (capacity - 1); char* ptr = data.data() + AsDeriv().GetCellSize() * bucket; TPSLStorage distance = 0; char* returnPtr; @@ -168,6 +171,12 @@ private: ptr = (ptr == data.data() + data.size()) ? data.data() : ptr; } + static ui64 GetSelfHash(void* self) { + char buf[sizeof(void*)]; + *(void**)buf = self; + return CityHash64(buf, sizeof(buf)); + } + protected: void Init() { Allocate(Capacity, Data); @@ -195,6 +204,7 @@ private: ui64 Size = 0; ui64 Capacity; TVec Data; + const ui64 SelfHash; }; template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>> diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index 321e345287..dff429c515 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2453,6 +2453,28 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.BlockCombineHashed(arg, filterColumn, keys, aggs, returnType); }); + AddCallable("BlockMergeFinalizeHashed", [](const TExprNode& node, TMkqlBuildContext& ctx) { + auto arg = MkqlBuildExpr(*node.Child(0), ctx); + TVector<ui32> keys; + for (const auto& key : node.Child(1)->Children()) { + keys.push_back(FromString<ui32>(key->Content())); + } + + TVector<TAggInfo> aggs; + for (const auto& agg : node.Child(2)->Children()) { + TAggInfo info; + info.Name = TString(agg->Head().Head().Content()); + for (ui32 i = 1; i < agg->ChildrenSize(); ++i) { + info.ArgsColumns.push_back(FromString<ui32>(agg->Child(i)->Content())); + } + + aggs.push_back(info); + } + + auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + return ctx.ProgramBuilder.BlockMergeFinalizeHashed(arg, keys, aggs, returnType); + }); + AddCallable("BlockCompress", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto flow = MkqlBuildExpr(node.Head(), ctx); const auto index = FromString<ui32>(node.Child(1)->Content()); diff --git a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp index e0c6f75cab..f89ed69a64 100644 --- a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp @@ -33,6 +33,7 @@ public: AddHandler(0, &TCoFlatMapBase::Match, HNDL(BuildFlatmapStage<false>)); AddHandler(0, &TCoCombineByKey::Match, HNDL(PushCombineToStage<false>)); AddHandler(0, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage)); + AddHandler(0, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage)); AddHandler(0, &TCoPartitionByKey::Match, HNDL(BuildPartitionStage)); AddHandler(0, &TCoAsList::Match, HNDL(BuildAggregationResultStage)); AddHandler(0, &TCoTopSort::Match, HNDL(BuildTopSortStage<false>)); @@ -272,6 +273,10 @@ protected: return DqBuildPartitionStage(node, ctx, *getParents()); } + TMaybeNode<TExprBase> BuildShuffleStage(TExprBase node, TExprContext& ctx, const TGetParents& getParents) { + return DqBuildShuffleStage(node, ctx, *getParents()); + } + TMaybeNode<TExprBase> BuildAggregationResultStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx) { return DqBuildAggregationResultStage(node, ctx, optCtx); } |