diff options
author | vokayndzop <vokayndzop@yandex-team.com> | 2024-12-12 12:36:12 +0300 |
---|---|---|
committer | vokayndzop <vokayndzop@yandex-team.com> | 2024-12-12 12:53:02 +0300 |
commit | 560023d550a203f58c64cdf9844dac139f54a368 (patch) | |
tree | 14e9d5e48bf96738ecbc8a241916e45f578c9bb9 /yql/essentials/minikql | |
parent | 8093553f5735b27c84bab0334f1150d325766268 (diff) | |
download | ydb-560023d550a203f58c64cdf9844dac139f54a368.tar.gz |
MR: greedy quantifiers fix
commit_hash:942b86bef9990f5a15a3a7ce862665194278ffd4
Diffstat (limited to 'yql/essentials/minikql')
3 files changed, 90 insertions, 59 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp index 0f758c7a8b..80caaad37e 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp @@ -54,7 +54,7 @@ public: ) : PartitionKey(std::move(partitionKey)) , Parameters(parameters) - , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines) + , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo) , Cache(cache) { } @@ -72,10 +72,10 @@ public: NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) { auto match = Nfa.GetMatched(); - if (!match.has_value()) { + if (!match) { return NUdf::TUnboxedValue{}; } - Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match.value())); + Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars)); Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>( ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows), Parameters.MeasureInputColumnOrder, @@ -95,9 +95,7 @@ public: break; } } - if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) { - Nfa.Clear(); - } + Nfa.AfterMatchSkip(*match); return result; } bool ProcessEndOfData(TComputationContext& ctx) { diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h index 944164e4bc..2b194212f4 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h @@ -350,19 +350,25 @@ class TNfa { using TRange = TSparseList::TRange; using TMatchedVars = TMatchedVars<TRange>; +public: + struct TMatch { + size_t BeginIndex; + size_t EndIndex; + TMatchedVars Vars; + }; + +private: struct TState { - size_t BeginMatchIndex; - size_t EndMatchIndex; size_t Index; - TMatchedVars Vars; + TMatch Match; std::deque<ui64, TMKQLAllocator<ui64>> Quantifiers; void Save(TMrOutputSerializer& serializer) const { - serializer.Write(BeginMatchIndex); - serializer.Write(EndMatchIndex); serializer.Write(Index); - serializer.Write(Vars.size()); - for (const auto& vector : Vars) { + serializer.Write(Match.BeginIndex); + serializer.Write(Match.EndIndex); + serializer.Write(Match.Vars.size()); + for (const auto& vector : Match.Vars) { serializer.Write(vector.size()); for (const auto& range : vector) { range.Save(serializer); @@ -375,13 +381,13 @@ class TNfa { } void Load(TMrInputSerializer& serializer) { - serializer.Read(BeginMatchIndex); - serializer.Read(EndMatchIndex); serializer.Read(Index); + serializer.Read(Match.BeginIndex); + serializer.Read(Match.EndIndex); auto varsSize = serializer.Read<TMatchedVars::size_type>(); - Vars.clear(); - Vars.resize(varsSize); - for (auto& subvec: Vars) { + Match.Vars.clear(); + Match.Vars.resize(varsSize); + for (auto& subvec: Match.Vars) { ui64 vectorSize = serializer.Read<ui64>(); subvec.resize(vectorSize); for (auto& item : subvec) { @@ -397,24 +403,29 @@ class TNfa { } friend inline bool operator<(const TState& lhs, const TState& rhs) { - auto lhsEndMatchIndex = -static_cast<i64>(lhs.EndMatchIndex); - auto rhsEndMatchIndex = -static_cast<i64>(rhs.EndMatchIndex); - return std::tie(lhs.BeginMatchIndex, lhsEndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.BeginMatchIndex, rhsEndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars); + auto lhsMatchEndIndex = -static_cast<i64>(lhs.Match.EndIndex); + auto rhsMatchEndIndex = -static_cast<i64>(rhs.Match.EndIndex); + return std::tie(lhs.Match.BeginIndex, lhsMatchEndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) < std::tie(rhs.Match.BeginIndex, rhsMatchEndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers); } friend inline bool operator==(const TState& lhs, const TState& rhs) { - return std::tie(lhs.BeginMatchIndex, lhs.EndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) == std::tie(rhs.BeginMatchIndex, rhs.EndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars); + return std::tie(lhs.Match.BeginIndex, lhs.Match.EndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) == std::tie(rhs.Match.BeginIndex, rhs.Match.EndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers); } }; public: - TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines) - : TransitionGraph(transitionGraph) - , MatchedRangesArg(matchedRangesArg) - , Defines(defines) { - } + TNfa( + TNfaTransitionGraph::TPtr transitionGraph, + IComputationExternalNode* matchedRangesArg, + const TComputationNodePtrVector& defines, + TAfterMatchSkipTo skipTo) + : TransitionGraph(transitionGraph) + , MatchedRangesArg(matchedRangesArg) + , Defines(defines) + , SkipTo_(skipTo) + {} void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) { - TState state(currentRowLock.From(), currentRowLock.To(), TransitionGraph->Input, TMatchedVars(Defines.size()), std::deque<ui64, TMKQLAllocator<ui64>>{}); + TState state(TransitionGraph->Input, TMatch{currentRowLock.From(), currentRowLock.To(), TMatchedVars(Defines.size())}, std::deque<ui64, TMKQLAllocator<ui64>>{}); Insert(std::move(state)); MakeEpsilonTransitions(); TStateSet newStates; @@ -423,17 +434,17 @@ public: //Here we handle only transitions of TMatchedVarTransition type, //all other transitions are handled in MakeEpsilonTransitions if (const auto* matchedVarTransition = std::get_if<TMatchedVarTransition>(&TransitionGraph->Transitions[state.Index])) { - MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Vars)); + MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Match.Vars)); const auto varIndex = matchedVarTransition->VarIndex; const auto& v = Defines[varIndex]->GetValue(ctx); if (v && v.Get<bool>()) { if (matchedVarTransition->SaveState) { - auto vars = state.Vars; //TODO get rid of this copy + auto vars = state.Match.Vars; //TODO get rid of this copy auto& matchedVar = vars[varIndex]; Extend(matchedVar, currentRowLock); - newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, std::move(vars), state.Quantifiers); + newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers); } else { - newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, state.Vars, state.Quantifiers); + newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), state.Match.Vars}, state.Quantifiers); } } deletedStates.insert(state); @@ -450,8 +461,8 @@ public: bool HasMatched() const { for (auto& state: ActiveStates) { - if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex), - finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex); + if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex), + finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex); ((activeStateIter != ActiveStateCounters.end() && finishedStateIter != FinishedStateCounters.end() && activeStateIter->second == finishedStateIter->second) || @@ -463,16 +474,16 @@ public: return false; } - std::optional<TMatchedVars> GetMatched() { + std::optional<TMatch> GetMatched() { for (auto& state: ActiveStates) { - if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex), - finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex); + if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex), + finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex); ((activeStateIter != ActiveStateCounters.end() && finishedStateIter != FinishedStateCounters.end() && activeStateIter->second == finishedStateIter->second) || EndOfData) && state.Index == TransitionGraph->Output) { - auto result = state.Vars; + auto result = state.Match; Erase(std::move(state)); return result; } @@ -515,9 +526,9 @@ public: auto activeStateCountersSize = serializer.Read<ui64>(); for (size_t i = 0; i < activeStateCountersSize; ++i) { using map_type = decltype(ActiveStateCounters); - auto beginMatchIndex = serializer.Read<map_type::key_type>(); + auto matchBeginIndex = serializer.Read<map_type::key_type>(); auto counter = serializer.Read<map_type::mapped_type>(); - ActiveStateCounters.emplace(beginMatchIndex, counter); + ActiveStateCounters.emplace(matchBeginIndex, counter); } } { @@ -525,9 +536,9 @@ public: auto finishedStateCountersSize = serializer.Read<ui64>(); for (size_t i = 0; i < finishedStateCountersSize; ++i) { using map_type = decltype(FinishedStateCounters); - auto beginMatchIndex = serializer.Read<map_type::key_type>(); + auto matchBeginIndex = serializer.Read<map_type::key_type>(); auto counter = serializer.Read<map_type::mapped_type>(); - FinishedStateCounters.emplace(beginMatchIndex, counter); + FinishedStateCounters.emplace(matchBeginIndex, counter); } } } @@ -537,10 +548,31 @@ public: return HasMatched(); } - void Clear() { - ActiveStates.clear(); - ActiveStateCounters.clear(); - FinishedStateCounters.clear(); + void AfterMatchSkip(const TMatch& match) { + const auto skipToRowIndex = [&]() { + switch (SkipTo_.To) { + case EAfterMatchSkipTo::NextRow: + return match.BeginIndex + 1; + case EAfterMatchSkipTo::PastLastRow: + return match.EndIndex + 1; + case EAfterMatchSkipTo::ToFirst: + MKQL_ENSURE(false, "AFTER MATCH SKIP TO FIRST is not implemented yet"); + case EAfterMatchSkipTo::ToLast: + [[fallthrough]]; + case EAfterMatchSkipTo::To: + MKQL_ENSURE(false, "AFTER MATCH SKIP TO LAST is not implemented yet"); + } + }(); + + TStateSet deletedStates; + for (const auto& state : ActiveStates) { + if (state.Match.BeginIndex < skipToRowIndex) { + deletedStates.insert(state); + } + } + for (auto& state : deletedStates) { + Erase(std::move(state)); + } } private: @@ -561,14 +593,14 @@ private: [&](const TEpsilonTransitions& epsilonTransitions) { deletedStates.insert(state); for (const auto& i : epsilonTransitions.To) { - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, i, state.Vars, state.Quantifiers); + newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers); } }, [&](const TQuantityEnterTransition& quantityEnterTransition) { deletedStates.insert(state); auto quantifiers = state.Quantifiers; //TODO get rid of this copy quantifiers.push_back(0); - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, quantityEnterTransition.To, state.Vars, std::move(quantifiers)); + newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers)); }, [&](const TQuantityExitTransition& quantityExitTransition) { deletedStates.insert(state); @@ -576,12 +608,12 @@ private: if (state.Quantifiers.back() + 1 < quantityMax) { auto q = state.Quantifiers; q.back()++; - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toFindMore, state.Vars, std::move(q)); + newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); } if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) { auto q = state.Quantifiers; q.pop_back(); - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toMatched, state.Vars, std::move(q)); + newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); } }, }, TransitionGraph->Transitions[state.Index]); @@ -610,22 +642,22 @@ private: } void Insert(TState state) { - auto beginMatchIndex = state.BeginMatchIndex; + auto matchBeginIndex = state.Match.BeginIndex; const auto& transition = TransitionGraph->Transitions[state.Index]; auto diff = static_cast<i64>(ActiveStates.insert(std::move(state)).second); - Add(ActiveStateCounters, beginMatchIndex, diff); + Add(ActiveStateCounters, matchBeginIndex, diff); if (std::holds_alternative<TVoidTransition>(transition)) { - Add(FinishedStateCounters, beginMatchIndex, diff); + Add(FinishedStateCounters, matchBeginIndex, diff); } } void Erase(TState state) { - auto beginMatchIndex = state.BeginMatchIndex; + auto matchBeginIndex = state.Match.BeginIndex; const auto& transition = TransitionGraph->Transitions[state.Index]; auto diff = -static_cast<i64>(ActiveStates.erase(std::move(state))); - Add(ActiveStateCounters, beginMatchIndex, diff); + Add(ActiveStateCounters, matchBeginIndex, diff); if (std::holds_alternative<TVoidTransition>(transition)) { - Add(FinishedStateCounters, beginMatchIndex, diff); + Add(FinishedStateCounters, matchBeginIndex, diff); } } @@ -636,6 +668,7 @@ private: THashMap<size_t, i64> ActiveStateCounters; THashMap<size_t, i64> FinishedStateCounters; bool EndOfData = false; + TAfterMatchSkipTo SkipTo_; }; }//namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp index 745c88e084..afdbd6b8e8 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp @@ -59,7 +59,7 @@ struct TNfaSetup { for (auto& d: Defines) { defines.push_back(d); } - return TNfa(transitionGraph, MatchedVars, defines); + return TNfa(transitionGraph, MatchedVars, defines, TAfterMatchSkipTo{EAfterMatchSkipTo::PastLastRow, ""}); } TComputationNodeFactory GetAuxCallableFactory() { @@ -261,8 +261,8 @@ Y_UNIT_TEST_SUITE(MatchRecognizeNfa) { Iota(expectedTo.begin(), expectedTo.end(), i - seriesLength + 1); for (size_t matchCount = 0; matchCount < seriesLength; ++matchCount) { auto match = setup.Nfa.GetMatched(); - UNIT_ASSERT_C(match.has_value(), i); - auto vars = match.value(); + UNIT_ASSERT_C(match, i); + auto vars = match->Vars; UNIT_ASSERT_VALUES_EQUAL_C(1, vars.size(), i); auto var = vars[0]; UNIT_ASSERT_VALUES_EQUAL_C(1, var.size(), i); |