aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql
diff options
context:
space:
mode:
authorvokayndzop <vokayndzop@yandex-team.com>2024-12-12 12:36:12 +0300
committervokayndzop <vokayndzop@yandex-team.com>2024-12-12 12:53:02 +0300
commit560023d550a203f58c64cdf9844dac139f54a368 (patch)
tree14e9d5e48bf96738ecbc8a241916e45f578c9bb9 /yql/essentials/minikql
parent8093553f5735b27c84bab0334f1150d325766268 (diff)
downloadydb-560023d550a203f58c64cdf9844dac139f54a368.tar.gz
MR: greedy quantifiers fix
commit_hash:942b86bef9990f5a15a3a7ce862665194278ffd4
Diffstat (limited to 'yql/essentials/minikql')
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp10
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h133
-rw-r--r--yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp6
3 files changed, 90 insertions, 59 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
index 0f758c7a8b..80caaad37e 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -54,7 +54,7 @@ public:
)
: PartitionKey(std::move(partitionKey))
, Parameters(parameters)
- , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines)
+ , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo)
, Cache(cache)
{
}
@@ -72,10 +72,10 @@ public:
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
auto match = Nfa.GetMatched();
- if (!match.has_value()) {
+ if (!match) {
return NUdf::TUnboxedValue{};
}
- Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match.value()));
+ Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars));
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows),
Parameters.MeasureInputColumnOrder,
@@ -95,9 +95,7 @@ public:
break;
}
}
- if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) {
- Nfa.Clear();
- }
+ Nfa.AfterMatchSkip(*match);
return result;
}
bool ProcessEndOfData(TComputationContext& ctx) {
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
index 944164e4bc..2b194212f4 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
@@ -350,19 +350,25 @@ class TNfa {
using TRange = TSparseList::TRange;
using TMatchedVars = TMatchedVars<TRange>;
+public:
+ struct TMatch {
+ size_t BeginIndex;
+ size_t EndIndex;
+ TMatchedVars Vars;
+ };
+
+private:
struct TState {
- size_t BeginMatchIndex;
- size_t EndMatchIndex;
size_t Index;
- TMatchedVars Vars;
+ TMatch Match;
std::deque<ui64, TMKQLAllocator<ui64>> Quantifiers;
void Save(TMrOutputSerializer& serializer) const {
- serializer.Write(BeginMatchIndex);
- serializer.Write(EndMatchIndex);
serializer.Write(Index);
- serializer.Write(Vars.size());
- for (const auto& vector : Vars) {
+ serializer.Write(Match.BeginIndex);
+ serializer.Write(Match.EndIndex);
+ serializer.Write(Match.Vars.size());
+ for (const auto& vector : Match.Vars) {
serializer.Write(vector.size());
for (const auto& range : vector) {
range.Save(serializer);
@@ -375,13 +381,13 @@ class TNfa {
}
void Load(TMrInputSerializer& serializer) {
- serializer.Read(BeginMatchIndex);
- serializer.Read(EndMatchIndex);
serializer.Read(Index);
+ serializer.Read(Match.BeginIndex);
+ serializer.Read(Match.EndIndex);
auto varsSize = serializer.Read<TMatchedVars::size_type>();
- Vars.clear();
- Vars.resize(varsSize);
- for (auto& subvec: Vars) {
+ Match.Vars.clear();
+ Match.Vars.resize(varsSize);
+ for (auto& subvec: Match.Vars) {
ui64 vectorSize = serializer.Read<ui64>();
subvec.resize(vectorSize);
for (auto& item : subvec) {
@@ -397,24 +403,29 @@ class TNfa {
}
friend inline bool operator<(const TState& lhs, const TState& rhs) {
- auto lhsEndMatchIndex = -static_cast<i64>(lhs.EndMatchIndex);
- auto rhsEndMatchIndex = -static_cast<i64>(rhs.EndMatchIndex);
- return std::tie(lhs.BeginMatchIndex, lhsEndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.BeginMatchIndex, rhsEndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars);
+ auto lhsMatchEndIndex = -static_cast<i64>(lhs.Match.EndIndex);
+ auto rhsMatchEndIndex = -static_cast<i64>(rhs.Match.EndIndex);
+ return std::tie(lhs.Match.BeginIndex, lhsMatchEndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) < std::tie(rhs.Match.BeginIndex, rhsMatchEndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers);
}
friend inline bool operator==(const TState& lhs, const TState& rhs) {
- return std::tie(lhs.BeginMatchIndex, lhs.EndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) == std::tie(rhs.BeginMatchIndex, rhs.EndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars);
+ return std::tie(lhs.Match.BeginIndex, lhs.Match.EndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) == std::tie(rhs.Match.BeginIndex, rhs.Match.EndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers);
}
};
public:
- TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
- : TransitionGraph(transitionGraph)
- , MatchedRangesArg(matchedRangesArg)
- , Defines(defines) {
- }
+ TNfa(
+ TNfaTransitionGraph::TPtr transitionGraph,
+ IComputationExternalNode* matchedRangesArg,
+ const TComputationNodePtrVector& defines,
+ TAfterMatchSkipTo skipTo)
+ : TransitionGraph(transitionGraph)
+ , MatchedRangesArg(matchedRangesArg)
+ , Defines(defines)
+ , SkipTo_(skipTo)
+ {}
void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
- TState state(currentRowLock.From(), currentRowLock.To(), TransitionGraph->Input, TMatchedVars(Defines.size()), std::deque<ui64, TMKQLAllocator<ui64>>{});
+ TState state(TransitionGraph->Input, TMatch{currentRowLock.From(), currentRowLock.To(), TMatchedVars(Defines.size())}, std::deque<ui64, TMKQLAllocator<ui64>>{});
Insert(std::move(state));
MakeEpsilonTransitions();
TStateSet newStates;
@@ -423,17 +434,17 @@ public:
//Here we handle only transitions of TMatchedVarTransition type,
//all other transitions are handled in MakeEpsilonTransitions
if (const auto* matchedVarTransition = std::get_if<TMatchedVarTransition>(&TransitionGraph->Transitions[state.Index])) {
- MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Vars));
+ MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Match.Vars));
const auto varIndex = matchedVarTransition->VarIndex;
const auto& v = Defines[varIndex]->GetValue(ctx);
if (v && v.Get<bool>()) {
if (matchedVarTransition->SaveState) {
- auto vars = state.Vars; //TODO get rid of this copy
+ auto vars = state.Match.Vars; //TODO get rid of this copy
auto& matchedVar = vars[varIndex];
Extend(matchedVar, currentRowLock);
- newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, std::move(vars), state.Quantifiers);
+ newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers);
} else {
- newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, state.Vars, state.Quantifiers);
+ newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), state.Match.Vars}, state.Quantifiers);
}
}
deletedStates.insert(state);
@@ -450,8 +461,8 @@ public:
bool HasMatched() const {
for (auto& state: ActiveStates) {
- if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex),
- finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex);
+ if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex),
+ finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex);
((activeStateIter != ActiveStateCounters.end() &&
finishedStateIter != FinishedStateCounters.end() &&
activeStateIter->second == finishedStateIter->second) ||
@@ -463,16 +474,16 @@ public:
return false;
}
- std::optional<TMatchedVars> GetMatched() {
+ std::optional<TMatch> GetMatched() {
for (auto& state: ActiveStates) {
- if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex),
- finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex);
+ if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex),
+ finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex);
((activeStateIter != ActiveStateCounters.end() &&
finishedStateIter != FinishedStateCounters.end() &&
activeStateIter->second == finishedStateIter->second) ||
EndOfData) &&
state.Index == TransitionGraph->Output) {
- auto result = state.Vars;
+ auto result = state.Match;
Erase(std::move(state));
return result;
}
@@ -515,9 +526,9 @@ public:
auto activeStateCountersSize = serializer.Read<ui64>();
for (size_t i = 0; i < activeStateCountersSize; ++i) {
using map_type = decltype(ActiveStateCounters);
- auto beginMatchIndex = serializer.Read<map_type::key_type>();
+ auto matchBeginIndex = serializer.Read<map_type::key_type>();
auto counter = serializer.Read<map_type::mapped_type>();
- ActiveStateCounters.emplace(beginMatchIndex, counter);
+ ActiveStateCounters.emplace(matchBeginIndex, counter);
}
}
{
@@ -525,9 +536,9 @@ public:
auto finishedStateCountersSize = serializer.Read<ui64>();
for (size_t i = 0; i < finishedStateCountersSize; ++i) {
using map_type = decltype(FinishedStateCounters);
- auto beginMatchIndex = serializer.Read<map_type::key_type>();
+ auto matchBeginIndex = serializer.Read<map_type::key_type>();
auto counter = serializer.Read<map_type::mapped_type>();
- FinishedStateCounters.emplace(beginMatchIndex, counter);
+ FinishedStateCounters.emplace(matchBeginIndex, counter);
}
}
}
@@ -537,10 +548,31 @@ public:
return HasMatched();
}
- void Clear() {
- ActiveStates.clear();
- ActiveStateCounters.clear();
- FinishedStateCounters.clear();
+ void AfterMatchSkip(const TMatch& match) {
+ const auto skipToRowIndex = [&]() {
+ switch (SkipTo_.To) {
+ case EAfterMatchSkipTo::NextRow:
+ return match.BeginIndex + 1;
+ case EAfterMatchSkipTo::PastLastRow:
+ return match.EndIndex + 1;
+ case EAfterMatchSkipTo::ToFirst:
+ MKQL_ENSURE(false, "AFTER MATCH SKIP TO FIRST is not implemented yet");
+ case EAfterMatchSkipTo::ToLast:
+ [[fallthrough]];
+ case EAfterMatchSkipTo::To:
+ MKQL_ENSURE(false, "AFTER MATCH SKIP TO LAST is not implemented yet");
+ }
+ }();
+
+ TStateSet deletedStates;
+ for (const auto& state : ActiveStates) {
+ if (state.Match.BeginIndex < skipToRowIndex) {
+ deletedStates.insert(state);
+ }
+ }
+ for (auto& state : deletedStates) {
+ Erase(std::move(state));
+ }
}
private:
@@ -561,14 +593,14 @@ private:
[&](const TEpsilonTransitions& epsilonTransitions) {
deletedStates.insert(state);
for (const auto& i : epsilonTransitions.To) {
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, i, state.Vars, state.Quantifiers);
+ newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers);
}
},
[&](const TQuantityEnterTransition& quantityEnterTransition) {
deletedStates.insert(state);
auto quantifiers = state.Quantifiers; //TODO get rid of this copy
quantifiers.push_back(0);
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, quantityEnterTransition.To, state.Vars, std::move(quantifiers));
+ newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers));
},
[&](const TQuantityExitTransition& quantityExitTransition) {
deletedStates.insert(state);
@@ -576,12 +608,12 @@ private:
if (state.Quantifiers.back() + 1 < quantityMax) {
auto q = state.Quantifiers;
q.back()++;
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toFindMore, state.Vars, std::move(q));
+ newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
}
if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) {
auto q = state.Quantifiers;
q.pop_back();
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toMatched, state.Vars, std::move(q));
+ newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
}
},
}, TransitionGraph->Transitions[state.Index]);
@@ -610,22 +642,22 @@ private:
}
void Insert(TState state) {
- auto beginMatchIndex = state.BeginMatchIndex;
+ auto matchBeginIndex = state.Match.BeginIndex;
const auto& transition = TransitionGraph->Transitions[state.Index];
auto diff = static_cast<i64>(ActiveStates.insert(std::move(state)).second);
- Add(ActiveStateCounters, beginMatchIndex, diff);
+ Add(ActiveStateCounters, matchBeginIndex, diff);
if (std::holds_alternative<TVoidTransition>(transition)) {
- Add(FinishedStateCounters, beginMatchIndex, diff);
+ Add(FinishedStateCounters, matchBeginIndex, diff);
}
}
void Erase(TState state) {
- auto beginMatchIndex = state.BeginMatchIndex;
+ auto matchBeginIndex = state.Match.BeginIndex;
const auto& transition = TransitionGraph->Transitions[state.Index];
auto diff = -static_cast<i64>(ActiveStates.erase(std::move(state)));
- Add(ActiveStateCounters, beginMatchIndex, diff);
+ Add(ActiveStateCounters, matchBeginIndex, diff);
if (std::holds_alternative<TVoidTransition>(transition)) {
- Add(FinishedStateCounters, beginMatchIndex, diff);
+ Add(FinishedStateCounters, matchBeginIndex, diff);
}
}
@@ -636,6 +668,7 @@ private:
THashMap<size_t, i64> ActiveStateCounters;
THashMap<size_t, i64> FinishedStateCounters;
bool EndOfData = false;
+ TAfterMatchSkipTo SkipTo_;
};
}//namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
index 745c88e084..afdbd6b8e8 100644
--- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
+++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
@@ -59,7 +59,7 @@ struct TNfaSetup {
for (auto& d: Defines) {
defines.push_back(d);
}
- return TNfa(transitionGraph, MatchedVars, defines);
+ return TNfa(transitionGraph, MatchedVars, defines, TAfterMatchSkipTo{EAfterMatchSkipTo::PastLastRow, ""});
}
TComputationNodeFactory GetAuxCallableFactory() {
@@ -261,8 +261,8 @@ Y_UNIT_TEST_SUITE(MatchRecognizeNfa) {
Iota(expectedTo.begin(), expectedTo.end(), i - seriesLength + 1);
for (size_t matchCount = 0; matchCount < seriesLength; ++matchCount) {
auto match = setup.Nfa.GetMatched();
- UNIT_ASSERT_C(match.has_value(), i);
- auto vars = match.value();
+ UNIT_ASSERT_C(match, i);
+ auto vars = match->Vars;
UNIT_ASSERT_VALUES_EQUAL_C(1, vars.size(), i);
auto var = vars[0];
UNIT_ASSERT_VALUES_EQUAL_C(1, var.size(), i);