aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoravevad <avevad@yandex-team.com>2023-10-16 19:15:25 +0300
committeravevad <avevad@yandex-team.com>2023-10-16 19:38:17 +0300
commita14372edbfbba7a2df98c30189e1d3080b4d4877 (patch)
treee7d3db9866fdda7644c7c1403e1a322dac985b14
parent6a25aaf186439ce43a1e0be7ba52d68c4d415c0b (diff)
downloadydb-a14372edbfbba7a2df98c30189e1d3080b4d4877.tar.gz
YQL-16823 Refactor match_recognize NFA
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp21
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h89
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp2
3 files changed, 61 insertions, 51 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
index 1a19d714d9..af82da13d8 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -41,9 +41,12 @@ class TBackTrackingMatchRecognize {
using TMatchedVars = TMatchedVars<TRange>;
public:
//TODO(YQL-16486): create a tree for backtracking(replace var names with indexes)
- struct TPatternConfiguration {
- using TPtr = std::shared_ptr<TPatternConfiguration>;
- static TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
+
+ struct TPatternConfiguration {};
+
+ struct TPatternConfigurationBuilder {
+ using TPatternConfigurationPtr = std::shared_ptr<TPatternConfiguration>;
+ static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
Y_UNUSED(pattern);
Y_UNUSED(varNameToIndex);
return std::make_shared<TPatternConfiguration>();
@@ -53,7 +56,7 @@ public:
TBackTrackingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
- const TPatternConfiguration::TPtr pattern,
+ const TPatternConfigurationBuilder::TPatternConfigurationPtr pattern,
const TContainerCacheOnContext& cache
)
: PartitionKey(std::move(partitionKey))
@@ -135,7 +138,7 @@ class TStreamingMatchRecognize {
using TRange = TPartitionList::TRange;
using TMatchedVars = TMatchedVars<TRange>;
public:
- using TPatternConfiguration = TNfaTransitionGraph;
+ using TPatternConfigurationBuilder = TNfaTransitionGraphBuilder;
TStreamingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
@@ -198,7 +201,7 @@ template <typename Algo>
class TStateForNonInterleavedPartitions
: public TComputationValue<TStateForNonInterleavedPartitions<Algo>>
{
- using TRowPatternConfiguration = typename Algo::TPatternConfiguration;
+ using TRowPatternConfigurationBuilder = typename Algo::TPatternConfigurationBuilder;
public:
TStateForNonInterleavedPartitions(
TMemoryUsageInfo* memInfo,
@@ -213,7 +216,7 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
- , RowPatternConfiguration(TRowPatternConfiguration::Create(parameters.Pattern, parameters.VarNamesLookup))
+ , RowPatternConfiguration(TRowPatternConfigurationBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
, Cache(cache)
, Terminating(false)
{}
@@ -280,7 +283,7 @@ private:
IComputationNode* PartitionKey;
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
- const typename TRowPatternConfiguration::TPtr RowPatternConfiguration;
+ const typename TRowPatternConfigurationBuilder::TPatternConfigurationPtr RowPatternConfiguration;
const TContainerCacheOnContext& Cache;
NUdf::TUnboxedValue DelayedRow;
bool Terminating;
@@ -304,7 +307,7 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
- , NfaTransitionGraph(TNfaTransitionGraph::Create(parameters.Pattern, parameters.VarNamesLookup))
+ , NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
, Cache(cache)
{
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
index a99250918d..a8cb7a3209 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
@@ -23,22 +23,29 @@ using TNfaTransition = std::variant<
TQuantityExitTransition
>;
-class TNfaTransitionGraph {
-public:
+struct TNfaTransitionGraph {
+ std::vector<TNfaTransition> Transitions;
+ size_t Input;
+ size_t Output;
+
using TPtr = std::shared_ptr<TNfaTransitionGraph>;
- static TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
- auto result = TPtr(new TNfaTransitionGraph());
- auto item = result->BuildTerms(pattern, varNameToIndex);
- result->Input = item.Input;
- result->Output = item.Output;
- return result;
- }
- friend class TNfa;
+};
+
+class TNfaTransitionGraphBuilder {
private:
struct TNfaItem {
size_t Input;
size_t Output;
};
+
+ TNfaTransitionGraphBuilder(TNfaTransitionGraph::TPtr graph)
+ : Graph(graph) {}
+
+ size_t AddNode() {
+ Graph->Transitions.resize(Graph->Transitions.size() + 1);
+ return Graph->Transitions.size() - 1;
+ }
+
TNfaItem BuildTerms(const std::vector<TRowPatternTerm>& terms, const THashMap<TString, size_t>& varNameToIndex) {
auto input = AddNode();
auto output = AddNode();
@@ -46,9 +53,9 @@ private:
for (const auto& t: terms) {
auto a = BuildTerm(t, varNameToIndex);
fromInput.push_back(a.Input);
- Transitions[a.Output] = TEpsilonTransitions({output});
+ Graph->Transitions[a.Output] = TEpsilonTransitions({output});
}
- Transitions[input] = std::move(fromInput);
+ Graph->Transitions[input] = std::move(fromInput);
return {input, output};
}
TNfaItem BuildTerm(const TRowPatternTerm& term, const THashMap<TString, size_t>& varNameToIndex) {
@@ -59,10 +66,10 @@ private:
automata.push_back(BuildFactor(f, varNameToIndex));
}
for (size_t i = 0; i != automata.size() - 1; ++i) {
- Transitions[automata[i].Output] = TEpsilonTransitions({automata[i+1].Input});
+ Graph->Transitions[automata[i].Output] = TEpsilonTransitions({automata[i + 1].Input});
}
- Transitions[input] = TEpsilonTransitions({automata.front().Input});
- Transitions[automata.back().Output] = TEpsilonTransitions({output});
+ Graph->Transitions[input] = TEpsilonTransitions({automata.front().Input});
+ Graph->Transitions[automata.back().Output] = TEpsilonTransitions({output});
return {input, output};
}
TNfaItem BuildFactor(const TRowPatternFactor& factor, const THashMap<TString, size_t>& varNameToIndex) {
@@ -75,28 +82,31 @@ private:
auto fromInput = TEpsilonTransitions{interim};
if (factor.QuantityMin == 0)
fromInput.push_back(output);
- Transitions[input] = fromInput;
- Transitions[interim] = TQuantityEnterTransition{item.Input};
- Transitions[item.Output] = std::pair{std::pair{factor.QuantityMin, factor.QuantityMax}, std::pair{item.Input, output}};
+ Graph->Transitions[input] = fromInput;
+ Graph->Transitions[interim] = TQuantityEnterTransition{item.Input};
+ Graph->Transitions[item.Output] = std::pair{std::pair{factor.QuantityMin, factor.QuantityMax}, std::pair{item.Input, output}};
return {input, output};
}
TNfaItem BuildVar(ui32 varIndex) {
auto input = AddNode();
auto matchVar = AddNode();
auto output = AddNode();
- Transitions[input] = TEpsilonTransitions({matchVar});
- Transitions[matchVar] = std::pair{varIndex, output};
+ Graph->Transitions[input] = TEpsilonTransitions({matchVar});
+ Graph->Transitions[matchVar] = std::pair{varIndex, output};
return {input, output};
}
-
- size_t AddNode() {
- Transitions.resize(Transitions.size() + 1);
- return Transitions.size() - 1;
+public:
+ using TPatternConfigurationPtr = TNfaTransitionGraph::TPtr;
+ static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
+ auto result = std::make_shared<TNfaTransitionGraph>();
+ TNfaTransitionGraphBuilder builder(result);
+ auto item = builder.BuildTerms(pattern, varNameToIndex);
+ result->Input = item.Input;
+ result->Output = item.Output;
+ return result;
}
private:
- std::vector<TNfaTransition> Transitions;
- size_t Input;
- size_t Output;
+ TNfaTransitionGraph::TPtr Graph;
};
class TNfa {
@@ -106,28 +116,26 @@ class TNfa {
TState(size_t index, const TMatchedVars& vars, std::stack<ui64>&& quantifiers)
: Index(index)
, Vars(vars)
- , Quantifiers(quantifiers)
- {}
+ , Quantifiers(quantifiers) {}
const size_t Index;
TMatchedVars Vars;
std::stack<ui64> Quantifiers; //get rid of this
- friend inline bool operator < (const TState& lhs, const TState& rhs) {
+ friend inline bool operator<(const TState& lhs, const TState& rhs) {
return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars);
}
- friend inline bool operator == (const TState& lhs, const TState& rhs) {
+ friend inline bool operator==(const TState& lhs, const TState& rhs) {
return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) == std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars);
}
};
public:
TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
- : TransitionGraph(transitionGraph)
- , MatchedRangesArg(matchedRangesArg)
- , Defines(defines)
- {
+ : TransitionGraph(transitionGraph)
+ , MatchedRangesArg(matchedRangesArg)
+ , Defines(defines) {
}
- void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx){
+ void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
ActiveStates.emplace(TransitionGraph->Input, TMatchedVars(Defines.size()), std::stack<ui64>{});
MakeEpsilonTransitions();
std::set<TState> newStates;
@@ -157,7 +165,7 @@ public:
bool HasMatched() const {
for (auto& s: ActiveStates) {
- if (s.Index == TransitionGraph->Output){
+ if (s.Index == TransitionGraph->Output) {
return true;
}
}
@@ -166,7 +174,7 @@ public:
std::optional<TMatchedVars> GetMatched() {
for (auto& s: ActiveStates) {
- if (s.Index == TransitionGraph->Output){
+ if (s.Index == TransitionGraph->Output) {
auto result = s.Vars;
ActiveStates.erase(s);
return result;
@@ -179,11 +187,10 @@ private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState>;
struct TTransitionVisitor {
- TTransitionVisitor(const TState& state, TStateSet& newStates, TStateSet& deletedStates)
+ TTransitionVisitor(const TState& state, TStateSet& newStates, TStateSet& deletedStates)
: State(state)
, NewStates(newStates)
- , DeletedStates(deletedStates)
- {}
+ , DeletedStates(deletedStates) {}
void operator()(const TMatchedVarTransition& var) {
//Transitions of TMatchedVarTransition type are handled in ProcessRow method
Y_UNUSED(var);
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
index 7dec0a2287..46ba1a68f2 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
@@ -49,7 +49,7 @@ struct TNfaSetup {
for(size_t i = 0; i != vars.size(); ++i) {
varNameLookup[varVec[i]] = i;
}
- const auto& transitionGraph = TNfaTransitionGraph::Create(pattern, varNameLookup);
+ const auto& transitionGraph = TNfaTransitionGraphBuilder::Create(pattern, varNameLookup);
TComputationNodePtrVector defines;
defines.reserve(Defines.size());
for (auto& d: Defines) {