diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-03-07 21:51:04 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-03-07 21:51:04 +0300 |
commit | 242a9ac0056883cc899c7fb56c27eb44cf23d96f (patch) | |
tree | 3fbe2b0b92471995711f869c558b84cfd1a1e849 | |
parent | ee1ab05c311bc7204bd88a00e9db01331a078961 (diff) | |
download | ydb-242a9ac0056883cc899c7fb56c27eb44cf23d96f.tar.gz |
YQL-15748 Fix push cast join keys for left and right to single stage.
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_join.cpp | 46 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_phy.cpp | 62 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_phy.h | 2 |
3 files changed, 72 insertions, 38 deletions
diff --git a/ydb/library/yql/dq/opt/dq_opt_join.cpp b/ydb/library/yql/dq/opt/dq_opt_join.cpp index 47058f2cd5..aca678f719 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join.cpp @@ -904,17 +904,16 @@ TExprNode::TPtr SqueezeJoinInputToDict(TExprNode::TPtr&& input, size_t width, co using TModifyKeysList = std::vector<std::tuple<TCoAtom, TCoAtom, ui32, const TTypeAnnotationNode*>>; template<bool LeftOrRight> -bool PrepareJoinSide( - TDqConnection& connection, +TCoLambda PrepareJoinSide( + TPositionHandle pos, const std::map<std::string_view, ui32>& columns, const std::vector<TCoAtom>& keys, TModifyKeysList& remap, bool filter, TExprNode::TListType& keysList, - TExprContext& ctx, - IOptimizationContext& optCtx) { + TExprContext& ctx) { - TCoArgument inputArg{ctx.NewArgument(connection.Pos(), "flow")}; + TCoArgument inputArg{ctx.NewArgument(pos, "flow")}; auto preprocess = ctx.Builder(inputArg.Pos()) .Callable("Map") .Add(0, inputArg.Ptr()) @@ -940,7 +939,7 @@ bool PrepareJoinSide( .Arg(0, "row") .Add(1, std::get<0>(key).Ptr()) .Seal() - .Add(1, ExpandType(connection.Pos(), *std::get<const TTypeAnnotationNode*>(key), ctx)) + .Add(1, ExpandType(pos, *std::get<const TTypeAnnotationNode*>(key), ctx)) .Seal() .Seal(); } @@ -970,22 +969,15 @@ bool PrepareJoinSide( } } - const auto lambda = Build<TCoLambda>(ctx, preprocess->Pos()) - .Args({inputArg}) - .Body(std::move(preprocess)) - .Done(); - - if (const auto cn = DqPushLambdaToStageUnionAll(connection.Cast<TDqCnUnionAll>(), lambda, {}, ctx, optCtx)) - connection = cn.Cast(); - else - return false; - for (auto& key : remap) { const auto index = std::get<ui32>(key); keysList[index] = ctx.ChangeChild(*keysList[index], LeftOrRight ? TDqJoinKeyTuple::idx_LeftColumn : TDqJoinKeyTuple::idx_RightColumn, std::get<1>(key).Ptr()); } - return true; + return Build<TCoLambda>(ctx, preprocess->Pos()) + .Args({inputArg}) + .Body(std::move(preprocess)) + .Done(); } TExprNode::TPtr ReplaceJoinOnSide(TExprNode::TPtr&& input, const TTypeAnnotationNode& resutType, const std::string_view& tableName, TExprContext& ctx) { @@ -1102,16 +1094,24 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext& if (!remapLeft.empty() || !remapRight.empty()) { auto joinKeys = join.JoinKeys().Ref().ChildrenList(); - auto connLeft = join.LeftInput().Cast<TDqConnection>(); - auto connRight = join.RightInput().Cast<TDqConnection>(); + auto connLeft = join.LeftInput().Cast<TDqCnUnionAll>(); + auto connRight = join.RightInput().Cast<TDqCnUnionAll>(); + + std::vector<std::pair<TDqCnUnionAll, TCoLambda>> remaps; + + if (!remapLeft.empty()) + remaps.emplace_back(connLeft, PrepareJoinSide<true>(connLeft.Pos(), leftNames, leftJoinKeys, remapLeft, filter || rightKind, joinKeys, ctx)); + + if (!remapRight.empty()) + remaps.emplace_back(connRight, PrepareJoinSide<false>(connRight.Pos(), rightNames, rightJoinKeys, remapRight, filter || leftKind, joinKeys, ctx)); + + DqPushLambdasToStagesUnionAll(remaps, ctx, optCtx); if (!remapLeft.empty()) - if (!PrepareJoinSide<true>(connLeft, leftNames, leftJoinKeys, remapLeft, filter || rightKind, joinKeys, ctx, optCtx)) - return join; + connLeft = remaps.front().first; if (!remapRight.empty()) - if (!PrepareJoinSide<false>(connRight, rightNames, rightJoinKeys, remapRight, filter || leftKind, joinKeys, ctx, optCtx)) - return join; + connRight = remaps.back().first; const auto& items = GetSeqItemType(*join.Ref().GetTypeAnn()).Cast<TStructExprType>()->GetItems(); TExprNode::TListType fields(items.size()); diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp index 6d940cbb7c..65d3ad30ae 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp @@ -546,26 +546,15 @@ TMaybeNode<TDqStage> DqPushFlatMapInnerConnectionsToStageInput(TCoFlatMapBase& f .Done(); } -} // namespace - -TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& outputIndex, const TCoLambda& lambda, +TMaybeNode<TDqStage> DqPushLambdasToStage(const TDqStage& stage, const std::map<ui32, TCoLambda>& lambdas, const TVector<TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx) { - YQL_CLOG(TRACE, CoreDq) << "stage #" << stage.Ref().UniqueId() << ": " << PrintDqStageOnly(stage, ctx) - << ", add lambda to output #" << outputIndex.Value(); - - if (IsDqDependsOnStage(lambda, stage)) { - YQL_CLOG(TRACE, CoreDq) << "Lambda " << lambda.Ref().Dump() << " depends on stage: " << PrintDqStageOnly(stage, ctx); - return {}; - } - auto program = stage.Program(); - ui32 index = FromString<ui32>(outputIndex.Value()); ui32 branchesCount = GetStageOutputsCount(stage); TExprNode::TPtr newProgram; if (branchesCount == 1) { - newProgram = ctx.FuseLambdas(lambda.Ref(), program.Ref()); + newProgram = ctx.FuseLambdas(lambdas.at(0U).Ref(), program.Ref()); } else { auto dqReplicate = program.Body().Cast<TDqReplicate>(); @@ -584,8 +573,8 @@ TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& o YQL_ENSURE(branchLambda.Args().Size() == 1); TExprNode::TPtr newBranchProgram; - if (index == i) { - newBranchProgram = ctx.FuseLambdas(lambda.Ref(), branchLambda.Ref()); + if (const auto it = lambdas.find(i); lambdas.cend() != it) { + newBranchProgram = ctx.FuseLambdas(it->second.Ref(), branchLambda.Ref()); } else { newBranchProgram = ctx.DeepCopyLambda(branchLambda.Ref()); } @@ -632,6 +621,23 @@ TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& o return newStage; } +} // namespace + +TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& outputIndex, const TCoLambda& lambda, + const TVector<TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx) +{ + YQL_CLOG(TRACE, CoreDq) << "stage #" << stage.Ref().UniqueId() << ": " << PrintDqStageOnly(stage, ctx) + << ", add lambda to output #" << outputIndex.Value(); + + if (IsDqDependsOnStage(lambda, stage)) { + YQL_CLOG(TRACE, CoreDq) << "Lambda " << lambda.Ref().Dump() << " depends on stage: " << PrintDqStageOnly(stage, ctx); + return {}; + } + + const auto index = FromString<ui32>(outputIndex.Value()); + return DqPushLambdasToStage(stage, {{index, lambda}}, lambdaInputs, ctx, optCtx); +} + TExprNode::TPtr DqBuildPushableStage(const NNodes::TDqConnection& connection, TExprContext& ctx) { auto stage = connection.Output().Stage().Cast<TDqStage>(); auto program = stage.Program(); @@ -675,6 +681,32 @@ TMaybeNode<TDqConnection> DqPushLambdaToStageUnionAll(const TDqConnection& conne return TDqConnection(ctx.ChangeChild(connection.Ref(), TDqConnection::idx_Output, output.Ptr())); } +void DqPushLambdasToStagesUnionAll(std::vector<std::pair<TDqCnUnionAll, TCoLambda>>& items, TExprContext& ctx, IOptimizationContext& optCtx) +{ + TNodeMap<std::pair<std::map<ui32, TCoLambda>, TDqStage>> map(items.size()); + + for (const auto& item: items) { + const auto& output = item.first.Output(); + const auto index = FromString<ui32>(output.Index().Value()); + const auto ins = map.emplace(output.Stage().Raw(), std::make_pair(std::map<ui32, TCoLambda>(), output.Stage().Cast<TDqStage>())).first; + ins->second.first.emplace(index, item.second); + } + + for (auto& item : map) { + item.second.second = DqPushLambdasToStage(item.second.second, item.second.first, {}, ctx, optCtx).Cast(); + } + + for (auto& item: items) { + const auto& output = item.first.Output(); + item.first = Build<TDqCnUnionAll>(ctx, item.first.Pos()) + .Output<TDqOutput>() + .Stage(map.find(output.Stage().Raw())->second.second) + .Index(output.Index()) + .Build() + .Done(); + } +} + TExprBase DqPushSkipNullMembersToStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TParentsMap& parentsMap, bool allowStageMultiUsage) { diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h index ca5a0bd4e9..f21d70f587 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.h +++ b/ydb/library/yql/dq/opt/dq_opt_phy.h @@ -19,6 +19,8 @@ TExprNode::TPtr DqBuildPushableStage(const NNodes::TDqConnection& connection, TE NNodes::TMaybeNode<NNodes::TDqConnection> DqPushLambdaToStageUnionAll(const NNodes::TDqConnection& connection, const NNodes::TCoLambda& lambda, const TVector<NNodes::TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx); +void DqPushLambdasToStagesUnionAll(std::vector<std::pair<NNodes::TDqCnUnionAll, NNodes::TCoLambda>>& items, TExprContext& ctx, IOptimizationContext& optCtx); + NNodes::TExprBase DqPushSkipNullMembersToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TParentsMap& parentsMap, bool allowStageMultiUsage = true); |