aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoravevad <avevad@yandex-team.com>2023-11-03 17:15:12 +0300
committeravevad <avevad@yandex-team.com>2023-11-03 17:52:14 +0300
commitdb7511009fbab76c1baf85a729be9267aaf58911 (patch)
treecd5a5a39771b9be8f02a2b454b1207e4ff34630d
parent478827a455e8dd660a65cd23c01e796e667dce6c (diff)
downloadydb-db7511009fbab76c1baf85a729be9267aaf58911.tar.gz
YQL-16823 Add new class for optimization of epsilon transitions in match_recognize NFA graph
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h119
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp71
2 files changed, 184 insertions, 6 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
index ac89950729..a5918f17b4 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
@@ -11,7 +11,8 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
using namespace NYql::NMatchRecognize;
-struct TVoidTransition{};
+struct TVoidTransition {
+};
using TEpsilonTransition = size_t; //to
using TEpsilonTransitions = std::vector<TEpsilonTransition, TMKQLAllocator<TEpsilonTransition>>;
using TMatchedVarTransition = std::pair<ui32, size_t>; //{varIndex, to}
@@ -25,6 +26,40 @@ using TNfaTransition = std::variant<
TQuantityExitTransition
>;
+struct TNfaTransitionDestinationVisitor {
+ std::function<size_t(size_t)> callback;
+
+ template<typename Callback>
+ explicit TNfaTransitionDestinationVisitor(Callback callback)
+ : callback(std::move(callback)) {}
+
+ TNfaTransition operator()(TVoidTransition tr) const {
+ return tr;
+ }
+
+ TNfaTransition operator()(TMatchedVarTransition tr) const {
+ tr.second = callback(tr.second);
+ return tr;
+ }
+
+ TNfaTransition operator()(TEpsilonTransitions tr) const {
+ for (size_t& toNode: tr) {
+ toNode = callback(toNode);
+ }
+ return tr;
+ }
+
+ TNfaTransition operator()(TQuantityEnterTransition tr) const {
+ return callback(tr);
+ }
+
+ TNfaTransition operator()(TQuantityExitTransition tr) const {
+ tr.second.first = callback(tr.second.first);
+ tr.second.second = callback(tr.second.second);
+ return tr;
+ }
+};
+
struct TNfaTransitionGraph {
std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>> Transitions;
size_t Input;
@@ -33,6 +68,80 @@ struct TNfaTransitionGraph {
using TPtr = std::shared_ptr<TNfaTransitionGraph>;
};
+class TNfaTransitionGraphOptimizer {
+public:
+ TNfaTransitionGraphOptimizer(TNfaTransitionGraph::TPtr graph)
+ : Graph(graph) {}
+
+ void DoOptimizations() {
+ EliminateEpsilonChains();
+ CollectGarbage();
+ }
+private:
+ void EliminateEpsilonChains() {
+ for (size_t node = 0; node != Graph->Transitions.size(); node++) {
+ if (auto* ts = std::get_if<TEpsilonTransitions>(&Graph->Transitions[node])) {
+ // new vector of eps transitions,
+ // contains refs to all nodes which are reachable from oldNode via eps transitions
+ TEpsilonTransitions optimizedTs;
+ auto dfsStack = *ts;
+ while (!dfsStack.empty()) {
+ auto curNode = dfsStack.back();
+ dfsStack.pop_back();
+ if (auto* curTs = std::get_if<TEpsilonTransitions>(&Graph->Transitions[curNode])) {
+ std::copy(curTs->begin(), curTs->end(), std::back_inserter(dfsStack));
+ } else {
+ optimizedTs.push_back(curNode);
+ }
+ }
+ *ts = optimizedTs;
+ }
+ }
+ }
+ void CollectGarbage() {
+ auto oldInput = Graph->Input;
+ auto oldOutput = Graph->Output;
+ decltype(Graph->Transitions) oldTransitions;
+ Graph->Transitions.swap(oldTransitions);
+ // Scan for reachable nodes and map old node ids to new node ids
+ std::vector<std::optional<size_t>> mapping(oldTransitions.size(), std::nullopt);
+ std::vector<size_t> dfsStack = {oldInput};
+ mapping[oldInput] = 0;
+ Graph->Transitions.emplace_back();
+ while (!dfsStack.empty()) {
+ auto oldNode = dfsStack.back();
+ dfsStack.pop_back();
+ std::visit(TNfaTransitionDestinationVisitor([&](size_t oldToNode) {
+ if (!mapping[oldToNode]) {
+ mapping[oldToNode] = Graph->Transitions.size();
+ Graph->Transitions.emplace_back();
+ dfsStack.push_back(oldToNode);
+ }
+ return 0;
+ }), oldTransitions[oldNode]);
+ }
+ // Rebuild transition vector
+ for (size_t oldNode = 0; oldNode != oldTransitions.size(); oldNode++) {
+ if (!mapping[oldNode]) {
+ continue;
+ }
+ auto node = mapping[oldNode].value();
+ if (oldNode == oldInput) {
+ Graph->Input = node;
+ }
+ if (oldNode == oldOutput) {
+ Graph->Output = node;
+ }
+ Graph->Transitions[node] = oldTransitions[oldNode];
+ Graph->Transitions[node] = std::visit(TNfaTransitionDestinationVisitor([&](size_t oldToNode) {
+ return mapping[oldToNode].value();
+ }), Graph->Transitions[node]);
+ }
+ }
+
+ TNfaTransitionGraph::TPtr Graph;
+};
+
class TNfaTransitionGraphBuilder {
private:
struct TNfaItem {
@@ -44,7 +153,7 @@ private:
: Graph(graph) {}
size_t AddNode() {
- Graph->Transitions.resize(Graph->Transitions.size() + 1);
+ Graph->Transitions.emplace_back();
return Graph->Transitions.size() - 1;
}
@@ -114,6 +223,8 @@ public:
auto item = builder.BuildTerms(pattern, varNameToIndex);
result->Input = item.Input;
result->Output = item.Output;
+ TNfaTransitionGraphOptimizer optimizer(result);
+ optimizer.DoOptimizations();
return result;
}
private:
@@ -205,8 +316,7 @@ private:
TTransitionVisitor(const TState& state, TStateSet& newStates, TStateSet& deletedStates)
: State(state)
, NewStates(newStates)
- , DeletedStates(deletedStates)
- {}
+ , DeletedStates(deletedStates) {}
void operator()(const TVoidTransition&) {
//Do nothing for void
}
@@ -266,6 +376,7 @@ private:
} while (MakeEpsilonTransitionsImpl());
}
+
TNfaTransitionGraph::TPtr TransitionGraph;
IComputationExternalNode* const MatchedRangesArg;
const TComputationNodePtrVector Defines;
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
index 7c7d94714a..897a973f64 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
@@ -40,7 +40,7 @@ struct TNfaSetup {
return graph;
}
- TNfa InitNfa(const TRowPattern& pattern) {
+ static THashMap<TString, size_t> BuildVarLookup(const TRowPattern& pattern) {
const auto& vars = GetPatternVars(pattern);
std::vector<TString> varVec{vars.cbegin(), vars.cend()};
//Simulate implicit name ordering in YQL structs
@@ -49,7 +49,11 @@ struct TNfaSetup {
for(size_t i = 0; i != vars.size(); ++i) {
varNameLookup[varVec[i]] = i;
}
- const auto& transitionGraph = TNfaTransitionGraphBuilder::Create(pattern, varNameLookup);
+ return varNameLookup;
+ }
+
+ TNfa InitNfa(const TRowPattern& pattern) {
+ const auto& transitionGraph = TNfaTransitionGraphBuilder::Create(pattern, BuildVarLookup(pattern));
TComputationNodePtrVector defines;
defines.reserve(Defines.size());
for (auto& d: Defines) {
@@ -92,6 +96,30 @@ struct TNfaSetup {
TNfa Nfa;
};
+static TVector<size_t> CountNonEpsilonInputs(const TNfaTransitionGraph& graph) {
+ TVector<size_t> nonEpsIns(graph.Transitions.size());
+ for (size_t node = 0; node != graph.Transitions.size(); node++) {
+ if (!std::holds_alternative<TEpsilonTransitions>(graph.Transitions[node])) {
+ std::visit(TNfaTransitionDestinationVisitor([&](size_t toNode){
+ nonEpsIns[toNode]++;
+ return 0;
+ }), graph.Transitions[node]);
+ }
+ }
+ return nonEpsIns;
+}
+
+static TVector<size_t> CountNonEpsilonOutputs(const TNfaTransitionGraph& graph) {
+ TVector<size_t> nonEpsOuts(graph.Transitions.size());
+ nonEpsOuts.resize(graph.Transitions.size());
+ for (size_t node = 0; node < graph.Transitions.size(); node++) {
+ if (!std::holds_alternative<TEpsilonTransitions>(graph.Transitions[node])) {
+ nonEpsOuts[node]++;
+ }
+ }
+ return nonEpsOuts;
+}
+
} //namespace
Y_UNIT_TEST_SUITE(MatchRecognizeNfa) {
@@ -103,6 +131,45 @@ Y_UNIT_TEST_SUITE(MatchRecognizeNfa) {
const auto& output = transitionGraph->Transitions.at(transitionGraph->Output);
UNIT_ASSERT(std::get_if<TVoidTransition>(&output));
}
+ Y_UNIT_TEST(EpsilonChainsEliminated) {
+ TScopedAlloc alloc(__LOCATION__);
+ const TRowPattern pattern{
+ {
+ TRowPatternFactor{"A", 1, 1, false, false},
+ TRowPatternFactor{"B", 1, 100, false, false},
+ TRowPatternFactor{
+ TRowPattern{
+ {TRowPatternFactor{"C", 1, 1, false, false}},
+ {TRowPatternFactor{"D", 1, 1, false, false}}
+ },
+ 1, 1, false, false
+ }
+ },
+ {
+ TRowPatternFactor{
+ TRowPattern{{
+ TRowPatternFactor{"E", 1, 1, false, false},
+ TRowPatternFactor{"F", 1, 100, false, false},
+ }},
+ 2, 100, false, false
+ },
+ TRowPatternFactor{"G", 1, 1, false, false}
+ }
+ };
+ const auto graph = TNfaTransitionGraphBuilder::Create(pattern, TNfaSetup::BuildVarLookup(pattern));
+ auto nonEpsIns = CountNonEpsilonInputs(*graph);
+ auto nonEpsOuts = CountNonEpsilonOutputs(*graph);
+ for(size_t node = 0; node < nonEpsIns.size(); node++) {
+ if (node == graph->Input) {
+ continue;
+ }
+ if (node == graph->Output) {
+ continue;
+ }
+ UNIT_ASSERT_GT(nonEpsIns[node] + nonEpsOuts[node], 0);
+ }
+ }
+
//Tests for NFA-based engine for MATCH_RECOGNIZE
//In the full implementation pattern variables are calculated as lambda predicates on input partition