diff options
author | pavelvelikhov <pavelvelikhov@yandex-team.com> | 2023-08-08 13:13:25 +0300 |
---|---|---|
committer | pavelvelikhov <pavelvelikhov@yandex-team.com> | 2023-08-08 15:03:00 +0300 |
commit | 46501e2a1aa36cdac2de76ebefac31d1c00a1ff0 (patch) | |
tree | 3e5533f39a70e3d7ca5f7ead775c2a66ce90fe52 | |
parent | 94d58dee6279337ceef3aaab04b7ae2225584323 (diff) | |
download | ydb-46501e2a1aa36cdac2de76ebefac31d1c00a1ff0.tar.gz |
First version of cost based optimizer
Added feature flags and moved debug output to logger
Working version, moved stats to type ann
Updated CBO
Initial commit of CBO
32 files changed, 1632 insertions, 0 deletions
diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp index ecbffe376e3..b46b12b476a 100644 --- a/ydb/core/kqp/host/kqp_runner.cpp +++ b/ydb/core/kqp/host/kqp_runner.cpp @@ -4,6 +4,8 @@ #include <ydb/core/kqp/query_compiler/kqp_query_compiler.h> #include <ydb/core/kqp/opt/kqp_opt.h> #include <ydb/core/kqp/opt/logical/kqp_opt_log.h> +#include <ydb/core/kqp/opt/kqp_statistics_transformer.h> + #include <ydb/core/kqp/opt/physical/kqp_opt_phy.h> #include <ydb/core/kqp/opt/peephole/kqp_opt_peephole.h> #include <ydb/core/kqp/opt/kqp_query_plan.h> @@ -89,6 +91,7 @@ public: .Add(CreateKqpCheckQueryTransformer(), "CheckKqlQuery") .AddPostTypeAnnotation(/* forSubgraph */ true) .AddCommonOptimization() + .Add(CreateKqpStatisticsTransformer(*typesCtx, Config), "Statistics") .Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config), "LogicalOptimize") .Add(CreateLogicalDataProposalsInspector(*typesCtx), "ProvidersLogicalOptimize") .Add(CreateKqpPhyOptTransformer(OptimizeCtx, *typesCtx), "KqpPhysicalOptimize") diff --git a/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt b/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt index e21e3a2928e..9c3d4e474fd 100644 --- a/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt +++ b/ydb/core/kqp/opt/CMakeLists.darwin-x86_64.txt @@ -44,6 +44,7 @@ target_sources(core-kqp-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp ) generate_enum_serilization(core-kqp-opt ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h diff --git a/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt b/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt index 2f01eb2c6d3..3cb516fc560 100644 --- a/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt +++ b/ydb/core/kqp/opt/CMakeLists.linux-aarch64.txt @@ -45,6 +45,7 @@ target_sources(core-kqp-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp ) generate_enum_serilization(core-kqp-opt ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h diff --git a/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt b/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt index 2f01eb2c6d3..3cb516fc560 100644 --- a/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt +++ b/ydb/core/kqp/opt/CMakeLists.linux-x86_64.txt @@ -45,6 +45,7 @@ target_sources(core-kqp-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp ) generate_enum_serilization(core-kqp-opt ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h diff --git a/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt b/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt index e21e3a2928e..9c3d4e474fd 100644 --- a/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt +++ b/ydb/core/kqp/opt/CMakeLists.windows-x86_64.txt @@ -44,6 +44,7 @@ target_sources(core-kqp-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_opt_range_legacy.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_blocks_transformer.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_statistics_transformer.cpp ) generate_enum_serilization(core-kqp-opt ${CMAKE_SOURCE_DIR}/ydb/core/kqp/opt/kqp_query_plan.h diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp new file mode 100644 index 00000000000..c1dfd05a3a0 --- /dev/null +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp @@ -0,0 +1,157 @@ +#include "kqp_statistics_transformer.h" +#include <ydb/library/yql/utils/log/log.h> + + +using namespace NYql; +using namespace NYql::NNodes; +using namespace NKikimr::NKqp; + +namespace { + +/** + * Helper method to fetch statistics from type annotation context +*/ +std::shared_ptr<TOptimizerStatistics> GetStats( const TExprNode* input, TTypeAnnotationContext* typeCtx ) { + + return typeCtx->StatisticsMap.Value(input, std::shared_ptr<TOptimizerStatistics>(nullptr)); +} + +/** + * Helper method to set statistics in type annotation context +*/ +void SetStats( const TExprNode* input, TTypeAnnotationContext* typeCtx, std::shared_ptr<TOptimizerStatistics> stats ) { + + typeCtx->StatisticsMap[input] = stats; +} + +/** + * Helper method to get cost from type annotation context + * Doesn't check if the cost is in the mapping +*/ +std::optional<double> GetCost( const TExprNode* input, TTypeAnnotationContext* typeCtx ) { + return typeCtx->StatisticsMap[input]->Cost; +} + +/** + * Helper method to set the cost in type annotation context +*/ +void SetCost( const TExprNode* input, TTypeAnnotationContext* typeCtx, std::optional<double> cost ) { + typeCtx->StatisticsMap[input]->Cost = cost; +} +} + +/** + * For Flatmap we check the input and fetch the statistcs and cost from below + * Then we analyze the filter predicate and compute it's selectivity and apply it + * to the result. +*/ +void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto inputNode = TExprBase(input); + auto flatmap = inputNode.Cast<TCoFlatMap>(); + if (!IsPredicateFlatMap(flatmap.Lambda().Body().Ref())) { + return; + } + + auto flatmapInput = flatmap.Input(); + auto inputStats = GetStats(flatmapInput.Raw(), typeCtx); + + if (! inputStats ) { + return; + } + + // Selectivity is the fraction of tuples that are selected by this predicate + // Currently we just set the number to 10% before we have statistics and parse + // the predicate + double selectivity = 0.1; + + auto outputStats = TOptimizerStatistics(inputStats->Nrows * selectivity, inputStats->Ncols); + + SetStats(input.Get(), typeCtx, std::make_shared<TOptimizerStatistics>(outputStats) ); + SetCost(input.Get(), typeCtx, GetCost(flatmapInput.Raw(), typeCtx)); +} + +/** + * Infer statistics and costs for SkipNullMembers + * We don't have a good idea at this time how many nulls will be discarded, so we just return the + * input statistics. +*/ +void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto inputNode = TExprBase(input); + auto skipNullMembers = inputNode.Cast<TCoSkipNullMembers>(); + auto skipNullMembersInput = skipNullMembers.Input(); + + auto inputStats = GetStats(skipNullMembersInput.Raw(), typeCtx); + if (!inputStats) { + return; + } + + SetStats( input.Get(), typeCtx, inputStats ); + SetCost( input.Get(), typeCtx, GetCost( skipNullMembersInput.Raw(), typeCtx ) ); +} + +/** + * Compute statistics and cost for read table + * Currently we just make up a number for the cardinality (100000) and set cost to 0 +*/ +void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table"; + + auto outputStats = TOptimizerStatistics(100000, 5, 0.0); + SetStats( input.Get(), typeCtx, std::make_shared<TOptimizerStatistics>(outputStats) ); +} + +/** + * Compute sstatistics for index lookup + * Currently we just make up a number for cardinality (5) and set cost to 0 +*/ +void InferStatisticsForIndexLookup(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto outputStats = TOptimizerStatistics(5, 5, 0.0); + SetStats( input.Get(), typeCtx, std::make_shared<TOptimizerStatistics>(outputStats) ); +} + +/** + * DoTransform method matches operators and callables in the query DAG and + * uses pre-computed statistics and costs of the children to compute their cost. +*/ +IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPtr input, + TExprNode::TPtr& output, TExprContext& ctx) { + + output = input; + if (!Config->HasOptEnableCostBasedOptimization()) { + return IGraphTransformer::TStatus::Ok; + } + + TOptimizeExprSettings settings(nullptr); + + auto ret = OptimizeExpr(input, output, [*this](const TExprNode::TPtr& input, TExprContext& ctx) { + Y_UNUSED(ctx); + auto output = input; + + if (TCoFlatMap::Match(input.Get())){ + InferStatisticsForFlatMap(input, typeCtx); + } + else if(TCoSkipNullMembers::Match(input.Get())){ + InferStatisticsForSkipNullMembers(input, typeCtx); + } + else if(TKqlReadTableBase::Match(input.Get()) || TKqlReadTableRangesBase::Match(input.Get())){ + InferStatisticsForReadTable(input, typeCtx); + } + else if(TKqlLookupTableBase::Match(input.Get()) || TKqlLookupIndexBase::Match(input.Get())){ + InferStatisticsForIndexLookup(input, typeCtx); + } + + return output; + }, ctx, settings); + + return ret; +} + +TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, + const TKikimrConfiguration::TPtr& config) { + + return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(typeCtx, config)); +}
\ No newline at end of file diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.h b/ydb/core/kqp/opt/kqp_statistics_transformer.h new file mode 100644 index 00000000000..78d982f9981 --- /dev/null +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.h @@ -0,0 +1,45 @@ +#pragma once + +#include <ydb/library/yql/core/yql_statistics.h> + +#include <ydb/core/kqp/common/kqp_yql.h> +#include <ydb/library/yql/core/yql_graph_transformer.h> +#include <ydb/library/yql/core/yql_expr_optimize.h> +#include <ydb/library/yql/core/yql_expr_type_annotation.h> +#include <ydb/core/kqp/provider/yql_kikimr_provider_impl.h> +#include <ydb/library/yql/core/yql_opt_utils.h> + +namespace NKikimr { +namespace NKqp { + +using namespace NYql; +using namespace NYql::NNodes; + +/*** + * Statistics transformer is a transformer that propagates statistics and costs from + * the leaves of the plan DAG up to the root of the DAG. It handles a number of operators, + * but will simply stop propagation if in encounters an operator that it has no rules for. + * One of such operators is EquiJoin, but there is a special rule to handle EquiJoin. +*/ +class TKqpStatisticsTransformer : public TSyncTransformerBase { + + TTypeAnnotationContext* typeCtx; + const TKikimrConfiguration::TPtr& Config; + + public: + TKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) : + typeCtx(&typeCtx), Config(config) {} + + // Main method of the transformer + IGraphTransformer::TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final; + + // Rewind currently does nothing + void Rewind() { + + } +}; + +TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, + const TKikimrConfiguration::TPtr& config); +} +} diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp index dd887361f5a..ea54f36881c 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp @@ -33,6 +33,7 @@ public: AddHandler(0, &TCoTake::Match, HNDL(RewriteTakeSortToTopSort)); AddHandler(0, &TCoFlatMap::Match, HNDL(RewriteSqlInToEquiJoin)); AddHandler(0, &TCoFlatMap::Match, HNDL(RewriteSqlInCompactToJoin)); + AddHandler(0, &TCoEquiJoin::Match, HNDL(OptimizeEquiJoinWithCosts)); AddHandler(0, &TCoEquiJoin::Match, HNDL(RewriteEquiJoin)); AddHandler(0, &TDqJoin::Match, HNDL(JoinToIndexLookup)); AddHandler(0, &TCoCalcOverWindowBase::Match, HNDL(ExpandWindowFunctions)); @@ -123,6 +124,12 @@ protected: return output; } + TMaybeNode<TExprBase> OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) { + TExprBase output = DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, Config->HasOptEnableCostBasedOptimization()); + DumpAppliedRule("OptimizeEquiJoinWithCosts", node.Ptr(), output.Ptr(), ctx); + return output; + } + TMaybeNode<TExprBase> RewriteEquiJoin(TExprBase node, TExprContext& ctx) { TExprBase output = DqRewriteEquiJoin(node, KqpCtx.Config->GetHashJoinMode(), ctx); DumpAppliedRule("RewriteEquiJoin", node.Ptr(), output.Ptr(), ctx); diff --git a/ydb/core/kqp/opt/ya.make b/ydb/core/kqp/opt/ya.make index d04d95ec68f..943fe63957c 100644 --- a/ydb/core/kqp/opt/ya.make +++ b/ydb/core/kqp/opt/ya.make @@ -12,6 +12,7 @@ SRCS( kqp_opt_range_legacy.cpp kqp_query_blocks_transformer.cpp kqp_query_plan.cpp + kqp_statistics_transformer.cpp ) PEERDIR( diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.cpp b/ydb/core/kqp/provider/yql_kikimr_settings.cpp index 115c0ee1e83..6144c18bb72 100644 --- a/ydb/core/kqp/provider/yql_kikimr_settings.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_settings.cpp @@ -63,6 +63,7 @@ TKikimrConfiguration::TKikimrConfiguration() { REGISTER_SETTING(*this, OptEnablePredicateExtract); REGISTER_SETTING(*this, OptEnableOlapPushdown); REGISTER_SETTING(*this, OptUseFinalizeByKey); + REGISTER_SETTING(*this, OptEnableCostBasedOptimization); /* Runtime */ REGISTER_SETTING(*this, ScanQuery); @@ -124,6 +125,11 @@ bool TKikimrSettings::HasOptUseFinalizeByKey() const { return GetOptionalFlagValue(OptUseFinalizeByKey.Get()) == EOptionalFlag::Enabled; } +bool TKikimrSettings::HasOptEnableCostBasedOptimization() const { + return GetOptionalFlagValue(OptEnableCostBasedOptimization.Get()) == EOptionalFlag::Enabled; +} + + EOptionalFlag TKikimrSettings::GetOptPredicateExtract() const { return GetOptionalFlagValue(OptEnablePredicateExtract.Get()); } diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.h b/ydb/core/kqp/provider/yql_kikimr_settings.h index aabf39f8170..081d64354a3 100644 --- a/ydb/core/kqp/provider/yql_kikimr_settings.h +++ b/ydb/core/kqp/provider/yql_kikimr_settings.h @@ -56,6 +56,7 @@ struct TKikimrSettings { NCommon::TConfSetting<bool, false> OptEnablePredicateExtract; NCommon::TConfSetting<bool, false> OptEnableOlapPushdown; NCommon::TConfSetting<bool, false> OptUseFinalizeByKey; + NCommon::TConfSetting<bool, false> OptEnableCostBasedOptimization; /* Runtime */ NCommon::TConfSetting<bool, true> ScanQuery; @@ -75,6 +76,8 @@ struct TKikimrSettings { bool HasOptDisableSqlInToJoin() const; bool HasOptEnableOlapPushdown() const; bool HasOptUseFinalizeByKey() const; + bool HasOptEnableCostBasedOptimization() const; + EOptionalFlag GetOptPredicateExtract() const; EOptionalFlag GetUseLlvm() const; NDq::EHashJoinMode GetHashJoinMode() const; diff --git a/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt b/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt index 46214236eea..44f17858ecd 100644 --- a/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt +++ b/ydb/core/kqp/ut/join/CMakeLists.darwin-x86_64.txt @@ -35,6 +35,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp ) set_property( TARGET diff --git a/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt b/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt index b4bbb7b0a1f..2f8f1cfe4d2 100644 --- a/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt +++ b/ydb/core/kqp/ut/join/CMakeLists.linux-aarch64.txt @@ -38,6 +38,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp ) set_property( TARGET diff --git a/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt b/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt index e7bde2ff331..6d8f2e29856 100644 --- a/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt +++ b/ydb/core/kqp/ut/join/CMakeLists.linux-x86_64.txt @@ -39,6 +39,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp ) set_property( TARGET diff --git a/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt b/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt index f80baee42f2..4ea7727bf85 100644 --- a/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt +++ b/ydb/core/kqp/ut/join/CMakeLists.windows-x86_64.txt @@ -28,6 +28,7 @@ target_sources(ydb-core-kqp-ut-join PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_flip_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_index_lookup_join_ut.cpp ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_ut.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp ) set_property( TARGET diff --git a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp new file mode 100644 index 00000000000..3265ebd4ea1 --- /dev/null +++ b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp @@ -0,0 +1,281 @@ +#include <ydb/core/kqp/ut/common/kqp_ut_common.h> + +#include <ydb/public/sdk/cpp/client/ydb_proto/accessor.h> + +#include <util/string/printf.h> + +namespace NKikimr { +namespace NKqp { + +using namespace NYdb; +using namespace NYdb::NTable; + +/** + * A basic join order test. We define 5 tables sharing the same + * key attribute and construct various full clique join queries +*/ +static void CreateSampleTable(TSession session) { + UNIT_ASSERT(session.ExecuteSchemeQuery(R"( + CREATE TABLE `/Root/R` ( + id Int32, + payload1 String, + PRIMARY KEY (id) + ); + )").GetValueSync().IsSuccess()); + + UNIT_ASSERT(session.ExecuteSchemeQuery(R"( + CREATE TABLE `/Root/S` ( + id Int32, + payload2 String, + PRIMARY KEY (id) + ); + )").GetValueSync().IsSuccess()); + + UNIT_ASSERT(session.ExecuteSchemeQuery(R"( + CREATE TABLE `/Root/T` ( + id Int32, + payload3 String, + PRIMARY KEY (id) + ); + )").GetValueSync().IsSuccess()); + + UNIT_ASSERT(session.ExecuteSchemeQuery(R"( + CREATE TABLE `/Root/U` ( + id Int32, + payload4 String, + PRIMARY KEY (id) + ); + )").GetValueSync().IsSuccess()); + + UNIT_ASSERT(session.ExecuteSchemeQuery(R"( + CREATE TABLE `/Root/V` ( + id Int32, + payload5 String, + PRIMARY KEY (id) + ); + )").GetValueSync().IsSuccess()); + + UNIT_ASSERT(session.ExecuteDataQuery(R"( + + REPLACE INTO `/Root/R` (id, payload1) VALUES + (1, "blah"); + + REPLACE INTO `/Root/S` (id, payload2) VALUES + (1, "blah"); + + REPLACE INTO `/Root/T` (id, payload3) VALUES + (1, "blah"); + + REPLACE INTO `/Root/U` (id, payload4) VALUES + (1, "blah"); + + REPLACE INTO `/Root/V` (id, payload5) VALUES + (1, "blah"); + )", TTxControl::BeginTx().CommitTx()).GetValueSync().IsSuccess()); +} + +static TKikimrRunner GetKikimrWithJoinSettings(){ + TVector<NKikimrKqp::TKqpSetting> settings; + NKikimrKqp::TKqpSetting setting; + setting.SetName("OptEnableCostBasedOptimization"); + setting.SetValue("true"); + settings.push_back(setting); + + return TKikimrRunner(settings); +} + + +Y_UNIT_TEST_SUITE(KqpJoinOrder) { + Y_UNIT_TEST(FiveWayJoin) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( + SELECT * + FROM `/Root/R` as R + INNER JOIN + `/Root/S` as S + ON R.id = S.id + INNER JOIN + `/Root/T` as T + ON S.id = T.id + INNER JOIN + `/Root/U` as U + ON T.id = U.id + INNER JOIN + `/Root/V` as V + ON U.id = V.id + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } + + Y_UNIT_TEST(FiveWayJoinWithPreds) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( + SELECT * + FROM `/Root/R` as R + INNER JOIN + `/Root/S` as S + ON R.id = S.id + INNER JOIN + `/Root/T` as T + ON S.id = T.id + INNER JOIN + `/Root/U` as U + ON T.id = U.id + INNER JOIN + `/Root/V` as V + ON U.id = V.id + WHERE R.payload1 = 'blah' AND V.payload5 = 'blah' + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } + + Y_UNIT_TEST(FiveWayJoinWithComplexPreds) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( + SELECT * + FROM `/Root/R` as R + INNER JOIN + `/Root/S` as S + ON R.id = S.id + INNER JOIN + `/Root/T` as T + ON S.id = T.id + INNER JOIN + `/Root/U` as U + ON T.id = U.id + INNER JOIN + `/Root/V` as V + ON U.id = V.id + WHERE R.payload1 = 'blah' AND V.payload5 = 'blah' AND ( S.payload2 || T.payload3 = U.payload4 ) + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } + + Y_UNIT_TEST(FiveWayJoinWithComplexPreds2) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( + SELECT * + FROM `/Root/R` as R + INNER JOIN + `/Root/S` as S + ON R.id = S.id + INNER JOIN + `/Root/T` as T + ON S.id = T.id + INNER JOIN + `/Root/U` as U + ON T.id = U.id + INNER JOIN + `/Root/V` as V + ON U.id = V.id + WHERE (R.payload1 || V.payload5 = 'blah') AND U.payload4 = 'blah' + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } + + Y_UNIT_TEST(FiveWayJoinWithPredsAndEquiv) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( + SELECT * + FROM `/Root/R` as R + INNER JOIN + `/Root/S` as S + ON R.id = S.id + INNER JOIN + `/Root/T` as T + ON S.id = T.id + INNER JOIN + `/Root/U` as U + ON T.id = U.id + INNER JOIN + `/Root/V` as V + ON U.id = V.id + WHERE R.payload1 = 'blah' AND V.payload5 = 'blah' AND R.id = 1 + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } +} + +} +} + diff --git a/ydb/core/kqp/ut/join/ya.make b/ydb/core/kqp/ut/join/ya.make index ee9c559a46e..9897a40d2e2 100644 --- a/ydb/core/kqp/ut/join/ya.make +++ b/ydb/core/kqp/ut/join/ya.make @@ -16,6 +16,7 @@ SRCS( kqp_flip_join_ut.cpp kqp_index_lookup_join_ut.cpp kqp_join_ut.cpp + kqp_join_order_ut.cpp ) PEERDIR( diff --git a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt index e195fe28fd7..0b4b5ebc1d8 100644 --- a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt +++ b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt @@ -98,6 +98,7 @@ target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp diff --git a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt index b4486180fb5..66ef852bbe6 100644 --- a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt @@ -99,6 +99,7 @@ target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp diff --git a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt index b4486180fb5..66ef852bbe6 100644 --- a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt +++ b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt @@ -99,6 +99,7 @@ target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp diff --git a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt index e195fe28fd7..0b4b5ebc1d8 100644 --- a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt +++ b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt @@ -98,6 +98,7 @@ target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_rewrite_io.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_utils.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_opt_window.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_statistics.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_annotation.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_type_helpers.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_udf_index.cpp diff --git a/ydb/library/yql/core/ya.make b/ydb/library/yql/core/ya.make index 1dbac92b4cf..7665acc9da3 100644 --- a/ydb/library/yql/core/ya.make +++ b/ydb/library/yql/core/ya.make @@ -37,6 +37,7 @@ SRCS( yql_opt_utils.h yql_opt_window.cpp yql_opt_window.h + yql_statistics.cpp yql_type_annotation.cpp yql_type_annotation.h yql_type_helpers.cpp diff --git a/ydb/library/yql/core/yql_statistics.cpp b/ydb/library/yql/core/yql_statistics.cpp new file mode 100644 index 00000000000..8abb9a59b5d --- /dev/null +++ b/ydb/library/yql/core/yql_statistics.cpp @@ -0,0 +1,14 @@ +#include "yql_statistics.h" + +using namespace NYql; + +std::ostream& operator<<(std::ostream& os, const TOptimizerStatistics& s) { + os << "Nrows: " << s.Nrows << ", Ncols: " << s.Ncols; + os << "Cost: "; + if (s.Cost.has_value()){ + os << s.Cost.value(); + } else { + os << "none"; + } + return os; +}
\ No newline at end of file diff --git a/ydb/library/yql/core/yql_statistics.h b/ydb/library/yql/core/yql_statistics.h new file mode 100644 index 00000000000..615dce6359b --- /dev/null +++ b/ydb/library/yql/core/yql_statistics.h @@ -0,0 +1,26 @@ +#pragma once + +#include <optional> +#include <iostream> + +namespace NYql { + +/** + * Optimizer Statistics struct records per-table and per-column statistics + * for the current operator in the plan. Currently, only Nrows and Ncols are + * recorded. + * Cost is also included in statistics, as its updated concurrently with statistics + * all of the time. Cost is optional, so it could be missing. +*/ +struct TOptimizerStatistics { + double Nrows; + int Ncols; + std::optional<double> Cost; + + TOptimizerStatistics() : Cost(std::nullopt) {} + TOptimizerStatistics(double nrows,int ncols): Nrows(nrows), Ncols(ncols), Cost(std::nullopt) {} + TOptimizerStatistics(double nrows,int ncols, double cost): Nrows(nrows), Ncols(ncols), Cost(cost) {} + + friend std::ostream& operator<<(std::ostream& os, const TOptimizerStatistics& s); +}; +} diff --git a/ydb/library/yql/core/yql_type_annotation.h b/ydb/library/yql/core/yql_type_annotation.h index f3f5e595f39..6a1670fd192 100644 --- a/ydb/library/yql/core/yql_type_annotation.h +++ b/ydb/library/yql/core/yql_type_annotation.h @@ -5,6 +5,7 @@ #include "yql_udf_resolver.h" #include "yql_user_data_storage.h" #include "yql_arrow_resolver.h" +#include "yql_statistics.h" #include <ydb/library/yql/public/udf/udf_validate.h> #include <ydb/library/yql/core/credentials/yql_credentials.h> @@ -184,6 +185,7 @@ struct TUdfCachedInfo { }; struct TTypeAnnotationContext: public TThrRefBase { + THashMap<const TExprNode*, std::shared_ptr<TOptimizerStatistics>> StatisticsMap; TIntrusivePtr<ITimeProvider> TimeProvider; TIntrusivePtr<IRandomProvider> RandomProvider; THashMap<TString, TIntrusivePtr<IDataProvider>> DataSourceMap; diff --git a/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt index 20e4f10070f..6ac7a6e01e7 100644 --- a/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt +++ b/ydb/library/yql/dq/opt/CMakeLists.darwin-x86_64.txt @@ -31,4 +31,5 @@ target_sources(yql-dq-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp ) diff --git a/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt b/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt index bc0a22c399e..54ab33f05f8 100644 --- a/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/dq/opt/CMakeLists.linux-aarch64.txt @@ -32,4 +32,5 @@ target_sources(yql-dq-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp ) diff --git a/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt index bc0a22c399e..54ab33f05f8 100644 --- a/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt +++ b/ydb/library/yql/dq/opt/CMakeLists.linux-x86_64.txt @@ -32,4 +32,5 @@ target_sources(yql-dq-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp ) diff --git a/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt b/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt index 20e4f10070f..6ac7a6e01e7 100644 --- a/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt +++ b/ydb/library/yql/dq/opt/CMakeLists.windows-x86_64.txt @@ -31,4 +31,5 @@ target_sources(yql-dq-opt PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_peephole.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_phy.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp ) diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp new file mode 100644 index 00000000000..03d328ae202 --- /dev/null +++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp @@ -0,0 +1,1064 @@ +#include "dq_opt_join.h" +#include "dq_opt_phy.h" + +#include <ydb/library/yql/core/yql_join.h> +#include <ydb/library/yql/core/yql_opt_utils.h> +#include <ydb/library/yql/dq/type_ann/dq_type_ann.h> +#include <ydb/library/yql/utils/log/log.h> +#include <ydb/library/yql/providers/common/provider/yql_provider.h> +#include <ydb/library/yql/core/yql_type_helpers.h> +#include <ydb/library/yql/core/yql_statistics.h> + +#include <library/cpp/disjoint_sets/disjoint_sets.h> + + +#include <bitset> +#include <set> +#include <unordered_map> +#include <unordered_set> +#include <vector> +#include <queue> +#include <memory> +#include <sstream> + +namespace NYql::NDq { + + +using namespace NYql::NNodes; + + +/** + * Join column is a struct that records the relation label and + * attribute name, used in join conditions +*/ +struct TJoinColumn { + TString RelName; + TString AttributeName; + + TJoinColumn(TString relName, TString attributeName) : RelName(relName), + AttributeName(attributeName) {} + + bool operator == (const TJoinColumn& other) const { + return RelName == other.RelName && AttributeName == other.AttributeName; + } + + struct HashFunction + { + size_t operator()(const TJoinColumn& c) const + { + return THash<TString>{}(c.RelName) ^ THash<TString>{}(c.AttributeName); + } + }; +}; + +bool operator < (const TJoinColumn& c1, const TJoinColumn& c2) { + if (c1.RelName < c2.RelName){ + return true; + } else if (c1.RelName == c2.RelName) { + return c1.AttributeName < c2.AttributeName; + } + return false; +} + +/** + * Edge structure records an edge in a Join graph. + * - from is the integer id of the source vertex of the graph + * - to is the integer id of the target vertex of the graph + * - joinConditions records the set of join conditions of this edge +*/ +struct TEdge { + int From; + int To; + std::set<std::pair<TJoinColumn, TJoinColumn>> JoinConditions; + + TEdge(int f, int t): From(f), To(t) {} + + TEdge(int f, int t, std::pair<TJoinColumn, TJoinColumn> cond): From(f), To(t) { + JoinConditions.insert(cond); + } + + TEdge(int f, int t, std::set<std::pair<TJoinColumn, TJoinColumn>> conds): From(f), To(t), + JoinConditions(conds) {} + + bool operator==(const TEdge& other) const + { + return From==other.From && To==other.To; + } + + struct HashFunction + { + size_t operator()(const TEdge& e) const + { + return e.From + e.To; + } + }; +}; + +/** + * Graph is a data structure for the join graph + * It is an undirected graph, with two edges per connection (from,to) and (to,from) + * It needs to be constructed with addNode and addEdge methods, since its + * keeping various indexes updated. + * The graph also needs to be reordered with the breadth-first search method, and + * the reordering is recorded in bfsMapping (original rel indexes can be recovered from + * this mapping) +*/ +template <int N> +struct TGraph { + // set of edges of the graph + std::unordered_set<TEdge,TEdge::HashFunction> Edges; + + // neightborgh index + TVector<std::bitset<N>> EdgeIdx; + + // number of nodes in a graph + int NNodes; + + // mapping from rel label to node in the graph + THashMap<TString,int> ScopeMapping; + + // mapping from node in the graph to rel label + TVector<TString> RevScopeMapping; + + // Breadth-first-search mapping + TVector<int> BfsMapping; + + // Empty graph constructor intializes indexes to size N + TGraph() : EdgeIdx(N), RevScopeMapping(N) {} + + // Add a node to a graph with a rel label + void AddNode(int nodeId, TString scope){ + NNodes = nodeId + 1; + ScopeMapping[scope] = nodeId; + RevScopeMapping[nodeId] = scope; + } + + // Add an edge to the graph, if the edge is already in the graph + // (we check both directions), no action is taken. Otherwise we + // insert two edges, the forward edge with original joinConditions + // and a reverse edge with swapped joinConditions + void AddEdge(TEdge e){ + if (Edges.contains(e) || Edges.contains(TEdge(e.To, e.From))) { + return; + } + + Edges.insert(e); + std::set<std::pair<TJoinColumn, TJoinColumn>> swappedSet; + for (auto c : e.JoinConditions){ + swappedSet.insert(std::make_pair(c.second, c.first)); + } + Edges.insert(TEdge(e.To,e.From,swappedSet)); + + EdgeIdx[e.From].set(e.To); + EdgeIdx[e.To].set(e.From); + } + + // Find a node by the rel scope + int FindNode(TString scope){ + return ScopeMapping[scope]; + } + + // Return a bitset of node's neighbors + inline std::bitset<N> FindNeighbors(int fromVertex) + { + return EdgeIdx[fromVertex]; + } + + // Return a bitset of node's neigbors with itself included + inline std::bitset<N> FindNeighborsWithSelf(int fromVertex) + { + std::bitset<N> res = FindNeighbors(fromVertex); + res.set(fromVertex); + return res; + } + + // Find an edge that connects two subsets of graph's nodes + // We are guaranteed to find a match + TEdge FindCrossingEdge(const std::bitset<N>& S1, const std::bitset<N>& S2) { + for(int i=0; i<NNodes; i++){ + if (!S1[i]) { + continue; + } + for (int j=0; j<NNodes; j++) { + if (!S2[j]) { + continue; + } + if ((FindNeighborsWithSelf(i) & FindNeighborsWithSelf(j)) != 0) { + auto it = Edges.find(TEdge(i, j)); + Y_VERIFY_DEBUG(it != Edges.end()); + return *it; + } + } + } + Y_ENSURE(false,"Connecting edge not found!"); + return TEdge(-1,-1); + } + + /** + * Create a union-set from the join conditions to record the equivalences. + * Then use the equivalence set to compute transitive closure of the graph. + * Transitive closure means that if we have an edge from (1,2) with join + * condition R.A = S.A and we have an edge from (2,3) with join condition + * S.A = T.A, we will find out that the join conditions form an equivalence set + * and add an edge (1,3) with join condition R.A = T.A. + */ + void ComputeTransitiveClosure(const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { + std::set<TJoinColumn> columnSet; + for (auto [ leftCondition, rightCondition ] : joinConditions) { + columnSet.insert( leftCondition ); + columnSet.insert( rightCondition ); + } + std::vector<TJoinColumn> columns; + for (auto c : columnSet ) { + columns.push_back(c); + } + + THashMap<TJoinColumn, int, TJoinColumn::HashFunction> indexMapping; + for (size_t i=0; i<columns.size(); i++) { + indexMapping[ columns[i] ] = i; + } + + TDisjointSets ds = TDisjointSets( columns.size() ); + for (auto [ leftCondition, rightCondition ] : joinConditions ) { + int leftIndex = indexMapping[ leftCondition ]; + int rightIndex = indexMapping[ rightCondition ]; + ds.UnionSets(leftIndex,rightIndex); + } + + for (size_t i=0;i<columns.size();i++) { + for (size_t j=0;j<i;j++) { + if (ds.CanonicSetElement(i) == ds.CanonicSetElement(j)) { + TJoinColumn left = columns[i]; + TJoinColumn right = columns[j]; + int leftNodeId = ScopeMapping[ left.RelName ]; + int rightNodeId = ScopeMapping[ right.RelName ]; + + if (! Edges.contains(TEdge(leftNodeId,rightNodeId)) && + ! Edges.contains(TEdge(rightNodeId,leftNodeId))) { + AddEdge(TEdge(leftNodeId,rightNodeId,std::make_pair(left, right))); + } else { + TEdge e1 = *Edges.find(TEdge(leftNodeId,rightNodeId)); + if (!e1.JoinConditions.contains(std::make_pair(left, right))) { + e1.JoinConditions.insert(std::make_pair(left, right)); + } + + TEdge e2 = *Edges.find(TEdge(rightNodeId,leftNodeId)); + if (!e2.JoinConditions.contains(std::make_pair(right, left))) { + e2.JoinConditions.insert(std::make_pair(right, left)); + } + } + } + } + } + } + + /** + * Reorder the graph by doing a breadth first search from node 0. + * This is required by the DPccp algorithm. + */ + TGraph<N> BfsReorder() { + std::set<int> visited; + std::queue<int> queue; + TVector<int> bfsMapping(NNodes); + + queue.push(0); + int lastVisited = 0; + + while(! queue.empty()) { + int curr = queue.front(); + queue.pop(); + if (visited.contains(curr)) { + continue; + } + bfsMapping[curr] = lastVisited; + lastVisited++; + visited.insert(curr); + + std::bitset<N> neighbors = FindNeighbors(curr); + for (int i=0; i<NNodes; i++ ) { + if (neighbors[i]){ + if (!visited.contains(i)) { + queue.push(i); + } + } + } + } + + TGraph<N> res; + + for (int i=0;i<NNodes;i++) { + res.AddNode(i,RevScopeMapping[bfsMapping[i]]); + } + + for (const TEdge& e : Edges){ + res.AddEdge( TEdge(bfsMapping[e.From], bfsMapping[e.To], e.JoinConditions) ); + } + + res.BfsMapping = bfsMapping; + + return res; + } + + /** + * Print the graph + */ + void PrintGraph(std::stringstream& stream) { + stream << "Join Graph:\n"; + stream << "nNodes: " << NNodes << ", nEdges: " << Edges.size() << "\n"; + + for(int i=0;i<NNodes;i++) { + stream << "Node:" << i << "," << RevScopeMapping[i] << "\n"; + } + for (const TEdge& e: Edges ) { + stream << "Edge: " << e.From << " -> " << e.To << "\n"; + for (auto p : e.JoinConditions) { + stream << p.first.RelName << "." + << p.first.AttributeName << "=" + << p.second.RelName << "." + << p.second.AttributeName << "\n"; + } + } + } +}; + +/** + * Fetch join conditions from the equi-join tree +*/ +void ComputeJoinConditions(const TCoEquiJoinTuple& joinTuple, + std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { + if (joinTuple.LeftScope().Maybe<TCoEquiJoinTuple>()) { + ComputeJoinConditions( joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), joinConditions ); + } + + if (joinTuple.RightScope().Maybe<TCoEquiJoinTuple>()) { + ComputeJoinConditions( joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), joinConditions ); + } + + size_t joinKeysCount = joinTuple.LeftKeys().Size() / 2; + for (size_t i = 0; i < joinKeysCount; ++i) { + size_t keyIndex = i * 2; + + auto leftScope = joinTuple.LeftKeys().Item(keyIndex).StringValue(); + auto leftColumn = joinTuple.LeftKeys().Item(keyIndex + 1).StringValue(); + auto rightScope = joinTuple.RightKeys().Item(keyIndex).StringValue(); + auto rightColumn = joinTuple.RightKeys().Item(keyIndex + 1).StringValue(); + + joinConditions.insert( std::make_pair( TJoinColumn(leftScope, leftColumn), + TJoinColumn(rightScope, rightColumn))); + } +} + +/** + * OptimizerNodes are the internal representations of operators inside the + * Cost-based optimizer. Currently we only support RelOptimizerNode - a node that + * is an input relation to the equi-join, and JoinOptimizerNode - an inner join + * that connects two sets of relations. +*/ +enum EOptimizerNodeKind: ui32 +{ + RelNodeType, + JoinNodeType +}; + +/** + * BaseOptimizerNode is a base class for the internal optimizer nodes + * It records a pointer to statistics and records the current cost of the + * operator tree, rooted at this node +*/ +struct IBaseOptimizerNode { + EOptimizerNodeKind Kind; + std::shared_ptr<TOptimizerStatistics> Stats; + + IBaseOptimizerNode(EOptimizerNodeKind k) : Kind(k) {} + IBaseOptimizerNode(EOptimizerNodeKind k, std::shared_ptr<TOptimizerStatistics> s) : + Kind(k), Stats(s) {} + + virtual void Print(std::stringstream& stream, int ntabs=0)=0; +}; + +/** + * RelOptimizerNode adds a label to base class + * This is the label assinged to the input by equi-Join +*/ +struct TRelOptimizerNode : public IBaseOptimizerNode { + TString Label; + + TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats) : + IBaseOptimizerNode(RelNodeType, stats), Label(label) { } + virtual ~TRelOptimizerNode() {} + + virtual void Print(std::stringstream& stream, int ntabs=0) { + for (int i=0;i<ntabs;i++){ + stream << "\t"; + } + stream << "Rel: " << Label << "\n"; + + for (int i=0;i<ntabs;i++){ + stream << "\t"; + } + stream << Stats << "\n"; + } +}; + +/** + * JoinOptimizerNode records the left and right arguments of the join + * as well as the set of join conditions. + * It also has methods to compute the statistics and cost of a join, + * based on pre-computed costs and statistics of the children. +*/ +struct TJoinOptimizerNode : public IBaseOptimizerNode { + std::shared_ptr<IBaseOptimizerNode> LeftArg; + std::shared_ptr<IBaseOptimizerNode> RightArg; + std::set<std::pair<TJoinColumn, TJoinColumn>> JoinConditions; + + TJoinOptimizerNode(std::shared_ptr<IBaseOptimizerNode> left, std::shared_ptr<IBaseOptimizerNode> right, + const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) : + IBaseOptimizerNode(JoinNodeType), LeftArg(left), RightArg(right), JoinConditions(joinConditions) {} + virtual ~TJoinOptimizerNode() {} + + /** + * Compute and set the statistics for this node. + * Currently we have a very rough calculation of statistics + */ + void ComputeStatistics() { + double newCard = 0.2 * LeftArg->Stats->Nrows * RightArg->Stats->Nrows; + int newNCols = LeftArg->Stats->Ncols + RightArg->Stats->Ncols; + Stats = std::shared_ptr<TOptimizerStatistics>(new TOptimizerStatistics(newCard,newNCols)); + } + + /** + * Compute the cost of the join based on statistics and costs of children + * Again, we only have a rought calculation at this time + */ + double ComputeCost() { + Y_ENSURE(LeftArg->Stats->Cost.has_value() && RightArg->Stats->Cost.has_value(), + "Missing values for costs in join computation"); + + return 2.0 * LeftArg->Stats->Nrows + RightArg->Stats->Nrows + + Stats->Nrows + + LeftArg->Stats->Cost.value() + RightArg->Stats->Cost.value(); + } + + /** + * Print out the join tree, rooted at this node + */ + virtual void Print(std::stringstream& stream, int ntabs=0) { + for (int i=0;i<ntabs;i++){ + stream << "\t"; + } + + stream << "Join: "; + for (auto c : JoinConditions){ + stream << c.first.RelName << "." << c.first.AttributeName + << "=" << c.second.RelName << "." + << c.second.AttributeName << ", "; + } + stream << "\n"; + + for (int i=0;i<ntabs;i++){ + stream << "\t"; + } + + stream << Stats << "\n"; + + LeftArg->Print(stream, ntabs+1); + RightArg->Print(stream, ntabs+1); + } +}; + + +/** + * Create a new join and compute its statistics and cost +*/ +std::shared_ptr<TJoinOptimizerNode> MakeJoin(std::shared_ptr<IBaseOptimizerNode> left, + std::shared_ptr<IBaseOptimizerNode> right, const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { + + auto res = std::shared_ptr<TJoinOptimizerNode>(new TJoinOptimizerNode(left, right, joinConditions)); + res->ComputeStatistics(); + res->Stats->Cost = res->ComputeCost(); + return res; +} + +struct pair_hash { + template <class T1, class T2> + std::size_t operator () (const std::pair<T1,T2> &p) const { + auto h1 = std::hash<T1>{}(p.first); + auto h2 = std::hash<T2>{}(p.second); + + // Mainly for demonstration purposes, i.e. works but is overly simple + // In the real world, use sth. like boost.hash_combine + return h1 ^ h2; + } +}; + +/** + * DPcpp (Dynamic Programming with connected complement pairs) is a graph-aware + * join eumeration algorithm that only considers CSGs (Connected Sub-Graphs) of + * the join graph and computes CMPs (Complement pairs) that are also connected + * subgraphs of the join graph. It enumerates CSGs in the order, such that subsets + * are enumerated first and no duplicates are ever enumerated. Then, for each emitted + * CSG it computes the complements with the same conditions - they much already be + * present in the dynamic programming table and no pair should be enumerated twice. + * + * The DPccp solver is templated by the largest number of joins we can process, this + * is in turn used by bitsets that represent sets of relations. +*/ +template <int N> +class TDPccpSolver { + public: + + // Construct the DPccp solver based on the join graph and data about input relations + TDPccpSolver(TGraph<N>& g, std::vector<std::shared_ptr<TRelOptimizerNode>> rels): + Graph(g), Rels(rels) { + NNodes = g.NNodes; + } + + // Run DPccp algorithm and produce the join tree in CBO's internal representation + std::shared_ptr<TJoinOptimizerNode> Solve(); + + private: + + // Compute the next subset of relations, given by the final bitset + std::bitset<N> NextBitset(const std::bitset<N>& current, const std::bitset<N>& final); + + // Print the set of relations in a bitset + void PrintBitset(std::stringstream& stream, const std::bitset<N>& s, std::string name, int ntabs=0); + + // Dynamic programming table that records optimal join subtrees + THashMap<std::bitset<N>, std::shared_ptr<IBaseOptimizerNode>, std::hash<std::bitset<N>>> DpTable; + + // REMOVE: Sanity check table that tracks that we don't consider the same pair twice + THashMap<std::pair<std::bitset<N>, std::bitset<N>>, bool, pair_hash> CheckTable; + + // number of nodes in a graph + int NNodes; + + // Join graph + TGraph<N>& Graph; + + // List of input relations to DPccp + std::vector<std::shared_ptr<TRelOptimizerNode>> Rels; + + // Emit connected subgraph + void EmitCsg(const std::bitset<N>&, int=0); + + // Enumerate subgraphs recursively + void EnumerateCsgRec(const std::bitset<N>&, const std::bitset<N>&,int=0); + + // Emit the final pair of CSG and CMP - compute the join and record it in the + // DP table + void EmitCsgCmp(const std::bitset<N>&, const std::bitset<N>&,int=0); + + // Enumerate complement pairs recursively + void EnumerateCmpRec(const std::bitset<N>&, const std::bitset<N>&, const std::bitset<N>&,int=0); + + // Compute the neighbors of a set of nodes, excluding the nodes in exclusion set + std::bitset<N> Neighbors(const std::bitset<N>&, const std::bitset<N>&); + + // Create an exclusion set that contains all the nodes of the graph that are smaller or equal to + // the smallest node in the provided bitset + std::bitset<N> MakeBiMin(const std::bitset<N>&); + + // Create an exclusion set that contains all the nodes of the bitset that are smaller or equal to + // the provided integer + std::bitset<N> MakeB(const std::bitset<N>&,int); +}; + +// Print tabs +void PrintTabs(std::stringstream& stream, int ntabs) { + + for (int i=0;i<ntabs;i++) + stream << "\t"; +} + +// Print a set of nodes in the graph given by this bitset +template <int N> void TDPccpSolver<N>::PrintBitset(std::stringstream& stream, + const std::bitset<N>& s, std::string name, int ntabs) { + + PrintTabs(stream, ntabs); + + stream << name << ": " << "{"; + for (int i=0;i<NNodes;i++) + if (s[i]) + stream << i << ","; + + stream <<"}\n"; +} + +// Compute neighbors of a set of nodes S, exclusing the exclusion set X +template<int N> std::bitset<N> TDPccpSolver<N>::Neighbors(const std::bitset<N>& S, const std::bitset<N>& X) { + + std::bitset<N> res; + + for (int i=0;i<Graph.NNodes;i++) { + if (S[i]) { + std::bitset<N> n = Graph.FindNeighbors(i); + res = res | n; + } + } + + res = res & ~ X; + return res; +} + +// Run the entire DPccp algorithm and compute the optimal join tree +template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() +{ + // Process singleton sets + for (int i=NNodes-1;i>=0;i--) { + std::bitset<N> s; + s.set(i); + DpTable[s] = Rels[Graph.BfsMapping[i]]; + } + + // Expand singleton sets + for (int i=NNodes-1;i>=0;i--) { + std::bitset<N> s; + s.set(i); + EmitCsg(s); + EnumerateCsgRec(s, MakeBiMin(s)); + } + + // Return the entry of the dpTable that corresponds to the full + // set of nodes in the graph + std::bitset<N> V; + for (int i=0;i<NNodes;i++) { + V.set(i); + } + + Y_ENSURE(DpTable.contains(V), "Final relset not in dptable"); + return std::dynamic_pointer_cast<TJoinOptimizerNode>(DpTable[V]); +} + +/** + * EmitCsg emits Connected SubGraphs + * First it iterates through neighbors of the initial set S and emits pairs + * (S,S2), where S2 is the neighbor of S. Then it recursively emits complement pairs +*/ + template <int N> void TDPccpSolver<N>::EmitCsg(const std::bitset<N>& S, int ntabs) { + std::bitset<N> X = S | MakeBiMin(S); + std::bitset<N> Ns = Neighbors(S, X); + + if (Ns==std::bitset<N>()) { + return; + } + + for (int i=NNodes-1;i>=0;i--) { + if (Ns[i]) { + std::bitset<N> S2; + S2.set(i); + EmitCsgCmp(S,S2,ntabs+1); + EnumerateCmpRec(S, S2, X|MakeB(Ns,i), ntabs+1); + } + } + } + + /** + * Enumerates connected subgraphs + * First it emits CSGs that are created by adding neighbors of S to S + * Then it recurses on the S fused with its neighbors. + */ + template <int N> void TDPccpSolver<N>::EnumerateCsgRec(const std::bitset<N>& S, const std::bitset<N>& X, int ntabs) { + + std::bitset<N> Ns = Neighbors(S,X); + + if (Ns == std::bitset<N>()) { + return; + } + + std::bitset<N> prev; + std::bitset<N> next; + + while(true) { + next = NextBitset(prev,Ns); + EmitCsg(S | next ); + if (next == Ns) { + break; + } + prev = next; + } + + prev.reset(); + while(true) { + next = NextBitset(prev,Ns); + EnumerateCsgRec(S | next, X | Ns , ntabs+1); + if (next==Ns) { + break; + } + prev = next; + } + } + +/*** + * Enumerates complement pairs + * First it emits the pairs (S1,S2+next) where S2+next is the set of relation sets + * that are obtained by adding S2's neighbors to itself + * Then it recusrses into pairs (S1,S2+next) +*/ + template <int N> void TDPccpSolver<N>::EnumerateCmpRec(const std::bitset<N>& S1, + const std::bitset<N>& S2, const std::bitset<N>& X, int ntabs) { + + std::bitset<N> Ns = Neighbors(S2, X); + + if (Ns==std::bitset<N>()) { + return; + } + + std::bitset<N> prev; + std::bitset<N> next; + + while(true) { + next = NextBitset(prev, Ns); + EmitCsgCmp(S1,S2|next, ntabs+1); + if (next==Ns) { + break; + } + prev = next; + } + + prev.reset(); + while(true) { + next = NextBitset(prev, Ns); + EnumerateCmpRec(S1, S2|next, X|Ns, ntabs+1); + if (next==Ns) { + break; + } + prev = next; + } + } + +/** + * Emit a single CSG + CMP pair +*/ +template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, const std::bitset<N>& S2, int ntabs) { + + Y_UNUSED(ntabs); + // Here we actually build the join and choose and compare the + // new plan to what's in the dpTable, if it there + + Y_ENSURE(DpTable.contains(S1),"DP Table does not contain S1"); + Y_ENSURE(DpTable.contains(S2),"DP Table does not conaint S2"); + + std::bitset<N> joined = S1 | S2; + + if (! DpTable.contains(joined)) { + TEdge e1 = Graph.FindCrossingEdge(S1, S2); + DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions); + TEdge e2 = Graph.FindCrossingEdge(S2, S1); + std::shared_ptr<TJoinOptimizerNode> newJoin = + MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions); + if (newJoin->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){ + DpTable[joined] = newJoin; + } + } else { + TEdge e1 = Graph.FindCrossingEdge(S1, S2); + std::shared_ptr<TJoinOptimizerNode> newJoin1 = + MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions); + TEdge e2 = Graph.FindCrossingEdge(S2, S1); + std::shared_ptr<TJoinOptimizerNode> newJoin2 = + MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions); + if (newJoin1->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){ + DpTable[joined] = newJoin1; + } + if (newJoin2->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){ + DpTable[joined] = newJoin2; + } + } + + auto pair = std::make_pair(S1, S2); + Y_ENSURE (!CheckTable.contains(pair), "Check table already contains pair S1|S2"); + + CheckTable[ std::pair<std::bitset<N>,std::bitset<N>>(S1, S2) ] = true; +} + +/** + * Create an exclusion set that contains all the nodes of the graph that are smaller or equal to + * the smallest node in the provided bitset +*/ +template <int N> std::bitset<N> TDPccpSolver<N>::MakeBiMin(const std::bitset<N>& S) { + std::bitset<N> res; + + for (int i=0;i<NNodes; i++) { + if (S[i]) { + for (int j=0; j<=i; j++) { + res.set(j); + } + break; + } + } + return res; +} + +/** + * Create an exclusion set that contains all the nodes of the bitset that are smaller or equal to + * the provided integer +*/ +template <int N> std::bitset<N> TDPccpSolver<N>::MakeB(const std::bitset<N>& S, int x) { + std::bitset<N> res; + + for (int i=0; i<NNodes; i++) { + if (S[i] && i<=x) { + res.set(i); + } + } + + return res; +} + +/** + * Compute the next subset of relations, given by the final bitset +*/ +template <int N> std::bitset<N> TDPccpSolver<N>::NextBitset(const std::bitset<N>& prev, const std::bitset<N>& final) { + if (prev==final) + return final; + + std::bitset<N> res = prev; + + bool carry = true; + for (int i=0; i<NNodes; i++) + { + if (!carry) { + break; + } + + if (!final[i]) { + continue; + } + + if (res[i]==1 && carry) { + res.reset(i); + } else if (res[i]==0 && carry) + { + res.set(i); + carry = false; + } + } + + return res; + + // TODO: We can optimize this with a few long integer operations, + // but it will only work for 64 bit bitsets + // return std::bitset<N>((prev | ~final).to_ulong() + 1) & final; +} + +/** + * Build a join tree that will replace the original join tree in equiJoin + * TODO: Add join implementations here +*/ +TExprBase BuildTree(TExprContext& ctx, const TCoEquiJoin& equiJoin, + std::shared_ptr<TJoinOptimizerNode>& reorderResult) { + + // Create dummy left and right arg that will be overwritten + TExprBase leftArg(equiJoin); + TExprBase rightArg(equiJoin); + + // Build left argument of the join + if (reorderResult->LeftArg->Kind == RelNodeType) { + std::shared_ptr<TRelOptimizerNode> rel = + std::dynamic_pointer_cast<TRelOptimizerNode>(reorderResult->LeftArg); + leftArg = BuildAtom(rel->Label, equiJoin.Pos(), ctx); + } else { + std::shared_ptr<TJoinOptimizerNode> join = + std::dynamic_pointer_cast<TJoinOptimizerNode>(reorderResult->LeftArg); + leftArg = BuildTree(ctx,equiJoin,join); + } + // Build right argument of the join + if (reorderResult->RightArg->Kind == RelNodeType) { + std::shared_ptr<TRelOptimizerNode> rel = + std::dynamic_pointer_cast<TRelOptimizerNode>(reorderResult->RightArg); + rightArg = BuildAtom(rel->Label, equiJoin.Pos(), ctx); + } else { + std::shared_ptr<TJoinOptimizerNode> join = + std::dynamic_pointer_cast<TJoinOptimizerNode>(reorderResult->RightArg); + rightArg = BuildTree(ctx,equiJoin,join); + } + + TVector<TExprBase> leftJoinColumns; + TVector<TExprBase> rightJoinColumns; + + // Build join conditions + for( auto pair : reorderResult->JoinConditions) { + leftJoinColumns.push_back(BuildAtom(pair.first.RelName, equiJoin.Pos(), ctx)); + leftJoinColumns.push_back(BuildAtom(pair.first.AttributeName, equiJoin.Pos(), ctx)); + rightJoinColumns.push_back(BuildAtom(pair.second.RelName, equiJoin.Pos(), ctx)); + rightJoinColumns.push_back(BuildAtom(pair.second.AttributeName, equiJoin.Pos(), ctx)); + } + + TVector<TExprBase> options; + + // Build the final output + return Build<TCoEquiJoinTuple>(ctx,equiJoin.Pos()) + .Type(BuildAtom("Inner",equiJoin.Pos(),ctx)) + .LeftScope(leftArg) + .RightScope(rightArg) + .LeftKeys() + .Add(leftJoinColumns) + .Build() + .RightKeys() + .Add(rightJoinColumns) + .Build() + .Options() + .Add(options) + .Build() + .Done(); +} + +/** + * Rebuild the equiJoinOperator with a new tree, that was obtained by optimizing join order +*/ +TExprBase RearrangeEquiJoinTree(TExprContext& ctx, const TCoEquiJoin& equiJoin, + std::shared_ptr<TJoinOptimizerNode> reorderResult) { + TVector<TExprBase> joinArgs; + for (size_t i=0; i<equiJoin.ArgCount() - 2; i++){ + joinArgs.push_back(equiJoin.Arg(i)); + } + + joinArgs.push_back(BuildTree(ctx,equiJoin,reorderResult)); + + joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount()-1)); + + return Build<TCoEquiJoin>(ctx, equiJoin.Pos()) + .Add(joinArgs) + .Done(); +} + +/** + * Check if all joins in the equiJoin tree are Inner Joins + * FIX: This is a temporary solution, need to be able to process all types of joins in the future +*/ +bool AllInnerJoins(const TCoEquiJoinTuple& joinTuple) { + if (joinTuple.Type() != "Inner") { + return false; + } + if (joinTuple.LeftScope().Maybe<TCoEquiJoinTuple>()) { + if (! AllInnerJoins(joinTuple.LeftScope().Cast<TCoEquiJoinTuple>())) { + return false; + } + } + + if (joinTuple.RightScope().Maybe<TCoEquiJoinTuple>()) { + if (! AllInnerJoins(joinTuple.RightScope().Cast<TCoEquiJoinTuple>())) { + return false; + } + } + return true; +} + +/** + * Main routine that checks: + * 1. Do we have an equiJoin + * 2. Is the cost already computed + * 3. FIX: Are all joins InnerJoins + * 4. Are all the costs of equiJoin inputs computed? + * + * Then it extracts join conditions from the join tree, constructs a join graph and + * optimizes it with the DPccp algorithm +*/ +TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, + bool ruleEnabled) { + + if (!ruleEnabled) { + return node; + } + + if (!node.Maybe<TCoEquiJoin>()) { + return node; + } + auto equiJoin = node.Cast<TCoEquiJoin>(); + YQL_ENSURE(equiJoin.ArgCount() >= 4); + + if (typesCtx.StatisticsMap.contains(equiJoin.Raw()) && + typesCtx.StatisticsMap[ equiJoin.Raw()]->Cost.has_value()) { + + return node; + } + + if (! AllInnerJoins(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>())) { + return node; + } + + YQL_CLOG(TRACE, CoreDq) << "Optimizing join with costs"; + + TVector<std::shared_ptr<TRelOptimizerNode>> rels; + + // Check that statistics for all inputs of equiJoin were computed + // The arguments of the EquiJoin are 1..n-2, n-2 is the actual join tree + // of the EquiJoin and n-1 argument are the parameters to EquiJoin + for (size_t i = 0; i < equiJoin.ArgCount() - 2; ++i) { + auto input = equiJoin.Arg(i).Cast<TCoEquiJoinInput>(); + auto joinArg = input.List(); + + if (!typesCtx.StatisticsMap.contains( joinArg.Raw() )) { + return node; + } + + if (!typesCtx.StatisticsMap[ joinArg.Raw() ]->Cost.has_value()) { + return node; + } + + auto scope = input.Scope(); + if (!scope.Maybe<TCoAtom>()){ + return node; + } + + auto label = scope.Cast<TCoAtom>().StringValue(); + auto stats = typesCtx.StatisticsMap[ joinArg.Raw() ]; + rels.push_back( std::shared_ptr<TRelOptimizerNode>( new TRelOptimizerNode( label, stats ))); + } + + YQL_CLOG(TRACE, CoreDq) << "All statistics for join in place"; + + std::set<std::pair<TJoinColumn, TJoinColumn>> joinConditions; + + // EquiJoin argument n-2 is the actual join tree, represented as TCoEquiJoinTuple + ComputeJoinConditions(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>(), joinConditions ); + + // construct a graph out of join conditions + TGraph<64> joinGraph; + for (size_t i=0; i<rels.size(); i++) { + joinGraph.AddNode(i,rels[i]->Label); + } + + for (auto cond : joinConditions ) { + int fromNode = joinGraph.FindNode(cond.first.RelName); + int toNode = joinGraph.FindNode(cond.second.RelName); + joinGraph.AddEdge(TEdge(fromNode,toNode,cond)); + } + + if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) { + std::stringstream str; + str << "Initial join graph:\n"; + joinGraph.PrintGraph(str); + YQL_CLOG(TRACE, CoreDq) << str.str(); + } + + // make a transitive closure of the graph and reorder the graph via BFS + joinGraph.ComputeTransitiveClosure(joinConditions); + + if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) { + std::stringstream str; + str << "Join graph after transitive closure:\n"; + joinGraph.PrintGraph(str); + YQL_CLOG(TRACE, CoreDq) << str.str(); + } + + joinGraph = joinGraph.BfsReorder(); + + // feed the graph to DPccp algorithm + TDPccpSolver<64> solver(joinGraph,rels); + std::shared_ptr<TJoinOptimizerNode> result = solver.Solve(); + + if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) { + std::stringstream str; + str << "Join tree after cost based optimization:\n"; + result->Print(str); + YQL_CLOG(TRACE, CoreDq) << str.str(); + } + + // rewrite the join tree and record the output statistics + TExprBase res = RearrangeEquiJoinTree(ctx,equiJoin,result); + typesCtx.StatisticsMap[ res.Raw() ] = result->Stats; + return res; +} + +} diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h index 6070de70c76..6a0dd1b4d88 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.h +++ b/ydb/library/yql/dq/opt/dq_opt_log.h @@ -18,6 +18,10 @@ NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx, NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents); +NNodes::TExprBase DqOptimizeEquiJoinWithCosts(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool isRuleEnabled); + +NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, TExprContext& ctx); + NNodes::TExprBase DqEnforceCompactPartition(NNodes::TExprBase node, NNodes::TExprList frames, TExprContext& ctx); NNodes::TExprBase DqExpandWindowFunctions(NNodes::TExprBase node, TExprContext& ctx, bool enforceCompact); diff --git a/ydb/library/yql/dq/opt/ya.make b/ydb/library/yql/dq/opt/ya.make index 074847336ab..52bd6534b6b 100644 --- a/ydb/library/yql/dq/opt/ya.make +++ b/ydb/library/yql/dq/opt/ya.make @@ -19,6 +19,7 @@ SRCS( dq_opt_peephole.cpp dq_opt_phy_finalizing.cpp dq_opt_phy.cpp + dq_opt_join_cost_based.cpp ) YQL_LAST_ABI_VERSION() |