diff options
author | aneporada <aneporada@ydb.tech> | 2023-10-22 18:37:57 +0300 |
---|---|---|
committer | aneporada <aneporada@ydb.tech> | 2023-10-22 19:14:30 +0300 |
commit | 3dbd883b0904a0edcf65cf77cc00f257f1d419d9 (patch) | |
tree | f4297b8eff18d955e6e31ae1c42bc3a288cf1ece | |
parent | 1a21631ab8cf2a29afe6498d3bb6471c922860bc (diff) | |
download | ydb-3dbd883b0904a0edcf65cf77cc00f257f1d419d9.tar.gz |
Fix memoization in EqualNodes/CompareNodes
-rw-r--r-- | ydb/library/yql/core/yql_expr_csee.cpp | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/ydb/library/yql/core/yql_expr_csee.cpp b/ydb/library/yql/core/yql_expr_csee.cpp index 2b9cd321228..e9a7662183f 100644 --- a/ydb/library/yql/core/yql_expr_csee.cpp +++ b/ydb/library/yql/core/yql_expr_csee.cpp @@ -220,17 +220,29 @@ namespace { return hash; } + using TEqualResults = THashMap<std::pair<const TExprNode*, const TExprNode*>, bool>; + bool DoEqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame, + TEqualResults& visited, const TColumnOrderStorage& coStore); bool EqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame, - TNodeSet& visited, const TColumnOrderStorage& coStore) + TEqualResults& visited, const TColumnOrderStorage& coStore) { if (&left == &right) { return true; } - if (!visited.emplace(&left).second) { - return true; + auto key = std::make_pair(&left, &right); + if (auto it = visited.find(key); it != visited.end()) { + return it->second; } + bool res = DoEqualNodes(left, currLeftFrame, right, currRightFrame, visited, coStore); + visited[key] = res; + return res; + } + + bool DoEqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame, + TEqualResults& visited, const TColumnOrderStorage& coStore) + { if (left.Type() != right.Type()) { return false; } @@ -350,15 +362,24 @@ namespace { return false; } - int CompareNodes(const TExprNode& left, const TExprNode& right, TNodeSet& visited) { + using TCompareResults = THashMap<std::pair<const TExprNode*, const TExprNode*>, int>; + int DoCompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited); + int CompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited) { if (&left == &right) { return 0; } - if (!visited.emplace(&left).second) { - return 0; + auto key = std::make_pair(&left, &right); + if (auto it = visited.find(key); it != visited.end()) { + return it->second; } + int res = DoCompareNodes(left, right, visited); + visited[key] = res; + return res; + } + + int DoCompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited) { if (left.Type() != right.Type()) { return (int)left.Type() - (int)right.Type(); } @@ -492,7 +513,7 @@ namespace { } bool EqualNodes(const TExprNode& left, const TExprNode& right, const TColumnOrderStorage& coStore) { - TNodeSet visited; + TEqualResults visited; TLambdaFrame frame; return EqualNodes(left, frame, right, frame, visited, coStore); } @@ -605,7 +626,7 @@ IGraphTransformer::TStatus EliminateCommonSubExpressions(const TExprNode::TPtr& } int CompareNodes(const TExprNode& left, const TExprNode& right) { - TNodeSet visited; + TCompareResults visited; return CompareNodes(left, right, visited); } |