diff options
author | deniskhalikov <deniskhalikov@yandex-team.com> | 2025-04-23 16:46:47 +0300 |
---|---|---|
committer | deniskhalikov <deniskhalikov@yandex-team.com> | 2025-04-23 17:11:51 +0300 |
commit | ac324b9ce470cd7bec4d18dbe5e77495e96f92b8 (patch) | |
tree | 9e62d48a6ac47f3273979629d1213a8249088dc5 | |
parent | 7c0632d935742fed09b7e3c49c5677e9bc3320b3 (diff) | |
download | ydb-ac324b9ce470cd7bec4d18dbe5e77495e96f92b8.tar.gz |
Fix FilterPushdownOverJoinOptionalSide
commit_hash:065881aad5e9f774c9709037fdfd30b5d3c77d51
-rw-r--r-- | yql/essentials/core/common_opt/yql_flatmap_over_join.cpp | 206 | ||||
-rw-r--r-- | yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql | 1 |
2 files changed, 133 insertions, 74 deletions
diff --git a/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp b/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp index 52107e4b311..5cadc3a2941 100644 --- a/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp +++ b/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp @@ -317,6 +317,24 @@ void CountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& cou } } +void CollectJoinLabels(TExprNode::TPtr joinTree, THashSet<TString> &labels) { + if (joinTree->IsAtom()) { + labels.emplace(joinTree->Content()); + } else { + CollectJoinLabels(joinTree->ChildPtr(1), labels); + CollectJoinLabels(joinTree->ChildPtr(2), labels); + } +} + +void DecrementCountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& counters) { + if (joinTree->IsAtom()) { + counters[joinTree->Content()]--; + } else { + DecrementCountLabelsInputUsage(joinTree->ChildPtr(1), counters); + DecrementCountLabelsInputUsage(joinTree->ChildPtr(2), counters); + } +} + // returns the path to join child std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin( const TExprNode::TPtr& joinTree, const TJoinLabels& labels, ui32 inputIndex, const TExprNode::TPtr& parent = nullptr @@ -350,6 +368,55 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin( return {nullptr, nullptr}; } +// Maps the given `labelNames` collected from join tree to `joinLabels` associated with `EquiJoin`. +TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> MapLabelNamesToJoinLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& joinLabels, + const THashSet<TString>& labelNames) { + const ui32 joinLabelSize = joinLabels.size(); + TVector<bool> taken(joinLabelSize, false); + TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> result; + + // We could have a situation with multiple labels associated with one set of join keys, so we want to match it ones. + for (const auto& labelName : labelNames) { + for (ui32 i = 0; i < joinLabelSize; ++i) { + const auto& labelNamesSet = joinLabels[i].first; + if (!taken[i] && labelNamesSet.count(labelName)) { + result.push_back(joinLabels[i]); + taken[i] = true; + } + } + } + return result; +} + +// Combines labels from the given `labels` vector to one hash set. +THashSet<TString> CombineLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& labels) { + THashSet<TString> combinedResult; + for (const auto &[labelNames, _] : labels) { + combinedResult.insert(labelNames.begin(), labelNames.end()); + } + return combinedResult; +} + +// Creates a list from the given `labels`. +TExprNode::TPtr CreateLabelList(const THashSet<TString>& labels, TExprContext& ctx, const TPositionHandle& position) { + TExprNode::TListType newKeys; + for (const auto& label : labels) { + newKeys.push_back(ctx.NewAtom(position, label)); + } + return ctx.NewList(position, std::move(newKeys)); +} + +TExprNode::TPtr RemoveJoinKeysFromElimination(const TExprNode& settings, TExprContext& ctx) { + TExprNode::TListType updated; + for (auto setting : settings.Children()) { + if (setting->ChildrenSize() == 3 && setting->Child(0)->Content() == "rename" && setting->Child(2)->Content() == "") { + continue; + } + updated.push_back(setting); + } + return ctx.NewList(settings.Pos(), std::move(updated)); +} + TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate, const TSet<TStringBuf>& usedFields, TExprNode::TPtr args, const TJoinLabels& labels, ui32 inputIndex, const TMap<TStringBuf, TVector<TStringBuf>>& renameMap, bool ordered, bool skipNulls, TExprContext& ctx, @@ -391,9 +458,21 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx } THashMap<TString, TExprNode::TPtr> equiJoinLabels; + // Stores labels as hash set and associated join input. + TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> joinLabels; for (size_t i = 0; i < equiJoin->ChildrenSize() - 2; i++) { auto label = equiJoin->Child(i); - equiJoinLabels.emplace(label->Child(1)->Content(), label->ChildPtr(0)); + THashSet<TString> labelsName; + if (auto value = TMaybeNode<TCoAtom>(label->Child(1))) { + labelsName.emplace(value.Cast().Value()); + equiJoinLabels.emplace(value.Cast().Value(), label->ChildPtr(0)); + } else if (auto tuple = TMaybeNode<TCoAtomList>(label->Child(1))) { + for (const auto& value : tuple.Cast()) { + labelsName.emplace(value.Value()); + equiJoinLabels.emplace(value.Value(), label->ChildPtr(0)); + } + } + joinLabels.push_back({labelsName, label->ChildPtr(0)}); } THashMap<TString, int> joinLabelCounters; @@ -401,16 +480,24 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx auto [leftJoinTree, parentJoinPtr] = IsRightSideForLeftJoin(joinTree, labels, inputIndex); YQL_ENSURE(leftJoinTree); - joinLabelCounters[leftJoinTree->Child(1)->Content()]--; - joinLabelCounters[leftJoinTree->Child(2)->Content()]--; + // Left child of the `leftJoinTree` could be a tree, need to walk and decrement them all, the do not need be at fina EquiJoin. + DecrementCountLabelsInputUsage(leftJoinTree, joinLabelCounters); auto leftJoinSettings = equiJoin->ChildPtr(equiJoin->ChildrenSize() - 1); + auto newLeftJoinSettings = RemoveJoinKeysFromElimination(*leftJoinSettings, ctx); auto innerJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "Inner")); auto leftOnlyJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "LeftOnly")); - THashMap<TString, int> leftSideJoinLabels; - CountLabelsInputUsage(leftJoinTree->Child(1), leftSideJoinLabels); + // Collect join labels for left child of the `Left` join tree, they are used in `EquiJoin` for `Left Only` and `Inner`. + THashSet<TString> leftLabelsNoRightChild; + CollectJoinLabels(leftJoinTree->Child(1), leftLabelsNoRightChild); + auto leftJoinLabelsNoRightChild = MapLabelNamesToJoinLabels(joinLabels, leftLabelsNoRightChild); + + // Collect join labels for the full `Left` join tree, the are used list of labels associated with result of `EquiJoin`. + THashSet<TString> leftLabelsFull; + CollectJoinLabels(leftJoinTree, leftLabelsFull); + auto leftJoinLabelsFull = MapLabelNamesToJoinLabels(joinLabels, leftLabelsFull); YQL_ENSURE(leftJoinTree->Child(2)->IsAtom()); auto rightSideInput = equiJoinLabels.at(leftJoinTree->Child(2)->Content()); @@ -436,11 +523,20 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx auto innerJoin = ctx.Builder(pos) .Callable("EquiJoin") .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (const auto& [labelName, _] : leftSideJoinLabels) { - parent.List(i++) - .Add(0, equiJoinLabels.at(labelName)) - .Atom(1, labelName) - .Seal(); + for (const auto& [labelNames, input] : leftJoinLabelsNoRightChild) { + if (labelNames.size() == 1) { + parent.List(i++) + .Add(0, input) + .Atom(1, *labelNames.begin()) + .Seal(); + } else { + // Create a label list if them more than 1. + auto labelList = CreateLabelList(labelNames, ctx, pos); + parent.List(i++) + .Add(0, input) + .Add(1, labelList) + .Seal(); + } } return parent; }) @@ -449,7 +545,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx .Atom(1, innerJoinTree->ChildRef(2)->Content()) .Seal() .Add(i++, innerJoinTree) - .Add(i++, leftJoinSettings) + .Add(i++, newLeftJoinSettings) .Seal() .Build(); @@ -459,11 +555,20 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx auto leftOnlyJoin = ctx.Builder(pos) .Callable("EquiJoin") .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (const auto& [labelName, _] : leftSideJoinLabels) { - parent.List(i++) - .Add(0, equiJoinLabels.at(labelName)) - .Atom(1, labelName) - .Seal(); + for (const auto& [labelNames, input] : leftJoinLabelsNoRightChild) { + if (labelNames.size() == 1) { + parent.List(i++) + .Add(0, input) + .Atom(1, *labelNames.begin()) + .Seal(); + } else { + // Create a label list if them more than 1. + auto labelList = CreateLabelList(labelNames, ctx, pos); + parent.List(i++) + .Add(0, input) + .Add(1, labelList) + .Seal(); + } } return parent; }) @@ -472,7 +577,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx .Atom(1, leftOnlyJoinTree->ChildRef(2)->Content()) .Seal() .Add(i++, leftOnlyJoinTree) - .Add(i++, leftJoinSettings) + .Add(i++, newLeftJoinSettings) .Seal() .Build(); @@ -495,25 +600,6 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx return unionAll; } - THashSet <TString> joinColumns; - for (const auto& [labelName, _] : leftSideJoinLabels) { - auto tableName = labels.FindInputIndex(labelName); - YQL_ENSURE(tableName); - for (auto column : labels.Inputs[*tableName].EnumerateAllColumns()) { - joinColumns.emplace(std::move(column)); - } - } - auto rightSideTableName = labels.FindInputIndex(innerJoinTree->Child(2)->Content()); - YQL_ENSURE(rightSideTableName); - for (auto column : labels.Inputs[*rightSideTableName].EnumerateAllColumns()) { - joinColumns.emplace(std::move(column)); - } - - auto newJoinLabel = ctx.Builder(pos) - .Atom("__yql_right_side_pushdown_input_label") - .Build(); - - TExprNode::TPtr remJoinKeys; bool changedLeftSide = false; if (leftJoinTree == parentJoinPtr->ChildPtr(1)) { @@ -523,36 +609,14 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx remJoinKeys = parentJoinPtr->ChildPtr(4); } - TExprNode::TListType newKeys; - newKeys.reserve(remJoinKeys->ChildrenSize()); - - for (ui32 i = 0; i < remJoinKeys->ChildrenSize(); i += 2) { - auto table = remJoinKeys->ChildPtr(i); - auto column = remJoinKeys->ChildPtr(i + 1); - - YQL_ENSURE(table->IsAtom()); - YQL_ENSURE(column->IsAtom()); - - auto fcn = FullColumnName(table->Content(), column->Content()); - - if (joinColumns.contains(fcn)) { - newKeys.push_back(newJoinLabel); - newKeys.push_back(ctx.NewAtom(column->Pos(), fcn)); - } else { - newKeys.push_back(table); - newKeys.push_back(column); - } - } - - auto newKeysList = ctx.NewList(remJoinKeys->Pos(), std::move(newKeys)); - + auto parentJoinLabel = remJoinKeys->ChildPtr(0); auto newParentJoin = ctx.Builder(joinTree->Pos()) .List() .Add(0, parentJoinPtr->ChildPtr(0)) - .Add(1, changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(1)) - .Add(2, !changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(2)) - .Add(3, changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(3)) - .Add(4, !changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(4)) + .Add(1, changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(1)) + .Add(2, !changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(2)) + .Add(3, parentJoinPtr->ChildPtr(3)) + .Add(4, parentJoinPtr->ChildPtr(4)) .Add(5, parentJoinPtr->ChildPtr(5)) .Seal() .Build(); @@ -568,19 +632,13 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx } return parent; }) - .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (const auto& column : joinColumns) { - parent.List(i++) - .Atom(0, "rename") - .Atom(1, FullColumnName("__yql_right_side_pushdown_input_label", column)) - .Atom(2, column) - .Seal(); - } - return parent; - }) .Seal() .Build(); + // Combine join labels from left tree and associate them with result of `EquiJoin` from above. + auto combinedLabelList = CombineLabels(leftJoinLabelsFull); + auto combinedJoinLabels = CreateLabelList(combinedLabelList, ctx, pos); + i = 0; auto newEquiJoin = ctx.Builder(pos) .Callable("EquiJoin") @@ -598,7 +656,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx }) .List(i++) .Add(0, unionAll) - .Add(1, newJoinLabel) + .Add(1, combinedJoinLabels) .Seal() .Add(i++, newJoinTree) .Add(i++, newJoinSettings) diff --git a/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql b/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql index da815ce13b7..e9370235643 100644 --- a/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql +++ b/yt/yql/tests/sql/suites/join/left_join_right_pushdown_nested_left.sql @@ -1,4 +1,5 @@ PRAGMA FilterPushdownOverJoinOptionalSide; +PRAGMA config.flags("OptimizerFlags", "FuseEquiJoinsInputMultiLabels", "PullUpFlatMapOverJoinMultipleLabels"); SELECT t1.Key1, t1.Key2, t1.Fk1, t1.Value, t2.Key, t2.Value, t3.Value |