aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorpavelvelikhov <pavelvelikhov@yandex-team.com>2023-09-22 16:11:12 +0300
committerpavelvelikhov <pavelvelikhov@yandex-team.com>2023-09-22 17:00:42 +0300
commit4ee466d5baecacc8e10b342caae9b57f8bf477af (patch)
tree6cc3ff243adb496a14c1945afc3c969f0cc67bdf
parent77fc18884ef08f62201e024b6c7b97084a358273 (diff)
downloadydb-4ee466d5baecacc8e10b342caae9b57f8bf477af.tar.gz
Compute optimizer statistics at physical level, added to Explain plan
Final commit
-rw-r--r--ydb/core/kqp/host/kqp_explain_prepared.cpp12
-rw-r--r--ydb/core/kqp/host/kqp_host_impl.h3
-rw-r--r--ydb/core/kqp/host/kqp_runner.cpp3
-rw-r--r--ydb/core/kqp/opt/kqp_query_plan.cpp86
-rw-r--r--ydb/core/kqp/opt/kqp_query_plan.h3
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.cpp299
-rw-r--r--ydb/library/yql/core/CMakeLists.darwin-x86_64.txt1
-rw-r--r--ydb/library/yql/core/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/library/yql/core/CMakeLists.linux-x86_64.txt1
-rw-r--r--ydb/library/yql/core/CMakeLists.windows-x86_64.txt1
-rw-r--r--ydb/library/yql/core/ya.make1
-rw-r--r--ydb/library/yql/core/yql_cost_function.cpp22
-rw-r--r--ydb/library/yql/core/yql_cost_function.h19
-rw-r--r--ydb/library/yql/core/yql_expr_optimize.cpp32
-rw-r--r--ydb/library/yql/core/yql_expr_optimize.h2
-rw-r--r--ydb/library/yql/core/yql_statistics.cpp4
-rw-r--r--ydb/library/yql/core/yql_statistics.h4
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp39
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat.cpp126
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat.h4
20 files changed, 593 insertions, 70 deletions
diff --git a/ydb/core/kqp/host/kqp_explain_prepared.cpp b/ydb/core/kqp/host/kqp_explain_prepared.cpp
index 695104abcd1..f72f49c3123 100644
--- a/ydb/core/kqp/host/kqp_explain_prepared.cpp
+++ b/ydb/core/kqp/host/kqp_explain_prepared.cpp
@@ -15,12 +15,14 @@ using namespace NThreading;
class TKqpExplainPreparedTransformer : public NYql::TGraphTransformerBase {
public:
TKqpExplainPreparedTransformer(TIntrusivePtr<IKqpGateway> gateway, const TString& cluster,
- TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry)
+ TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry,
+ TTypeAnnotationContext& typeCtx)
: Gateway(gateway)
, Cluster(cluster)
, TransformCtx(transformCtx)
, FuncRegistry(funcRegistry)
, CurrentTxIndex(0)
+ , TypeCtx(typeCtx)
{
TxAlloc = TransformCtx->QueryCtx->QueryData->GetAllocState();
}
@@ -57,7 +59,7 @@ public:
}
PhyQuerySetTxPlans(query, TKqpPhysicalQuery(input), std::move(TxResults),
- ctx, Cluster, TransformCtx->Tables, TransformCtx->Config);
+ ctx, Cluster, TransformCtx->Tables, TransformCtx->Config, TypeCtx);
query.SetQueryAst(KqpExprToPrettyString(*input, ctx));
return TStatus::Ok;
@@ -112,16 +114,18 @@ private:
NThreading::TFuture<IKqpGateway::TExecPhysicalResult> ExecuteFuture;
NThreading::TPromise<void> Promise;
TTxAllocatorState::TPtr TxAlloc;
+ TTypeAnnotationContext& TypeCtx;
};
TAutoPtr<IGraphTransformer> CreateKqpExplainPreparedTransformer(TIntrusivePtr<IKqpGateway> gateway,
const TString& cluster, TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry,
- TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider)
+ TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider,
+ TTypeAnnotationContext& typeCtx)
{
Y_UNUSED(randomProvider);
Y_UNUSED(timeProvider);
- return new TKqpExplainPreparedTransformer(gateway, cluster, transformCtx, funcRegistry);
+ return new TKqpExplainPreparedTransformer(gateway, cluster, transformCtx, funcRegistry, typeCtx);
}
} // namespace NKqp
diff --git a/ydb/core/kqp/host/kqp_host_impl.h b/ydb/core/kqp/host/kqp_host_impl.h
index b2972b01763..d6fb8986b07 100644
--- a/ydb/core/kqp/host/kqp_host_impl.h
+++ b/ydb/core/kqp/host/kqp_host_impl.h
@@ -249,7 +249,8 @@ TIntrusivePtr<IKqpRunner> CreateKqpRunner(TIntrusivePtr<IKqpGateway> gateway, co
TAutoPtr<NYql::IGraphTransformer> CreateKqpExplainPreparedTransformer(TIntrusivePtr<IKqpGateway> gateway,
const TString& cluster, TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry,
- TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider);
+ TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider,
+ NYql::TTypeAnnotationContext& typeCtx);
TAutoPtr<NYql::IGraphTransformer> CreateKqpTypeAnnotationTransformer(const TString& cluster,
TIntrusivePtr<NYql::TKikimrTablesData> tablesData, NYql::TTypeAnnotationContext& typesCtx,
diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp
index f77295229d8..fa22700cafc 100644
--- a/ydb/core/kqp/host/kqp_runner.cpp
+++ b/ydb/core/kqp/host/kqp_runner.cpp
@@ -281,7 +281,7 @@ public:
PreparedExplainTransformer = TTransformationPipeline(typesCtx)
.Add(CreateKqpExplainPreparedTransformer(
- Gateway, Cluster, TransformCtx, &funcRegistry, timeProvider, randomProvider), "ExplainQuery")
+ Gateway, Cluster, TransformCtx, &funcRegistry, timeProvider, randomProvider, *typesCtx), "ExplainQuery")
.Build(false);
PhysicalOptimizeTransformer = CreateKqpQueryBlocksTransformer(TTransformationPipeline(typesCtx)
@@ -329,6 +329,7 @@ public:
.AddTypeAnnotationTransformer(CreateKqpTypeAnnotationTransformer(Cluster, sessionCtx->TablesPtr(), *typesCtx, Config))
.AddPostTypeAnnotation()
.Add(CreateKqpBuildPhysicalQueryTransformer(OptimizeCtx, BuildQueryCtx), "BuildPhysicalQuery")
+ .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics")
.Build(false);
PhysicalPeepholeTransformer = TTransformationPipeline(typesCtx)
diff --git a/ydb/core/kqp/opt/kqp_query_plan.cpp b/ydb/core/kqp/opt/kqp_query_plan.cpp
index 3836800ca4b..0529b4735cc 100644
--- a/ydb/core/kqp/opt/kqp_query_plan.cpp
+++ b/ydb/core/kqp/opt/kqp_query_plan.cpp
@@ -66,13 +66,15 @@ struct TSerializerCtx {
TSerializerCtx(TExprContext& exprCtx, const TString& cluster,
const TIntrusivePtr<NYql::TKikimrTablesData> tablesData,
const TKikimrConfiguration::TPtr config, ui32 txCount,
- TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults)
+ TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults,
+ TTypeAnnotationContext& typeCtx)
: ExprCtx(exprCtx)
, Cluster(cluster)
, TablesData(tablesData)
, Config(config)
, TxCount(txCount)
, PureTxResults(std::move(pureTxResults))
+ , TypeCtx(typeCtx)
{}
TMap<TString, TTableInfo> Tables;
@@ -88,6 +90,7 @@ struct TSerializerCtx {
const TKikimrConfiguration::TPtr Config;
const ui32 TxCount;
TVector<TVector<NKikimrMiniKQL::TResult>> PureTxResults;
+ TTypeAnnotationContext& TypeCtx;
};
TString GetExprStr(const TExprBase& scalar, bool quoteStr = true) {
@@ -724,6 +727,8 @@ private:
op.Properties["Reverse"] = true;
}
+ AddOptimizerEstimates(op, sourceSettings);
+
SerializerCtx.Tables[table].Reads.push_back(readInfo);
if (readInfo.Type == EPlanTableReadType::Scan) {
@@ -833,6 +838,8 @@ private:
op.Properties["Reverse"] = true;
}
+ AddOptimizerEstimates(op, sourceSettings);
+
if (readInfo.Type == EPlanTableReadType::FullScan) {
op.Properties["Name"] = "TableFullScan";
AddOperator(planNode, "TableFullScan", std::move(op));
@@ -956,6 +963,17 @@ private:
})).Cast<TCoJoinDict>();
operatorId = Visit(flatMap, join, planNode);
node = join.Ptr();
+ } else if (auto maybeJoinDict = TMaybeNode<TCoGraceJoinCore>(node)) {
+ operatorId = Visit(maybeJoinDict.Cast(), planNode);
+ } else if (TMaybeNode<TCoFlatMapBase>(node).Lambda().Body().Maybe<TCoGraceJoinCore>() ||
+ TMaybeNode<TCoFlatMapBase>(node).Lambda().Body().Maybe<TCoMap>().Input().Maybe<TCoGraceJoinCore>()) {
+ auto flatMap = TMaybeNode<TCoFlatMapBase>(node).Cast();
+ auto join = TExprBase(FindNode(node, [](const TExprNode::TPtr& node) {
+ Y_ENSURE(!TMaybeNode<TDqConnection>(node).IsValid());
+ return TMaybeNode<TCoGraceJoinCore>(node).IsValid();
+ })).Cast<TCoGraceJoinCore>();
+ operatorId = Visit(flatMap, join, planNode);
+ node = join.Ptr();
} else if (auto maybeCondense1 = TMaybeNode<TCoCondense1>(node)) {
operatorId = Visit(maybeCondense1.Cast(), planNode);
} else if (auto maybeCondense = TMaybeNode<TCoCondense>(node)) {
@@ -1130,6 +1148,9 @@ private:
TOperator op;
op.Properties["Name"] = name;
+
+ AddOptimizerEstimates(op, join);
+
auto operatorId = AddOperator(planNode, name, std::move(op));
auto inputs = Visit(flatMap.Input().Ptr(), planNode);
@@ -1142,6 +1163,9 @@ private:
TOperator op;
op.Properties["Name"] = name;
+
+ AddOptimizerEstimates(op, join);
+
return AddOperator(planNode, name, std::move(op));
}
@@ -1150,6 +1174,9 @@ private:
TOperator op;
op.Properties["Name"] = name;
+
+ AddOptimizerEstimates(op, join);
+
auto operatorId = AddOperator(planNode, name, std::move(op));
auto inputs = Visit(flatMap.Input().Ptr(), planNode);
@@ -1162,6 +1189,34 @@ private:
TOperator op;
op.Properties["Name"] = name;
+
+ AddOptimizerEstimates(op, join);
+
+ return AddOperator(planNode, name, std::move(op));
+ }
+
+ ui32 Visit(const TCoFlatMapBase& flatMap, const TCoGraceJoinCore& join, TQueryPlanNode& planNode) {
+ const auto name = TStringBuilder() << join.JoinKind().Value() << "Join (Grace)";
+
+ TOperator op;
+ op.Properties["Name"] = name;
+ auto operatorId = AddOperator(planNode, name, std::move(op));
+
+ AddOptimizerEstimates(op, join);
+
+ auto inputs = Visit(flatMap.Input().Ptr(), planNode);
+ planNode.Operators[operatorId].Inputs.insert(inputs.begin(), inputs.end());
+ return operatorId;
+ }
+
+ ui32 Visit(const TCoGraceJoinCore& join, TQueryPlanNode& planNode) {
+ const auto name = TStringBuilder() << join.JoinKind().Value() << "Join (Grace)";
+
+ TOperator op;
+ op.Properties["Name"] = name;
+
+ AddOptimizerEstimates(op, join);
+
return AddOperator(planNode, name, std::move(op));
}
@@ -1176,6 +1231,22 @@ private:
return pred;
}
+ void AddOptimizerEstimates(TOperator& op, const TExprBase& expr) {
+ if (!SerializerCtx.Config->HasOptEnableCostBasedOptimization()) {
+ return;
+ }
+
+ if (auto stats = SerializerCtx.TypeCtx.GetStats(expr.Raw())) {
+ op.Properties["E-Rows"] = stats->Nrows;
+ op.Properties["E-Cost"] = stats->Cost.value();
+ }
+ else {
+ op.Properties["E-Rows"] = "No estimate";
+ op.Properties["E-Cost"] = "No estimate";
+
+ }
+ }
+
ui32 Visit(const TCoFilterBase& filter, TQueryPlanNode& planNode) {
TOperator op;
op.Properties["Name"] = "Filter";
@@ -1183,6 +1254,8 @@ private:
auto pred = ExtractPredicate(filter.Lambda());
op.Properties["Predicate"] = pred.Body;
+ AddOptimizerEstimates(op, filter);
+
if (filter.Limit()) {
op.Properties["Limit"] = PrettyExprStr(filter.Limit().Cast());
}
@@ -1205,6 +1278,8 @@ private:
columns.AppendValue(col.Value());
}
+ AddOptimizerEstimates(op, lookup);
+
SerializerCtx.Tables[table].Reads.push_back(readInfo);
planNode.NodeInfo["Tables"].AppendValue(op.Properties["Table"]);
return AddOperator(planNode, "TablePointLookup", std::move(op));
@@ -1306,6 +1381,8 @@ private:
op.Properties["SsaProgram"] = GetSsaProgramInJsonByTable(table, planNode.StageProto);
}
+ AddOptimizerEstimates(op, read);
+
ui32 operatorId;
if (readInfo.Type == EPlanTableReadType::FullScan) {
op.Properties["Name"] = "TableFullScan";
@@ -1444,6 +1521,8 @@ private:
SerializerCtx.Tables[table].Reads.push_back(readInfo);
+ AddOptimizerEstimates(op, read);
+
ui32 operatorId;
if (readInfo.Type == EPlanTableReadType::Scan) {
op.Properties["Name"] = "TableRangeScan";
@@ -1664,9 +1743,10 @@ TString SerializeTxPlans(const TVector<const TString>& txPlans, const TString co
// TODO(sk): check params from correlated subqueries // lookup join
void PhyQuerySetTxPlans(NKqpProto::TKqpPhyQuery& queryProto, const TKqpPhysicalQuery& query,
TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults, TExprContext& ctx, const TString& cluster,
- const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, TKikimrConfiguration::TPtr config)
+ const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, TKikimrConfiguration::TPtr config,
+ TTypeAnnotationContext& typeCtx)
{
- TSerializerCtx serializerCtx(ctx, cluster, tablesData, config, query.Transactions().Size(), std::move(pureTxResults));
+ TSerializerCtx serializerCtx(ctx, cluster, tablesData, config, query.Transactions().Size(), std::move(pureTxResults), typeCtx);
/* bindingName -> stage */
auto collectBindings = [&serializerCtx, &query] (auto id, const auto& phase) {
diff --git a/ydb/core/kqp/opt/kqp_query_plan.h b/ydb/core/kqp/opt/kqp_query_plan.h
index 5b08b9d4968..04ffe9516a5 100644
--- a/ydb/core/kqp/opt/kqp_query_plan.h
+++ b/ydb/core/kqp/opt/kqp_query_plan.h
@@ -34,7 +34,8 @@ enum class EPlanTableWriteType {
*/
void PhyQuerySetTxPlans(NKqpProto::TKqpPhyQuery& queryProto, const NYql::NNodes::TKqpPhysicalQuery& query,
TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults, NYql::TExprContext& ctx, const TString& cluster,
- const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, NYql::TKikimrConfiguration::TPtr config);
+ const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, NYql::TKikimrConfiguration::TPtr config,
+ NYql::TTypeAnnotationContext& typeCtx);
/*
* Fill stages in given txPlan with ExecutionStats fields. Each plan stage stores StageGuid which is
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
index ab4b0a32188..3b55caf4355 100644
--- a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
@@ -1,7 +1,9 @@
#include "kqp_statistics_transformer.h"
#include <ydb/library/yql/utils/log/log.h>
#include <ydb/library/yql/dq/opt/dq_opt_stat.h>
+#include <ydb/library/yql/core/yql_cost_function.h>
+#include <charconv>
using namespace NYql;
using namespace NYql::NNodes;
@@ -11,7 +13,7 @@ using namespace NYql::NDq;
/**
* Compute statistics and cost for read table
* Currently we look up the number of rows and attributes in the statistics service
-*/
+ */
void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx,
const TKqpOptimizeContext& kqpCtx) {
@@ -21,54 +23,246 @@ void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationCo
const TExprNode* path;
- if ( auto readTable = inputNode.Maybe<TKqlReadTableBase>()){
+ if (auto readTable = inputNode.Maybe<TKqlReadTableBase>()) {
path = readTable.Cast().Table().Path().Raw();
nAttrs = readTable.Cast().Columns().Size();
- } else if(auto readRanges = inputNode.Maybe<TKqlReadTableRangesBase>()){
+ } else if (auto readRanges = inputNode.Maybe<TKqlReadTableRangesBase>()) {
path = readRanges.Cast().Table().Path().Raw();
nAttrs = readRanges.Cast().Columns().Size();
} else {
- Y_ENSURE(false,"Invalid node type for InferStatisticsForReadTable");
+ Y_ENSURE(false, "Invalid node type for InferStatisticsForReadTable");
}
const auto& tableData = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, path->Content());
nRows = tableData.Metadata->RecordsCount;
+
YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table, nrows:" << nRows << ", nattrs: " << nAttrs;
auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0.0);
- typeCtx->SetStats( input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) );
+ typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
}
/**
- * Compute sstatistics for index lookup
+ * Infer statistics for KQP table
+ */
+void InferStatisticsForKqpTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx,
+ const TKqpOptimizeContext& kqpCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto readTable = inputNode.Cast<TKqpTable>();
+ auto path = readTable.Path();
+
+ const auto& tableData = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, path.Value());
+ double nRows = tableData.Metadata->RecordsCount;
+ int nAttrs = tableData.Metadata->Columns.size();
+ YQL_CLOG(TRACE, CoreDq) << "Infer statistics for table: " << path.Value() << ", nrows: " << nRows << ", nattrs: " << nAttrs;
+
+ auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0.0);
+ typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
+}
+
+/**
+ * Infer statistics for TableLookup
+ *
+ * Table lookup can be done with an Iterator, in which case we treat it as a full scan
+ * We don't differentiate between a small range and full scan at this time
+ */
+void InferStatisticsForLookupTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+ auto inputNode = TExprBase(input);
+ auto lookupTable = inputNode.Cast<TKqlLookupTableBase>();
+
+ int nAttrs = lookupTable.Columns().Size();
+ double nRows = 0;
+
+ auto inputStats = typeCtx->GetStats(lookupTable.Table().Raw());
+
+ if (lookupTable.LookupKeys().Maybe<TCoIterator>()) {
+ if (inputStats) {
+ nRows = inputStats->Nrows;
+ } else {
+ return;
+ }
+ } else {
+ nRows = 1;
+ }
+
+ auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0);
+ typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
+}
+
+/**
+ * Compute statistics for RowsSourceSetting
+ * We look into range expression to check if its a point lookup or a full scan
+ * We currently don't try to figure out whether this is a small range vs full scan
+ */
+void InferStatisticsForRowsSourceSettings(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+ auto inputNode = TExprBase(input);
+ auto sourceSettings = inputNode.Cast<TKqpReadRangesSourceSettings>();
+
+ auto inputStats = typeCtx->GetStats(sourceSettings.Table().Raw());
+ if (!inputStats) {
+ return;
+ }
+
+ double nRows = inputStats->Nrows;
+
+ // Check if we have a range expression, in that case just assign a single row to this read
+ // We don't currently check the size of an index lookup
+ if (sourceSettings.RangesExpr().Maybe<TKqlKeyRange>()) {
+ auto range = sourceSettings.RangesExpr().Cast<TKqlKeyRange>();
+ auto maybeFromKey = range.From().Maybe<TKqlKeyTuple>();
+ auto maybeToKey = range.To().Maybe<TKqlKeyTuple>();
+ if (maybeFromKey && maybeToKey) {
+ nRows = 1;
+ }
+ }
+
+ int nAttrs = sourceSettings.Columns().Size();
+ double cost = inputStats->Cost.value();
+
+ auto outputStats = TOptimizerStatistics(nRows, nAttrs, cost);
+ typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
+}
+
+/**
+ * Compute statistics for index lookup
* Currently we just make up a number for cardinality (5) and set cost to 0
-*/
+ */
void InferStatisticsForIndexLookup(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
-
auto outputStats = TOptimizerStatistics(5, 5, 0.0);
- typeCtx->SetStats( input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) );
+ typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
+}
+
+/**
+ * Compute statistics for map join
+ * FIX: Currently we treat all join the same from the cost perspective, need to refine cost function
+ */
+void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+ auto inputNode = TExprBase(input);
+ auto join = inputNode.Cast<TCoMapJoinCore>();
+
+ auto leftArg = join.LeftInput();
+ auto rightArg = join.RightDict();
+
+ auto leftStats = typeCtx->GetStats(leftArg.Raw());
+ auto rightStats = typeCtx->GetStats(rightArg.Raw());
+
+ if (!leftStats || !rightStats) {
+ return;
+ }
+
+ typeCtx->SetStats(join.Raw(), std::make_shared<TOptimizerStatistics>(
+ ComputeJoinStats(*leftStats, *rightStats, MapJoin)));
+}
+
+/**
+ * Compute statistics for grace join
+ * FIX: Currently we treat all join the same from the cost perspective, need to refine cost function
+ */
+void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+ auto inputNode = TExprBase(input);
+ auto join = inputNode.Cast<TCoGraceJoinCore>();
+
+ auto leftArg = join.LeftInput();
+ auto rightArg = join.RightInput();
+
+ auto leftStats = typeCtx->GetStats(leftArg.Raw());
+ auto rightStats = typeCtx->GetStats(rightArg.Raw());
+
+ if (!leftStats || !rightStats) {
+ return;
+ }
+
+ typeCtx->SetStats(join.Raw(), std::make_shared<TOptimizerStatistics>(
+ ComputeJoinStats(*leftStats, *rightStats, GraceJoin)));
+}
+
+/***
+ * Infer statistics for result binding of a stage
+ */
+void InferStatisticsForResultBinding(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx,
+ TVector<TVector<std::shared_ptr<TOptimizerStatistics>>>& txStats) {
+
+ auto inputNode = TExprBase(input);
+ auto param = inputNode.Cast<TCoParameter>();
+
+ if (param.Name().Maybe<TCoAtom>()) {
+ auto atom = param.Name().Cast<TCoAtom>();
+ if (atom.Value().StartsWith("%kqp%tx_result_binding")) {
+ TStringBuf suffix;
+ atom.Value().AfterPrefix("%kqp%tx_result_binding_", suffix);
+ TStringBuf bindingNoStr, resultNoStr;
+ suffix.Split('_', bindingNoStr, resultNoStr);
+
+ int bindingNo;
+ int resultNo;
+ std::from_chars(bindingNoStr.data(), bindingNoStr.data() + bindingNoStr.size(), bindingNo);
+ std::from_chars(resultNoStr.data(), resultNoStr.data() + resultNoStr.size(), resultNo);
+
+ typeCtx->SetStats(param.Name().Raw(), txStats[bindingNo][resultNo]);
+ typeCtx->SetStats(inputNode.Raw(), txStats[bindingNo][resultNo]);
+ }
+ }
+}
+
+/**
+ * Infer statistics for DqSource
+ *
+ * We just pass up the statistics from the Settings of the DqSource
+ */
+void InferStatisticsForDqSource(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+ auto inputNode = TExprBase(input);
+ auto dqSource = inputNode.Cast<TDqSource>();
+ auto inputStats = typeCtx->GetStats(dqSource.Settings().Raw());
+ if (!inputStats) {
+ return;
+ }
+
+ typeCtx->SetStats(input.Get(), inputStats);
+ typeCtx->SetCost(input.Get(), typeCtx->GetCost(dqSource.Settings().Raw()));
+}
+
+/**
+ * When encountering a KqpPhysicalTx, we save the results of the stage in a vector
+ * where it can later be accessed via binding parameters
+ */
+void AppendTxStats(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx,
+ TVector<TVector<std::shared_ptr<TOptimizerStatistics>>>& txStats) {
+
+ auto inputNode = TExprBase(input);
+ auto tx = inputNode.Cast<TKqpPhysicalTx>();
+ TVector<std::shared_ptr<TOptimizerStatistics>> vec;
+
+ for (size_t i = 0; i < tx.Results().Size(); i++) {
+ vec.push_back(typeCtx->GetStats(tx.Results().Item(i).Raw()));
+ }
+
+ txStats.push_back(vec);
}
/**
* DoTransform method matches operators and callables in the query DAG and
* uses pre-computed statistics and costs of the children to compute their cost.
-*/
-IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPtr input,
+ */
+IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPtr input,
TExprNode::TPtr& output, TExprContext& ctx) {
+ Y_UNUSED(ctx);
output = input;
if (!Config->HasOptEnableCostBasedOptimization()) {
return IGraphTransformer::TStatus::Ok;
}
-
+
TOptimizeExprSettings settings(nullptr);
- auto ret = OptimizeExpr(input, output, [*this](const TExprNode::TPtr& input, TExprContext& ctx) {
- Y_UNUSED(ctx);
- auto output = input;
+ TVector<TVector<std::shared_ptr<TOptimizerStatistics>>> txStats;
+
+ VisitExprLambdasLast(
+ input, [*this, &txStats](const TExprNode::TPtr& input) {
- if (TCoFlatMap::Match(input.Get())){
- InferStatisticsForFlatMap(input, TypeCtx);
+ // Generic matchers
+ if (TCoFilterBase::Match(input.Get())){
+ InferStatisticsForFilter(input, TypeCtx);
}
else if(TCoSkipNullMembers::Match(input.Get())){
InferStatisticsForSkipNullMembers(input, TypeCtx);
@@ -82,21 +276,82 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt
else if(TCoAggregateMergeFinalize::Match(input.Get())){
InferStatisticsForAggregateMergeFinalize(input, TypeCtx);
}
+
+ // KQP Matchers
else if(TKqlReadTableBase::Match(input.Get()) || TKqlReadTableRangesBase::Match(input.Get())){
InferStatisticsForReadTable(input, TypeCtx, KqpCtx);
}
- else if(TKqlLookupTableBase::Match(input.Get()) || TKqlLookupIndexBase::Match(input.Get())){
+ else if(TKqlLookupTableBase::Match(input.Get())) {
+ InferStatisticsForLookupTable(input, TypeCtx);
+ }
+ else if(TKqlLookupIndexBase::Match(input.Get())){
InferStatisticsForIndexLookup(input, TypeCtx);
}
+ else if(TKqpTable::Match(input.Get())) {
+ InferStatisticsForKqpTable(input, TypeCtx, KqpCtx);
+ }
+ else if (TKqpReadRangesSourceSettings::Match(input.Get())) {
+ InferStatisticsForRowsSourceSettings(input, TypeCtx);
+ }
+
+ // Join matchers
+ else if(TCoMapJoinCore::Match(input.Get())) {
+ InferStatisticsForMapJoin(input, TypeCtx);
+ }
+ else if(TCoGraceJoinCore::Match(input.Get())) {
+ InferStatisticsForGraceJoin(input, TypeCtx);
+ }
- return output;
- }, ctx, settings);
+ // Do nothing in case of EquiJoin, otherwise the EquiJoin rule won't fire
+ else if(TCoEquiJoin::Match(input.Get())){
+ }
+
+ // In case of DqSource, propagate the statistics from the correct argument
+ else if (TDqSource::Match(input.Get())) {
+ InferStatisticsForDqSource(input, TypeCtx);
+ }
+
+ // Match a result binding atom and connect it to a stage
+ else if(TCoParameter::Match(input.Get())) {
+ InferStatisticsForResultBinding(input, TypeCtx, txStats);
+ }
+
+ // Finally, use a default rule to propagate the statistics and costs
+ else {
+
+ // default sum propagation
+ if (input->ChildrenSize() >= 1) {
+ auto stats = TypeCtx->GetStats(input->ChildRef(0).Get());
+ if (stats) {
+ TypeCtx->SetStats(input.Get(), stats);
+ }
+ }
+ }
+
+ // We have a separate rule for all callables that may use a lambda
+ // we need to take each generic callable and see if it includes a lambda
+ // if so - we will map the input to the callable to the argument of the lambda
+ if (input->IsCallable()) {
+ PropagateStatisticsToLambdaArgument(input, TypeCtx);
+ }
- return ret;
+ return true; },
+
+ [*this, &txStats](const TExprNode::TPtr& input) {
+ if (TDqStageBase::Match(input.Get())) {
+ InferStatisticsForStage(input, TypeCtx);
+ } else if (TKqpPhysicalTx::Match(input.Get())) {
+ AppendTxStats(input, TypeCtx, txStats);
+ } else if (TCoFlatMapBase::Match(input.Get())) {
+ InferStatisticsForFlatMap(input, TypeCtx);
+ }
+
+ return true;
+ });
+ return IGraphTransformer::TStatus::Ok;
}
TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) {
-
return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config));
}
diff --git a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt
index 0e86b20c363..893a2f0cff9 100644
--- a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt
+++ b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt
@@ -85,6 +85,7 @@ target_link_libraries(library-yql-core PUBLIC
target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp
diff --git a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt
index b0e95587bd7..f947778c5f2 100644
--- a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt
+++ b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt
@@ -86,6 +86,7 @@ target_link_libraries(library-yql-core PUBLIC
target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp
diff --git a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt
index b0e95587bd7..f947778c5f2 100644
--- a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt
+++ b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt
@@ -86,6 +86,7 @@ target_link_libraries(library-yql-core PUBLIC
target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp
diff --git a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt
index 0e86b20c363..893a2f0cff9 100644
--- a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt
+++ b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt
@@ -85,6 +85,7 @@ target_link_libraries(library-yql-core PUBLIC
target_sources(library-yql-core PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp
diff --git a/ydb/library/yql/core/ya.make b/ydb/library/yql/core/ya.make
index 4fa07131448..088e2dae50f 100644
--- a/ydb/library/yql/core/ya.make
+++ b/ydb/library/yql/core/ya.make
@@ -5,6 +5,7 @@ SRCS(
yql_atom_enums.h
yql_callable_transform.cpp
yql_callable_transform.h
+ yql_cost_function.cpp
yql_csv.cpp
yql_csv.h
yql_data_provider.h
diff --git a/ydb/library/yql/core/yql_cost_function.cpp b/ydb/library/yql/core/yql_cost_function.cpp
new file mode 100644
index 00000000000..635b0552eb3
--- /dev/null
+++ b/ydb/library/yql/core/yql_cost_function.cpp
@@ -0,0 +1,22 @@
+#include "yql_cost_function.h"
+
+using namespace NYql;
+
+/**
+ * Compute the cost and output cardinality of a join
+ *
+ * Currently a very basic computation targeted at GraceJoin
+ *
+ * The build is on the right side, so we make the build side a bit more expensive than the probe
+*/
+TOptimizerStatistics NYql::ComputeJoinStats(TOptimizerStatistics leftStats, TOptimizerStatistics rightStats, EJoinImplType joinImpl) {
+ Y_UNUSED(joinImpl);
+
+ double newCard = 0.2 * leftStats.Nrows * rightStats.Nrows;
+ int newNCols = leftStats.Ncols + rightStats.Ncols;
+ double cost = leftStats.Nrows + 2.0 * rightStats.Nrows
+ + newCard
+ + leftStats.Cost.value() + rightStats.Cost.value();
+
+ return TOptimizerStatistics(newCard, newNCols, cost);
+}
diff --git a/ydb/library/yql/core/yql_cost_function.h b/ydb/library/yql/core/yql_cost_function.h
new file mode 100644
index 00000000000..eedd8034b96
--- /dev/null
+++ b/ydb/library/yql/core/yql_cost_function.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include "yql_statistics.h"
+
+/**
+ * The cost function for cost based optimizer currently consists of methods for computing
+ * both the cost and cardinalities of individual plan operators
+*/
+namespace NYql {
+
+enum EJoinImplType {
+ DictJoin,
+ MapJoin,
+ GraceJoin
+};
+
+TOptimizerStatistics ComputeJoinStats(TOptimizerStatistics leftStats, TOptimizerStatistics rightStats, EJoinImplType joinType);
+
+} \ No newline at end of file
diff --git a/ydb/library/yql/core/yql_expr_optimize.cpp b/ydb/library/yql/core/yql_expr_optimize.cpp
index db25fed93cb..8887211c00e 100644
--- a/ydb/library/yql/core/yql_expr_optimize.cpp
+++ b/ydb/library/yql/core/yql_expr_optimize.cpp
@@ -329,6 +329,32 @@ namespace {
}
}
+ void VisitExprLambdasLastInternal(const TExprNode::TPtr& node,
+ const TExprVisitPtrFunc& preLambdaFunc,
+ const TExprVisitPtrFunc& postLambdaFunc,
+ TNodeSet& visitedNodes)
+ {
+ if (!visitedNodes.emplace(node.Get()).second) {
+ return;
+ }
+
+ for (auto child : node->Children()) {
+ if (!child->IsLambda()) {
+ VisitExprLambdasLastInternal(child, preLambdaFunc, postLambdaFunc, visitedNodes);
+ }
+ }
+
+ preLambdaFunc(node);
+
+ for (auto child : node->Children()) {
+ if (child->IsLambda()) {
+ VisitExprLambdasLastInternal(child, preLambdaFunc, postLambdaFunc, visitedNodes);
+ }
+ }
+
+ postLambdaFunc(node);
+ }
+
void VisitExprInternal(const TExprNode& node, const TExprVisitRefFunc& preFunc,
const TExprVisitRefFunc& postFunc, TNodeSet& visitedNodes)
{
@@ -863,6 +889,12 @@ void VisitExpr(const TExprNode& root, const TExprVisitRefFunc& func) {
void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func, TNodeSet& visitedNodes) {
VisitExprInternal(root, func, {}, visitedNodes);
}
+
+void VisitExprLambdasLast(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preLambdaFunc, const TExprVisitPtrFunc& postLambdaFunc)
+{
+ TNodeSet visitedNodes;
+ VisitExprLambdasLastInternal(root, preLambdaFunc, postLambdaFunc, visitedNodes);
+}
void VisitExprByFirst(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func) {
TNodeSet visitedNodes;
diff --git a/ydb/library/yql/core/yql_expr_optimize.h b/ydb/library/yql/core/yql_expr_optimize.h
index 7fdc64132b5..f9463751c75 100644
--- a/ydb/library/yql/core/yql_expr_optimize.h
+++ b/ydb/library/yql/core/yql_expr_optimize.h
@@ -58,6 +58,8 @@ void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func);
void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preFunc, const TExprVisitPtrFunc& postFunc);
void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func, TNodeSet& visitedNodes);
void VisitExpr(const TExprNode& root, const TExprVisitRefFunc& func);
+void VisitExprLambdasLast(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preLambdaFunc, const TExprVisitPtrFunc& postLambdaFunc);
+
void VisitExprByFirst(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func);
void VisitExprByFirst(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preFunc, const TExprVisitPtrFunc& postFunc);
diff --git a/ydb/library/yql/core/yql_statistics.cpp b/ydb/library/yql/core/yql_statistics.cpp
index c076592671a..91043b31b7c 100644
--- a/ydb/library/yql/core/yql_statistics.cpp
+++ b/ydb/library/yql/core/yql_statistics.cpp
@@ -20,9 +20,9 @@ bool TOptimizerStatistics::Empty() const {
TOptimizerStatistics& TOptimizerStatistics::operator+=(const TOptimizerStatistics& other) {
Nrows += other.Nrows;
Ncols += other.Ncols;
- if (Cost) {
+ if (Cost.has_value() && other.Cost.has_value()) {
Cost = *Cost + *other.Cost;
- } else {
+ } else if (other.Cost.has_value()) {
Cost = other.Cost;
}
return *this;
diff --git a/ydb/library/yql/core/yql_statistics.h b/ydb/library/yql/core/yql_statistics.h
index ea711e56aaa..a8fed920d69 100644
--- a/ydb/library/yql/core/yql_statistics.h
+++ b/ydb/library/yql/core/yql_statistics.h
@@ -1,5 +1,6 @@
#pragma once
+#include <util/generic/string.h>
#include <optional>
#include <iostream>
@@ -16,10 +17,13 @@ struct TOptimizerStatistics {
double Nrows = 0;
int Ncols = 0;
std::optional<double> Cost;
+ TString Descr;
TOptimizerStatistics() : Cost(std::nullopt) {}
TOptimizerStatistics(double nrows,int ncols): Nrows(nrows), Ncols(ncols), Cost(std::nullopt) {}
TOptimizerStatistics(double nrows,int ncols, double cost): Nrows(nrows), Ncols(ncols), Cost(cost) {}
+ TOptimizerStatistics(double nrows,int ncols, double cost, TString descr): Nrows(nrows), Ncols(ncols), Cost(cost), Descr(descr) {}
+
TOptimizerStatistics& operator+=(const TOptimizerStatistics& other);
bool Empty() const;
diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
index 4eb27f97f59..3e1622cebee 100644
--- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
@@ -8,6 +8,7 @@
#include <ydb/library/yql/providers/common/provider/yql_provider.h>
#include <ydb/library/yql/core/yql_type_helpers.h>
#include <ydb/library/yql/core/yql_statistics.h>
+#include <ydb/library/yql/core/yql_cost_function.h>
#include <ydb/library/yql/core/cbo/cbo_optimizer.h> //interface
@@ -192,29 +193,6 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
virtual ~TJoinOptimizerNode() {}
/**
- * Compute and set the statistics for this node.
- * Currently we have a very rough calculation of statistics
- */
- void ComputeStatistics() {
- double newCard = 0.2 * LeftArg->Stats->Nrows * RightArg->Stats->Nrows;
- int newNCols = LeftArg->Stats->Ncols + RightArg->Stats->Ncols;
- Stats = std::make_shared<TOptimizerStatistics>(newCard,newNCols);
- }
-
- /**
- * Compute the cost of the join based on statistics and costs of children
- * Again, we only have a rought calculation at this time
- */
- double ComputeCost() {
- Y_ENSURE(LeftArg->Stats->Cost.has_value() && RightArg->Stats->Cost.has_value(),
- "Missing values for costs in join computation");
-
- return 2.0 * LeftArg->Stats->Nrows + RightArg->Stats->Nrows
- + Stats->Nrows
- + LeftArg->Stats->Cost.value() + RightArg->Stats->Cost.value();
- }
-
- /**
* Print out the join tree, rooted at this node
*/
virtual void Print(std::stringstream& stream, int ntabs=0) {
@@ -246,11 +224,12 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
* Create a new join and compute its statistics and cost
*/
std::shared_ptr<TJoinOptimizerNode> MakeJoin(std::shared_ptr<IBaseOptimizerNode> left,
- std::shared_ptr<IBaseOptimizerNode> right, const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
+ std::shared_ptr<IBaseOptimizerNode> right,
+ const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions,
+ EJoinImplType joinImpl) {
auto res = std::make_shared<TJoinOptimizerNode>(left, right, joinConditions);
- res->ComputeStatistics();
- res->Stats->Cost = res->ComputeCost();
+ res->Stats = std::make_shared<TOptimizerStatistics>( ComputeJoinStats(*left->Stats, *right->Stats, joinImpl));
return res;
}
@@ -681,20 +660,20 @@ template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, cons
if (! DpTable.contains(joined)) {
TEdge e1 = Graph.FindCrossingEdge(S1, S2);
- DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions);
+ DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin);
TEdge e2 = Graph.FindCrossingEdge(S2, S1);
std::shared_ptr<TJoinOptimizerNode> newJoin =
- MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions);
+ MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin);
if (newJoin->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){
DpTable[joined] = newJoin;
}
} else {
TEdge e1 = Graph.FindCrossingEdge(S1, S2);
std::shared_ptr<TJoinOptimizerNode> newJoin1 =
- MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions);
+ MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin);
TEdge e2 = Graph.FindCrossingEdge(S2, S1);
std::shared_ptr<TJoinOptimizerNode> newJoin2 =
- MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions);
+ MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin);
if (newJoin1->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){
DpTable[joined] = newJoin1;
}
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.cpp b/ydb/library/yql/dq/opt/dq_opt_stat.cpp
index 6424e89478e..67c921d26cd 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_stat.cpp
@@ -1,6 +1,8 @@
#include "dq_opt_stat.h"
#include <ydb/library/yql/core/yql_opt_utils.h>
+#include <ydb/library/yql/utils/log/log.h>
+
namespace NYql::NDq {
@@ -10,15 +12,13 @@ using namespace NNodes;
* For Flatmap we check the input and fetch the statistcs and cost from below
* Then we analyze the filter predicate and compute it's selectivity and apply it
* to the result.
+ *
+ * If this flatmap's lambda is a join, we propagate the join result as the output of FlatMap
*/
void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
auto inputNode = TExprBase(input);
- auto flatmap = inputNode.Cast<TCoFlatMap>();
- if (!IsPredicateFlatMap(flatmap.Lambda().Body().Ref())) {
- return;
- }
-
+ auto flatmap = inputNode.Cast<TCoFlatMapBase>();
auto flatmapInput = flatmap.Input();
auto inputStats = typeCtx->GetStats(flatmapInput.Raw());
@@ -26,15 +26,65 @@ void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationCont
return;
}
+ if (IsPredicateFlatMap(flatmap.Lambda().Body().Ref())) {
+ // Selectivity is the fraction of tuples that are selected by this predicate
+ // Currently we just set the number to 10% before we have statistics and parse
+ // the predicate
+ double selectivity = 0.1;
+
+ auto outputStats = TOptimizerStatistics(inputStats->Nrows * selectivity, inputStats->Ncols);
+
+ typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) );
+ typeCtx->SetCost(input.Get(), typeCtx->GetCost(flatmapInput.Raw()));
+
+ }
+ else if (flatmap.Lambda().Body().Maybe<TCoMapJoinCore>() ||
+ flatmap.Lambda().Body().Maybe<TCoMap>().Input().Maybe<TCoMapJoinCore>() ||
+ flatmap.Lambda().Body().Maybe<TCoJoinDict>() ||
+ flatmap.Lambda().Body().Maybe<TCoMap>().Input().Maybe<TCoJoinDict>()){
+
+ typeCtx->SetStats(input.Get(), typeCtx->GetStats(flatmap.Lambda().Body().Raw()));
+ }
+ else {
+ typeCtx->SetStats(input.Get(), typeCtx->GetStats(flatmapInput.Raw()));
+ }
+}
+
+/**
+ * For Filter we check the input and fetch the statistcs and cost from below
+ * Then we analyze the filter predicate and compute it's selectivity and apply it
+ * to the result, just like in FlatMap, except we check for a specific pattern:
+ * If the filter's lambda is an Exists callable with a Member callable, we set the
+ * selectivity to 1 to be consistent with SkipNullMembers in the logical plan
+ */
+void InferStatisticsForFilter(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto filter = inputNode.Cast<TCoFilterBase>();
+ auto filterInput = filter.Input();
+ auto inputStats = typeCtx->GetStats(filterInput.Raw());
+
+ if (!inputStats){
+ return;
+ }
+
// Selectivity is the fraction of tuples that are selected by this predicate
// Currently we just set the number to 10% before we have statistics and parse
// the predicate
double selectivity = 0.1;
+ auto filterLambda = filter.Lambda();
+ if (auto exists = filterLambda.Body().Maybe<TCoExists>()) {
+ if (exists.Cast().Optional().Maybe<TCoMember>()) {
+ selectivity = 1.0;
+ }
+ }
+
auto outputStats = TOptimizerStatistics(inputStats->Nrows * selectivity, inputStats->Ncols);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) );
- typeCtx->SetCost(input.Get(), typeCtx->GetCost(flatmapInput.Raw()));
+ typeCtx->SetCost(input.Get(), typeCtx->GetCost(filterInput.Raw()));
+
}
/**
@@ -114,4 +164,68 @@ void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTyp
typeCtx->SetCost( input.Get(), typeCtx->GetCost( aggInput.Raw() ) );
}
+/***
+ * For callables that include lambdas, we want to propagate the statistics from lambda's input to its argument, so
+ * that the operators inside lambda receive the correct statistics
+*/
+void PropagateStatisticsToLambdaArgument(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ if (input->ChildrenSize()<2) {
+ return;
+ }
+
+ auto callableInput = input->ChildRef(0);
+
+ // Iterate over all children except for the input
+ // Check if the child is a lambda and propagate the statistics into it
+ for (size_t i=1; i<input->ChildrenSize(); i++) {
+ auto maybeLambda = TExprBase(input->ChildRef(i));
+ if (!maybeLambda.Maybe<TCoLambda>()) {
+ continue;
+ }
+
+ auto lambda = maybeLambda.Cast<TCoLambda>();
+ if (!lambda.Args().Size()){
+ continue;
+ }
+
+ // If the input to the callable is a list, then lambda also takes a list of arguments
+ // So we need to propagate corresponding arguments
+
+ if (callableInput->IsList()){
+ for(size_t j=0; j<callableInput->ChildrenSize(); j++){
+ auto inputStats = typeCtx->GetStats(callableInput->Child(j) );
+ if (inputStats){
+ typeCtx->SetStats( lambda.Args().Arg(j).Raw(), inputStats );
+ typeCtx->SetCost( lambda.Args().Arg(j).Raw(), typeCtx->GetCost( callableInput->Child(j) ));
+ }
+ }
+
+ }
+ else {
+ auto inputStats = typeCtx->GetStats(callableInput.Get());
+ if (!inputStats) {
+ return;
+ }
+
+ typeCtx->SetStats( lambda.Args().Arg(0).Raw(), inputStats );
+ typeCtx->SetCost( lambda.Args().Arg(0).Raw(), typeCtx->GetCost( callableInput.Get() ));
+ }
+ }
+}
+
+/**
+ * After processing the lambda for the stage we set the stage output to the result of the lambda
+*/
+void InferStatisticsForStage(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+ auto inputNode = TExprBase(input);
+ auto stage = inputNode.Cast<TDqStageBase>();
+
+ auto lambdaStats = typeCtx->GetStats( stage.Program().Body().Raw());
+ if (lambdaStats){
+ typeCtx->SetStats( stage.Raw(), lambdaStats );
+ typeCtx->SetCost( stage.Raw(), typeCtx->GetCost( stage.Program().Body().Raw()));
+ }
+}
+
} // namespace NYql::NDq {
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.h b/ydb/library/yql/dq/opt/dq_opt_stat.h
index 01c71771344..4baa3f272e9 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat.h
+++ b/ydb/library/yql/dq/opt/dq_opt_stat.h
@@ -5,9 +5,13 @@
namespace NYql::NDq {
void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void InferStatisticsForFilter(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForExtractMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForAggregateCombine(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void PropagateStatisticsToLambdaArgument(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void PropagateStatisticsToStageArguments(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void InferStatisticsForStage(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
} // namespace NYql::NDq { \ No newline at end of file