aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPavel Velikhov <pavelvelikhov@ydb.tech>2024-01-30 17:00:22 +0300
committerGitHub <noreply@github.com>2024-01-30 17:00:22 +0300
commit544c63168e883abfeab2830ba11bb8917fa6401b (patch)
tree78e8dec2e0b2bc716df6a14593f55eb43288ca5c
parent1c307df064397e34b75d3f813ff9bbb9af61ee51 (diff)
downloadydb-544c63168e883abfeab2830ba11bb8917fa6401b.tar.gz
Enabled multiple join implementations in CBO (#1396)
* Intermediate commit to save work * Enabled multiple join implementations in CBO * Addressed Alexey's comments, added opt level 2 * Fixed a TPCH11 bug, losing JoinKind * Fixed DQ CBO unit test
-rw-r--r--ydb/core/kqp/host/kqp_runner.cpp12
-rw-r--r--ydb/core/kqp/opt/kqp_query_plan.cpp3
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.cpp9
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.h7
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp164
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_cbo.h37
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_log.cpp15
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_log.h4
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_log_impl.h2
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp38
-rw-r--r--ydb/core/kqp/opt/logical/ya.make1
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_settings.cpp6
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_settings.h3
-rw-r--r--ydb/core/kqp/ut/join/kqp_join_order_ut.cpp509
-rw-r--r--ydb/library/yql/core/cbo/cbo_optimizer_new.cpp5
-rw-r--r--ydb/library/yql/core/cbo/cbo_optimizer_new.h59
-rw-r--r--ydb/library/yql/core/yql_cost_function.cpp12
-rw-r--r--ydb/library/yql/core/yql_cost_function.h13
-rw-r--r--ydb/library/yql/dq/opt/dq_cbo_ut.cpp10
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join.h4
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp177
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp9
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_log.h14
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat.cpp8
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat.h5
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp8
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h4
-rw-r--r--ydb/library/yql/providers/dq/common/yql_dq_settings.h1
-rw-r--r--ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp2
-rw-r--r--ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp8
-rw-r--r--ydb/library/yql/providers/dq/provider/yql_dq_statistics.h3
31 files changed, 1011 insertions, 141 deletions
diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp
index b73e6ebe86..6e0d9b7f98 100644
--- a/ydb/core/kqp/host/kqp_runner.cpp
+++ b/ydb/core/kqp/host/kqp_runner.cpp
@@ -6,6 +6,7 @@
#include <ydb/core/kqp/opt/logical/kqp_opt_log.h>
#include <ydb/core/kqp/opt/kqp_statistics_transformer.h>
#include <ydb/core/kqp/opt/kqp_constant_folding_transformer.h>
+#include <ydb/core/kqp/opt/logical/kqp_opt_cbo.h>
#include <ydb/core/kqp/opt/physical/kqp_opt_phy.h>
@@ -20,6 +21,8 @@
#include <ydb/library/yql/core/services/yql_transform_pipeline.h>
#include <ydb/library/yql/core/yql_opt_proposed_by_data.h>
+#include <ydb/library/yql/providers/dq/common/yql_dq_settings.h>
+
#include <util/generic/is_in.h>
namespace NKikimr {
@@ -145,6 +148,7 @@ public:
, OptimizeCtx(MakeIntrusive<TKqpOptimizeContext>(cluster, Config, sessionCtx->QueryPtr(),
sessionCtx->TablesPtr()))
, BuildQueryCtx(MakeIntrusive<TKqpBuildQueryContext>())
+ , Pctx(TKqpProviderContext(*OptimizeCtx, Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel)))
{
CreateGraphTransformer(typesCtx, sessionCtx, funcRegistry);
}
@@ -266,8 +270,8 @@ private:
.AddPostTypeAnnotation(/* forSubgraph */ true)
.AddCommonOptimization()
.Add(CreateKqpConstantFoldingTransformer(OptimizeCtx, *typesCtx, Config), "ConstantFolding")
- .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics")
- .Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config), "LogicalOptimize")
+ .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config, Pctx), "Statistics")
+ .Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config, Pctx), "LogicalOptimize")
.Add(CreateLogicalDataProposalsInspector(*typesCtx), "ProvidersLogicalOptimize")
.Add(CreateKqpPhyOptTransformer(OptimizeCtx, *typesCtx), "KqpPhysicalOptimize")
.Add(CreatePhysicalDataProposalsInspector(*typesCtx), "ProvidersPhysicalOptimize")
@@ -300,7 +304,7 @@ private:
.AddTypeAnnotationTransformer(CreateKqpTypeAnnotationTransformer(Cluster, sessionCtx->TablesPtr(), *typesCtx, Config))
.AddPostTypeAnnotation()
.Add(CreateKqpBuildPhysicalQueryTransformer(OptimizeCtx, BuildQueryCtx), "BuildPhysicalQuery")
- .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics")
+ .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config, Pctx), "Statistics")
.Build(false);
auto physicalPeepholeTransformer = TTransformationPipeline(typesCtx)
@@ -364,6 +368,8 @@ private:
TIntrusivePtr<TKqpOptimizeContext> OptimizeCtx;
TIntrusivePtr<TKqpBuildQueryContext> BuildQueryCtx;
+ TKqpProviderContext Pctx;
+
TAutoPtr<IGraphTransformer> Transformer;
};
diff --git a/ydb/core/kqp/opt/kqp_query_plan.cpp b/ydb/core/kqp/opt/kqp_query_plan.cpp
index d2778ba244..86a1dfa0b1 100644
--- a/ydb/core/kqp/opt/kqp_query_plan.cpp
+++ b/ydb/core/kqp/opt/kqp_query_plan.cpp
@@ -14,6 +14,7 @@
#include <ydb/library/yql/dq/type_ann/dq_type_ann.h>
#include <ydb/library/yql/dq/tasks/dq_tasks_graph.h>
#include <ydb/library/yql/utils/plan/plan_utils.h>
+#include <ydb/library/yql/providers/dq/common/yql_dq_settings.h>
#include <library/cpp/json/writer/json.h>
#include <library/cpp/json/json_reader.h>
@@ -1357,7 +1358,7 @@ private:
}
void AddOptimizerEstimates(TOperator& op, const TExprBase& expr) {
- if (!SerializerCtx.Config->HasOptEnableCostBasedOptimization()) {
+ if (SerializerCtx.Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel)==0) {
return;
}
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
index a5594b43c8..ae12cba6c8 100644
--- a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
@@ -3,6 +3,9 @@
#include <ydb/library/yql/dq/opt/dq_opt_stat.h>
#include <ydb/library/yql/core/yql_cost_function.h>
+#include <ydb/library/yql/providers/dq/common/yql_dq_settings.h>
+
+
#include <charconv>
using namespace NYql;
@@ -187,7 +190,7 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt
TExprNode::TPtr& output, TExprContext& ctx) {
output = input;
- if (!Config->HasOptEnableCostBasedOptimization()) {
+ if (Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel) == 0) {
return IGraphTransformer::TStatus::Ok;
}
@@ -238,6 +241,6 @@ bool TKqpStatisticsTransformer::AfterLambdasSpecific(const TExprNode::TPtr& inpu
}
TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
- TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) {
- return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config));
+ TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config, const TKqpProviderContext& pctx) {
+ return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config, pctx));
}
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.h b/ydb/core/kqp/opt/kqp_statistics_transformer.h
index 3f4d9a3a39..3c54c7ee76 100644
--- a/ydb/core/kqp/opt/kqp_statistics_transformer.h
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.h
@@ -5,6 +5,7 @@
#include <ydb/library/yql/core/yql_statistics.h>
#include <ydb/core/kqp/common/kqp_yql.h>
+#include <ydb/core/kqp/opt/logical/kqp_opt_cbo.h>
#include <ydb/library/yql/core/yql_graph_transformer.h>
#include <ydb/library/yql/core/yql_expr_optimize.h>
#include <ydb/library/yql/core/yql_expr_type_annotation.h>
@@ -33,8 +34,8 @@ class TKqpStatisticsTransformer : public NYql::NDq::TDqStatisticsTransformerBase
public:
TKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx, TTypeAnnotationContext& typeCtx,
- const TKikimrConfiguration::TPtr& config) :
- TDqStatisticsTransformerBase(&typeCtx),
+ const TKikimrConfiguration::TPtr& config, const TKqpProviderContext& pctx) :
+ TDqStatisticsTransformerBase(&typeCtx, pctx),
Config(config),
KqpCtx(*kqpCtx) {}
@@ -47,6 +48,6 @@ class TKqpStatisticsTransformer : public NYql::NDq::TDqStatisticsTransformerBase
};
TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
- TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config);
+ TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config, const TKqpProviderContext& pctx);
}
}
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp b/ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp
new file mode 100644
index 0000000000..b421ba2975
--- /dev/null
+++ b/ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp
@@ -0,0 +1,164 @@
+#include "kqp_opt_cbo.h"
+#include "kqp_opt_log_impl.h"
+
+#include <ydb/library/yql/core/yql_opt_utils.h>
+#include <ydb/library/yql/utils/log/log.h>
+
+
+namespace NKikimr::NKqp::NOpt {
+
+using namespace NYql;
+using namespace NYql::NCommon;
+using namespace NYql::NDq;
+using namespace NYql::NNodes;
+
+namespace {
+
+/**
+ * KQP specific rule to check if a LookupJoin is applicable
+*/
+bool IsLookupJoinApplicableDetailed(const std::shared_ptr<NYql::TRelOptimizerNode>& node, const TVector<TString>& joinColumns, const TKqpProviderContext& ctx) {
+
+ auto rel = std::static_pointer_cast<TKqpRelOptimizerNode>(node);
+ auto expr = TExprBase(rel->Node);
+
+ if (ctx.KqpCtx.IsScanQuery() && !ctx.KqpCtx.Config->EnableKqpScanQueryStreamIdxLookupJoin) {
+ return false;
+ }
+
+ if (find_if(joinColumns.begin(), joinColumns.end(), [&] (const TString& s) { return node->Stats->KeyColumns[0] == s;})) {
+ return true;
+ }
+
+ auto readMatch = MatchRead<TKqlReadTable>(expr);
+ TMaybeNode<TKqlKeyInc> maybeTablePrefix;
+ size_t prefixSize;
+
+ if (readMatch) {
+ if (readMatch->FlatMap && !IsPassthroughFlatMap(readMatch->FlatMap.Cast(), nullptr)){
+ return false;
+ }
+ auto read = readMatch->Read.Cast<TKqlReadTable>();
+ maybeTablePrefix = GetRightTableKeyPrefix(read.Range());
+
+ if (!maybeTablePrefix) {
+ return false;
+ }
+
+ prefixSize = maybeTablePrefix.Cast().ArgCount();
+
+ if (!prefixSize) {
+ return true;
+ }
+ }
+ else {
+ readMatch = MatchRead<TKqlReadTableRangesBase>(expr);
+ if (readMatch) {
+ if (readMatch->FlatMap && !IsPassthroughFlatMap(readMatch->FlatMap.Cast(), nullptr)){
+ return false;
+ }
+ auto read = readMatch->Read.Cast<TKqlReadTableRangesBase>();
+ if (TCoVoid::Match(read.Ranges().Raw())) {
+ return true;
+ } else {
+ auto prompt = TKqpReadTableExplainPrompt::Parse(read);
+
+ if (prompt.PointPrefixLen != prompt.UsedKeyColumns.size()) {
+ return false;
+ }
+
+ if (prompt.ExpectedMaxRanges != TMaybe<ui64>(1)) {
+ return false;
+ }
+ prefixSize = prompt.PointPrefixLen;
+ }
+ }
+ }
+ if (! readMatch) {
+ return false;
+ }
+
+ if (prefixSize < node->Stats->KeyColumns.size() && !(find_if(joinColumns.begin(), joinColumns.end(), [&] (const TString& s) {
+ return node->Stats->KeyColumns[prefixSize] == s;
+ }))){
+ return false;
+ }
+
+ return true;
+}
+
+bool IsLookupJoinApplicable(std::shared_ptr<IBaseOptimizerNode> left,
+ std::shared_ptr<IBaseOptimizerNode> right,
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions,
+ TKqpProviderContext& ctx) {
+
+ Y_UNUSED(left);
+
+ auto rightStats = right->Stats;
+
+ if (rightStats->Type != EStatisticsType::BaseTable) {
+ return false;
+ }
+ if (joinConditions.size() > rightStats->KeyColumns.size()) {
+ return false;
+ }
+
+ for (auto [leftCol, rightCol] : joinConditions) {
+ if (! find_if(rightStats->KeyColumns.begin(), rightStats->KeyColumns.end(),
+ [rightCol] (const TString& s) {
+ return rightCol.AttributeName == s;
+ } )) {
+ return false;
+ }
+ }
+
+ TVector<TString> joinKeys;
+ for( auto [leftJc, rightJc] : joinConditions ) {
+ joinKeys.emplace_back( rightJc.AttributeName);
+ }
+
+ return IsLookupJoinApplicableDetailed(std::static_pointer_cast<TRelOptimizerNode>(right), joinKeys, ctx);
+}
+
+}
+
+bool TKqpProviderContext::IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
+ const std::shared_ptr<IBaseOptimizerNode>& right,
+ const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
+ EJoinAlgoType joinAlgo) {
+
+ switch( joinAlgo ) {
+ case EJoinAlgoType::LookupJoin:
+ if (OptLevel==2 && left->Stats->Nrows > 10e3) {
+ return false;
+ }
+ return IsLookupJoinApplicable(left, right, joinConditions, *this);
+
+ case EJoinAlgoType::DictJoin:
+ return right->Stats->Nrows < 10e5;
+ case EJoinAlgoType::MapJoin:
+ return right->Stats->Nrows < 10e6;
+ case EJoinAlgoType::GraceJoin:
+ return true;
+ }
+}
+
+double TKqpProviderContext::ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, EJoinAlgoType joinAlgo) const {
+
+ switch(joinAlgo) {
+ case EJoinAlgoType::LookupJoin:
+ if (OptLevel==1) {
+ return -1;
+ }
+ return leftStats.Nrows;
+ case EJoinAlgoType::DictJoin:
+ return leftStats.Nrows + 1.7 * rightStats.Nrows;
+ case EJoinAlgoType::MapJoin:
+ return leftStats.Nrows + 1.8 * rightStats.Nrows;
+ case EJoinAlgoType::GraceJoin:
+ return leftStats.Nrows + 2.0 * rightStats.Nrows;
+ }
+}
+
+
+} \ No newline at end of file
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_cbo.h b/ydb/core/kqp/opt/logical/kqp_opt_cbo.h
new file mode 100644
index 0000000000..13b6b0200e
--- /dev/null
+++ b/ydb/core/kqp/opt/logical/kqp_opt_cbo.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <ydb/library/yql/ast/yql_expr.h>
+#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h>
+
+#include <ydb/core/kqp/opt/kqp_opt.h>
+
+namespace NKikimr::NKqp::NOpt {
+
+/**
+ * KQP specific Rel node, includes a pointer to ExprNode
+*/
+struct TKqpRelOptimizerNode : public NYql::TRelOptimizerNode {
+ const NYql::TExprNode::TPtr Node;
+
+ TKqpRelOptimizerNode(TString label, std::shared_ptr<NYql::TOptimizerStatistics> stats, const NYql::TExprNode::TPtr node) :
+ TRelOptimizerNode(label, stats), Node(node) { }
+};
+
+/**
+ * KQP Specific cost function and join applicability cost function
+*/
+struct TKqpProviderContext : public NYql::IProviderContext {
+ TKqpProviderContext(const TKqpOptimizeContext& kqpCtx, const int optLevel) : KqpCtx(kqpCtx), OptLevel(optLevel) {}
+
+ virtual bool IsJoinApplicable(const std::shared_ptr<NYql::IBaseOptimizerNode>& left,
+ const std::shared_ptr<NYql::IBaseOptimizerNode>& right,
+ const std::set<std::pair<NYql::NDq::TJoinColumn, NYql::NDq::TJoinColumn>>& joinConditions,
+ NYql::EJoinAlgoType joinAlgo) override;
+
+ virtual double ComputeJoinCost(const NYql::TOptimizerStatistics& leftStats, const NYql::TOptimizerStatistics& rightStats, NYql::EJoinAlgoType joinAlgo) const override;
+
+ const TKqpOptimizeContext& KqpCtx;
+ int OptLevel;
+};
+
+} \ No newline at end of file
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
index 96f7055653..5119d7769d 100644
--- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
+++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
@@ -1,4 +1,5 @@
#include "kqp_opt_log_rules.h"
+#include "kqp_opt_cbo.h"
#include <ydb/core/kqp/common/kqp_yql.h>
#include <ydb/core/kqp/opt/kqp_opt_impl.h>
@@ -21,11 +22,12 @@ using namespace NYql::NNodes;
class TKqpLogicalOptTransformer : public TOptimizeTransformerBase {
public:
TKqpLogicalOptTransformer(TTypeAnnotationContext& typesCtx, const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
- const TKikimrConfiguration::TPtr& config)
+ const TKikimrConfiguration::TPtr& config, TKqpProviderContext& pctx)
: TOptimizeTransformerBase(nullptr, NYql::NLog::EComponent::ProviderKqp, {})
, TypesCtx(typesCtx)
, KqpCtx(*kqpCtx)
, Config(config)
+ , Pctx(pctx)
{
#define HNDL(name) "KqpLogical-"#name, Hndl(&TKqpLogicalOptTransformer::name)
AddHandler(0, &TCoFlatMapBase::Match, HNDL(PushPredicateToReadTable));
@@ -134,7 +136,10 @@ protected:
TMaybeNode<TExprBase> OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) {
auto maxDPccpDPTableSize = Config->MaxDPccpDPTableSize.Get().GetOrElse(TDqSettings::TDefault::MaxDPccpDPTableSize);
- TExprBase output = DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, Config->HasOptEnableCostBasedOptimization(), maxDPccpDPTableSize);
+ TExprBase output = DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel),
+ maxDPccpDPTableSize, Pctx, [](auto& rels, auto label, auto node, auto stat) {
+ rels.emplace_back(std::make_shared<TKqpRelOptimizerNode>(TString(label), stat, node));
+ });
DumpAppliedRule("OptimizeEquiJoinWithCosts", node.Ptr(), output.Ptr(), ctx);
return output;
}
@@ -269,12 +274,14 @@ private:
TTypeAnnotationContext& TypesCtx;
const TKqpOptimizeContext& KqpCtx;
const TKikimrConfiguration::TPtr& Config;
+ TKqpProviderContext& Pctx;
};
TAutoPtr<IGraphTransformer> CreateKqpLogOptTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
- TTypeAnnotationContext& typesCtx, const TKikimrConfiguration::TPtr& config)
+ TTypeAnnotationContext& typesCtx, const TKikimrConfiguration::TPtr& config,
+ TKqpProviderContext& pctx)
{
- return THolder<IGraphTransformer>(new TKqpLogicalOptTransformer(typesCtx, kqpCtx, config));
+ return THolder<IGraphTransformer>(new TKqpLogicalOptTransformer(typesCtx, kqpCtx, config, pctx));
}
} // namespace NKikimr::NKqp::NOpt
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.h b/ydb/core/kqp/opt/logical/kqp_opt_log.h
index e833934a54..3e5e9ed879 100644
--- a/ydb/core/kqp/opt/logical/kqp_opt_log.h
+++ b/ydb/core/kqp/opt/logical/kqp_opt_log.h
@@ -1,12 +1,14 @@
#pragma once
#include <ydb/core/kqp/opt/kqp_opt.h>
+#include <ydb/core/kqp/opt/logical/kqp_opt_cbo.h>
namespace NKikimr::NKqp::NOpt {
struct TKqpOptimizeContext;
TAutoPtr<NYql::IGraphTransformer> CreateKqpLogOptTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
- NYql::TTypeAnnotationContext& typesCtx, const NYql::TKikimrConfiguration::TPtr& config);
+ NYql::TTypeAnnotationContext& typesCtx, const NYql::TKikimrConfiguration::TPtr& config,
+ NKikimr::NKqp::NOpt::TKqpProviderContext& pctx);
} // namespace NKikimr::NKqp::NOpt
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h b/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h
index 49f5c74e03..5ff32edc6e 100644
--- a/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h
+++ b/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h
@@ -22,6 +22,8 @@ TMaybe<TKqpMatchReadResult> MatchRead(NYql::NNodes::TExprBase node) {
return MatchRead(node, [] (NYql::NNodes::TExprBase node) { return node.Maybe<TRead>().IsValid(); });
}
+NYql::NNodes::TMaybeNode<NYql::NNodes::TKqlKeyInc> GetRightTableKeyPrefix(const NYql::NNodes::TKqlKeyRange& range);
+
} // NKikimr::NKqp::NOpt
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp
index a9a248a197..47affa3e19 100644
--- a/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp
+++ b/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp
@@ -167,25 +167,6 @@ TDqJoin FlipLeftSemiJoin(const TDqJoin& join, TExprContext& ctx) {
.Done();
}
-TMaybeNode<TKqlKeyInc> GetRightTableKeyPrefix(const TKqlKeyRange& range) {
- if (!range.From().Maybe<TKqlKeyInc>() || !range.To().Maybe<TKqlKeyInc>()) {
- return {};
- }
- auto rangeFrom = range.From().Cast<TKqlKeyInc>();
- auto rangeTo = range.To().Cast<TKqlKeyInc>();
-
- if (rangeFrom.ArgCount() != rangeTo.ArgCount()) {
- return {};
- }
- for (ui32 i = 0; i < rangeFrom.ArgCount(); ++i) {
- if (rangeFrom.Arg(i).Raw() != rangeTo.Arg(i).Raw()) {
- return {};
- }
- }
-
- return rangeFrom;
-}
-
TExprBase BuildLookupIndex(TExprContext& ctx, const TPositionHandle pos,
const TKqpTable& table, const TCoAtomList& columns,
const TExprBase& keysToLookup, const TVector<TCoAtom>& skipNullColumns, const TString& indexName,
@@ -859,6 +840,25 @@ TMaybeNode<TExprBase> KqpJoinToIndexLookupImpl(const TDqJoin& join, TExprContext
} // anonymous namespace
+TMaybeNode<TKqlKeyInc> GetRightTableKeyPrefix(const TKqlKeyRange& range) {
+ if (!range.From().Maybe<TKqlKeyInc>() || !range.To().Maybe<TKqlKeyInc>()) {
+ return {};
+ }
+ auto rangeFrom = range.From().Cast<TKqlKeyInc>();
+ auto rangeTo = range.To().Cast<TKqlKeyInc>();
+
+ if (rangeFrom.ArgCount() != rangeTo.ArgCount()) {
+ return {};
+ }
+ for (ui32 i = 0; i < rangeFrom.ArgCount(); ++i) {
+ if (rangeFrom.Arg(i).Raw() != rangeTo.Arg(i).Raw()) {
+ return {};
+ }
+ }
+
+ return rangeFrom;
+}
+
TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx)
{
if ((kqpCtx.IsScanQuery() && !kqpCtx.Config->EnableKqpScanQueryStreamIdxLookupJoin) || !node.Maybe<TDqJoin>()) {
diff --git a/ydb/core/kqp/opt/logical/ya.make b/ydb/core/kqp/opt/logical/ya.make
index 8b0b7ad5ab..017e5bf87b 100644
--- a/ydb/core/kqp/opt/logical/ya.make
+++ b/ydb/core/kqp/opt/logical/ya.make
@@ -12,6 +12,7 @@ SRCS(
kqp_opt_log_sqlin.cpp
kqp_opt_log_sqlin_compact.cpp
kqp_opt_log.cpp
+ kqp_opt_cbo.cpp
)
PEERDIR(
diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.cpp b/ydb/core/kqp/provider/yql_kikimr_settings.cpp
index c3a4e769f0..b1dfad359e 100644
--- a/ydb/core/kqp/provider/yql_kikimr_settings.cpp
+++ b/ydb/core/kqp/provider/yql_kikimr_settings.cpp
@@ -65,7 +65,7 @@ TKikimrConfiguration::TKikimrConfiguration() {
REGISTER_SETTING(*this, OptEnableOlapProvideComputeSharding);
REGISTER_SETTING(*this, OptUseFinalizeByKey);
- REGISTER_SETTING(*this, OptEnableCostBasedOptimization);
+ REGISTER_SETTING(*this, CostBasedOptimizationLevel);
REGISTER_SETTING(*this, OptEnableConstantFolding);
REGISTER_SETTING(*this, MaxDPccpDPTableSize);
@@ -122,10 +122,6 @@ bool TKikimrSettings::HasOptUseFinalizeByKey() const {
return GetOptionalFlagValue(OptUseFinalizeByKey.Get()) != EOptionalFlag::Disabled;
}
-bool TKikimrSettings::HasOptEnableCostBasedOptimization() const {
- return GetOptionalFlagValue(OptEnableCostBasedOptimization.Get()) == EOptionalFlag::Enabled;
-}
-
bool TKikimrSettings::HasOptEnableConstantFolding() const {
return GetOptionalFlagValue(OptEnableConstantFolding.Get()) == EOptionalFlag::Enabled;
}
diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.h b/ydb/core/kqp/provider/yql_kikimr_settings.h
index fd8f091800..f6fb2beb1e 100644
--- a/ydb/core/kqp/provider/yql_kikimr_settings.h
+++ b/ydb/core/kqp/provider/yql_kikimr_settings.h
@@ -58,7 +58,7 @@ struct TKikimrSettings {
NCommon::TConfSetting<bool, false> OptEnableOlapPushdown;
NCommon::TConfSetting<bool, false> OptEnableOlapProvideComputeSharding;
NCommon::TConfSetting<bool, false> OptUseFinalizeByKey;
- NCommon::TConfSetting<bool, false> OptEnableCostBasedOptimization;
+ NCommon::TConfSetting<ui32, false> CostBasedOptimizationLevel;
NCommon::TConfSetting<bool, false> OptEnableConstantFolding;
NCommon::TConfSetting<ui32, false> MaxDPccpDPTableSize;
@@ -81,7 +81,6 @@ struct TKikimrSettings {
bool HasOptEnableOlapPushdown() const;
bool HasOptEnableOlapProvideComputeSharding() const;
bool HasOptUseFinalizeByKey() const;
- bool HasOptEnableCostBasedOptimization() const;
bool HasOptEnableConstantFolding() const;
diff --git a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
index f2e780e406..227ec78d61 100644
--- a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
+++ b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
@@ -17,7 +17,7 @@ using namespace NYdb::NTable;
static void CreateSampleTable(TSession session) {
UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
CREATE TABLE `/Root/R` (
- id Int32,
+ id Int32 not null,
payload1 String,
ts Date,
PRIMARY KEY (id)
@@ -26,7 +26,7 @@ static void CreateSampleTable(TSession session) {
UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
CREATE TABLE `/Root/S` (
- id Int32,
+ id Int32 not null,
payload2 String,
PRIMARY KEY (id)
);
@@ -34,7 +34,7 @@ static void CreateSampleTable(TSession session) {
UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
CREATE TABLE `/Root/T` (
- id Int32,
+ id Int32 not null,
payload3 String,
PRIMARY KEY (id)
);
@@ -42,7 +42,7 @@ static void CreateSampleTable(TSession session) {
UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
CREATE TABLE `/Root/U` (
- id Int32,
+ id Int32 not null,
payload4 String,
PRIMARY KEY (id)
);
@@ -50,7 +50,7 @@ static void CreateSampleTable(TSession session) {
UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
CREATE TABLE `/Root/V` (
- id Int32,
+ id Int32 not null,
payload5 String,
PRIMARY KEY (id)
);
@@ -73,15 +73,118 @@ static void CreateSampleTable(TSession session) {
REPLACE INTO `/Root/V` (id, payload5) VALUES
(1, "blah");
)", TTxControl::BeginTx().CommitTx()).GetValueSync().IsSuccess());
+
+ UNIT_ASSERT(session.ExecuteSchemeQuery(R"(
+ CREATE TABLE `/Root/customer` (
+ c_acctbal Double,
+ c_address String,
+ c_comment String,
+ c_custkey Int32, -- Identifier
+ c_mktsegment String ,
+ c_name String ,
+ c_nationkey Int32 , -- FK to N_NATIONKEY
+ c_phone String ,
+ PRIMARY KEY (c_custkey)
+)
+;
+
+CREATE TABLE `/Root/lineitem` (
+ l_comment String ,
+ l_commitdate Date ,
+ l_discount Double , -- it should be Decimal(12, 2)
+ l_extendedprice Double , -- it should be Decimal(12, 2)
+ l_linenumber Int32 ,
+ l_linestatus String ,
+ l_orderkey Int32 , -- FK to O_ORDERKEY
+ l_partkey Int32 , -- FK to P_PARTKEY, first part of the compound FK to (PS_PARTKEY, PS_SUPPKEY) with L_SUPPKEY
+ l_quantity Double , -- it should be Decimal(12, 2)
+ l_receiptdate Date ,
+ l_returnflag String ,
+ l_shipdate Date ,
+ l_shipinstruct String ,
+ l_shipmode String ,
+ l_suppkey Int32 , -- FK to S_SUPPKEY, second part of the compound FK to (PS_PARTKEY, PS_SUPPKEY) with L_PARTKEY
+ l_tax Double , -- it should be Decimal(12, 2)
+ PRIMARY KEY (l_orderkey, l_linenumber)
+)
+;
+
+CREATE TABLE `/Root/nation` (
+ n_comment String ,
+ n_name String ,
+ n_nationkey Int32 , -- Identifier
+ n_regionkey Int32 , -- FK to R_REGIONKEY
+ PRIMARY KEY(n_nationkey)
+)
+;
+
+CREATE TABLE `/Root/orders` (
+ o_clerk String ,
+ o_comment String ,
+ o_custkey Int32 , -- FK to C_CUSTKEY
+ o_orderdate Date ,
+ o_orderkey Int32 , -- Identifier
+ o_orderpriority String ,
+ o_orderstatus String ,
+ o_shippriority Int32 ,
+ o_totalprice Double , -- it should be Decimal(12, 2)
+ PRIMARY KEY (o_orderkey)
+)
+;
+
+CREATE TABLE `/Root/part` (
+ p_brand String ,
+ p_comment String ,
+ p_container String ,
+ p_mfgr String ,
+ p_name String ,
+ p_partkey Int32 , -- Identifier
+ p_retailprice Double , -- it should be Decimal(12, 2)
+ p_size Int32 ,
+ p_type String ,
+ PRIMARY KEY(p_partkey)
+)
+;
+
+CREATE TABLE `/Root/partsupp` (
+ ps_availqty Int32 ,
+ ps_comment String ,
+ ps_partkey Int32 , -- FK to P_PARTKEY
+ ps_suppkey Int32 , -- FK to S_SUPPKEY
+ ps_supplycost Double , -- it should be Decimal(12, 2)
+ PRIMARY KEY(ps_partkey, ps_suppkey)
+)
+;
+
+CREATE TABLE `/Root/region` (
+ r_comment String ,
+ r_name String ,
+ r_regionkey Int32 , -- Identifier
+ PRIMARY KEY(r_regionkey)
+)
+;
+
+CREATE TABLE `/Root/supplier` (
+ s_acctbal Double , -- it should be Decimal(12, 2)
+ s_address String ,
+ s_comment String ,
+ s_name String ,
+ s_nationkey Int32 , -- FK to N_NATIONKEY
+ s_phone String ,
+ s_suppkey Int32 , -- Identifier
+ PRIMARY KEY(s_suppkey)
+)
+;)").GetValueSync().IsSuccess());
+
}
static TKikimrRunner GetKikimrWithJoinSettings(){
TVector<NKikimrKqp::TKqpSetting> settings;
NKikimrKqp::TKqpSetting setting;
-
- setting.SetName("OptEnableCostBasedOptimization");
- setting.SetValue("true");
+
+ setting.SetName("CostBasedOptimizationLevel");
+ setting.SetValue("1");
settings.push_back(setting);
setting.SetName("OptEnableConstantFolding");
@@ -458,8 +561,394 @@ Y_UNIT_TEST_SUITE(KqpJoinOrder) {
Cout << result.GetPlan();
}
}
-}
+ Y_UNIT_TEST(TPCH21) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+-- TPC-H/TPC-R Suppliers Who Kept Orders Waiting Query (Q21)
+-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0)
+-- using 1680793381 as a seed to the RNG
+
+$n = select n_nationkey from `/Root/nation`
+where n_name = 'EGYPT';
+
+$s = select s_name, s_suppkey from `/Root/supplier` as supplier
+join $n as nation
+on supplier.s_nationkey = nation.n_nationkey;
+
+$l = select l_suppkey, l_orderkey from `/Root/lineitem`
+where l_receiptdate > l_commitdate;
+
+$j1 = select s_name, l_suppkey, l_orderkey from $l as l1
+join $s as supplier
+on l1.l_suppkey = supplier.s_suppkey;
+
+-- exists
+$j2 = select l1.l_orderkey as l_orderkey, l1.l_suppkey as l_suppkey, l1.s_name as s_name, l2.l_receiptdate as l_receiptdate, l2.l_commitdate as l_commitdate from $j1 as l1
+join `/Root/lineitem` as l2
+on l1.l_orderkey = l2.l_orderkey
+where l2.l_suppkey <> l1.l_suppkey;
+
+$j2_1 = select s_name, l1.l_suppkey as l_suppkey, l1.l_orderkey as l_orderkey from $j1 as l1
+left semi join $j2 as l2
+on l1.l_orderkey = l2.l_orderkey;
+
+-- not exists
+$j2_2 = select l_orderkey from $j2 where l_receiptdate > l_commitdate;
+
+$j3 = select s_name, l_suppkey, l_orderkey from $j2_1 as l1
+left only join $j2_2 as l3
+on l1.l_orderkey = l3.l_orderkey;
+
+$j4 = select s_name from $j3 as l1
+join `/Root/orders` as orders
+on orders.o_orderkey = l1.l_orderkey
+where o_orderstatus = 'F';
+
+select s_name,
+ count(*) as numwait from $j4
+group by
+ s_name
+order by
+ numwait desc,
+ s_name
+limit 100;)");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(TPCH5) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+-- TPC-H/TPC-R Local Supplier Volume Query (Q5)
+-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0)
+-- using 1680793381 as a seed to the RNG
+
+$join1 = (
+select
+ o.o_orderkey as o_orderkey,
+ o.o_orderdate as o_orderdate,
+ c.c_nationkey as c_nationkey
+from
+ `/Root/customer` as c
+join
+ `/Root/orders` as o
+on
+ c.c_custkey = o.o_custkey
+);
+
+$join2 = (
+select
+ j.o_orderkey as o_orderkey,
+ j.o_orderdate as o_orderdate,
+ j.c_nationkey as c_nationkey,
+ l.l_extendedprice as l_extendedprice,
+ l.l_discount as l_discount,
+ l.l_suppkey as l_suppkey
+from
+ $join1 as j
+join
+ `/Root/lineitem` as l
+on
+ l.l_orderkey = j.o_orderkey
+);
+
+$join3 = (
+select
+ j.o_orderkey as o_orderkey,
+ j.o_orderdate as o_orderdate,
+ j.c_nationkey as c_nationkey,
+ j.l_extendedprice as l_extendedprice,
+ j.l_discount as l_discount,
+ j.l_suppkey as l_suppkey,
+ s.s_nationkey as s_nationkey
+from
+ $join2 as j
+join
+ `/Root/supplier` as s
+on
+ j.l_suppkey = s.s_suppkey
+);
+$join4 = (
+select
+ j.o_orderkey as o_orderkey,
+ j.o_orderdate as o_orderdate,
+ j.c_nationkey as c_nationkey,
+ j.l_extendedprice as l_extendedprice,
+ j.l_discount as l_discount,
+ j.l_suppkey as l_suppkey,
+ j.s_nationkey as s_nationkey,
+ n.n_regionkey as n_regionkey,
+ n.n_name as n_name
+from
+ $join3 as j
+join
+ `/Root/nation` as n
+on
+ j.s_nationkey = n.n_nationkey
+ and j.c_nationkey = n.n_nationkey
+);
+$join5 = (
+select
+ j.o_orderkey as o_orderkey,
+ j.o_orderdate as o_orderdate,
+ j.c_nationkey as c_nationkey,
+ j.l_extendedprice as l_extendedprice,
+ j.l_discount as l_discount,
+ j.l_suppkey as l_suppkey,
+ j.s_nationkey as s_nationkey,
+ j.n_regionkey as n_regionkey,
+ j.n_name as n_name,
+ r.r_name as r_name
+from
+ $join4 as j
+join
+ `/Root/region` as r
+on
+ j.n_regionkey = r.r_regionkey
+);
+$border = Date('1995-01-01');
+select
+ n_name,
+ sum(l_extendedprice * (1 - l_discount)) as revenue
+from
+ $join5
+where
+ r_name = 'AFRICA'
+ and CAST(o_orderdate AS Timestamp) >= $border
+ and CAST(o_orderdate AS Timestamp) < ($border + Interval('P365D'))
+group by
+ n_name
+order by
+ revenue desc;
+
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(TPCH10) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+
+-- TPC-H/TPC-R Returned Item Reporting Query (Q10)
+-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0)
+-- using 1680793381 as a seed to the RNG
+
+$border = Date("1993-12-01");
+$join1 = (
+select
+ c.c_custkey as c_custkey,
+ c.c_name as c_name,
+ c.c_acctbal as c_acctbal,
+ c.c_address as c_address,
+ c.c_phone as c_phone,
+ c.c_comment as c_comment,
+ c.c_nationkey as c_nationkey,
+ o.o_orderkey as o_orderkey
+from
+ `/Root/customer` as c
+join
+ `/Root/orders` as o
+on
+ c.c_custkey = o.o_custkey
+where
+ cast(o.o_orderdate as timestamp) >= $border and
+ cast(o.o_orderdate as timestamp) < ($border + Interval("P90D"))
+);
+$join2 = (
+select
+ j.c_custkey as c_custkey,
+ j.c_name as c_name,
+ j.c_acctbal as c_acctbal,
+ j.c_address as c_address,
+ j.c_phone as c_phone,
+ j.c_comment as c_comment,
+ j.c_nationkey as c_nationkey,
+ l.l_extendedprice as l_extendedprice,
+ l.l_discount as l_discount
+from
+ $join1 as j
+join
+ `/Root/lineitem` as l
+on
+ l.l_orderkey = j.o_orderkey
+where
+ l.l_returnflag = 'R'
+);
+$join3 = (
+select
+ j.c_custkey as c_custkey,
+ j.c_name as c_name,
+ j.c_acctbal as c_acctbal,
+ j.c_address as c_address,
+ j.c_phone as c_phone,
+ j.c_comment as c_comment,
+ j.c_nationkey as c_nationkey,
+ j.l_extendedprice as l_extendedprice,
+ j.l_discount as l_discount,
+ n.n_name as n_name
+from
+ $join2 as j
+join
+ `/Root/nation` as n
+on
+ n.n_nationkey = j.c_nationkey
+);
+select
+ c_custkey,
+ c_name,
+ sum(l_extendedprice * (1 - l_discount)) as revenue,
+ c_acctbal,
+ n_name,
+ c_address,
+ c_phone,
+ c_comment
+from
+ $join3
+group by
+ c_custkey,
+ c_name,
+ c_acctbal,
+ c_phone,
+ n_name,
+ c_address,
+ c_comment
+order by
+ revenue desc
+limit 20;
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+
+ Y_UNIT_TEST(TPCH11) {
+
+ auto kikimr = GetKikimrWithJoinSettings();
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+
+ CreateSampleTable(session);
+
+ /* join with parameters */
+ {
+ const TString query = Q_(R"(
+
+-- TPC-H/TPC-R Important Stock Identification Query (Q11)
+-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0)
+-- using 1680793381 as a seed to the RNG
+
+$join1 = (
+select
+ ps.ps_partkey as ps_partkey,
+ ps.ps_supplycost as ps_supplycost,
+ ps.ps_availqty as ps_availqty,
+ s.s_nationkey as s_nationkey
+from
+ `/Root/partsupp` as ps
+join
+ `/Root/supplier` as s
+on
+ ps.ps_suppkey = s.s_suppkey
+);
+$join2 = (
+select
+ j.ps_partkey as ps_partkey,
+ j.ps_supplycost as ps_supplycost,
+ j.ps_availqty as ps_availqty,
+ j.s_nationkey as s_nationkey
+from
+ $join1 as j
+join
+ `/Root/nation` as n
+on
+ n.n_nationkey = j.s_nationkey
+where
+ n.n_name = 'CANADA'
+);
+$threshold = (
+select
+ sum(ps_supplycost * ps_availqty) * 0.0001000000 as threshold
+from
+ $join2
+);
+$values = (
+select
+ ps_partkey,
+ sum(ps_supplycost * ps_availqty) as value
+from
+ $join2
+group by
+ ps_partkey
+);
+
+select
+ v.ps_partkey as ps_partkey,
+ v.value as value
+from
+ $values as v
+cross join
+ $threshold as t
+where
+ v.value > t.threshold
+order by
+ value desc;
+ )");
+
+ auto result = session.ExplainDataQuery(query).ExtractValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
+
+ NJson::TJsonValue plan;
+ NJson::ReadJsonTree(result.GetPlan(), &plan, true);
+ Cout << result.GetPlan();
+ }
+ }
+}
}
}
-
diff --git a/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp b/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp
index 2ee3d1a0d9..c7ce233d76 100644
--- a/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp
+++ b/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp
@@ -63,12 +63,13 @@ void TRelOptimizerNode::Print(std::stringstream& stream, int ntabs) {
}
TJoinOptimizerNode::TJoinOptimizerNode(const std::shared_ptr<IBaseOptimizerNode>& left, const std::shared_ptr<IBaseOptimizerNode>& right,
- const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions, const EJoinKind joinType, bool nonReorderable) :
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions, const EJoinKind joinType, const EJoinAlgoType joinAlgo, bool nonReorderable) :
IBaseOptimizerNode(JoinNodeType),
LeftArg(left),
RightArg(right),
JoinConditions(joinConditions),
- JoinType(joinType) {
+ JoinType(joinType),
+ JoinAlgo(joinAlgo) {
IsReorderable = (JoinType==EJoinKind::InnerJoin) && (nonReorderable==false);
}
diff --git a/ydb/library/yql/core/cbo/cbo_optimizer_new.h b/ydb/library/yql/core/cbo/cbo_optimizer_new.h
index 1f35a0a231..256252241e 100644
--- a/ydb/library/yql/core/cbo/cbo_optimizer_new.h
+++ b/ydb/library/yql/core/cbo/cbo_optimizer_new.h
@@ -49,8 +49,13 @@ struct IBaseOptimizerNode {
struct TRelOptimizerNode : public IBaseOptimizerNode {
TString Label;
+ // Temporary solution to check if a LookupJoin is possible in KQP
+ //void* Expr;
+
TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats) :
IBaseOptimizerNode(RelNodeType, stats), Label(label) { }
+ //TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats, const TExprNode::TPtr expr) :
+ // IBaseOptimizerNode(RelNodeType, stats), Label(label), Expr(expr) { }
virtual ~TRelOptimizerNode() {}
virtual TVector<TString> Labels();
@@ -74,6 +79,54 @@ enum EJoinKind: ui32
EJoinKind ConvertToJoinKind(const TString& joinString);
TString ConvertToJoinString(const EJoinKind kind);
+/**
+ * This is a temporary structure for KQP provider
+ * We will soon be supporting multiple providers and we will need to design
+ * some interfaces to pass provider-specific context to the optimizer
+*/
+struct IProviderContext {
+ virtual ~IProviderContext() = default;
+
+ virtual double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, EJoinAlgoType joinAlgol) const = 0;
+
+ virtual bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
+ const std::shared_ptr<IBaseOptimizerNode>& right,
+ const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
+ EJoinAlgoType joinAlgo) = 0;
+
+};
+
+/**
+ * Temporary solution for default provider context
+*/
+
+struct TDummyProviderContext : public IProviderContext {
+ TDummyProviderContext() {}
+
+ double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, EJoinAlgoType joinAlgo) const override {
+ Y_UNUSED(joinAlgo);
+ return leftStats.Nrows + 2.0 * rightStats.Nrows;
+ }
+
+ bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
+ const std::shared_ptr<IBaseOptimizerNode>& right,
+ const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
+ EJoinAlgoType joinAlgo) override {
+
+ Y_UNUSED(left);
+ Y_UNUSED(right);
+ Y_UNUSED(joinConditions);
+ Y_UNUSED(joinAlgo);
+
+ return true;
+ }
+
+ static const TDummyProviderContext& instance() {
+ static TDummyProviderContext staticContext;
+ return staticContext;
+ }
+
+};
/**
* JoinOptimizerNode records the left and right arguments of the join
@@ -86,16 +139,20 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
std::shared_ptr<IBaseOptimizerNode> RightArg;
std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>> JoinConditions;
EJoinKind JoinType;
+ EJoinAlgoType JoinAlgo;
bool IsReorderable;
TJoinOptimizerNode(const std::shared_ptr<IBaseOptimizerNode>& left, const std::shared_ptr<IBaseOptimizerNode>& right,
- const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, const EJoinKind joinType, bool nonReorderable=false);
+ const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, const EJoinKind joinType, const EJoinAlgoType joinAlgo, bool nonReorderable=false);
virtual ~TJoinOptimizerNode() {}
virtual TVector<TString> Labels();
virtual void Print(std::stringstream& stream, int ntabs=0);
};
struct IOptimizerNew {
+ IProviderContext& Pctx;
+
+ IOptimizerNew(IProviderContext& ctx) : Pctx(ctx) {}
virtual ~IOptimizerNew() = default;
virtual std::shared_ptr<TJoinOptimizerNode> JoinSearch(const std::shared_ptr<TJoinOptimizerNode>& joinTree) = 0;
};
diff --git a/ydb/library/yql/core/yql_cost_function.cpp b/ydb/library/yql/core/yql_cost_function.cpp
index 5724c91e52..dcf395ca40 100644
--- a/ydb/library/yql/core/yql_cost_function.cpp
+++ b/ydb/library/yql/core/yql_cost_function.cpp
@@ -1,5 +1,7 @@
#include "yql_cost_function.h"
+#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h>
+
using namespace NYql;
namespace {
@@ -16,6 +18,7 @@ bool IsPKJoin(const TOptimizerStatistics& stats, const TVector<TString>& joinKey
}
return true;
}
+
}
bool NDq::operator < (const NDq::TJoinColumn& c1, const NDq::TJoinColumn& c2) {
@@ -36,8 +39,7 @@ bool NDq::operator < (const NDq::TJoinColumn& c1, const NDq::TJoinColumn& c2) {
*/
TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats,
- const TVector<TString>& leftJoinKeys, const TVector<TString>& rightJoinKeys, EJoinImplType joinImpl) {
- Y_UNUSED(joinImpl);
+ const TVector<TString>& leftJoinKeys, const TVector<TString>& rightJoinKeys, EJoinAlgoType joinAlgo, const IProviderContext& ctx) {
double newCard;
EStatisticsType outputType;
@@ -68,7 +70,7 @@ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStat
int newNCols = leftStats.Ncols + rightStats.Ncols;
- double cost = leftStats.Nrows + 2.0 * rightStats.Nrows
+ double cost = ctx.ComputeJoinCost(leftStats, rightStats, joinAlgo)
+ newCard
+ leftStats.Cost + rightStats.Cost;
@@ -76,7 +78,7 @@ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStat
}
TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats,
- const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, EJoinImplType joinImpl) {
+ const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, EJoinAlgoType joinAlgo, const IProviderContext& ctx) {
TVector<TString> leftJoinKeys;
TVector<TString> rightJoinKeys;
@@ -86,5 +88,5 @@ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStat
rightJoinKeys.emplace_back(c.second.AttributeName);
}
- return ComputeJoinStats(leftStats, rightStats, leftJoinKeys, rightJoinKeys, joinImpl);
+ return ComputeJoinStats(leftStats, rightStats, leftJoinKeys, rightJoinKeys, joinAlgo, ctx);
}
diff --git a/ydb/library/yql/core/yql_cost_function.h b/ydb/library/yql/core/yql_cost_function.h
index ae0b16de80..77adb6d3b1 100644
--- a/ydb/library/yql/core/yql_cost_function.h
+++ b/ydb/library/yql/core/yql_cost_function.h
@@ -14,6 +14,8 @@
*/
namespace NYql {
+struct IProviderContext;
+
namespace NDq {
/**
* Join column is a struct that records the relation label and
@@ -43,16 +45,19 @@ bool operator < (const TJoinColumn& c1, const TJoinColumn& c2);
}
-enum EJoinImplType {
+enum EJoinAlgoType {
DictJoin,
MapJoin,
- GraceJoin
+ GraceJoin,
+ LookupJoin
};
+static const EJoinAlgoType AllJoinAlgos[] = { DictJoin, MapJoin, GraceJoin, LookupJoin };
+
TOptimizerStatistics ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats,
- const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, EJoinImplType joinType);
+ const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, EJoinAlgoType joinAlgo, const IProviderContext& ctx);
TOptimizerStatistics ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats,
- const TVector<TString>& leftJoinKeys, const TVector<TString>& rightJoinKeys, EJoinImplType joinType);
+ const TVector<TString>& leftJoinKeys, const TVector<TString>& rightJoinKeys, EJoinAlgoType joinAlgo, const IProviderContext& ctx);
} \ No newline at end of file
diff --git a/ydb/library/yql/dq/opt/dq_cbo_ut.cpp b/ydb/library/yql/dq/opt/dq_cbo_ut.cpp
index b93e97543f..7a8700e506 100644
--- a/ydb/library/yql/dq/opt/dq_cbo_ut.cpp
+++ b/ydb/library/yql/dq/opt/dq_cbo_ut.cpp
@@ -140,16 +140,17 @@ Y_UNIT_TEST(RelCollector) {
.Done();
TTypeAnnotationContext typeCtx;
- UNIT_ASSERT(DqCollectJoinRelationsWithStats(typeCtx, equiJoin, [&](auto, auto) {}) == false);
+ TVector<std::shared_ptr<TRelOptimizerNode>> rels;
+ UNIT_ASSERT(DqCollectJoinRelationsWithStats(rels, typeCtx, equiJoin, [&](auto, auto, auto, auto) {}) == false);
typeCtx.StatisticsMap[tables[1].Ptr()->Child(0)] = std::make_shared<TOptimizerStatistics>(1, 1, 1);
- UNIT_ASSERT(DqCollectJoinRelationsWithStats(typeCtx, equiJoin, [&](auto, auto) {}) == false);
+ UNIT_ASSERT(DqCollectJoinRelationsWithStats(rels, typeCtx, equiJoin, [&](auto, auto, auto, auto) {}) == false);
typeCtx.StatisticsMap[tables[0].Ptr()->Child(0)] = std::make_shared<TOptimizerStatistics>(1, 1, 1);
typeCtx.StatisticsMap[tables[2].Ptr()->Child(0)] = std::make_shared<TOptimizerStatistics>(1, 1, 1);
TVector<TString> labels;
- UNIT_ASSERT(DqCollectJoinRelationsWithStats(typeCtx, equiJoin, [&](auto label, auto) { labels.emplace_back(label); }) == true);
+ UNIT_ASSERT(DqCollectJoinRelationsWithStats(rels, typeCtx, equiJoin, [&](auto, auto label, auto, auto) { labels.emplace_back(label); }) == true);
UNIT_ASSERT(labels.size() == 3);
UNIT_ASSERT_STRINGS_EQUAL(labels[0], "orders");
UNIT_ASSERT_STRINGS_EQUAL(labels[1], "customer");
@@ -167,7 +168,8 @@ Y_UNIT_TEST(RelCollectorBrokenEquiJoin) {
.Done();
TTypeAnnotationContext typeCtx;
- UNIT_ASSERT(DqCollectJoinRelationsWithStats(typeCtx, equiJoin, [&](auto, auto) {}) == false);
+ TVector<std::shared_ptr<TRelOptimizerNode>> rels;
+ UNIT_ASSERT(DqCollectJoinRelationsWithStats(rels, typeCtx, equiJoin, [&](auto, auto, auto, auto) {}) == false);
}
void _DqOptimizeEquiJoinWithCosts(const std::function<IOptimizer*(IOptimizer::TInput&&)>& optFactory) {
diff --git a/ydb/library/yql/dq/opt/dq_opt_join.h b/ydb/library/yql/dq/opt/dq_opt_join.h
index 8d83b2de64..9b9c071345 100644
--- a/ydb/library/yql/dq/opt/dq_opt_join.h
+++ b/ydb/library/yql/dq/opt/dq_opt_join.h
@@ -8,6 +8,7 @@
namespace NYql {
struct TOptimizerStatistics;
+struct TRelOptimizerNode;
namespace NDq {
@@ -25,9 +26,10 @@ NNodes::TExprBase DqBuildJoinDict(const NNodes::TDqJoin& join, TExprContext& ctx
NNodes::TDqJoin DqSuppressSortOnJoinInput(const NNodes::TDqJoin& node, TExprContext& ctx);
bool DqCollectJoinRelationsWithStats(
+ TVector<std::shared_ptr<TRelOptimizerNode>>& rels,
TTypeAnnotationContext& typesCtx,
const NNodes::TCoEquiJoin& equiJoin,
- const std::function<void(TStringBuf, const std::shared_ptr<TOptimizerStatistics>&)>& collector);
+ const std::function<void(TVector<std::shared_ptr<TRelOptimizerNode>>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr<TOptimizerStatistics>&)>& collector);
} // namespace NDq
} // namespace NYql
diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
index 79d73f11de..38d824e9ae 100644
--- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
@@ -13,7 +13,6 @@
#include <ydb/library/yql/core/cbo/cbo_optimizer.h> //interface
#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h> //new interface
-
#include <library/cpp/disjoint_sets/disjoint_sets.h>
@@ -98,10 +97,88 @@ void ComputeJoinConditions(const TCoEquiJoinTuple& joinTuple,
std::shared_ptr<TJoinOptimizerNode> MakeJoin(std::shared_ptr<IBaseOptimizerNode> left,
std::shared_ptr<IBaseOptimizerNode> right,
const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions,
- EJoinImplType joinImpl) {
+ EJoinKind joinKind,
+ EJoinAlgoType joinAlgo,
+ IProviderContext& ctx) {
+
+ auto res = std::make_shared<TJoinOptimizerNode>(left, right, joinConditions, joinKind, joinAlgo);
+ res->Stats = std::make_shared<TOptimizerStatistics>( ComputeJoinStats(*left->Stats, *right->Stats, joinConditions, joinAlgo, ctx));
+ return res;
+}
+
+/**
+ * Iterate over all join algorithms and pick the best join that is applicable.
+ * Also considers commuting joins
+*/
+std::shared_ptr<TJoinOptimizerNode> PickBestJoin(std::shared_ptr<IBaseOptimizerNode> left,
+ std::shared_ptr<IBaseOptimizerNode> right,
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& leftJoinConditions,
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& rightJoinConditions,
+ IProviderContext& ctx) {
+
+ auto res = std::shared_ptr<TJoinOptimizerNode>();
+
+ for ( auto joinAlgo : AllJoinAlgos ) {
+ auto p1 = ctx.IsJoinApplicable(left, right, leftJoinConditions, joinAlgo) ?
+ MakeJoin(left, right, leftJoinConditions, EJoinKind::InnerJoin, joinAlgo, ctx) :
+ std::shared_ptr<TJoinOptimizerNode>();
+ auto p2 = ctx.IsJoinApplicable(right, left, rightJoinConditions, joinAlgo) ?
+ MakeJoin(right, left, rightJoinConditions, EJoinKind::InnerJoin, joinAlgo, ctx) :
+ std::shared_ptr<TJoinOptimizerNode>();
+
+ if (p1) {
+ if (res) {
+ if (p1->Stats->Cost < res->Stats->Cost) {
+ res = p1;
+ }
+ } else {
+ res = p1;
+ }
+ }
+ if (p2) {
+ if (res) {
+ if (p2->Stats->Cost < res->Stats->Cost) {
+ res = p2;
+ }
+ } else {
+ res = p2;
+ }
+ }
+ }
- auto res = std::make_shared<TJoinOptimizerNode>(left, right, joinConditions, EJoinKind::InnerJoin);
- res->Stats = std::make_shared<TOptimizerStatistics>( ComputeJoinStats(*left->Stats, *right->Stats, joinConditions, joinImpl));
+ Y_ENSURE(res,"No join was chosen!");
+ return res;
+}
+
+/**
+ * Iterate over all join algorithms and pick the best join that is applicable
+*/
+std::shared_ptr<TJoinOptimizerNode> PickBestNonReorderabeJoin(std::shared_ptr<IBaseOptimizerNode> left,
+ std::shared_ptr<IBaseOptimizerNode> right,
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& leftJoinConditions,
+ EJoinKind joinKind,
+ IProviderContext& ctx) {
+
+ auto res = std::shared_ptr<TJoinOptimizerNode>();
+
+ for ( auto joinAlgo : AllJoinAlgos ) {
+ auto p = ctx.IsJoinApplicable(left, right, leftJoinConditions, joinAlgo) ?
+ MakeJoin(left, right, leftJoinConditions, joinKind, joinAlgo, ctx) :
+ std::shared_ptr<TJoinOptimizerNode>();
+
+ if (p) {
+ if (res) {
+ if (p->Stats->Cost < res->Stats->Cost) {
+ res = p;
+ }
+ } else {
+ res = p;
+ }
+ }
+
+ }
+
+ Y_ENSURE(res,"No join was chosen!");
return res;
}
@@ -309,8 +386,8 @@ class TDPccpSolver {
public:
// Construct the DPccp solver based on the join graph and data about input relations
- TDPccpSolver(TGraph<N>& g, TVector<std::shared_ptr<IBaseOptimizerNode>> rels):
- Graph(g), Rels(rels) {
+ TDPccpSolver(TGraph<N>& g, TVector<std::shared_ptr<IBaseOptimizerNode>> rels, IProviderContext& ctx):
+ Graph(g), Rels(rels), Pctx(ctx) {
NNodes = g.NNodes;
}
@@ -342,6 +419,10 @@ private:
// List of input relations to DPccp
TVector<std::shared_ptr<IBaseOptimizerNode>> Rels;
+
+ // Provider specific contexts?
+ // FIXME: This is a temporary structure that needs to be extended to multiple providers
+ IProviderContext& Pctx;
// Emit connected subgraph
void EmitCsg(const std::bitset<N>&, int=0);
@@ -548,34 +629,27 @@ template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, cons
std::bitset<N> joined = S1 | S2;
+ TEdge e1 = Graph.FindCrossingEdge(S1, S2);
+ TEdge e2 = Graph.FindCrossingEdge(S2, S1);
+ auto bestJoin = PickBestJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, e2.JoinConditions, Pctx);
+
if (! DpTable.contains(joined)) {
- TEdge e1 = Graph.FindCrossingEdge(S1, S2);
- DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin);
- TEdge e2 = Graph.FindCrossingEdge(S2, S1);
- std::shared_ptr<TJoinOptimizerNode> newJoin =
- MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin);
- if (newJoin->Stats->Cost < DpTable[joined]->Stats->Cost){
- DpTable[joined] = newJoin;
- }
+ DpTable[joined] = bestJoin;
} else {
- TEdge e1 = Graph.FindCrossingEdge(S1, S2);
- std::shared_ptr<TJoinOptimizerNode> newJoin1 =
- MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin);
- TEdge e2 = Graph.FindCrossingEdge(S2, S1);
- std::shared_ptr<TJoinOptimizerNode> newJoin2 =
- MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin);
- if (newJoin1->Stats->Cost < DpTable[joined]->Stats->Cost){
- DpTable[joined] = newJoin1;
- }
- if (newJoin2->Stats->Cost < DpTable[joined]->Stats->Cost){
- DpTable[joined] = newJoin2;
+ if (bestJoin->Stats->Cost < DpTable[joined]->Stats->Cost) {
+ DpTable[joined] = bestJoin;
}
}
+ /*
+ * This is a sanity check that slows down the optimizer
+ *
+
auto pair = std::make_pair(S1, S2);
Y_ENSURE (!CheckTable.contains(pair), "Check table already contains pair S1|S2");
CheckTable[ std::pair<std::bitset<N>,std::bitset<N>>(S1, S2) ] = true;
+ */
}
/**
@@ -782,9 +856,10 @@ TExprBase RearrangeEquiJoinTree(TExprContext& ctx, const TCoEquiJoin& equiJoin,
}
bool DqCollectJoinRelationsWithStats(
+ TVector<std::shared_ptr<TRelOptimizerNode>>& rels,
TTypeAnnotationContext& typesCtx,
const TCoEquiJoin& equiJoin,
- const std::function<void(TStringBuf, const std::shared_ptr<TOptimizerStatistics>&)>& collector)
+ const std::function<void(TVector<std::shared_ptr<TRelOptimizerNode>>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr<TOptimizerStatistics>&)>& collector)
{
if (equiJoin.ArgCount() < 3) {
return false;
@@ -808,7 +883,7 @@ bool DqCollectJoinRelationsWithStats(
TStringBuf label = scope.Cast<TCoAtom>();
auto stats = maybeStat->second;
- collector(label, stats);
+ collector(rels, label, joinArg.Ptr(), stats);
}
return true;
}
@@ -861,7 +936,7 @@ std::shared_ptr<TJoinOptimizerNode> ConvertToJoinTree(const TCoEquiJoinTuple& jo
TJoinColumn(rightScope, rightColumn)));
}
- return std::make_shared<TJoinOptimizerNode>(left,right,joinConds,ConvertToJoinKind(joinTuple.Type().StringValue()));
+ return std::make_shared<TJoinOptimizerNode>(left, right, joinConds, ConvertToJoinKind(joinTuple.Type().StringValue()), EJoinAlgoType::DictJoin);
}
/**
@@ -919,14 +994,14 @@ void ExtractRelsAndJoinConditions(const std::shared_ptr<TJoinOptimizerNode>& joi
/**
* Recursively computes statistics for a join tree
*/
-void ComputeStatistics(const std::shared_ptr<TJoinOptimizerNode>& join) {
+void ComputeStatistics(const std::shared_ptr<TJoinOptimizerNode>& join, IProviderContext& ctx) {
if (join->LeftArg->Kind == EOptimizerNodeKind::JoinNodeType) {
- ComputeStatistics(static_pointer_cast<TJoinOptimizerNode>(join->LeftArg));
+ ComputeStatistics(static_pointer_cast<TJoinOptimizerNode>(join->LeftArg), ctx);
}
if (join->RightArg->Kind == EOptimizerNodeKind::JoinNodeType) {
- ComputeStatistics(static_pointer_cast<TJoinOptimizerNode>(join->RightArg));
+ ComputeStatistics(static_pointer_cast<TJoinOptimizerNode>(join->RightArg), ctx);
}
- join->Stats = std::make_shared<TOptimizerStatistics>(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinImplType::DictJoin));
+ join->Stats = std::make_shared<TOptimizerStatistics>(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinAlgoType::DictJoin, ctx));
}
/**
@@ -934,10 +1009,9 @@ void ComputeStatistics(const std::shared_ptr<TJoinOptimizerNode>& join) {
* The root of the subtree that needs to be optimizer needs to be reorderable, otherwise we will
* only update the statistics for it and return it unchanged
*/
-std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinOptimizerNode>& joinTree, ui32 maxDPccpDPTableSize) {
+std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinOptimizerNode>& joinTree, ui32 maxDPccpDPTableSize, IProviderContext& ctx) {
if (!joinTree->IsReorderable) {
- joinTree->Stats = std::make_shared<TOptimizerStatistics>(ComputeJoinStats(*joinTree->LeftArg->Stats, *joinTree->RightArg->Stats, joinTree->JoinConditions, EJoinImplType::DictJoin));
- return joinTree;
+ return PickBestNonReorderabeJoin(joinTree->LeftArg, joinTree->RightArg, joinTree->JoinConditions, joinTree->JoinType, ctx);
}
TGraph<64> joinGraph;
@@ -954,7 +1028,7 @@ std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinO
// If that's the case - don't optimize the plan and just return it with
// computed statistics
if (rels.size() >= 64) {
- ComputeStatistics(joinTree);
+ ComputeStatistics(joinTree, ctx);
return joinTree;
}
@@ -981,12 +1055,12 @@ std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinO
YQL_CLOG(TRACE, CoreDq) << str.str();
}
- TDPccpSolver<64> solver(joinGraph,rels);
+ TDPccpSolver<64> solver(joinGraph, rels, ctx);
// Check that the dynamic table of DPccp is not too big
// If it is, just compute the statistics for the join tree and return it
if (solver.CountCC(maxDPccpDPTableSize) >= maxDPccpDPTableSize) {
- ComputeStatistics(joinTree);
+ ComputeStatistics(joinTree, ctx);
return joinTree;
}
@@ -1005,8 +1079,8 @@ std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinO
class TOptimizerNativeNew: public IOptimizerNew {
public:
- TOptimizerNativeNew(const ui32 maxDPccpDPTableSize)
- : MaxDPccpDPTableSize(maxDPccpDPTableSize) { }
+ TOptimizerNativeNew(IProviderContext& ctx, const ui32 maxDPccpDPTableSize)
+ : IOptimizerNew(ctx), MaxDPccpDPTableSize(maxDPccpDPTableSize) { }
std::shared_ptr<TJoinOptimizerNode> JoinSearch(const std::shared_ptr<TJoinOptimizerNode>& joinTree) override {
// Traverse the join tree and generate a list of non-orderable joins in a post-order
@@ -1016,16 +1090,16 @@ public:
// For all non-orderable joins, optimize the children
for( auto join : nonOrderables ) {
if (join->LeftArg->Kind == EOptimizerNodeKind::JoinNodeType) {
- join->LeftArg = OptimizeSubtree(static_pointer_cast<TJoinOptimizerNode>(join->LeftArg), MaxDPccpDPTableSize);
+ join->LeftArg = OptimizeSubtree(static_pointer_cast<TJoinOptimizerNode>(join->LeftArg), MaxDPccpDPTableSize, Pctx);
}
if (join->RightArg->Kind == EOptimizerNodeKind::JoinNodeType) {
- join->RightArg = OptimizeSubtree(static_pointer_cast<TJoinOptimizerNode>(join->RightArg), MaxDPccpDPTableSize);
+ join->RightArg = OptimizeSubtree(static_pointer_cast<TJoinOptimizerNode>(join->RightArg), MaxDPccpDPTableSize, Pctx);
}
- join->Stats = std::make_shared<TOptimizerStatistics>(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinImplType::DictJoin));
+ join->Stats = std::make_shared<TOptimizerStatistics>(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinAlgoType::DictJoin, Pctx));
}
// Optimize the root
- return OptimizeSubtree(joinTree, MaxDPccpDPTableSize);
+ return OptimizeSubtree(joinTree, MaxDPccpDPTableSize, Pctx);
}
const ui32 MaxDPccpDPTableSize;
@@ -1041,9 +1115,10 @@ public:
* and finally optimizes the root of the tree
*/
TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx,
- bool ruleEnabled, ui32 maxDPccpDPTableSize) {
+ ui32 optLevel, ui32 maxDPccpDPTableSize, IProviderContext& providerCtx,
+ const std::function<void(TVector<std::shared_ptr<TRelOptimizerNode>>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr<TOptimizerStatistics>&)>& providerCollect) {
- if (!ruleEnabled) {
+ if (optLevel==0) {
return node;
}
@@ -1065,9 +1140,8 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
// Check that statistics for all inputs of equiJoin were computed
// The arguments of the EquiJoin are 1..n-2, n-2 is the actual join tree
// of the EquiJoin and n-1 argument are the parameters to EquiJoin
- if (!DqCollectJoinRelationsWithStats(typesCtx, equiJoin, [&](auto label, auto stat) {
- rels.emplace_back(std::make_shared<TRelOptimizerNode>(TString(label), stat));
- })) {
+
+ if (!DqCollectJoinRelationsWithStats(rels, typesCtx, equiJoin, providerCollect)){
return node;
}
@@ -1078,7 +1152,7 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
// Generate an initial tree
auto joinTree = ConvertToJoinTree(joinTuple,rels);
- auto opt = TOptimizerNativeNew(maxDPccpDPTableSize);
+ auto opt = TOptimizerNativeNew(providerCtx, maxDPccpDPTableSize);
joinTree = opt.JoinSearch(joinTree);
// rewrite the join tree and record the output statistics
@@ -1097,7 +1171,8 @@ public:
}
TOutput JoinSearch() override {
- TDPccpSolver<64> solver(JoinGraph, Rels);
+ auto dummyProviderCtx = TDummyProviderContext();
+ TDPccpSolver<64> solver(JoinGraph, Rels, dummyProviderCtx);
std::shared_ptr<TJoinOptimizerNode> result = solver.Solve();
if (Log) {
std::stringstream str;
diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
index 75a1bbf1a1..591737af61 100644
--- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp
@@ -157,11 +157,11 @@ TExprBase DqOptimizeEquiJoinWithCosts(
TExprContext& ctx,
TTypeAnnotationContext& typesCtx,
const std::function<IOptimizer*(IOptimizer::TInput&&)>& optFactory,
- bool ruleEnabled)
+ ui32 optLevel)
{
Y_UNUSED(ctx);
- if (!ruleEnabled) {
+ if (optLevel==0) {
return node;
}
@@ -184,7 +184,10 @@ TExprBase DqOptimizeEquiJoinWithCosts(
TState state(equiJoin);
// collect Rels
- if (!DqCollectJoinRelationsWithStats(typesCtx, equiJoin, [&](auto label, auto stat) {
+ TVector<std::shared_ptr<TRelOptimizerNode>> rels;
+ if (!DqCollectJoinRelationsWithStats(rels, typesCtx, equiJoin, [&](auto r, auto label, auto node, auto stat) {
+ Y_UNUSED(r);
+ Y_UNUSED(node);
state.CollectRel(label, stat);
})) {
return node;
diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h
index 83061e1b3d..0c140b9d99 100644
--- a/ydb/library/yql/dq/opt/dq_opt_log.h
+++ b/ydb/library/yql/dq/opt/dq_opt_log.h
@@ -11,6 +11,9 @@
namespace NYql {
struct TTypeAnnotationContext;
struct TDqSettings;
+ struct IProviderContext;
+ struct TRelOptimizerNode;
+ struct TOptimizerStatistics;
}
namespace NYql::NDq {
@@ -19,14 +22,21 @@ NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx,
NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents);
-NNodes::TExprBase DqOptimizeEquiJoinWithCosts(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool isRuleEnabled, ui32 maxDPccpDPTableSize);
+NNodes::TExprBase DqOptimizeEquiJoinWithCosts(
+ const NNodes::TExprBase& node,
+ TExprContext& ctx,
+ TTypeAnnotationContext& typesCtx,
+ ui32 optLevel,
+ ui32 maxDPccpDPTableSize,
+ IProviderContext& providerCtx,
+ const std::function<void(TVector<std::shared_ptr<TRelOptimizerNode>>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr<TOptimizerStatistics>&)>& providerCollect);
NNodes::TExprBase DqOptimizeEquiJoinWithCosts(
const NNodes::TExprBase& node,
TExprContext& ctx,
TTypeAnnotationContext& typesCtx,
const std::function<IOptimizer*(IOptimizer::TInput&&)>& optFactory,
- bool ruleEnabled);
+ ui32 optLevel);
NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, TExprContext& ctx);
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.cpp b/ydb/library/yql/dq/opt/dq_opt_stat.cpp
index 747bda5ca5..aed5076e9e 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_stat.cpp
@@ -93,7 +93,7 @@ bool IsConstantExpr(const TExprNode::TPtr& input) {
* Compute statistics for map join
* FIX: Currently we treat all join the same from the cost perspective, need to refine cost function
*/
-void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx) {
auto inputNode = TExprBase(input);
auto join = inputNode.Cast<TCoMapJoinCore>();
@@ -118,14 +118,14 @@ void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationCont
}
typeCtx->SetStats(join.Raw(), std::make_shared<TOptimizerStatistics>(
- ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, MapJoin)));
+ ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, MapJoin, ctx)));
}
/**
* Compute statistics for grace join
* FIX: Currently we treat all join the same from the cost perspective, need to refine cost function
*/
-void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx) {
auto inputNode = TExprBase(input);
auto join = inputNode.Cast<TCoGraceJoinCore>();
@@ -150,7 +150,7 @@ void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationCo
}
typeCtx->SetStats(join.Raw(), std::make_shared<TOptimizerStatistics>(
- ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, GraceJoin)));
+ ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, GraceJoin, ctx)));
}
/**
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.h b/ydb/library/yql/dq/opt/dq_opt_stat.h
index 7a5f954276..4f3497da4f 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat.h
+++ b/ydb/library/yql/dq/opt/dq_opt_stat.h
@@ -1,6 +1,7 @@
#include "dq_opt.h"
#include <ydb/library/yql/core/yql_type_annotation.h>
+#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h>
namespace NYql::NDq {
@@ -14,8 +15,8 @@ void PropagateStatisticsToLambdaArgument(const TExprNode::TPtr& input, TTypeAnno
void PropagateStatisticsToStageArguments(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForStage(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForDqSource(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
-void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
-void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx);
+void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx);
double ComputePredicateSelectivity(const NNodes::TExprBase& input, const std::shared_ptr<TOptimizerStatistics>& stats);
bool NeedCalc(NNodes::TExprBase node);
bool IsConstantExpr(const TExprNode::TPtr& input);
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp
index 4284d30cec..f29e814444 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp
@@ -7,8 +7,8 @@ namespace NYql::NDq {
using namespace NNodes;
-TDqStatisticsTransformerBase::TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx)
- : TypeCtx(typeCtx)
+TDqStatisticsTransformerBase::TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx, const IProviderContext& ctx)
+ : TypeCtx(typeCtx), Pctx(ctx)
{ }
IGraphTransformer::TStatus TDqStatisticsTransformerBase::DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) {
@@ -55,10 +55,10 @@ bool TDqStatisticsTransformerBase::BeforeLambdas(const TExprNode::TPtr& input, T
// Join matchers
else if(TCoMapJoinCore::Match(input.Get())) {
- InferStatisticsForMapJoin(input, TypeCtx);
+ InferStatisticsForMapJoin(input, TypeCtx, Pctx);
}
else if(TCoGraceJoinCore::Match(input.Get())) {
- InferStatisticsForGraceJoin(input, TypeCtx);
+ InferStatisticsForGraceJoin(input, TypeCtx, Pctx);
}
// Do nothing in case of EquiJoin, otherwise the EquiJoin rule won't fire
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h
index 45ff27acc3..8201832b64 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h
+++ b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h
@@ -2,12 +2,13 @@
#include <ydb/library/yql/core/yql_graph_transformer.h>
#include <ydb/library/yql/core/yql_type_annotation.h>
+#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h>
namespace NYql::NDq {
class TDqStatisticsTransformerBase : public TSyncTransformerBase {
public:
- TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx);
+ TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx, const IProviderContext& ctx);
IGraphTransformer::TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) override;
void Rewind() override;
@@ -21,6 +22,7 @@ protected:
bool AfterLambdas(const TExprNode::TPtr& input, TExprContext& ctx);
TTypeAnnotationContext* TypeCtx;
+ const IProviderContext& Pctx;
};
} // namespace NYql::NDq
diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.h b/ydb/library/yql/providers/dq/common/yql_dq_settings.h
index b50224a76e..3530a3178e 100644
--- a/ydb/library/yql/providers/dq/common/yql_dq_settings.h
+++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.h
@@ -56,6 +56,7 @@ struct TDqSettings {
static constexpr bool ExportStats = false;
static constexpr ETaskRunnerStats TaskRunnerStats = ETaskRunnerStats::Basic;
static constexpr ESpillingEngine SpillingEngine = ESpillingEngine::Disable;
+ static constexpr ui32 CostBasedOptimizationLevel = 0;
static constexpr ui32 MaxDPccpDPTableSize = 10000U;
};
diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp
index e91c019b5a..886d5787b4 100644
--- a/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp
+++ b/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp
@@ -50,7 +50,7 @@ public:
, ExecTransformer_([this, execTransformerFactory] () { return THolder<IGraphTransformer>(execTransformerFactory(State_)); })
, TypeAnnotationTransformer_([] () { return CreateDqsDataSourceTypeAnnotationTransformer(); })
, ConstraintsTransformer_([] () { return CreateDqDataSourceConstraintTransformer(); })
- , StatisticsTransformer_([this]() { return CreateDqsStatisticsTransformer(State_); })
+ , StatisticsTransformer_([this]() { return CreateDqsStatisticsTransformer(State_, TDummyProviderContext::instance()); })
{ }
TStringBuf GetName() const override {
diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp
index 4d44e7cd03..36b3e46031 100644
--- a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp
+++ b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp
@@ -14,8 +14,8 @@ using namespace NNodes;
class TDqsStatisticsTransformer : public NDq::TDqStatisticsTransformerBase {
public:
- TDqsStatisticsTransformer(const TDqStatePtr& state)
- : NDq::TDqStatisticsTransformerBase(state->TypeCtx)
+ TDqsStatisticsTransformer(const TDqStatePtr& state, const IProviderContext& ctx)
+ : NDq::TDqStatisticsTransformerBase(state->TypeCtx, ctx)
, State(state)
{ }
@@ -55,8 +55,8 @@ private:
TDqStatePtr State;
};
-THolder<IGraphTransformer> CreateDqsStatisticsTransformer(TDqStatePtr state) {
- return MakeHolder<TDqsStatisticsTransformer>(state);
+THolder<IGraphTransformer> CreateDqsStatisticsTransformer(TDqStatePtr state, const IProviderContext& ctx) {
+ return MakeHolder<TDqsStatisticsTransformer>(state, ctx);
}
} // namespace NYql
diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h
index 460b363bf4..6a5592c775 100644
--- a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h
+++ b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h
@@ -1,6 +1,7 @@
#pragma once
#include <ydb/library/yql/core/yql_graph_transformer.h>
+#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h>
#include <util/generic/ptr.h>
@@ -9,6 +10,6 @@ namespace NYql {
struct TDqState;
using TDqStatePtr = TIntrusivePtr<TDqState>;
-THolder<IGraphTransformer> CreateDqsStatisticsTransformer(TDqStatePtr state);
+THolder<IGraphTransformer> CreateDqsStatisticsTransformer(TDqStatePtr state, const IProviderContext& ctx);
} // namespace NYql