aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql
diff options
context:
space:
mode:
authorvokayndzop <vokayndzop@yandex-team.com>2024-12-16 15:55:05 +0300
committervokayndzop <vokayndzop@yandex-team.com>2024-12-16 16:34:36 +0300
commitb1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c (patch)
tree230bddb8bb4ce7d8290a16a4465ec98dbf513a5a /yql/essentials/minikql
parent88e0ad5922cea1349ec1f8cbf133524cf865d696 (diff)
downloadydb-b1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c.tar.gz
MR: support ALL ROWS PER MATCH
commit_hash:9e2ba38d0d523bb870f6dc76717a3bec5d8ffadc
Diffstat (limited to 'yql/essentials/minikql')
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp207
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h71
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h2
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h27
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp144
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h72
-rw-r--r--yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp63
-rw-r--r--yql/essentials/minikql/comp_nodes/ya.make.inc1
-rw-r--r--yql/essentials/minikql/mkql_program_builder.cpp198
-rw-r--r--yql/essentials/minikql/mkql_program_builder.h11
10 files changed, 531 insertions, 265 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
index 80caaad37e..e1fc529fda 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -1,32 +1,28 @@
#include "mkql_match_recognize_list.h"
-#include "mkql_match_recognize_matched_vars.h"
#include "mkql_match_recognize_measure_arg.h"
+#include "mkql_match_recognize_matched_vars.h"
#include "mkql_match_recognize_nfa.h"
+#include "mkql_match_recognize_rows_formatter.h"
#include "mkql_match_recognize_save_load.h"
#include "mkql_saveload.h"
#include <yql/essentials/core/sql_types/match_recognize.h>
-#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_pack.h>
+#include <yql/essentials/minikql/mkql_node.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
-#include <yql/essentials/minikql/mkql_runtime_version.h>
#include <yql/essentials/minikql/mkql_string_util.h>
-#include <yql/essentials/core/sql_types/match_recognize.h>
+
#include <deque>
namespace NKikimr::NMiniKQL {
namespace NMatchRecognize {
-enum class EOutputColumnSource {PartitionKey, Measure};
-using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>, TMKQLAllocator<std::pair<EOutputColumnSource, size_t>>>;
-
constexpr ui32 StateVersion = 1;
-using namespace NYql::NMatchRecognize;
-
struct TMatchRecognizeProcessorParameters {
IComputationExternalNode* InputDataArg;
TRowPattern Pattern;
@@ -37,27 +33,21 @@ struct TMatchRecognizeProcessorParameters {
TComputationNodePtrVector Defines;
IComputationExternalNode* MeasureInputDataArg;
TMeasureInputColumnOrder MeasureInputColumnOrder;
- TComputationNodePtrVector Measures;
- TOutputColumnOrder OutputColumnOrder;
- TAfterMatchSkipTo SkipTo;
+ TAfterMatchSkipTo SkipTo;
};
class TStreamingMatchRecognize {
- using TPartitionList = TSparseList;
- using TRange = TPartitionList::TRange;
public:
TStreamingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
- TNfaTransitionGraph::TPtr nfaTransitions,
- const TContainerCacheOnContext& cache
- )
+ const IRowsFormatter::TState& rowsFormatterState,
+ TNfaTransitionGraph::TPtr nfaTransitions)
: PartitionKey(std::move(partitionKey))
, Parameters(parameters)
+ , RowsFormatter_(IRowsFormatter::Create(rowsFormatterState))
, Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo)
- , Cache(cache)
- {
- }
+ {}
bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows));
@@ -71,33 +61,26 @@ public:
}
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
+ if (auto result = RowsFormatter_->GetOtherMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph())) {
+ return result;
+ }
auto match = Nfa.GetMatched();
if (!match) {
return NUdf::TUnboxedValue{};
}
- Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars));
+ Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TSparseList::TRange>>(ctx.HolderFactory, match->Vars));
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows),
Parameters.MeasureInputColumnOrder,
Parameters.MatchedVarsArg->GetValue(ctx),
Parameters.VarNames,
MatchNumber
- ));
- NUdf::TUnboxedValue *itemsPtr = nullptr;
- const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr);
- for (auto const& c: Parameters.OutputColumnOrder) {
- switch(c.first) {
- case EOutputColumnSource::Measure:
- *itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx);
- break;
- case EOutputColumnSource::PartitionKey:
- *itemsPtr++ = PartitionKey.GetElement(c.second);
- break;
- }
- }
+ ));
+ auto result = RowsFormatter_->GetFirstMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph(), *match);
Nfa.AfterMatchSkip(*match);
return result;
}
+
bool ProcessEndOfData(TComputationContext& ctx) {
return Nfa.ProcessEndOfData(ctx);
}
@@ -117,11 +100,11 @@ public:
}
private:
- const NUdf::TUnboxedValue PartitionKey;
+ NUdf::TUnboxedValue PartitionKey;
const TMatchRecognizeProcessorParameters& Parameters;
+ std::unique_ptr<IRowsFormatter> RowsFormatter_;
TSparseList Rows;
TNfa Nfa;
- const TContainerCacheOnContext& Cache;
ui64 MatchNumber = 0;
};
@@ -135,7 +118,7 @@ public:
IComputationNode* partitionKey,
TType* partitionKeyType,
const TMatchRecognizeProcessorParameters& parameters,
- const TContainerCacheOnContext& cache,
+ const IRowsFormatter::TState& rowsFormatterState,
TComputationContext &ctx,
TType* rowType,
const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
@@ -145,8 +128,8 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
+ , RowsFormatterState(rowsFormatterState)
, RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
- , Cache(cache)
, Terminating(false)
, SerializerContext(ctx, rowType, rowPacker)
, Ctx(ctx)
@@ -184,8 +167,8 @@ public:
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(key),
Parameters,
- RowPatternConfiguration,
- Cache
+ RowsFormatterState,
+ RowPatternConfiguration
));
PartitionHandler->Load(in);
}
@@ -250,8 +233,9 @@ public:
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(partitionKey),
Parameters,
- RowPatternConfiguration,
- Cache));
+ RowsFormatterState,
+ RowPatternConfiguration
+ ));
PartitionHandler->ProcessInputRow(std::move(temp), ctx);
}
if (Terminating) {
@@ -266,8 +250,8 @@ private:
IComputationNode* PartitionKey;
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
+ const IRowsFormatter::TState& RowsFormatterState;
const TNfaTransitionGraph::TPtr RowPatternConfiguration;
- const TContainerCacheOnContext& Cache;
NUdf::TUnboxedValue DelayedRow;
bool Terminating;
TSerializerContext SerializerContext;
@@ -286,7 +270,7 @@ public:
IComputationNode* partitionKey,
TType* partitionKeyType,
const TMatchRecognizeProcessorParameters& parameters,
- const TContainerCacheOnContext& cache,
+ const IRowsFormatter::TState& rowsFormatterState,
TComputationContext &ctx,
TType* rowType,
const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
@@ -296,8 +280,8 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
+ , RowsFormatterState(rowsFormatterState)
, NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
- , Cache(cache)
, SerializerContext(ctx, rowType, rowPacker)
, Ctx(ctx)
{}
@@ -335,8 +319,10 @@ public:
std::make_unique<TStreamingMatchRecognize>(
std::move(key),
Parameters,
- NfaTransitionGraph,
- Cache));
+ RowsFormatterState,
+ NfaTransitionGraph
+ )
+ );
pair.first->second->Load(in);
}
@@ -402,8 +388,8 @@ private:
return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamingMatchRecognize>(
std::move(partitionKey),
Parameters,
- NfaTransitionGraph,
- Cache
+ RowsFormatterState,
+ NfaTransitionGraph
));
}
}
@@ -418,8 +404,8 @@ private:
//TODO switch to tuple compare
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
+ const IRowsFormatter::TState& RowsFormatterState;
const TNfaTransitionGraph::TPtr NfaTransitionGraph;
- const TContainerCacheOnContext& Cache;
TSerializerContext SerializerContext;
TComputationContext& Ctx;
};
@@ -428,20 +414,23 @@ template<class State>
class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true> {
using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true>;
public:
- TMatchRecognizeWrapper(TComputationMutables &mutables, EValueRepresentation kind, IComputationNode *inputFlow,
- IComputationExternalNode *inputRowArg,
- IComputationNode *partitionKey,
- TType* partitionKeyType,
- const TMatchRecognizeProcessorParameters& parameters,
- TType* rowType
- )
- :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
+ TMatchRecognizeWrapper(
+ TComputationMutables& mutables,
+ EValueRepresentation kind,
+ IComputationNode *inputFlow,
+ IComputationExternalNode *inputRowArg,
+ IComputationNode *partitionKey,
+ TType* partitionKeyType,
+ TMatchRecognizeProcessorParameters&& parameters,
+ IRowsFormatter::TState&& rowsFormatterState,
+ TType* rowType)
+ : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
, InputFlow(inputFlow)
, InputRowArg(inputRowArg)
, PartitionKey(partitionKey)
, PartitionKeyType(partitionKeyType)
- , Parameters(parameters)
- , Cache(mutables)
+ , Parameters(std::move(parameters))
+ , RowsFormatterState(std::move(rowsFormatterState))
, RowType(rowType)
, RowPacker(mutables)
{}
@@ -453,7 +442,7 @@ public:
PartitionKey,
PartitionKeyType,
Parameters,
- Cache,
+ RowsFormatterState,
ctx,
RowType,
RowPacker
@@ -468,7 +457,7 @@ public:
PartitionKey,
PartitionKeyType,
Parameters,
- Cache,
+ RowsFormatterState,
ctx,
RowType,
RowPacker
@@ -503,7 +492,7 @@ private:
Own(flow, Parameters.CurrentRowIndexArg);
Own(flow, Parameters.MeasureInputDataArg);
DependsOn(flow, PartitionKey);
- for (auto& m: Parameters.Measures) {
+ for (auto& m: RowsFormatterState.Measures) {
DependsOn(flow, m);
}
for (auto& d: Parameters.Defines) {
@@ -516,32 +505,31 @@ private:
IComputationExternalNode* const InputRowArg;
IComputationNode* const PartitionKey;
TType* const PartitionKeyType;
- const TMatchRecognizeProcessorParameters Parameters;
- const TContainerCacheOnContext Cache;
+ TMatchRecognizeProcessorParameters Parameters;
+ IRowsFormatter::TState RowsFormatterState;
TType* const RowType;
TMutableObjectOverBoxedValue<TValuePackerBoxed> RowPacker;
};
TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode partitionKyeColumnsIndexes, TRuntimeNode measureColumnsIndexes) {
- using tempMapValue = std::pair<EOutputColumnSource, size_t>;
- std::unordered_map<size_t, tempMapValue, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, tempMapValue>, EMemorySubPool::Temporary>> temp;
+ std::unordered_map<size_t, TOutputColumnEntry, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, TOutputColumnEntry>, EMemorySubPool::Temporary>> temp;
{
auto list = AS_VALUE(TListLiteral, partitionKyeColumnsIndexes);
for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>();
- temp[index] = std::make_pair(EOutputColumnSource::PartitionKey, i);
+ temp[index] = {i, EOutputColumnSource::PartitionKey};
}
}
{
auto list = AS_VALUE(TListLiteral, measureColumnsIndexes);
for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>();
- temp[index] = std::make_pair(EOutputColumnSource::Measure, i);
+ temp[index] = {i, EOutputColumnSource::Measure};
}
}
if (temp.empty())
return {};
- auto outputSize = max_element(temp.cbegin(), temp.cend())->first + 1;
+ auto outputSize = std::ranges::max_element(temp, {}, &std::pair<const size_t, TOutputColumnEntry>::first)->first + 1;
TOutputColumnOrder result(outputSize);
for (const auto& [i, v]: temp) {
result[i] = v;
@@ -576,7 +564,6 @@ TRowPattern ConvertPattern(const TRuntimeNode& pattern) {
}
TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) {
- using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns;
//Use Last enum value to denote that c colum comes from the input table
TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0));
if (specialColumnIndexes.GetItemsCount() != 0) {
@@ -621,7 +608,6 @@ std::pair<TUnboxedValueVector, THashMap<TString, size_t>> ConvertListOfStrings(c
} //namespace NMatchRecognize
-
IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
using namespace NMatchRecognize;
size_t inputIndex = 0;
@@ -641,9 +627,9 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
const auto& pattern = callable.GetInput(inputIndex++);
const auto& currentRowIndexArg = callable.GetInput(inputIndex++);
const auto& inputDataArg = callable.GetInput(inputIndex++);
- const auto& varNames = callable.GetInput(inputIndex++);
+ const auto& defineNames = callable.GetInput(inputIndex++);
TRuntimeNode::TList defines;
- for (size_t i = 0; i != AS_VALUE(TListLiteral, varNames)->GetItemsCount(); ++i) {
+ for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) {
defines.push_back(callable.GetInput(inputIndex++));
}
const auto& streamingMode = callable.GetInput(inputIndex++);
@@ -652,47 +638,58 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef();
}
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch = NYql::NMatchRecognize::ERowsPerMatch::OneRow;
+ TOutputColumnOrder outputColumnOrder;
+ if (inputIndex + 2 <= callable.GetInputsCount()) {
+ rowsPerMatch = static_cast<ERowsPerMatch>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
+ outputColumnOrder = IRowsFormatter::GetOutputColumnOrder(callable.GetInput(inputIndex++));
+ } else {
+ outputColumnOrder = GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes);
+ }
MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");
- const auto& [vars, varsLookup] = ConvertListOfStrings(varNames);
+ const auto& [varNames, varNamesLookup] = ConvertListOfStrings(defineNames);
auto* rowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputFlow.GetStaticType())->GetItemType());
- const auto parameters = TMatchRecognizeProcessorParameters {
- static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode()))
- , ConvertPattern(pattern)
- , vars
- , varsLookup
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode()))
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode()))
- , ConvertVectorOfCallables(defines, ctx)
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode()))
- , GetMeasureColumnOrder(
+ auto parameters = TMatchRecognizeProcessorParameters {
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())),
+ ConvertPattern(pattern),
+ varNames,
+ varNamesLookup,
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())),
+ ConvertVectorOfCallables(defines, ctx),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())),
+ GetMeasureColumnOrder(
*AS_VALUE(TListLiteral, measureSpecialColumnIndexes),
AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>()
- )
- , ConvertVectorOfCallables(measures, ctx)
- , GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
- , skipTo
+ ),
+ skipTo
};
+ IRowsFormatter::TState rowsFormatterState(ctx, outputColumnOrder, ConvertVectorOfCallables(measures, ctx), rowsPerMatch);
if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) {
- return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables
- , GetValueRepresentation(inputFlow.GetStaticType())
- , LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
- , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
- , partitionKeySelector.GetStaticType()
- , std::move(parameters)
- , rowType
+ return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(
+ ctx.Mutables,
+ GetValueRepresentation(inputFlow.GetStaticType()),
+ LocateNode(ctx.NodeLocator, *inputFlow.GetNode()),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())),
+ LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()),
+ partitionKeySelector.GetStaticType(),
+ std::move(parameters),
+ std::move(rowsFormatterState),
+ rowType
);
} else {
- return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables
- , GetValueRepresentation(inputFlow.GetStaticType())
- , LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
- , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
- , partitionKeySelector.GetStaticType()
- , std::move(parameters)
- , rowType
+ return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(
+ ctx.Mutables,
+ GetValueRepresentation(inputFlow.GetStaticType()),
+ LocateNode(ctx.NodeLocator, *inputFlow.GetNode()),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())),
+ LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()),
+ partitionKeySelector.GetStaticType(),
+ std::move(parameters),
+ std::move(rowsFormatterState),
+ rowType
);
}
}
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h
index 3c7771a41e..cf9bfb46a9 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h
@@ -18,26 +18,29 @@ public:
class TRange {
public:
TRange()
- : FromIndex(-1)
- , ToIndex(-1)
+ : FromIndex(Max())
+ , ToIndex(Max())
+ , NfaIndex_(Max())
{
}
explicit TRange(ui64 index)
- : FromIndex(index)
- , ToIndex(index)
+ : FromIndex(index)
+ , ToIndex(index)
+ , NfaIndex_(Max())
{
}
TRange(ui64 from, ui64 to)
- : FromIndex(from)
- , ToIndex(to)
+ : FromIndex(from)
+ , ToIndex(to)
+ , NfaIndex_(Max())
{
MKQL_ENSURE(FromIndex <= ToIndex, "Internal logic error");
}
bool IsValid() const {
- return true;
+ return FromIndex != Max<size_t>() && ToIndex != Max<size_t>();
}
size_t From() const {
@@ -50,6 +53,11 @@ public:
return ToIndex;
}
+ [[nodiscard]] size_t NfaIndex() const {
+ MKQL_ENSURE(IsValid(), "Internal logic error");
+ return NfaIndex_;
+ }
+
size_t Size() const {
MKQL_ENSURE(IsValid(), "Internal logic error");
return ToIndex - FromIndex + 1;
@@ -61,8 +69,9 @@ public:
}
private:
- ui64 FromIndex;
- ui64 ToIndex;
+ size_t FromIndex;
+ size_t ToIndex;
+ size_t NfaIndex_;
};
TRange Append(NUdf::TUnboxedValue&& value) {
@@ -156,14 +165,12 @@ class TSparseList {
private:
//TODO consider to replace hash table with contiguous chunks
- using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>;
-
using TStorage = std::unordered_map<
size_t,
TItem,
std::hash<size_t>,
std::equal_to<size_t>,
- TAllocator>;
+ TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>>;
TStorage Storage;
};
@@ -178,8 +185,9 @@ public:
public:
TRange()
: Container()
- , FromIndex(-1)
- , ToIndex(-1)
+ , FromIndex(Max())
+ , ToIndex(Max())
+ , NfaIndex_(Max())
{
}
@@ -187,6 +195,7 @@ public:
: Container(other.Container)
, FromIndex(other.FromIndex)
, ToIndex(other.ToIndex)
+ , NfaIndex_(other.NfaIndex_)
{
LockRange(FromIndex, ToIndex);
}
@@ -195,6 +204,7 @@ public:
: Container(other.Container)
, FromIndex(other.FromIndex)
, ToIndex(other.ToIndex)
+ , NfaIndex_(other.NfaIndex_)
{
other.Reset();
}
@@ -212,6 +222,7 @@ public:
Container = other.Container;
FromIndex = other.FromIndex;
ToIndex = other.ToIndex;
+ NfaIndex_ = other.NfaIndex_;
LockRange(FromIndex, ToIndex);
return *this;
}
@@ -224,20 +235,21 @@ public:
Container = other.Container;
FromIndex = other.FromIndex;
ToIndex = other.ToIndex;
+ NfaIndex_ = other.NfaIndex_;
other.Reset();
return *this;
}
friend inline bool operator==(const TRange& lhs, const TRange& rhs) {
- return std::tie(lhs.FromIndex, lhs.ToIndex) == std::tie(rhs.FromIndex, rhs.ToIndex);
+ return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) == std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_);
}
friend inline bool operator<(const TRange& lhs, const TRange& rhs) {
- return std::tie(lhs.FromIndex, lhs.ToIndex) < std::tie(rhs.FromIndex, rhs.ToIndex);
+ return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) < std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_);
}
bool IsValid() const {
- return static_cast<bool>(Container);
+ return static_cast<bool>(Container) && FromIndex != Max<size_t>() && ToIndex != Max<size_t>();
}
size_t From() const {
@@ -250,6 +262,15 @@ public:
return ToIndex;
}
+ [[nodiscard]] size_t NfaIndex() const {
+ MKQL_ENSURE(IsValid(), "Internal logic error");
+ return NfaIndex_;
+ }
+
+ void NfaIndex(size_t index) {
+ NfaIndex_ = index;
+ }
+
size_t Size() const {
MKQL_ENSURE(IsValid(), "Internal logic error");
return ToIndex - FromIndex + 1;
@@ -264,16 +285,17 @@ public:
void Release() {
UnlockRange(FromIndex, ToIndex);
Container.Reset();
- FromIndex = -1;
- ToIndex = -1;
+ FromIndex = Max();
+ ToIndex = Max();
+ NfaIndex_ = Max();
}
void Save(TMrOutputSerializer& serializer) const {
- serializer(Container, FromIndex, ToIndex);
+ serializer(Container, FromIndex, ToIndex, NfaIndex_);
}
void Load(TMrInputSerializer& serializer) {
- serializer(Container, FromIndex, ToIndex);
+ serializer(Container, FromIndex, ToIndex, NfaIndex_);
}
private:
@@ -281,6 +303,7 @@ public:
: Container(container)
, FromIndex(index)
, ToIndex(index)
+ , NfaIndex_(Max())
{}
void LockRange(size_t from, size_t to) {
@@ -297,13 +320,15 @@ public:
void Reset() {
Container.Reset();
- FromIndex = -1;
- ToIndex = -1;
+ FromIndex = Max();
+ ToIndex = Max();
+ NfaIndex_ = Max();
}
TContainerPtr Container;
size_t FromIndex;
size_t ToIndex;
+ size_t NfaIndex_;
};
public:
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
index 8c2576798a..de0ccc352f 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
@@ -15,7 +15,7 @@ void Extend(TMatchedVar<R>& var, const R& r) {
var.emplace_back(r);
} else {
MKQL_ENSURE(r.From() > var.back().To(), "Internal logic error");
- if (var.back().To() + 1 == r.From()) {
+ if (var.back().To() + 1 == r.From() && var.back().NfaIndex() == r.NfaIndex()) {
var.back().Extend();
} else {
var.emplace_back(r);
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
index 2b194212f4..c868379d40 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
@@ -20,9 +20,10 @@ struct TEpsilonTransitions {
friend constexpr bool operator==(const TEpsilonTransitions&, const TEpsilonTransitions&) = default;
};
struct TMatchedVarTransition {
+ size_t To;
ui32 VarIndex;
bool SaveState;
- size_t To;
+ bool ExcludeFromOutput;
friend constexpr bool operator==(const TMatchedVarTransition&, const TMatchedVarTransition&) = default;
};
struct TQuantityEnterTransition {
@@ -116,7 +117,7 @@ struct TNfaTransitionGraph {
serializer(tr.To);
},
[&](const TMatchedVarTransition& tr) {
- serializer(tr.VarIndex, tr.SaveState, tr.To);
+ serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput);
},
[&](const TQuantityEnterTransition& tr) {
serializer(tr.To);
@@ -141,7 +142,7 @@ struct TNfaTransitionGraph {
serializer(tr.To);
},
[&](TMatchedVarTransition& tr) {
- serializer(tr.VarIndex, tr.SaveState, tr.To);
+ serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput);
},
[&](TQuantityEnterTransition& tr) {
serializer(tr.To);
@@ -297,7 +298,7 @@ private:
auto input = AddNode();
auto output = AddNode();
auto item = factor.Primary.index() == 0 ?
- BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused) :
+ BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused, !factor.Output) :
BuildTerms(std::get<1>(factor.Primary), varNameToIndex);
if (1 == factor.QuantityMin && 1 == factor.QuantityMax) { //simple linear case
Graph->Transitions[input] = TEpsilonTransitions{{item.Input}};
@@ -319,15 +320,16 @@ private:
}
return {input, output};
}
- TNfaItem BuildVar(ui32 varIndex, bool isUsed) {
+ TNfaItem BuildVar(ui32 varIndex, bool isUsed, bool excludeFromOutput) {
auto input = AddNode();
auto matchVar = AddNode();
auto output = AddNode();
Graph->Transitions[input] = TEpsilonTransitions({matchVar});
Graph->Transitions[matchVar] = TMatchedVarTransition{
+ output,
varIndex,
isUsed,
- output,
+ excludeFromOutput,
};
return {input, output};
}
@@ -441,6 +443,7 @@ public:
if (matchedVarTransition->SaveState) {
auto vars = state.Match.Vars; //TODO get rid of this copy
auto& matchedVar = vars[varIndex];
+ currentRowLock.NfaIndex(state.Index);
Extend(matchedVar, currentRowLock);
newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers);
} else {
@@ -575,6 +578,10 @@ public:
}
}
+ const TNfaTransitionGraph& GetTransitionGraph() const {
+ return *TransitionGraph;
+ }
+
private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
@@ -593,14 +600,14 @@ private:
[&](const TEpsilonTransitions& epsilonTransitions) {
deletedStates.insert(state);
for (const auto& i : epsilonTransitions.To) {
- newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers);
+ newStates.emplace(i, state.Match, state.Quantifiers);
}
},
[&](const TQuantityEnterTransition& quantityEnterTransition) {
deletedStates.insert(state);
auto quantifiers = state.Quantifiers; //TODO get rid of this copy
quantifiers.push_back(0);
- newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers));
+ newStates.emplace(quantityEnterTransition.To, state.Match, std::move(quantifiers));
},
[&](const TQuantityExitTransition& quantityExitTransition) {
deletedStates.insert(state);
@@ -608,12 +615,12 @@ private:
if (state.Quantifiers.back() + 1 < quantityMax) {
auto q = state.Quantifiers;
q.back()++;
- newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
+ newStates.emplace(toFindMore, state.Match, std::move(q));
}
if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) {
auto q = state.Quantifiers;
q.pop_back();
- newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
+ newStates.emplace(toMatched, state.Match, std::move(q));
}
},
}, TransitionGraph->Transitions[state.Index]);
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp
new file mode 100644
index 0000000000..6d70458b3b
--- /dev/null
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp
@@ -0,0 +1,144 @@
+#include "mkql_match_recognize_rows_formatter.h"
+
+#include <yql/essentials/minikql/mkql_node.h>
+#include <yql/essentials/minikql/mkql_node_cast.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+namespace {
+
+class TOneRowFormatter final : public IRowsFormatter {
+public:
+ explicit TOneRowFormatter(const TState& state) : IRowsFormatter(state) {}
+
+ NUdf::TUnboxedValue GetFirstMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph,
+ const TNfa::TMatch& match) {
+ Match_ = match;
+ const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph);
+ IRowsFormatter::Clear();
+ return result;
+ }
+
+ NUdf::TUnboxedValue GetOtherMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph) {
+ return NUdf::TUnboxedValue{};
+ }
+};
+
+class TAllRowsFormatter final : public IRowsFormatter {
+public:
+ explicit TAllRowsFormatter(const IRowsFormatter::TState& state) : IRowsFormatter(state) {}
+
+ NUdf::TUnboxedValue GetFirstMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph,
+ const TNfa::TMatch& match) {
+ Match_ = match;
+ CurrentRowIndex_ = Match_.BeginIndex;
+ for (const auto& matchedVar : Match_.Vars) {
+ for (const auto& range : matchedVar) {
+ ToIndexToMatchRangeLookup_.emplace(range.To(), range);
+ }
+ }
+ return GetMatchRow(ctx, rows, partitionKey, graph);
+ }
+
+ NUdf::TUnboxedValue GetOtherMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph) {
+ return GetMatchRow(ctx, rows, partitionKey, graph);
+ }
+
+private:
+ NUdf::TUnboxedValue GetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) {
+ while (CurrentRowIndex_ <= Match_.EndIndex) {
+ if (auto iter = ToIndexToMatchRangeLookup_.lower_bound(CurrentRowIndex_);
+ iter == ToIndexToMatchRangeLookup_.end()) {
+ MKQL_ENSURE(false, "Internal logic error");
+ } else if (auto transition = std::get_if<TMatchedVarTransition>(&graph.Transitions.at(iter->second.NfaIndex()));
+ !transition) {
+ MKQL_ENSURE(false, "Internal logic error");
+ } else if (transition->ExcludeFromOutput) {
+ ++CurrentRowIndex_;
+ } else {
+ break;
+ }
+ }
+ if (CurrentRowIndex_ > Match_.EndIndex) {
+ return NUdf::TUnboxedValue{};
+ }
+ const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph);
+ ++CurrentRowIndex_;
+ if (CurrentRowIndex_ == Match_.EndIndex) {
+ Clear();
+ }
+ return result;
+ }
+
+ void Clear() {
+ IRowsFormatter::Clear();
+ ToIndexToMatchRangeLookup_.clear();
+ }
+
+ TMap<size_t, const TSparseList::TRange&> ToIndexToMatchRangeLookup_;
+};
+
+} // anonymous namespace
+
+IRowsFormatter::IRowsFormatter(const TState& state) : State_(state) {}
+
+TOutputColumnOrder IRowsFormatter::GetOutputColumnOrder(
+ TRuntimeNode outputColumnOrder) {
+ TOutputColumnOrder result;
+ auto list = AS_VALUE(TListLiteral, outputColumnOrder);
+ TConstArrayRef<TRuntimeNode> items(list->GetItems(), list->GetItemsCount());
+ for (auto item : items) {
+ const auto entry = AS_VALUE(TStructLiteral, item);
+ result.emplace_back(
+ AS_VALUE(TDataLiteral, entry->GetValue(0))->AsValue().Get<ui32>(),
+ static_cast<EOutputColumnSource>(AS_VALUE(TDataLiteral, entry->GetValue(1))->AsValue().Get<i32>())
+ );
+ }
+ return result;
+}
+
+NUdf::TUnboxedValue IRowsFormatter::DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) {
+ NUdf::TUnboxedValue *itemsPtr = nullptr;
+ const auto result = State_.Cache->NewArray(ctx, State_.OutputColumnOrder.size(), itemsPtr);
+ for (const auto& columnEntry: State_.OutputColumnOrder) {
+ switch(columnEntry.SourceType) {
+ case EOutputColumnSource::PartitionKey:
+ *itemsPtr++ = partitionKey.GetElement(columnEntry.Index);
+ break;
+ case EOutputColumnSource::Measure:
+ *itemsPtr++ = State_.Measures[columnEntry.Index]->GetValue(ctx);
+ break;
+ case EOutputColumnSource::Other:
+ *itemsPtr++ = rows.Get(CurrentRowIndex_).GetElement(columnEntry.Index);
+ break;
+ }
+ }
+ return result;
+}
+
+std::unique_ptr<IRowsFormatter> IRowsFormatter::Create(const IRowsFormatter::TState& state) {
+ switch (state.RowsPerMatch) {
+ case ERowsPerMatch::OneRow:
+ return std::unique_ptr<IRowsFormatter>(new TOneRowFormatter(state));
+ case ERowsPerMatch::AllRows:
+ return std::unique_ptr<IRowsFormatter>(new TAllRowsFormatter(state));
+ }
+}
+
+} //namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h
new file mode 100644
index 0000000000..39750bf4ea
--- /dev/null
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include "mkql_match_recognize_nfa.h"
+
+#include <yql/essentials/core/sql_types/match_recognize.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
+#include <yql/essentials/minikql/mkql_alloc.h>
+#include <yql/essentials/public/udf/udf_value.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+struct TOutputColumnEntry {
+ size_t Index;
+ NYql::NMatchRecognize::EOutputColumnSource SourceType;
+};
+using TOutputColumnOrder = std::vector<TOutputColumnEntry, TMKQLAllocator<TOutputColumnEntry>>;
+
+class IRowsFormatter {
+public:
+ struct TState {
+ std::unique_ptr<TContainerCacheOnContext> Cache;
+ TOutputColumnOrder OutputColumnOrder;
+ TComputationNodePtrVector Measures;
+ NYql::NMatchRecognize::ERowsPerMatch RowsPerMatch;
+
+ TState(
+ const TComputationNodeFactoryContext& ctx,
+ TOutputColumnOrder outputColumnOrder,
+ TComputationNodePtrVector measures,
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch)
+ : Cache(std::make_unique<TContainerCacheOnContext>(ctx.Mutables))
+ , OutputColumnOrder(std::move(outputColumnOrder))
+ , Measures(std::move(measures))
+ , RowsPerMatch(rowsPerMatch)
+ {}
+ };
+
+ explicit IRowsFormatter(const TState& state);
+ virtual ~IRowsFormatter() = default;
+
+ virtual NUdf::TUnboxedValue GetFirstMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph,
+ const TNfa::TMatch& match) = 0;
+
+ virtual NUdf::TUnboxedValue GetOtherMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph) = 0;
+
+ static TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode outputColumnOrder);
+
+ static std::unique_ptr<IRowsFormatter> Create(const TState& state);
+
+protected:
+ NUdf::TUnboxedValue DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph);
+
+ inline void Clear() {
+ Match_ = {};
+ CurrentRowIndex_ = Max();
+ }
+
+ const TState& State_;
+ TNfa::TMatch Match_ {};
+ size_t CurrentRowIndex_ = Max();
+};
+
+} // namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp
index a9fad1b6ef..513a72df5e 100644
--- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp
+++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp
@@ -64,57 +64,46 @@ namespace {
const TTestInputData& input) {
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
- auto structType = pgmBuilder.NewStructType({
- {"time", pgmBuilder.NewDataType(NUdf::TDataType<i64>::Id)},
- {"key", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)},
- {"sum", pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id)},
- {"part", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}});
+ const auto structType = pgmBuilder.NewStructType({
+ {"time", pgmBuilder.NewDataType(NUdf::EDataSlot::Int64)},
+ {"key", pgmBuilder.NewDataType(NUdf::EDataSlot::String)},
+ {"sum", pgmBuilder.NewDataType(NUdf::EDataSlot::Uint32)},
+ {"part", pgmBuilder.NewDataType(NUdf::EDataSlot::String)}
+ });
TVector<TRuntimeNode> items;
- for (size_t i = 0; i < input.size(); ++i)
- {
- auto time = pgmBuilder.NewDataLiteral<i64>(std::get<0>(input[i]));
- auto key = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<1>(input[i])));
- auto sum = pgmBuilder.NewDataLiteral<ui32>(std::get<2>(input[i]));
- auto part = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<3>(input[i])));
-
- auto item = pgmBuilder.NewStruct(structType,
- {{"time", time}, {"key", key}, {"sum", sum}, {"part", part}});
- items.push_back(std::move(item));
+ for (size_t i = 0; i < input.size(); ++i) {
+ const auto& [time, key, sum, part] = input[i];
+ items.push_back(pgmBuilder.NewStruct({
+ {"time", pgmBuilder.NewDataLiteral(time)},
+ {"key", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(key)},
+ {"sum", pgmBuilder.NewDataLiteral(sum)},
+ {"part", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(part)},
+ }));
}
const auto list = pgmBuilder.NewList(structType, std::move(items));
auto inputFlow = pgmBuilder.ToFlow(list);
-
- TVector<TStringBuf> partitionColumns;
- TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures = {{
- std::make_pair(
- TStringBuf("key"),
- [&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) {
- return pgmBuilder.NewDataLiteral<ui32>(56);
- }
- )}};
- TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines = {{
- std::make_pair(
- TStringBuf("A"),
- [&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) {
- return pgmBuilder.NewDataLiteral<bool>(true);
- }
- )}};
-
auto pgmReturn = pgmBuilder.MatchRecognizeCore(
inputFlow,
[&](TRuntimeNode item) {
- return pgmBuilder.Member(item, "part");
+ return pgmBuilder.NewTuple({pgmBuilder.Member(item, "part")});
},
- partitionColumns,
- getMeasures,
+ {},
+ {"key"sv},
+ {[&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) {
+ return pgmBuilder.NewDataLiteral<ui32>(56);
+ }},
{
{NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
},
- getDefines,
+ {"A"sv},
+ {[&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) {
+ return pgmBuilder.NewDataLiteral<bool>(true);
+ }},
streamingMode,
- {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}
+ {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""},
+ NYql::NMatchRecognize::ERowsPerMatch::OneRow
);
auto graph = setup.BuildGraph(pgmReturn);
diff --git a/yql/essentials/minikql/comp_nodes/ya.make.inc b/yql/essentials/minikql/comp_nodes/ya.make.inc
index 518f5f3c43..b2f0da8ac9 100644
--- a/yql/essentials/minikql/comp_nodes/ya.make.inc
+++ b/yql/essentials/minikql/comp_nodes/ya.make.inc
@@ -79,6 +79,7 @@ SET(ORIG_SOURCES
mkql_mapnext.cpp
mkql_map_join.cpp
mkql_match_recognize.cpp
+ mkql_match_recognize_rows_formatter.cpp
mkql_multihopping.cpp
mkql_multimap.cpp
mkql_next_value.cpp
diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp
index 691a642814..0b9197024d 100644
--- a/yql/essentials/minikql/mkql_program_builder.cpp
+++ b/yql/essentials/minikql/mkql_program_builder.cpp
@@ -1,5 +1,4 @@
#include "mkql_program_builder.h"
-#include "mkql_opt_literal.h"
#include "mkql_node_visitor.h"
#include "mkql_node_cast.h"
#include "mkql_runtime_version.h"
@@ -11,6 +10,7 @@
#include "yql/essentials/core/sql_types/time_order_recover.h"
#include <yql/essentials/parser/pg_catalog/catalog.h>
+#include <util/generic/overloaded.h>
#include <util/string/cast.h>
#include <util/string/printf.h>
#include <array>
@@ -6035,15 +6035,19 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil
TTupleLiteralBuilder termBuilder(env);
for (const auto& factor: term) {
TTupleLiteralBuilder factorBuilder(env);
- factorBuilder.Add(factor.Primary.index() == 0 ?
- programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(std::get<0>(factor.Primary)) :
- PatternToRuntimeNode(std::get<1>(factor.Primary), programBuilder)
- );
- factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMin));
- factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMax));
- factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Greedy));
- factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Output));
- factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Unused));
+ factorBuilder.Add(std::visit(TOverloaded {
+ [&](const TString& s) {
+ return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s);
+ },
+ [&](const TRowPattern& pattern) {
+ return PatternToRuntimeNode(pattern, programBuilder);
+ },
+ }, factor.Primary));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused));
termBuilder.Add({factorBuilder.Build(), true});
}
patternBuilder.Add({termBuilder.Build(), true});
@@ -6056,151 +6060,172 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil
TRuntimeNode TProgramBuilder::MatchRecognizeCore(
TRuntimeNode inputStream,
const TUnaryLambda& getPartitionKeySelectorNode,
- const TArrayRef<TStringBuf>& partitionColumns,
- const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
+ const TArrayRef<TStringBuf>& partitionColumnNames,
+ const TVector<TStringBuf>& measureColumnNames,
+ const TVector<TBinaryLambda>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
- const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
+ const TVector<TStringBuf>& defineVarNames,
+ const TVector<TTernaryLambda>& getDefines,
bool streamingMode,
- const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
+ const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
) {
MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);
const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType());
const auto inputRowArg = Arg(inputRowType);
const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg);
+ const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements();
- TStructTypeBuilder indexRangeTypeBuilder(Env);
- indexRangeTypeBuilder.Add("From", TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
- indexRangeTypeBuilder.Add("To", TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
- const auto& rangeList = TListType::Create(indexRangeTypeBuilder.Build(), Env);
+ const auto rangeList = NewListType(NewStructType({
+ {"From", NewDataType(NUdf::EDataSlot::Uint64)},
+ {"To", NewDataType(NUdf::EDataSlot::Uint64)}
+ }));
TStructTypeBuilder matchedVarsTypeBuilder(Env);
for (const auto& var: GetPatternVars(pattern)) {
matchedVarsTypeBuilder.Add(var, rangeList);
}
- TRuntimeNode matchedVarsArg = Arg(matchedVarsTypeBuilder.Build());
+ const auto matchedVarsType = matchedVarsTypeBuilder.Build();
+ TRuntimeNode matchedVarsArg = Arg(matchedVarsType);
//---These vars may be empty in case of no measures
TRuntimeNode measureInputDataArg;
std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow;
TVector<TRuntimeNode> measures;
- TVector<TType*> measureTypes;
//---
if (getMeasures.empty()) {
measureInputDataArg = Arg(Env.GetTypeOfVoidLazy());
} else {
- using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns;
measures.reserve(getMeasures.size());
- measureTypes.reserve(getMeasures.size());
specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last));
TStructTypeBuilder measureInputDataRowTypeBuilder(Env);
- for (ui32 i = 0; i != inputRowType->GetMembersCount(); ++i) {
+ for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) {
measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i));
}
measureInputDataRowTypeBuilder.Add(
MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier),
- TDataType::Create(NUdf::TDataType<NYql::NUdf::TUtf8>::Id, Env)
+ NewDataType(NUdf::EDataSlot::Utf8)
);
measureInputDataRowTypeBuilder.Add(
MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber),
- TDataType::Create(NUdf::TDataType<ui64>::Id, Env)
+ NewDataType(NUdf::EDataSlot::Uint64)
);
const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build();
- for (ui32 i = 0; i != measureInputDataRowType->GetMembersCount(); ++i) {
+ for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) {
//assume a few, if grows, it's better to use a lookup table here
static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5);
for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) {
if (measureInputDataRowType->GetMemberName(i) ==
NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j)))
- specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral<ui32>(i);
+ specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i);
}
}
- measureInputDataArg = Arg(TListType::Create(measureInputDataRowType, Env));
+ measureInputDataArg = Arg(NewListType(measureInputDataRowType));
for (size_t i = 0; i != getMeasures.size(); ++i) {
- measures.push_back(getMeasures[i].second(measureInputDataArg, matchedVarsArg));
- measureTypes.push_back(measures[i].GetStaticType());
+ measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg));
}
}
TStructTypeBuilder outputRowTypeBuilder(Env);
THashMap<TStringBuf, size_t> partitionColumnLookup;
- for (size_t i = 0; i != partitionColumns.size(); ++i) {
- const auto& name = partitionColumns[i];
- partitionColumnLookup[name] = i;
- outputRowTypeBuilder.Add(
- name,
- AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElementType(i)
- );
- }
THashMap<TStringBuf, size_t> measureColumnLookup;
- for (size_t i = 0; i != measures.size(); ++i) {
- const auto& name = getMeasures[i].first;
- measureColumnLookup[name] = i;
- outputRowTypeBuilder.Add(
- name,
- measures[i].GetStaticType()
- );
+ THashMap<TStringBuf, size_t> otherColumnLookup;
+ for (size_t i = 0; i < measureColumnNames.size(); ++i) {
+ const auto name = measureColumnNames[i];
+ measureColumnLookup.emplace(name, i);
+ outputRowTypeBuilder.Add(name, measures[i].GetStaticType());
+ }
+ switch (rowsPerMatch) {
+ case NYql::NMatchRecognize::ERowsPerMatch::OneRow:
+ for (size_t i = 0; i < partitionColumnNames.size(); ++i) {
+ const auto name = partitionColumnNames[i];
+ partitionColumnLookup.emplace(name, i);
+ outputRowTypeBuilder.Add(name, partitionColumnTypes[i]);
+ }
+ break;
+ case NYql::NMatchRecognize::ERowsPerMatch::AllRows:
+ for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) {
+ const auto name = inputRowType->GetMemberName(i);
+ otherColumnLookup.emplace(name, i);
+ outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i));
+ }
+ break;
}
auto outputRowType = outputRowTypeBuilder.Build();
std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size());
std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size());
- for (ui32 i = 0; i != outputRowType->GetMembersCount(); ++i) {
- if (auto it = partitionColumnLookup.find(outputRowType->GetMemberName(i)); it != partitionColumnLookup.end()) {
- partitionColumnIndexes[it->second] = NewDataLiteral<ui32>(i);
- }
- else if (auto it = measureColumnLookup.find(outputRowType->GetMemberName(i)); it != measureColumnLookup.end()) {
- measureColumnIndexes[it->second] = NewDataLiteral<ui32>(i);
+ TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()});
+ for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) {
+ const auto name = outputRowType->GetMemberName(i);
+ if (auto iter = partitionColumnLookup.find(name);
+ iter != partitionColumnLookup.end()) {
+ partitionColumnIndexes[iter->second] = NewDataLiteral(i);
+ outputColumnOrder.push_back(NewStruct({
+ std::pair{"Index", NewDataLiteral(iter->second)},
+ std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))},
+ }));
+ } else if (auto iter = measureColumnLookup.find(name);
+ iter != measureColumnLookup.end()) {
+ measureColumnIndexes[iter->second] = NewDataLiteral(i);
+ outputColumnOrder.push_back(NewStruct({
+ std::pair{"Index", NewDataLiteral(iter->second)},
+ std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))},
+ }));
+ } else if (auto iter = otherColumnLookup.find(name);
+ iter != otherColumnLookup.end()) {
+ outputColumnOrder.push_back(NewStruct({
+ std::pair{"Index", NewDataLiteral(iter->second)},
+ std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))},
+ }));
}
}
- auto outputType = (TType*)TFlowType::Create(outputRowType, Env);
+ const auto outputType = NewFlowType(outputRowType);
- THashMap<TStringBuf , size_t> patternVarLookup;
- for (ui32 i = 0; i != AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMembersCount(); ++i){
- patternVarLookup[AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMemberName(i)] = i;
+ THashMap<TStringBuf, size_t> patternVarLookup;
+ for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) {
+ patternVarLookup[matchedVarsType->GetMemberName(i)] = i;
}
- THashMap<TStringBuf , size_t> defineLookup;
- for (size_t i = 0; i != getDefines.size(); ++i) {
- defineLookup[getDefines[i].first] = i;
+ THashMap<TStringBuf, size_t> defineLookup;
+ for (size_t i = 0; i < defineVarNames.size(); ++i) {
+ const auto name = defineVarNames[i];
+ defineLookup[name] = i;
}
- std::vector<TRuntimeNode> defineNames(patternVarLookup.size());
- std::vector<TRuntimeNode> defineNodes(patternVarLookup.size());
- const auto& inputDataArg = Arg(TListType::Create(inputRowType, Env));
- const auto& currentRowIndexArg = Arg(TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
+ TVector<TRuntimeNode> defineNames(defineVarNames.size());
+ TVector<TRuntimeNode> defineNodes(patternVarLookup.size());
+ const auto inputDataArg = Arg(NewListType(inputRowType));
+ const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64));
for (const auto& [v, i]: patternVarLookup) {
defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v);
- if (const auto it = defineLookup.find(v); it != defineLookup.end()) {
- defineNodes[i] = getDefines[it->second].second(inputDataArg, matchedVarsArg, currentRowIndexArg);
- }
- else { //no predicate for var
- if ("$" == v || "^" == v) {
- //DO nothing, //will be handled in a specific way
- }
- else { // a var without a predicate matches any row
- defineNodes[i] = NewDataLiteral<bool>(true);
- }
+ if (auto iter = defineLookup.find(v);
+ iter != defineLookup.end()) {
+ defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg);
+ } else if ("$" == v || "^" == v) {
+ //DO nothing, //will be handled in a specific way
+ } else { // a var without a predicate matches any row
+ defineNodes[i] = NewDataLiteral(true);
}
}
TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType);
- auto indexType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env);
- auto indexListType = TListType::Create(indexType, Env);
+ const auto indexType = NewDataType(NUdf::EDataSlot::Uint32);
+ const auto outputColumnEntryType = NewStructType({
+ {"Index", NewDataType(NUdf::EDataSlot::Uint64)},
+ {"SourceType", NewDataType(NUdf::EDataSlot::Int32)},
+ });
callableBuilder.Add(inputStream);
callableBuilder.Add(inputRowArg);
callableBuilder.Add(partitionKeySelectorNode);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(partitionColumnIndexes.data(), partitionColumnIndexes.size(), indexListType, Env), true));
+ callableBuilder.Add(NewList(indexType, partitionColumnIndexes));
callableBuilder.Add(measureInputDataArg);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(
- specialColumnIndexesInMeasureInputDataRow.data(), specialColumnIndexesInMeasureInputDataRow.size(),
- indexListType, Env
- ),
- true));
- callableBuilder.Add(NewDataLiteral<ui32>(inputRowType->GetMembersCount()));
+ callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow));
+ callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount()));
callableBuilder.Add(matchedVarsArg);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(measureColumnIndexes.data(), measureColumnIndexes.size(), indexListType, Env), true));
+ callableBuilder.Add(NewList(indexType, measureColumnIndexes));
for (const auto& m: measures) {
callableBuilder.Add(m);
}
@@ -6209,16 +6234,19 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
callableBuilder.Add(currentRowIndexArg);
callableBuilder.Add(inputDataArg);
- const auto stringType = NewDataType(NUdf::EDataSlot::String);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(defineNames.begin(), defineNames.size(), TListType::Create(stringType, Env), Env), true));
+ callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames));
for (const auto& d: defineNodes) {
callableBuilder.Add(d);
}
callableBuilder.Add(NewDataLiteral(streamingMode));
- if (RuntimeVersion >= 52U) {
+ if constexpr (RuntimeVersion >= 52U) {
callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
}
+ if constexpr (RuntimeVersion >= 53U) {
+ callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch)));
+ callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder));
+ }
return TRuntimeNode(callableBuilder.Build(), false);
}
diff --git a/yql/essentials/minikql/mkql_program_builder.h b/yql/essentials/minikql/mkql_program_builder.h
index 762d3a013e..9e6b0d97d0 100644
--- a/yql/essentials/minikql/mkql_program_builder.h
+++ b/yql/essentials/minikql/mkql_program_builder.h
@@ -712,12 +712,15 @@ public:
TRuntimeNode MatchRecognizeCore(
TRuntimeNode inputStream,
const TUnaryLambda& getPartitionKeySelectorNode,
- const TArrayRef<TStringBuf>& partitionColumns,
- const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
+ const TArrayRef<TStringBuf>& partitionColumnNames,
+ const TVector<TStringBuf>& measureColumnNames,
+ const TVector<TBinaryLambda>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
- const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
+ const TVector<TStringBuf>& defineVarNames,
+ const TVector<TTernaryLambda>& getDefines,
bool streamingMode,
- const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
+ const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
);
TRuntimeNode TimeOrderRecover(