aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVitaly Stoyan <vitstn@gmail.com>2022-06-29 15:20:42 +0300
committerVitaly Stoyan <vitstn@gmail.com>2022-06-29 15:20:42 +0300
commit5bc762a70a35cb6aa5ca6547538736c38e7bc064 (patch)
tree758c8959d91d5c17f6a042d53d45c7b359bd6083
parent529a9604e2a54a886109bbd03f04ccf1847540ec (diff)
downloadydb-5bc762a70a35cb6aa5ca6547538736c38e7bc064.tar.gz
YQL-13966 strict equijoin
ref:0ecafb2e87a88272aec5e778164d09ece2117de7
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_flow2.cpp144
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_pgselect.cpp108
-rw-r--r--ydb/library/yql/core/yql_join.cpp168
-rw-r--r--ydb/library/yql/core/yql_join.h6
4 files changed, 289 insertions, 137 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
index 0300178d89..7a67e5bc02 100644
--- a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
@@ -9,7 +9,6 @@
#include <ydb/library/yql/core/yql_type_helpers.h>
#include <ydb/library/yql/utils/log/log.h>
-#include <util/string/type.h>
namespace NYql {
namespace {
@@ -114,40 +113,6 @@ TExprNode::TPtr AggregateSubsetFieldsAnalyzer(const TCoAggregate& node, TExprCon
return ret;
}
-void GatherAndTerms(TExprNode::TPtr&& predicate, TExprNode::TListType& andTerms) {
- if (!predicate->IsCallable("And")) {
- andTerms.emplace_back(std::move(predicate));
- return;
- }
-
- for (auto& child : predicate->ChildrenList()) {
- GatherAndTerms(std::move(child), andTerms);
- }
-}
-
-TExprNode::TPtr FuseAndTerms(TPositionHandle position, const TExprNode::TListType& andTerms, TExprNode::TPtr exclude, TExprContext& ctx) {
- TExprNode::TPtr prevAndNode = nullptr;
- TNodeSet added;
- for (const auto& otherAndTerm : andTerms) {
- if (otherAndTerm == exclude) {
- continue;
- }
-
- if (!added.insert(otherAndTerm.Get()).second) {
- continue;
- }
-
- if (!prevAndNode) {
- prevAndNode = otherAndTerm;
- }
- else {
- prevAndNode = ctx.NewCallable(position, "And", { prevAndNode, otherAndTerm });
- }
- }
-
- return prevAndNode;
-}
-
TExprNode::TPtr ConstantPredicatePushdownOverEquiJoin(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate, bool ordered, TExprContext& ctx) {
auto lambda = ctx.Builder(predicate->Pos())
.Lambda()
@@ -650,30 +615,12 @@ private:
TExprContext& Ctx;
};
-TExprNode::TPtr DecayCrossJoinIntoInner(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate,
+TExprNode::TPtr DecayCrossJoinIntoInner(TExprNode::TPtr equiJoin, const TExprNode::TPtr& predicate,
const TJoinLabels& labels, ui32 index1, ui32 index2, const TExprNode& row, const THashMap<TString, TString>& backRenameMap,
const TParentsMap& parentsMap, TExprContext& ctx) {
YQL_ENSURE(index1 != index2);
- bool withCoalesce = false;
- if (predicate->IsCallable("Coalesce")) {
- if (predicate->Tail().IsCallable("Bool") && IsFalse(predicate->Tail().Head().Content())) {
- withCoalesce = true;
- predicate = predicate->HeadPtr();
- } else {
- return equiJoin;
- }
- }
-
-
TExprNode::TPtr left, right;
- if (predicate->IsCallable("==")) {
- left = predicate->ChildPtr(0);
- right = predicate->ChildPtr(1);
- } else if (predicate->IsCallable("FromPg") && predicate->Head().IsCallable("PgResolvedOp") &&
- (predicate->Head().Head().Content() == "=")) {
- left = predicate->Head().ChildPtr(2);
- right = predicate->Head().ChildPtr(3);
- } else {
+ if (!IsEquality(predicate, left, right)) {
return equiJoin;
}
@@ -729,84 +676,6 @@ TExprNode::TPtr DecayCrossJoinIntoInner(TExprNode::TPtr equiJoin, TExprNode::TPt
return ctx.ChangeChild(*equiJoin, inputsCount, std::move(newJoinTree));
}
-TExprNode::TPtr PreparePredicate(TExprNode::TPtr predicate, TExprContext& ctx) {
- if (!predicate->IsCallable("Or")) {
- return predicate;
- }
-
- if (predicate->ChildrenSize() == 1) {
- return predicate->HeadPtr();
- }
-
- // try to extract common And parts from Or
- TVector<TExprNode::TListType> andParts;
- for (ui32 i = 0; i < predicate->ChildrenSize(); ++i) {
- TExprNode::TListType res;
- GatherAndTerms(predicate->ChildPtr(i), res);
- andParts.emplace_back(std::move(res));
- }
-
- THashMap<const TExprNode*, ui32> commonParts;
- for (ui32 j = 0; j < andParts[0].size(); ++j) {
- commonParts[andParts[0][j].Get()] = j;
- }
-
- for (ui32 i = 1; i < andParts.size(); ++i) {
- THashSet<const TExprNode*> found;
- for (ui32 j = 0; j < andParts[i].size(); ++j) {
- found.insert(andParts[i][j].Get());
- }
-
- // remove
- for (auto it = commonParts.begin(); it != commonParts.end();) {
- if (found.contains(it->first)) {
- ++it;
- }
- else {
- commonParts.erase(it++);
- }
- }
- }
-
- if (commonParts.size() == 0) {
- return predicate;
- }
-
- // rebuild commonParts in order of original And
- TVector<ui32> idx;
- for (const auto& x : commonParts) {
- idx.push_back(x.second);
- }
-
- Sort(idx);
- TExprNode::TListType andArgs;
- for (ui32 i : idx) {
- andArgs.push_back(andParts[0][i]);
- }
-
- TExprNode::TListType orArgs;
- for (ui32 i = 0; i < andParts.size(); ++i) {
- TExprNode::TListType restAndArgs;
- for (ui32 j = 0; j < andParts[i].size(); ++j) {
- if (commonParts.contains(andParts[i][j].Get())) {
- continue;
- }
-
- restAndArgs.push_back(andParts[i][j]);
- }
-
- if (restAndArgs.size() >= 1) {
- orArgs.push_back(ctx.NewCallable(predicate->Pos(), "And", std::move(restAndArgs)));
- }
- }
-
- if (orArgs.size() >= 1) {
- andArgs.push_back(ctx.NewCallable(predicate->Pos(), "Or", std::move(orArgs)));
- }
-
- return ctx.NewCallable(predicate->Pos(), "And", std::move(andArgs));
-}
-
TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ctx, const TParentsMap& parentsMap) {
auto equiJoin = node.Input();
auto structType = equiJoin.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType()
@@ -924,7 +793,8 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct
predicate = PreparePredicate(predicate, ctx);
TExprNode::TListType andTerms;
- GatherAndTerms(std::move(predicate), andTerms);
+ bool isPg;
+ GatherAndTerms(predicate, andTerms, isPg);
TExprNode::TPtr ret;
TExprNode::TPtr extraPredicate;
auto joinSettings = equiJoin.Ref().Child(equiJoin.Ref().ChildrenSize() - 1);
@@ -951,7 +821,7 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct
if (inputs.size() == 0) {
YQL_CLOG(DEBUG, Core) << "ConstantPredicatePushdownOverEquiJoin";
ret = ConstantPredicatePushdownOverEquiJoin(equiJoin.Ptr(), andTerm, ordered, ctx);
- extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, ctx);
+ extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, isPg, ctx);
break;
}
@@ -961,7 +831,7 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct
if (newJoin != equiJoin.Ptr()) {
YQL_CLOG(DEBUG, Core) << "SingleInputPredicatePushdownOverEquiJoin";
ret = newJoin;
- extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, ctx);
+ extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, isPg, ctx);
break;
}
}
@@ -972,7 +842,7 @@ TExprNode::TPtr FlatMapOverEquiJoin(const TCoFlatMapBase& node, TExprContext& ct
if (newJoin != equiJoin.Ptr()) {
YQL_CLOG(DEBUG, Core) << "DecayCrossJoinIntoInner";
ret = newJoin;
- extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, ctx);
+ extraPredicate = FuseAndTerms(node.Pos(), andTerms, andTerm, isPg, ctx);
break;
}
}
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
index c6f351cd60..d479160001 100644
--- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
@@ -1104,6 +1104,67 @@ bool GatherJoinInputs(const TExprNode& root, const TExprNode& row, ui32 rightInp
return true;
}
+TExprNode::TPtr BuildEquiJoin(TPositionHandle pos, TStringBuf joinType, const TExprNode::TPtr& left, const TExprNode::TPtr& right,
+ const TExprNode::TListType& leftColumns, const TExprNode::TListType& rightColumns, TExprContext& ctx) {
+ auto join = ctx.Builder(pos)
+ .Callable("EquiJoin")
+ .List(0)
+ .Add(0, left)
+ .Atom(1, "a")
+ .Seal()
+ .List(1)
+ .Add(0, right)
+ .Atom(1, "b")
+ .Seal()
+ .List(2)
+ .Atom(0, to_title(TString(joinType)))
+ .Atom(1, "a")
+ .Atom(2, "b")
+ .List(3)
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & {
+ for (ui32 i = 0; i < leftColumns.size(); ++i) {
+ parent.Atom(2 * i, "a");
+ parent.Add(2* i + 1, leftColumns[i]);
+ }
+
+ return parent;
+ })
+ .Seal()
+ .List(4)
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & {
+ for (ui32 i = 0; i < rightColumns.size(); ++i) {
+ parent.Atom(2 * i, "b");
+ parent.Add(2 * i + 1, rightColumns[i]);
+ }
+
+ return parent;
+ })
+ .Seal()
+ .List(5)
+ .Seal()
+ .Seal()
+ .List(3)
+ .Seal()
+ .Seal()
+ .Build();
+
+ return ctx.Builder(pos)
+ .Callable("Map")
+ .Add(0, join)
+ .Lambda(1)
+ .Param("row")
+ .Callable("DivePrefixMembers")
+ .Arg(0, "row")
+ .List(1)
+ .Atom(0, "a.")
+ .Atom(1, "b.")
+ .Seal()
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+}
+
std::tuple<TVector<ui32>, TExprNode::TListType> BuildJoinGroups(TPositionHandle pos, const TExprNode::TListType& cleanedInputs,
const TExprNode::TPtr& joinOps, const THashMap<TString, ui32>& memberToInput, TExprContext& ctx, TOptimizeContext& optCtx) {
TVector<ui32> groupForIndex;
@@ -1154,6 +1215,53 @@ std::tuple<TVector<ui32>, TExprNode::TListType> BuildJoinGroups(TPositionHandle
current = BuildSingleInputPredicateJoin(pos, reverseJoinType, predicate, with, current, ctx);
continue;
+ } else if (hasLeftInput && hasRightInput) {
+ auto newPredicate = PreparePredicate(predicate->TailPtr(), ctx);
+ TExprNode::TListType andTerms;
+ bool isPg;
+ GatherAndTerms(std::move(newPredicate), andTerms, isPg);
+ bool bad = false;
+ TExprNode::TListType leftColumns;
+ TExprNode::TListType rightColumns;
+ for (auto& andTerm : andTerms) {
+ TExprNode::TPtr left, right;
+ if (!IsEquality(andTerm, left, right)) {
+ bad = true;
+ break;
+ }
+
+ bool leftOnLeft;
+ if (left->IsCallable("Member") && &left->Head() == &predicate->Head().Head()) {
+ auto inputPtr = memberToInput.FindPtr(left->Child(1)->Content());
+ YQL_ENSURE(inputPtr);
+ leftOnLeft = (*inputPtr < inputIndex - 1);
+ (leftOnLeft ? leftColumns : rightColumns).push_back(left->ChildPtr(1));
+ } else {
+ bad = true;
+ break;
+ }
+
+ bool rightOnRight;
+ if (right->IsCallable("Member") && &right->Head() == &predicate->Head().Head()) {
+ auto inputPtr = memberToInput.FindPtr(right->Child(1)->Content());
+ YQL_ENSURE(inputPtr);
+ rightOnRight = (*inputPtr == inputIndex - 1);
+ (rightOnRight ? rightColumns : leftColumns).push_back(right->ChildPtr(1));
+ } else {
+ bad = true;
+ break;
+ }
+
+ if (leftOnLeft != rightOnRight) {
+ bad = true;
+ break;
+ }
+ }
+
+ if (!bad) {
+ current = BuildEquiJoin(pos, joinType, current, with, leftColumns, rightColumns, ctx);
+ continue;
+ }
}
}
diff --git a/ydb/library/yql/core/yql_join.cpp b/ydb/library/yql/core/yql_join.cpp
index c9324e5a51..df9421aeb2 100644
--- a/ydb/library/yql/core/yql_join.cpp
+++ b/ydb/library/yql/core/yql_join.cpp
@@ -4,6 +4,7 @@
#include <util/string/cast.h>
#include <util/string/join.h>
+#include <util/string/type.h>
namespace NYql {
@@ -1555,4 +1556,171 @@ TExprNode::TPtr MakeCrossJoin(TPositionHandle pos, TExprNode::TPtr left, TExprNo
.Build();
}
+TExprNode::TPtr PreparePredicate(TExprNode::TPtr predicate, TExprContext& ctx) {
+ auto originalPredicate = predicate;
+ bool isPg = false;
+ if (predicate->IsCallable("ToPg")) {
+ isPg = true;
+ predicate = predicate->ChildPtr(0);
+ }
+
+ if (!predicate->IsCallable("Or")) {
+ return originalPredicate;
+ }
+
+ if (predicate->ChildrenSize() == 1) {
+ return originalPredicate;
+ }
+
+ // try to extract common And parts from Or
+ TVector<TExprNode::TListType> andParts;
+ for (ui32 i = 0; i < predicate->ChildrenSize(); ++i) {
+ TExprNode::TListType res;
+ bool isPg;
+ GatherAndTerms(predicate->ChildPtr(i), res, isPg);
+ YQL_ENSURE(!isPg); // direct child for Or
+ andParts.emplace_back(std::move(res));
+ }
+
+ THashMap<const TExprNode*, ui32> commonParts;
+ for (ui32 j = 0; j < andParts[0].size(); ++j) {
+ commonParts[andParts[0][j].Get()] = j;
+ }
+
+ for (ui32 i = 1; i < andParts.size(); ++i) {
+ THashSet<const TExprNode*> found;
+ for (ui32 j = 0; j < andParts[i].size(); ++j) {
+ found.insert(andParts[i][j].Get());
+ }
+
+ // remove
+ for (auto it = commonParts.begin(); it != commonParts.end();) {
+ if (found.contains(it->first)) {
+ ++it;
+ } else {
+ commonParts.erase(it++);
+ }
+ }
+ }
+
+ if (commonParts.size() == 0) {
+ return originalPredicate;
+ }
+
+ // rebuild commonParts in order of original And
+ TVector<ui32> idx;
+ for (const auto& x : commonParts) {
+ idx.push_back(x.second);
+ }
+
+ Sort(idx);
+ TExprNode::TListType andArgs;
+ for (ui32 i : idx) {
+ andArgs.push_back(andParts[0][i]);
+ }
+
+ TExprNode::TListType orArgs;
+ for (ui32 i = 0; i < andParts.size(); ++i) {
+ TExprNode::TListType restAndArgs;
+ for (ui32 j = 0; j < andParts[i].size(); ++j) {
+ if (commonParts.contains(andParts[i][j].Get())) {
+ continue;
+ }
+
+ restAndArgs.push_back(andParts[i][j]);
+ }
+
+ if (restAndArgs.size() >= 1) {
+ orArgs.push_back(ctx.NewCallable(predicate->Pos(), "And", std::move(restAndArgs)));
+ }
+ }
+
+ if (orArgs.size() >= 1) {
+ andArgs.push_back(ctx.NewCallable(predicate->Pos(), "Or", std::move(orArgs)));
+ }
+
+ auto ret = ctx.NewCallable(predicate->Pos(), "And", std::move(andArgs));
+ if (isPg) {
+ ret = ctx.NewCallable(predicate->Pos(), "ToPg", { ret });
+ }
+
+ return ret;
+}
+
+void GatherAndTermsImpl(const TExprNode::TPtr& predicate, TExprNode::TListType& andTerms) {
+ if (!predicate->IsCallable("And")) {
+ andTerms.emplace_back(predicate);
+ return;
+ }
+
+ for (ui32 i = 0; i < predicate->ChildrenSize(); ++i) {
+ GatherAndTermsImpl(predicate->ChildPtr(i), andTerms);
+ }
+}
+
+void GatherAndTerms(const TExprNode::TPtr& predicate, TExprNode::TListType& andTerms, bool& isPg) {
+ isPg = false;
+ if (predicate->IsCallable("ToPg")) {
+ isPg = true;
+ GatherAndTermsImpl(predicate->HeadPtr(), andTerms);
+ } else {
+ GatherAndTermsImpl(predicate, andTerms);
+ }
+}
+
+TExprNode::TPtr FuseAndTerms(TPositionHandle position, const TExprNode::TListType& andTerms, const TExprNode::TPtr& exclude, bool isPg, TExprContext& ctx) {
+ TExprNode::TPtr prevAndNode = nullptr;
+ TNodeSet added;
+ for (const auto& otherAndTerm : andTerms) {
+ if (otherAndTerm == exclude) {
+ continue;
+ }
+
+ if (!added.insert(otherAndTerm.Get()).second) {
+ continue;
+ }
+
+ if (!prevAndNode) {
+ prevAndNode = otherAndTerm;
+ } else {
+ prevAndNode = ctx.NewCallable(position, "And", { prevAndNode, otherAndTerm });
+ }
+ }
+
+ if (isPg) {
+ return ctx.NewCallable(position, "ToPg", { prevAndNode });
+ } else {
+ return prevAndNode;
+ }
+}
+
+bool IsEquality(TExprNode::TPtr predicate, TExprNode::TPtr& left, TExprNode::TPtr& right) {
+ if (predicate->IsCallable("Coalesce")) {
+ if (predicate->Tail().IsCallable("Bool") && IsFalse(predicate->Tail().Head().Content())) {
+ predicate = predicate->HeadPtr();
+ } else {
+ return false;
+ }
+ }
+
+ if (predicate->IsCallable("FromPg")) {
+ predicate = predicate->HeadPtr();
+ }
+
+ if (predicate->IsCallable("==")) {
+ left = predicate->ChildPtr(0);
+ right = predicate->ChildPtr(1);
+ return true;
+ }
+
+ if (predicate->IsCallable("PgResolvedOp") &&
+ (predicate->Head().Content() == "=")) {
+ left = predicate->ChildPtr(2);
+ right = predicate->ChildPtr(3);
+ return true;
+ }
+
+ return false;
+}
+
} // namespace NYql
diff --git a/ydb/library/yql/core/yql_join.h b/ydb/library/yql/core/yql_join.h
index 69d5f2b2d2..4c07b8ade0 100644
--- a/ydb/library/yql/core/yql_join.h
+++ b/ydb/library/yql/core/yql_join.h
@@ -144,4 +144,10 @@ TExprNode::TPtr MakeDictForJoin(TExprNode::TPtr&& list, bool payload, bool multi
TExprNode::TPtr MakeCrossJoin(TPositionHandle pos, TExprNode::TPtr left, TExprNode::TPtr right, TExprContext& ctx);
+TExprNode::TPtr PreparePredicate(TExprNode::TPtr predicate, TExprContext& ctx);
+void GatherAndTerms(const TExprNode::TPtr& predicate, TExprNode::TListType& andTerms, bool& isPg);
+TExprNode::TPtr FuseAndTerms(TPositionHandle position, const TExprNode::TListType& andTerms, const TExprNode::TPtr& exclude, bool isPg, TExprContext& ctx);
+
+bool IsEquality(TExprNode::TPtr predicate, TExprNode::TPtr& left, TExprNode::TPtr& right);
+
}