diff options
author | Vitaly Stoyan <vitstn@gmail.com> | 2022-06-29 15:20:42 +0300 |
---|---|---|
committer | Vitaly Stoyan <vitstn@gmail.com> | 2022-06-29 15:20:42 +0300 |
commit | 5bc762a70a35cb6aa5ca6547538736c38e7bc064 (patch) | |
tree | 758c8959d91d5c17f6a042d53d45c7b359bd6083 | |
parent | 529a9604e2a54a886109bbd03f04ccf1847540ec (diff) | |
download | ydb-5bc762a70a35cb6aa5ca6547538736c38e7bc064.tar.gz |
YQL-13966 strict equijoin
ref:0ecafb2e87a88272aec5e778164d09ece2117de7
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_flow2.cpp | 144 | ||||
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_pgselect.cpp | 108 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_join.cpp | 168 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_join.h | 6 |
4 files changed, 289 insertions, 137 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp index 0300178d89..7a67e5bc02 100644 --- a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp @@ -9,7 +9,6 @@ #include <ydb/library/yql/core/yql_type_helpers.h> #include <ydb/library/yql/utils/log/log.h> -#include <util/string/type.h> namespace NYql { namespace { @@ -114,40 +113,6 @@ TExprNode::TPtr AggregateSubsetFieldsAnalyzer(const TCoAggregate& node, TExprCon return ret; } -void GatherAndTerms(TExprNode::TPtr&& predicate, TExprNode::TListType& andTerms) { - if (!predicate->IsCallable("And")) { - andTerms.emplace_back(std::move(predicate)); - return; - } - - for (auto& child : predicate->ChildrenList()) { - GatherAndTerms(std::move(child), andTerms); - } -} - -TExprNode::TPtr FuseAndTerms(TPositionHandle position, const TExprNode::TListType& andTerms, TExprNode::TPtr exclude, TExprContext& ctx) { - TExprNode::TPtr prevAndNode = nullptr; - TNodeSet added; - for (const auto& otherAndTerm : andTerms) { - if (otherAndTerm == exclude) { - continue; - } - - if (!added.insert(otherAndTerm.Get()).second) { - continue; - } - - if (!prevAndNode) { - prevAndNode = otherAndTerm; - } - else { - prevAndNode = ctx.NewCallable(position, "And", { prevAndNode, otherAndTerm }); - } - } - - return prevAndNode; -} - TExprNode::TPtr ConstantPredicatePushdownOverEquiJoin(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate, bool ordered, TExprContext& ctx) { auto lambda = ctx.Builder(predicate->Pos()) .Lambda() @@ -650,30 +615,12 @@ private: TExprContext& Ctx; }; -TExprNode::TPtr DecayCrossJoinIntoInner(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate, +TExprNode::TPtr DecayCrossJoinIntoInner(TExprNode::TPtr equiJoin, const TExprNode::TPtr& predicate, const TJoinLabels& labels, ui32 index1, ui32 index2, const TExprNode& row, const THashMap<TString, TString>& backRenameMap, const TParentsMap& parentsMap, TExprContext& ctx) { YQL_ENSURE(index1 != index2); - bool withCoalesce = false; - if (predicate->IsCallable("Coalesce")) { - if (predicate->Tail().IsCallable("Bool") && IsFalse(predicate->Tail().Head().Content())) { - withCoalesce = true; - predicate = predicate->HeadPtr(); - } else { - return equiJoin; - } - } - - TExprNode::TPtr left, right; - if (predicate->IsCallable("==")) { - left = predicate->ChildPtr(0); - right = predicate->ChildPtr(1); - } else if (predicate->IsCallable("FromPg") && predicate->Head().IsCallable("PgResolvedOp") && - (predicate->Head().Head().Content() == "=")) { - left = predicate->Head().ChildPtr(2); - right = predicate->Head().ChildPtr(3); - } else { + if (!IsEquality(predicate, left, right)) { return equiJoin; } @@ -729,84 +676,6 @@ TExprNode::TPtr DecayCrossJoinIntoInner(TExprNode::TPtr equiJoin, TExprNode::TPt return ctx.ChangeChild(*equiJoin, inputsCount, std::move(newJoinTree)); } -TExprNode::TPtr PreparePredicate(TExprNode::TPtr predicate, TExprContext& ctx) { - if (!predicate->IsCallable("Or")) { - return predicate; - } - - if (predicate->ChildrenSize() == 1) { - return predicate->HeadPtr(); - } - - // try to extract common And parts from Or - TVector<TExprNode::TListType> andParts; - for (ui32 i = 0; i < predicate->ChildrenSize(); ++i) { - TExprNode::TListType res; - GatherAndTerms(predicate->ChildPtr(i), res); - andParts.emplace_back(std::move(res)); - } - - THashMap<const TExprNode*, ui32> commonParts; - for (ui32 j = 0; j < andParts[0].size(); ++j) { - commonParts[andParts[0][j].Get()] = j; - } - - for (ui32 i = 1; i < andParts.size(); ++i) { - THashSet<const TExprNode*> found; - for (ui32 j = 0; j < andParts[i].size(); ++j) { - found.insert(andParts[i][j].Get()); - } - - // remove - for (auto it = commonParts.begin(); it != commonParts.end();) { - if (found.contains(it->first)) { - ++it; - } - else { - commonParts.erase(it++); - } - } - } - - if (commonParts.size() == 0) { - return predicate; - } - - // rebuild commonParts in order of original And - TVector<ui32> idx; - for (const auto& x : commonParts) { - idx.push_back(x.second); - } - - Sort(idx); - TExprNode::TListType andArgs; - for (ui32 i : idx) { - andArgs.push_back(andParts[0][i]); - } - - TExprNode::TListType orArgs; - for (ui32 i = 0; i < andParts.size(); ++i) { - TExprNode::TListType restAndArgs; - for (ui32 j = 0; j < andParts[i].size(); ++j) { - if (commonParts.contains(andParts[i][j].Get())) { - continue; - } - - restAndArgs.push_back(andParts[i][j]); - } - - if (restAndArgs.size() >= 1) { - orArgs.push_back(ctx.NewCallable(predicate->Pos(), "And", std::move(restAndArgs))); - } - } - - if (orArgs.size() >= 1) { - andArgs.push_back(ctx.NewCallable(predicate->Pos(), "Or", std::move(orArgs))); - } - - return ctx.NewCallable(predicate->Pos(), "And", std::move(andArgs)); -} - TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ctx, const TParentsMap& parentsMap) { auto equiJoin = node.Input(); auto structType = equiJoin.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType() @@ -924,7 +793,8 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct predicate = PreparePredicate(predicate, ctx); TExprNode::TListType andTerms; - GatherAndTerms(std::move(predicate), andTerms); + bool isPg; + GatherAndTerms(predicate, andTerms, isPg); TExprNode::TPtr ret; TExprNode::TPtr extraPredicate; auto joinSettings = equiJoin.Ref().Child(equiJoin.Ref().ChildrenSize() - 1); @@ -951,7 +821,7 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct if (inputs.size() == 0) { YQL_CLOG(DEBUG, Core) << "ConstantPredicatePushdownOverEquiJoin"; ret = ConstantPredicatePushdownOverEquiJoin(equiJoin.Ptr(), andTerm, ordered, ctx); - extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, ctx); + extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, isPg, ctx); break; } @@ -961,7 +831,7 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct if (newJoin != equiJoin.Ptr()) { YQL_CLOG(DEBUG, Core) << "SingleInputPredicatePushdownOverEquiJoin"; ret = newJoin; - extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, ctx); + extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, isPg, ctx); break; } } @@ -972,7 +842,7 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct if (newJoin != equiJoin.Ptr()) { YQL_CLOG(DEBUG, Core) << "DecayCrossJoinIntoInner"; ret = newJoin; - extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, ctx); + extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, isPg, ctx); break; } } diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp index c6f351cd60..d479160001 100644 --- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp @@ -1104,6 +1104,67 @@ bool GatherJoinInputs(const TExprNode& root, const TExprNode& row, ui32 rightInp return true; } +TExprNode::TPtr BuildEquiJoin(TPositionHandle pos, TStringBuf joinType, const TExprNode::TPtr& left, const TExprNode::TPtr& right, + const TExprNode::TListType& leftColumns, const TExprNode::TListType& rightColumns, TExprContext& ctx) { + auto join = ctx.Builder(pos) + .Callable("EquiJoin") + .List(0) + .Add(0, left) + .Atom(1, "a") + .Seal() + .List(1) + .Add(0, right) + .Atom(1, "b") + .Seal() + .List(2) + .Atom(0, to_title(TString(joinType))) + .Atom(1, "a") + .Atom(2, "b") + .List(3) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & { + for (ui32 i = 0; i < leftColumns.size(); ++i) { + parent.Atom(2 * i, "a"); + parent.Add(2* i + 1, leftColumns[i]); + } + + return parent; + }) + .Seal() + .List(4) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & { + for (ui32 i = 0; i < rightColumns.size(); ++i) { + parent.Atom(2 * i, "b"); + parent.Add(2 * i + 1, rightColumns[i]); + } + + return parent; + }) + .Seal() + .List(5) + .Seal() + .Seal() + .List(3) + .Seal() + .Seal() + .Build(); + + return ctx.Builder(pos) + .Callable("Map") + .Add(0, join) + .Lambda(1) + .Param("row") + .Callable("DivePrefixMembers") + .Arg(0, "row") + .List(1) + .Atom(0, "a.") + .Atom(1, "b.") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); +} + std::tuple<TVector<ui32>, TExprNode::TListType> BuildJoinGroups(TPositionHandle pos, const TExprNode::TListType& cleanedInputs, const TExprNode::TPtr& joinOps, const THashMap<TString, ui32>& memberToInput, TExprContext& ctx, TOptimizeContext& optCtx) { TVector<ui32> groupForIndex; @@ -1154,6 +1215,53 @@ std::tuple<TVector<ui32>, TExprNode::TListType> BuildJoinGroups(TPositionHandle current = BuildSingleInputPredicateJoin(pos, reverseJoinType, predicate, with, current, ctx); continue; + } else if (hasLeftInput && hasRightInput) { + auto newPredicate = PreparePredicate(predicate->TailPtr(), ctx); + TExprNode::TListType andTerms; + bool isPg; + GatherAndTerms(std::move(newPredicate), andTerms, isPg); + bool bad = false; + TExprNode::TListType leftColumns; + TExprNode::TListType rightColumns; + for (auto& andTerm : andTerms) { + TExprNode::TPtr left, right; + if (!IsEquality(andTerm, left, right)) { + bad = true; + break; + } + + bool leftOnLeft; + if (left->IsCallable("Member") && &left->Head() == &predicate->Head().Head()) { + auto inputPtr = memberToInput.FindPtr(left->Child(1)->Content()); + YQL_ENSURE(inputPtr); + leftOnLeft = (*inputPtr < inputIndex - 1); + (leftOnLeft ? leftColumns : rightColumns).push_back(left->ChildPtr(1)); + } else { + bad = true; + break; + } + + bool rightOnRight; + if (right->IsCallable("Member") && &right->Head() == &predicate->Head().Head()) { + auto inputPtr = memberToInput.FindPtr(right->Child(1)->Content()); + YQL_ENSURE(inputPtr); + rightOnRight = (*inputPtr == inputIndex - 1); + (rightOnRight ? rightColumns : leftColumns).push_back(right->ChildPtr(1)); + } else { + bad = true; + break; + } + + if (leftOnLeft != rightOnRight) { + bad = true; + break; + } + } + + if (!bad) { + current = BuildEquiJoin(pos, joinType, current, with, leftColumns, rightColumns, ctx); + continue; + } } } diff --git a/ydb/library/yql/core/yql_join.cpp b/ydb/library/yql/core/yql_join.cpp index c9324e5a51..df9421aeb2 100644 --- a/ydb/library/yql/core/yql_join.cpp +++ b/ydb/library/yql/core/yql_join.cpp @@ -4,6 +4,7 @@ #include <util/string/cast.h> #include <util/string/join.h> +#include <util/string/type.h> namespace NYql { @@ -1555,4 +1556,171 @@ TExprNode::TPtr MakeCrossJoin(TPositionHandle pos, TExprNode::TPtr left, TExprNo .Build(); } +TExprNode::TPtr PreparePredicate(TExprNode::TPtr predicate, TExprContext& ctx) { + auto originalPredicate = predicate; + bool isPg = false; + if (predicate->IsCallable("ToPg")) { + isPg = true; + predicate = predicate->ChildPtr(0); + } + + if (!predicate->IsCallable("Or")) { + return originalPredicate; + } + + if (predicate->ChildrenSize() == 1) { + return originalPredicate; + } + + // try to extract common And parts from Or + TVector<TExprNode::TListType> andParts; + for (ui32 i = 0; i < predicate->ChildrenSize(); ++i) { + TExprNode::TListType res; + bool isPg; + GatherAndTerms(predicate->ChildPtr(i), res, isPg); + YQL_ENSURE(!isPg); // direct child for Or + andParts.emplace_back(std::move(res)); + } + + THashMap<const TExprNode*, ui32> commonParts; + for (ui32 j = 0; j < andParts[0].size(); ++j) { + commonParts[andParts[0][j].Get()] = j; + } + + for (ui32 i = 1; i < andParts.size(); ++i) { + THashSet<const TExprNode*> found; + for (ui32 j = 0; j < andParts[i].size(); ++j) { + found.insert(andParts[i][j].Get()); + } + + // remove + for (auto it = commonParts.begin(); it != commonParts.end();) { + if (found.contains(it->first)) { + ++it; + } else { + commonParts.erase(it++); + } + } + } + + if (commonParts.size() == 0) { + return originalPredicate; + } + + // rebuild commonParts in order of original And + TVector<ui32> idx; + for (const auto& x : commonParts) { + idx.push_back(x.second); + } + + Sort(idx); + TExprNode::TListType andArgs; + for (ui32 i : idx) { + andArgs.push_back(andParts[0][i]); + } + + TExprNode::TListType orArgs; + for (ui32 i = 0; i < andParts.size(); ++i) { + TExprNode::TListType restAndArgs; + for (ui32 j = 0; j < andParts[i].size(); ++j) { + if (commonParts.contains(andParts[i][j].Get())) { + continue; + } + + restAndArgs.push_back(andParts[i][j]); + } + + if (restAndArgs.size() >= 1) { + orArgs.push_back(ctx.NewCallable(predicate->Pos(), "And", std::move(restAndArgs))); + } + } + + if (orArgs.size() >= 1) { + andArgs.push_back(ctx.NewCallable(predicate->Pos(), "Or", std::move(orArgs))); + } + + auto ret = ctx.NewCallable(predicate->Pos(), "And", std::move(andArgs)); + if (isPg) { + ret = ctx.NewCallable(predicate->Pos(), "ToPg", { ret }); + } + + return ret; +} + +void GatherAndTermsImpl(const TExprNode::TPtr& predicate, TExprNode::TListType& andTerms) { + if (!predicate->IsCallable("And")) { + andTerms.emplace_back(predicate); + return; + } + + for (ui32 i = 0; i < predicate->ChildrenSize(); ++i) { + GatherAndTermsImpl(predicate->ChildPtr(i), andTerms); + } +} + +void GatherAndTerms(const TExprNode::TPtr& predicate, TExprNode::TListType& andTerms, bool& isPg) { + isPg = false; + if (predicate->IsCallable("ToPg")) { + isPg = true; + GatherAndTermsImpl(predicate->HeadPtr(), andTerms); + } else { + GatherAndTermsImpl(predicate, andTerms); + } +} + +TExprNode::TPtr FuseAndTerms(TPositionHandle position, const TExprNode::TListType& andTerms, const TExprNode::TPtr& exclude, bool isPg, TExprContext& ctx) { + TExprNode::TPtr prevAndNode = nullptr; + TNodeSet added; + for (const auto& otherAndTerm : andTerms) { + if (otherAndTerm == exclude) { + continue; + } + + if (!added.insert(otherAndTerm.Get()).second) { + continue; + } + + if (!prevAndNode) { + prevAndNode = otherAndTerm; + } else { + prevAndNode = ctx.NewCallable(position, "And", { prevAndNode, otherAndTerm }); + } + } + + if (isPg) { + return ctx.NewCallable(position, "ToPg", { prevAndNode }); + } else { + return prevAndNode; + } +} + +bool IsEquality(TExprNode::TPtr predicate, TExprNode::TPtr& left, TExprNode::TPtr& right) { + if (predicate->IsCallable("Coalesce")) { + if (predicate->Tail().IsCallable("Bool") && IsFalse(predicate->Tail().Head().Content())) { + predicate = predicate->HeadPtr(); + } else { + return false; + } + } + + if (predicate->IsCallable("FromPg")) { + predicate = predicate->HeadPtr(); + } + + if (predicate->IsCallable("==")) { + left = predicate->ChildPtr(0); + right = predicate->ChildPtr(1); + return true; + } + + if (predicate->IsCallable("PgResolvedOp") && + (predicate->Head().Content() == "=")) { + left = predicate->ChildPtr(2); + right = predicate->ChildPtr(3); + return true; + } + + return false; +} + } // namespace NYql diff --git a/ydb/library/yql/core/yql_join.h b/ydb/library/yql/core/yql_join.h index 69d5f2b2d2..4c07b8ade0 100644 --- a/ydb/library/yql/core/yql_join.h +++ b/ydb/library/yql/core/yql_join.h @@ -144,4 +144,10 @@ TExprNode::TPtr MakeDictForJoin(TExprNode::TPtr&& list, bool payload, bool multi TExprNode::TPtr MakeCrossJoin(TPositionHandle pos, TExprNode::TPtr left, TExprNode::TPtr right, TExprContext& ctx); +TExprNode::TPtr PreparePredicate(TExprNode::TPtr predicate, TExprContext& ctx); +void GatherAndTerms(const TExprNode::TPtr& predicate, TExprNode::TListType& andTerms, bool& isPg); +TExprNode::TPtr FuseAndTerms(TPositionHandle position, const TExprNode::TListType& andTerms, const TExprNode::TPtr& exclude, bool isPg, TExprContext& ctx); + +bool IsEquality(TExprNode::TPtr predicate, TExprNode::TPtr& left, TExprNode::TPtr& right); + } |