aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes
diff options
context:
space:
mode:
authorVitaly Isaev <vitalyisaev@ydb.tech>2024-12-12 15:39:00 +0000
committerVitaly Isaev <vitalyisaev@ydb.tech>2024-12-12 15:39:00 +0000
commit827b115675004838023427572a7c69f40a86a80a (patch)
treee99c953fe494b9de8d8597a15859d77c81f118c7 /yql/essentials/minikql/comp_nodes
parent42701242eaf5be980cb935631586d0e90b82641c (diff)
parentfab222fd8176d00eee5ddafc6bce8cb95a6e3ab0 (diff)
downloadydb-827b115675004838023427572a7c69f40a86a80a.tar.gz
Merge branch 'rightlib' into rightlib_20241212
Diffstat (limited to 'yql/essentials/minikql/comp_nodes')
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp22
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp69
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h2
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp10
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h133
-rw-r--r--yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp6
6 files changed, 165 insertions, 77 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp b/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp
index 08986bc8fd..871321786b 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp
@@ -643,12 +643,15 @@ private:
}
void SwitchMode(EOperatingMode mode, TComputationContext& ctx) {
+ LogMemoryUsage();
switch(mode) {
case EOperatingMode::InMemory: {
+ YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to InMemory";
MKQL_ENSURE(false, "Internal logic error");
break;
}
case EOperatingMode::Spilling: {
+ YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to Spilling";
MKQL_ENSURE(EOperatingMode::InMemory == Mode, "Internal logic error");
auto spiller = ctx.SpillerFactory->CreateSpiller();
RightPacker->TablePtr->InitializeBucketSpillers(spiller);
@@ -656,6 +659,7 @@ private:
break;
}
case EOperatingMode::ProcessSpilled: {
+ YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to ProcessSpilled";
SpilledBucketsJoinOrder.reserve(GraceJoin::NumberOfBuckets);
for (ui32 i = 0; i < GraceJoin::NumberOfBuckets; ++i) SpilledBucketsJoinOrder.push_back(i);
@@ -843,9 +847,6 @@ private:
if (isYield == EFetchResult::One)
return isYield;
if (IsSpillingAllowed && ctx.SpillerFactory && IsSwitchToSpillingModeCondition()) {
- LogMemoryUsage();
- YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to Spilling";
-
SwitchMode(EOperatingMode::Spilling, ctx);
return EFetchResult::Yield;
}
@@ -861,14 +862,18 @@ private:
<< " HaveLeft " << *HaveMoreLeftRows << " LeftPacked " << LeftPacker->TuplesBatchPacked << " LeftBatch " << LeftPacker->BatchSize
<< " HaveRight " << *HaveMoreRightRows << " RightPacked " << RightPacker->TuplesBatchPacked << " RightBatch " << RightPacker->BatchSize
;
+
+ auto& leftTable = *LeftPacker->TablePtr;
+ auto& rightTable = SelfJoinSameKeys_ ? *LeftPacker->TablePtr : *RightPacker->TablePtr;
+ if (IsSpillingAllowed && ctx.SpillerFactory && !JoinedTablePtr->TryToPreallocateMemoryForJoin(leftTable, rightTable, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows)) {
+ SwitchMode(EOperatingMode::Spilling, ctx);
+ return EFetchResult::Yield;
+ }
+
*PartialJoinCompleted = true;
LeftPacker->StartTime = std::chrono::system_clock::now();
RightPacker->StartTime = std::chrono::system_clock::now();
- if ( SelfJoinSameKeys_ ) {
- JoinedTablePtr->Join(*LeftPacker->TablePtr, *LeftPacker->TablePtr, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows);
- } else {
- JoinedTablePtr->Join(*LeftPacker->TablePtr, *RightPacker->TablePtr, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows);
- }
+ JoinedTablePtr->Join(leftTable, rightTable, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows);
JoinedTablePtr->ResetIterator();
LeftPacker->EndTime = std::chrono::system_clock::now();
RightPacker->EndTime = std::chrono::system_clock::now();
@@ -945,7 +950,6 @@ EFetchResult DoCalculateWithSpilling(TComputationContext& ctx, NUdf::TUnboxedVal
}
if (!IsReadyForSpilledDataProcessing()) return EFetchResult::Yield;
- YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching to ProcessSpilled";
SwitchMode(EOperatingMode::ProcessSpilled, ctx);
return EFetchResult::Finish;
}
diff --git a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp
index f9b19fdfbc..ae7c89daef 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp
@@ -320,6 +320,63 @@ void ResizeHashTable(KeysHashTable &t, ui64 newSlots){
}
+bool IsTablesSwapRequired(ui64 tuplesNum1, ui64 tuplesNum2, bool table1Batch, bool table2Batch) {
+ return tuplesNum2 > tuplesNum1 && !table1Batch || table2Batch;
+}
+
+ui64 ComputeJoinSlotsSizeForBucket(const TTableBucket& bucket, const TTableBucketStats& bucketStats, ui64 headerSize, bool tableHasKeyStringColumns, bool tableHasKeyIColumns) {
+ ui64 tuplesNum = bucketStats.TuplesNum;
+
+ ui64 avgStringsSize = (3 * (bucket.KeyIntVals.size() - tuplesNum * headerSize) ) / ( 2 * tuplesNum + 1) + 1;
+ ui64 slotSize = headerSize + 1; // Header [Short Strings] SlotIdx
+ if (tableHasKeyStringColumns || tableHasKeyIColumns) {
+ slotSize = slotSize + avgStringsSize;
+ }
+
+ return slotSize;
+}
+
+ui64 ComputeNumberOfSlots(ui64 tuplesNum) {
+ return (3 * tuplesNum + 1) | 1;
+}
+
+bool TTable::TryToPreallocateMemoryForJoin(TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLeftTuples, bool hasMoreRightTuples) {
+ // If the batch is final or the only one, then the buckets are processed sequentially, the memory for the hash tables is freed immediately after processing.
+ // So, no preallocation is required.
+ if (!hasMoreLeftTuples && !hasMoreRightTuples) return true;
+
+ for (ui64 bucket = 0; bucket < GraceJoin::NumberOfBuckets; bucket++) {
+ ui64 tuplesNum1 = t1.TableBucketsStats[bucket].TuplesNum;
+ ui64 tuplesNum2 = t2.TableBucketsStats[bucket].TuplesNum;
+
+ TTable& tableForPreallocation = IsTablesSwapRequired(tuplesNum1, tuplesNum2, hasMoreLeftTuples || LeftTableBatch_, hasMoreRightTuples || RightTableBatch_) ? t1 : t2;
+ if (!tableForPreallocation.TableBucketsStats[bucket].TuplesNum || tableForPreallocation.TableBuckets[bucket].NSlots) continue;
+
+ TTableBucket& bucketForPreallocation = tableForPreallocation.TableBuckets[bucket];
+ const TTableBucketStats& bucketForPreallocationStats = tableForPreallocation.TableBucketsStats[bucket];
+
+ const auto nSlots = ComputeJoinSlotsSizeForBucket(bucketForPreallocation, bucketForPreallocationStats, tableForPreallocation.HeaderSize,
+ tableForPreallocation.NumberOfKeyStringColumns != 0, tableForPreallocation.NumberOfKeyIColumns != 0);
+ const auto slotSize = ComputeNumberOfSlots(tableForPreallocation.TableBucketsStats[bucket].TuplesNum);
+
+ try {
+ bucketForPreallocation.JoinSlots.reserve(nSlots*slotSize);
+ } catch (TMemoryLimitExceededException) {
+ for (ui64 i = 0; i < bucket; ++i) {
+ GraceJoin::TTableBucket * b1 = &JoinTable1->TableBuckets[i];
+ b1->JoinSlots.resize(0);
+ b1->JoinSlots.shrink_to_fit();
+ GraceJoin::TTableBucket * b2 = &JoinTable2->TableBuckets[i];
+ b2->JoinSlots.resize(0);
+ b2->JoinSlots.shrink_to_fit();
+ }
+ return false;
+ }
+ }
+
+ return true;
+}
+
// Joins two tables and returns join result in joined table. Tuples of joined table could be received by
// joined table iterator
@@ -368,7 +425,7 @@ void TTable::Join( TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLef
bool table2HasKeyStringColumns = (JoinTable2->NumberOfKeyStringColumns != 0);
bool table1HasKeyIColumns = (JoinTable1->NumberOfKeyIColumns != 0);
bool table2HasKeyIColumns = (JoinTable2->NumberOfKeyIColumns != 0);
- bool swapTables = tuplesNum2 > tuplesNum1 && !table1Batch || table2Batch;
+ bool swapTables = IsTablesSwapRequired(tuplesNum1, tuplesNum2, table1Batch, table2Batch);
if (swapTables) {
@@ -402,13 +459,7 @@ void TTable::Join( TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLef
if (tuplesNum1 == 0 && (hasMoreRightTuples || hasMoreLeftTuples || !bucketStats2->HashtableMatches))
continue;
- ui64 slotSize = headerSize2 + 1; // Header [Short Strings] SlotIdx
-
- ui64 avgStringsSize = ( 3 * (bucket2->KeyIntVals.size() - tuplesNum2 * headerSize2) ) / ( 2 * tuplesNum2 + 1) + 1;
-
- if (table2HasKeyStringColumns || table2HasKeyIColumns ) {
- slotSize = slotSize + avgStringsSize;
- }
+ ui64 slotSize = ComputeJoinSlotsSizeForBucket(*bucket2, *bucketStats2, headerSize2, table2HasKeyStringColumns, table2HasKeyIColumns);
ui64 &nSlots = bucket2->NSlots;
auto &joinSlots = bucket2->JoinSlots;
@@ -417,7 +468,7 @@ void TTable::Join( TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLef
Y_DEBUG_ABORT_UNLESS(bucketStats2->SlotSize == 0 || bucketStats2->SlotSize == slotSize);
if (!nSlots) {
- nSlots = (3 * tuplesNum2 + 1) | 1;
+ nSlots = ComputeNumberOfSlots(tuplesNum2);
joinSlots.resize(nSlots*slotSize, 0);
bloomFilter.Resize(tuplesNum2);
initHashTable = true;
diff --git a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h
index a4846926d1..d6b9a54aca 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h
@@ -346,6 +346,8 @@ public:
// Returns value of next tuple. Returs true if there are more tuples
bool NextTuple(TupleData& td);
+ bool TryToPreallocateMemoryForJoin(TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLeftTuples, bool hasMoreRightTuples);
+
// Joins two tables and stores join result in table data. Tuples of joined table could be received by
// joined table iterator. Life time of t1, t2 should be greater than lifetime of joined table
// hasMoreLeftTuples, hasMoreRightTuples is true if join is partial and more rows are coming. For final batch hasMoreLeftTuples = false, hasMoreRightTuples = false
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
index 0f758c7a8b..80caaad37e 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -54,7 +54,7 @@ public:
)
: PartitionKey(std::move(partitionKey))
, Parameters(parameters)
- , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines)
+ , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo)
, Cache(cache)
{
}
@@ -72,10 +72,10 @@ public:
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
auto match = Nfa.GetMatched();
- if (!match.has_value()) {
+ if (!match) {
return NUdf::TUnboxedValue{};
}
- Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match.value()));
+ Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars));
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows),
Parameters.MeasureInputColumnOrder,
@@ -95,9 +95,7 @@ public:
break;
}
}
- if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) {
- Nfa.Clear();
- }
+ Nfa.AfterMatchSkip(*match);
return result;
}
bool ProcessEndOfData(TComputationContext& ctx) {
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
index 944164e4bc..2b194212f4 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
@@ -350,19 +350,25 @@ class TNfa {
using TRange = TSparseList::TRange;
using TMatchedVars = TMatchedVars<TRange>;
+public:
+ struct TMatch {
+ size_t BeginIndex;
+ size_t EndIndex;
+ TMatchedVars Vars;
+ };
+
+private:
struct TState {
- size_t BeginMatchIndex;
- size_t EndMatchIndex;
size_t Index;
- TMatchedVars Vars;
+ TMatch Match;
std::deque<ui64, TMKQLAllocator<ui64>> Quantifiers;
void Save(TMrOutputSerializer& serializer) const {
- serializer.Write(BeginMatchIndex);
- serializer.Write(EndMatchIndex);
serializer.Write(Index);
- serializer.Write(Vars.size());
- for (const auto& vector : Vars) {
+ serializer.Write(Match.BeginIndex);
+ serializer.Write(Match.EndIndex);
+ serializer.Write(Match.Vars.size());
+ for (const auto& vector : Match.Vars) {
serializer.Write(vector.size());
for (const auto& range : vector) {
range.Save(serializer);
@@ -375,13 +381,13 @@ class TNfa {
}
void Load(TMrInputSerializer& serializer) {
- serializer.Read(BeginMatchIndex);
- serializer.Read(EndMatchIndex);
serializer.Read(Index);
+ serializer.Read(Match.BeginIndex);
+ serializer.Read(Match.EndIndex);
auto varsSize = serializer.Read<TMatchedVars::size_type>();
- Vars.clear();
- Vars.resize(varsSize);
- for (auto& subvec: Vars) {
+ Match.Vars.clear();
+ Match.Vars.resize(varsSize);
+ for (auto& subvec: Match.Vars) {
ui64 vectorSize = serializer.Read<ui64>();
subvec.resize(vectorSize);
for (auto& item : subvec) {
@@ -397,24 +403,29 @@ class TNfa {
}
friend inline bool operator<(const TState& lhs, const TState& rhs) {
- auto lhsEndMatchIndex = -static_cast<i64>(lhs.EndMatchIndex);
- auto rhsEndMatchIndex = -static_cast<i64>(rhs.EndMatchIndex);
- return std::tie(lhs.BeginMatchIndex, lhsEndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.BeginMatchIndex, rhsEndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars);
+ auto lhsMatchEndIndex = -static_cast<i64>(lhs.Match.EndIndex);
+ auto rhsMatchEndIndex = -static_cast<i64>(rhs.Match.EndIndex);
+ return std::tie(lhs.Match.BeginIndex, lhsMatchEndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) < std::tie(rhs.Match.BeginIndex, rhsMatchEndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers);
}
friend inline bool operator==(const TState& lhs, const TState& rhs) {
- return std::tie(lhs.BeginMatchIndex, lhs.EndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) == std::tie(rhs.BeginMatchIndex, rhs.EndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars);
+ return std::tie(lhs.Match.BeginIndex, lhs.Match.EndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) == std::tie(rhs.Match.BeginIndex, rhs.Match.EndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers);
}
};
public:
- TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
- : TransitionGraph(transitionGraph)
- , MatchedRangesArg(matchedRangesArg)
- , Defines(defines) {
- }
+ TNfa(
+ TNfaTransitionGraph::TPtr transitionGraph,
+ IComputationExternalNode* matchedRangesArg,
+ const TComputationNodePtrVector& defines,
+ TAfterMatchSkipTo skipTo)
+ : TransitionGraph(transitionGraph)
+ , MatchedRangesArg(matchedRangesArg)
+ , Defines(defines)
+ , SkipTo_(skipTo)
+ {}
void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
- TState state(currentRowLock.From(), currentRowLock.To(), TransitionGraph->Input, TMatchedVars(Defines.size()), std::deque<ui64, TMKQLAllocator<ui64>>{});
+ TState state(TransitionGraph->Input, TMatch{currentRowLock.From(), currentRowLock.To(), TMatchedVars(Defines.size())}, std::deque<ui64, TMKQLAllocator<ui64>>{});
Insert(std::move(state));
MakeEpsilonTransitions();
TStateSet newStates;
@@ -423,17 +434,17 @@ public:
//Here we handle only transitions of TMatchedVarTransition type,
//all other transitions are handled in MakeEpsilonTransitions
if (const auto* matchedVarTransition = std::get_if<TMatchedVarTransition>(&TransitionGraph->Transitions[state.Index])) {
- MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Vars));
+ MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Match.Vars));
const auto varIndex = matchedVarTransition->VarIndex;
const auto& v = Defines[varIndex]->GetValue(ctx);
if (v && v.Get<bool>()) {
if (matchedVarTransition->SaveState) {
- auto vars = state.Vars; //TODO get rid of this copy
+ auto vars = state.Match.Vars; //TODO get rid of this copy
auto& matchedVar = vars[varIndex];
Extend(matchedVar, currentRowLock);
- newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, std::move(vars), state.Quantifiers);
+ newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers);
} else {
- newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, state.Vars, state.Quantifiers);
+ newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), state.Match.Vars}, state.Quantifiers);
}
}
deletedStates.insert(state);
@@ -450,8 +461,8 @@ public:
bool HasMatched() const {
for (auto& state: ActiveStates) {
- if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex),
- finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex);
+ if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex),
+ finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex);
((activeStateIter != ActiveStateCounters.end() &&
finishedStateIter != FinishedStateCounters.end() &&
activeStateIter->second == finishedStateIter->second) ||
@@ -463,16 +474,16 @@ public:
return false;
}
- std::optional<TMatchedVars> GetMatched() {
+ std::optional<TMatch> GetMatched() {
for (auto& state: ActiveStates) {
- if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex),
- finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex);
+ if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex),
+ finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex);
((activeStateIter != ActiveStateCounters.end() &&
finishedStateIter != FinishedStateCounters.end() &&
activeStateIter->second == finishedStateIter->second) ||
EndOfData) &&
state.Index == TransitionGraph->Output) {
- auto result = state.Vars;
+ auto result = state.Match;
Erase(std::move(state));
return result;
}
@@ -515,9 +526,9 @@ public:
auto activeStateCountersSize = serializer.Read<ui64>();
for (size_t i = 0; i < activeStateCountersSize; ++i) {
using map_type = decltype(ActiveStateCounters);
- auto beginMatchIndex = serializer.Read<map_type::key_type>();
+ auto matchBeginIndex = serializer.Read<map_type::key_type>();
auto counter = serializer.Read<map_type::mapped_type>();
- ActiveStateCounters.emplace(beginMatchIndex, counter);
+ ActiveStateCounters.emplace(matchBeginIndex, counter);
}
}
{
@@ -525,9 +536,9 @@ public:
auto finishedStateCountersSize = serializer.Read<ui64>();
for (size_t i = 0; i < finishedStateCountersSize; ++i) {
using map_type = decltype(FinishedStateCounters);
- auto beginMatchIndex = serializer.Read<map_type::key_type>();
+ auto matchBeginIndex = serializer.Read<map_type::key_type>();
auto counter = serializer.Read<map_type::mapped_type>();
- FinishedStateCounters.emplace(beginMatchIndex, counter);
+ FinishedStateCounters.emplace(matchBeginIndex, counter);
}
}
}
@@ -537,10 +548,31 @@ public:
return HasMatched();
}
- void Clear() {
- ActiveStates.clear();
- ActiveStateCounters.clear();
- FinishedStateCounters.clear();
+ void AfterMatchSkip(const TMatch& match) {
+ const auto skipToRowIndex = [&]() {
+ switch (SkipTo_.To) {
+ case EAfterMatchSkipTo::NextRow:
+ return match.BeginIndex + 1;
+ case EAfterMatchSkipTo::PastLastRow:
+ return match.EndIndex + 1;
+ case EAfterMatchSkipTo::ToFirst:
+ MKQL_ENSURE(false, "AFTER MATCH SKIP TO FIRST is not implemented yet");
+ case EAfterMatchSkipTo::ToLast:
+ [[fallthrough]];
+ case EAfterMatchSkipTo::To:
+ MKQL_ENSURE(false, "AFTER MATCH SKIP TO LAST is not implemented yet");
+ }
+ }();
+
+ TStateSet deletedStates;
+ for (const auto& state : ActiveStates) {
+ if (state.Match.BeginIndex < skipToRowIndex) {
+ deletedStates.insert(state);
+ }
+ }
+ for (auto& state : deletedStates) {
+ Erase(std::move(state));
+ }
}
private:
@@ -561,14 +593,14 @@ private:
[&](const TEpsilonTransitions& epsilonTransitions) {
deletedStates.insert(state);
for (const auto& i : epsilonTransitions.To) {
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, i, state.Vars, state.Quantifiers);
+ newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers);
}
},
[&](const TQuantityEnterTransition& quantityEnterTransition) {
deletedStates.insert(state);
auto quantifiers = state.Quantifiers; //TODO get rid of this copy
quantifiers.push_back(0);
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, quantityEnterTransition.To, state.Vars, std::move(quantifiers));
+ newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers));
},
[&](const TQuantityExitTransition& quantityExitTransition) {
deletedStates.insert(state);
@@ -576,12 +608,12 @@ private:
if (state.Quantifiers.back() + 1 < quantityMax) {
auto q = state.Quantifiers;
q.back()++;
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toFindMore, state.Vars, std::move(q));
+ newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
}
if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) {
auto q = state.Quantifiers;
q.pop_back();
- newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toMatched, state.Vars, std::move(q));
+ newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
}
},
}, TransitionGraph->Transitions[state.Index]);
@@ -610,22 +642,22 @@ private:
}
void Insert(TState state) {
- auto beginMatchIndex = state.BeginMatchIndex;
+ auto matchBeginIndex = state.Match.BeginIndex;
const auto& transition = TransitionGraph->Transitions[state.Index];
auto diff = static_cast<i64>(ActiveStates.insert(std::move(state)).second);
- Add(ActiveStateCounters, beginMatchIndex, diff);
+ Add(ActiveStateCounters, matchBeginIndex, diff);
if (std::holds_alternative<TVoidTransition>(transition)) {
- Add(FinishedStateCounters, beginMatchIndex, diff);
+ Add(FinishedStateCounters, matchBeginIndex, diff);
}
}
void Erase(TState state) {
- auto beginMatchIndex = state.BeginMatchIndex;
+ auto matchBeginIndex = state.Match.BeginIndex;
const auto& transition = TransitionGraph->Transitions[state.Index];
auto diff = -static_cast<i64>(ActiveStates.erase(std::move(state)));
- Add(ActiveStateCounters, beginMatchIndex, diff);
+ Add(ActiveStateCounters, matchBeginIndex, diff);
if (std::holds_alternative<TVoidTransition>(transition)) {
- Add(FinishedStateCounters, beginMatchIndex, diff);
+ Add(FinishedStateCounters, matchBeginIndex, diff);
}
}
@@ -636,6 +668,7 @@ private:
THashMap<size_t, i64> ActiveStateCounters;
THashMap<size_t, i64> FinishedStateCounters;
bool EndOfData = false;
+ TAfterMatchSkipTo SkipTo_;
};
}//namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
index 745c88e084..afdbd6b8e8 100644
--- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
+++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp
@@ -59,7 +59,7 @@ struct TNfaSetup {
for (auto& d: Defines) {
defines.push_back(d);
}
- return TNfa(transitionGraph, MatchedVars, defines);
+ return TNfa(transitionGraph, MatchedVars, defines, TAfterMatchSkipTo{EAfterMatchSkipTo::PastLastRow, ""});
}
TComputationNodeFactory GetAuxCallableFactory() {
@@ -261,8 +261,8 @@ Y_UNIT_TEST_SUITE(MatchRecognizeNfa) {
Iota(expectedTo.begin(), expectedTo.end(), i - seriesLength + 1);
for (size_t matchCount = 0; matchCount < seriesLength; ++matchCount) {
auto match = setup.Nfa.GetMatched();
- UNIT_ASSERT_C(match.has_value(), i);
- auto vars = match.value();
+ UNIT_ASSERT_C(match, i);
+ auto vars = match->Vars;
UNIT_ASSERT_VALUES_EQUAL_C(1, vars.size(), i);
auto var = vars[0];
UNIT_ASSERT_VALUES_EQUAL_C(1, var.size(), i);