diff options
author | avevad <avevad@yandex-team.com> | 2023-10-16 19:15:25 +0300 |
---|---|---|
committer | avevad <avevad@yandex-team.com> | 2023-10-16 19:38:17 +0300 |
commit | a14372edbfbba7a2df98c30189e1d3080b4d4877 (patch) | |
tree | e7d3db9866fdda7644c7c1403e1a322dac985b14 | |
parent | 6a25aaf186439ce43a1e0be7ba52d68c4d415c0b (diff) | |
download | ydb-a14372edbfbba7a2df98c30189e1d3080b4d4877.tar.gz |
YQL-16823 Refactor match_recognize NFA
3 files changed, 61 insertions, 51 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp index 1a19d714d9..af82da13d8 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp @@ -41,9 +41,12 @@ class TBackTrackingMatchRecognize { using TMatchedVars = TMatchedVars<TRange>; public: //TODO(YQL-16486): create a tree for backtracking(replace var names with indexes) - struct TPatternConfiguration { - using TPtr = std::shared_ptr<TPatternConfiguration>; - static TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) { + + struct TPatternConfiguration {}; + + struct TPatternConfigurationBuilder { + using TPatternConfigurationPtr = std::shared_ptr<TPatternConfiguration>; + static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) { Y_UNUSED(pattern); Y_UNUSED(varNameToIndex); return std::make_shared<TPatternConfiguration>(); @@ -53,7 +56,7 @@ public: TBackTrackingMatchRecognize( NUdf::TUnboxedValue&& partitionKey, const TMatchRecognizeProcessorParameters& parameters, - const TPatternConfiguration::TPtr pattern, + const TPatternConfigurationBuilder::TPatternConfigurationPtr pattern, const TContainerCacheOnContext& cache ) : PartitionKey(std::move(partitionKey)) @@ -135,7 +138,7 @@ class TStreamingMatchRecognize { using TRange = TPartitionList::TRange; using TMatchedVars = TMatchedVars<TRange>; public: - using TPatternConfiguration = TNfaTransitionGraph; + using TPatternConfigurationBuilder = TNfaTransitionGraphBuilder; TStreamingMatchRecognize( NUdf::TUnboxedValue&& partitionKey, const TMatchRecognizeProcessorParameters& parameters, @@ -198,7 +201,7 @@ template <typename Algo> class TStateForNonInterleavedPartitions : public TComputationValue<TStateForNonInterleavedPartitions<Algo>> { - using TRowPatternConfiguration = typename Algo::TPatternConfiguration; + using TRowPatternConfigurationBuilder = typename Algo::TPatternConfigurationBuilder; public: TStateForNonInterleavedPartitions( TMemoryUsageInfo* memInfo, @@ -213,7 +216,7 @@ public: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , Parameters(parameters) - , RowPatternConfiguration(TRowPatternConfiguration::Create(parameters.Pattern, parameters.VarNamesLookup)) + , RowPatternConfiguration(TRowPatternConfigurationBuilder::Create(parameters.Pattern, parameters.VarNamesLookup)) , Cache(cache) , Terminating(false) {} @@ -280,7 +283,7 @@ private: IComputationNode* PartitionKey; TValuePackerGeneric<false> PartitionKeyPacker; const TMatchRecognizeProcessorParameters& Parameters; - const typename TRowPatternConfiguration::TPtr RowPatternConfiguration; + const typename TRowPatternConfigurationBuilder::TPatternConfigurationPtr RowPatternConfiguration; const TContainerCacheOnContext& Cache; NUdf::TUnboxedValue DelayedRow; bool Terminating; @@ -304,7 +307,7 @@ public: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , Parameters(parameters) - , NfaTransitionGraph(TNfaTransitionGraph::Create(parameters.Pattern, parameters.VarNamesLookup)) + , NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup)) , Cache(cache) { } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h index a99250918d..a8cb7a3209 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h @@ -23,22 +23,29 @@ using TNfaTransition = std::variant< TQuantityExitTransition >; -class TNfaTransitionGraph { -public: +struct TNfaTransitionGraph { + std::vector<TNfaTransition> Transitions; + size_t Input; + size_t Output; + using TPtr = std::shared_ptr<TNfaTransitionGraph>; - static TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) { - auto result = TPtr(new TNfaTransitionGraph()); - auto item = result->BuildTerms(pattern, varNameToIndex); - result->Input = item.Input; - result->Output = item.Output; - return result; - } - friend class TNfa; +}; + +class TNfaTransitionGraphBuilder { private: struct TNfaItem { size_t Input; size_t Output; }; + + TNfaTransitionGraphBuilder(TNfaTransitionGraph::TPtr graph) + : Graph(graph) {} + + size_t AddNode() { + Graph->Transitions.resize(Graph->Transitions.size() + 1); + return Graph->Transitions.size() - 1; + } + TNfaItem BuildTerms(const std::vector<TRowPatternTerm>& terms, const THashMap<TString, size_t>& varNameToIndex) { auto input = AddNode(); auto output = AddNode(); @@ -46,9 +53,9 @@ private: for (const auto& t: terms) { auto a = BuildTerm(t, varNameToIndex); fromInput.push_back(a.Input); - Transitions[a.Output] = TEpsilonTransitions({output}); + Graph->Transitions[a.Output] = TEpsilonTransitions({output}); } - Transitions[input] = std::move(fromInput); + Graph->Transitions[input] = std::move(fromInput); return {input, output}; } TNfaItem BuildTerm(const TRowPatternTerm& term, const THashMap<TString, size_t>& varNameToIndex) { @@ -59,10 +66,10 @@ private: automata.push_back(BuildFactor(f, varNameToIndex)); } for (size_t i = 0; i != automata.size() - 1; ++i) { - Transitions[automata[i].Output] = TEpsilonTransitions({automata[i+1].Input}); + Graph->Transitions[automata[i].Output] = TEpsilonTransitions({automata[i + 1].Input}); } - Transitions[input] = TEpsilonTransitions({automata.front().Input}); - Transitions[automata.back().Output] = TEpsilonTransitions({output}); + Graph->Transitions[input] = TEpsilonTransitions({automata.front().Input}); + Graph->Transitions[automata.back().Output] = TEpsilonTransitions({output}); return {input, output}; } TNfaItem BuildFactor(const TRowPatternFactor& factor, const THashMap<TString, size_t>& varNameToIndex) { @@ -75,28 +82,31 @@ private: auto fromInput = TEpsilonTransitions{interim}; if (factor.QuantityMin == 0) fromInput.push_back(output); - Transitions[input] = fromInput; - Transitions[interim] = TQuantityEnterTransition{item.Input}; - Transitions[item.Output] = std::pair{std::pair{factor.QuantityMin, factor.QuantityMax}, std::pair{item.Input, output}}; + Graph->Transitions[input] = fromInput; + Graph->Transitions[interim] = TQuantityEnterTransition{item.Input}; + Graph->Transitions[item.Output] = std::pair{std::pair{factor.QuantityMin, factor.QuantityMax}, std::pair{item.Input, output}}; return {input, output}; } TNfaItem BuildVar(ui32 varIndex) { auto input = AddNode(); auto matchVar = AddNode(); auto output = AddNode(); - Transitions[input] = TEpsilonTransitions({matchVar}); - Transitions[matchVar] = std::pair{varIndex, output}; + Graph->Transitions[input] = TEpsilonTransitions({matchVar}); + Graph->Transitions[matchVar] = std::pair{varIndex, output}; return {input, output}; } - - size_t AddNode() { - Transitions.resize(Transitions.size() + 1); - return Transitions.size() - 1; +public: + using TPatternConfigurationPtr = TNfaTransitionGraph::TPtr; + static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) { + auto result = std::make_shared<TNfaTransitionGraph>(); + TNfaTransitionGraphBuilder builder(result); + auto item = builder.BuildTerms(pattern, varNameToIndex); + result->Input = item.Input; + result->Output = item.Output; + return result; } private: - std::vector<TNfaTransition> Transitions; - size_t Input; - size_t Output; + TNfaTransitionGraph::TPtr Graph; }; class TNfa { @@ -106,28 +116,26 @@ class TNfa { TState(size_t index, const TMatchedVars& vars, std::stack<ui64>&& quantifiers) : Index(index) , Vars(vars) - , Quantifiers(quantifiers) - {} + , Quantifiers(quantifiers) {} const size_t Index; TMatchedVars Vars; std::stack<ui64> Quantifiers; //get rid of this - friend inline bool operator < (const TState& lhs, const TState& rhs) { + friend inline bool operator<(const TState& lhs, const TState& rhs) { return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars); } - friend inline bool operator == (const TState& lhs, const TState& rhs) { + friend inline bool operator==(const TState& lhs, const TState& rhs) { return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) == std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars); } }; public: TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines) - : TransitionGraph(transitionGraph) - , MatchedRangesArg(matchedRangesArg) - , Defines(defines) - { + : TransitionGraph(transitionGraph) + , MatchedRangesArg(matchedRangesArg) + , Defines(defines) { } - void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx){ + void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) { ActiveStates.emplace(TransitionGraph->Input, TMatchedVars(Defines.size()), std::stack<ui64>{}); MakeEpsilonTransitions(); std::set<TState> newStates; @@ -157,7 +165,7 @@ public: bool HasMatched() const { for (auto& s: ActiveStates) { - if (s.Index == TransitionGraph->Output){ + if (s.Index == TransitionGraph->Output) { return true; } } @@ -166,7 +174,7 @@ public: std::optional<TMatchedVars> GetMatched() { for (auto& s: ActiveStates) { - if (s.Index == TransitionGraph->Output){ + if (s.Index == TransitionGraph->Output) { auto result = s.Vars; ActiveStates.erase(s); return result; @@ -179,11 +187,10 @@ private: //TODO (zverevgeny): Consider to change to std::vector for the sake of perf using TStateSet = std::set<TState>; struct TTransitionVisitor { - TTransitionVisitor(const TState& state, TStateSet& newStates, TStateSet& deletedStates) + TTransitionVisitor(const TState& state, TStateSet& newStates, TStateSet& deletedStates) : State(state) , NewStates(newStates) - , DeletedStates(deletedStates) - {} + , DeletedStates(deletedStates) {} void operator()(const TMatchedVarTransition& var) { //Transitions of TMatchedVarTransition type are handled in ProcessRow method Y_UNUSED(var); diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp index 7dec0a2287..46ba1a68f2 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp @@ -49,7 +49,7 @@ struct TNfaSetup { for(size_t i = 0; i != vars.size(); ++i) { varNameLookup[varVec[i]] = i; } - const auto& transitionGraph = TNfaTransitionGraph::Create(pattern, varNameLookup); + const auto& transitionGraph = TNfaTransitionGraphBuilder::Create(pattern, varNameLookup); TComputationNodePtrVector defines; defines.reserve(Defines.size()); for (auto& d: Defines) { |