aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraozeritsky <aozeritsky@ydb.tech>2023-09-23 16:07:12 +0300
committeraozeritsky <aozeritsky@ydb.tech>2023-09-23 16:22:25 +0300
commiteaff9b03851a5dc5a6f1584598c252ae585ea7a4 (patch)
tree0905ab643805c40d3f3ae3dad68fb8aecbebb808
parent936ea741c05fc55a8afd90a88503a44353490b85 (diff)
downloadydb-eaff9b03851a5dc5a6f1584598c252ae585ea7a4.tar.gz
Can use generic cbo interface in dq
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt1
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt1
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt1
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp214
-rw-r--r--ydb/library/yql/dq/opt/ya.make1
6 files changed, 219 insertions, 0 deletions
diff --git a/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt
index b3ced877418..a8a812b9acd 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt
@@ -34,4 +34,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_stat.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
)
diff --git a/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt b/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt
index 087c54c61e6..c559bb7f28e 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt
@@ -35,4 +35,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_stat.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
)
diff --git a/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt
index 087c54c61e6..c559bb7f28e 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt
@@ -35,4 +35,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_stat.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
)
diff --git a/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt
index b3ced877418..a8a812b9acd 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt
@@ -34,4 +34,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_stat.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
)
diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
new file mode 100644
index 00000000000..b4bf3be73ff
--- /dev/null
+++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
@@ -0,0 +1,214 @@
+#include "dq_opt_join.h"
+#include "dq_opt_phy.h"
+
+#include <ydb/library/yql/core/cbo/cbo_optimizer.h>
+#include <ydb/library/yql/utils/log/log.h>
+#include <ydb/library/yql/core/yql_opt_utils.h>
+
+namespace NYql::NDq {
+
+using namespace NYql::NNodes;
+
+namespace {
+
+struct TState {
+ IOptimizer::TInput Input;
+ IOptimizer::TOutput Result;
+ std::vector<TStringBuf> Tables; // relId -> table
+ std::vector<THashMap<TStringBuf, int>> VarIds; // relId -> varsIds
+ THashMap<TStringBuf, std::vector<int>> Table2RelIds;
+ std::vector<std::vector<std::tuple<TStringBuf, TStringBuf>>> Var2TableCol; // relId, varId -> table, col
+ TPositionHandle Pos;
+
+ TState(const TCoEquiJoin& join)
+ : Pos(join.Pos())
+ { }
+
+ int GetVarId(int relId, TStringBuf column) {
+ int varId = 0;
+ auto maybeVarId = VarIds[relId-1].find(column);
+ if (maybeVarId != VarIds[relId-1].end()) {
+ varId = maybeVarId->second;
+ } else {
+ varId = Input.Rels[relId - 1].TargetVars.size() + 1;
+ VarIds[relId - 1][column] = varId;
+ Input.Rels[relId - 1].TargetVars.emplace_back();
+ Var2TableCol[relId - 1].emplace_back();
+ }
+ return varId;
+ }
+
+ void CollectRel(TStringBuf label, auto stat) {
+ Input.Rels.emplace_back();
+ Var2TableCol.emplace_back();
+ int relId = Input.Rels.size();
+ Input.Rels.back().Rows = stat->Nrows;
+ Input.Rels.back().TotalCost = *stat->Cost;
+ Tables.emplace_back(label);
+ Table2RelIds[label].emplace_back(relId);
+ }
+
+ void CollectJoins(TExprNode::TPtr joinTuple) {
+ // joinTuple->Child(0)->Content(); // type
+ auto leftScope = joinTuple->Child(1);
+ auto rightScope = joinTuple->Child(2);
+ auto leftKeys = joinTuple->Child(3);
+ auto rightKeys = joinTuple->Child(4);
+
+ YQL_ENSURE(leftKeys->ChildrenSize() == rightKeys->ChildrenSize());
+ for (ui32 i = 0; i < leftKeys->ChildrenSize(); i += 2) {
+ auto ltable = leftKeys->Child(i)->Content();
+ auto lcolumn = leftKeys->Child(i + 1)->Content();
+ auto rtable = rightKeys->Child(i)->Content();
+ auto rcolumn = rightKeys->Child(i + 1)->Content();
+
+ size_t nclasses = Input.EqClasses.size();
+ for (auto lrelId : Table2RelIds[ltable]) {
+ for (auto rrelId : Table2RelIds[rtable]) {
+ auto lvarId = GetVarId(lrelId, lcolumn);
+ auto rvarId = GetVarId(rrelId, rcolumn);
+
+ Var2TableCol[lrelId - 1][lvarId - 1] = std::make_tuple(ltable, lcolumn);
+ Var2TableCol[rrelId - 1][rvarId - 1] = std::make_tuple(rtable, rcolumn);
+
+ IOptimizer::TEq eqClass; eqClass.Vars.reserve(2);
+ eqClass.Vars.emplace_back(std::make_tuple(lrelId, lvarId));
+ eqClass.Vars.emplace_back(std::make_tuple(rrelId, rvarId));
+
+ Input.EqClasses.emplace_back(std::move(eqClass));
+ }
+ }
+
+ YQL_ENSURE(nclasses != Input.EqClasses.size());
+ }
+
+ if (TMaybeNode<TCoEquiJoinTuple>(leftScope)) {
+ CollectJoins(leftScope);
+ }
+ if (TMaybeNode<TCoEquiJoinTuple>(rightScope)) {
+ CollectJoins(rightScope);
+ }
+ }
+
+ TExprNode::TPtr MakeLabel(TExprContext& ctx, const std::vector<IOptimizer::TVarId>& vars) const {
+ TVector<TExprNodePtr> label; label.reserve(vars.size() * 2);
+
+ for (auto [relId, varId] : vars) {
+ auto [table, column] = Var2TableCol[relId - 1][varId - 1];
+
+ label.emplace_back(ctx.NewAtom(Pos, table));
+ label.emplace_back(ctx.NewAtom(Pos, column));
+ }
+
+ return Build<TCoAtomList>(ctx, Pos)
+ .Add(label)
+ .Done()
+ .Ptr();
+ }
+
+ TExprBase BuildTree(TExprContext& ctx, int nodeId) {
+ const IOptimizer::TJoinNode* node = &Result.Nodes[nodeId];
+ if (node->Outer == -1 && node->Inner == -1) {
+ // leaf
+ YQL_ENSURE(node->Rels.size() == 1);
+ auto scope = Tables[node->Rels[0]-1];
+ return BuildAtom(scope, Pos, ctx);
+ } else if (node->Outer != -1 && node->Inner != -1) {
+ TString joinKind;
+ switch (node->Mode) {
+ case IOptimizer::EJoinType::Inner:
+ joinKind = "Inner";
+ break;
+ case IOptimizer::EJoinType::Left:
+ joinKind = "Left";
+ break;
+ case IOptimizer::EJoinType::Right:
+ joinKind = "Right";
+ break;
+ default:
+ YQL_ENSURE(false, "Unsupported join type");
+ break;
+ }
+
+ TVector<TExprBase> options;
+ return Build<TCoEquiJoinTuple>(ctx, Pos)
+ .Type(BuildAtom(joinKind, Pos, ctx))
+ .LeftScope(BuildTree(ctx, node->Outer))
+ .RightScope(BuildTree(ctx, node->Inner))
+ .LeftKeys()
+ .Add(MakeLabel(ctx, node->LeftVars))
+ .Build()
+ .RightKeys()
+ .Add(MakeLabel(ctx, node->RightVars))
+ .Build()
+ .Options()
+ .Add(options)
+ .Build()
+ .Done();
+ } else {
+ YQL_ENSURE(false, "Wrong CBO node");
+ }
+ }
+};
+
+} // namespace
+
+TExprBase DqOptimizeEquiJoinWithCosts(
+ const TExprBase& node,
+ TExprContext& ctx,
+ TTypeAnnotationContext& typesCtx,
+ const std::function<IOptimizer*(IOptimizer::TInput&&)> optFactory,
+ bool ruleEnabled)
+{
+ Y_UNUSED(ctx);
+
+ if (!ruleEnabled) {
+ return node;
+ }
+
+ if (!node.Maybe<TCoEquiJoin>()) {
+ return node;
+ }
+ auto equiJoin = node.Cast<TCoEquiJoin>();
+ YQL_ENSURE(equiJoin.ArgCount() >= 4);
+
+ auto maybeStat = typesCtx.StatisticsMap.find(equiJoin.Raw());
+ if (maybeStat != typesCtx.StatisticsMap.end() &&
+ maybeStat->second->Cost.has_value()) {
+ return node;
+ }
+
+ if (! HasOnlyOneJoinType(*equiJoin.Arg(equiJoin.ArgCount() - 2).Ptr(), "Inner")) {
+ return node;
+ }
+
+ YQL_CLOG(TRACE, CoreDq) << "Optimizing join with costs";
+
+ TState state(equiJoin);
+ // collect Rels
+ if (!DqCollectJoinRelationsWithStats(typesCtx, equiJoin, [&](auto label, auto stat) {
+ state.CollectRel(label, stat);
+ })) {
+ return node;
+ }
+
+ state.CollectJoins(equiJoin.Arg(equiJoin.ArgCount() - 2).Ptr());
+ state.Input.Normalize();
+ std::unique_ptr<IOptimizer> opt = std::unique_ptr<IOptimizer>(optFactory(std::move(state.Input)));
+ state.Result = opt->JoinSearch();
+
+ TVector<TExprBase> joinArgs;
+ for (size_t i = 0; i < equiJoin.ArgCount() - 2; i++) {
+ joinArgs.push_back(equiJoin.Arg(i));
+ }
+
+ joinArgs.push_back(state.BuildTree(ctx, 0));
+ joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount() - 1));
+
+ return Build<TCoEquiJoin>(ctx, equiJoin.Pos())
+ .Add(joinArgs)
+ .Done();
+}
+
+} // namespace NYql::NDq
+
diff --git a/ydb/library/yql/dq/opt/ya.make b/ydb/library/yql/dq/opt/ya.make
index 4d920e0d8dc..d7932027985 100644
--- a/ydb/library/yql/dq/opt/ya.make
+++ b/ydb/library/yql/dq/opt/ya.make
@@ -21,6 +21,7 @@ SRCS(
dq_opt_phy.cpp
dq_opt_stat.cpp
dq_opt_join_cost_based.cpp
+ dq_opt_join_cost_based_generic.cpp
)
YQL_LAST_ABI_VERSION()