aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authora-romanov <Anton.Romanov@ydb.tech>2023-03-07 21:51:04 +0300
committera-romanov <Anton.Romanov@ydb.tech>2023-03-07 21:51:04 +0300
commit242a9ac0056883cc899c7fb56c27eb44cf23d96f (patch)
tree3fbe2b0b92471995711f869c558b84cfd1a1e849
parentee1ab05c311bc7204bd88a00e9db01331a078961 (diff)
downloadydb-242a9ac0056883cc899c7fb56c27eb44cf23d96f.tar.gz
YQL-15748 Fix push cast join keys for left and right to single stage.
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join.cpp46
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy.cpp62
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy.h2
3 files changed, 72 insertions, 38 deletions
diff --git a/ydb/library/yql/dq/opt/dq_opt_join.cpp b/ydb/library/yql/dq/opt/dq_opt_join.cpp
index 47058f2cd5..aca678f719 100644
--- a/ydb/library/yql/dq/opt/dq_opt_join.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_join.cpp
@@ -904,17 +904,16 @@ TExprNode::TPtr SqueezeJoinInputToDict(TExprNode::TPtr&& input, size_t width, co
using TModifyKeysList = std::vector<std::tuple<TCoAtom, TCoAtom, ui32, const TTypeAnnotationNode*>>;
template<bool LeftOrRight>
-bool PrepareJoinSide(
- TDqConnection& connection,
+TCoLambda PrepareJoinSide(
+ TPositionHandle pos,
const std::map<std::string_view, ui32>& columns,
const std::vector<TCoAtom>& keys,
TModifyKeysList& remap,
bool filter,
TExprNode::TListType& keysList,
- TExprContext& ctx,
- IOptimizationContext& optCtx) {
+ TExprContext& ctx) {
- TCoArgument inputArg{ctx.NewArgument(connection.Pos(), "flow")};
+ TCoArgument inputArg{ctx.NewArgument(pos, "flow")};
auto preprocess = ctx.Builder(inputArg.Pos())
.Callable("Map")
.Add(0, inputArg.Ptr())
@@ -940,7 +939,7 @@ bool PrepareJoinSide(
.Arg(0, "row")
.Add(1, std::get<0>(key).Ptr())
.Seal()
- .Add(1, ExpandType(connection.Pos(), *std::get<const TTypeAnnotationNode*>(key), ctx))
+ .Add(1, ExpandType(pos, *std::get<const TTypeAnnotationNode*>(key), ctx))
.Seal()
.Seal();
}
@@ -970,22 +969,15 @@ bool PrepareJoinSide(
}
}
- const auto lambda = Build<TCoLambda>(ctx, preprocess->Pos())
- .Args({inputArg})
- .Body(std::move(preprocess))
- .Done();
-
- if (const auto cn = DqPushLambdaToStageUnionAll(connection.Cast<TDqCnUnionAll>(), lambda, {}, ctx, optCtx))
- connection = cn.Cast();
- else
- return false;
-
for (auto& key : remap) {
const auto index = std::get<ui32>(key);
keysList[index] = ctx.ChangeChild(*keysList[index], LeftOrRight ? TDqJoinKeyTuple::idx_LeftColumn : TDqJoinKeyTuple::idx_RightColumn, std::get<1>(key).Ptr());
}
- return true;
+ return Build<TCoLambda>(ctx, preprocess->Pos())
+ .Args({inputArg})
+ .Body(std::move(preprocess))
+ .Done();
}
TExprNode::TPtr ReplaceJoinOnSide(TExprNode::TPtr&& input, const TTypeAnnotationNode& resutType, const std::string_view& tableName, TExprContext& ctx) {
@@ -1102,16 +1094,24 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
if (!remapLeft.empty() || !remapRight.empty()) {
auto joinKeys = join.JoinKeys().Ref().ChildrenList();
- auto connLeft = join.LeftInput().Cast<TDqConnection>();
- auto connRight = join.RightInput().Cast<TDqConnection>();
+ auto connLeft = join.LeftInput().Cast<TDqCnUnionAll>();
+ auto connRight = join.RightInput().Cast<TDqCnUnionAll>();
+
+ std::vector<std::pair<TDqCnUnionAll, TCoLambda>> remaps;
+
+ if (!remapLeft.empty())
+ remaps.emplace_back(connLeft, PrepareJoinSide<true>(connLeft.Pos(), leftNames, leftJoinKeys, remapLeft, filter || rightKind, joinKeys, ctx));
+
+ if (!remapRight.empty())
+ remaps.emplace_back(connRight, PrepareJoinSide<false>(connRight.Pos(), rightNames, rightJoinKeys, remapRight, filter || leftKind, joinKeys, ctx));
+
+ DqPushLambdasToStagesUnionAll(remaps, ctx, optCtx);
if (!remapLeft.empty())
- if (!PrepareJoinSide<true>(connLeft, leftNames, leftJoinKeys, remapLeft, filter || rightKind, joinKeys, ctx, optCtx))
- return join;
+ connLeft = remaps.front().first;
if (!remapRight.empty())
- if (!PrepareJoinSide<false>(connRight, rightNames, rightJoinKeys, remapRight, filter || leftKind, joinKeys, ctx, optCtx))
- return join;
+ connRight = remaps.back().first;
const auto& items = GetSeqItemType(*join.Ref().GetTypeAnn()).Cast<TStructExprType>()->GetItems();
TExprNode::TListType fields(items.size());
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp
index 6d940cbb7c..65d3ad30ae 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp
@@ -546,26 +546,15 @@ TMaybeNode<TDqStage> DqPushFlatMapInnerConnectionsToStageInput(TCoFlatMapBase& f
.Done();
}
-} // namespace
-
-TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& outputIndex, const TCoLambda& lambda,
+TMaybeNode<TDqStage> DqPushLambdasToStage(const TDqStage& stage, const std::map<ui32, TCoLambda>& lambdas,
const TVector<TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx)
{
- YQL_CLOG(TRACE, CoreDq) << "stage #" << stage.Ref().UniqueId() << ": " << PrintDqStageOnly(stage, ctx)
- << ", add lambda to output #" << outputIndex.Value();
-
- if (IsDqDependsOnStage(lambda, stage)) {
- YQL_CLOG(TRACE, CoreDq) << "Lambda " << lambda.Ref().Dump() << " depends on stage: " << PrintDqStageOnly(stage, ctx);
- return {};
- }
-
auto program = stage.Program();
- ui32 index = FromString<ui32>(outputIndex.Value());
ui32 branchesCount = GetStageOutputsCount(stage);
TExprNode::TPtr newProgram;
if (branchesCount == 1) {
- newProgram = ctx.FuseLambdas(lambda.Ref(), program.Ref());
+ newProgram = ctx.FuseLambdas(lambdas.at(0U).Ref(), program.Ref());
} else {
auto dqReplicate = program.Body().Cast<TDqReplicate>();
@@ -584,8 +573,8 @@ TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& o
YQL_ENSURE(branchLambda.Args().Size() == 1);
TExprNode::TPtr newBranchProgram;
- if (index == i) {
- newBranchProgram = ctx.FuseLambdas(lambda.Ref(), branchLambda.Ref());
+ if (const auto it = lambdas.find(i); lambdas.cend() != it) {
+ newBranchProgram = ctx.FuseLambdas(it->second.Ref(), branchLambda.Ref());
} else {
newBranchProgram = ctx.DeepCopyLambda(branchLambda.Ref());
}
@@ -632,6 +621,23 @@ TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& o
return newStage;
}
+} // namespace
+
+TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& outputIndex, const TCoLambda& lambda,
+ const TVector<TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx)
+{
+ YQL_CLOG(TRACE, CoreDq) << "stage #" << stage.Ref().UniqueId() << ": " << PrintDqStageOnly(stage, ctx)
+ << ", add lambda to output #" << outputIndex.Value();
+
+ if (IsDqDependsOnStage(lambda, stage)) {
+ YQL_CLOG(TRACE, CoreDq) << "Lambda " << lambda.Ref().Dump() << " depends on stage: " << PrintDqStageOnly(stage, ctx);
+ return {};
+ }
+
+ const auto index = FromString<ui32>(outputIndex.Value());
+ return DqPushLambdasToStage(stage, {{index, lambda}}, lambdaInputs, ctx, optCtx);
+}
+
TExprNode::TPtr DqBuildPushableStage(const NNodes::TDqConnection& connection, TExprContext& ctx) {
auto stage = connection.Output().Stage().Cast<TDqStage>();
auto program = stage.Program();
@@ -675,6 +681,32 @@ TMaybeNode<TDqConnection> DqPushLambdaToStageUnionAll(const TDqConnection& conne
return TDqConnection(ctx.ChangeChild(connection.Ref(), TDqConnection::idx_Output, output.Ptr()));
}
+void DqPushLambdasToStagesUnionAll(std::vector<std::pair<TDqCnUnionAll, TCoLambda>>& items, TExprContext& ctx, IOptimizationContext& optCtx)
+{
+ TNodeMap<std::pair<std::map<ui32, TCoLambda>, TDqStage>> map(items.size());
+
+ for (const auto& item: items) {
+ const auto& output = item.first.Output();
+ const auto index = FromString<ui32>(output.Index().Value());
+ const auto ins = map.emplace(output.Stage().Raw(), std::make_pair(std::map<ui32, TCoLambda>(), output.Stage().Cast<TDqStage>())).first;
+ ins->second.first.emplace(index, item.second);
+ }
+
+ for (auto& item : map) {
+ item.second.second = DqPushLambdasToStage(item.second.second, item.second.first, {}, ctx, optCtx).Cast();
+ }
+
+ for (auto& item: items) {
+ const auto& output = item.first.Output();
+ item.first = Build<TDqCnUnionAll>(ctx, item.first.Pos())
+ .Output<TDqOutput>()
+ .Stage(map.find(output.Stage().Raw())->second.second)
+ .Index(output.Index())
+ .Build()
+ .Done();
+ }
+}
+
TExprBase DqPushSkipNullMembersToStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
const TParentsMap& parentsMap, bool allowStageMultiUsage)
{
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h
index ca5a0bd4e9..f21d70f587 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy.h
+++ b/ydb/library/yql/dq/opt/dq_opt_phy.h
@@ -19,6 +19,8 @@ TExprNode::TPtr DqBuildPushableStage(const NNodes::TDqConnection& connection, TE
NNodes::TMaybeNode<NNodes::TDqConnection> DqPushLambdaToStageUnionAll(const NNodes::TDqConnection& connection, const NNodes::TCoLambda& lambda,
const TVector<NNodes::TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx);
+void DqPushLambdasToStagesUnionAll(std::vector<std::pair<NNodes::TDqCnUnionAll, NNodes::TCoLambda>>& items, TExprContext& ctx, IOptimizationContext& optCtx);
+
NNodes::TExprBase DqPushSkipNullMembersToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
const TParentsMap& parentsMap, bool allowStageMultiUsage = true);