aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorpavelvelikhov <pavelvelikhov@yandex-team.com>2023-08-30 18:04:02 +0300
committerpavelvelikhov <pavelvelikhov@yandex-team.com>2023-08-30 18:45:04 +0300
commit6420e048afe32b6590f6b52f9fc6ffdcfab642ba (patch)
treec9e4e1988e574361e388c9ccc0c36812ac8ed925
parent15b722615d5789b04575b0b279cf6f391199da68 (diff)
downloadydb-6420e048afe32b6590f6b52f9fc6ffdcfab642ba.tar.gz
Integrated with statistics, bug fixes
formatting change Stashed conflicts resolved resolved conflicts
-rw-r--r--ydb/core/kqp/compile_service/kqp_compile_actor.cpp2
-rw-r--r--ydb/core/kqp/gateway/kqp_metadata_loader.cpp55
-rw-r--r--ydb/core/kqp/gateway/kqp_metadata_loader.h9
-rw-r--r--ydb/core/kqp/host/kqp_runner.cpp2
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.cpp51
-rw-r--r--ydb/core/kqp/opt/kqp_statistics_transformer.h17
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp2
-rw-r--r--ydb/core/kqp/session_actor/kqp_worker_actor.cpp3
-rw-r--r--ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp2
-rw-r--r--ydb/library/yql/core/yql_statistics.cpp4
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp484
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat.cpp57
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_stat.h3
13 files changed, 393 insertions, 298 deletions
diff --git a/ydb/core/kqp/compile_service/kqp_compile_actor.cpp b/ydb/core/kqp/compile_service/kqp_compile_actor.cpp
index 05d0f8343d6..b2121dc5141 100644
--- a/ydb/core/kqp/compile_service/kqp_compile_actor.cpp
+++ b/ydb/core/kqp/compile_service/kqp_compile_actor.cpp
@@ -117,7 +117,7 @@ public:
counters->TxProxyMon = new NTxProxy::TTxProxyMon(AppData(ctx)->Counters);
std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader =
std::make_shared<TKqpTableMetadataLoader>(
- TlsActivationContext->ActorSystem(), true, TempTablesState);
+ TlsActivationContext->ActorSystem(), Config, true, TempTablesState);
Gateway = CreateKikimrIcGateway(QueryId.Cluster, QueryId.Database, std::move(loader),
ctx.ExecutorThread.ActorSystem, ctx.SelfID.NodeId(), counters);
Gateway->SetToken(QueryId.Cluster, UserToken);
diff --git a/ydb/core/kqp/gateway/kqp_metadata_loader.cpp b/ydb/core/kqp/gateway/kqp_metadata_loader.cpp
index aeb30a29ff0..c24d594a30d 100644
--- a/ydb/core/kqp/gateway/kqp_metadata_loader.cpp
+++ b/ydb/core/kqp/gateway/kqp_metadata_loader.cpp
@@ -2,6 +2,8 @@
#include "actors/kqp_ic_gateway_actors.h"
#include <ydb/core/base/path.h>
+#include <ydb/core/statistics/events.h>
+#include <ydb/core/statistics/stat_service.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/actors/core/log.h>
@@ -688,7 +690,7 @@ NThreading::TFuture<TTableMetadataResult> TKqpTableMetadataLoader::LoadTableMeta
const auto schemeCacheId = MakeSchemeCacheID();
- return SendActorRequest<TRequest, TResponse, TResult>(
+ auto future = SendActorRequest<TRequest, TResponse, TResult>(
ActorSystem,
schemeCacheId,
ev.Release(),
@@ -778,6 +780,57 @@ NThreading::TFuture<TTableMetadataResult> TKqpTableMetadataLoader::LoadTableMeta
promise.SetValue(ResultFromException<TResult>(e));
}
});
+
+ // Create an apply for the future that will fetch table statistics and save it in the metadata
+ // This method will only run if cost based optimization is enabled
+
+ if (!Config || !Config->HasOptEnableCostBasedOptimization()){
+ return future;
+ }
+
+ TActorSystem* actorSystem = ActorSystem;
+
+ return future.Apply([actorSystem,table](const TFuture<TTableMetadataResult>& f) {
+ auto result = f.GetValue();
+ if (!result.Success()) {
+ return MakeFuture(result);
+ }
+
+ if (!result.Metadata->DoesExist){
+ return MakeFuture(result);
+ }
+
+ if (result.Metadata->Kind != NYql::EKikimrTableKind::Datashard &&
+ result.Metadata->Kind != NYql::EKikimrTableKind::Olap) {
+ return MakeFuture(result);
+ }
+
+ NKikimr::NStat::TRequest t;
+ t.StatType = NKikimr::NStat::EStatType::SIMPLE;
+ t.PathId = NKikimr::TPathId(result.Metadata->PathId.OwnerId(), result.Metadata->PathId.TableId());
+
+ auto event = MakeHolder<NStat::TEvStatistics::TEvGetStatistics>();
+ event->StatRequests.push_back(t);
+
+ auto statServiceId = NStat::MakeStatServiceID();
+
+
+ return SendActorRequest<NStat::TEvStatistics::TEvGetStatistics, NStat::TEvStatistics::TEvGetStatisticsResult, TResult>(
+ actorSystem,
+ statServiceId,
+ event.Release(),
+ [result](TPromise<TResult> promise, NStat::TEvStatistics::TEvGetStatisticsResult&& response){
+ if (!response.StatResponses.size()){
+ return;
+ }
+ auto resp = response.StatResponses[0];
+ auto s = std::get<NKikimr::NStat::TStatSimple>(resp.Statistics);
+ result.Metadata->RecordsCount = s.RowCount;
+ result.Metadata->DataSize = s.BytesSize;
+ promise.SetValue(result);
+ });
+
+ });
}
} // namespace NKikimr::NKqp
diff --git a/ydb/core/kqp/gateway/kqp_metadata_loader.h b/ydb/core/kqp/gateway/kqp_metadata_loader.h
index a51f1bf9800..691f091989d 100644
--- a/ydb/core/kqp/gateway/kqp_metadata_loader.h
+++ b/ydb/core/kqp/gateway/kqp_metadata_loader.h
@@ -15,10 +15,13 @@ namespace NKikimr::NKqp {
class TKqpTableMetadataLoader : public NYql::IKikimrGateway::IKqpTableMetadataLoader {
public:
- explicit TKqpTableMetadataLoader(TActorSystem* actorSystem,
- bool needCollectSchemeData = false, TKqpTempTablesState::TConstPtr tempTablesState = nullptr)
+ explicit TKqpTableMetadataLoader(TActorSystem* actorSystem,
+ NYql::TKikimrConfiguration::TPtr config,
+ bool needCollectSchemeData = false,
+ TKqpTempTablesState::TConstPtr tempTablesState = nullptr)
: NeedCollectSchemeData(needCollectSchemeData)
, ActorSystem(actorSystem)
+ , Config(config)
, TempTablesState(std::move(tempTablesState))
{};
@@ -56,7 +59,9 @@ private:
TMutex Lock;
bool NeedCollectSchemeData;
TActorSystem* ActorSystem;
+ NYql::TKikimrConfiguration::TPtr Config;
TKqpTempTablesState::TConstPtr TempTablesState;
+
};
} // namespace NKikimr::NKqp
diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp
index b46b12b476a..a4f8aefa64b 100644
--- a/ydb/core/kqp/host/kqp_runner.cpp
+++ b/ydb/core/kqp/host/kqp_runner.cpp
@@ -91,7 +91,7 @@ public:
.Add(CreateKqpCheckQueryTransformer(), "CheckKqlQuery")
.AddPostTypeAnnotation(/* forSubgraph */ true)
.AddCommonOptimization()
- .Add(CreateKqpStatisticsTransformer(*typesCtx, Config), "Statistics")
+ .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics")
.Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config), "LogicalOptimize")
.Add(CreateLogicalDataProposalsInspector(*typesCtx), "ProvidersLogicalOptimize")
.Add(CreateKqpPhyOptTransformer(OptimizeCtx, *typesCtx), "KqpPhysicalOptimize")
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
index 20d71b3c0e9..ab4b0a32188 100644
--- a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp
@@ -6,16 +6,36 @@
using namespace NYql;
using namespace NYql::NNodes;
using namespace NKikimr::NKqp;
+using namespace NYql::NDq;
/**
* Compute statistics and cost for read table
- * Currently we just make up a number for the cardinality (100000) and set cost to 0
+ * Currently we look up the number of rows and attributes in the statistics service
*/
-void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx,
+ const TKqpOptimizeContext& kqpCtx) {
- YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table";
+ auto inputNode = TExprBase(input);
+ double nRows = 0;
+ int nAttrs = 0;
- auto outputStats = TOptimizerStatistics(100000, 5, 0.0);
+ const TExprNode* path;
+
+ if ( auto readTable = inputNode.Maybe<TKqlReadTableBase>()){
+ path = readTable.Cast().Table().Path().Raw();
+ nAttrs = readTable.Cast().Columns().Size();
+ } else if(auto readRanges = inputNode.Maybe<TKqlReadTableRangesBase>()){
+ path = readRanges.Cast().Table().Path().Raw();
+ nAttrs = readRanges.Cast().Columns().Size();
+ } else {
+ Y_ENSURE(false,"Invalid node type for InferStatisticsForReadTable");
+ }
+
+ const auto& tableData = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, path->Content());
+ nRows = tableData.Metadata->RecordsCount;
+ YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table, nrows:" << nRows << ", nattrs: " << nAttrs;
+
+ auto outputStats = TOptimizerStatistics(nRows, nAttrs, 0.0);
typeCtx->SetStats( input.Get(), std::make_shared<TOptimizerStatistics>(outputStats) );
}
@@ -48,16 +68,25 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt
auto output = input;
if (TCoFlatMap::Match(input.Get())){
- NDq::InferStatisticsForFlatMap(input, typeCtx);
+ InferStatisticsForFlatMap(input, TypeCtx);
}
else if(TCoSkipNullMembers::Match(input.Get())){
- NDq::InferStatisticsForSkipNullMembers(input, typeCtx);
+ InferStatisticsForSkipNullMembers(input, TypeCtx);
+ }
+ else if(TCoExtractMembers::Match(input.Get())){
+ InferStatisticsForExtractMembers(input, TypeCtx);
+ }
+ else if(TCoAggregateCombine::Match(input.Get())){
+ InferStatisticsForAggregateCombine(input, TypeCtx);
+ }
+ else if(TCoAggregateMergeFinalize::Match(input.Get())){
+ InferStatisticsForAggregateMergeFinalize(input, TypeCtx);
}
else if(TKqlReadTableBase::Match(input.Get()) || TKqlReadTableRangesBase::Match(input.Get())){
- InferStatisticsForReadTable(input, typeCtx);
+ InferStatisticsForReadTable(input, TypeCtx, KqpCtx);
}
else if(TKqlLookupTableBase::Match(input.Get()) || TKqlLookupIndexBase::Match(input.Get())){
- InferStatisticsForIndexLookup(input, typeCtx);
+ InferStatisticsForIndexLookup(input, TypeCtx);
}
return output;
@@ -66,8 +95,8 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt
return ret;
}
-TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx,
- const TKikimrConfiguration::TPtr& config) {
+TAutoPtr<IGraphTransformer> NKikimr::NKqp::CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
+ TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) {
- return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(typeCtx, config));
+ return THolder<IGraphTransformer>(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config));
}
diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.h b/ydb/core/kqp/opt/kqp_statistics_transformer.h
index 78d982f9981..d1715e461e0 100644
--- a/ydb/core/kqp/opt/kqp_statistics_transformer.h
+++ b/ydb/core/kqp/opt/kqp_statistics_transformer.h
@@ -1,5 +1,7 @@
#pragma once
+#include "kqp_opt.h"
+
#include <ydb/library/yql/core/yql_statistics.h>
#include <ydb/core/kqp/common/kqp_yql.h>
@@ -14,6 +16,7 @@ namespace NKqp {
using namespace NYql;
using namespace NYql::NNodes;
+using namespace NOpt;
/***
* Statistics transformer is a transformer that propagates statistics and costs from
@@ -23,12 +26,16 @@ using namespace NYql::NNodes;
*/
class TKqpStatisticsTransformer : public TSyncTransformerBase {
- TTypeAnnotationContext* typeCtx;
+ TTypeAnnotationContext* TypeCtx;
const TKikimrConfiguration::TPtr& Config;
+ const TKqpOptimizeContext& KqpCtx;
public:
- TKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) :
- typeCtx(&typeCtx), Config(config) {}
+ TKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx, TTypeAnnotationContext& typeCtx,
+ const TKikimrConfiguration::TPtr& config) :
+ TypeCtx(&typeCtx),
+ Config(config),
+ KqpCtx(*kqpCtx) {}
// Main method of the transformer
IGraphTransformer::TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final;
@@ -39,7 +46,7 @@ class TKqpStatisticsTransformer : public TSyncTransformerBase {
}
};
-TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(TTypeAnnotationContext& typeCtx,
- const TKikimrConfiguration::TPtr& config);
+TAutoPtr<IGraphTransformer> CreateKqpStatisticsTransformer(const TIntrusivePtr<TKqpOptimizeContext>& kqpCtx,
+ TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config);
}
}
diff --git a/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp b/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp
index 3a3ce562bf5..750c979e07a 100644
--- a/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp
+++ b/ydb/core/kqp/provider/yql_kikimr_gateway_ut.cpp
@@ -73,7 +73,7 @@ TIntrusivePtr<IKqpGateway> GetIcGateway(Tests::TServer& server) {
counters->Counters = new TKqpCounters(server.GetRuntime()->GetAppData(0).Counters);
counters->TxProxyMon = new NTxProxy::TTxProxyMon(server.GetRuntime()->GetAppData(0).Counters);
- std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(), false);
+ std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(),TIntrusivePtr<NYql::TKikimrConfiguration>(nullptr), false);
return CreateKikimrIcGateway(TestCluster, "/Root", std::move(loader), server.GetRuntime()->GetAnyNodeActorSystem(),
server.GetRuntime()->GetNodeId(0), counters);
}
diff --git a/ydb/core/kqp/session_actor/kqp_worker_actor.cpp b/ydb/core/kqp/session_actor/kqp_worker_actor.cpp
index ebe9ac03358..e17ab10cdc4 100644
--- a/ydb/core/kqp/session_actor/kqp_worker_actor.cpp
+++ b/ydb/core/kqp/session_actor/kqp_worker_actor.cpp
@@ -133,7 +133,8 @@ public:
LOG_D("Worker bootstrapped");
Counters->ReportWorkerCreated(Settings.DbCounters);
- std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(TlsActivationContext->ActorSystem(), false);
+ std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(
+ TlsActivationContext->ActorSystem(), Config, false);
Gateway = CreateKikimrIcGateway(Settings.Cluster, Settings.Database, std::move(loader),
ctx.ExecutorThread.ActorSystem, ctx.SelfID.NodeId(), RequestCounters);
diff --git a/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp b/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp
index 0b5f577a7d6..16d0ca327b0 100644
--- a/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp
+++ b/ydb/core/kqp/ut/indexes/kqp_indexes_ut.cpp
@@ -34,7 +34,7 @@ TIntrusivePtr<NKqp::IKqpGateway> GetIcGateway(Tests::TServer& server) {
auto counters = MakeIntrusive<TKqpRequestCounters>();
counters->Counters = new TKqpCounters(server.GetRuntime()->GetAppData(0).Counters);
counters->TxProxyMon = new NTxProxy::TTxProxyMon(server.GetRuntime()->GetAppData(0).Counters);
- std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(), false);
+ std::shared_ptr<NYql::IKikimrGateway::IKqpTableMetadataLoader> loader = std::make_shared<TKqpTableMetadataLoader>(server.GetRuntime()->GetAnyNodeActorSystem(),TIntrusivePtr<NYql::TKikimrConfiguration>(nullptr),false);
return NKqp::CreateKikimrIcGateway(TestCluster, "/Root", std::move(loader), server.GetRuntime()->GetAnyNodeActorSystem(),
server.GetRuntime()->GetNodeId(0), counters);
}
diff --git a/ydb/library/yql/core/yql_statistics.cpp b/ydb/library/yql/core/yql_statistics.cpp
index 8abb9a59b5d..ad3968f0fee 100644
--- a/ydb/library/yql/core/yql_statistics.cpp
+++ b/ydb/library/yql/core/yql_statistics.cpp
@@ -2,9 +2,9 @@
using namespace NYql;
-std::ostream& operator<<(std::ostream& os, const TOptimizerStatistics& s) {
+std::ostream& NYql::operator<<(std::ostream& os, const TOptimizerStatistics& s) {
os << "Nrows: " << s.Nrows << ", Ncols: " << s.Ncols;
- os << "Cost: ";
+ os << ", Cost: ";
if (s.Cost.has_value()){
os << s.Cost.value();
} else {
diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
index 03d328ae202..867e6fd977d 100644
--- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp
@@ -95,243 +95,16 @@ struct TEdge {
};
/**
- * Graph is a data structure for the join graph
- * It is an undirected graph, with two edges per connection (from,to) and (to,from)
- * It needs to be constructed with addNode and addEdge methods, since its
- * keeping various indexes updated.
- * The graph also needs to be reordered with the breadth-first search method, and
- * the reordering is recorded in bfsMapping (original rel indexes can be recovered from
- * this mapping)
-*/
-template <int N>
-struct TGraph {
- // set of edges of the graph
- std::unordered_set<TEdge,TEdge::HashFunction> Edges;
-
- // neightborgh index
- TVector<std::bitset<N>> EdgeIdx;
-
- // number of nodes in a graph
- int NNodes;
-
- // mapping from rel label to node in the graph
- THashMap<TString,int> ScopeMapping;
-
- // mapping from node in the graph to rel label
- TVector<TString> RevScopeMapping;
-
- // Breadth-first-search mapping
- TVector<int> BfsMapping;
-
- // Empty graph constructor intializes indexes to size N
- TGraph() : EdgeIdx(N), RevScopeMapping(N) {}
-
- // Add a node to a graph with a rel label
- void AddNode(int nodeId, TString scope){
- NNodes = nodeId + 1;
- ScopeMapping[scope] = nodeId;
- RevScopeMapping[nodeId] = scope;
- }
-
- // Add an edge to the graph, if the edge is already in the graph
- // (we check both directions), no action is taken. Otherwise we
- // insert two edges, the forward edge with original joinConditions
- // and a reverse edge with swapped joinConditions
- void AddEdge(TEdge e){
- if (Edges.contains(e) || Edges.contains(TEdge(e.To, e.From))) {
- return;
- }
-
- Edges.insert(e);
- std::set<std::pair<TJoinColumn, TJoinColumn>> swappedSet;
- for (auto c : e.JoinConditions){
- swappedSet.insert(std::make_pair(c.second, c.first));
- }
- Edges.insert(TEdge(e.To,e.From,swappedSet));
-
- EdgeIdx[e.From].set(e.To);
- EdgeIdx[e.To].set(e.From);
- }
-
- // Find a node by the rel scope
- int FindNode(TString scope){
- return ScopeMapping[scope];
- }
-
- // Return a bitset of node's neighbors
- inline std::bitset<N> FindNeighbors(int fromVertex)
- {
- return EdgeIdx[fromVertex];
- }
-
- // Return a bitset of node's neigbors with itself included
- inline std::bitset<N> FindNeighborsWithSelf(int fromVertex)
- {
- std::bitset<N> res = FindNeighbors(fromVertex);
- res.set(fromVertex);
- return res;
- }
-
- // Find an edge that connects two subsets of graph's nodes
- // We are guaranteed to find a match
- TEdge FindCrossingEdge(const std::bitset<N>& S1, const std::bitset<N>& S2) {
- for(int i=0; i<NNodes; i++){
- if (!S1[i]) {
- continue;
- }
- for (int j=0; j<NNodes; j++) {
- if (!S2[j]) {
- continue;
- }
- if ((FindNeighborsWithSelf(i) & FindNeighborsWithSelf(j)) != 0) {
- auto it = Edges.find(TEdge(i, j));
- Y_VERIFY_DEBUG(it != Edges.end());
- return *it;
- }
- }
- }
- Y_ENSURE(false,"Connecting edge not found!");
- return TEdge(-1,-1);
- }
-
- /**
- * Create a union-set from the join conditions to record the equivalences.
- * Then use the equivalence set to compute transitive closure of the graph.
- * Transitive closure means that if we have an edge from (1,2) with join
- * condition R.A = S.A and we have an edge from (2,3) with join condition
- * S.A = T.A, we will find out that the join conditions form an equivalence set
- * and add an edge (1,3) with join condition R.A = T.A.
- */
- void ComputeTransitiveClosure(const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
- std::set<TJoinColumn> columnSet;
- for (auto [ leftCondition, rightCondition ] : joinConditions) {
- columnSet.insert( leftCondition );
- columnSet.insert( rightCondition );
- }
- std::vector<TJoinColumn> columns;
- for (auto c : columnSet ) {
- columns.push_back(c);
- }
-
- THashMap<TJoinColumn, int, TJoinColumn::HashFunction> indexMapping;
- for (size_t i=0; i<columns.size(); i++) {
- indexMapping[ columns[i] ] = i;
- }
-
- TDisjointSets ds = TDisjointSets( columns.size() );
- for (auto [ leftCondition, rightCondition ] : joinConditions ) {
- int leftIndex = indexMapping[ leftCondition ];
- int rightIndex = indexMapping[ rightCondition ];
- ds.UnionSets(leftIndex,rightIndex);
- }
-
- for (size_t i=0;i<columns.size();i++) {
- for (size_t j=0;j<i;j++) {
- if (ds.CanonicSetElement(i) == ds.CanonicSetElement(j)) {
- TJoinColumn left = columns[i];
- TJoinColumn right = columns[j];
- int leftNodeId = ScopeMapping[ left.RelName ];
- int rightNodeId = ScopeMapping[ right.RelName ];
-
- if (! Edges.contains(TEdge(leftNodeId,rightNodeId)) &&
- ! Edges.contains(TEdge(rightNodeId,leftNodeId))) {
- AddEdge(TEdge(leftNodeId,rightNodeId,std::make_pair(left, right)));
- } else {
- TEdge e1 = *Edges.find(TEdge(leftNodeId,rightNodeId));
- if (!e1.JoinConditions.contains(std::make_pair(left, right))) {
- e1.JoinConditions.insert(std::make_pair(left, right));
- }
-
- TEdge e2 = *Edges.find(TEdge(rightNodeId,leftNodeId));
- if (!e2.JoinConditions.contains(std::make_pair(right, left))) {
- e2.JoinConditions.insert(std::make_pair(right, left));
- }
- }
- }
- }
- }
- }
-
- /**
- * Reorder the graph by doing a breadth first search from node 0.
- * This is required by the DPccp algorithm.
- */
- TGraph<N> BfsReorder() {
- std::set<int> visited;
- std::queue<int> queue;
- TVector<int> bfsMapping(NNodes);
-
- queue.push(0);
- int lastVisited = 0;
-
- while(! queue.empty()) {
- int curr = queue.front();
- queue.pop();
- if (visited.contains(curr)) {
- continue;
- }
- bfsMapping[curr] = lastVisited;
- lastVisited++;
- visited.insert(curr);
-
- std::bitset<N> neighbors = FindNeighbors(curr);
- for (int i=0; i<NNodes; i++ ) {
- if (neighbors[i]){
- if (!visited.contains(i)) {
- queue.push(i);
- }
- }
- }
- }
-
- TGraph<N> res;
-
- for (int i=0;i<NNodes;i++) {
- res.AddNode(i,RevScopeMapping[bfsMapping[i]]);
- }
-
- for (const TEdge& e : Edges){
- res.AddEdge( TEdge(bfsMapping[e.From], bfsMapping[e.To], e.JoinConditions) );
- }
-
- res.BfsMapping = bfsMapping;
-
- return res;
- }
-
- /**
- * Print the graph
- */
- void PrintGraph(std::stringstream& stream) {
- stream << "Join Graph:\n";
- stream << "nNodes: " << NNodes << ", nEdges: " << Edges.size() << "\n";
-
- for(int i=0;i<NNodes;i++) {
- stream << "Node:" << i << "," << RevScopeMapping[i] << "\n";
- }
- for (const TEdge& e: Edges ) {
- stream << "Edge: " << e.From << " -> " << e.To << "\n";
- for (auto p : e.JoinConditions) {
- stream << p.first.RelName << "."
- << p.first.AttributeName << "="
- << p.second.RelName << "."
- << p.second.AttributeName << "\n";
- }
- }
- }
-};
-
-/**
* Fetch join conditions from the equi-join tree
*/
void ComputeJoinConditions(const TCoEquiJoinTuple& joinTuple,
std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
if (joinTuple.LeftScope().Maybe<TCoEquiJoinTuple>()) {
- ComputeJoinConditions( joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), joinConditions );
+ ComputeJoinConditions(joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), joinConditions);
}
if (joinTuple.RightScope().Maybe<TCoEquiJoinTuple>()) {
- ComputeJoinConditions( joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), joinConditions );
+ ComputeJoinConditions(joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), joinConditions);
}
size_t joinKeysCount = joinTuple.LeftKeys().Size() / 2;
@@ -388,15 +161,15 @@ struct TRelOptimizerNode : public IBaseOptimizerNode {
virtual ~TRelOptimizerNode() {}
virtual void Print(std::stringstream& stream, int ntabs=0) {
- for (int i=0;i<ntabs;i++){
+ for (int i = 0; i < ntabs; i++){
stream << "\t";
}
stream << "Rel: " << Label << "\n";
- for (int i=0;i<ntabs;i++){
+ for (int i = 0; i < ntabs; i++){
stream << "\t";
}
- stream << Stats << "\n";
+ stream << *Stats << "\n";
}
};
@@ -443,7 +216,7 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
* Print out the join tree, rooted at this node
*/
virtual void Print(std::stringstream& stream, int ntabs=0) {
- for (int i=0;i<ntabs;i++){
+ for (int i = 0; i < ntabs; i++){
stream << "\t";
}
@@ -455,11 +228,11 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
}
stream << "\n";
- for (int i=0;i<ntabs;i++){
+ for (int i = 0; i < ntabs; i++){
stream << "\t";
}
- stream << Stats << "\n";
+ stream << *Stats << "\n";
LeftArg->Print(stream, ntabs+1);
RightArg->Print(stream, ntabs+1);
@@ -492,6 +265,173 @@ struct pair_hash {
};
/**
+ * Graph is a data structure for the join graph
+ * It is an undirected graph, with two edges per connection (from,to) and (to,from)
+ * It needs to be constructed with addNode and addEdge methods, since its
+ * keeping various indexes updated.
+ * The graph also needs to be reordered with the breadth-first search method
+*/
+template <int N>
+struct TGraph {
+ // set of edges of the graph
+ std::unordered_set<TEdge,TEdge::HashFunction> Edges;
+
+ // neightborgh index
+ TVector<std::bitset<N>> EdgeIdx;
+
+ // number of nodes in a graph
+ int NNodes;
+
+ // mapping from rel label to node in the graph
+ THashMap<TString,int> ScopeMapping;
+
+ // mapping from node in the graph to rel label
+ TVector<TString> RevScopeMapping;
+
+ // Empty graph constructor intializes indexes to size N
+ TGraph() : EdgeIdx(N), RevScopeMapping(N) {}
+
+ // Add a node to a graph with a rel label
+ void AddNode(int nodeId, TString scope){
+ NNodes = nodeId + 1;
+ ScopeMapping[scope] = nodeId;
+ RevScopeMapping[nodeId] = scope;
+ }
+
+ // Add an edge to the graph, if the edge is already in the graph
+ // (we check both directions), no action is taken. Otherwise we
+ // insert two edges, the forward edge with original joinConditions
+ // and a reverse edge with swapped joinConditions
+ void AddEdge(TEdge e){
+ if (Edges.contains(e) || Edges.contains(TEdge(e.To, e.From))) {
+ return;
+ }
+
+ Edges.insert(e);
+ std::set<std::pair<TJoinColumn, TJoinColumn>> swappedSet;
+ for (auto c : e.JoinConditions){
+ swappedSet.insert(std::make_pair(c.second, c.first));
+ }
+ Edges.insert(TEdge(e.To,e.From,swappedSet));
+
+ EdgeIdx[e.From].set(e.To);
+ EdgeIdx[e.To].set(e.From);
+ }
+
+ // Find a node by the rel scope
+ int FindNode(TString scope){
+ return ScopeMapping[scope];
+ }
+
+ // Return a bitset of node's neighbors
+ inline std::bitset<N> FindNeighbors(int fromVertex)
+ {
+ return EdgeIdx[fromVertex];
+ }
+
+ // Find an edge that connects two subsets of graph's nodes
+ // We are guaranteed to find a match
+ TEdge FindCrossingEdge(const std::bitset<N>& S1, const std::bitset<N>& S2) {
+ for(int i = 0; i < NNodes; i++){
+ if (!S1[i]) {
+ continue;
+ }
+ for (int j = 0; j < NNodes; j++) {
+ if (!S2[j]) {
+ continue;
+ }
+ if (EdgeIdx[i].test(j)) {
+ auto it = Edges.find(TEdge(i, j));
+ Y_VERIFY_DEBUG(it != Edges.end());
+ return *it;
+ }
+ }
+ }
+ Y_ENSURE(false,"Connecting edge not found!");
+ return TEdge(-1,-1);
+ }
+
+ /**
+ * Create a union-set from the join conditions to record the equivalences.
+ * Then use the equivalence set to compute transitive closure of the graph.
+ * Transitive closure means that if we have an edge from (1,2) with join
+ * condition R.A = S.A and we have an edge from (2,3) with join condition
+ * S.A = T.A, we will find out that the join conditions form an equivalence set
+ * and add an edge (1,3) with join condition R.A = T.A.
+ */
+ void ComputeTransitiveClosure(const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions) {
+ std::set<TJoinColumn> columnSet;
+ for (auto [ leftCondition, rightCondition ] : joinConditions) {
+ columnSet.insert(leftCondition);
+ columnSet.insert(rightCondition);
+ }
+ std::vector<TJoinColumn> columns;
+ for (auto c : columnSet ) {
+ columns.push_back(c);
+ }
+
+ THashMap<TJoinColumn, int, TJoinColumn::HashFunction> indexMapping;
+ for (size_t i=0; i<columns.size(); i++) {
+ indexMapping[columns[i]] = i;
+ }
+
+ TDisjointSets ds = TDisjointSets( columns.size() );
+ for (auto [ leftCondition, rightCondition ] : joinConditions ) {
+ int leftIndex = indexMapping[leftCondition];
+ int rightIndex = indexMapping[rightCondition];
+ ds.UnionSets(leftIndex,rightIndex);
+ }
+
+ for (size_t i = 0; i < columns.size(); i++) {
+ for (size_t j = 0; j < i; j++) {
+ if (ds.CanonicSetElement(i) == ds.CanonicSetElement(j)) {
+ TJoinColumn left = columns[i];
+ TJoinColumn right = columns[j];
+ int leftNodeId = ScopeMapping[left.RelName];
+ int rightNodeId = ScopeMapping[right.RelName];
+
+ if (! Edges.contains(TEdge(leftNodeId,rightNodeId)) &&
+ ! Edges.contains(TEdge(rightNodeId,leftNodeId))) {
+ AddEdge(TEdge(leftNodeId,rightNodeId,std::make_pair(left, right)));
+ } else {
+ TEdge e1 = *Edges.find(TEdge(leftNodeId,rightNodeId));
+ if (!e1.JoinConditions.contains(std::make_pair(left, right))) {
+ e1.JoinConditions.insert(std::make_pair(left, right));
+ }
+
+ TEdge e2 = *Edges.find(TEdge(rightNodeId,leftNodeId));
+ if (!e2.JoinConditions.contains(std::make_pair(right, left))) {
+ e2.JoinConditions.insert(std::make_pair(right, left));
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Print the graph
+ */
+ void PrintGraph(std::stringstream& stream) {
+ stream << "Join Graph:\n";
+ stream << "nNodes: " << NNodes << ", nEdges: " << Edges.size() << "\n";
+
+ for(int i = 0; i < NNodes; i++) {
+ stream << "Node:" << i << "," << RevScopeMapping[i] << "\n";
+ }
+ for (const TEdge& e: Edges ) {
+ stream << "Edge: " << e.From << " -> " << e.To << "\n";
+ for (auto p : e.JoinConditions) {
+ stream << p.first.RelName << "."
+ << p.first.AttributeName << "="
+ << p.second.RelName << "."
+ << p.second.AttributeName << "\n";
+ }
+ }
+ }
+};
+
+/**
* DPcpp (Dynamic Programming with connected complement pairs) is a graph-aware
* join eumeration algorithm that only considers CSGs (Connected Sub-Graphs) of
* the join graph and computes CMPs (Complement pairs) that are also connected
@@ -508,7 +448,7 @@ class TDPccpSolver {
public:
// Construct the DPccp solver based on the join graph and data about input relations
- TDPccpSolver(TGraph<N>& g, std::vector<std::shared_ptr<TRelOptimizerNode>> rels):
+ TDPccpSolver(TGraph<N>& g, TVector<std::shared_ptr<TRelOptimizerNode>> rels):
Graph(g), Rels(rels) {
NNodes = g.NNodes;
}
@@ -537,7 +477,7 @@ class TDPccpSolver {
TGraph<N>& Graph;
// List of input relations to DPccp
- std::vector<std::shared_ptr<TRelOptimizerNode>> Rels;
+ TVector<std::shared_ptr<TRelOptimizerNode>> Rels;
// Emit connected subgraph
void EmitCsg(const std::bitset<N>&, int=0);
@@ -567,7 +507,7 @@ class TDPccpSolver {
// Print tabs
void PrintTabs(std::stringstream& stream, int ntabs) {
- for (int i=0;i<ntabs;i++)
+ for (int i = 0; i < ntabs; i++)
stream << "\t";
}
@@ -578,7 +518,7 @@ template <int N> void TDPccpSolver<N>::PrintBitset(std::stringstream& stream,
PrintTabs(stream, ntabs);
stream << name << ": " << "{";
- for (int i=0;i<NNodes;i++)
+ for (int i = 0; i < NNodes; i++)
if (s[i])
stream << i << ",";
@@ -590,7 +530,7 @@ template<int N> std::bitset<N> TDPccpSolver<N>::Neighbors(const std::bitset<N>&
std::bitset<N> res;
- for (int i=0;i<Graph.NNodes;i++) {
+ for (int i = 0; i < Graph.NNodes; i++) {
if (S[i]) {
std::bitset<N> n = Graph.FindNeighbors(i);
res = res | n;
@@ -605,14 +545,14 @@ template<int N> std::bitset<N> TDPccpSolver<N>::Neighbors(const std::bitset<N>&
template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
{
// Process singleton sets
- for (int i=NNodes-1;i>=0;i--) {
+ for (int i = NNodes-1; i >= 0; i--) {
std::bitset<N> s;
s.set(i);
- DpTable[s] = Rels[Graph.BfsMapping[i]];
+ DpTable[s] = Rels[i];
}
// Expand singleton sets
- for (int i=NNodes-1;i>=0;i--) {
+ for (int i = NNodes-1; i >= 0; i--) {
std::bitset<N> s;
s.set(i);
EmitCsg(s);
@@ -622,7 +562,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
// Return the entry of the dpTable that corresponds to the full
// set of nodes in the graph
std::bitset<N> V;
- for (int i=0;i<NNodes;i++) {
+ for (int i = 0; i < NNodes; i++) {
V.set(i);
}
@@ -643,12 +583,12 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
return;
}
- for (int i=NNodes-1;i>=0;i--) {
+ for (int i = NNodes - 1; i >= 0; i--) {
if (Ns[i]) {
std::bitset<N> S2;
S2.set(i);
- EmitCsgCmp(S,S2,ntabs+1);
- EnumerateCmpRec(S, S2, X|MakeB(Ns,i), ntabs+1);
+ EmitCsgCmp(S, S2, ntabs+1);
+ EnumerateCmpRec(S, S2, X | MakeB(Ns, i), ntabs+1);
}
}
}
@@ -660,7 +600,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
*/
template <int N> void TDPccpSolver<N>::EnumerateCsgRec(const std::bitset<N>& S, const std::bitset<N>& X, int ntabs) {
- std::bitset<N> Ns = Neighbors(S,X);
+ std::bitset<N> Ns = Neighbors(S, X);
if (Ns == std::bitset<N>()) {
return;
@@ -670,7 +610,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
std::bitset<N> next;
while(true) {
- next = NextBitset(prev,Ns);
+ next = NextBitset(prev, Ns);
EmitCsg(S | next );
if (next == Ns) {
break;
@@ -680,7 +620,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
prev.reset();
while(true) {
- next = NextBitset(prev,Ns);
+ next = NextBitset(prev, Ns);
EnumerateCsgRec(S | next, X | Ns , ntabs+1);
if (next==Ns) {
break;
@@ -709,7 +649,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
while(true) {
next = NextBitset(prev, Ns);
- EmitCsgCmp(S1,S2|next, ntabs+1);
+ EmitCsgCmp(S1, S2 | next, ntabs+1);
if (next==Ns) {
break;
}
@@ -719,7 +659,7 @@ template<int N> std::shared_ptr<TJoinOptimizerNode> TDPccpSolver<N>::Solve()
prev.reset();
while(true) {
next = NextBitset(prev, Ns);
- EnumerateCmpRec(S1, S2|next, X|Ns, ntabs+1);
+ EnumerateCmpRec(S1, S2 | next, X | Ns, ntabs+1);
if (next==Ns) {
break;
}
@@ -778,9 +718,9 @@ template <int N> void TDPccpSolver<N>::EmitCsgCmp(const std::bitset<N>& S1, cons
template <int N> std::bitset<N> TDPccpSolver<N>::MakeBiMin(const std::bitset<N>& S) {
std::bitset<N> res;
- for (int i=0;i<NNodes; i++) {
+ for (int i = 0; i < NNodes; i++) {
if (S[i]) {
- for (int j=0; j<=i; j++) {
+ for (int j = 0; j <= i; j++) {
res.set(j);
}
break;
@@ -796,8 +736,8 @@ template <int N> std::bitset<N> TDPccpSolver<N>::MakeBiMin(const std::bitset<N>&
template <int N> std::bitset<N> TDPccpSolver<N>::MakeB(const std::bitset<N>& S, int x) {
std::bitset<N> res;
- for (int i=0; i<NNodes; i++) {
- if (S[i] && i<=x) {
+ for (int i = 0; i < NNodes; i++) {
+ if (S[i] && i <= x) {
res.set(i);
}
}
@@ -815,7 +755,7 @@ template <int N> std::bitset<N> TDPccpSolver<N>::NextBitset(const std::bitset<N>
std::bitset<N> res = prev;
bool carry = true;
- for (int i=0; i<NNodes; i++)
+ for (int i = 0; i < NNodes; i++)
{
if (!carry) {
break;
@@ -909,13 +849,13 @@ TExprBase BuildTree(TExprContext& ctx, const TCoEquiJoin& equiJoin,
TExprBase RearrangeEquiJoinTree(TExprContext& ctx, const TCoEquiJoin& equiJoin,
std::shared_ptr<TJoinOptimizerNode> reorderResult) {
TVector<TExprBase> joinArgs;
- for (size_t i=0; i<equiJoin.ArgCount() - 2; i++){
+ for (size_t i = 0; i < equiJoin.ArgCount() - 2; i++){
joinArgs.push_back(equiJoin.Arg(i));
}
joinArgs.push_back(BuildTree(ctx,equiJoin,reorderResult));
- joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount()-1));
+ joinArgs.push_back(equiJoin.Arg(equiJoin.ArgCount() - 1));
return Build<TCoEquiJoin>(ctx, equiJoin.Pos())
.Add(joinArgs)
@@ -968,12 +908,12 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
YQL_ENSURE(equiJoin.ArgCount() >= 4);
if (typesCtx.StatisticsMap.contains(equiJoin.Raw()) &&
- typesCtx.StatisticsMap[ equiJoin.Raw()]->Cost.has_value()) {
+ typesCtx.StatisticsMap[equiJoin.Raw()]->Cost.has_value()) {
return node;
}
- if (! AllInnerJoins(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>())) {
+ if (! AllInnerJoins(equiJoin.Arg(equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>())) {
return node;
}
@@ -988,11 +928,13 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
auto input = equiJoin.Arg(i).Cast<TCoEquiJoinInput>();
auto joinArg = input.List();
- if (!typesCtx.StatisticsMap.contains( joinArg.Raw() )) {
+ if (!typesCtx.StatisticsMap.contains(joinArg.Raw())) {
+ YQL_CLOG(TRACE, CoreDq) << "Didn't find statistics for scope " << input.Scope().Cast<TCoAtom>().StringValue() << "\n";
+
return node;
}
- if (!typesCtx.StatisticsMap[ joinArg.Raw() ]->Cost.has_value()) {
+ if (!typesCtx.StatisticsMap[joinArg.Raw()]->Cost.has_value()) {
return node;
}
@@ -1002,8 +944,8 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
}
auto label = scope.Cast<TCoAtom>().StringValue();
- auto stats = typesCtx.StatisticsMap[ joinArg.Raw() ];
- rels.push_back( std::shared_ptr<TRelOptimizerNode>( new TRelOptimizerNode( label, stats )));
+ auto stats = typesCtx.StatisticsMap[joinArg.Raw()];
+ rels.push_back( std::shared_ptr<TRelOptimizerNode>( new TRelOptimizerNode(label, stats)));
}
YQL_CLOG(TRACE, CoreDq) << "All statistics for join in place";
@@ -1011,18 +953,18 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
std::set<std::pair<TJoinColumn, TJoinColumn>> joinConditions;
// EquiJoin argument n-2 is the actual join tree, represented as TCoEquiJoinTuple
- ComputeJoinConditions(equiJoin.Arg( equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>(), joinConditions );
+ ComputeJoinConditions(equiJoin.Arg(equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>(), joinConditions);
// construct a graph out of join conditions
TGraph<64> joinGraph;
- for (size_t i=0; i<rels.size(); i++) {
- joinGraph.AddNode(i,rels[i]->Label);
+ for (size_t i = 0; i < rels.size(); i++) {
+ joinGraph.AddNode(i, rels[i]->Label);
}
for (auto cond : joinConditions ) {
int fromNode = joinGraph.FindNode(cond.first.RelName);
int toNode = joinGraph.FindNode(cond.second.RelName);
- joinGraph.AddEdge(TEdge(fromNode,toNode,cond));
+ joinGraph.AddEdge(TEdge(fromNode, toNode, cond));
}
if (NYql::NLog::YqlLogger().NeedToLog(NYql::NLog::EComponent::ProviderKqp, NYql::NLog::ELevel::TRACE)) {
@@ -1042,8 +984,6 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
YQL_CLOG(TRACE, CoreDq) << str.str();
}
- joinGraph = joinGraph.BfsReorder();
-
// feed the graph to DPccp algorithm
TDPccpSolver<64> solver(joinGraph,rels);
std::shared_ptr<TJoinOptimizerNode> result = solver.Solve();
@@ -1056,7 +996,7 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx,
}
// rewrite the join tree and record the output statistics
- TExprBase res = RearrangeEquiJoinTree(ctx,equiJoin,result);
+ TExprBase res = RearrangeEquiJoinTree(ctx, equiJoin, result);
typesCtx.StatisticsMap[ res.Raw() ] = result->Stats;
return res;
}
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.cpp b/ydb/library/yql/dq/opt/dq_opt_stat.cpp
index fcd7b692562..6424e89478e 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_stat.cpp
@@ -57,4 +57,61 @@ void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnota
typeCtx->SetCost( input.Get(), typeCtx->GetCost( skipNullMembersInput.Raw() ) );
}
+/**
+ * Infer statistics and costs for ExtractlMembers
+ * We just return the input statistics.
+*/
+void InferStatisticsForExtractMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto extractMembers = inputNode.Cast<TCoExtractMembers>();
+ auto extractMembersInput = extractMembers.Input();
+
+ auto inputStats = typeCtx->GetStats(extractMembersInput.Raw() );
+ if (!inputStats) {
+ return;
+ }
+
+ typeCtx->SetStats( input.Get(), inputStats );
+ typeCtx->SetCost( input.Get(), typeCtx->GetCost( extractMembersInput.Raw() ) );
+}
+
+/**
+ * Infer statistics and costs for AggregateCombine
+ * We just return the input statistics.
+*/
+void InferStatisticsForAggregateCombine(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto agg = inputNode.Cast<TCoAggregateCombine>();
+ auto aggInput = agg.Input();
+
+ auto inputStats = typeCtx->GetStats(aggInput.Raw());
+ if (!inputStats) {
+ return;
+ }
+
+ typeCtx->SetStats( input.Get(), inputStats );
+ typeCtx->SetCost( input.Get(), typeCtx->GetCost( aggInput.Raw() ) );
+}
+
+/**
+ * Infer statistics and costs for AggregateMergeFinalize
+ * Just return input stats
+*/
+void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
+
+ auto inputNode = TExprBase(input);
+ auto agg = inputNode.Cast<TCoAggregateMergeFinalize>();
+ auto aggInput = agg.Input();
+
+ auto inputStats = typeCtx->GetStats(aggInput.Raw() );
+ if (!inputStats) {
+ return;
+ }
+
+ typeCtx->SetStats( input.Get(), inputStats );
+ typeCtx->SetCost( input.Get(), typeCtx->GetCost( aggInput.Raw() ) );
+}
+
} // namespace NYql::NDq {
diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.h b/ydb/library/yql/dq/opt/dq_opt_stat.h
index c4ab54aff52..01c71771344 100644
--- a/ydb/library/yql/dq/opt/dq_opt_stat.h
+++ b/ydb/library/yql/dq/opt/dq_opt_stat.h
@@ -6,5 +6,8 @@ namespace NYql::NDq {
void InferStatisticsForFlatMap(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
void InferStatisticsForSkipNullMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void InferStatisticsForExtractMembers(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void InferStatisticsForAggregateCombine(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
+void InferStatisticsForAggregateMergeFinalize(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx);
} // namespace NYql::NDq { \ No newline at end of file