diff options
author | vokayndzop <vokayndzop@yandex-team.com> | 2024-12-16 15:55:05 +0300 |
---|---|---|
committer | vokayndzop <vokayndzop@yandex-team.com> | 2024-12-16 16:34:36 +0300 |
commit | b1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c (patch) | |
tree | 230bddb8bb4ce7d8290a16a4465ec98dbf513a5a /yql/essentials/minikql | |
parent | 88e0ad5922cea1349ec1f8cbf133524cf865d696 (diff) | |
download | ydb-b1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c.tar.gz |
MR: support ALL ROWS PER MATCH
commit_hash:9e2ba38d0d523bb870f6dc76717a3bec5d8ffadc
Diffstat (limited to 'yql/essentials/minikql')
10 files changed, 531 insertions, 265 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp index 80caaad37e..e1fc529fda 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp @@ -1,32 +1,28 @@ #include "mkql_match_recognize_list.h" -#include "mkql_match_recognize_matched_vars.h" #include "mkql_match_recognize_measure_arg.h" +#include "mkql_match_recognize_matched_vars.h" #include "mkql_match_recognize_nfa.h" +#include "mkql_match_recognize_rows_formatter.h" #include "mkql_match_recognize_save_load.h" #include "mkql_saveload.h" #include <yql/essentials/core/sql_types/match_recognize.h> -#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> #include <yql/essentials/minikql/computation/mkql_computation_node_pack.h> +#include <yql/essentials/minikql/mkql_node.h> #include <yql/essentials/minikql/mkql_node_cast.h> -#include <yql/essentials/minikql/mkql_runtime_version.h> #include <yql/essentials/minikql/mkql_string_util.h> -#include <yql/essentials/core/sql_types/match_recognize.h> + #include <deque> namespace NKikimr::NMiniKQL { namespace NMatchRecognize { -enum class EOutputColumnSource {PartitionKey, Measure}; -using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>, TMKQLAllocator<std::pair<EOutputColumnSource, size_t>>>; - constexpr ui32 StateVersion = 1; -using namespace NYql::NMatchRecognize; - struct TMatchRecognizeProcessorParameters { IComputationExternalNode* InputDataArg; TRowPattern Pattern; @@ -37,27 +33,21 @@ struct TMatchRecognizeProcessorParameters { TComputationNodePtrVector Defines; IComputationExternalNode* MeasureInputDataArg; TMeasureInputColumnOrder MeasureInputColumnOrder; - TComputationNodePtrVector Measures; - TOutputColumnOrder OutputColumnOrder; - TAfterMatchSkipTo SkipTo; + TAfterMatchSkipTo SkipTo; }; class TStreamingMatchRecognize { - using TPartitionList = TSparseList; - using TRange = TPartitionList::TRange; public: TStreamingMatchRecognize( NUdf::TUnboxedValue&& partitionKey, const TMatchRecognizeProcessorParameters& parameters, - TNfaTransitionGraph::TPtr nfaTransitions, - const TContainerCacheOnContext& cache - ) + const IRowsFormatter::TState& rowsFormatterState, + TNfaTransitionGraph::TPtr nfaTransitions) : PartitionKey(std::move(partitionKey)) , Parameters(parameters) + , RowsFormatter_(IRowsFormatter::Create(rowsFormatterState)) , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo) - , Cache(cache) - { - } + {} bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) { Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows)); @@ -71,33 +61,26 @@ public: } NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) { + if (auto result = RowsFormatter_->GetOtherMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph())) { + return result; + } auto match = Nfa.GetMatched(); if (!match) { return NUdf::TUnboxedValue{}; } - Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars)); + Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TSparseList::TRange>>(ctx.HolderFactory, match->Vars)); Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>( ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows), Parameters.MeasureInputColumnOrder, Parameters.MatchedVarsArg->GetValue(ctx), Parameters.VarNames, MatchNumber - )); - NUdf::TUnboxedValue *itemsPtr = nullptr; - const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr); - for (auto const& c: Parameters.OutputColumnOrder) { - switch(c.first) { - case EOutputColumnSource::Measure: - *itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx); - break; - case EOutputColumnSource::PartitionKey: - *itemsPtr++ = PartitionKey.GetElement(c.second); - break; - } - } + )); + auto result = RowsFormatter_->GetFirstMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph(), *match); Nfa.AfterMatchSkip(*match); return result; } + bool ProcessEndOfData(TComputationContext& ctx) { return Nfa.ProcessEndOfData(ctx); } @@ -117,11 +100,11 @@ public: } private: - const NUdf::TUnboxedValue PartitionKey; + NUdf::TUnboxedValue PartitionKey; const TMatchRecognizeProcessorParameters& Parameters; + std::unique_ptr<IRowsFormatter> RowsFormatter_; TSparseList Rows; TNfa Nfa; - const TContainerCacheOnContext& Cache; ui64 MatchNumber = 0; }; @@ -135,7 +118,7 @@ public: IComputationNode* partitionKey, TType* partitionKeyType, const TMatchRecognizeProcessorParameters& parameters, - const TContainerCacheOnContext& cache, + const IRowsFormatter::TState& rowsFormatterState, TComputationContext &ctx, TType* rowType, const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker @@ -145,8 +128,8 @@ public: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , Parameters(parameters) + , RowsFormatterState(rowsFormatterState) , RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup)) - , Cache(cache) , Terminating(false) , SerializerContext(ctx, rowType, rowPacker) , Ctx(ctx) @@ -184,8 +167,8 @@ public: PartitionHandler.reset(new TStreamingMatchRecognize( std::move(key), Parameters, - RowPatternConfiguration, - Cache + RowsFormatterState, + RowPatternConfiguration )); PartitionHandler->Load(in); } @@ -250,8 +233,9 @@ public: PartitionHandler.reset(new TStreamingMatchRecognize( std::move(partitionKey), Parameters, - RowPatternConfiguration, - Cache)); + RowsFormatterState, + RowPatternConfiguration + )); PartitionHandler->ProcessInputRow(std::move(temp), ctx); } if (Terminating) { @@ -266,8 +250,8 @@ private: IComputationNode* PartitionKey; TValuePackerGeneric<false> PartitionKeyPacker; const TMatchRecognizeProcessorParameters& Parameters; + const IRowsFormatter::TState& RowsFormatterState; const TNfaTransitionGraph::TPtr RowPatternConfiguration; - const TContainerCacheOnContext& Cache; NUdf::TUnboxedValue DelayedRow; bool Terminating; TSerializerContext SerializerContext; @@ -286,7 +270,7 @@ public: IComputationNode* partitionKey, TType* partitionKeyType, const TMatchRecognizeProcessorParameters& parameters, - const TContainerCacheOnContext& cache, + const IRowsFormatter::TState& rowsFormatterState, TComputationContext &ctx, TType* rowType, const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker @@ -296,8 +280,8 @@ public: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , Parameters(parameters) + , RowsFormatterState(rowsFormatterState) , NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup)) - , Cache(cache) , SerializerContext(ctx, rowType, rowPacker) , Ctx(ctx) {} @@ -335,8 +319,10 @@ public: std::make_unique<TStreamingMatchRecognize>( std::move(key), Parameters, - NfaTransitionGraph, - Cache)); + RowsFormatterState, + NfaTransitionGraph + ) + ); pair.first->second->Load(in); } @@ -402,8 +388,8 @@ private: return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamingMatchRecognize>( std::move(partitionKey), Parameters, - NfaTransitionGraph, - Cache + RowsFormatterState, + NfaTransitionGraph )); } } @@ -418,8 +404,8 @@ private: //TODO switch to tuple compare TValuePackerGeneric<false> PartitionKeyPacker; const TMatchRecognizeProcessorParameters& Parameters; + const IRowsFormatter::TState& RowsFormatterState; const TNfaTransitionGraph::TPtr NfaTransitionGraph; - const TContainerCacheOnContext& Cache; TSerializerContext SerializerContext; TComputationContext& Ctx; }; @@ -428,20 +414,23 @@ template<class State> class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true> { using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true>; public: - TMatchRecognizeWrapper(TComputationMutables &mutables, EValueRepresentation kind, IComputationNode *inputFlow, - IComputationExternalNode *inputRowArg, - IComputationNode *partitionKey, - TType* partitionKeyType, - const TMatchRecognizeProcessorParameters& parameters, - TType* rowType - ) - :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) + TMatchRecognizeWrapper( + TComputationMutables& mutables, + EValueRepresentation kind, + IComputationNode *inputFlow, + IComputationExternalNode *inputRowArg, + IComputationNode *partitionKey, + TType* partitionKeyType, + TMatchRecognizeProcessorParameters&& parameters, + IRowsFormatter::TState&& rowsFormatterState, + TType* rowType) + : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) , InputFlow(inputFlow) , InputRowArg(inputRowArg) , PartitionKey(partitionKey) , PartitionKeyType(partitionKeyType) - , Parameters(parameters) - , Cache(mutables) + , Parameters(std::move(parameters)) + , RowsFormatterState(std::move(rowsFormatterState)) , RowType(rowType) , RowPacker(mutables) {} @@ -453,7 +442,7 @@ public: PartitionKey, PartitionKeyType, Parameters, - Cache, + RowsFormatterState, ctx, RowType, RowPacker @@ -468,7 +457,7 @@ public: PartitionKey, PartitionKeyType, Parameters, - Cache, + RowsFormatterState, ctx, RowType, RowPacker @@ -503,7 +492,7 @@ private: Own(flow, Parameters.CurrentRowIndexArg); Own(flow, Parameters.MeasureInputDataArg); DependsOn(flow, PartitionKey); - for (auto& m: Parameters.Measures) { + for (auto& m: RowsFormatterState.Measures) { DependsOn(flow, m); } for (auto& d: Parameters.Defines) { @@ -516,32 +505,31 @@ private: IComputationExternalNode* const InputRowArg; IComputationNode* const PartitionKey; TType* const PartitionKeyType; - const TMatchRecognizeProcessorParameters Parameters; - const TContainerCacheOnContext Cache; + TMatchRecognizeProcessorParameters Parameters; + IRowsFormatter::TState RowsFormatterState; TType* const RowType; TMutableObjectOverBoxedValue<TValuePackerBoxed> RowPacker; }; TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode partitionKyeColumnsIndexes, TRuntimeNode measureColumnsIndexes) { - using tempMapValue = std::pair<EOutputColumnSource, size_t>; - std::unordered_map<size_t, tempMapValue, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, tempMapValue>, EMemorySubPool::Temporary>> temp; + std::unordered_map<size_t, TOutputColumnEntry, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, TOutputColumnEntry>, EMemorySubPool::Temporary>> temp; { auto list = AS_VALUE(TListLiteral, partitionKyeColumnsIndexes); for (ui32 i = 0; i != list->GetItemsCount(); ++i) { auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>(); - temp[index] = std::make_pair(EOutputColumnSource::PartitionKey, i); + temp[index] = {i, EOutputColumnSource::PartitionKey}; } } { auto list = AS_VALUE(TListLiteral, measureColumnsIndexes); for (ui32 i = 0; i != list->GetItemsCount(); ++i) { auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>(); - temp[index] = std::make_pair(EOutputColumnSource::Measure, i); + temp[index] = {i, EOutputColumnSource::Measure}; } } if (temp.empty()) return {}; - auto outputSize = max_element(temp.cbegin(), temp.cend())->first + 1; + auto outputSize = std::ranges::max_element(temp, {}, &std::pair<const size_t, TOutputColumnEntry>::first)->first + 1; TOutputColumnOrder result(outputSize); for (const auto& [i, v]: temp) { result[i] = v; @@ -576,7 +564,6 @@ TRowPattern ConvertPattern(const TRuntimeNode& pattern) { } TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) { - using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns; //Use Last enum value to denote that c colum comes from the input table TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0)); if (specialColumnIndexes.GetItemsCount() != 0) { @@ -621,7 +608,6 @@ std::pair<TUnboxedValueVector, THashMap<TString, size_t>> ConvertListOfStrings(c } //namespace NMatchRecognize - IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) { using namespace NMatchRecognize; size_t inputIndex = 0; @@ -641,9 +627,9 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation const auto& pattern = callable.GetInput(inputIndex++); const auto& currentRowIndexArg = callable.GetInput(inputIndex++); const auto& inputDataArg = callable.GetInput(inputIndex++); - const auto& varNames = callable.GetInput(inputIndex++); + const auto& defineNames = callable.GetInput(inputIndex++); TRuntimeNode::TList defines; - for (size_t i = 0; i != AS_VALUE(TListLiteral, varNames)->GetItemsCount(); ++i) { + for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) { defines.push_back(callable.GetInput(inputIndex++)); } const auto& streamingMode = callable.GetInput(inputIndex++); @@ -652,47 +638,58 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>()); skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef(); } + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch = NYql::NMatchRecognize::ERowsPerMatch::OneRow; + TOutputColumnOrder outputColumnOrder; + if (inputIndex + 2 <= callable.GetInputsCount()) { + rowsPerMatch = static_cast<ERowsPerMatch>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>()); + outputColumnOrder = IRowsFormatter::GetOutputColumnOrder(callable.GetInput(inputIndex++)); + } else { + outputColumnOrder = GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes); + } MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count"); - const auto& [vars, varsLookup] = ConvertListOfStrings(varNames); + const auto& [varNames, varNamesLookup] = ConvertListOfStrings(defineNames); auto* rowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputFlow.GetStaticType())->GetItemType()); - const auto parameters = TMatchRecognizeProcessorParameters { - static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())) - , ConvertPattern(pattern) - , vars - , varsLookup - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())) - , ConvertVectorOfCallables(defines, ctx) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())) - , GetMeasureColumnOrder( + auto parameters = TMatchRecognizeProcessorParameters { + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())), + ConvertPattern(pattern), + varNames, + varNamesLookup, + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())), + ConvertVectorOfCallables(defines, ctx), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())), + GetMeasureColumnOrder( *AS_VALUE(TListLiteral, measureSpecialColumnIndexes), AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>() - ) - , ConvertVectorOfCallables(measures, ctx) - , GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes) - , skipTo + ), + skipTo }; + IRowsFormatter::TState rowsFormatterState(ctx, outputColumnOrder, ConvertVectorOfCallables(measures, ctx), rowsPerMatch); if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) { - return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables - , GetValueRepresentation(inputFlow.GetStaticType()) - , LocateNode(ctx.NodeLocator, *inputFlow.GetNode()) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())) - , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()) - , partitionKeySelector.GetStaticType() - , std::move(parameters) - , rowType + return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>( + ctx.Mutables, + GetValueRepresentation(inputFlow.GetStaticType()), + LocateNode(ctx.NodeLocator, *inputFlow.GetNode()), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())), + LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()), + partitionKeySelector.GetStaticType(), + std::move(parameters), + std::move(rowsFormatterState), + rowType ); } else { - return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables - , GetValueRepresentation(inputFlow.GetStaticType()) - , LocateNode(ctx.NodeLocator, *inputFlow.GetNode()) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())) - , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()) - , partitionKeySelector.GetStaticType() - , std::move(parameters) - , rowType + return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>( + ctx.Mutables, + GetValueRepresentation(inputFlow.GetStaticType()), + LocateNode(ctx.NodeLocator, *inputFlow.GetNode()), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())), + LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()), + partitionKeySelector.GetStaticType(), + std::move(parameters), + std::move(rowsFormatterState), + rowType ); } } diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h index 3c7771a41e..cf9bfb46a9 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h @@ -18,26 +18,29 @@ public: class TRange { public: TRange() - : FromIndex(-1) - , ToIndex(-1) + : FromIndex(Max()) + , ToIndex(Max()) + , NfaIndex_(Max()) { } explicit TRange(ui64 index) - : FromIndex(index) - , ToIndex(index) + : FromIndex(index) + , ToIndex(index) + , NfaIndex_(Max()) { } TRange(ui64 from, ui64 to) - : FromIndex(from) - , ToIndex(to) + : FromIndex(from) + , ToIndex(to) + , NfaIndex_(Max()) { MKQL_ENSURE(FromIndex <= ToIndex, "Internal logic error"); } bool IsValid() const { - return true; + return FromIndex != Max<size_t>() && ToIndex != Max<size_t>(); } size_t From() const { @@ -50,6 +53,11 @@ public: return ToIndex; } + [[nodiscard]] size_t NfaIndex() const { + MKQL_ENSURE(IsValid(), "Internal logic error"); + return NfaIndex_; + } + size_t Size() const { MKQL_ENSURE(IsValid(), "Internal logic error"); return ToIndex - FromIndex + 1; @@ -61,8 +69,9 @@ public: } private: - ui64 FromIndex; - ui64 ToIndex; + size_t FromIndex; + size_t ToIndex; + size_t NfaIndex_; }; TRange Append(NUdf::TUnboxedValue&& value) { @@ -156,14 +165,12 @@ class TSparseList { private: //TODO consider to replace hash table with contiguous chunks - using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>; - using TStorage = std::unordered_map< size_t, TItem, std::hash<size_t>, std::equal_to<size_t>, - TAllocator>; + TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>>; TStorage Storage; }; @@ -178,8 +185,9 @@ public: public: TRange() : Container() - , FromIndex(-1) - , ToIndex(-1) + , FromIndex(Max()) + , ToIndex(Max()) + , NfaIndex_(Max()) { } @@ -187,6 +195,7 @@ public: : Container(other.Container) , FromIndex(other.FromIndex) , ToIndex(other.ToIndex) + , NfaIndex_(other.NfaIndex_) { LockRange(FromIndex, ToIndex); } @@ -195,6 +204,7 @@ public: : Container(other.Container) , FromIndex(other.FromIndex) , ToIndex(other.ToIndex) + , NfaIndex_(other.NfaIndex_) { other.Reset(); } @@ -212,6 +222,7 @@ public: Container = other.Container; FromIndex = other.FromIndex; ToIndex = other.ToIndex; + NfaIndex_ = other.NfaIndex_; LockRange(FromIndex, ToIndex); return *this; } @@ -224,20 +235,21 @@ public: Container = other.Container; FromIndex = other.FromIndex; ToIndex = other.ToIndex; + NfaIndex_ = other.NfaIndex_; other.Reset(); return *this; } friend inline bool operator==(const TRange& lhs, const TRange& rhs) { - return std::tie(lhs.FromIndex, lhs.ToIndex) == std::tie(rhs.FromIndex, rhs.ToIndex); + return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) == std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_); } friend inline bool operator<(const TRange& lhs, const TRange& rhs) { - return std::tie(lhs.FromIndex, lhs.ToIndex) < std::tie(rhs.FromIndex, rhs.ToIndex); + return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) < std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_); } bool IsValid() const { - return static_cast<bool>(Container); + return static_cast<bool>(Container) && FromIndex != Max<size_t>() && ToIndex != Max<size_t>(); } size_t From() const { @@ -250,6 +262,15 @@ public: return ToIndex; } + [[nodiscard]] size_t NfaIndex() const { + MKQL_ENSURE(IsValid(), "Internal logic error"); + return NfaIndex_; + } + + void NfaIndex(size_t index) { + NfaIndex_ = index; + } + size_t Size() const { MKQL_ENSURE(IsValid(), "Internal logic error"); return ToIndex - FromIndex + 1; @@ -264,16 +285,17 @@ public: void Release() { UnlockRange(FromIndex, ToIndex); Container.Reset(); - FromIndex = -1; - ToIndex = -1; + FromIndex = Max(); + ToIndex = Max(); + NfaIndex_ = Max(); } void Save(TMrOutputSerializer& serializer) const { - serializer(Container, FromIndex, ToIndex); + serializer(Container, FromIndex, ToIndex, NfaIndex_); } void Load(TMrInputSerializer& serializer) { - serializer(Container, FromIndex, ToIndex); + serializer(Container, FromIndex, ToIndex, NfaIndex_); } private: @@ -281,6 +303,7 @@ public: : Container(container) , FromIndex(index) , ToIndex(index) + , NfaIndex_(Max()) {} void LockRange(size_t from, size_t to) { @@ -297,13 +320,15 @@ public: void Reset() { Container.Reset(); - FromIndex = -1; - ToIndex = -1; + FromIndex = Max(); + ToIndex = Max(); + NfaIndex_ = Max(); } TContainerPtr Container; size_t FromIndex; size_t ToIndex; + size_t NfaIndex_; }; public: diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h index 8c2576798a..de0ccc352f 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h @@ -15,7 +15,7 @@ void Extend(TMatchedVar<R>& var, const R& r) { var.emplace_back(r); } else { MKQL_ENSURE(r.From() > var.back().To(), "Internal logic error"); - if (var.back().To() + 1 == r.From()) { + if (var.back().To() + 1 == r.From() && var.back().NfaIndex() == r.NfaIndex()) { var.back().Extend(); } else { var.emplace_back(r); diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h index 2b194212f4..c868379d40 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h @@ -20,9 +20,10 @@ struct TEpsilonTransitions { friend constexpr bool operator==(const TEpsilonTransitions&, const TEpsilonTransitions&) = default; }; struct TMatchedVarTransition { + size_t To; ui32 VarIndex; bool SaveState; - size_t To; + bool ExcludeFromOutput; friend constexpr bool operator==(const TMatchedVarTransition&, const TMatchedVarTransition&) = default; }; struct TQuantityEnterTransition { @@ -116,7 +117,7 @@ struct TNfaTransitionGraph { serializer(tr.To); }, [&](const TMatchedVarTransition& tr) { - serializer(tr.VarIndex, tr.SaveState, tr.To); + serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput); }, [&](const TQuantityEnterTransition& tr) { serializer(tr.To); @@ -141,7 +142,7 @@ struct TNfaTransitionGraph { serializer(tr.To); }, [&](TMatchedVarTransition& tr) { - serializer(tr.VarIndex, tr.SaveState, tr.To); + serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput); }, [&](TQuantityEnterTransition& tr) { serializer(tr.To); @@ -297,7 +298,7 @@ private: auto input = AddNode(); auto output = AddNode(); auto item = factor.Primary.index() == 0 ? - BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused) : + BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused, !factor.Output) : BuildTerms(std::get<1>(factor.Primary), varNameToIndex); if (1 == factor.QuantityMin && 1 == factor.QuantityMax) { //simple linear case Graph->Transitions[input] = TEpsilonTransitions{{item.Input}}; @@ -319,15 +320,16 @@ private: } return {input, output}; } - TNfaItem BuildVar(ui32 varIndex, bool isUsed) { + TNfaItem BuildVar(ui32 varIndex, bool isUsed, bool excludeFromOutput) { auto input = AddNode(); auto matchVar = AddNode(); auto output = AddNode(); Graph->Transitions[input] = TEpsilonTransitions({matchVar}); Graph->Transitions[matchVar] = TMatchedVarTransition{ + output, varIndex, isUsed, - output, + excludeFromOutput, }; return {input, output}; } @@ -441,6 +443,7 @@ public: if (matchedVarTransition->SaveState) { auto vars = state.Match.Vars; //TODO get rid of this copy auto& matchedVar = vars[varIndex]; + currentRowLock.NfaIndex(state.Index); Extend(matchedVar, currentRowLock); newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers); } else { @@ -575,6 +578,10 @@ public: } } + const TNfaTransitionGraph& GetTransitionGraph() const { + return *TransitionGraph; + } + private: //TODO (zverevgeny): Consider to change to std::vector for the sake of perf using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>; @@ -593,14 +600,14 @@ private: [&](const TEpsilonTransitions& epsilonTransitions) { deletedStates.insert(state); for (const auto& i : epsilonTransitions.To) { - newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers); + newStates.emplace(i, state.Match, state.Quantifiers); } }, [&](const TQuantityEnterTransition& quantityEnterTransition) { deletedStates.insert(state); auto quantifiers = state.Quantifiers; //TODO get rid of this copy quantifiers.push_back(0); - newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers)); + newStates.emplace(quantityEnterTransition.To, state.Match, std::move(quantifiers)); }, [&](const TQuantityExitTransition& quantityExitTransition) { deletedStates.insert(state); @@ -608,12 +615,12 @@ private: if (state.Quantifiers.back() + 1 < quantityMax) { auto q = state.Quantifiers; q.back()++; - newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); + newStates.emplace(toFindMore, state.Match, std::move(q)); } if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) { auto q = state.Quantifiers; q.pop_back(); - newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); + newStates.emplace(toMatched, state.Match, std::move(q)); } }, }, TransitionGraph->Transitions[state.Index]); diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp new file mode 100644 index 0000000000..6d70458b3b --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp @@ -0,0 +1,144 @@ +#include "mkql_match_recognize_rows_formatter.h" + +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/minikql/mkql_node_cast.h> + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +namespace { + +class TOneRowFormatter final : public IRowsFormatter { +public: + explicit TOneRowFormatter(const TState& state) : IRowsFormatter(state) {} + + NUdf::TUnboxedValue GetFirstMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph, + const TNfa::TMatch& match) { + Match_ = match; + const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph); + IRowsFormatter::Clear(); + return result; + } + + NUdf::TUnboxedValue GetOtherMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph) { + return NUdf::TUnboxedValue{}; + } +}; + +class TAllRowsFormatter final : public IRowsFormatter { +public: + explicit TAllRowsFormatter(const IRowsFormatter::TState& state) : IRowsFormatter(state) {} + + NUdf::TUnboxedValue GetFirstMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph, + const TNfa::TMatch& match) { + Match_ = match; + CurrentRowIndex_ = Match_.BeginIndex; + for (const auto& matchedVar : Match_.Vars) { + for (const auto& range : matchedVar) { + ToIndexToMatchRangeLookup_.emplace(range.To(), range); + } + } + return GetMatchRow(ctx, rows, partitionKey, graph); + } + + NUdf::TUnboxedValue GetOtherMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph) { + return GetMatchRow(ctx, rows, partitionKey, graph); + } + +private: + NUdf::TUnboxedValue GetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) { + while (CurrentRowIndex_ <= Match_.EndIndex) { + if (auto iter = ToIndexToMatchRangeLookup_.lower_bound(CurrentRowIndex_); + iter == ToIndexToMatchRangeLookup_.end()) { + MKQL_ENSURE(false, "Internal logic error"); + } else if (auto transition = std::get_if<TMatchedVarTransition>(&graph.Transitions.at(iter->second.NfaIndex())); + !transition) { + MKQL_ENSURE(false, "Internal logic error"); + } else if (transition->ExcludeFromOutput) { + ++CurrentRowIndex_; + } else { + break; + } + } + if (CurrentRowIndex_ > Match_.EndIndex) { + return NUdf::TUnboxedValue{}; + } + const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph); + ++CurrentRowIndex_; + if (CurrentRowIndex_ == Match_.EndIndex) { + Clear(); + } + return result; + } + + void Clear() { + IRowsFormatter::Clear(); + ToIndexToMatchRangeLookup_.clear(); + } + + TMap<size_t, const TSparseList::TRange&> ToIndexToMatchRangeLookup_; +}; + +} // anonymous namespace + +IRowsFormatter::IRowsFormatter(const TState& state) : State_(state) {} + +TOutputColumnOrder IRowsFormatter::GetOutputColumnOrder( + TRuntimeNode outputColumnOrder) { + TOutputColumnOrder result; + auto list = AS_VALUE(TListLiteral, outputColumnOrder); + TConstArrayRef<TRuntimeNode> items(list->GetItems(), list->GetItemsCount()); + for (auto item : items) { + const auto entry = AS_VALUE(TStructLiteral, item); + result.emplace_back( + AS_VALUE(TDataLiteral, entry->GetValue(0))->AsValue().Get<ui32>(), + static_cast<EOutputColumnSource>(AS_VALUE(TDataLiteral, entry->GetValue(1))->AsValue().Get<i32>()) + ); + } + return result; +} + +NUdf::TUnboxedValue IRowsFormatter::DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) { + NUdf::TUnboxedValue *itemsPtr = nullptr; + const auto result = State_.Cache->NewArray(ctx, State_.OutputColumnOrder.size(), itemsPtr); + for (const auto& columnEntry: State_.OutputColumnOrder) { + switch(columnEntry.SourceType) { + case EOutputColumnSource::PartitionKey: + *itemsPtr++ = partitionKey.GetElement(columnEntry.Index); + break; + case EOutputColumnSource::Measure: + *itemsPtr++ = State_.Measures[columnEntry.Index]->GetValue(ctx); + break; + case EOutputColumnSource::Other: + *itemsPtr++ = rows.Get(CurrentRowIndex_).GetElement(columnEntry.Index); + break; + } + } + return result; +} + +std::unique_ptr<IRowsFormatter> IRowsFormatter::Create(const IRowsFormatter::TState& state) { + switch (state.RowsPerMatch) { + case ERowsPerMatch::OneRow: + return std::unique_ptr<IRowsFormatter>(new TOneRowFormatter(state)); + case ERowsPerMatch::AllRows: + return std::unique_ptr<IRowsFormatter>(new TAllRowsFormatter(state)); + } +} + +} //namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h new file mode 100644 index 0000000000..39750bf4ea --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h @@ -0,0 +1,72 @@ +#pragma once + +#include "mkql_match_recognize_nfa.h" + +#include <yql/essentials/core/sql_types/match_recognize.h> +#include <yql/essentials/minikql/computation/mkql_computation_node.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h> +#include <yql/essentials/minikql/mkql_alloc.h> +#include <yql/essentials/public/udf/udf_value.h> + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +struct TOutputColumnEntry { + size_t Index; + NYql::NMatchRecognize::EOutputColumnSource SourceType; +}; +using TOutputColumnOrder = std::vector<TOutputColumnEntry, TMKQLAllocator<TOutputColumnEntry>>; + +class IRowsFormatter { +public: + struct TState { + std::unique_ptr<TContainerCacheOnContext> Cache; + TOutputColumnOrder OutputColumnOrder; + TComputationNodePtrVector Measures; + NYql::NMatchRecognize::ERowsPerMatch RowsPerMatch; + + TState( + const TComputationNodeFactoryContext& ctx, + TOutputColumnOrder outputColumnOrder, + TComputationNodePtrVector measures, + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch) + : Cache(std::make_unique<TContainerCacheOnContext>(ctx.Mutables)) + , OutputColumnOrder(std::move(outputColumnOrder)) + , Measures(std::move(measures)) + , RowsPerMatch(rowsPerMatch) + {} + }; + + explicit IRowsFormatter(const TState& state); + virtual ~IRowsFormatter() = default; + + virtual NUdf::TUnboxedValue GetFirstMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph, + const TNfa::TMatch& match) = 0; + + virtual NUdf::TUnboxedValue GetOtherMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph) = 0; + + static TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode outputColumnOrder); + + static std::unique_ptr<IRowsFormatter> Create(const TState& state); + +protected: + NUdf::TUnboxedValue DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph); + + inline void Clear() { + Match_ = {}; + CurrentRowIndex_ = Max(); + } + + const TState& State_; + TNfa::TMatch Match_ {}; + size_t CurrentRowIndex_ = Max(); +}; + +} // namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp index a9fad1b6ef..513a72df5e 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp @@ -64,57 +64,46 @@ namespace { const TTestInputData& input) { TProgramBuilder& pgmBuilder = *setup.PgmBuilder; - auto structType = pgmBuilder.NewStructType({ - {"time", pgmBuilder.NewDataType(NUdf::TDataType<i64>::Id)}, - {"key", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}, - {"sum", pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id)}, - {"part", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}}); + const auto structType = pgmBuilder.NewStructType({ + {"time", pgmBuilder.NewDataType(NUdf::EDataSlot::Int64)}, + {"key", pgmBuilder.NewDataType(NUdf::EDataSlot::String)}, + {"sum", pgmBuilder.NewDataType(NUdf::EDataSlot::Uint32)}, + {"part", pgmBuilder.NewDataType(NUdf::EDataSlot::String)} + }); TVector<TRuntimeNode> items; - for (size_t i = 0; i < input.size(); ++i) - { - auto time = pgmBuilder.NewDataLiteral<i64>(std::get<0>(input[i])); - auto key = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<1>(input[i]))); - auto sum = pgmBuilder.NewDataLiteral<ui32>(std::get<2>(input[i])); - auto part = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<3>(input[i]))); - - auto item = pgmBuilder.NewStruct(structType, - {{"time", time}, {"key", key}, {"sum", sum}, {"part", part}}); - items.push_back(std::move(item)); + for (size_t i = 0; i < input.size(); ++i) { + const auto& [time, key, sum, part] = input[i]; + items.push_back(pgmBuilder.NewStruct({ + {"time", pgmBuilder.NewDataLiteral(time)}, + {"key", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(key)}, + {"sum", pgmBuilder.NewDataLiteral(sum)}, + {"part", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(part)}, + })); } const auto list = pgmBuilder.NewList(structType, std::move(items)); auto inputFlow = pgmBuilder.ToFlow(list); - - TVector<TStringBuf> partitionColumns; - TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures = {{ - std::make_pair( - TStringBuf("key"), - [&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) { - return pgmBuilder.NewDataLiteral<ui32>(56); - } - )}}; - TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines = {{ - std::make_pair( - TStringBuf("A"), - [&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) { - return pgmBuilder.NewDataLiteral<bool>(true); - } - )}}; - auto pgmReturn = pgmBuilder.MatchRecognizeCore( inputFlow, [&](TRuntimeNode item) { - return pgmBuilder.Member(item, "part"); + return pgmBuilder.NewTuple({pgmBuilder.Member(item, "part")}); }, - partitionColumns, - getMeasures, + {}, + {"key"sv}, + {[&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) { + return pgmBuilder.NewDataLiteral<ui32>(56); + }}, { {NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}} }, - getDefines, + {"A"sv}, + {[&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) { + return pgmBuilder.NewDataLiteral<bool>(true); + }}, streamingMode, - {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""} + {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}, + NYql::NMatchRecognize::ERowsPerMatch::OneRow ); auto graph = setup.BuildGraph(pgmReturn); diff --git a/yql/essentials/minikql/comp_nodes/ya.make.inc b/yql/essentials/minikql/comp_nodes/ya.make.inc index 518f5f3c43..b2f0da8ac9 100644 --- a/yql/essentials/minikql/comp_nodes/ya.make.inc +++ b/yql/essentials/minikql/comp_nodes/ya.make.inc @@ -79,6 +79,7 @@ SET(ORIG_SOURCES mkql_mapnext.cpp mkql_map_join.cpp mkql_match_recognize.cpp + mkql_match_recognize_rows_formatter.cpp mkql_multihopping.cpp mkql_multimap.cpp mkql_next_value.cpp diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 691a642814..0b9197024d 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -1,5 +1,4 @@ #include "mkql_program_builder.h" -#include "mkql_opt_literal.h" #include "mkql_node_visitor.h" #include "mkql_node_cast.h" #include "mkql_runtime_version.h" @@ -11,6 +10,7 @@ #include "yql/essentials/core/sql_types/time_order_recover.h" #include <yql/essentials/parser/pg_catalog/catalog.h> +#include <util/generic/overloaded.h> #include <util/string/cast.h> #include <util/string/printf.h> #include <array> @@ -6035,15 +6035,19 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil TTupleLiteralBuilder termBuilder(env); for (const auto& factor: term) { TTupleLiteralBuilder factorBuilder(env); - factorBuilder.Add(factor.Primary.index() == 0 ? - programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(std::get<0>(factor.Primary)) : - PatternToRuntimeNode(std::get<1>(factor.Primary), programBuilder) - ); - factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMin)); - factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMax)); - factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Greedy)); - factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Output)); - factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Unused)); + factorBuilder.Add(std::visit(TOverloaded { + [&](const TString& s) { + return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s); + }, + [&](const TRowPattern& pattern) { + return PatternToRuntimeNode(pattern, programBuilder); + }, + }, factor.Primary)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused)); termBuilder.Add({factorBuilder.Build(), true}); } patternBuilder.Add({termBuilder.Build(), true}); @@ -6056,151 +6060,172 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil TRuntimeNode TProgramBuilder::MatchRecognizeCore( TRuntimeNode inputStream, const TUnaryLambda& getPartitionKeySelectorNode, - const TArrayRef<TStringBuf>& partitionColumns, - const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures, + const TArrayRef<TStringBuf>& partitionColumnNames, + const TVector<TStringBuf>& measureColumnNames, + const TVector<TBinaryLambda>& getMeasures, const NYql::NMatchRecognize::TRowPattern& pattern, - const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines, + const TVector<TStringBuf>& defineVarNames, + const TVector<TTernaryLambda>& getDefines, bool streamingMode, - const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo + const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo, + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch ) { MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion); const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType()); const auto inputRowArg = Arg(inputRowType); const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg); + const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements(); - TStructTypeBuilder indexRangeTypeBuilder(Env); - indexRangeTypeBuilder.Add("From", TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); - indexRangeTypeBuilder.Add("To", TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); - const auto& rangeList = TListType::Create(indexRangeTypeBuilder.Build(), Env); + const auto rangeList = NewListType(NewStructType({ + {"From", NewDataType(NUdf::EDataSlot::Uint64)}, + {"To", NewDataType(NUdf::EDataSlot::Uint64)} + })); TStructTypeBuilder matchedVarsTypeBuilder(Env); for (const auto& var: GetPatternVars(pattern)) { matchedVarsTypeBuilder.Add(var, rangeList); } - TRuntimeNode matchedVarsArg = Arg(matchedVarsTypeBuilder.Build()); + const auto matchedVarsType = matchedVarsTypeBuilder.Build(); + TRuntimeNode matchedVarsArg = Arg(matchedVarsType); //---These vars may be empty in case of no measures TRuntimeNode measureInputDataArg; std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow; TVector<TRuntimeNode> measures; - TVector<TType*> measureTypes; //--- if (getMeasures.empty()) { measureInputDataArg = Arg(Env.GetTypeOfVoidLazy()); } else { - using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns; measures.reserve(getMeasures.size()); - measureTypes.reserve(getMeasures.size()); specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last)); TStructTypeBuilder measureInputDataRowTypeBuilder(Env); - for (ui32 i = 0; i != inputRowType->GetMembersCount(); ++i) { + for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) { measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i)); } measureInputDataRowTypeBuilder.Add( MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier), - TDataType::Create(NUdf::TDataType<NYql::NUdf::TUtf8>::Id, Env) + NewDataType(NUdf::EDataSlot::Utf8) ); measureInputDataRowTypeBuilder.Add( MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber), - TDataType::Create(NUdf::TDataType<ui64>::Id, Env) + NewDataType(NUdf::EDataSlot::Uint64) ); const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build(); - for (ui32 i = 0; i != measureInputDataRowType->GetMembersCount(); ++i) { + for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) { //assume a few, if grows, it's better to use a lookup table here static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5); for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) { if (measureInputDataRowType->GetMemberName(i) == NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j))) - specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral<ui32>(i); + specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i); } } - measureInputDataArg = Arg(TListType::Create(measureInputDataRowType, Env)); + measureInputDataArg = Arg(NewListType(measureInputDataRowType)); for (size_t i = 0; i != getMeasures.size(); ++i) { - measures.push_back(getMeasures[i].second(measureInputDataArg, matchedVarsArg)); - measureTypes.push_back(measures[i].GetStaticType()); + measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg)); } } TStructTypeBuilder outputRowTypeBuilder(Env); THashMap<TStringBuf, size_t> partitionColumnLookup; - for (size_t i = 0; i != partitionColumns.size(); ++i) { - const auto& name = partitionColumns[i]; - partitionColumnLookup[name] = i; - outputRowTypeBuilder.Add( - name, - AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElementType(i) - ); - } THashMap<TStringBuf, size_t> measureColumnLookup; - for (size_t i = 0; i != measures.size(); ++i) { - const auto& name = getMeasures[i].first; - measureColumnLookup[name] = i; - outputRowTypeBuilder.Add( - name, - measures[i].GetStaticType() - ); + THashMap<TStringBuf, size_t> otherColumnLookup; + for (size_t i = 0; i < measureColumnNames.size(); ++i) { + const auto name = measureColumnNames[i]; + measureColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, measures[i].GetStaticType()); + } + switch (rowsPerMatch) { + case NYql::NMatchRecognize::ERowsPerMatch::OneRow: + for (size_t i = 0; i < partitionColumnNames.size(); ++i) { + const auto name = partitionColumnNames[i]; + partitionColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, partitionColumnTypes[i]); + } + break; + case NYql::NMatchRecognize::ERowsPerMatch::AllRows: + for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) { + const auto name = inputRowType->GetMemberName(i); + otherColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i)); + } + break; } auto outputRowType = outputRowTypeBuilder.Build(); std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size()); std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size()); - for (ui32 i = 0; i != outputRowType->GetMembersCount(); ++i) { - if (auto it = partitionColumnLookup.find(outputRowType->GetMemberName(i)); it != partitionColumnLookup.end()) { - partitionColumnIndexes[it->second] = NewDataLiteral<ui32>(i); - } - else if (auto it = measureColumnLookup.find(outputRowType->GetMemberName(i)); it != measureColumnLookup.end()) { - measureColumnIndexes[it->second] = NewDataLiteral<ui32>(i); + TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()}); + for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) { + const auto name = outputRowType->GetMemberName(i); + if (auto iter = partitionColumnLookup.find(name); + iter != partitionColumnLookup.end()) { + partitionColumnIndexes[iter->second] = NewDataLiteral(i); + outputColumnOrder.push_back(NewStruct({ + std::pair{"Index", NewDataLiteral(iter->second)}, + std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))}, + })); + } else if (auto iter = measureColumnLookup.find(name); + iter != measureColumnLookup.end()) { + measureColumnIndexes[iter->second] = NewDataLiteral(i); + outputColumnOrder.push_back(NewStruct({ + std::pair{"Index", NewDataLiteral(iter->second)}, + std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))}, + })); + } else if (auto iter = otherColumnLookup.find(name); + iter != otherColumnLookup.end()) { + outputColumnOrder.push_back(NewStruct({ + std::pair{"Index", NewDataLiteral(iter->second)}, + std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))}, + })); } } - auto outputType = (TType*)TFlowType::Create(outputRowType, Env); + const auto outputType = NewFlowType(outputRowType); - THashMap<TStringBuf , size_t> patternVarLookup; - for (ui32 i = 0; i != AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMembersCount(); ++i){ - patternVarLookup[AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMemberName(i)] = i; + THashMap<TStringBuf, size_t> patternVarLookup; + for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) { + patternVarLookup[matchedVarsType->GetMemberName(i)] = i; } - THashMap<TStringBuf , size_t> defineLookup; - for (size_t i = 0; i != getDefines.size(); ++i) { - defineLookup[getDefines[i].first] = i; + THashMap<TStringBuf, size_t> defineLookup; + for (size_t i = 0; i < defineVarNames.size(); ++i) { + const auto name = defineVarNames[i]; + defineLookup[name] = i; } - std::vector<TRuntimeNode> defineNames(patternVarLookup.size()); - std::vector<TRuntimeNode> defineNodes(patternVarLookup.size()); - const auto& inputDataArg = Arg(TListType::Create(inputRowType, Env)); - const auto& currentRowIndexArg = Arg(TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); + TVector<TRuntimeNode> defineNames(defineVarNames.size()); + TVector<TRuntimeNode> defineNodes(patternVarLookup.size()); + const auto inputDataArg = Arg(NewListType(inputRowType)); + const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64)); for (const auto& [v, i]: patternVarLookup) { defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v); - if (const auto it = defineLookup.find(v); it != defineLookup.end()) { - defineNodes[i] = getDefines[it->second].second(inputDataArg, matchedVarsArg, currentRowIndexArg); - } - else { //no predicate for var - if ("$" == v || "^" == v) { - //DO nothing, //will be handled in a specific way - } - else { // a var without a predicate matches any row - defineNodes[i] = NewDataLiteral<bool>(true); - } + if (auto iter = defineLookup.find(v); + iter != defineLookup.end()) { + defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg); + } else if ("$" == v || "^" == v) { + //DO nothing, //will be handled in a specific way + } else { // a var without a predicate matches any row + defineNodes[i] = NewDataLiteral(true); } } TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType); - auto indexType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env); - auto indexListType = TListType::Create(indexType, Env); + const auto indexType = NewDataType(NUdf::EDataSlot::Uint32); + const auto outputColumnEntryType = NewStructType({ + {"Index", NewDataType(NUdf::EDataSlot::Uint64)}, + {"SourceType", NewDataType(NUdf::EDataSlot::Int32)}, + }); callableBuilder.Add(inputStream); callableBuilder.Add(inputRowArg); callableBuilder.Add(partitionKeySelectorNode); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create(partitionColumnIndexes.data(), partitionColumnIndexes.size(), indexListType, Env), true)); + callableBuilder.Add(NewList(indexType, partitionColumnIndexes)); callableBuilder.Add(measureInputDataArg); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create( - specialColumnIndexesInMeasureInputDataRow.data(), specialColumnIndexesInMeasureInputDataRow.size(), - indexListType, Env - ), - true)); - callableBuilder.Add(NewDataLiteral<ui32>(inputRowType->GetMembersCount())); + callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow)); + callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount())); callableBuilder.Add(matchedVarsArg); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create(measureColumnIndexes.data(), measureColumnIndexes.size(), indexListType, Env), true)); + callableBuilder.Add(NewList(indexType, measureColumnIndexes)); for (const auto& m: measures) { callableBuilder.Add(m); } @@ -6209,16 +6234,19 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( callableBuilder.Add(currentRowIndexArg); callableBuilder.Add(inputDataArg); - const auto stringType = NewDataType(NUdf::EDataSlot::String); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create(defineNames.begin(), defineNames.size(), TListType::Create(stringType, Env), Env), true)); + callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames)); for (const auto& d: defineNodes) { callableBuilder.Add(d); } callableBuilder.Add(NewDataLiteral(streamingMode)); - if (RuntimeVersion >= 52U) { + if constexpr (RuntimeVersion >= 52U) { callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To))); callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var)); } + if constexpr (RuntimeVersion >= 53U) { + callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch))); + callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder)); + } return TRuntimeNode(callableBuilder.Build(), false); } diff --git a/yql/essentials/minikql/mkql_program_builder.h b/yql/essentials/minikql/mkql_program_builder.h index 762d3a013e..9e6b0d97d0 100644 --- a/yql/essentials/minikql/mkql_program_builder.h +++ b/yql/essentials/minikql/mkql_program_builder.h @@ -712,12 +712,15 @@ public: TRuntimeNode MatchRecognizeCore( TRuntimeNode inputStream, const TUnaryLambda& getPartitionKeySelectorNode, - const TArrayRef<TStringBuf>& partitionColumns, - const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures, + const TArrayRef<TStringBuf>& partitionColumnNames, + const TVector<TStringBuf>& measureColumnNames, + const TVector<TBinaryLambda>& getMeasures, const NYql::NMatchRecognize::TRowPattern& pattern, - const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines, + const TVector<TStringBuf>& defineVarNames, + const TVector<TTernaryLambda>& getDefines, bool streamingMode, - const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo + const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo, + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch ); TRuntimeNode TimeOrderRecover( |