aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authordeniskhalikov <deniskhalikov@yandex-team.com>2025-04-23 16:46:47 +0300
committerdeniskhalikov <deniskhalikov@yandex-team.com>2025-04-23 17:11:51 +0300
commitac324b9ce470cd7bec4d18dbe5e77495e96f92b8 (patch)
tree9e62d48a6ac47f3273979629d1213a8249088dc5
parent7c0632d935742fed09b7e3c49c5677e9bc3320b3 (diff)
downloadydb-ac324b9ce470cd7bec4d18dbe5e77495e96f92b8.tar.gz
Fix FilterPushdownOverJoinOptionalSide
commit_hash:065881aad5e9f774c9709037fdfd30b5d3c77d51
-rw-r--r--yql/essentials/core/common_opt/yql_flatmap_over_join.cpp206
-rw-r--r--yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql1
2 files changed, 133 insertions, 74 deletions
diff --git a/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp b/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp
index 52107e4b311..5cadc3a2941 100644
--- a/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp
+++ b/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp
@@ -317,6 +317,24 @@ void CountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& cou
}
}
+void CollectJoinLabels(TExprNode::TPtr joinTree, THashSet<TString> &labels) {
+ if (joinTree->IsAtom()) {
+ labels.emplace(joinTree->Content());
+ } else {
+ CollectJoinLabels(joinTree->ChildPtr(1), labels);
+ CollectJoinLabels(joinTree->ChildPtr(2), labels);
+ }
+}
+
+void DecrementCountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& counters) {
+ if (joinTree->IsAtom()) {
+ counters[joinTree->Content()]--;
+ } else {
+ DecrementCountLabelsInputUsage(joinTree->ChildPtr(1), counters);
+ DecrementCountLabelsInputUsage(joinTree->ChildPtr(2), counters);
+ }
+}
+
// returns the path to join child
std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
const TExprNode::TPtr& joinTree, const TJoinLabels& labels, ui32 inputIndex, const TExprNode::TPtr& parent = nullptr
@@ -350,6 +368,55 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
return {nullptr, nullptr};
}
+// Maps the given `labelNames` collected from join tree to `joinLabels` associated with `EquiJoin`.
+TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> MapLabelNamesToJoinLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& joinLabels,
+ const THashSet<TString>& labelNames) {
+ const ui32 joinLabelSize = joinLabels.size();
+ TVector<bool> taken(joinLabelSize, false);
+ TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> result;
+
+ // We could have a situation with multiple labels associated with one set of join keys, so we want to match it ones.
+ for (const auto& labelName : labelNames) {
+ for (ui32 i = 0; i < joinLabelSize; ++i) {
+ const auto& labelNamesSet = joinLabels[i].first;
+ if (!taken[i] && labelNamesSet.count(labelName)) {
+ result.push_back(joinLabels[i]);
+ taken[i] = true;
+ }
+ }
+ }
+ return result;
+}
+
+// Combines labels from the given `labels` vector to one hash set.
+THashSet<TString> CombineLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& labels) {
+ THashSet<TString> combinedResult;
+ for (const auto &[labelNames, _] : labels) {
+ combinedResult.insert(labelNames.begin(), labelNames.end());
+ }
+ return combinedResult;
+}
+
+// Creates a list from the given `labels`.
+TExprNode::TPtr CreateLabelList(const THashSet<TString>& labels, TExprContext& ctx, const TPositionHandle& position) {
+ TExprNode::TListType newKeys;
+ for (const auto& label : labels) {
+ newKeys.push_back(ctx.NewAtom(position, label));
+ }
+ return ctx.NewList(position, std::move(newKeys));
+}
+
+TExprNode::TPtr RemoveJoinKeysFromElimination(const TExprNode& settings, TExprContext& ctx) {
+ TExprNode::TListType updated;
+ for (auto setting : settings.Children()) {
+ if (setting->ChildrenSize() == 3 && setting->Child(0)->Content() == "rename" && setting->Child(2)->Content() == "") {
+ continue;
+ }
+ updated.push_back(setting);
+ }
+ return ctx.NewList(settings.Pos(), std::move(updated));
+}
+
TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate,
const TSet<TStringBuf>& usedFields, TExprNode::TPtr args, const TJoinLabels& labels,
ui32 inputIndex, const TMap<TStringBuf, TVector<TStringBuf>>& renameMap, bool ordered, bool skipNulls, TExprContext& ctx,
@@ -391,9 +458,21 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
}
THashMap<TString, TExprNode::TPtr> equiJoinLabels;
+ // Stores labels as hash set and associated join input.
+ TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> joinLabels;
for (size_t i = 0; i < equiJoin->ChildrenSize() - 2; i++) {
auto label = equiJoin->Child(i);
- equiJoinLabels.emplace(label->Child(1)->Content(), label->ChildPtr(0));
+ THashSet<TString> labelsName;
+ if (auto value = TMaybeNode<TCoAtom>(label->Child(1))) {
+ labelsName.emplace(value.Cast().Value());
+ equiJoinLabels.emplace(value.Cast().Value(), label->ChildPtr(0));
+ } else if (auto tuple = TMaybeNode<TCoAtomList>(label->Child(1))) {
+ for (const auto& value : tuple.Cast()) {
+ labelsName.emplace(value.Value());
+ equiJoinLabels.emplace(value.Value(), label->ChildPtr(0));
+ }
+ }
+ joinLabels.push_back({labelsName, label->ChildPtr(0)});
}
THashMap<TString, int> joinLabelCounters;
@@ -401,16 +480,24 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
auto [leftJoinTree, parentJoinPtr] = IsRightSideForLeftJoin(joinTree, labels, inputIndex);
YQL_ENSURE(leftJoinTree);
- joinLabelCounters[leftJoinTree->Child(1)->Content()]--;
- joinLabelCounters[leftJoinTree->Child(2)->Content()]--;
+ // Left child of the `leftJoinTree` could be a tree, need to walk and decrement them all, the do not need be at fina EquiJoin.
+ DecrementCountLabelsInputUsage(leftJoinTree, joinLabelCounters);
auto leftJoinSettings = equiJoin->ChildPtr(equiJoin->ChildrenSize() - 1);
+ auto newLeftJoinSettings = RemoveJoinKeysFromElimination(*leftJoinSettings, ctx);
auto innerJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "Inner"));
auto leftOnlyJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "LeftOnly"));
- THashMap<TString, int> leftSideJoinLabels;
- CountLabelsInputUsage(leftJoinTree->Child(1), leftSideJoinLabels);
+ // Collect join labels for left child of the `Left` join tree, they are used in `EquiJoin` for `Left Only` and `Inner`.
+ THashSet<TString> leftLabelsNoRightChild;
+ CollectJoinLabels(leftJoinTree->Child(1), leftLabelsNoRightChild);
+ auto leftJoinLabelsNoRightChild = MapLabelNamesToJoinLabels(joinLabels, leftLabelsNoRightChild);
+
+ // Collect join labels for the full `Left` join tree, the are used list of labels associated with result of `EquiJoin`.
+ THashSet<TString> leftLabelsFull;
+ CollectJoinLabels(leftJoinTree, leftLabelsFull);
+ auto leftJoinLabelsFull = MapLabelNamesToJoinLabels(joinLabels, leftLabelsFull);
YQL_ENSURE(leftJoinTree->Child(2)->IsAtom());
auto rightSideInput = equiJoinLabels.at(leftJoinTree->Child(2)->Content());
@@ -436,11 +523,20 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
auto innerJoin = ctx.Builder(pos)
.Callable("EquiJoin")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (const auto& [labelName, _] : leftSideJoinLabels) {
- parent.List(i++)
- .Add(0, equiJoinLabels.at(labelName))
- .Atom(1, labelName)
- .Seal();
+ for (const auto& [labelNames, input] : leftJoinLabelsNoRightChild) {
+ if (labelNames.size() == 1) {
+ parent.List(i++)
+ .Add(0, input)
+ .Atom(1, *labelNames.begin())
+ .Seal();
+ } else {
+ // Create a label list if them more than 1.
+ auto labelList = CreateLabelList(labelNames, ctx, pos);
+ parent.List(i++)
+ .Add(0, input)
+ .Add(1, labelList)
+ .Seal();
+ }
}
return parent;
})
@@ -449,7 +545,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
.Atom(1, innerJoinTree->ChildRef(2)->Content())
.Seal()
.Add(i++, innerJoinTree)
- .Add(i++, leftJoinSettings)
+ .Add(i++, newLeftJoinSettings)
.Seal()
.Build();
@@ -459,11 +555,20 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
auto leftOnlyJoin = ctx.Builder(pos)
.Callable("EquiJoin")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (const auto& [labelName, _] : leftSideJoinLabels) {
- parent.List(i++)
- .Add(0, equiJoinLabels.at(labelName))
- .Atom(1, labelName)
- .Seal();
+ for (const auto& [labelNames, input] : leftJoinLabelsNoRightChild) {
+ if (labelNames.size() == 1) {
+ parent.List(i++)
+ .Add(0, input)
+ .Atom(1, *labelNames.begin())
+ .Seal();
+ } else {
+ // Create a label list if them more than 1.
+ auto labelList = CreateLabelList(labelNames, ctx, pos);
+ parent.List(i++)
+ .Add(0, input)
+ .Add(1, labelList)
+ .Seal();
+ }
}
return parent;
})
@@ -472,7 +577,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
.Atom(1, leftOnlyJoinTree->ChildRef(2)->Content())
.Seal()
.Add(i++, leftOnlyJoinTree)
- .Add(i++, leftJoinSettings)
+ .Add(i++, newLeftJoinSettings)
.Seal()
.Build();
@@ -495,25 +600,6 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
return unionAll;
}
- THashSet <TString> joinColumns;
- for (const auto& [labelName, _] : leftSideJoinLabels) {
- auto tableName = labels.FindInputIndex(labelName);
- YQL_ENSURE(tableName);
- for (auto column : labels.Inputs[*tableName].EnumerateAllColumns()) {
- joinColumns.emplace(std::move(column));
- }
- }
- auto rightSideTableName = labels.FindInputIndex(innerJoinTree->Child(2)->Content());
- YQL_ENSURE(rightSideTableName);
- for (auto column : labels.Inputs[*rightSideTableName].EnumerateAllColumns()) {
- joinColumns.emplace(std::move(column));
- }
-
- auto newJoinLabel = ctx.Builder(pos)
- .Atom("__yql_right_side_pushdown_input_label")
- .Build();
-
-
TExprNode::TPtr remJoinKeys;
bool changedLeftSide = false;
if (leftJoinTree == parentJoinPtr->ChildPtr(1)) {
@@ -523,36 +609,14 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
remJoinKeys = parentJoinPtr->ChildPtr(4);
}
- TExprNode::TListType newKeys;
- newKeys.reserve(remJoinKeys->ChildrenSize());
-
- for (ui32 i = 0; i < remJoinKeys->ChildrenSize(); i += 2) {
- auto table = remJoinKeys->ChildPtr(i);
- auto column = remJoinKeys->ChildPtr(i + 1);
-
- YQL_ENSURE(table->IsAtom());
- YQL_ENSURE(column->IsAtom());
-
- auto fcn = FullColumnName(table->Content(), column->Content());
-
- if (joinColumns.contains(fcn)) {
- newKeys.push_back(newJoinLabel);
- newKeys.push_back(ctx.NewAtom(column->Pos(), fcn));
- } else {
- newKeys.push_back(table);
- newKeys.push_back(column);
- }
- }
-
- auto newKeysList = ctx.NewList(remJoinKeys->Pos(), std::move(newKeys));
-
+ auto parentJoinLabel = remJoinKeys->ChildPtr(0);
auto newParentJoin = ctx.Builder(joinTree->Pos())
.List()
.Add(0, parentJoinPtr->ChildPtr(0))
- .Add(1, changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(1))
- .Add(2, !changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(2))
- .Add(3, changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(3))
- .Add(4, !changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(4))
+ .Add(1, changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(1))
+ .Add(2, !changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(2))
+ .Add(3, parentJoinPtr->ChildPtr(3))
+ .Add(4, parentJoinPtr->ChildPtr(4))
.Add(5, parentJoinPtr->ChildPtr(5))
.Seal()
.Build();
@@ -568,19 +632,13 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
}
return parent;
})
- .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
- for (const auto& column : joinColumns) {
- parent.List(i++)
- .Atom(0, "rename")
- .Atom(1, FullColumnName("__yql_right_side_pushdown_input_label", column))
- .Atom(2, column)
- .Seal();
- }
- return parent;
- })
.Seal()
.Build();
+ // Combine join labels from left tree and associate them with result of `EquiJoin` from above.
+ auto combinedLabelList = CombineLabels(leftJoinLabelsFull);
+ auto combinedJoinLabels = CreateLabelList(combinedLabelList, ctx, pos);
+
i = 0;
auto newEquiJoin = ctx.Builder(pos)
.Callable("EquiJoin")
@@ -598,7 +656,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
})
.List(i++)
.Add(0, unionAll)
- .Add(1, newJoinLabel)
+ .Add(1, combinedJoinLabels)
.Seal()
.Add(i++, newJoinTree)
.Add(i++, newJoinSettings)
diff --git a/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql b/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql
index da815ce13b7..e9370235643 100644
--- a/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql
+++ b/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql
@@ -1,4 +1,5 @@
PRAGMA FilterPushdownOverJoinOptionalSide;
+PRAGMA config.flags("OptimizerFlags", "FuseEquiJoinsInputMultiLabels", "PullUpFlatMapOverJoinMultipleLabels");
SELECT t1.Key1, t1.Key2, t1.Fk1, t1.Value, t2.Key, t2.Value, t3.Value