aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorpavelvelikhov <pavelvelikhov@yandex-team.com>2023-08-08 13:13:25 +0300
committerpavelvelikhov <pavelvelikhov@yandex-team.com>2023-08-08 15:03:00 +0300
commit46501e2a1aa36cdac2de76ebefac31d1c00a1ff0 (patch)
tree3e5533f39a70e3d7ca5f7ead775c2a66ce90fe52
parent94d58dee6279337ceef3aaab04b7ae2225584323 (diff)
downloadydb-46501e2a1aa36cdac2de76ebefac31d1c00a1ff0.tar.gz
First version of cost based optimizer
Added feature flags and moved debug output to logger Working version, moved stats to type ann Updated CBO Initial commit of CBO
-rw-r--r--ydb/core/kqp/host/kqp_runner.cpp3
-rw-r--r--ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt1
-rw-r--r--ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt1
-rw-r--r--ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt1
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.cpp157
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.h45
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_log.cpp7
-rw-r--r--ydb/core/kqp/opt/ya.make1
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_settings.cpp6
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_settings.h3
-rw-r--r--ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt1
-rw-r--r--ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt1
-rw-r--r--ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt1
-rw-r--r--ydb/core/kqp/ut/join/kqp_join_order_ut.cpp281
-rw-r--r--ydb/core/kqp/ut/join/ya.make1
-rw-r--r--ydb/library/yql/core/CMakeLists.darwin-x86_64.txt1
-rw-r--r--ydb/library/yql/core/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/library/yql/core/CMakeLists.linux-x86_64.txt1
-rw-r--r--ydb/library/yql/core/CMakeLists.windows-x86_64.txt1
-rw-r--r--ydb/library/yql/core/ya.make1
-rw-r--r--ydb/library/yql/core/yql_statistics.cpp14
-rw-r--r--ydb/library/yql/core/yql_statistics.h26
-rw-r--r--ydb/library/yql/core/yql_type_annotation.h2
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt1
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt1
-rw-r--r--ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt1
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp1064
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_log.h4
-rw-r--r--ydb/library/yql/dq/opt/ya.make1
32 files changed, 1632 insertions, 0 deletions
diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp
index ecbffe376e3..b46b12b476a 100644
--- a/ydb/core/kqp/host/kqp_runner.cpp
+++ b/ydb/core/kqp/host/kqp_runner.cpp
@@ -4,6 +4,8 @@
#include <ydb/core/kqp/query_compiler/kqp_query_compiler.h>
#include <ydb/core/kqp/opt/kqp_opt.h>
#include <ydb/core/kqp/opt/logical/kqp_opt_log.h>
+#include <ydb/core/kqp/opt/kqp_statistics_transformer.h>
+
#include <ydb/core/kqp/opt/physical/kqp_opt_phy.h>
#include <ydb/core/kqp/opt/peephole/kqp_opt_peephole.h>
#include <ydb/core/kqp/opt/kqp_query_plan.h>
@@ -89,6 +91,7 @@ public:
.Add(CreateKqpCheckQueryTransformer(), "CheckKqlQuery")
.AddPostTypeAnnotation(/* forSubgraph */ true)
.AddCommonOptimization()
+ .Add(CreateKqpStatisticsTransformer(*typesCtx, Config), "Statistics")
.Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config), "LogicalOptimize")
.Add(CreateLogicalDataProposalsInspector(*typesCtx), "ProvidersLogicalOptimize")
.Add(CreateKqpPhyOptTransformer(OptimizeCtx, *typesCtx), "KqpPhysicalOptimize")
diff --git a/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt b/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt
index e21e3a2928e..9c3d4e474fd 100644
--- a/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt
+++ b/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt
@@ -44,6 +44,7 @@ target_sources(core-kqp-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
)
generate_enum_serilization(core-kqp-opt
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h
diff --git a/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt b/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt
index 2f01eb2c6d3..3cb516fc560 100644
--- a/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt
+++ b/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt
@@ -45,6 +45,7 @@ target_sources(core-kqp-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
)
generate_enum_serilization(core-kqp-opt
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h
diff --git a/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt b/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt
index 2f01eb2c6d3..3cb516fc560 100644
--- a/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt
+++ b/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt
@@ -45,6 +45,7 @@ target_sources(core-kqp-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
)
generate_enum_serilization(core-kqp-opt
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h
diff --git a/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt b/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt
index e21e3a2928e..9c3d4e474fd 100644
--- a/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt
+++ b/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt
@@ -44,6 +44,7 @@ target_sources(core-kqp-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
)
generate_enum_serilization(core-kqp-opt
${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
new file mode 100644
index 00000000000..c1dfd05a3a0
--- /dev/null
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
@@ -0,0 +1,157 @@
+#include "kqp_statistics_transformer.h"
+#include <ydb/library/yql/utils/log/log.h>
+
+
+using namespace NYql;
+using namespace NYql::NNodes;
+using namespace NKikimr::NKqp;
+
+namespace {
+
+/**
+ * Helper method to fetch statistics from type annotation context
+*/
+std::shared_ptr<TOptimizerStatistics> GetStats( const TExprNode* input, TTypeAnnotationContext* typeCtx ) {
+
+ return typeCtx->StatisticsMap.Value(input, std::shared_ptr<TOptimizerStatistics>(nullptr));
+}
+
+/**
+ * Helper method to set statistics in type annotation context
+*/
+void SetStats( const TExprNode* input, TTypeAnnotationContext* typeCtx, std::shared_ptr<TOptimizerStatistics> stats ) {
+
+ typeCtx->StatisticsMap[input] = stats;
+}
+
+/**
+ * Helper method to get cost from type annotation context
+ * Doesn't check if the cost is in the mapping
+*/
+std::optional<double> GetCost( const TExprNode* input, TTypeAnnotationContext* typeCtx ) {
+ return typeCtx->StatisticsMap[input]->Cost;
+}
+
+/**
+ * Helper method to set the cost in type annotation context
+*/
+void SetCost( const TExprNode* input, TTypeAnnotationContext* typeCtx, std::optional<double> cost ) {
+ typeCtx->StatisticsMap[input]->Cost = cost;
+}
+}
+
+/**
+ * For Flatmap we check the input and fetch the statistcs and cost from below
+ * Then we analyze the filter predicate and compute it's selectivity and apply it
+ * to the result.
+*/
+void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto flatmap = inputNode.Cast<TCoFlatMap>();
+ if (!IsPredicateFlatMap(flatmap.Lambda().Body().Ref())) {
+ return;
+ }
+
+ auto flatmapInput = flatmap.Input();
+ auto inputStats = GetStats(flatmapInput.Raw(), typeCtx);
+
+ if (! inputStats ) {
+ return;
+ }
+
+ // Selectivity is the fraction of tuples that are selected by this predicate
+ // Currently we just set the number to 10% before we have statistics and parse
+ // the predicate
+ double selectivity = 0.1;
+
+ auto outputStats = TOptimizerStatistics(inputStats->Nrows * selectivity, inputStats->Ncols);
+
+ SetStats(input.Get(), typeCtx, std::make_shared<TOptimizerStatistics>(outputStats) );
+ SetCost(input.Get(), typeCtx, GetCost(flatmapInput.Raw(), typeCtx));
+}
+
+/**
+ * Infer statistics and costs for SkipNullMembers
+ * We don't have a good idea at this time how many nulls will be discarded, so we just return the
+ * input statistics.
+*/
+void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto skipNullMembers = inputNode.Cast<TCoSkipNullMembers>();
+ auto skipNullMembersInput = skipNullMembers.Input();
+
+ auto inputStats = GetStats(skipNullMembersInput.Raw(), typeCtx);
+ if (!inputStats) {
+ return;
+ }
+
+ SetStats( input.Get(), typeCtx, inputStats );
+ SetCost( input.Get(), typeCtx, GetCost( skipNullMembersInput.Raw(), typeCtx ) );
+}
+
+/**
+ * Compute statistics and cost for read table
+ * Currently we just make up a number for the cardinality (100000) and set cost to 0
+*/
+void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table";
+
+ auto outputStats = TOptimizerStatistics(100000, 5, 0.0);
+ SetStats( input.Get(), typeCtx, std::make_shared<TOptimizerStatistics>(outputStats) );
+}
+
+/**
+ * Compute sstatistics for index lookup
+ * Currently we just make up a number for cardinality (5) and set cost to 0
+*/
+void InferStatisticsForIndexLookup(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto outputStats = TOptimizerStatistics(5, 5, 0.0);
+ SetStats( input.Get(), typeCtx, std::make_shared<TOptimizerStatistics>(outputStats) );
+}
+
+/**
+ * DoTransform method matches operators and callables in the query DAG and
+ * uses pre-computed statistics and costs of the children to compute their cost.
+*/
+IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPtr input,
+ TExprNode::TPtr& output, TExprContext& ctx) {
+
+ output = input;
+ if (!Config->HasOptEnableCostBasedOptimization()) {
+ return IGraphTransformer::TStatus::Ok;
+ }
+
+ TOptimizeExprSettings settings(nullptr);
+
+ auto ret = OptimizeExpr(input, output, [*this](const TExprNode::TPtr& input, TExprContext& ctx) {
+ Y_UNUSED(ctx);
+ auto output = input;
+
+ if (TCoFlatMap::Match(input.Get())){
+ InferStatisticsForFlatMap(input, typeCtx);
+ }
+ else if(TCoSkipNullMembers::Match(input.Get())){
+ InferStatisticsForSkipNullMembers(input, typeCtx);
+ }
+ else if(TKqlReadTableBase::Match(input.Get()) || TKqlReadTableRangesBase::Match(input.Get())){
+ InferStatisticsForReadTable(input, typeCtx);
+ }
+ else if(TKqlLookupTableBase::Match(input.Get()) || TKqlLookupIndexBase::Match(input.Get())){
+ InferStatisticsForIndexLookup(input, typeCtx);
+ }
+
+ return output;
+ }, ctx, settings);
+
+ return ret;
+}
+
+TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx,
+ const TKikimrConfiguration::TPtr& config) {
+
+ return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(typeCtx, config));
+} \ No newline at end of file
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.h b/ydb/core/kqp/opt/kqp_statistics_transformer.h
new file mode 100644
index 00000000000..78d982f9981
--- /dev/null
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#include <ydb/library/yql/core/yql_statistics.h>
+
+#include <ydb/core/kqp/common/kqp_yql.h>
+#include <ydb/library/yql/core/yql_graph_transformer.h>
+#include <ydb/library/yql/core/yql_expr_optimize.h>
+#include <ydb/library/yql/core/yql_expr_type_annotation.h>
+#include <ydb/core/kqp/provider/yql_kikimr_provider_impl.h>
+#include <ydb/library/yql/core/yql_opt_utils.h>
+
+namespace NKikimr {
+namespace NKqp {
+
+using namespace NYql;
+using namespace NYql::NNodes;
+
+/***
+ * Statistics transformer is a transformer that propagates statistics and costs from
+ * the leaves of the plan DAG up to the root of the DAG. It handles a number of operators,
+ * but will simply stop propagation if in encounters an operator that it has no rules for.
+ * One of such operators is EquiJoin, but there is a special rule to handle EquiJoin.
+*/
+class TKqpStatisticsTransformer : public TSyncTransformerBase {
+
+ TTypeAnnotationContext* typeCtx;
+ const TKikimrConfiguration::TPtr& Config;
+
+ public:
+ TKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) :
+ typeCtx(&typeCtx), Config(config) {}
+
+ // Main method of the transformer
+ IGraphTransformer::TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final;
+
+ // Rewind currently does nothing
+ void Rewind() {
+
+ }
+};
+
+TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx,
+ const TKikimrConfiguration::TPtr& config);
+}
+}
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
index dd887361f5a..ea54f36881c 100644
--- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
+++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
@@ -33,6 +33,7 @@ public:
AddHandler(0, &TCoTake::Match, HNDL(RewriteTakeSortToTopSort));
AddHandler(0, &TCoFlatMap::Match, HNDL(RewriteSqlInToEquiJoin));
AddHandler(0, &TCoFlatMap::Match, HNDL(RewriteSqlInCompactToJoin));
+ AddHandler(0, &TCoEquiJoin::Match, HNDL(OptimizeEquiJoinWithCosts));
AddHandler(0, &TCoEquiJoin::Match, HNDL(RewriteEquiJoin));
AddHandler(0, &TDqJoin::Match, HNDL(JoinToIndexLookup));
AddHandler(0, &TCoCalcOverWindowBase::Match, HNDL(ExpandWindowFunctions));
@@ -123,6 +124,12 @@ protected:
return output;
}
+ TMaybeNode<TExprBase> OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) {
+ TExprBase output = DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, Config->HasOptEnableCostBasedOptimization());
+ DumpAppliedRule("OptimizeEquiJoinWithCosts", node.Ptr(), output.Ptr(), ctx);
+ return output;
+ }
+
TMaybeNode<TExprBase> RewriteEquiJoin(TExprBase node, TExprContext& ctx) {
TExprBase output = DqRewriteEquiJoin(node, KqpCtx.Config->GetHashJoinMode(), ctx);
DumpAppliedRule("RewriteEquiJoin", node.Ptr(), output.Ptr(), ctx);
diff --git a/ydb/core/kqp/opt/ya.make b/ydb/core/kqp/opt/ya.make
index d04d95ec68f..943fe63957c 100644
--- a/ydb/core/kqp/opt/ya.make
+++ b/ydb/core/kqp/opt/ya.make
@@ -12,6 +12,7 @@ SRCS(
kqp_opt_range_legacy.cpp
kqp_query_blocks_transformer.cpp
kqp_query_plan.cpp
+ kqp_statistics_transformer.cpp
)
PEERDIR(
diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.cpp b/ydb/core/kqp/provider/yql_kikimr_settings.cpp
index 115c0ee1e83..6144c18bb72 100644
--- a/ydb/core/kqp/provider/yql_kikimr_settings.cpp
+++ b/ydb/core/kqp/provider/yql_kikimr_settings.cpp
@@ -63,6 +63,7 @@ TKikimrConfiguration::TKikimrConfiguration() {
REGISTER_SETTING(*this, OptEnablePredicateExtract);
REGISTER_SETTING(*this, OptEnableOlapPushdown);
REGISTER_SETTING(*this, OptUseFinalizeByKey);
+ REGISTER_SETTING(*this, OptEnableCostBasedOptimization);
/* Runtime */
REGISTER_SETTING(*this, ScanQuery);
@@ -124,6 +125,11 @@ bool TKikimrSettings::HasOptUseFinalizeByKey() const {
return GetOptionalFlagValue(OptUseFinalizeByKey.Get()) == EOptionalFlag::Enabled;
}
+bool TKikimrSettings::HasOptEnableCostBasedOptimization() const {
+ return GetOptionalFlagValue(OptEnableCostBasedOptimization.Get()) == EOptionalFlag::Enabled;
+}
+
+
EOptionalFlag TKikimrSettings::GetOptPredicateExtract() const {
return GetOptionalFlagValue(OptEnablePredicateExtract.Get());
}
diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.h b/ydb/core/kqp/provider/yql_kikimr_settings.h
index aabf39f8170..081d64354a3 100644
--- a/ydb/core/kqp/provider/yql_kikimr_settings.h
+++ b/ydb/core/kqp/provider/yql_kikimr_settings.h
@@ -56,6 +56,7 @@ struct TKikimrSettings {
NCommon::TConfSetting<bool, false> OptEnablePredicateExtract;
NCommon::TConfSetting<bool, false> OptEnableOlapPushdown;
NCommon::TConfSetting<bool, false> OptUseFinalizeByKey;
+ NCommon::TConfSetting<bool, false> OptEnableCostBasedOptimization;
/* Runtime */
NCommon::TConfSetting<bool, true> ScanQuery;
@@ -75,6 +76,8 @@ struct TKikimrSettings {
bool HasOptDisableSqlInToJoin() const;
bool HasOptEnableOlapPushdown() const;
bool HasOptUseFinalizeByKey() const;
+ bool HasOptEnableCostBasedOptimization() const;
+
EOptionalFlag GetOptPredicateExtract() const;
EOptionalFlag GetUseLlvm() const;
NDq::EHashJoinMode GetHashJoinMode() const;
diff --git a/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt b/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt
index 46214236eea..44f17858ecd 100644
--- a/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt
+++ b/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt
@@ -35,6 +35,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
)
set_property(
TARGET
diff --git a/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt b/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt
index b4bbb7b0a1f..2f8f1cfe4d2 100644
--- a/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt
+++ b/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt
@@ -38,6 +38,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
)
set_property(
TARGET
diff --git a/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt b/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt
index e7bde2ff331..6d8f2e29856 100644
--- a/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt
+++ b/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt
@@ -39,6 +39,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
)
set_property(
TARGET
diff --git a/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt b/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt
index f80baee42f2..4ea7727bf85 100644
--- a/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt
+++ b/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt
@@ -28,6 +28,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp
${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
)
set_property(
TARGET
diff --git a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
new file mode 100644
index 00000000000..3265ebd4ea1
--- /dev/null
+++ b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
@@ -0,0 +1,281 @@
+#include <ydb/core/kqp/ut/common/kqp_ut_common.h>
+
+#include <ydb/public/sdk/cpp/client/ydb_proto/accessor.h>
+
+#include <util/string/printf.h>
+
+namespace NKikimr {
+namespace NKqp {
+
+using namespace NYdb;
+using namespace NYdb::NTable;
+
+/**
+ * A basic join order test. We define 5 tables sharing the same
+ * key attribute and construct various full clique join queries
+*/
+static void CreateSampleTable(TSession session) {
+ UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
+ CREATE TABLE `/Root/R` (
+ id Int32,
+ payload1 String,
+ PRIMARY KEY (id)
+ );
+ )").GetValueSync().IsSuccess());
+
+ UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
+ CREATE TABLE `/Root/S` (
+ id Int32,
+ payload2 String,
+ PRIMARY KEY (id)
+ );
+ )").GetValueSync().IsSuccess());
+
+ UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
+ CREATE TABLE `/Root/T` (
+ id Int32,
+ payload3 String,
+ PRIMARY KEY (id)
+ );
+ )").GetValueSync().IsSuccess());
+
+ UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
+ CREATE TABLE `/Root/U` (
+ id Int32,
+ payload4 String,
+ PRIMARY KEY (id)
+ );
+ )").GetValueSync().IsSuccess());
+
+ UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
+ CREATE TABLE `/Root/V` (
+ id Int32,
+ payload5 String,
+ PRIMARY KEY (id)
+ );
+ )").GetValueSync().IsSuccess());
+
+ UNIT_ASSERT(session.ExecuteDataQuery(R"(
+
+ REPLACE INTO `/Root/R` (id, payload1) VALUES
+ (1, "blah");
+
+ REPLACE INTO `/Root/S` (id, payload2) VALUES
+ (1, "blah");
+
+ REPLACE INTO `/Root/T` (id, payload3) VALUES
+ (1, "blah");
+
+ REPLACE INTO `/Root/U` (id, payload4) VALUES
+ (1, "blah");
+
+ REPLACE INTO `/Root/V` (id, payload5) VALUES
+ (1, "blah");
+ )", TTxControl::BeginTx().CommitTx()).GetValueSync().IsSuccess());
+}
+
+static TKikimrRunner GetKikimrWithJoinSettings(){
+ TVector<NKikimrKqp::TKqpSetting> settings;
+ NKikimrKqp::TKqpSetting setting;
+ setting.SetName("OptEnableCostBasedOptimization");
+ setting.SetValue("true");
+ settings.push_back(setting);
+
+ return TKikimrRunner(settings);
+}
+
+
+Y_UNIT_TEST_SUITE(KqpJoinOrder) {
+ Y_UNIT_TEST(FiveWayJoin) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+ SELECT *
+ FROM `/Root/R` as R
+ INNER JOIN
+ `/Root/S` as S
+ ON R.id = S.id
+ INNER JOIN
+ `/Root/T` as T
+ ON S.id = T.id
+ INNER JOIN
+ `/Root/U` as U
+ ON T.id = U.id
+ INNER JOIN
+ `/Root/V` as V
+ ON U.id = V.id
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(FiveWayJoinWithPreds) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+ SELECT *
+ FROM `/Root/R` as R
+ INNER JOIN
+ `/Root/S` as S
+ ON R.id = S.id
+ INNER JOIN
+ `/Root/T` as T
+ ON S.id = T.id
+ INNER JOIN
+ `/Root/U` as U
+ ON T.id = U.id
+ INNER JOIN
+ `/Root/V` as V
+ ON U.id = V.id
+ WHERE R.payload1 = 'blah' AND V.payload5 = 'blah'
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(FiveWayJoinWithComplexPreds) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+ SELECT *
+ FROM `/Root/R` as R
+ INNER JOIN
+ `/Root/S` as S
+ ON R.id = S.id
+ INNER JOIN
+ `/Root/T` as T
+ ON S.id = T.id
+ INNER JOIN
+ `/Root/U` as U
+ ON T.id = U.id
+ INNER JOIN
+ `/Root/V` as V
+ ON U.id = V.id
+ WHERE R.payload1 = 'blah' AND V.payload5 = 'blah' AND ( S.payload2 || T.payload3 = U.payload4 )
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(FiveWayJoinWithComplexPreds2) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+ SELECT *
+ FROM `/Root/R` as R
+ INNER JOIN
+ `/Root/S` as S
+ ON R.id = S.id
+ INNER JOIN
+ `/Root/T` as T
+ ON S.id = T.id
+ INNER JOIN
+ `/Root/U` as U
+ ON T.id = U.id
+ INNER JOIN
+ `/Root/V` as V
+ ON U.id = V.id
+ WHERE (R.payload1 || V.payload5 = 'blah') AND U.payload4 = 'blah'
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(FiveWayJoinWithPredsAndEquiv) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+ SELECT *
+ FROM `/Root/R` as R
+ INNER JOIN
+ `/Root/S` as S
+ ON R.id = S.id
+ INNER JOIN
+ `/Root/T` as T
+ ON S.id = T.id
+ INNER JOIN
+ `/Root/U` as U
+ ON T.id = U.id
+ INNER JOIN
+ `/Root/V` as V
+ ON U.id = V.id
+ WHERE R.payload1 = 'blah' AND V.payload5 = 'blah' AND R.id = 1
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+}
+
+}
+}
+
diff --git a/ydb/core/kqp/ut/join/ya.make b/ydb/core/kqp/ut/join/ya.make
index ee9c559a46e..9897a40d2e2 100644
--- a/ydb/core/kqp/ut/join/ya.make
+++ b/ydb/core/kqp/ut/join/ya.make
@@ -16,6 +16,7 @@ SRCS(
kqp_flip_join_ut.cpp
kqp_index_lookup_join_ut.cpp
kqp_join_ut.cpp
+ kqp_join_order_ut.cpp
)
PEERDIR(
diff --git a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt
index e195fe28fd7..0b4b5ebc1d8 100644
--- a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt
+++ b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt
@@ -98,6 +98,7 @@ target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp
diff --git a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt
index b4486180fb5..66ef852bbe6 100644
--- a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt
+++ b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt
@@ -99,6 +99,7 @@ target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp
diff --git a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt
index b4486180fb5..66ef852bbe6 100644
--- a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt
+++ b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt
@@ -99,6 +99,7 @@ target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp
diff --git a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt
index e195fe28fd7..0b4b5ebc1d8 100644
--- a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt
+++ b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt
@@ -98,6 +98,7 @@ target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp
diff --git a/ydb/library/yql/core/ya.make b/ydb/library/yql/core/ya.make
index 1dbac92b4cf..7665acc9da3 100644
--- a/ydb/library/yql/core/ya.make
+++ b/ydb/library/yql/core/ya.make
@@ -37,6 +37,7 @@ SRCS(
yql_opt_utils.h
yql_opt_window.cpp
yql_opt_window.h
+ yql_statistics.cpp
yql_type_annotation.cpp
yql_type_annotation.h
yql_type_helpers.cpp
diff --git a/ydb/library/yql/core/yql_statistics.cpp b/ydb/library/yql/core/yql_statistics.cpp
new file mode 100644
index 00000000000..8abb9a59b5d
--- /dev/null
+++ b/ydb/library/yql/core/yql_statistics.cpp
@@ -0,0 +1,14 @@
+#include "yql_statistics.h"
+
+using namespace NYql;
+
+std::ostream& operator<<(std::ostream& os, const TOptimizerStatistics& s) {
+ os << "Nrows: " << s.Nrows << ", Ncols: " << s.Ncols;
+ os << "Cost: ";
+ if (s.Cost.has_value()){
+ os << s.Cost.value();
+ } else {
+ os << "none";
+ }
+ return os;
+} \ No newline at end of file
diff --git a/ydb/library/yql/core/yql_statistics.h b/ydb/library/yql/core/yql_statistics.h
new file mode 100644
index 00000000000..615dce6359b
--- /dev/null
+++ b/ydb/library/yql/core/yql_statistics.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include <optional>
+#include <iostream>
+
+namespace NYql {
+
+/**
+ * Optimizer Statistics struct records per-table and per-column statistics
+ * for the current operator in the plan. Currently, only Nrows and Ncols are
+ * recorded.
+ * Cost is also included in statistics, as its updated concurrently with statistics
+ * all of the time. Cost is optional, so it could be missing.
+*/
+struct TOptimizerStatistics {
+ double Nrows;
+ int Ncols;
+ std::optional<double> Cost;
+
+ TOptimizerStatistics() : Cost(std::nullopt) {}
+ TOptimizerStatistics(double nrows,int ncols): Nrows(nrows), Ncols(ncols), Cost(std::nullopt) {}
+ TOptimizerStatistics(double nrows,int ncols, double cost): Nrows(nrows), Ncols(ncols), Cost(cost) {}
+
+ friend std::ostream& operator<<(std::ostream& os, const TOptimizerStatistics& s);
+};
+}
diff --git a/ydb/library/yql/core/yql_type_annotation.h b/ydb/library/yql/core/yql_type_annotation.h
index f3f5e595f39..6a1670fd192 100644
--- a/ydb/library/yql/core/yql_type_annotation.h
+++ b/ydb/library/yql/core/yql_type_annotation.h
@@ -5,6 +5,7 @@
#include "yql_udf_resolver.h"
#include "yql_user_data_storage.h"
#include "yql_arrow_resolver.h"
+#include "yql_statistics.h"
#include <ydb/library/yql/public/udf/udf_validate.h>
#include <ydb/library/yql/core/credentials/yql_credentials.h>
@@ -184,6 +185,7 @@ struct TUdfCachedInfo {
};
struct TTypeAnnotationContext: public TThrRefBase {
+ THashMap<const TExprNode*, std::shared_ptr<TOptimizerStatistics>> StatisticsMap;
TIntrusivePtr<ITimeProvider> TimeProvider;
TIntrusivePtr<IRandomProvider> RandomProvider;
THashMap<TString, TIntrusivePtr<IDataProvider>> DataSourceMap;
diff --git a/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt
index 20e4f10070f..6ac7a6e01e7 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt
@@ -31,4 +31,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
)
diff --git a/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt b/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt
index bc0a22c399e..54ab33f05f8 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt
@@ -32,4 +32,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
)
diff --git a/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt
index bc0a22c399e..54ab33f05f8 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt
@@ -32,4 +32,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
)
diff --git a/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt
index 20e4f10070f..6ac7a6e01e7 100644
--- a/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt
+++ b/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt
@@ -31,4 +31,5 @@ target_sources(yql-dq-opt PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
)
diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
new file mode 100644
index 00000000000..03d328ae202
--- /dev/null
+++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
@@ -0,0 +1,1064 @@
+#include "dq_opt_join.h"
+#include "dq_opt_phy.h"
+
+#include <ydb/library/yql/core/yql_join.h>
+#include <ydb/library/yql/core/yql_opt_utils.h>
+#include <ydb/library/yql/dq/type_ann/dq_type_ann.h>
+#include <ydb/library/yql/utils/log/log.h>
+#include <ydb/library/yql/providers/common/provider/yql_provider.h>
+#include <ydb/library/yql/core/yql_type_helpers.h>
+#include <ydb/library/yql/core/yql_statistics.h>
+
+#include <library/cpp/disjoint_sets/disjoint_sets.h>
+
+
+#include <bitset>
+#include <set>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <queue>
+#include <memory>
+#include <sstream>
+
+namespace NYql::NDq {
+
+
+using namespace NYql::NNodes;
+
+
+/**
+ * Join column is a struct that records the relation label and
+ * attribute name, used in join conditions
+*/
+struct TJoinColumn {
+ TString RelName;
+ TString AttributeName;
+
+ TJoinColumn(TString relName, TString attributeName) : RelName(relName),
+ AttributeName(attributeName) {}
+
+ bool operator == (const TJoinColumn& other) const {
+ return RelName == other.RelName && AttributeName == other.AttributeName;
+ }
+
+ struct HashFunction
+ {
+ size_t operator()(const TJoinColumn& c) const
+ {
+ return THash<TString>{}(c.RelName) ^ THash<TString>{}(c.AttributeName);
+ }
+ };
+};
+
+bool operator < (const TJoinColumn& c1, const TJoinColumn& c2) {
+ if (c1.RelName < c2.RelName){
+ return true;
+ } else if (c1.RelName == c2.RelName) {
+ return c1.AttributeName < c2.AttributeName;
+ }
+ return false;
+}
+
+/**
+ * Edge structure records an edge in a Join graph.
+ * - from is the integer id of the source vertex of the graph
+ * - to is the integer id of the target vertex of the graph
+ * - joinConditions records the set of join conditions of this edge
+*/
+struct TEdge {
+ int From;
+ int To;
+ std::set<std::pair<TJoinColumn, TJoinColumn>> JoinConditions;
+
+ TEdge(int f, int t): From(f), To(t) {}
+
+ TEdge(int f, int t, std::pair<TJoinColumn, TJoinColumn> cond): From(f), To(t) {
+ JoinConditions.insert(cond);
+ }
+
+ TEdge(int f, int t, std::set<std::pair<TJoinColumn, TJoinColumn>> conds): From(f), To(t),
+ JoinConditions(conds) {}
+
+ bool operator==(const TEdge& other) const
+ {
+ return From==other.From && To==other.To;
+ }
+
+ struct HashFunction
+ {
+ size_t operator()(const TEdge& e) const
+ {
+ return e.From + e.To;
+ }
+ };
+};
+
+/**
+ * Graph is a data structure for the join graph
+ * It is an undirected graph, with two edges per connection (from,to) and (to,from)
+ * It needs to be constructed with addNode and addEdge methods, since its
+ * keeping various indexes updated.
+ * The graph also needs to be reordered with the breadth-first search method, and
+ * the reordering is recorded in bfsMapping (original rel indexes can be recovered from
+ * this mapping)
+*/
+template <int N>
+struct TGraph {
+ // set of edges of the graph
+ std::unordered_set<TEdge,TEdge::HashFunction> Edges;
+
+ // neightborgh index
+ TVector<std::bitset<N>> EdgeIdx;
+
+ // number of nodes in a graph
+ int NNodes;
+
+ // mapping from rel label to node in the graph
+ THashMap<TString,int> ScopeMapping;
+
+ // mapping from node in the graph to rel label
+ TVector<TString> RevScopeMapping;
+
+ // Breadth-first-search mapping
+ TVector<int> BfsMapping;
+
+ // Empty graph constructor intializes indexes to size N
+ TGraph() : EdgeIdx(N), RevScopeMapping(N) {}
+
+ // Add a node to a graph with a rel label
+ void AddNode(int nodeId, TString scope){
+ NNodes = nodeId + 1;
+ ScopeMapping[scope] = nodeId;
+ RevScopeMapping[nodeId] = scope;
+ }
+
+ // Add an edge to the graph, if the edge is already in the graph
+ // (we check both directions), no action is taken. Otherwise we
+ // insert two edges, the forward edge with original joinConditions
+ // and a reverse edge with swapped joinConditions
+ void AddEdge(TEdge e){
+ if (Edges.contains(e) || Edges.contains(TEdge(e.To, e.From))) {
+ return;
+ }
+
+ Edges.insert(e);
+ std::set<std::pair<TJoinColumn, TJoinColumn>> swappedSet;
+ for (auto c : e.JoinConditions){
+ swappedSet.insert(std::make_pair(c.second, c.first));
+ }
+ Edges.insert(TEdge(e.To,e.From,swappedSet));
+
+ EdgeIdx[e.From].set(e.To);
+ EdgeIdx[e.To].set(e.From);
+ }
+
+ // Find a node by the rel scope
+ int FindNode(TString scope){
+ return ScopeMapping[scope];
+ }
+
+ // Return a bitset of node's neighbors
+ inline std::bitset<N> FindNeighbors(int fromVertex)
+ {
+ return EdgeIdx[fromVertex];
+ }
+
+ // Return a bitset of node's neigbors with itself included
+ inline std::bitset<N> FindNeighborsWithSelf(int fromVertex)
+ {
+ std::bitset<N> res = FindNeighbors(fromVertex);
+ res.set(fromVertex);
+ return res;
+ }
+
+ // Find an edge that connects two subsets of graph's nodes
+ // We are guaranteed to find a match
+ TEdge FindCrossingEdge(const std::bitset<N>& S1, const std::bitset<N>& S2) {
+ for(int i=0; i<NNodes; i++){
+ if (!S1[i]) {
+ continue;
+ }
+ for (int j=0; j<NNodes; j++) {
+ if (!S2[j]) {
+ continue;
+ }
+ if ((FindNeighborsWithSelf(i) & FindNeighborsWithSelf(j)) != 0) {
+ auto it = Edges.find(TEdge(i, j));
+ Y_VERIFY_DEBUG(it != Edges.end());
+ return *it;
+ }
+ }
+ }
+ Y_ENSURE(false,"Connecting edge not found!");
+ return TEdge(-1,-1);
+ }
+
+ /**
+ * Create a union-set from the join conditions to record the equivalences.
+ * Then use the equivalence set to compute transitive closure of the graph.
+ * Transitive closure means that if we have an edge from (1,2) with join
+ * condition R.A = S.A and we have an edge from (2,3) with join condition
+ * S.A = T.A, we will find out that the join conditions form an equivalence set
+ * and add an edge (1,3) with join condition R.A = T.A.
+ */
+ void ComputeTransitiveClosure(const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
+ std::set<TJoinColumn> columnSet;
+ for (auto [ leftCondition, rightCondition ] : joinConditions) {
+ columnSet.insert( leftCondition );
+ columnSet.insert( rightCondition );
+ }
+ std::vector<TJoinColumn> columns;
+ for (auto c : columnSet ) {
+ columns.push_back(c);
+ }
+
+ THashMap<TJoinColumn, int, TJoinColumn::HashFunction> indexMapping;
+ for (size_t i=0; i<columns.size(); i++) {
+ indexMapping[ columns[i] ] = i;
+ }
+
+ TDisjointSets ds = TDisjointSets( columns.size() );
+ for (auto [ leftCondition, rightCondition ] : joinConditions ) {
+ int leftIndex = indexMapping[ leftCondition ];
+ int rightIndex = indexMapping[ rightCondition ];
+ ds.UnionSets(leftIndex,rightIndex);
+ }
+
+ for (size_t i=0;i<columns.size();i++) {
+ for (size_t j=0;j<i;j++) {
+ if (ds.CanonicSetElement(i) == ds.CanonicSetElement(j)) {
+ TJoinColumn left = columns[i];
+ TJoinColumn right = columns[j];
+ int leftNodeId = ScopeMapping[ left.RelName ];
+ int rightNodeId = ScopeMapping[ right.RelName ];
+
+ if (! Edges.contains(TEdge(leftNodeId,rightNodeId)) &&
+ ! Edges.contains(TEdge(rightNodeId,leftNodeId))) {
+ AddEdge(TEdge(leftNodeId,rightNodeId,std::make_pair(left, right)));
+ } else {
+ TEdge e1 = *Edges.find(TEdge(leftNodeId,rightNodeId));
+ if (!e1.JoinConditions.contains(std::make_pair(left, right))) {
+ e1.JoinConditions.insert(std::make_pair(left, right));
+ }
+
+ TEdge e2 = *Edges.find(TEdge(rightNodeId,leftNodeId));
+ if (!e2.JoinConditions.contains(std::make_pair(right, left))) {
+ e2.JoinConditions.insert(std::make_pair(right, left));
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Reorder the graph by doing a breadth first search from node 0.
+ * This is required by the DPccp algorithm.
+ */
+ TGraph<N> BfsReorder() {
+ std::set<int> visited;
+ std::queue<int> queue;
+ TVector<int> bfsMapping(NNodes);
+
+ queue.push(0);
+ int lastVisited = 0;
+
+ while(! queue.empty()) {
+ int curr = queue.front();
+ queue.pop();
+ if (visited.contains(curr)) {
+ continue;
+ }
+ bfsMapping[curr] = lastVisited;
+ lastVisited++;
+ visited.insert(curr);
+
+ std::bitset<N> neighbors = FindNeighbors(curr);
+ for (int i=0; i<NNodes; i++ ) {
+ if (neighbors[i]){
+ if (!visited.contains(i)) {
+ queue.push(i);
+ }
+ }
+ }
+ }
+
+ TGraph<N> res;
+
+ for (int i=0;i<NNodes;i++) {
+ res.AddNode(i,RevScopeMapping[bfsMapping[i]]);
+ }
+
+ for (const TEdge& e : Edges){
+ res.AddEdge( TEdge(bfsMapping[e.From], bfsMapping[e.To], e.JoinConditions) );
+ }
+
+ res.BfsMapping = bfsMapping;
+
+ return res;
+ }
+
+ /**
+ * Print the graph
+ */
+ void PrintGraph(std::stringstream& stream) {
+ stream << "Join Graph:\n";
+ stream << "nNodes: " << NNodes << ", nEdges: " << Edges.size() << "\n";
+
+ for(int i=0;i<NNodes;i++) {
+ stream << "Node:" << i << "," << RevScopeMapping[i] << "\n";
+ }
+ for (const TEdge& e: Edges ) {
+ stream << "Edge: " << e.From << " -> " << e.To << "\n";
+ for (auto p : e.JoinConditions) {
+ stream << p.first.RelName << "."
+ << p.first.AttributeName << "="
+ << p.second.RelName << "."
+ << p.second.AttributeName << "\n";
+ }
+ }
+ }
+};
+
+/**
+ * Fetch join conditions from the equi-join tree
+*/
+void ComputeJoinConditions(const TCoEquiJoinTuple& joinTuple,
+ std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
+ if (joinTuple.LeftScope().Maybe<TCoEquiJoinTuple>()) {
+ ComputeJoinConditions( joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), joinConditions );
+ }
+
+ if (joinTuple.RightScope().Maybe<TCoEquiJoinTuple>()) {
+ ComputeJoinConditions( joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), joinConditions );
+ }
+
+ size_t joinKeysCount = joinTuple.LeftKeys().Size() / 2;
+ for (size_t i = 0; i < joinKeysCount; ++i) {
+ size_t keyIndex = i * 2;
+
+ auto leftScope = joinTuple.LeftKeys().Item(keyIndex).StringValue();
+ auto leftColumn = joinTuple.LeftKeys().Item(keyIndex + 1).StringValue();
+ auto rightScope = joinTuple.RightKeys().Item(keyIndex).StringValue();
+ auto rightColumn = joinTuple.RightKeys().Item(keyIndex + 1).StringValue();
+
+ joinConditions.insert( std::make_pair( TJoinColumn(leftScope, leftColumn),
+ TJoinColumn(rightScope, rightColumn)));
+ }
+}
+
+/**
+ * OptimizerNodes are the internal representations of operators inside the
+ * Cost-based optimizer. Currently we only support RelOptimizerNode - a node that
+ * is an input relation to the equi-join, and JoinOptimizerNode - an inner join
+ * that connects two sets of relations.
+*/
+enum EOptimizerNodeKind: ui32
+{
+ RelNodeType,
+ JoinNodeType
+};
+
+/**
+ * BaseOptimizerNode is a base class for the internal optimizer nodes
+ * It records a pointer to statistics and records the current cost of the
+ * operator tree, rooted at this node
+*/
+struct IBaseOptimizerNode {
+ EOptimizerNodeKind Kind;
+ std::shared_ptr<TOptimizerStatistics> Stats;
+
+ IBaseOptimizerNode(EOptimizerNodeKind k) : Kind(k) {}
+ IBaseOptimizerNode(EOptimizerNodeKind k, std::shared_ptr<TOptimizerStatistics> s) :
+ Kind(k), Stats(s) {}
+
+ virtual void Print(std::stringstream& stream, int ntabs=0)=0;
+};
+
+/**
+ * RelOptimizerNode adds a label to base class
+ * This is the label assinged to the input by equi-Join
+*/
+struct TRelOptimizerNode : public IBaseOptimizerNode {
+ TString Label;
+
+ TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats) :
+ IBaseOptimizerNode(RelNodeType, stats), Label(label) { }
+ virtual ~TRelOptimizerNode() {}
+
+ virtual void Print(std::stringstream& stream, int ntabs=0) {
+ for (int i=0;i<ntabs;i++){
+ stream << "\t";
+ }
+ stream << "Rel: " << Label << "\n";
+
+ for (int i=0;i<ntabs;i++){
+ stream << "\t";
+ }
+ stream << Stats << "\n";
+ }
+};
+
+/**
+ * JoinOptimizerNode records the left and right arguments of the join
+ * as well as the set of join conditions.
+ * It also has methods to compute the statistics and cost of a join,
+ * based on pre-computed costs and statistics of the children.
+*/
+struct TJoinOptimizerNode : public IBaseOptimizerNode {
+ std::shared_ptr<IBaseOptimizerNode> LeftArg;
+ std::shared_ptr<IBaseOptimizerNode> RightArg;
+ std::set<std::pair<TJoinColumn, TJoinColumn>> JoinConditions;
+
+ TJoinOptimizerNode(std::shared_ptr<IBaseOptimizerNode> left, std::shared_ptr<IBaseOptimizerNode> right,
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) :
+ IBaseOptimizerNode(JoinNodeType), LeftArg(left), RightArg(right), JoinConditions(joinConditions) {}
+ virtual ~TJoinOptimizerNode() {}
+
+ /**
+ * Compute and set the statistics for this node.
+ * Currently we have a very rough calculation of statistics
+ */
+ void ComputeStatistics() {
+ double newCard = 0.2 * LeftArg->Stats->Nrows * RightArg->Stats->Nrows;
+ int newNCols = LeftArg->Stats->Ncols + RightArg->Stats->Ncols;
+ Stats = std::shared_ptr<TOptimizerStatistics>(new TOptimizerStatistics(newCard,newNCols));
+ }
+
+ /**
+ * Compute the cost of the join based on statistics and costs of children
+ * Again, we only have a rought calculation at this time
+ */
+ double ComputeCost() {
+ Y_ENSURE(LeftArg->Stats->Cost.has_value() && RightArg->Stats->Cost.has_value(),
+ "Missing values for costs in join computation");
+
+ return 2.0 * LeftArg->Stats->Nrows + RightArg->Stats->Nrows
+ + Stats->Nrows
+ + LeftArg->Stats->Cost.value() + RightArg->Stats->Cost.value();
+ }
+
+ /**
+ * Print out the join tree, rooted at this node
+ */
+ virtual void Print(std::stringstream& stream, int ntabs=0) {
+ for (int i=0;i<ntabs;i++){
+ stream << "\t";
+ }
+
+ stream << "Join: ";
+ for (auto c : JoinConditions){
+ stream << c.first.RelName << "." << c.first.AttributeName
+ << "=" << c.second.RelName << "."
+ << c.second.AttributeName << ", ";
+ }
+ stream << "\n";
+
+ for (int i=0;i<ntabs;i++){
+ stream << "\t";
+ }
+
+ stream << Stats << "\n";
+
+ LeftArg->Print(stream, ntabs+1);
+ RightArg->Print(stream, ntabs+1);
+ }
+};
+
+
+/**
+ * Create a new join and compute its statistics and cost
+*/
+std::shared_ptr<TJoinOptimizerNode> MakeJoin(std::shared_ptr<IBaseOptimizerNode> left,
+ std::shared_ptr<IBaseOptimizerNode> right, const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
+
+ auto res = std::shared_ptr<TJoinOptimizerNode>(new TJoinOptimizerNode(left, right, joinConditions));
+ res->ComputeStatistics();
+ res->Stats->Cost = res->ComputeCost();
+ return res;
+}
+
+struct pair_hash {
+ template <class T1, class T2>
+ std::size_t operator () (const std::pair<T1,T2> &p) const {
+ auto h1 = std::hash<T1>{}(p.first);
+ auto h2 = std::hash<T2>{}(p.second);
+
+ // Mainly for demonstration purposes, i.e. works but is overly simple
+ // In the real world, use sth. like boost.hash_combine
+ return h1 ^ h2;
+ }
+};
+
+/**
+ * DPcpp (Dynamic Programming with connected complement pairs) is a graph-aware
+ * join eumeration algorithm that only considers CSGs (Connected Sub-Graphs) of
+ * the join graph and computes CMPs (Complement pairs) that are also connected
+ * subgraphs of the join graph. It enumerates CSGs in the order, such that subsets
+ * are enumerated first and no duplicates are ever enumerated. Then, for each emitted
+ * CSG it computes the complements with the same conditions - they much already be
+ * present in the dynamic programming table and no pair should be enumerated twice.
+ *
+ * The DPccp solver is templated by the largest number of joins we can process, this
+ * is in turn used by bitsets that represent sets of relations.
+*/
+template <int N>
+class TDPccpSolver {
+ public:
+
+ // Construct the DPccp solver based on the join graph and data about input relations
+ TDPccpSolver(TGraph<N>& g, std::vector<std::shared_ptr<TRelOptimizerNode>> rels):
+ Graph(g), Rels(rels) {
+ NNodes = g.NNodes;
+ }
+
+ // Run DPccp algorithm and produce the join tree in CBO's internal representation
+ std::shared_ptr<TJoinOptimizerNode> Solve();
+
+ private:
+
+ // Compute the next subset of relations, given by the final bitset
+ std::bitset<N> NextBitset(const std::bitset<N>& current, const std::bitset<N>& final);
+
+ // Print the set of relations in a bitset
+ void PrintBitset(std::stringstream& stream, const std::bitset<N>& s, std::string name, int ntabs=0);
+
+ // Dynamic programming table that records optimal join subtrees
+ THashMap<std::bitset<N>, std::shared_ptr<IBaseOptimizerNode>, std::hash<std::bitset<N>>> DpTable;
+
+ // REMOVE: Sanity check table that tracks that we don't consider the same pair twice
+ THashMap<std::pair<std::bitset<N>, std::bitset<N>>, bool, pair_hash> CheckTable;
+
+ // number of nodes in a graph
+ int NNodes;
+
+ // Join graph
+ TGraph<N>& Graph;
+
+ // List of input relations to DPccp
+ std::vector<std::shared_ptr<TRelOptimizerNode>> Rels;
+
+ // Emit connected subgraph
+ void EmitCsg(const std::bitset<N>&, int=0);
+
+ // Enumerate subgraphs recursively
+ void EnumerateCsgRec(const std::bitset<N>&, const std::bitset<N>&,int=0);
+
+ // Emit the final pair of CSG and CMP - compute the join and record it in the
+ // DP table
+ void EmitCsgCmp(const std::bitset<N>&, const std::bitset<N>&,int=0);
+
+ // Enumerate complement pairs recursively
+ void EnumerateCmpRec(const std::bitset<N>&, const std::bitset<N>&, const std::bitset<N>&,int=0);
+
+ // Compute the neighbors of a set of nodes, excluding the nodes in exclusion set
+ std::bitset<N> Neighbors(const std::bitset<N>&, const std::bitset<N>&);
+
+ // Create an exclusion set that contains all the nodes of the graph that are smaller or equal to
+ // the smallest node in the provided bitset
+ std::bitset<N> MakeBiMin(const std::bitset<N>&);
+
+ // Create an exclusion set that contains all the nodes of the bitset that are smaller or equal to
+ // the provided integer
+ std::bitset<N> MakeB(const std::bitset<N>&,int);
+};
+
+// Print tabs
+void PrintTabs(std::stringstream& stream, int ntabs) {
+
+ for (int i=0;i<ntabs;i++)
+ stream << "\t";
+}
+
+// Print a set of nodes in the graph given by this bitset
+template <int N> void TDPccpSolver<N>::PrintBitset(std::stringstream& stream,
+ const std::bitset<N>& s, std::string name, int ntabs) {
+
+ PrintTabs(stream, ntabs);
+
+ stream << name << ": " << "{";
+ for (int i=0;i<NNodes;i++)
+ if (s[i])
+ stream << i << ",";
+
+ stream <<"}\n";
+}
+
+// Compute neighbors of a set of nodes S, exclusing the exclusion set X
+template<int N> std::bitset<N> TDPccpSolver<N>::Neighbors(const std::bitset<N>& S, const std::bitset<N>& X) {
+
+ std::bitset<N> res;
+
+ for (int i=0;i<Graph.NNodes;i++) {
+ if (S[i]) {
+ std::bitset<N> n = Graph.FindNeighbors(i);
+ res = res | n;
+ }
+ }
+
+ res = res & ~ X;
+ return res;
+}
+
+// Run the entire DPccp algorithm and compute the optimal join tree
+template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
+{
+ // Process singleton sets
+ for (int i=NNodes-1;i>=0;i--) {
+ std::bitset<N> s;
+ s.set(i);
+ DpTable[s] = Rels[Graph.BfsMapping[i]];
+ }
+
+ // Expand singleton sets
+ for (int i=NNodes-1;i>=0;i--) {
+ std::bitset<N> s;
+ s.set(i);
+ EmitCsg(s);
+ EnumerateCsgRec(s, MakeBiMin(s));
+ }
+
+ // Return the entry of the dpTable that corresponds to the full
+ // set of nodes in the graph
+ std::bitset<N> V;
+ for (int i=0;i<NNodes;i++) {
+ V.set(i);
+ }
+
+ Y_ENSURE(DpTable.contains(V), "Final relset not in dptable");
+ return std::dynamic_pointer_cast<TJoinOptimizerNode>(DpTable[V]);
+}
+
+/**
+ * EmitCsg emits Connected SubGraphs
+ * First it iterates through neighbors of the initial set S and emits pairs
+ * (S,S2), where S2 is the neighbor of S. Then it recursively emits complement pairs
+*/
+ template <int N> void TDPccpSolver<N>::EmitCsg(const std::bitset<N>& S, int ntabs) {
+ std::bitset<N> X = S | MakeBiMin(S);
+ std::bitset<N> Ns = Neighbors(S, X);
+
+ if (Ns==std::bitset<N>()) {
+ return;
+ }
+
+ for (int i=NNodes-1;i>=0;i--) {
+ if (Ns[i]) {
+ std::bitset<N> S2;
+ S2.set(i);
+ EmitCsgCmp(S,S2,ntabs+1);
+ EnumerateCmpRec(S, S2, X|MakeB(Ns,i), ntabs+1);
+ }
+ }
+ }
+
+ /**
+ * Enumerates connected subgraphs
+ * First it emits CSGs that are created by adding neighbors of S to S
+ * Then it recurses on the S fused with its neighbors.
+ */
+ template <int N> void TDPccpSolver<N>::EnumerateCsgRec(const std::bitset<N>& S, const std::bitset<N>& X, int ntabs) {
+
+ std::bitset<N> Ns = Neighbors(S,X);
+
+ if (Ns == std::bitset<N>()) {
+ return;
+ }
+
+ std::bitset<N> prev;
+ std::bitset<N> next;
+
+ while(true) {
+ next = NextBitset(prev,Ns);
+ EmitCsg(S | next );
+ if (next == Ns) {
+ break;
+ }
+ prev = next;
+ }
+
+ prev.reset();
+ while(true) {
+ next = NextBitset(prev,Ns);
+ EnumerateCsgRec(S | next, X | Ns , ntabs+1);
+ if (next==Ns) {
+ break;
+ }
+ prev = next;
+ }
+ }
+
+/***
+ * Enumerates complement pairs
+ * First it emits the pairs (S1,S2+next) where S2+next is the set of relation sets
+ * that are obtained by adding S2's neighbors to itself
+ * Then it recusrses into pairs (S1,S2+next)
+*/
+ template <int N> void TDPccpSolver<N>::EnumerateCmpRec(const std::bitset<N>& S1,
+ const std::bitset<N>& S2, const std::bitset<N>& X, int ntabs) {
+
+ std::bitset<N> Ns = Neighbors(S2, X);
+
+ if (Ns==std::bitset<N>()) {
+ return;
+ }
+
+ std::bitset<N> prev;
+ std::bitset<N> next;
+
+ while(true) {
+ next = NextBitset(prev, Ns);
+ EmitCsgCmp(S1,S2|next, ntabs+1);
+ if (next==Ns) {
+ break;
+ }
+ prev = next;
+ }
+
+ prev.reset();
+ while(true) {
+ next = NextBitset(prev, Ns);
+ EnumerateCmpRec(S1, S2|next, X|Ns, ntabs+1);
+ if (next==Ns) {
+ break;
+ }
+ prev = next;
+ }
+ }
+
+/**
+ * Emit a single CSG + CMP pair
+*/
+template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, const std::bitset<N>& S2, int ntabs) {
+
+ Y_UNUSED(ntabs);
+ // Here we actually build the join and choose and compare the
+ // new plan to what's in the dpTable, if it there
+
+ Y_ENSURE(DpTable.contains(S1),"DP Table does not contain S1");
+ Y_ENSURE(DpTable.contains(S2),"DP Table does not conaint S2");
+
+ std::bitset<N> joined = S1 | S2;
+
+ if (! DpTable.contains(joined)) {
+ TEdge e1 = Graph.FindCrossingEdge(S1, S2);
+ DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions);
+ TEdge e2 = Graph.FindCrossingEdge(S2, S1);
+ std::shared_ptr<TJoinOptimizerNode> newJoin =
+ MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions);
+ if (newJoin->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){
+ DpTable[joined] = newJoin;
+ }
+ } else {
+ TEdge e1 = Graph.FindCrossingEdge(S1, S2);
+ std::shared_ptr<TJoinOptimizerNode> newJoin1 =
+ MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions);
+ TEdge e2 = Graph.FindCrossingEdge(S2, S1);
+ std::shared_ptr<TJoinOptimizerNode> newJoin2 =
+ MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions);
+ if (newJoin1->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){
+ DpTable[joined] = newJoin1;
+ }
+ if (newJoin2->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){
+ DpTable[joined] = newJoin2;
+ }
+ }
+
+ auto pair = std::make_pair(S1, S2);
+ Y_ENSURE (!CheckTable.contains(pair), "Check table already contains pair S1|S2");
+
+ CheckTable[ std::pair<std::bitset<N>,std::bitset<N>>(S1, S2) ] = true;
+}
+
+/**
+ * Create an exclusion set that contains all the nodes of the graph that are smaller or equal to
+ * the smallest node in the provided bitset
+*/
+template <int N> std::bitset<N> TDPccpSolver<N>::MakeBiMin(const std::bitset<N>& S) {
+ std::bitset<N> res;
+
+ for (int i=0;i<NNodes; i++) {
+ if (S[i]) {
+ for (int j=0; j<=i; j++) {
+ res.set(j);
+ }
+ break;
+ }
+ }
+ return res;
+}
+
+/**
+ * Create an exclusion set that contains all the nodes of the bitset that are smaller or equal to
+ * the provided integer
+*/
+template <int N> std::bitset<N> TDPccpSolver<N>::MakeB(const std::bitset<N>& S, int x) {
+ std::bitset<N> res;
+
+ for (int i=0; i<NNodes; i++) {
+ if (S[i] && i<=x) {
+ res.set(i);
+ }
+ }
+
+ return res;
+}
+
+/**
+ * Compute the next subset of relations, given by the final bitset
+*/
+template <int N> std::bitset<N> TDPccpSolver<N>::NextBitset(const std::bitset<N>& prev, const std::bitset<N>& final) {
+ if (prev==final)
+ return final;
+
+ std::bitset<N> res = prev;
+
+ bool carry = true;
+ for (int i=0; i<NNodes; i++)
+ {
+ if (!carry) {
+ break;
+ }
+
+ if (!final[i]) {
+ continue;
+ }
+
+ if (res[i]==1 && carry) {
+ res.reset(i);
+ } else if (res[i]==0 && carry)
+ {
+ res.set(i);
+ carry = false;
+ }
+ }
+
+ return res;
+
+ // TODO: We can optimize this with a few long integer operations,
+ // but it will only work for 64 bit bitsets
+ // return std::bitset<N>((prev | ~final).to_ulong() + 1) & final;
+}
+
+/**
+ * Build a join tree that will replace the original join tree in equiJoin
+ * TODO: Add join implementations here
+*/
+TExprBase BuildTree(TExprContext& ctx, const TCoEquiJoin& equiJoin,
+ std::shared_ptr<TJoinOptimizerNode>& reorderResult) {
+
+ // Create dummy left and right arg that will be overwritten
+ TExprBase leftArg(equiJoin);
+ TExprBase rightArg(equiJoin);
+
+ // Build left argument of the join
+ if (reorderResult->LeftArg->Kind == RelNodeType) {
+ std::shared_ptr<TRelOptimizerNode> rel =
+ std::dynamic_pointer_cast<TRelOptimizerNode>(reorderResult->LeftArg);
+ leftArg = BuildAtom(rel->Label, equiJoin.Pos(), ctx);
+ } else {
+ std::shared_ptr<TJoinOptimizerNode> join =
+ std::dynamic_pointer_cast<TJoinOptimizerNode>(reorderResult->LeftArg);
+ leftArg = BuildTree(ctx,equiJoin,join);
+ }
+ // Build right argument of the join
+ if (reorderResult->RightArg->Kind == RelNodeType) {
+ std::shared_ptr<TRelOptimizerNode> rel =
+ std::dynamic_pointer_cast<TRelOptimizerNode>(reorderResult->RightArg);
+ rightArg = BuildAtom(rel->Label, equiJoin.Pos(), ctx);
+ } else {
+ std::shared_ptr<TJoinOptimizerNode> join =
+ std::dynamic_pointer_cast<TJoinOptimizerNode>(reorderResult->RightArg);
+ rightArg = BuildTree(ctx,equiJoin,join);
+ }
+
+ TVector<TExprBase> leftJoinColumns;
+ TVector<TExprBase> rightJoinColumns;
+
+ // Build join conditions
+ for( auto pair : reorderResult->JoinConditions) {
+ leftJoinColumns.push_back(BuildAtom(pair.first.RelName, equiJoin.Pos(), ctx));
+ leftJoinColumns.push_back(BuildAtom(pair.first.AttributeName, equiJoin.Pos(), ctx));
+ rightJoinColumns.push_back(BuildAtom(pair.second.RelName, equiJoin.Pos(), ctx));
+ rightJoinColumns.push_back(BuildAtom(pair.second.AttributeName, equiJoin.Pos(), ctx));
+ }
+
+ TVector<TExprBase> options;
+
+ // Build the final output
+ return Build<TCoEquiJoinTuple>(ctx,equiJoin.Pos())
+ .Type(BuildAtom("Inner",equiJoin.Pos(),ctx))
+ .LeftScope(leftArg)
+ .RightScope(rightArg)
+ .LeftKeys()
+ .Add(leftJoinColumns)
+ .Build()
+ .RightKeys()
+ .Add(rightJoinColumns)
+ .Build()
+ .Options()
+ .Add(options)
+ .Build()
+ .Done();
+}
+
+/**
+ * Rebuild the equiJoinOperator with a new tree, that was obtained by optimizing join order
+*/
+TExprBase RearrangeEquiJoinTree(TExprContext& ctx, const TCoEquiJoin& equiJoin,
+ std::shared_ptr<TJoinOptimizerNode> reorderResult) {
+ TVector<TExprBase> joinArgs;
+ for (size_t i=0; i<equiJoin.ArgCount() - 2; i++){
+ joinArgs.push_back(equiJoin.Arg(i));
+ }
+
+ joinArgs.push_back(BuildTree(ctx,equiJoin,reorderResult));
+
+ joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount()-1));
+
+ return Build<TCoEquiJoin>(ctx, equiJoin.Pos())
+ .Add(joinArgs)
+ .Done();
+}
+
+/**
+ * Check if all joins in the equiJoin tree are Inner Joins
+ * FIX: This is a temporary solution, need to be able to process all types of joins in the future
+*/
+bool AllInnerJoins(const TCoEquiJoinTuple& joinTuple) {
+ if (joinTuple.Type() != "Inner") {
+ return false;
+ }
+ if (joinTuple.LeftScope().Maybe<TCoEquiJoinTuple>()) {
+ if (! AllInnerJoins(joinTuple.LeftScope().Cast<TCoEquiJoinTuple>())) {
+ return false;
+ }
+ }
+
+ if (joinTuple.RightScope().Maybe<TCoEquiJoinTuple>()) {
+ if (! AllInnerJoins(joinTuple.RightScope().Cast<TCoEquiJoinTuple>())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/**
+ * Main routine that checks:
+ * 1. Do we have an equiJoin
+ * 2. Is the cost already computed
+ * 3. FIX: Are all joins InnerJoins
+ * 4. Are all the costs of equiJoin inputs computed?
+ *
+ * Then it extracts join conditions from the join tree, constructs a join graph and
+ * optimizes it with the DPccp algorithm
+*/
+TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx,
+ bool ruleEnabled) {
+
+ if (!ruleEnabled) {
+ return node;
+ }
+
+ if (!node.Maybe<TCoEquiJoin>()) {
+ return node;
+ }
+ auto equiJoin = node.Cast<TCoEquiJoin>();
+ YQL_ENSURE(equiJoin.ArgCount() >= 4);
+
+ if (typesCtx.StatisticsMap.contains(equiJoin.Raw()) &&
+ typesCtx.StatisticsMap[ equiJoin.Raw()]->Cost.has_value()) {
+
+ return node;
+ }
+
+ if (! AllInnerJoins(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>())) {
+ return node;
+ }
+
+ YQL_CLOG(TRACE, CoreDq) << "Optimizing join with costs";
+
+ TVector<std::shared_ptr<TRelOptimizerNode>> rels;
+
+ // Check that statistics for all inputs of equiJoin were computed
+ // The arguments of the EquiJoin are 1..n-2, n-2 is the actual join tree
+ // of the EquiJoin and n-1 argument are the parameters to EquiJoin
+ for (size_t i = 0; i < equiJoin.ArgCount() - 2; ++i) {
+ auto input = equiJoin.Arg(i).Cast<TCoEquiJoinInput>();
+ auto joinArg = input.List();
+
+ if (!typesCtx.StatisticsMap.contains( joinArg.Raw() )) {
+ return node;
+ }
+
+ if (!typesCtx.StatisticsMap[ joinArg.Raw() ]->Cost.has_value()) {
+ return node;
+ }
+
+ auto scope = input.Scope();
+ if (!scope.Maybe<TCoAtom>()){
+ return node;
+ }
+
+ auto label = scope.Cast<TCoAtom>().StringValue();
+ auto stats = typesCtx.StatisticsMap[ joinArg.Raw() ];
+ rels.push_back( std::shared_ptr<TRelOptimizerNode>( new TRelOptimizerNode( label, stats )));
+ }
+
+ YQL_CLOG(TRACE, CoreDq) << "All statistics for join in place";
+
+ std::set<std::pair<TJoinColumn, TJoinColumn>> joinConditions;
+
+ // EquiJoin argument n-2 is the actual join tree, represented as TCoEquiJoinTuple
+ ComputeJoinConditions(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>(), joinConditions );
+
+ // construct a graph out of join conditions
+ TGraph<64> joinGraph;
+ for (size_t i=0; i<rels.size(); i++) {
+ joinGraph.AddNode(i,rels[i]->Label);
+ }
+
+ for (auto cond : joinConditions ) {
+ int fromNode = joinGraph.FindNode(cond.first.RelName);
+ int toNode = joinGraph.FindNode(cond.second.RelName);
+ joinGraph.AddEdge(TEdge(fromNode,toNode,cond));
+ }
+
+ if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) {
+ std::stringstream str;
+ str << "Initial join graph:\n";
+ joinGraph.PrintGraph(str);
+ YQL_CLOG(TRACE, CoreDq) << str.str();
+ }
+
+ // make a transitive closure of the graph and reorder the graph via BFS
+ joinGraph.ComputeTransitiveClosure(joinConditions);
+
+ if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) {
+ std::stringstream str;
+ str << "Join graph after transitive closure:\n";
+ joinGraph.PrintGraph(str);
+ YQL_CLOG(TRACE, CoreDq) << str.str();
+ }
+
+ joinGraph = joinGraph.BfsReorder();
+
+ // feed the graph to DPccp algorithm
+ TDPccpSolver<64> solver(joinGraph,rels);
+ std::shared_ptr<TJoinOptimizerNode> result = solver.Solve();
+
+ if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) {
+ std::stringstream str;
+ str << "Join tree after cost based optimization:\n";
+ result->Print(str);
+ YQL_CLOG(TRACE, CoreDq) << str.str();
+ }
+
+ // rewrite the join tree and record the output statistics
+ TExprBase res = RearrangeEquiJoinTree(ctx,equiJoin,result);
+ typesCtx.StatisticsMap[ res.Raw() ] = result->Stats;
+ return res;
+}
+
+}
diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h
index 6070de70c76..6a0dd1b4d88 100644
--- a/ydb/library/yql/dq/opt/dq_opt_log.h
+++ b/ydb/library/yql/dq/opt/dq_opt_log.h
@@ -18,6 +18,10 @@ NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx,
NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents);
+NNodes::TExprBase DqOptimizeEquiJoinWithCosts(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool isRuleEnabled);
+
+NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, TExprContext& ctx);
+
NNodes::TExprBase DqEnforceCompactPartition(NNodes::TExprBase node, NNodes::TExprList frames, TExprContext& ctx);
NNodes::TExprBase DqExpandWindowFunctions(NNodes::TExprBase node, TExprContext& ctx, bool enforceCompact);
diff --git a/ydb/library/yql/dq/opt/ya.make b/ydb/library/yql/dq/opt/ya.make
index 074847336ab..52bd6534b6b 100644
--- a/ydb/library/yql/dq/opt/ya.make
+++ b/ydb/library/yql/dq/opt/ya.make
@@ -19,6 +19,7 @@ SRCS(
dq_opt_peephole.cpp
dq_opt_phy_finalizing.cpp
dq_opt_phy.cpp
+ dq_opt_join_cost_based.cpp
)
YQL_LAST_ABI_VERSION()