diff options
author | Vitaly Isaev <vitalyisaev@ydb.tech> | 2024-12-12 15:39:00 +0000 |
---|---|---|
committer | Vitaly Isaev <vitalyisaev@ydb.tech> | 2024-12-12 15:39:00 +0000 |
commit | 827b115675004838023427572a7c69f40a86a80a (patch) | |
tree | e99c953fe494b9de8d8597a15859d77c81f118c7 /yql/essentials/minikql/comp_nodes | |
parent | 42701242eaf5be980cb935631586d0e90b82641c (diff) | |
parent | fab222fd8176d00eee5ddafc6bce8cb95a6e3ab0 (diff) | |
download | ydb-827b115675004838023427572a7c69f40a86a80a.tar.gz |
Merge branch 'rightlib' into rightlib_20241212
Diffstat (limited to 'yql/essentials/minikql/comp_nodes')
6 files changed, 165 insertions, 77 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp b/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp index 08986bc8fd..871321786b 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_grace_join.cpp @@ -643,12 +643,15 @@ private: } void SwitchMode(EOperatingMode mode, TComputationContext& ctx) { + LogMemoryUsage(); switch(mode) { case EOperatingMode::InMemory: { + YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to InMemory"; MKQL_ENSURE(false, "Internal logic error"); break; } case EOperatingMode::Spilling: { + YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to Spilling"; MKQL_ENSURE(EOperatingMode::InMemory == Mode, "Internal logic error"); auto spiller = ctx.SpillerFactory->CreateSpiller(); RightPacker->TablePtr->InitializeBucketSpillers(spiller); @@ -656,6 +659,7 @@ private: break; } case EOperatingMode::ProcessSpilled: { + YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to ProcessSpilled"; SpilledBucketsJoinOrder.reserve(GraceJoin::NumberOfBuckets); for (ui32 i = 0; i < GraceJoin::NumberOfBuckets; ++i) SpilledBucketsJoinOrder.push_back(i); @@ -843,9 +847,6 @@ private: if (isYield == EFetchResult::One) return isYield; if (IsSpillingAllowed && ctx.SpillerFactory && IsSwitchToSpillingModeCondition()) { - LogMemoryUsage(); - YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching Memory mode to Spilling"; - SwitchMode(EOperatingMode::Spilling, ctx); return EFetchResult::Yield; } @@ -861,14 +862,18 @@ private: << " HaveLeft " << *HaveMoreLeftRows << " LeftPacked " << LeftPacker->TuplesBatchPacked << " LeftBatch " << LeftPacker->BatchSize << " HaveRight " << *HaveMoreRightRows << " RightPacked " << RightPacker->TuplesBatchPacked << " RightBatch " << RightPacker->BatchSize ; + + auto& leftTable = *LeftPacker->TablePtr; + auto& rightTable = SelfJoinSameKeys_ ? *LeftPacker->TablePtr : *RightPacker->TablePtr; + if (IsSpillingAllowed && ctx.SpillerFactory && !JoinedTablePtr->TryToPreallocateMemoryForJoin(leftTable, rightTable, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows)) { + SwitchMode(EOperatingMode::Spilling, ctx); + return EFetchResult::Yield; + } + *PartialJoinCompleted = true; LeftPacker->StartTime = std::chrono::system_clock::now(); RightPacker->StartTime = std::chrono::system_clock::now(); - if ( SelfJoinSameKeys_ ) { - JoinedTablePtr->Join(*LeftPacker->TablePtr, *LeftPacker->TablePtr, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows); - } else { - JoinedTablePtr->Join(*LeftPacker->TablePtr, *RightPacker->TablePtr, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows); - } + JoinedTablePtr->Join(leftTable, rightTable, JoinKind, *HaveMoreLeftRows, *HaveMoreRightRows); JoinedTablePtr->ResetIterator(); LeftPacker->EndTime = std::chrono::system_clock::now(); RightPacker->EndTime = std::chrono::system_clock::now(); @@ -945,7 +950,6 @@ EFetchResult DoCalculateWithSpilling(TComputationContext& ctx, NUdf::TUnboxedVal } if (!IsReadyForSpilledDataProcessing()) return EFetchResult::Yield; - YQL_LOG(INFO) << (const void *)&*JoinedTablePtr << "# switching to ProcessSpilled"; SwitchMode(EOperatingMode::ProcessSpilled, ctx); return EFetchResult::Finish; } diff --git a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp index f9b19fdfbc..ae7c89daef 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.cpp @@ -320,6 +320,63 @@ void ResizeHashTable(KeysHashTable &t, ui64 newSlots){ } +bool IsTablesSwapRequired(ui64 tuplesNum1, ui64 tuplesNum2, bool table1Batch, bool table2Batch) { + return tuplesNum2 > tuplesNum1 && !table1Batch || table2Batch; +} + +ui64 ComputeJoinSlotsSizeForBucket(const TTableBucket& bucket, const TTableBucketStats& bucketStats, ui64 headerSize, bool tableHasKeyStringColumns, bool tableHasKeyIColumns) { + ui64 tuplesNum = bucketStats.TuplesNum; + + ui64 avgStringsSize = (3 * (bucket.KeyIntVals.size() - tuplesNum * headerSize) ) / ( 2 * tuplesNum + 1) + 1; + ui64 slotSize = headerSize + 1; // Header [Short Strings] SlotIdx + if (tableHasKeyStringColumns || tableHasKeyIColumns) { + slotSize = slotSize + avgStringsSize; + } + + return slotSize; +} + +ui64 ComputeNumberOfSlots(ui64 tuplesNum) { + return (3 * tuplesNum + 1) | 1; +} + +bool TTable::TryToPreallocateMemoryForJoin(TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLeftTuples, bool hasMoreRightTuples) { + // If the batch is final or the only one, then the buckets are processed sequentially, the memory for the hash tables is freed immediately after processing. + // So, no preallocation is required. + if (!hasMoreLeftTuples && !hasMoreRightTuples) return true; + + for (ui64 bucket = 0; bucket < GraceJoin::NumberOfBuckets; bucket++) { + ui64 tuplesNum1 = t1.TableBucketsStats[bucket].TuplesNum; + ui64 tuplesNum2 = t2.TableBucketsStats[bucket].TuplesNum; + + TTable& tableForPreallocation = IsTablesSwapRequired(tuplesNum1, tuplesNum2, hasMoreLeftTuples || LeftTableBatch_, hasMoreRightTuples || RightTableBatch_) ? t1 : t2; + if (!tableForPreallocation.TableBucketsStats[bucket].TuplesNum || tableForPreallocation.TableBuckets[bucket].NSlots) continue; + + TTableBucket& bucketForPreallocation = tableForPreallocation.TableBuckets[bucket]; + const TTableBucketStats& bucketForPreallocationStats = tableForPreallocation.TableBucketsStats[bucket]; + + const auto nSlots = ComputeJoinSlotsSizeForBucket(bucketForPreallocation, bucketForPreallocationStats, tableForPreallocation.HeaderSize, + tableForPreallocation.NumberOfKeyStringColumns != 0, tableForPreallocation.NumberOfKeyIColumns != 0); + const auto slotSize = ComputeNumberOfSlots(tableForPreallocation.TableBucketsStats[bucket].TuplesNum); + + try { + bucketForPreallocation.JoinSlots.reserve(nSlots*slotSize); + } catch (TMemoryLimitExceededException) { + for (ui64 i = 0; i < bucket; ++i) { + GraceJoin::TTableBucket * b1 = &JoinTable1->TableBuckets[i]; + b1->JoinSlots.resize(0); + b1->JoinSlots.shrink_to_fit(); + GraceJoin::TTableBucket * b2 = &JoinTable2->TableBuckets[i]; + b2->JoinSlots.resize(0); + b2->JoinSlots.shrink_to_fit(); + } + return false; + } + } + + return true; +} + // Joins two tables and returns join result in joined table. Tuples of joined table could be received by // joined table iterator @@ -368,7 +425,7 @@ void TTable::Join( TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLef bool table2HasKeyStringColumns = (JoinTable2->NumberOfKeyStringColumns != 0); bool table1HasKeyIColumns = (JoinTable1->NumberOfKeyIColumns != 0); bool table2HasKeyIColumns = (JoinTable2->NumberOfKeyIColumns != 0); - bool swapTables = tuplesNum2 > tuplesNum1 && !table1Batch || table2Batch; + bool swapTables = IsTablesSwapRequired(tuplesNum1, tuplesNum2, table1Batch, table2Batch); if (swapTables) { @@ -402,13 +459,7 @@ void TTable::Join( TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLef if (tuplesNum1 == 0 && (hasMoreRightTuples || hasMoreLeftTuples || !bucketStats2->HashtableMatches)) continue; - ui64 slotSize = headerSize2 + 1; // Header [Short Strings] SlotIdx - - ui64 avgStringsSize = ( 3 * (bucket2->KeyIntVals.size() - tuplesNum2 * headerSize2) ) / ( 2 * tuplesNum2 + 1) + 1; - - if (table2HasKeyStringColumns || table2HasKeyIColumns ) { - slotSize = slotSize + avgStringsSize; - } + ui64 slotSize = ComputeJoinSlotsSizeForBucket(*bucket2, *bucketStats2, headerSize2, table2HasKeyStringColumns, table2HasKeyIColumns); ui64 &nSlots = bucket2->NSlots; auto &joinSlots = bucket2->JoinSlots; @@ -417,7 +468,7 @@ void TTable::Join( TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLef Y_DEBUG_ABORT_UNLESS(bucketStats2->SlotSize == 0 || bucketStats2->SlotSize == slotSize); if (!nSlots) { - nSlots = (3 * tuplesNum2 + 1) | 1; + nSlots = ComputeNumberOfSlots(tuplesNum2); joinSlots.resize(nSlots*slotSize, 0); bloomFilter.Resize(tuplesNum2); initHashTable = true; diff --git a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h index a4846926d1..d6b9a54aca 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h +++ b/yql/essentials/minikql/comp_nodes/mkql_grace_join_imp.h @@ -346,6 +346,8 @@ public: // Returns value of next tuple. Returs true if there are more tuples bool NextTuple(TupleData& td); + bool TryToPreallocateMemoryForJoin(TTable & t1, TTable & t2, EJoinKind joinKind, bool hasMoreLeftTuples, bool hasMoreRightTuples); + // Joins two tables and stores join result in table data. Tuples of joined table could be received by // joined table iterator. Life time of t1, t2 should be greater than lifetime of joined table // hasMoreLeftTuples, hasMoreRightTuples is true if join is partial and more rows are coming. For final batch hasMoreLeftTuples = false, hasMoreRightTuples = false diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp index 0f758c7a8b..80caaad37e 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp @@ -54,7 +54,7 @@ public: ) : PartitionKey(std::move(partitionKey)) , Parameters(parameters) - , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines) + , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo) , Cache(cache) { } @@ -72,10 +72,10 @@ public: NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) { auto match = Nfa.GetMatched(); - if (!match.has_value()) { + if (!match) { return NUdf::TUnboxedValue{}; } - Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match.value())); + Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars)); Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>( ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows), Parameters.MeasureInputColumnOrder, @@ -95,9 +95,7 @@ public: break; } } - if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) { - Nfa.Clear(); - } + Nfa.AfterMatchSkip(*match); return result; } bool ProcessEndOfData(TComputationContext& ctx) { diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h index 944164e4bc..2b194212f4 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h @@ -350,19 +350,25 @@ class TNfa { using TRange = TSparseList::TRange; using TMatchedVars = TMatchedVars<TRange>; +public: + struct TMatch { + size_t BeginIndex; + size_t EndIndex; + TMatchedVars Vars; + }; + +private: struct TState { - size_t BeginMatchIndex; - size_t EndMatchIndex; size_t Index; - TMatchedVars Vars; + TMatch Match; std::deque<ui64, TMKQLAllocator<ui64>> Quantifiers; void Save(TMrOutputSerializer& serializer) const { - serializer.Write(BeginMatchIndex); - serializer.Write(EndMatchIndex); serializer.Write(Index); - serializer.Write(Vars.size()); - for (const auto& vector : Vars) { + serializer.Write(Match.BeginIndex); + serializer.Write(Match.EndIndex); + serializer.Write(Match.Vars.size()); + for (const auto& vector : Match.Vars) { serializer.Write(vector.size()); for (const auto& range : vector) { range.Save(serializer); @@ -375,13 +381,13 @@ class TNfa { } void Load(TMrInputSerializer& serializer) { - serializer.Read(BeginMatchIndex); - serializer.Read(EndMatchIndex); serializer.Read(Index); + serializer.Read(Match.BeginIndex); + serializer.Read(Match.EndIndex); auto varsSize = serializer.Read<TMatchedVars::size_type>(); - Vars.clear(); - Vars.resize(varsSize); - for (auto& subvec: Vars) { + Match.Vars.clear(); + Match.Vars.resize(varsSize); + for (auto& subvec: Match.Vars) { ui64 vectorSize = serializer.Read<ui64>(); subvec.resize(vectorSize); for (auto& item : subvec) { @@ -397,24 +403,29 @@ class TNfa { } friend inline bool operator<(const TState& lhs, const TState& rhs) { - auto lhsEndMatchIndex = -static_cast<i64>(lhs.EndMatchIndex); - auto rhsEndMatchIndex = -static_cast<i64>(rhs.EndMatchIndex); - return std::tie(lhs.BeginMatchIndex, lhsEndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.BeginMatchIndex, rhsEndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars); + auto lhsMatchEndIndex = -static_cast<i64>(lhs.Match.EndIndex); + auto rhsMatchEndIndex = -static_cast<i64>(rhs.Match.EndIndex); + return std::tie(lhs.Match.BeginIndex, lhsMatchEndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) < std::tie(rhs.Match.BeginIndex, rhsMatchEndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers); } friend inline bool operator==(const TState& lhs, const TState& rhs) { - return std::tie(lhs.BeginMatchIndex, lhs.EndMatchIndex, lhs.Index, lhs.Quantifiers, lhs.Vars) == std::tie(rhs.BeginMatchIndex, rhs.EndMatchIndex, rhs.Index, rhs.Quantifiers, rhs.Vars); + return std::tie(lhs.Match.BeginIndex, lhs.Match.EndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) == std::tie(rhs.Match.BeginIndex, rhs.Match.EndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers); } }; public: - TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines) - : TransitionGraph(transitionGraph) - , MatchedRangesArg(matchedRangesArg) - , Defines(defines) { - } + TNfa( + TNfaTransitionGraph::TPtr transitionGraph, + IComputationExternalNode* matchedRangesArg, + const TComputationNodePtrVector& defines, + TAfterMatchSkipTo skipTo) + : TransitionGraph(transitionGraph) + , MatchedRangesArg(matchedRangesArg) + , Defines(defines) + , SkipTo_(skipTo) + {} void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) { - TState state(currentRowLock.From(), currentRowLock.To(), TransitionGraph->Input, TMatchedVars(Defines.size()), std::deque<ui64, TMKQLAllocator<ui64>>{}); + TState state(TransitionGraph->Input, TMatch{currentRowLock.From(), currentRowLock.To(), TMatchedVars(Defines.size())}, std::deque<ui64, TMKQLAllocator<ui64>>{}); Insert(std::move(state)); MakeEpsilonTransitions(); TStateSet newStates; @@ -423,17 +434,17 @@ public: //Here we handle only transitions of TMatchedVarTransition type, //all other transitions are handled in MakeEpsilonTransitions if (const auto* matchedVarTransition = std::get_if<TMatchedVarTransition>(&TransitionGraph->Transitions[state.Index])) { - MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Vars)); + MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Match.Vars)); const auto varIndex = matchedVarTransition->VarIndex; const auto& v = Defines[varIndex]->GetValue(ctx); if (v && v.Get<bool>()) { if (matchedVarTransition->SaveState) { - auto vars = state.Vars; //TODO get rid of this copy + auto vars = state.Match.Vars; //TODO get rid of this copy auto& matchedVar = vars[varIndex]; Extend(matchedVar, currentRowLock); - newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, std::move(vars), state.Quantifiers); + newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers); } else { - newStates.emplace(state.BeginMatchIndex, currentRowLock.To(), matchedVarTransition->To, state.Vars, state.Quantifiers); + newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), state.Match.Vars}, state.Quantifiers); } } deletedStates.insert(state); @@ -450,8 +461,8 @@ public: bool HasMatched() const { for (auto& state: ActiveStates) { - if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex), - finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex); + if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex), + finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex); ((activeStateIter != ActiveStateCounters.end() && finishedStateIter != FinishedStateCounters.end() && activeStateIter->second == finishedStateIter->second) || @@ -463,16 +474,16 @@ public: return false; } - std::optional<TMatchedVars> GetMatched() { + std::optional<TMatch> GetMatched() { for (auto& state: ActiveStates) { - if (auto activeStateIter = ActiveStateCounters.find(state.BeginMatchIndex), - finishedStateIter = FinishedStateCounters.find(state.BeginMatchIndex); + if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex), + finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex); ((activeStateIter != ActiveStateCounters.end() && finishedStateIter != FinishedStateCounters.end() && activeStateIter->second == finishedStateIter->second) || EndOfData) && state.Index == TransitionGraph->Output) { - auto result = state.Vars; + auto result = state.Match; Erase(std::move(state)); return result; } @@ -515,9 +526,9 @@ public: auto activeStateCountersSize = serializer.Read<ui64>(); for (size_t i = 0; i < activeStateCountersSize; ++i) { using map_type = decltype(ActiveStateCounters); - auto beginMatchIndex = serializer.Read<map_type::key_type>(); + auto matchBeginIndex = serializer.Read<map_type::key_type>(); auto counter = serializer.Read<map_type::mapped_type>(); - ActiveStateCounters.emplace(beginMatchIndex, counter); + ActiveStateCounters.emplace(matchBeginIndex, counter); } } { @@ -525,9 +536,9 @@ public: auto finishedStateCountersSize = serializer.Read<ui64>(); for (size_t i = 0; i < finishedStateCountersSize; ++i) { using map_type = decltype(FinishedStateCounters); - auto beginMatchIndex = serializer.Read<map_type::key_type>(); + auto matchBeginIndex = serializer.Read<map_type::key_type>(); auto counter = serializer.Read<map_type::mapped_type>(); - FinishedStateCounters.emplace(beginMatchIndex, counter); + FinishedStateCounters.emplace(matchBeginIndex, counter); } } } @@ -537,10 +548,31 @@ public: return HasMatched(); } - void Clear() { - ActiveStates.clear(); - ActiveStateCounters.clear(); - FinishedStateCounters.clear(); + void AfterMatchSkip(const TMatch& match) { + const auto skipToRowIndex = [&]() { + switch (SkipTo_.To) { + case EAfterMatchSkipTo::NextRow: + return match.BeginIndex + 1; + case EAfterMatchSkipTo::PastLastRow: + return match.EndIndex + 1; + case EAfterMatchSkipTo::ToFirst: + MKQL_ENSURE(false, "AFTER MATCH SKIP TO FIRST is not implemented yet"); + case EAfterMatchSkipTo::ToLast: + [[fallthrough]]; + case EAfterMatchSkipTo::To: + MKQL_ENSURE(false, "AFTER MATCH SKIP TO LAST is not implemented yet"); + } + }(); + + TStateSet deletedStates; + for (const auto& state : ActiveStates) { + if (state.Match.BeginIndex < skipToRowIndex) { + deletedStates.insert(state); + } + } + for (auto& state : deletedStates) { + Erase(std::move(state)); + } } private: @@ -561,14 +593,14 @@ private: [&](const TEpsilonTransitions& epsilonTransitions) { deletedStates.insert(state); for (const auto& i : epsilonTransitions.To) { - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, i, state.Vars, state.Quantifiers); + newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers); } }, [&](const TQuantityEnterTransition& quantityEnterTransition) { deletedStates.insert(state); auto quantifiers = state.Quantifiers; //TODO get rid of this copy quantifiers.push_back(0); - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, quantityEnterTransition.To, state.Vars, std::move(quantifiers)); + newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers)); }, [&](const TQuantityExitTransition& quantityExitTransition) { deletedStates.insert(state); @@ -576,12 +608,12 @@ private: if (state.Quantifiers.back() + 1 < quantityMax) { auto q = state.Quantifiers; q.back()++; - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toFindMore, state.Vars, std::move(q)); + newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); } if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) { auto q = state.Quantifiers; q.pop_back(); - newStates.emplace(state.BeginMatchIndex, state.EndMatchIndex, toMatched, state.Vars, std::move(q)); + newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); } }, }, TransitionGraph->Transitions[state.Index]); @@ -610,22 +642,22 @@ private: } void Insert(TState state) { - auto beginMatchIndex = state.BeginMatchIndex; + auto matchBeginIndex = state.Match.BeginIndex; const auto& transition = TransitionGraph->Transitions[state.Index]; auto diff = static_cast<i64>(ActiveStates.insert(std::move(state)).second); - Add(ActiveStateCounters, beginMatchIndex, diff); + Add(ActiveStateCounters, matchBeginIndex, diff); if (std::holds_alternative<TVoidTransition>(transition)) { - Add(FinishedStateCounters, beginMatchIndex, diff); + Add(FinishedStateCounters, matchBeginIndex, diff); } } void Erase(TState state) { - auto beginMatchIndex = state.BeginMatchIndex; + auto matchBeginIndex = state.Match.BeginIndex; const auto& transition = TransitionGraph->Transitions[state.Index]; auto diff = -static_cast<i64>(ActiveStates.erase(std::move(state))); - Add(ActiveStateCounters, beginMatchIndex, diff); + Add(ActiveStateCounters, matchBeginIndex, diff); if (std::holds_alternative<TVoidTransition>(transition)) { - Add(FinishedStateCounters, beginMatchIndex, diff); + Add(FinishedStateCounters, matchBeginIndex, diff); } } @@ -636,6 +668,7 @@ private: THashMap<size_t, i64> ActiveStateCounters; THashMap<size_t, i64> FinishedStateCounters; bool EndOfData = false; + TAfterMatchSkipTo SkipTo_; }; }//namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp index 745c88e084..afdbd6b8e8 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_nfa_ut.cpp @@ -59,7 +59,7 @@ struct TNfaSetup { for (auto& d: Defines) { defines.push_back(d); } - return TNfa(transitionGraph, MatchedVars, defines); + return TNfa(transitionGraph, MatchedVars, defines, TAfterMatchSkipTo{EAfterMatchSkipTo::PastLastRow, ""}); } TComputationNodeFactory GetAuxCallableFactory() { @@ -261,8 +261,8 @@ Y_UNIT_TEST_SUITE(MatchRecognizeNfa) { Iota(expectedTo.begin(), expectedTo.end(), i - seriesLength + 1); for (size_t matchCount = 0; matchCount < seriesLength; ++matchCount) { auto match = setup.Nfa.GetMatched(); - UNIT_ASSERT_C(match.has_value(), i); - auto vars = match.value(); + UNIT_ASSERT_C(match, i); + auto vars = match->Vars; UNIT_ASSERT_VALUES_EQUAL_C(1, vars.size(), i); auto var = vars[0]; UNIT_ASSERT_VALUES_EQUAL_C(1, var.size(), i); |