diff options
author | pavelvelikhov <pavelvelikhov@yandex-team.com> | 2023-09-22 16:11:12 +0300 |
---|---|---|
committer | pavelvelikhov <pavelvelikhov@yandex-team.com> | 2023-09-22 17:00:42 +0300 |
commit | 4ee466d5baecacc8e10b342caae9b57f8bf477af (patch) | |
tree | 6cc3ff243adb496a14c1945afc3c969f0cc67bdf | |
parent | 77fc18884ef08f62201e024b6c7b97084a358273 (diff) | |
download | ydb-4ee466d5baecacc8e10b342caae9b57f8bf477af.tar.gz |
Compute optimizer statistics at physical level, added to Explain plan
Final commit
20 files changed, 593 insertions, 70 deletions
diff --git a/ydb/core/kqp/host/kqp_explain_prepared.cpp b/ydb/core/kqp/host/kqp_explain_prepared.cpp index 695104abcd1..f72f49c3123 100644 --- a/ydb/core/kqp/host/kqp_explain_prepared.cpp +++ b/ydb/core/kqp/host/kqp_explain_prepared.cpp @@ -15,12 +15,14 @@ using namespace NThreading; class TKqpExplainPreparedTransformer : public NYql::TGraphTransformerBase { public: TKqpExplainPreparedTransformer(TIntrusivePtr<IKqpGateway> gateway, const TString& cluster, - TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry) + TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry, + TTypeAnnotationContext& typeCtx) : Gateway(gateway) , Cluster(cluster) , TransformCtx(transformCtx) , FuncRegistry(funcRegistry) , CurrentTxIndex(0) + , TypeCtx(typeCtx) { TxAlloc = TransformCtx->QueryCtx->QueryData->GetAllocState(); } @@ -57,7 +59,7 @@ public: } PhyQuerySetTxPlans(query, TKqpPhysicalQuery(input), std::move(TxResults), - ctx, Cluster, TransformCtx->Tables, TransformCtx->Config); + ctx, Cluster, TransformCtx->Tables, TransformCtx->Config, TypeCtx); query.SetQueryAst(KqpExprToPrettyString(*input, ctx)); return TStatus::Ok; @@ -112,16 +114,18 @@ private: NThreading::TFuture<IKqpGateway::TExecPhysicalResult> ExecuteFuture; NThreading::TPromise<void> Promise; TTxAllocatorState::TPtr TxAlloc; + TTypeAnnotationContext& TypeCtx; }; TAutoPtr<IGraphTransformer> CreateKqpExplainPreparedTransformer(TIntrusivePtr<IKqpGateway> gateway, const TString& cluster, TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry, - TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider) + TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider, + TTypeAnnotationContext& typeCtx) { Y_UNUSED(randomProvider); Y_UNUSED(timeProvider); - return new TKqpExplainPreparedTransformer(gateway, cluster, transformCtx, funcRegistry); + return new TKqpExplainPreparedTransformer(gateway, cluster, transformCtx, funcRegistry, typeCtx); } } // namespace NKqp diff --git a/ydb/core/kqp/host/kqp_host_impl.h b/ydb/core/kqp/host/kqp_host_impl.h index b2972b01763..d6fb8986b07 100644 --- a/ydb/core/kqp/host/kqp_host_impl.h +++ b/ydb/core/kqp/host/kqp_host_impl.h @@ -249,7 +249,8 @@ TIntrusivePtr<IKqpRunner> CreateKqpRunner(TIntrusivePtr<IKqpGateway> gateway, co TAutoPtr<NYql::IGraphTransformer> CreateKqpExplainPreparedTransformer(TIntrusivePtr<IKqpGateway> gateway, const TString& cluster, TIntrusivePtr<TKqlTransformContext> transformCtx, const NMiniKQL::IFunctionRegistry* funcRegistry, - TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider); + TIntrusivePtr<ITimeProvider> timeProvider, TIntrusivePtr<IRandomProvider> randomProvider, + NYql::TTypeAnnotationContext& typeCtx); TAutoPtr<NYql::IGraphTransformer> CreateKqpTypeAnnotationTransformer(const TString& cluster, TIntrusivePtr<NYql::TKikimrTablesData> tablesData, NYql::TTypeAnnotationContext& typesCtx, diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp index f77295229d8..fa22700cafc 100644 --- a/ydb/core/kqp/host/kqp_runner.cpp +++ b/ydb/core/kqp/host/kqp_runner.cpp @@ -281,7 +281,7 @@ public: PreparedExplainTransformer = TTransformationPipeline(typesCtx) .Add(CreateKqpExplainPreparedTransformer( - Gateway, Cluster, TransformCtx, &funcRegistry, timeProvider, randomProvider), "ExplainQuery") + Gateway, Cluster, TransformCtx, &funcRegistry, timeProvider, randomProvider, *typesCtx), "ExplainQuery") .Build(false); PhysicalOptimizeTransformer = CreateKqpQueryBlocksTransformer(TTransformationPipeline(typesCtx) @@ -329,6 +329,7 @@ public: .AddTypeAnnotationTransformer(CreateKqpTypeAnnotationTransformer(Cluster, sessionCtx->TablesPtr(), *typesCtx, Config)) .AddPostTypeAnnotation() .Add(CreateKqpBuildPhysicalQueryTransformer(OptimizeCtx, BuildQueryCtx), "BuildPhysicalQuery") + .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics") .Build(false); PhysicalPeepholeTransformer = TTransformationPipeline(typesCtx) diff --git a/ydb/core/kqp/opt/kqp_query_plan.cpp b/ydb/core/kqp/opt/kqp_query_plan.cpp index 3836800ca4b..0529b4735cc 100644 --- a/ydb/core/kqp/opt/kqp_query_plan.cpp +++ b/ydb/core/kqp/opt/kqp_query_plan.cpp @@ -66,13 +66,15 @@ struct TSerializerCtx { TSerializerCtx(TExprContext& exprCtx, const TString& cluster, const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, const TKikimrConfiguration::TPtr config, ui32 txCount, - TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults) + TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults, + TTypeAnnotationContext& typeCtx) : ExprCtx(exprCtx) , Cluster(cluster) , TablesData(tablesData) , Config(config) , TxCount(txCount) , PureTxResults(std::move(pureTxResults)) + , TypeCtx(typeCtx) {} TMap<TString, TTableInfo> Tables; @@ -88,6 +90,7 @@ struct TSerializerCtx { const TKikimrConfiguration::TPtr Config; const ui32 TxCount; TVector<TVector<NKikimrMiniKQL::TResult>> PureTxResults; + TTypeAnnotationContext& TypeCtx; }; TString GetExprStr(const TExprBase& scalar, bool quoteStr = true) { @@ -724,6 +727,8 @@ private: op.Properties["Reverse"] = true; } + AddOptimizerEstimates(op, sourceSettings); + SerializerCtx.Tables[table].Reads.push_back(readInfo); if (readInfo.Type == EPlanTableReadType::Scan) { @@ -833,6 +838,8 @@ private: op.Properties["Reverse"] = true; } + AddOptimizerEstimates(op, sourceSettings); + if (readInfo.Type == EPlanTableReadType::FullScan) { op.Properties["Name"] = "TableFullScan"; AddOperator(planNode, "TableFullScan", std::move(op)); @@ -956,6 +963,17 @@ private: })).Cast<TCoJoinDict>(); operatorId = Visit(flatMap, join, planNode); node = join.Ptr(); + } else if (auto maybeJoinDict = TMaybeNode<TCoGraceJoinCore>(node)) { + operatorId = Visit(maybeJoinDict.Cast(), planNode); + } else if (TMaybeNode<TCoFlatMapBase>(node).Lambda().Body().Maybe<TCoGraceJoinCore>() || + TMaybeNode<TCoFlatMapBase>(node).Lambda().Body().Maybe<TCoMap>().Input().Maybe<TCoGraceJoinCore>()) { + auto flatMap = TMaybeNode<TCoFlatMapBase>(node).Cast(); + auto join = TExprBase(FindNode(node, [](const TExprNode::TPtr& node) { + Y_ENSURE(!TMaybeNode<TDqConnection>(node).IsValid()); + return TMaybeNode<TCoGraceJoinCore>(node).IsValid(); + })).Cast<TCoGraceJoinCore>(); + operatorId = Visit(flatMap, join, planNode); + node = join.Ptr(); } else if (auto maybeCondense1 = TMaybeNode<TCoCondense1>(node)) { operatorId = Visit(maybeCondense1.Cast(), planNode); } else if (auto maybeCondense = TMaybeNode<TCoCondense>(node)) { @@ -1130,6 +1148,9 @@ private: TOperator op; op.Properties["Name"] = name; + + AddOptimizerEstimates(op, join); + auto operatorId = AddOperator(planNode, name, std::move(op)); auto inputs = Visit(flatMap.Input().Ptr(), planNode); @@ -1142,6 +1163,9 @@ private: TOperator op; op.Properties["Name"] = name; + + AddOptimizerEstimates(op, join); + return AddOperator(planNode, name, std::move(op)); } @@ -1150,6 +1174,9 @@ private: TOperator op; op.Properties["Name"] = name; + + AddOptimizerEstimates(op, join); + auto operatorId = AddOperator(planNode, name, std::move(op)); auto inputs = Visit(flatMap.Input().Ptr(), planNode); @@ -1162,6 +1189,34 @@ private: TOperator op; op.Properties["Name"] = name; + + AddOptimizerEstimates(op, join); + + return AddOperator(planNode, name, std::move(op)); + } + + ui32 Visit(const TCoFlatMapBase& flatMap, const TCoGraceJoinCore& join, TQueryPlanNode& planNode) { + const auto name = TStringBuilder() << join.JoinKind().Value() << "Join (Grace)"; + + TOperator op; + op.Properties["Name"] = name; + auto operatorId = AddOperator(planNode, name, std::move(op)); + + AddOptimizerEstimates(op, join); + + auto inputs = Visit(flatMap.Input().Ptr(), planNode); + planNode.Operators[operatorId].Inputs.insert(inputs.begin(), inputs.end()); + return operatorId; + } + + ui32 Visit(const TCoGraceJoinCore& join, TQueryPlanNode& planNode) { + const auto name = TStringBuilder() << join.JoinKind().Value() << "Join (Grace)"; + + TOperator op; + op.Properties["Name"] = name; + + AddOptimizerEstimates(op, join); + return AddOperator(planNode, name, std::move(op)); } @@ -1176,6 +1231,22 @@ private: return pred; } + void AddOptimizerEstimates(TOperator& op, const TExprBase& expr) { + if (!SerializerCtx.Config->HasOptEnableCostBasedOptimization()) { + return; + } + + if (auto stats = SerializerCtx.TypeCtx.GetStats(expr.Raw())) { + op.Properties["E-Rows"] = stats->Nrows; + op.Properties["E-Cost"] = stats->Cost.value(); + } + else { + op.Properties["E-Rows"] = "No estimate"; + op.Properties["E-Cost"] = "No estimate"; + + } + } + ui32 Visit(const TCoFilterBase& filter, TQueryPlanNode& planNode) { TOperator op; op.Properties["Name"] = "Filter"; @@ -1183,6 +1254,8 @@ private: auto pred = ExtractPredicate(filter.Lambda()); op.Properties["Predicate"] = pred.Body; + AddOptimizerEstimates(op, filter); + if (filter.Limit()) { op.Properties["Limit"] = PrettyExprStr(filter.Limit().Cast()); } @@ -1205,6 +1278,8 @@ private: columns.AppendValue(col.Value()); } + AddOptimizerEstimates(op, lookup); + SerializerCtx.Tables[table].Reads.push_back(readInfo); planNode.NodeInfo["Tables"].AppendValue(op.Properties["Table"]); return AddOperator(planNode, "TablePointLookup", std::move(op)); @@ -1306,6 +1381,8 @@ private: op.Properties["SsaProgram"] = GetSsaProgramInJsonByTable(table, planNode.StageProto); } + AddOptimizerEstimates(op, read); + ui32 operatorId; if (readInfo.Type == EPlanTableReadType::FullScan) { op.Properties["Name"] = "TableFullScan"; @@ -1444,6 +1521,8 @@ private: SerializerCtx.Tables[table].Reads.push_back(readInfo); + AddOptimizerEstimates(op, read); + ui32 operatorId; if (readInfo.Type == EPlanTableReadType::Scan) { op.Properties["Name"] = "TableRangeScan"; @@ -1664,9 +1743,10 @@ TString SerializeTxPlans(const TVector<const TString>& txPlans, const TString co // TODO(sk): check params from correlated subqueries // lookup join void PhyQuerySetTxPlans(NKqpProto::TKqpPhyQuery& queryProto, const TKqpPhysicalQuery& query, TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults, TExprContext& ctx, const TString& cluster, - const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, TKikimrConfiguration::TPtr config) + const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, TKikimrConfiguration::TPtr config, + TTypeAnnotationContext& typeCtx) { - TSerializerCtx serializerCtx(ctx, cluster, tablesData, config, query.Transactions().Size(), std::move(pureTxResults)); + TSerializerCtx serializerCtx(ctx, cluster, tablesData, config, query.Transactions().Size(), std::move(pureTxResults), typeCtx); /* bindingName -> stage */ auto collectBindings = [&serializerCtx, &query] (auto id, const auto& phase) { diff --git a/ydb/core/kqp/opt/kqp_query_plan.h b/ydb/core/kqp/opt/kqp_query_plan.h index 5b08b9d4968..04ffe9516a5 100644 --- a/ydb/core/kqp/opt/kqp_query_plan.h +++ b/ydb/core/kqp/opt/kqp_query_plan.h @@ -34,7 +34,8 @@ enum class EPlanTableWriteType { */ void PhyQuerySetTxPlans(NKqpProto::TKqpPhyQuery& queryProto, const NYql::NNodes::TKqpPhysicalQuery& query, TVector<TVector<NKikimrMiniKQL::TResult>> pureTxResults, NYql::TExprContext& ctx, const TString& cluster, - const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, NYql::TKikimrConfiguration::TPtr config); + const TIntrusivePtr<NYql::TKikimrTablesData> tablesData, NYql::TKikimrConfiguration::TPtr config, + NYql::TTypeAnnotationContext& typeCtx); /* * Fill stages in given txPlan with ExecutionStats fields. Each plan stage stores StageGuid which is diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp index ab4b0a32188..3b55caf4355 100644 --- a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp @@ -1,7 +1,9 @@ #include "kqp_statistics_transformer.h" #include <ydb/library/yql/utils/log/log.h> #include <ydb/library/yql/dq/opt/dq_opt_stat.h> +#include <ydb/library/yql/core/yql_cost_function.h> +#include <charconv> using namespace NYql; using namespace NYql::NNodes; @@ -11,7 +13,7 @@ using namespace NYql::NDq; /** * Compute statistics and cost for read table * Currently we look up the number of rows and attributes in the statistics service -*/ + */ void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const TKqpOptimizeContext& kqpCtx) { @@ -21,54 +23,246 @@ void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationCo const TExprNode* path; - if ( auto readTable = inputNode.Maybe<TKqlReadTableBase>()){ + if (auto readTable = inputNode.Maybe<TKqlReadTableBase>()) { path = readTable.Cast().Table().Path().Raw(); nAttrs = readTable.Cast().Columns().Size(); - } else if(auto readRanges = inputNode.Maybe<TKqlReadTableRangesBase>()){ + } else if (auto readRanges = inputNode.Maybe<TKqlReadTableRangesBase>()) { path = readRanges.Cast().Table().Path().Raw(); nAttrs = readRanges.Cast().Columns().Size(); } else { - Y_ENSURE(false,"Invalid node type for InferStatisticsForReadTable"); + Y_ENSURE(false, "Invalid node type for InferStatisticsForReadTable"); } const auto& tableData = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, path->Content()); nRows = tableData.Metadata->RecordsCount; + YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table, nrows:" << nRows << ", nattrs: " << nAttrs; auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0.0); - typeCtx->SetStats( input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) ); + typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats)); } /** - * Compute sstatistics for index lookup + * Infer statistics for KQP table + */ +void InferStatisticsForKqpTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, + const TKqpOptimizeContext& kqpCtx) { + + auto inputNode = TExprBase(input); + auto readTable = inputNode.Cast<TKqpTable>(); + auto path = readTable.Path(); + + const auto& tableData = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, path.Value()); + double nRows = tableData.Metadata->RecordsCount; + int nAttrs = tableData.Metadata->Columns.size(); + YQL_CLOG(TRACE, CoreDq) << "Infer statistics for table: " << path.Value() << ", nrows: " << nRows << ", nattrs: " << nAttrs; + + auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0.0); + typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats)); +} + +/** + * Infer statistics for TableLookup + * + * Table lookup can be done with an Iterator, in which case we treat it as a full scan + * We don't differentiate between a small range and full scan at this time + */ +void InferStatisticsForLookupTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + auto inputNode = TExprBase(input); + auto lookupTable = inputNode.Cast<TKqlLookupTableBase>(); + + int nAttrs = lookupTable.Columns().Size(); + double nRows = 0; + + auto inputStats = typeCtx->GetStats(lookupTable.Table().Raw()); + + if (lookupTable.LookupKeys().Maybe<TCoIterator>()) { + if (inputStats) { + nRows = inputStats->Nrows; + } else { + return; + } + } else { + nRows = 1; + } + + auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0); + typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats)); +} + +/** + * Compute statistics for RowsSourceSetting + * We look into range expression to check if its a point lookup or a full scan + * We currently don't try to figure out whether this is a small range vs full scan + */ +void InferStatisticsForRowsSourceSettings(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + auto inputNode = TExprBase(input); + auto sourceSettings = inputNode.Cast<TKqpReadRangesSourceSettings>(); + + auto inputStats = typeCtx->GetStats(sourceSettings.Table().Raw()); + if (!inputStats) { + return; + } + + double nRows = inputStats->Nrows; + + // Check if we have a range expression, in that case just assign a single row to this read + // We don't currently check the size of an index lookup + if (sourceSettings.RangesExpr().Maybe<TKqlKeyRange>()) { + auto range = sourceSettings.RangesExpr().Cast<TKqlKeyRange>(); + auto maybeFromKey = range.From().Maybe<TKqlKeyTuple>(); + auto maybeToKey = range.To().Maybe<TKqlKeyTuple>(); + if (maybeFromKey && maybeToKey) { + nRows = 1; + } + } + + int nAttrs = sourceSettings.Columns().Size(); + double cost = inputStats->Cost.value(); + + auto outputStats = TOptimizerStatistics(nRows, nAttrs, cost); + typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats)); +} + +/** + * Compute statistics for index lookup * Currently we just make up a number for cardinality (5) and set cost to 0 -*/ + */ void InferStatisticsForIndexLookup(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { - auto outputStats = TOptimizerStatistics(5, 5, 0.0); - typeCtx->SetStats( input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) ); + typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats)); +} + +/** + * Compute statistics for map join + * FIX: Currently we treat all join the same from the cost perspective, need to refine cost function + */ +void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + auto inputNode = TExprBase(input); + auto join = inputNode.Cast<TCoMapJoinCore>(); + + auto leftArg = join.LeftInput(); + auto rightArg = join.RightDict(); + + auto leftStats = typeCtx->GetStats(leftArg.Raw()); + auto rightStats = typeCtx->GetStats(rightArg.Raw()); + + if (!leftStats || !rightStats) { + return; + } + + typeCtx->SetStats(join.Raw(), std::make_shared<TOptimizerStatistics>( + ComputeJoinStats(*leftStats, *rightStats, MapJoin))); +} + +/** + * Compute statistics for grace join + * FIX: Currently we treat all join the same from the cost perspective, need to refine cost function + */ +void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + auto inputNode = TExprBase(input); + auto join = inputNode.Cast<TCoGraceJoinCore>(); + + auto leftArg = join.LeftInput(); + auto rightArg = join.RightInput(); + + auto leftStats = typeCtx->GetStats(leftArg.Raw()); + auto rightStats = typeCtx->GetStats(rightArg.Raw()); + + if (!leftStats || !rightStats) { + return; + } + + typeCtx->SetStats(join.Raw(), std::make_shared<TOptimizerStatistics>( + ComputeJoinStats(*leftStats, *rightStats, GraceJoin))); +} + +/*** + * Infer statistics for result binding of a stage + */ +void InferStatisticsForResultBinding(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, + TVector<TVector<std::shared_ptr<TOptimizerStatistics>>>& txStats) { + + auto inputNode = TExprBase(input); + auto param = inputNode.Cast<TCoParameter>(); + + if (param.Name().Maybe<TCoAtom>()) { + auto atom = param.Name().Cast<TCoAtom>(); + if (atom.Value().StartsWith("%kqp%tx_result_binding")) { + TStringBuf suffix; + atom.Value().AfterPrefix("%kqp%tx_result_binding_", suffix); + TStringBuf bindingNoStr, resultNoStr; + suffix.Split('_', bindingNoStr, resultNoStr); + + int bindingNo; + int resultNo; + std::from_chars(bindingNoStr.data(), bindingNoStr.data() + bindingNoStr.size(), bindingNo); + std::from_chars(resultNoStr.data(), resultNoStr.data() + resultNoStr.size(), resultNo); + + typeCtx->SetStats(param.Name().Raw(), txStats[bindingNo][resultNo]); + typeCtx->SetStats(inputNode.Raw(), txStats[bindingNo][resultNo]); + } + } +} + +/** + * Infer statistics for DqSource + * + * We just pass up the statistics from the Settings of the DqSource + */ +void InferStatisticsForDqSource(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + auto inputNode = TExprBase(input); + auto dqSource = inputNode.Cast<TDqSource>(); + auto inputStats = typeCtx->GetStats(dqSource.Settings().Raw()); + if (!inputStats) { + return; + } + + typeCtx->SetStats(input.Get(), inputStats); + typeCtx->SetCost(input.Get(), typeCtx->GetCost(dqSource.Settings().Raw())); +} + +/** + * When encountering a KqpPhysicalTx, we save the results of the stage in a vector + * where it can later be accessed via binding parameters + */ +void AppendTxStats(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, + TVector<TVector<std::shared_ptr<TOptimizerStatistics>>>& txStats) { + + auto inputNode = TExprBase(input); + auto tx = inputNode.Cast<TKqpPhysicalTx>(); + TVector<std::shared_ptr<TOptimizerStatistics>> vec; + + for (size_t i = 0; i < tx.Results().Size(); i++) { + vec.push_back(typeCtx->GetStats(tx.Results().Item(i).Raw())); + } + + txStats.push_back(vec); } /** * DoTransform method matches operators and callables in the query DAG and * uses pre-computed statistics and costs of the children to compute their cost. -*/ -IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPtr input, + */ +IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) { + Y_UNUSED(ctx); output = input; if (!Config->HasOptEnableCostBasedOptimization()) { return IGraphTransformer::TStatus::Ok; } - + TOptimizeExprSettings settings(nullptr); - auto ret = OptimizeExpr(input, output, [*this](const TExprNode::TPtr& input, TExprContext& ctx) { - Y_UNUSED(ctx); - auto output = input; + TVector<TVector<std::shared_ptr<TOptimizerStatistics>>> txStats; + + VisitExprLambdasLast( + input, [*this, &txStats](const TExprNode::TPtr& input) { - if (TCoFlatMap::Match(input.Get())){ - InferStatisticsForFlatMap(input, TypeCtx); + // Generic matchers + if (TCoFilterBase::Match(input.Get())){ + InferStatisticsForFilter(input, TypeCtx); } else if(TCoSkipNullMembers::Match(input.Get())){ InferStatisticsForSkipNullMembers(input, TypeCtx); @@ -82,21 +276,82 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt else if(TCoAggregateMergeFinalize::Match(input.Get())){ InferStatisticsForAggregateMergeFinalize(input, TypeCtx); } + + // KQP Matchers else if(TKqlReadTableBase::Match(input.Get()) || TKqlReadTableRangesBase::Match(input.Get())){ InferStatisticsForReadTable(input, TypeCtx, KqpCtx); } - else if(TKqlLookupTableBase::Match(input.Get()) || TKqlLookupIndexBase::Match(input.Get())){ + else if(TKqlLookupTableBase::Match(input.Get())) { + InferStatisticsForLookupTable(input, TypeCtx); + } + else if(TKqlLookupIndexBase::Match(input.Get())){ InferStatisticsForIndexLookup(input, TypeCtx); } + else if(TKqpTable::Match(input.Get())) { + InferStatisticsForKqpTable(input, TypeCtx, KqpCtx); + } + else if (TKqpReadRangesSourceSettings::Match(input.Get())) { + InferStatisticsForRowsSourceSettings(input, TypeCtx); + } + + // Join matchers + else if(TCoMapJoinCore::Match(input.Get())) { + InferStatisticsForMapJoin(input, TypeCtx); + } + else if(TCoGraceJoinCore::Match(input.Get())) { + InferStatisticsForGraceJoin(input, TypeCtx); + } - return output; - }, ctx, settings); + // Do nothing in case of EquiJoin, otherwise the EquiJoin rule won't fire + else if(TCoEquiJoin::Match(input.Get())){ + } + + // In case of DqSource, propagate the statistics from the correct argument + else if (TDqSource::Match(input.Get())) { + InferStatisticsForDqSource(input, TypeCtx); + } + + // Match a result binding atom and connect it to a stage + else if(TCoParameter::Match(input.Get())) { + InferStatisticsForResultBinding(input, TypeCtx, txStats); + } + + // Finally, use a default rule to propagate the statistics and costs + else { + + // default sum propagation + if (input->ChildrenSize() >= 1) { + auto stats = TypeCtx->GetStats(input->ChildRef(0).Get()); + if (stats) { + TypeCtx->SetStats(input.Get(), stats); + } + } + } + + // We have a separate rule for all callables that may use a lambda + // we need to take each generic callable and see if it includes a lambda + // if so - we will map the input to the callable to the argument of the lambda + if (input->IsCallable()) { + PropagateStatisticsToLambdaArgument(input, TypeCtx); + } - return ret; + return true; }, + + [*this, &txStats](const TExprNode::TPtr& input) { + if (TDqStageBase::Match(input.Get())) { + InferStatisticsForStage(input, TypeCtx); + } else if (TKqpPhysicalTx::Match(input.Get())) { + AppendTxStats(input, TypeCtx, txStats); + } else if (TCoFlatMapBase::Match(input.Get())) { + InferStatisticsForFlatMap(input, TypeCtx); + } + + return true; + }); + return IGraphTransformer::TStatus::Ok; } TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx, TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) { - return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config)); } diff --git a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt index 0e86b20c363..893a2f0cff9 100644 --- a/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt +++ b/ydb/library/yql/core/CMakeLists.darwin-x86_64.txt @@ -85,6 +85,7 @@ target_link_libraries(library-yql-core PUBLIC target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp diff --git a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt index b0e95587bd7..f947778c5f2 100644 --- a/ydb/library/yql/core/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/core/CMakeLists.linux-aarch64.txt @@ -86,6 +86,7 @@ target_link_libraries(library-yql-core PUBLIC target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp diff --git a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt index b0e95587bd7..f947778c5f2 100644 --- a/ydb/library/yql/core/CMakeLists.linux-x86_64.txt +++ b/ydb/library/yql/core/CMakeLists.linux-x86_64.txt @@ -86,6 +86,7 @@ target_link_libraries(library-yql-core PUBLIC target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp diff --git a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt index 0e86b20c363..893a2f0cff9 100644 --- a/ydb/library/yql/core/CMakeLists.windows-x86_64.txt +++ b/ydb/library/yql/core/CMakeLists.windows-x86_64.txt @@ -85,6 +85,7 @@ target_link_libraries(library-yql-core PUBLIC target_sources(library-yql-core PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_aggregate_expander.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_callable_transform.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_cost_function.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_csv.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_execution.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/core/yql_expr_constraint.cpp diff --git a/ydb/library/yql/core/ya.make b/ydb/library/yql/core/ya.make index 4fa07131448..088e2dae50f 100644 --- a/ydb/library/yql/core/ya.make +++ b/ydb/library/yql/core/ya.make @@ -5,6 +5,7 @@ SRCS( yql_atom_enums.h yql_callable_transform.cpp yql_callable_transform.h + yql_cost_function.cpp yql_csv.cpp yql_csv.h yql_data_provider.h diff --git a/ydb/library/yql/core/yql_cost_function.cpp b/ydb/library/yql/core/yql_cost_function.cpp new file mode 100644 index 00000000000..635b0552eb3 --- /dev/null +++ b/ydb/library/yql/core/yql_cost_function.cpp @@ -0,0 +1,22 @@ +#include "yql_cost_function.h" + +using namespace NYql; + +/** + * Compute the cost and output cardinality of a join + * + * Currently a very basic computation targeted at GraceJoin + * + * The build is on the right side, so we make the build side a bit more expensive than the probe +*/ +TOptimizerStatistics NYql::ComputeJoinStats(TOptimizerStatistics leftStats, TOptimizerStatistics rightStats, EJoinImplType joinImpl) { + Y_UNUSED(joinImpl); + + double newCard = 0.2 * leftStats.Nrows * rightStats.Nrows; + int newNCols = leftStats.Ncols + rightStats.Ncols; + double cost = leftStats.Nrows + 2.0 * rightStats.Nrows + + newCard + + leftStats.Cost.value() + rightStats.Cost.value(); + + return TOptimizerStatistics(newCard, newNCols, cost); +} diff --git a/ydb/library/yql/core/yql_cost_function.h b/ydb/library/yql/core/yql_cost_function.h new file mode 100644 index 00000000000..eedd8034b96 --- /dev/null +++ b/ydb/library/yql/core/yql_cost_function.h @@ -0,0 +1,19 @@ +#pragma once + +#include "yql_statistics.h" + +/** + * The cost function for cost based optimizer currently consists of methods for computing + * both the cost and cardinalities of individual plan operators +*/ +namespace NYql { + +enum EJoinImplType { + DictJoin, + MapJoin, + GraceJoin +}; + +TOptimizerStatistics ComputeJoinStats(TOptimizerStatistics leftStats, TOptimizerStatistics rightStats, EJoinImplType joinType); + +}
\ No newline at end of file diff --git a/ydb/library/yql/core/yql_expr_optimize.cpp b/ydb/library/yql/core/yql_expr_optimize.cpp index db25fed93cb..8887211c00e 100644 --- a/ydb/library/yql/core/yql_expr_optimize.cpp +++ b/ydb/library/yql/core/yql_expr_optimize.cpp @@ -329,6 +329,32 @@ namespace { } } + void VisitExprLambdasLastInternal(const TExprNode::TPtr& node, + const TExprVisitPtrFunc& preLambdaFunc, + const TExprVisitPtrFunc& postLambdaFunc, + TNodeSet& visitedNodes) + { + if (!visitedNodes.emplace(node.Get()).second) { + return; + } + + for (auto child : node->Children()) { + if (!child->IsLambda()) { + VisitExprLambdasLastInternal(child, preLambdaFunc, postLambdaFunc, visitedNodes); + } + } + + preLambdaFunc(node); + + for (auto child : node->Children()) { + if (child->IsLambda()) { + VisitExprLambdasLastInternal(child, preLambdaFunc, postLambdaFunc, visitedNodes); + } + } + + postLambdaFunc(node); + } + void VisitExprInternal(const TExprNode& node, const TExprVisitRefFunc& preFunc, const TExprVisitRefFunc& postFunc, TNodeSet& visitedNodes) { @@ -863,6 +889,12 @@ void VisitExpr(const TExprNode& root, const TExprVisitRefFunc& func) { void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func, TNodeSet& visitedNodes) { VisitExprInternal(root, func, {}, visitedNodes); } + +void VisitExprLambdasLast(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preLambdaFunc, const TExprVisitPtrFunc& postLambdaFunc) +{ + TNodeSet visitedNodes; + VisitExprLambdasLastInternal(root, preLambdaFunc, postLambdaFunc, visitedNodes); +} void VisitExprByFirst(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func) { TNodeSet visitedNodes; diff --git a/ydb/library/yql/core/yql_expr_optimize.h b/ydb/library/yql/core/yql_expr_optimize.h index 7fdc64132b5..f9463751c75 100644 --- a/ydb/library/yql/core/yql_expr_optimize.h +++ b/ydb/library/yql/core/yql_expr_optimize.h @@ -58,6 +58,8 @@ void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func); void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preFunc, const TExprVisitPtrFunc& postFunc); void VisitExpr(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func, TNodeSet& visitedNodes); void VisitExpr(const TExprNode& root, const TExprVisitRefFunc& func); +void VisitExprLambdasLast(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preLambdaFunc, const TExprVisitPtrFunc& postLambdaFunc); + void VisitExprByFirst(const TExprNode::TPtr& root, const TExprVisitPtrFunc& func); void VisitExprByFirst(const TExprNode::TPtr& root, const TExprVisitPtrFunc& preFunc, const TExprVisitPtrFunc& postFunc); diff --git a/ydb/library/yql/core/yql_statistics.cpp b/ydb/library/yql/core/yql_statistics.cpp index c076592671a..91043b31b7c 100644 --- a/ydb/library/yql/core/yql_statistics.cpp +++ b/ydb/library/yql/core/yql_statistics.cpp @@ -20,9 +20,9 @@ bool TOptimizerStatistics::Empty() const { TOptimizerStatistics& TOptimizerStatistics::operator+=(const TOptimizerStatistics& other) { Nrows += other.Nrows; Ncols += other.Ncols; - if (Cost) { + if (Cost.has_value() && other.Cost.has_value()) { Cost = *Cost + *other.Cost; - } else { + } else if (other.Cost.has_value()) { Cost = other.Cost; } return *this; diff --git a/ydb/library/yql/core/yql_statistics.h b/ydb/library/yql/core/yql_statistics.h index ea711e56aaa..a8fed920d69 100644 --- a/ydb/library/yql/core/yql_statistics.h +++ b/ydb/library/yql/core/yql_statistics.h @@ -1,5 +1,6 @@ #pragma once +#include <util/generic/string.h> #include <optional> #include <iostream> @@ -16,10 +17,13 @@ struct TOptimizerStatistics { double Nrows = 0; int Ncols = 0; std::optional<double> Cost; + TString Descr; TOptimizerStatistics() : Cost(std::nullopt) {} TOptimizerStatistics(double nrows,int ncols): Nrows(nrows), Ncols(ncols), Cost(std::nullopt) {} TOptimizerStatistics(double nrows,int ncols, double cost): Nrows(nrows), Ncols(ncols), Cost(cost) {} + TOptimizerStatistics(double nrows,int ncols, double cost, TString descr): Nrows(nrows), Ncols(ncols), Cost(cost), Descr(descr) {} + TOptimizerStatistics& operator+=(const TOptimizerStatistics& other); bool Empty() const; diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp index 4eb27f97f59..3e1622cebee 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp @@ -8,6 +8,7 @@ #include <ydb/library/yql/providers/common/provider/yql_provider.h> #include <ydb/library/yql/core/yql_type_helpers.h> #include <ydb/library/yql/core/yql_statistics.h> +#include <ydb/library/yql/core/yql_cost_function.h> #include <ydb/library/yql/core/cbo/cbo_optimizer.h> //interface @@ -192,29 +193,6 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode { virtual ~TJoinOptimizerNode() {} /** - * Compute and set the statistics for this node. - * Currently we have a very rough calculation of statistics - */ - void ComputeStatistics() { - double newCard = 0.2 * LeftArg->Stats->Nrows * RightArg->Stats->Nrows; - int newNCols = LeftArg->Stats->Ncols + RightArg->Stats->Ncols; - Stats = std::make_shared<TOptimizerStatistics>(newCard,newNCols); - } - - /** - * Compute the cost of the join based on statistics and costs of children - * Again, we only have a rought calculation at this time - */ - double ComputeCost() { - Y_ENSURE(LeftArg->Stats->Cost.has_value() && RightArg->Stats->Cost.has_value(), - "Missing values for costs in join computation"); - - return 2.0 * LeftArg->Stats->Nrows + RightArg->Stats->Nrows - + Stats->Nrows - + LeftArg->Stats->Cost.value() + RightArg->Stats->Cost.value(); - } - - /** * Print out the join tree, rooted at this node */ virtual void Print(std::stringstream& stream, int ntabs=0) { @@ -246,11 +224,12 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode { * Create a new join and compute its statistics and cost */ std::shared_ptr<TJoinOptimizerNode> MakeJoin(std::shared_ptr<IBaseOptimizerNode> left, - std::shared_ptr<IBaseOptimizerNode> right, const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { + std::shared_ptr<IBaseOptimizerNode> right, + const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions, + EJoinImplType joinImpl) { auto res = std::make_shared<TJoinOptimizerNode>(left, right, joinConditions); - res->ComputeStatistics(); - res->Stats->Cost = res->ComputeCost(); + res->Stats = std::make_shared<TOptimizerStatistics>( ComputeJoinStats(*left->Stats, *right->Stats, joinImpl)); return res; } @@ -681,20 +660,20 @@ template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, cons if (! DpTable.contains(joined)) { TEdge e1 = Graph.FindCrossingEdge(S1, S2); - DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions); + DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin); TEdge e2 = Graph.FindCrossingEdge(S2, S1); std::shared_ptr<TJoinOptimizerNode> newJoin = - MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions); + MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin); if (newJoin->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){ DpTable[joined] = newJoin; } } else { TEdge e1 = Graph.FindCrossingEdge(S1, S2); std::shared_ptr<TJoinOptimizerNode> newJoin1 = - MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions); + MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin); TEdge e2 = Graph.FindCrossingEdge(S2, S1); std::shared_ptr<TJoinOptimizerNode> newJoin2 = - MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions); + MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin); if (newJoin1->Stats->Cost.value() < DpTable[joined]->Stats->Cost.value()){ DpTable[joined] = newJoin1; } diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.cpp b/ydb/library/yql/dq/opt/dq_opt_stat.cpp index 6424e89478e..67c921d26cd 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_stat.cpp @@ -1,6 +1,8 @@ #include "dq_opt_stat.h" #include <ydb/library/yql/core/yql_opt_utils.h> +#include <ydb/library/yql/utils/log/log.h> + namespace NYql::NDq { @@ -10,15 +12,13 @@ using namespace NNodes; * For Flatmap we check the input and fetch the statistcs and cost from below * Then we analyze the filter predicate and compute it's selectivity and apply it * to the result. + * + * If this flatmap's lambda is a join, we propagate the join result as the output of FlatMap */ void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { auto inputNode = TExprBase(input); - auto flatmap = inputNode.Cast<TCoFlatMap>(); - if (!IsPredicateFlatMap(flatmap.Lambda().Body().Ref())) { - return; - } - + auto flatmap = inputNode.Cast<TCoFlatMapBase>(); auto flatmapInput = flatmap.Input(); auto inputStats = typeCtx->GetStats(flatmapInput.Raw()); @@ -26,15 +26,65 @@ void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationCont return; } + if (IsPredicateFlatMap(flatmap.Lambda().Body().Ref())) { + // Selectivity is the fraction of tuples that are selected by this predicate + // Currently we just set the number to 10% before we have statistics and parse + // the predicate + double selectivity = 0.1; + + auto outputStats = TOptimizerStatistics(inputStats->Nrows * selectivity, inputStats->Ncols); + + typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) ); + typeCtx->SetCost(input.Get(), typeCtx->GetCost(flatmapInput.Raw())); + + } + else if (flatmap.Lambda().Body().Maybe<TCoMapJoinCore>() || + flatmap.Lambda().Body().Maybe<TCoMap>().Input().Maybe<TCoMapJoinCore>() || + flatmap.Lambda().Body().Maybe<TCoJoinDict>() || + flatmap.Lambda().Body().Maybe<TCoMap>().Input().Maybe<TCoJoinDict>()){ + + typeCtx->SetStats(input.Get(), typeCtx->GetStats(flatmap.Lambda().Body().Raw())); + } + else { + typeCtx->SetStats(input.Get(), typeCtx->GetStats(flatmapInput.Raw())); + } +} + +/** + * For Filter we check the input and fetch the statistcs and cost from below + * Then we analyze the filter predicate and compute it's selectivity and apply it + * to the result, just like in FlatMap, except we check for a specific pattern: + * If the filter's lambda is an Exists callable with a Member callable, we set the + * selectivity to 1 to be consistent with SkipNullMembers in the logical plan + */ +void InferStatisticsForFilter(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto inputNode = TExprBase(input); + auto filter = inputNode.Cast<TCoFilterBase>(); + auto filterInput = filter.Input(); + auto inputStats = typeCtx->GetStats(filterInput.Raw()); + + if (!inputStats){ + return; + } + // Selectivity is the fraction of tuples that are selected by this predicate // Currently we just set the number to 10% before we have statistics and parse // the predicate double selectivity = 0.1; + auto filterLambda = filter.Lambda(); + if (auto exists = filterLambda.Body().Maybe<TCoExists>()) { + if (exists.Cast().Optional().Maybe<TCoMember>()) { + selectivity = 1.0; + } + } + auto outputStats = TOptimizerStatistics(inputStats->Nrows * selectivity, inputStats->Ncols); typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) ); - typeCtx->SetCost(input.Get(), typeCtx->GetCost(flatmapInput.Raw())); + typeCtx->SetCost(input.Get(), typeCtx->GetCost(filterInput.Raw())); + } /** @@ -114,4 +164,68 @@ void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTyp typeCtx->SetCost( input.Get(), typeCtx->GetCost( aggInput.Raw() ) ); } +/*** + * For callables that include lambdas, we want to propagate the statistics from lambda's input to its argument, so + * that the operators inside lambda receive the correct statistics +*/ +void PropagateStatisticsToLambdaArgument(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + if (input->ChildrenSize()<2) { + return; + } + + auto callableInput = input->ChildRef(0); + + // Iterate over all children except for the input + // Check if the child is a lambda and propagate the statistics into it + for (size_t i=1; i<input->ChildrenSize(); i++) { + auto maybeLambda = TExprBase(input->ChildRef(i)); + if (!maybeLambda.Maybe<TCoLambda>()) { + continue; + } + + auto lambda = maybeLambda.Cast<TCoLambda>(); + if (!lambda.Args().Size()){ + continue; + } + + // If the input to the callable is a list, then lambda also takes a list of arguments + // So we need to propagate corresponding arguments + + if (callableInput->IsList()){ + for(size_t j=0; j<callableInput->ChildrenSize(); j++){ + auto inputStats = typeCtx->GetStats(callableInput->Child(j) ); + if (inputStats){ + typeCtx->SetStats( lambda.Args().Arg(j).Raw(), inputStats ); + typeCtx->SetCost( lambda.Args().Arg(j).Raw(), typeCtx->GetCost( callableInput->Child(j) )); + } + } + + } + else { + auto inputStats = typeCtx->GetStats(callableInput.Get()); + if (!inputStats) { + return; + } + + typeCtx->SetStats( lambda.Args().Arg(0).Raw(), inputStats ); + typeCtx->SetCost( lambda.Args().Arg(0).Raw(), typeCtx->GetCost( callableInput.Get() )); + } + } +} + +/** + * After processing the lambda for the stage we set the stage output to the result of the lambda +*/ +void InferStatisticsForStage(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + auto inputNode = TExprBase(input); + auto stage = inputNode.Cast<TDqStageBase>(); + + auto lambdaStats = typeCtx->GetStats( stage.Program().Body().Raw()); + if (lambdaStats){ + typeCtx->SetStats( stage.Raw(), lambdaStats ); + typeCtx->SetCost( stage.Raw(), typeCtx->GetCost( stage.Program().Body().Raw())); + } +} + } // namespace NYql::NDq { diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.h b/ydb/library/yql/dq/opt/dq_opt_stat.h index 01c71771344..4baa3f272e9 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat.h +++ b/ydb/library/yql/dq/opt/dq_opt_stat.h @@ -5,9 +5,13 @@ namespace NYql::NDq { void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void InferStatisticsForFilter(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForExtractMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForAggregateCombine(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void PropagateStatisticsToLambdaArgument(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void PropagateStatisticsToStageArguments(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void InferStatisticsForStage(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); } // namespace NYql::NDq {
\ No newline at end of file |