aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <aneporada@ydb.tech>2023-10-22 18:37:57 +0300
committeraneporada <aneporada@ydb.tech>2023-10-22 19:14:30 +0300
commit3dbd883b0904a0edcf65cf77cc00f257f1d419d9 (patch)
treef4297b8eff18d955e6e31ae1c42bc3a288cf1ece
parent1a21631ab8cf2a29afe6498d3bb6471c922860bc (diff)
downloadydb-3dbd883b0904a0edcf65cf77cc00f257f1d419d9.tar.gz
Fix memoization in EqualNodes/CompareNodes
-rw-r--r--ydb/library/yql/core/yql_expr_csee.cpp37
1 files changed, 29 insertions, 8 deletions
diff --git a/ydb/library/yql/core/yql_expr_csee.cpp b/ydb/library/yql/core/yql_expr_csee.cpp
index 2b9cd321228..e9a7662183f 100644
--- a/ydb/library/yql/core/yql_expr_csee.cpp
+++ b/ydb/library/yql/core/yql_expr_csee.cpp
@@ -220,17 +220,29 @@ namespace {
return hash;
}
+ using TEqualResults = THashMap<std::pair<const TExprNode*, const TExprNode*>, bool>;
+ bool DoEqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame,
+ TEqualResults& visited, const TColumnOrderStorage& coStore);
bool EqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame,
- TNodeSet& visited, const TColumnOrderStorage& coStore)
+ TEqualResults& visited, const TColumnOrderStorage& coStore)
{
if (&left == &right) {
return true;
}
- if (!visited.emplace(&left).second) {
- return true;
+ auto key = std::make_pair(&left, &right);
+ if (auto it = visited.find(key); it != visited.end()) {
+ return it->second;
}
+ bool res = DoEqualNodes(left, currLeftFrame, right, currRightFrame, visited, coStore);
+ visited[key] = res;
+ return res;
+ }
+
+ bool DoEqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame,
+ TEqualResults& visited, const TColumnOrderStorage& coStore)
+ {
if (left.Type() != right.Type()) {
return false;
}
@@ -350,15 +362,24 @@ namespace {
return false;
}
- int CompareNodes(const TExprNode& left, const TExprNode& right, TNodeSet& visited) {
+ using TCompareResults = THashMap<std::pair<const TExprNode*, const TExprNode*>, int>;
+ int DoCompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited);
+ int CompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited) {
if (&left == &right) {
return 0;
}
- if (!visited.emplace(&left).second) {
- return 0;
+ auto key = std::make_pair(&left, &right);
+ if (auto it = visited.find(key); it != visited.end()) {
+ return it->second;
}
+ int res = DoCompareNodes(left, right, visited);
+ visited[key] = res;
+ return res;
+ }
+
+ int DoCompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited) {
if (left.Type() != right.Type()) {
return (int)left.Type() - (int)right.Type();
}
@@ -492,7 +513,7 @@ namespace {
}
bool EqualNodes(const TExprNode& left, const TExprNode& right, const TColumnOrderStorage& coStore) {
- TNodeSet visited;
+ TEqualResults visited;
TLambdaFrame frame;
return EqualNodes(left, frame, right, frame, visited, coStore);
}
@@ -605,7 +626,7 @@ IGraphTransformer::TStatus EliminateCommonSubExpressions(const TExprNode::TPtr&
}
int CompareNodes(const TExprNode& left, const TExprNode& right) {
- TNodeSet visited;
+ TCompareResults visited;
return CompareNodes(left, right, visited);
}