diff options
author | pavelvelikhov <pavelvelikhov@yandex-team.com> | 2023-08-30 18:04:02 +0300 |
---|---|---|
committer | pavelvelikhov <pavelvelikhov@yandex-team.com> | 2023-08-30 18:45:04 +0300 |
commit | 6420e048afe32b6590f6b52f9fc6ffdcfab642ba (patch) | |
tree | c9e4e1988e574361e388c9ccc0c36812ac8ed925 | |
parent | 15b722615d5789b04575b0b279cf6f391199da68 (diff) | |
download | ydb-6420e048afe32b6590f6b52f9fc6ffdcfab642ba.tar.gz |
Integrated with statistics, bug fixes
formatting change
Stashed conflicts resolved
resolved conflicts
-rw-r--r-- | ydb/core/kqp/compile_service/kqp_compile_actor.cpp | 2 | ||||
-rw-r--r-- | ydb/core/kqp/gateway/kqp_metadata_loader.cpp | 55 | ||||
-rw-r--r-- | ydb/core/kqp/gateway/kqp_metadata_loader.h | 9 | ||||
-rw-r--r-- | ydb/core/kqp/host/kqp_runner.cpp | 2 | ||||
-rw-r--r-- | ydb/core/kqp/opt/kqp_statistics_transformer.cpp | 51 | ||||
-rw-r--r-- | ydb/core/kqp/opt/kqp_statistics_transformer.h | 17 | ||||
-rw-r--r-- | ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp | 2 | ||||
-rw-r--r-- | ydb/core/kqp/session_actor/kqp_worker_actor.cpp | 3 | ||||
-rw-r--r-- | ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp | 2 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_statistics.cpp | 4 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp | 484 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_stat.cpp | 57 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_stat.h | 3 |
13 files changed, 393 insertions, 298 deletions
diff --git a/ydb/core/kqp/compile_service/kqp_compile_actor.cpp b/ydb/core/kqp/compile_service/kqp_compile_actor.cpp index 05d0f8343d6..b2121dc5141 100644 --- a/ydb/core/kqp/compile_service/kqp_compile_actor.cpp +++ b/ydb/core/kqp/compile_service/kqp_compile_actor.cpp @@ -117,7 +117,7 @@ public: counters->TxProxyMon = new NTxProxy::TTxProxyMon(AppData(ctx)->Counters); std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>( - TlsActivationContext->ActorSystem(), true, TempTablesState); + TlsActivationContext->ActorSystem(), Config, true, TempTablesState); Gateway = CreateKikimrIcGateway(QueryId.Cluster, QueryId.Database, std::move(loader), ctx.ExecutorThread.ActorSystem, ctx.SelfID.NodeId(), counters); Gateway->SetToken(QueryId.Cluster, UserToken); diff --git a/ydb/core/kqp/gateway/kqp_metadata_loader.cpp b/ydb/core/kqp/gateway/kqp_metadata_loader.cpp index aeb30a29ff0..c24d594a30d 100644 --- a/ydb/core/kqp/gateway/kqp_metadata_loader.cpp +++ b/ydb/core/kqp/gateway/kqp_metadata_loader.cpp @@ -2,6 +2,8 @@ #include "actors/kqp_ic_gateway_actors.h" #include <ydb/core/base/path.h> +#include <ydb/core/statistics/events.h> +#include <ydb/core/statistics/stat_service.h> #include <library/cpp/actors/core/hfunc.h> #include <library/cpp/actors/core/log.h> @@ -688,7 +690,7 @@ NThreading::TFuture<TTableMetadataResult> TKqpTableMetadataLoader::LoadTableMeta const auto schemeCacheId = MakeSchemeCacheID(); - return SendActorRequest<TRequest, TResponse, TResult>( + auto future = SendActorRequest<TRequest, TResponse, TResult>( ActorSystem, schemeCacheId, ev.Release(), @@ -778,6 +780,57 @@ NThreading::TFuture<TTableMetadataResult> TKqpTableMetadataLoader::LoadTableMeta promise.SetValue(ResultFromException<TResult>(e)); } }); + + // Create an apply for the future that will fetch table statistics and save it in the metadata + // This method will only run if cost based optimization is enabled + + if (!Config || !Config->HasOptEnableCostBasedOptimization()){ + return future; + } + + TActorSystem* actorSystem = ActorSystem; + + return future.Apply([actorSystem,table](const TFuture<TTableMetadataResult>& f) { + auto result = f.GetValue(); + if (!result.Success()) { + return MakeFuture(result); + } + + if (!result.Metadata->DoesExist){ + return MakeFuture(result); + } + + if (result.Metadata->Kind != NYql::EKikimrTableKind::Datashard && + result.Metadata->Kind != NYql::EKikimrTableKind::Olap) { + return MakeFuture(result); + } + + NKikimr::NStat::TRequest t; + t.StatType = NKikimr::NStat::EStatType::SIMPLE; + t.PathId = NKikimr::TPathId(result.Metadata->PathId.OwnerId(), result.Metadata->PathId.TableId()); + + auto event = MakeHolder<NStat::TEvStatistics::TEvGetStatistics>(); + event->StatRequests.push_back(t); + + auto statServiceId = NStat::MakeStatServiceID(); + + + return SendActorRequest<NStat::TEvStatistics::TEvGetStatistics, NStat::TEvStatistics::TEvGetStatisticsResult, TResult>( + actorSystem, + statServiceId, + event.Release(), + [result](TPromise<TResult> promise, NStat::TEvStatistics::TEvGetStatisticsResult&& response){ + if (!response.StatResponses.size()){ + return; + } + auto resp = response.StatResponses[0]; + auto s = std::get<NKikimr::NStat::TStatSimple>(resp.Statistics); + result.Metadata->RecordsCount = s.RowCount; + result.Metadata->DataSize = s.BytesSize; + promise.SetValue(result); + }); + + }); } } // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/gateway/kqp_metadata_loader.h b/ydb/core/kqp/gateway/kqp_metadata_loader.h index a51f1bf9800..691f091989d 100644 --- a/ydb/core/kqp/gateway/kqp_metadata_loader.h +++ b/ydb/core/kqp/gateway/kqp_metadata_loader.h @@ -15,10 +15,13 @@ namespace NKikimr::NKqp { class TKqpTableMetadataLoader : public NYql::IKikimrGateway::IKqpTableMetadataLoader { public: - explicit TKqpTableMetadataLoader(TActorSystem* actorSystem, - bool needCollectSchemeData = false, TKqpTempTablesState::TConstPtr tempTablesState = nullptr) + explicit TKqpTableMetadataLoader(TActorSystem* actorSystem, + NYql::TKikimrConfiguration::TPtr config, + bool needCollectSchemeData = false, + TKqpTempTablesState::TConstPtr tempTablesState = nullptr) : NeedCollectSchemeData(needCollectSchemeData) , ActorSystem(actorSystem) + , Config(config) , TempTablesState(std::move(tempTablesState)) {}; @@ -56,7 +59,9 @@ private: TMutex Lock; bool NeedCollectSchemeData; TActorSystem* ActorSystem; + NYql::TKikimrConfiguration::TPtr Config; TKqpTempTablesState::TConstPtr TempTablesState; + }; } // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp index b46b12b476a..a4f8aefa64b 100644 --- a/ydb/core/kqp/host/kqp_runner.cpp +++ b/ydb/core/kqp/host/kqp_runner.cpp @@ -91,7 +91,7 @@ public: .Add(CreateKqpCheckQueryTransformer(), "CheckKqlQuery") .AddPostTypeAnnotation(/* forSubgraph */ true) .AddCommonOptimization() - .Add(CreateKqpStatisticsTransformer(*typesCtx, Config), "Statistics") + .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics") .Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config), "LogicalOptimize") .Add(CreateLogicalDataProposalsInspector(*typesCtx), "ProvidersLogicalOptimize") .Add(CreateKqpPhyOptTransformer(OptimizeCtx, *typesCtx), "KqpPhysicalOptimize") diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp index 20d71b3c0e9..ab4b0a32188 100644 --- a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp @@ -6,16 +6,36 @@ using namespace NYql; using namespace NYql::NNodes; using namespace NKikimr::NKqp; +using namespace NYql::NDq; /** * Compute statistics and cost for read table - * Currently we just make up a number for the cardinality (100000) and set cost to 0 + * Currently we look up the number of rows and attributes in the statistics service */ -void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { +void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, + const TKqpOptimizeContext& kqpCtx) { - YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table"; + auto inputNode = TExprBase(input); + double nRows = 0; + int nAttrs = 0; - auto outputStats = TOptimizerStatistics(100000, 5, 0.0); + const TExprNode* path; + + if ( auto readTable = inputNode.Maybe<TKqlReadTableBase>()){ + path = readTable.Cast().Table().Path().Raw(); + nAttrs = readTable.Cast().Columns().Size(); + } else if(auto readRanges = inputNode.Maybe<TKqlReadTableRangesBase>()){ + path = readRanges.Cast().Table().Path().Raw(); + nAttrs = readRanges.Cast().Columns().Size(); + } else { + Y_ENSURE(false,"Invalid node type for InferStatisticsForReadTable"); + } + + const auto& tableData = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, path->Content()); + nRows = tableData.Metadata->RecordsCount; + YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table, nrows:" << nRows << ", nattrs: " << nAttrs; + + auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0.0); typeCtx->SetStats( input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) ); } @@ -48,16 +68,25 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt auto output = input; if (TCoFlatMap::Match(input.Get())){ - NDq::InferStatisticsForFlatMap(input, typeCtx); + InferStatisticsForFlatMap(input, TypeCtx); } else if(TCoSkipNullMembers::Match(input.Get())){ - NDq::InferStatisticsForSkipNullMembers(input, typeCtx); + InferStatisticsForSkipNullMembers(input, TypeCtx); + } + else if(TCoExtractMembers::Match(input.Get())){ + InferStatisticsForExtractMembers(input, TypeCtx); + } + else if(TCoAggregateCombine::Match(input.Get())){ + InferStatisticsForAggregateCombine(input, TypeCtx); + } + else if(TCoAggregateMergeFinalize::Match(input.Get())){ + InferStatisticsForAggregateMergeFinalize(input, TypeCtx); } else if(TKqlReadTableBase::Match(input.Get()) || TKqlReadTableRangesBase::Match(input.Get())){ - InferStatisticsForReadTable(input, typeCtx); + InferStatisticsForReadTable(input, TypeCtx, KqpCtx); } else if(TKqlLookupTableBase::Match(input.Get()) || TKqlLookupIndexBase::Match(input.Get())){ - InferStatisticsForIndexLookup(input, typeCtx); + InferStatisticsForIndexLookup(input, TypeCtx); } return output; @@ -66,8 +95,8 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt return ret; } -TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, - const TKikimrConfiguration::TPtr& config) { +TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx, + TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) { - return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(typeCtx, config)); + return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config)); } diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.h b/ydb/core/kqp/opt/kqp_statistics_transformer.h index 78d982f9981..d1715e461e0 100644 --- a/ydb/core/kqp/opt/kqp_statistics_transformer.h +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.h @@ -1,5 +1,7 @@ #pragma once +#include "kqp_opt.h" + #include <ydb/library/yql/core/yql_statistics.h> #include <ydb/core/kqp/common/kqp_yql.h> @@ -14,6 +16,7 @@ namespace NKqp { using namespace NYql; using namespace NYql::NNodes; +using namespace NOpt; /*** * Statistics transformer is a transformer that propagates statistics and costs from @@ -23,12 +26,16 @@ using namespace NYql::NNodes; */ class TKqpStatisticsTransformer : public TSyncTransformerBase { - TTypeAnnotationContext* typeCtx; + TTypeAnnotationContext* TypeCtx; const TKikimrConfiguration::TPtr& Config; + const TKqpOptimizeContext& KqpCtx; public: - TKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) : - typeCtx(&typeCtx), Config(config) {} + TKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx, TTypeAnnotationContext& typeCtx, + const TKikimrConfiguration::TPtr& config) : + TypeCtx(&typeCtx), + Config(config), + KqpCtx(*kqpCtx) {} // Main method of the transformer IGraphTransformer::TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final; @@ -39,7 +46,7 @@ class TKqpStatisticsTransformer : public TSyncTransformerBase { } }; -TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, - const TKikimrConfiguration::TPtr& config); +TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx, + TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config); } } diff --git a/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp b/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp index 3a3ce562bf5..750c979e07a 100644 --- a/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp @@ -73,7 +73,7 @@ TIntrusivePtr<IKqpGateway> GetIcGateway(Tests::TServer& server) { counters->Counters = new TKqpCounters(server.GetRuntime()->GetAppData(0).Counters); counters->TxProxyMon = new NTxProxy::TTxProxyMon(server.GetRuntime()->GetAppData(0).Counters); - std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(), false); + std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(),TIntrusivePtr<NYql::TKikimrConfiguration>(nullptr), false); return CreateKikimrIcGateway(TestCluster, "/Root", std::move(loader), server.GetRuntime()->GetAnyNodeActorSystem(), server.GetRuntime()->GetNodeId(0), counters); } diff --git a/ydb/core/kqp/session_actor/kqp_worker_actor.cpp b/ydb/core/kqp/session_actor/kqp_worker_actor.cpp index ebe9ac03358..e17ab10cdc4 100644 --- a/ydb/core/kqp/session_actor/kqp_worker_actor.cpp +++ b/ydb/core/kqp/session_actor/kqp_worker_actor.cpp @@ -133,7 +133,8 @@ public: LOG_D("Worker bootstrapped"); Counters->ReportWorkerCreated(Settings.DbCounters); - std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(TlsActivationContext->ActorSystem(), false); + std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>( + TlsActivationContext->ActorSystem(), Config, false); Gateway = CreateKikimrIcGateway(Settings.Cluster, Settings.Database, std::move(loader), ctx.ExecutorThread.ActorSystem, ctx.SelfID.NodeId(), RequestCounters); diff --git a/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp b/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp index 0b5f577a7d6..16d0ca327b0 100644 --- a/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp +++ b/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp @@ -34,7 +34,7 @@ TIntrusivePtr<NKqp::IKqpGateway> GetIcGateway(Tests::TServer& server) { auto counters = MakeIntrusive<TKqpRequestCounters>(); counters->Counters = new TKqpCounters(server.GetRuntime()->GetAppData(0).Counters); counters->TxProxyMon = new NTxProxy::TTxProxyMon(server.GetRuntime()->GetAppData(0).Counters); - std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(), false); + std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(),TIntrusivePtr<NYql::TKikimrConfiguration>(nullptr),false); return NKqp::CreateKikimrIcGateway(TestCluster, "/Root", std::move(loader), server.GetRuntime()->GetAnyNodeActorSystem(), server.GetRuntime()->GetNodeId(0), counters); } diff --git a/ydb/library/yql/core/yql_statistics.cpp b/ydb/library/yql/core/yql_statistics.cpp index 8abb9a59b5d..ad3968f0fee 100644 --- a/ydb/library/yql/core/yql_statistics.cpp +++ b/ydb/library/yql/core/yql_statistics.cpp @@ -2,9 +2,9 @@ using namespace NYql; -std::ostream& operator<<(std::ostream& os, const TOptimizerStatistics& s) { +std::ostream& NYql::operator<<(std::ostream& os, const TOptimizerStatistics& s) { os << "Nrows: " << s.Nrows << ", Ncols: " << s.Ncols; - os << "Cost: "; + os << ", Cost: "; if (s.Cost.has_value()){ os << s.Cost.value(); } else { diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp index 03d328ae202..867e6fd977d 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp @@ -95,243 +95,16 @@ struct TEdge { }; /** - * Graph is a data structure for the join graph - * It is an undirected graph, with two edges per connection (from,to) and (to,from) - * It needs to be constructed with addNode and addEdge methods, since its - * keeping various indexes updated. - * The graph also needs to be reordered with the breadth-first search method, and - * the reordering is recorded in bfsMapping (original rel indexes can be recovered from - * this mapping) -*/ -template <int N> -struct TGraph { - // set of edges of the graph - std::unordered_set<TEdge,TEdge::HashFunction> Edges; - - // neightborgh index - TVector<std::bitset<N>> EdgeIdx; - - // number of nodes in a graph - int NNodes; - - // mapping from rel label to node in the graph - THashMap<TString,int> ScopeMapping; - - // mapping from node in the graph to rel label - TVector<TString> RevScopeMapping; - - // Breadth-first-search mapping - TVector<int> BfsMapping; - - // Empty graph constructor intializes indexes to size N - TGraph() : EdgeIdx(N), RevScopeMapping(N) {} - - // Add a node to a graph with a rel label - void AddNode(int nodeId, TString scope){ - NNodes = nodeId + 1; - ScopeMapping[scope] = nodeId; - RevScopeMapping[nodeId] = scope; - } - - // Add an edge to the graph, if the edge is already in the graph - // (we check both directions), no action is taken. Otherwise we - // insert two edges, the forward edge with original joinConditions - // and a reverse edge with swapped joinConditions - void AddEdge(TEdge e){ - if (Edges.contains(e) || Edges.contains(TEdge(e.To, e.From))) { - return; - } - - Edges.insert(e); - std::set<std::pair<TJoinColumn, TJoinColumn>> swappedSet; - for (auto c : e.JoinConditions){ - swappedSet.insert(std::make_pair(c.second, c.first)); - } - Edges.insert(TEdge(e.To,e.From,swappedSet)); - - EdgeIdx[e.From].set(e.To); - EdgeIdx[e.To].set(e.From); - } - - // Find a node by the rel scope - int FindNode(TString scope){ - return ScopeMapping[scope]; - } - - // Return a bitset of node's neighbors - inline std::bitset<N> FindNeighbors(int fromVertex) - { - return EdgeIdx[fromVertex]; - } - - // Return a bitset of node's neigbors with itself included - inline std::bitset<N> FindNeighborsWithSelf(int fromVertex) - { - std::bitset<N> res = FindNeighbors(fromVertex); - res.set(fromVertex); - return res; - } - - // Find an edge that connects two subsets of graph's nodes - // We are guaranteed to find a match - TEdge FindCrossingEdge(const std::bitset<N>& S1, const std::bitset<N>& S2) { - for(int i=0; i<NNodes; i++){ - if (!S1[i]) { - continue; - } - for (int j=0; j<NNodes; j++) { - if (!S2[j]) { - continue; - } - if ((FindNeighborsWithSelf(i) & FindNeighborsWithSelf(j)) != 0) { - auto it = Edges.find(TEdge(i, j)); - Y_VERIFY_DEBUG(it != Edges.end()); - return *it; - } - } - } - Y_ENSURE(false,"Connecting edge not found!"); - return TEdge(-1,-1); - } - - /** - * Create a union-set from the join conditions to record the equivalences. - * Then use the equivalence set to compute transitive closure of the graph. - * Transitive closure means that if we have an edge from (1,2) with join - * condition R.A = S.A and we have an edge from (2,3) with join condition - * S.A = T.A, we will find out that the join conditions form an equivalence set - * and add an edge (1,3) with join condition R.A = T.A. - */ - void ComputeTransitiveClosure(const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { - std::set<TJoinColumn> columnSet; - for (auto [ leftCondition, rightCondition ] : joinConditions) { - columnSet.insert( leftCondition ); - columnSet.insert( rightCondition ); - } - std::vector<TJoinColumn> columns; - for (auto c : columnSet ) { - columns.push_back(c); - } - - THashMap<TJoinColumn, int, TJoinColumn::HashFunction> indexMapping; - for (size_t i=0; i<columns.size(); i++) { - indexMapping[ columns[i] ] = i; - } - - TDisjointSets ds = TDisjointSets( columns.size() ); - for (auto [ leftCondition, rightCondition ] : joinConditions ) { - int leftIndex = indexMapping[ leftCondition ]; - int rightIndex = indexMapping[ rightCondition ]; - ds.UnionSets(leftIndex,rightIndex); - } - - for (size_t i=0;i<columns.size();i++) { - for (size_t j=0;j<i;j++) { - if (ds.CanonicSetElement(i) == ds.CanonicSetElement(j)) { - TJoinColumn left = columns[i]; - TJoinColumn right = columns[j]; - int leftNodeId = ScopeMapping[ left.RelName ]; - int rightNodeId = ScopeMapping[ right.RelName ]; - - if (! Edges.contains(TEdge(leftNodeId,rightNodeId)) && - ! Edges.contains(TEdge(rightNodeId,leftNodeId))) { - AddEdge(TEdge(leftNodeId,rightNodeId,std::make_pair(left, right))); - } else { - TEdge e1 = *Edges.find(TEdge(leftNodeId,rightNodeId)); - if (!e1.JoinConditions.contains(std::make_pair(left, right))) { - e1.JoinConditions.insert(std::make_pair(left, right)); - } - - TEdge e2 = *Edges.find(TEdge(rightNodeId,leftNodeId)); - if (!e2.JoinConditions.contains(std::make_pair(right, left))) { - e2.JoinConditions.insert(std::make_pair(right, left)); - } - } - } - } - } - } - - /** - * Reorder the graph by doing a breadth first search from node 0. - * This is required by the DPccp algorithm. - */ - TGraph<N> BfsReorder() { - std::set<int> visited; - std::queue<int> queue; - TVector<int> bfsMapping(NNodes); - - queue.push(0); - int lastVisited = 0; - - while(! queue.empty()) { - int curr = queue.front(); - queue.pop(); - if (visited.contains(curr)) { - continue; - } - bfsMapping[curr] = lastVisited; - lastVisited++; - visited.insert(curr); - - std::bitset<N> neighbors = FindNeighbors(curr); - for (int i=0; i<NNodes; i++ ) { - if (neighbors[i]){ - if (!visited.contains(i)) { - queue.push(i); - } - } - } - } - - TGraph<N> res; - - for (int i=0;i<NNodes;i++) { - res.AddNode(i,RevScopeMapping[bfsMapping[i]]); - } - - for (const TEdge& e : Edges){ - res.AddEdge( TEdge(bfsMapping[e.From], bfsMapping[e.To], e.JoinConditions) ); - } - - res.BfsMapping = bfsMapping; - - return res; - } - - /** - * Print the graph - */ - void PrintGraph(std::stringstream& stream) { - stream << "Join Graph:\n"; - stream << "nNodes: " << NNodes << ", nEdges: " << Edges.size() << "\n"; - - for(int i=0;i<NNodes;i++) { - stream << "Node:" << i << "," << RevScopeMapping[i] << "\n"; - } - for (const TEdge& e: Edges ) { - stream << "Edge: " << e.From << " -> " << e.To << "\n"; - for (auto p : e.JoinConditions) { - stream << p.first.RelName << "." - << p.first.AttributeName << "=" - << p.second.RelName << "." - << p.second.AttributeName << "\n"; - } - } - } -}; - -/** * Fetch join conditions from the equi-join tree */ void ComputeJoinConditions(const TCoEquiJoinTuple& joinTuple, std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { if (joinTuple.LeftScope().Maybe<TCoEquiJoinTuple>()) { - ComputeJoinConditions( joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), joinConditions ); + ComputeJoinConditions(joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), joinConditions); } if (joinTuple.RightScope().Maybe<TCoEquiJoinTuple>()) { - ComputeJoinConditions( joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), joinConditions ); + ComputeJoinConditions(joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), joinConditions); } size_t joinKeysCount = joinTuple.LeftKeys().Size() / 2; @@ -388,15 +161,15 @@ struct TRelOptimizerNode : public IBaseOptimizerNode { virtual ~TRelOptimizerNode() {} virtual void Print(std::stringstream& stream, int ntabs=0) { - for (int i=0;i<ntabs;i++){ + for (int i = 0; i < ntabs; i++){ stream << "\t"; } stream << "Rel: " << Label << "\n"; - for (int i=0;i<ntabs;i++){ + for (int i = 0; i < ntabs; i++){ stream << "\t"; } - stream << Stats << "\n"; + stream << *Stats << "\n"; } }; @@ -443,7 +216,7 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode { * Print out the join tree, rooted at this node */ virtual void Print(std::stringstream& stream, int ntabs=0) { - for (int i=0;i<ntabs;i++){ + for (int i = 0; i < ntabs; i++){ stream << "\t"; } @@ -455,11 +228,11 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode { } stream << "\n"; - for (int i=0;i<ntabs;i++){ + for (int i = 0; i < ntabs; i++){ stream << "\t"; } - stream << Stats << "\n"; + stream << *Stats << "\n"; LeftArg->Print(stream, ntabs+1); RightArg->Print(stream, ntabs+1); @@ -492,6 +265,173 @@ struct pair_hash { }; /** + * Graph is a data structure for the join graph + * It is an undirected graph, with two edges per connection (from,to) and (to,from) + * It needs to be constructed with addNode and addEdge methods, since its + * keeping various indexes updated. + * The graph also needs to be reordered with the breadth-first search method +*/ +template <int N> +struct TGraph { + // set of edges of the graph + std::unordered_set<TEdge,TEdge::HashFunction> Edges; + + // neightborgh index + TVector<std::bitset<N>> EdgeIdx; + + // number of nodes in a graph + int NNodes; + + // mapping from rel label to node in the graph + THashMap<TString,int> ScopeMapping; + + // mapping from node in the graph to rel label + TVector<TString> RevScopeMapping; + + // Empty graph constructor intializes indexes to size N + TGraph() : EdgeIdx(N), RevScopeMapping(N) {} + + // Add a node to a graph with a rel label + void AddNode(int nodeId, TString scope){ + NNodes = nodeId + 1; + ScopeMapping[scope] = nodeId; + RevScopeMapping[nodeId] = scope; + } + + // Add an edge to the graph, if the edge is already in the graph + // (we check both directions), no action is taken. Otherwise we + // insert two edges, the forward edge with original joinConditions + // and a reverse edge with swapped joinConditions + void AddEdge(TEdge e){ + if (Edges.contains(e) || Edges.contains(TEdge(e.To, e.From))) { + return; + } + + Edges.insert(e); + std::set<std::pair<TJoinColumn, TJoinColumn>> swappedSet; + for (auto c : e.JoinConditions){ + swappedSet.insert(std::make_pair(c.second, c.first)); + } + Edges.insert(TEdge(e.To,e.From,swappedSet)); + + EdgeIdx[e.From].set(e.To); + EdgeIdx[e.To].set(e.From); + } + + // Find a node by the rel scope + int FindNode(TString scope){ + return ScopeMapping[scope]; + } + + // Return a bitset of node's neighbors + inline std::bitset<N> FindNeighbors(int fromVertex) + { + return EdgeIdx[fromVertex]; + } + + // Find an edge that connects two subsets of graph's nodes + // We are guaranteed to find a match + TEdge FindCrossingEdge(const std::bitset<N>& S1, const std::bitset<N>& S2) { + for(int i = 0; i < NNodes; i++){ + if (!S1[i]) { + continue; + } + for (int j = 0; j < NNodes; j++) { + if (!S2[j]) { + continue; + } + if (EdgeIdx[i].test(j)) { + auto it = Edges.find(TEdge(i, j)); + Y_VERIFY_DEBUG(it != Edges.end()); + return *it; + } + } + } + Y_ENSURE(false,"Connecting edge not found!"); + return TEdge(-1,-1); + } + + /** + * Create a union-set from the join conditions to record the equivalences. + * Then use the equivalence set to compute transitive closure of the graph. + * Transitive closure means that if we have an edge from (1,2) with join + * condition R.A = S.A and we have an edge from (2,3) with join condition + * S.A = T.A, we will find out that the join conditions form an equivalence set + * and add an edge (1,3) with join condition R.A = T.A. + */ + void ComputeTransitiveClosure(const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) { + std::set<TJoinColumn> columnSet; + for (auto [ leftCondition, rightCondition ] : joinConditions) { + columnSet.insert(leftCondition); + columnSet.insert(rightCondition); + } + std::vector<TJoinColumn> columns; + for (auto c : columnSet ) { + columns.push_back(c); + } + + THashMap<TJoinColumn, int, TJoinColumn::HashFunction> indexMapping; + for (size_t i=0; i<columns.size(); i++) { + indexMapping[columns[i]] = i; + } + + TDisjointSets ds = TDisjointSets( columns.size() ); + for (auto [ leftCondition, rightCondition ] : joinConditions ) { + int leftIndex = indexMapping[leftCondition]; + int rightIndex = indexMapping[rightCondition]; + ds.UnionSets(leftIndex,rightIndex); + } + + for (size_t i = 0; i < columns.size(); i++) { + for (size_t j = 0; j < i; j++) { + if (ds.CanonicSetElement(i) == ds.CanonicSetElement(j)) { + TJoinColumn left = columns[i]; + TJoinColumn right = columns[j]; + int leftNodeId = ScopeMapping[left.RelName]; + int rightNodeId = ScopeMapping[right.RelName]; + + if (! Edges.contains(TEdge(leftNodeId,rightNodeId)) && + ! Edges.contains(TEdge(rightNodeId,leftNodeId))) { + AddEdge(TEdge(leftNodeId,rightNodeId,std::make_pair(left, right))); + } else { + TEdge e1 = *Edges.find(TEdge(leftNodeId,rightNodeId)); + if (!e1.JoinConditions.contains(std::make_pair(left, right))) { + e1.JoinConditions.insert(std::make_pair(left, right)); + } + + TEdge e2 = *Edges.find(TEdge(rightNodeId,leftNodeId)); + if (!e2.JoinConditions.contains(std::make_pair(right, left))) { + e2.JoinConditions.insert(std::make_pair(right, left)); + } + } + } + } + } + } + + /** + * Print the graph + */ + void PrintGraph(std::stringstream& stream) { + stream << "Join Graph:\n"; + stream << "nNodes: " << NNodes << ", nEdges: " << Edges.size() << "\n"; + + for(int i = 0; i < NNodes; i++) { + stream << "Node:" << i << "," << RevScopeMapping[i] << "\n"; + } + for (const TEdge& e: Edges ) { + stream << "Edge: " << e.From << " -> " << e.To << "\n"; + for (auto p : e.JoinConditions) { + stream << p.first.RelName << "." + << p.first.AttributeName << "=" + << p.second.RelName << "." + << p.second.AttributeName << "\n"; + } + } + } +}; + +/** * DPcpp (Dynamic Programming with connected complement pairs) is a graph-aware * join eumeration algorithm that only considers CSGs (Connected Sub-Graphs) of * the join graph and computes CMPs (Complement pairs) that are also connected @@ -508,7 +448,7 @@ class TDPccpSolver { public: // Construct the DPccp solver based on the join graph and data about input relations - TDPccpSolver(TGraph<N>& g, std::vector<std::shared_ptr<TRelOptimizerNode>> rels): + TDPccpSolver(TGraph<N>& g, TVector<std::shared_ptr<TRelOptimizerNode>> rels): Graph(g), Rels(rels) { NNodes = g.NNodes; } @@ -537,7 +477,7 @@ class TDPccpSolver { TGraph<N>& Graph; // List of input relations to DPccp - std::vector<std::shared_ptr<TRelOptimizerNode>> Rels; + TVector<std::shared_ptr<TRelOptimizerNode>> Rels; // Emit connected subgraph void EmitCsg(const std::bitset<N>&, int=0); @@ -567,7 +507,7 @@ class TDPccpSolver { // Print tabs void PrintTabs(std::stringstream& stream, int ntabs) { - for (int i=0;i<ntabs;i++) + for (int i = 0; i < ntabs; i++) stream << "\t"; } @@ -578,7 +518,7 @@ template <int N> void TDPccpSolver<N>::PrintBitset(std::stringstream& stream, PrintTabs(stream, ntabs); stream << name << ": " << "{"; - for (int i=0;i<NNodes;i++) + for (int i = 0; i < NNodes; i++) if (s[i]) stream << i << ","; @@ -590,7 +530,7 @@ template<int N> std::bitset<N> TDPccpSolver<N>::Neighbors(const std::bitset<N>& std::bitset<N> res; - for (int i=0;i<Graph.NNodes;i++) { + for (int i = 0; i < Graph.NNodes; i++) { if (S[i]) { std::bitset<N> n = Graph.FindNeighbors(i); res = res | n; @@ -605,14 +545,14 @@ template<int N> std::bitset<N> TDPccpSolver<N>::Neighbors(const std::bitset<N>& template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() { // Process singleton sets - for (int i=NNodes-1;i>=0;i--) { + for (int i = NNodes-1; i >= 0; i--) { std::bitset<N> s; s.set(i); - DpTable[s] = Rels[Graph.BfsMapping[i]]; + DpTable[s] = Rels[i]; } // Expand singleton sets - for (int i=NNodes-1;i>=0;i--) { + for (int i = NNodes-1; i >= 0; i--) { std::bitset<N> s; s.set(i); EmitCsg(s); @@ -622,7 +562,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() // Return the entry of the dpTable that corresponds to the full // set of nodes in the graph std::bitset<N> V; - for (int i=0;i<NNodes;i++) { + for (int i = 0; i < NNodes; i++) { V.set(i); } @@ -643,12 +583,12 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() return; } - for (int i=NNodes-1;i>=0;i--) { + for (int i = NNodes - 1; i >= 0; i--) { if (Ns[i]) { std::bitset<N> S2; S2.set(i); - EmitCsgCmp(S,S2,ntabs+1); - EnumerateCmpRec(S, S2, X|MakeB(Ns,i), ntabs+1); + EmitCsgCmp(S, S2, ntabs+1); + EnumerateCmpRec(S, S2, X | MakeB(Ns, i), ntabs+1); } } } @@ -660,7 +600,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() */ template <int N> void TDPccpSolver<N>::EnumerateCsgRec(const std::bitset<N>& S, const std::bitset<N>& X, int ntabs) { - std::bitset<N> Ns = Neighbors(S,X); + std::bitset<N> Ns = Neighbors(S, X); if (Ns == std::bitset<N>()) { return; @@ -670,7 +610,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() std::bitset<N> next; while(true) { - next = NextBitset(prev,Ns); + next = NextBitset(prev, Ns); EmitCsg(S | next ); if (next == Ns) { break; @@ -680,7 +620,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() prev.reset(); while(true) { - next = NextBitset(prev,Ns); + next = NextBitset(prev, Ns); EnumerateCsgRec(S | next, X | Ns , ntabs+1); if (next==Ns) { break; @@ -709,7 +649,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() while(true) { next = NextBitset(prev, Ns); - EmitCsgCmp(S1,S2|next, ntabs+1); + EmitCsgCmp(S1, S2 | next, ntabs+1); if (next==Ns) { break; } @@ -719,7 +659,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve() prev.reset(); while(true) { next = NextBitset(prev, Ns); - EnumerateCmpRec(S1, S2|next, X|Ns, ntabs+1); + EnumerateCmpRec(S1, S2 | next, X | Ns, ntabs+1); if (next==Ns) { break; } @@ -778,9 +718,9 @@ template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, cons template <int N> std::bitset<N> TDPccpSolver<N>::MakeBiMin(const std::bitset<N>& S) { std::bitset<N> res; - for (int i=0;i<NNodes; i++) { + for (int i = 0; i < NNodes; i++) { if (S[i]) { - for (int j=0; j<=i; j++) { + for (int j = 0; j <= i; j++) { res.set(j); } break; @@ -796,8 +736,8 @@ template <int N> std::bitset<N> TDPccpSolver<N>::MakeBiMin(const std::bitset<N>& template <int N> std::bitset<N> TDPccpSolver<N>::MakeB(const std::bitset<N>& S, int x) { std::bitset<N> res; - for (int i=0; i<NNodes; i++) { - if (S[i] && i<=x) { + for (int i = 0; i < NNodes; i++) { + if (S[i] && i <= x) { res.set(i); } } @@ -815,7 +755,7 @@ template <int N> std::bitset<N> TDPccpSolver<N>::NextBitset(const std::bitset<N> std::bitset<N> res = prev; bool carry = true; - for (int i=0; i<NNodes; i++) + for (int i = 0; i < NNodes; i++) { if (!carry) { break; @@ -909,13 +849,13 @@ TExprBase BuildTree(TExprContext& ctx, const TCoEquiJoin& equiJoin, TExprBase RearrangeEquiJoinTree(TExprContext& ctx, const TCoEquiJoin& equiJoin, std::shared_ptr<TJoinOptimizerNode> reorderResult) { TVector<TExprBase> joinArgs; - for (size_t i=0; i<equiJoin.ArgCount() - 2; i++){ + for (size_t i = 0; i < equiJoin.ArgCount() - 2; i++){ joinArgs.push_back(equiJoin.Arg(i)); } joinArgs.push_back(BuildTree(ctx,equiJoin,reorderResult)); - joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount()-1)); + joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount() - 1)); return Build<TCoEquiJoin>(ctx, equiJoin.Pos()) .Add(joinArgs) @@ -968,12 +908,12 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, YQL_ENSURE(equiJoin.ArgCount() >= 4); if (typesCtx.StatisticsMap.contains(equiJoin.Raw()) && - typesCtx.StatisticsMap[ equiJoin.Raw()]->Cost.has_value()) { + typesCtx.StatisticsMap[equiJoin.Raw()]->Cost.has_value()) { return node; } - if (! AllInnerJoins(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>())) { + if (! AllInnerJoins(equiJoin.Arg(equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>())) { return node; } @@ -988,11 +928,13 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, auto input = equiJoin.Arg(i).Cast<TCoEquiJoinInput>(); auto joinArg = input.List(); - if (!typesCtx.StatisticsMap.contains( joinArg.Raw() )) { + if (!typesCtx.StatisticsMap.contains(joinArg.Raw())) { + YQL_CLOG(TRACE, CoreDq) << "Didn't find statistics for scope " << input.Scope().Cast<TCoAtom>().StringValue() << "\n"; + return node; } - if (!typesCtx.StatisticsMap[ joinArg.Raw() ]->Cost.has_value()) { + if (!typesCtx.StatisticsMap[joinArg.Raw()]->Cost.has_value()) { return node; } @@ -1002,8 +944,8 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, } auto label = scope.Cast<TCoAtom>().StringValue(); - auto stats = typesCtx.StatisticsMap[ joinArg.Raw() ]; - rels.push_back( std::shared_ptr<TRelOptimizerNode>( new TRelOptimizerNode( label, stats ))); + auto stats = typesCtx.StatisticsMap[joinArg.Raw()]; + rels.push_back( std::shared_ptr<TRelOptimizerNode>( new TRelOptimizerNode(label, stats))); } YQL_CLOG(TRACE, CoreDq) << "All statistics for join in place"; @@ -1011,18 +953,18 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, std::set<std::pair<TJoinColumn, TJoinColumn>> joinConditions; // EquiJoin argument n-2 is the actual join tree, represented as TCoEquiJoinTuple - ComputeJoinConditions(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>(), joinConditions ); + ComputeJoinConditions(equiJoin.Arg(equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>(), joinConditions); // construct a graph out of join conditions TGraph<64> joinGraph; - for (size_t i=0; i<rels.size(); i++) { - joinGraph.AddNode(i,rels[i]->Label); + for (size_t i = 0; i < rels.size(); i++) { + joinGraph.AddNode(i, rels[i]->Label); } for (auto cond : joinConditions ) { int fromNode = joinGraph.FindNode(cond.first.RelName); int toNode = joinGraph.FindNode(cond.second.RelName); - joinGraph.AddEdge(TEdge(fromNode,toNode,cond)); + joinGraph.AddEdge(TEdge(fromNode, toNode, cond)); } if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) { @@ -1042,8 +984,6 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, YQL_CLOG(TRACE, CoreDq) << str.str(); } - joinGraph = joinGraph.BfsReorder(); - // feed the graph to DPccp algorithm TDPccpSolver<64> solver(joinGraph,rels); std::shared_ptr<TJoinOptimizerNode> result = solver.Solve(); @@ -1056,7 +996,7 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, } // rewrite the join tree and record the output statistics - TExprBase res = RearrangeEquiJoinTree(ctx,equiJoin,result); + TExprBase res = RearrangeEquiJoinTree(ctx, equiJoin, result); typesCtx.StatisticsMap[ res.Raw() ] = result->Stats; return res; } diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.cpp b/ydb/library/yql/dq/opt/dq_opt_stat.cpp index fcd7b692562..6424e89478e 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_stat.cpp @@ -57,4 +57,61 @@ void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnota typeCtx->SetCost( input.Get(), typeCtx->GetCost( skipNullMembersInput.Raw() ) ); } +/** + * Infer statistics and costs for ExtractlMembers + * We just return the input statistics. +*/ +void InferStatisticsForExtractMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto inputNode = TExprBase(input); + auto extractMembers = inputNode.Cast<TCoExtractMembers>(); + auto extractMembersInput = extractMembers.Input(); + + auto inputStats = typeCtx->GetStats(extractMembersInput.Raw() ); + if (!inputStats) { + return; + } + + typeCtx->SetStats( input.Get(), inputStats ); + typeCtx->SetCost( input.Get(), typeCtx->GetCost( extractMembersInput.Raw() ) ); +} + +/** + * Infer statistics and costs for AggregateCombine + * We just return the input statistics. +*/ +void InferStatisticsForAggregateCombine(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto inputNode = TExprBase(input); + auto agg = inputNode.Cast<TCoAggregateCombine>(); + auto aggInput = agg.Input(); + + auto inputStats = typeCtx->GetStats(aggInput.Raw()); + if (!inputStats) { + return; + } + + typeCtx->SetStats( input.Get(), inputStats ); + typeCtx->SetCost( input.Get(), typeCtx->GetCost( aggInput.Raw() ) ); +} + +/** + * Infer statistics and costs for AggregateMergeFinalize + * Just return input stats +*/ +void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { + + auto inputNode = TExprBase(input); + auto agg = inputNode.Cast<TCoAggregateMergeFinalize>(); + auto aggInput = agg.Input(); + + auto inputStats = typeCtx->GetStats(aggInput.Raw() ); + if (!inputStats) { + return; + } + + typeCtx->SetStats( input.Get(), inputStats ); + typeCtx->SetCost( input.Get(), typeCtx->GetCost( aggInput.Raw() ) ); +} + } // namespace NYql::NDq { diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.h b/ydb/library/yql/dq/opt/dq_opt_stat.h index c4ab54aff52..01c71771344 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat.h +++ b/ydb/library/yql/dq/opt/dq_opt_stat.h @@ -6,5 +6,8 @@ namespace NYql::NDq { void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void InferStatisticsForExtractMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void InferStatisticsForAggregateCombine(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); } // namespace NYql::NDq {
\ No newline at end of file |