diff options
author | vokayndzop <vokayndzop@yandex-team.com> | 2024-12-16 15:55:05 +0300 |
---|---|---|
committer | vokayndzop <vokayndzop@yandex-team.com> | 2024-12-16 16:34:36 +0300 |
commit | b1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c (patch) | |
tree | 230bddb8bb4ce7d8290a16a4465ec98dbf513a5a | |
parent | 88e0ad5922cea1349ec1f8cbf133524cf865d696 (diff) | |
download | ydb-b1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c.tar.gz |
MR: support ALL ROWS PER MATCH
commit_hash:9e2ba38d0d523bb870f6dc76717a3bec5d8ffadc
24 files changed, 800 insertions, 401 deletions
diff --git a/yql/essentials/core/sql_types/match_recognize.h b/yql/essentials/core/sql_types/match_recognize.h index 0c6105ad94..e142587890 100644 --- a/yql/essentials/core/sql_types/match_recognize.h +++ b/yql/essentials/core/sql_types/match_recognize.h @@ -23,6 +23,16 @@ struct TAfterMatchSkipTo { [[nodiscard]] bool operator==(const TAfterMatchSkipTo&) const noexcept = default; }; +enum class ERowsPerMatch { + OneRow, + AllRows +}; +enum class EOutputColumnSource { + PartitionKey, + Measure, + Other, +}; + constexpr size_t MaxPatternNesting = 20; //Limit recursion for patterns constexpr size_t MaxPermutedItems = 6; @@ -47,8 +57,8 @@ using TRowPatternPrimary = std::variant<TString, TRowPattern>; struct TRowPatternFactor { TRowPatternPrimary Primary; - uint64_t QuantityMin; - uint64_t QuantityMax; + ui64 QuantityMin; + ui64 QuantityMax; bool Greedy; bool Output; //include in output with ALL ROW PER MATCH bool Unused; // optimization flag; is true when the variable is not used in defines and measures diff --git a/yql/essentials/core/type_ann/type_ann_match_recognize.cpp b/yql/essentials/core/type_ann/type_ann_match_recognize.cpp index 58d24e7b1b..4a6d61192e 100644 --- a/yql/essentials/core/type_ann/type_ann_match_recognize.cpp +++ b/yql/essentials/core/type_ann/type_ann_match_recognize.cpp @@ -10,11 +10,11 @@ MatchRecognizeWrapper(const TExprNode::TPtr &input, TExprNode::TPtr &output, TCo if (!EnsureArgsCount(*input, 5, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto& source = input->ChildRef(0); + const auto source = input->Child(0); auto& partitionKeySelector = input->ChildRef(1); - const auto& partitionColumns = input->ChildRef(2); - const auto& sortTraits = input->ChildRef(3); - const auto& params = input->ChildRef(4); + const auto partitionColumns = input->Child(2); + const auto sortTraits = input->Child(3); + const auto params = input->Child(4); Y_UNUSED(source, sortTraits); auto status = ConvertToLambda(partitionKeySelector, ctx.Expr, 1, 1); if (status.Level != IGraphTransformer::TStatus::Ok) { @@ -31,11 +31,22 @@ MatchRecognizeWrapper(const TExprNode::TPtr &input, TExprNode::TPtr &output, TCo //merge measure columns, came from params, with partition columns to form output row type auto outputTableColumns = params->GetTypeAnn()->Cast<TStructExprType>()->GetItems(); - for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) { - outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>( - partitionColumns->ChildRef(i)->Content(), - partitionKeySelectorItemTypes[i] - )); + if (const auto rowsPerMatch = params->Child(1); + "RowsPerMatch_OneRow" == rowsPerMatch->Content()) { + for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) { + outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>( + partitionColumns->Child(i)->Content(), + partitionKeySelectorItemTypes[i] + )); + } + } else if ("RowsPerMatch_AllRows" == rowsPerMatch->Content()) { + const auto& inputTableColumns = GetSeqItemType(source->GetTypeAnn())->Cast<TStructExprType>()->GetItems(); + for (const auto& column : inputTableColumns) { + outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(column->GetName(), column->GetItemType())); + } + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(rowsPerMatch->Pos()), "Unknown RowsPerMatch option")); + return IGraphTransformer::TStatus::Error; } const auto outputTableRowType = ctx.Expr.MakeType<TStructExprType>(outputTableColumns); input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(outputTableRowType)); @@ -48,7 +59,7 @@ MatchRecognizeParamsWrapper(const TExprNode::TPtr &input, TExprNode::TPtr &outpu if (!EnsureArgsCount(*input, 5, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto& measures = input->ChildRef(0); + const auto measures = input->Child(0); input->SetTypeAnn(measures->GetTypeAnn()); return IGraphTransformer::TStatus::Ok; } @@ -77,10 +88,10 @@ MatchRecognizeMeasuresWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& out if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto& inputRowType = input->ChildRef(0); - const auto& pattern = input->ChildRef(1); - const auto& names = input->ChildRef(2); - const size_t FirstLambdaIndex = 3; + const auto inputRowType = input->Child(0); + const auto pattern = input->Child(1); + const auto names = input->Child(2); + constexpr size_t FirstLambdaIndex = 3; if (!EnsureTupleOfAtoms(*names, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -119,7 +130,7 @@ MatchRecognizeMeasuresWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& out return IGraphTransformer::TStatus::Error; } if (auto type = lambda->GetTypeAnn()) { - items.push_back(ctx.Expr.MakeType<TItemExprType>(names->ChildRef(i)->Content(), type)); + items.push_back(ctx.Expr.MakeType<TItemExprType>(names->Child(i)->Content(), type)); } else { return IGraphTransformer::TStatus::Repeat; } @@ -143,9 +154,9 @@ MatchRecognizeDefinesWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& outp if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto& inputRowType = input->ChildRef(0); - const auto& pattern = input->ChildRef(1); - const auto& names = input->ChildRef(2); + const auto inputRowType = input->Child(0); + const auto pattern = input->Child(1); + const auto names = input->Child(2); const size_t FirstLambdaIndex = 3; if (!EnsureTupleOfAtoms(*names, ctx.Expr)) { @@ -176,7 +187,7 @@ MatchRecognizeDefinesWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& outp } if (auto type = lambda->GetTypeAnn()) { if (IsBoolLike(*type)) { - items.push_back(ctx.Expr.MakeType<TItemExprType>(names->ChildRef(i)->Content(), type)); + items.push_back(ctx.Expr.MakeType<TItemExprType>(names->Child(i)->Content(), type)); } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambda->Pos()), "DEFINE expression must be a predicate")); return IGraphTransformer::TStatus::Error; @@ -200,18 +211,18 @@ bool ValidateSettings(const TExprNode::TPtr& settings, TExprContext& ctx) { return false; } - const auto streamingMode = settings->ChildRef(0); + const auto streamingMode = settings->Child(0); if (!EnsureTupleOfAtoms(*streamingMode, ctx)) { return false; } if (!EnsureArgsCount(*streamingMode, 2, ctx)) { return false; } - if (streamingMode->ChildRef(0)->Content() != "Streaming") { + if (streamingMode->Child(0)->Content() != "Streaming") { ctx.AddError(TIssue(ctx.GetPosition(settings->Pos()), "Expected Streaming setting")); return false; } - const auto mode = streamingMode->ChildRef(1)->Content(); + const auto mode = streamingMode->Child(1)->Content(); if (mode != "0" and mode != "1") { ctx.AddError(TIssue(ctx.GetPosition(settings->Pos()), TStringBuilder() << "Expected 0 or 1, but got: " << mode)); return false; @@ -231,11 +242,11 @@ MatchRecognizeCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, if (!EnsureArgsCount(*input, 5, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto& source = input->ChildRef(0); + const auto source = input->Child(0); auto& partitionKeySelector = input->ChildRef(1); - const auto& partitionColumns = input->ChildRef(2); - const auto& params = input->ChildRef(3); - const auto& settings = input->ChildRef(4); + const auto partitionColumns = input->Child(2); + const auto params = input->Child(3); + const auto settings = input->Child(4); if (not params->IsCallable("MatchRecognizeParams")) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(params->Pos()), "Expected MatchRecognizeParams")); return IGraphTransformer::TStatus::Error; @@ -248,9 +259,9 @@ MatchRecognizeCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, if (!EnsureFlowType(*source, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto& inputRowType = GetSeqItemType(source->GetTypeAnn()); - const auto& define = params->ChildRef(4); - if (not inputRowType->Equals(*define->ChildRef(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType())) { + const auto inputRowType = GetSeqItemType(source->GetTypeAnn()); + const auto define = params->Child(4); + if (not inputRowType->Equals(*define->Child(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType())) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Expected the same input row type as for DEFINE")); return IGraphTransformer::TStatus::Error; } @@ -279,11 +290,22 @@ MatchRecognizeCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, } auto outputTableColumns = params->GetTypeAnn()->Cast<TStructExprType>()->GetItems(); - for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) { - outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>( - partitionColumns->ChildRef(i)->Content(), - partitionKeySelectorItemTypes[i] - )); + if (const auto rowsPerMatch = params->Child(1); + "RowsPerMatch_OneRow" == rowsPerMatch->Content()) { + for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) { + outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>( + partitionColumns->Child(i)->Content(), + partitionKeySelectorItemTypes[i] + )); + } + } else if ("RowsPerMatch_AllRows" == rowsPerMatch->Content()) { + const auto& inputTableColumns = GetSeqItemType(source->GetTypeAnn())->Cast<TStructExprType>()->GetItems(); + for (const auto& column : inputTableColumns) { + outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(column->GetName(), column->GetItemType())); + } + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(rowsPerMatch->Pos()), "Unknown RowsPerMatch option")); + return IGraphTransformer::TStatus::Error; } const auto outputTableRowType = ctx.Expr.MakeType<TStructExprType>(outputTableColumns); input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputTableRowType)); diff --git a/yql/essentials/core/yql_opt_match_recognize.cpp b/yql/essentials/core/yql_opt_match_recognize.cpp index 8e40e400a1..f2bef7d7d4 100644 --- a/yql/essentials/core/yql_opt_match_recognize.cpp +++ b/yql/essentials/core/yql_opt_match_recognize.cpp @@ -22,8 +22,8 @@ bool IsStreaming(const TExprNode::TPtr& input, const TTypeAnnotationContext& typ bool hasPq = false; NYql::VisitExpr(input, [&hasPq](const TExprNode::TPtr& node){ if (node->IsCallable("DataSource")) { - YQL_ENSURE(node->ChildrenSize() > 0 and node->ChildRef(0)->IsAtom()); - hasPq = node->ChildRef(0)->Content() == "pq"; + YQL_ENSURE(node->ChildrenSize() > 0 and node->Child(0)->IsAtom()); + hasPq = node->Child(0)->Content() == "pq"; } return !hasPq; }); @@ -39,8 +39,8 @@ std::optional<TSet<TStringBuf>> FindUsedVars(const TExprNode::TPtr& params) { const auto createVisitor = [&usedVars, &allVarsUsed](const TExprNode::TPtr& varsArg) { return [&varsArg, &usedVars, &allVarsUsed](const TExprNode::TPtr& node) -> bool { if (node->IsCallable("Member")) { - if (node->ChildRef(0) == varsArg) { - usedVars.insert(node->ChildRef(1)->Content()); + if (node->Child(0) == varsArg) { + usedVars.insert(node->Child(1)->Content()); return false; } } @@ -51,23 +51,23 @@ std::optional<TSet<TStringBuf>> FindUsedVars(const TExprNode::TPtr& params) { }; }; - const auto& measures = params->ChildRef(0); + const auto measures = params->Child(0); static constexpr size_t measureLambdasStartPos = 3; for (size_t pos = measureLambdasStartPos; pos != measures->ChildrenSize(); pos++) { - const auto& lambda = measures->ChildRef(pos); - const auto& lambdaArgs = lambda->ChildRef(0); - const auto& lambdaBody = lambda->ChildRef(1); - const auto& varsArg = lambdaArgs->ChildRef(1); + const auto lambda = measures->Child(pos); + const auto lambdaArgs = lambda->Child(0); + const auto lambdaBody = lambda->ChildPtr(1); + const auto varsArg = lambdaArgs->ChildPtr(1); NYql::VisitExpr(lambdaBody, createVisitor(varsArg)); } - const auto& defines = params->ChildRef(4); + const auto defines = params->Child(4); static constexpr size_t defineLambdasStartPos = 3; for (size_t pos = defineLambdasStartPos; pos != defines->ChildrenSize(); pos++) { - const auto& lambda = defines->ChildRef(pos); - const auto& lambdaArgs = lambda->ChildRef(0); - const auto& lambdaBody = lambda->ChildRef(1); - const auto& varsArg = lambdaArgs->ChildRef(1); + const auto lambda = defines->Child(pos); + const auto lambdaArgs = lambda->Child(0); + const auto lambdaBody = lambda->ChildPtr(1); + const auto varsArg = lambdaArgs->ChildPtr(1); NYql::VisitExpr(lambdaBody, createVisitor(varsArg)); } @@ -75,25 +75,26 @@ std::optional<TSet<TStringBuf>> FindUsedVars(const TExprNode::TPtr& params) { } // usedVars can be std::nullopt if all vars could probably be used -TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& ctx, const std::optional<TSet<TStringBuf>> &usedVars) { +TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& ctx, const std::optional<TSet<TStringBuf>> &usedVars, TStringBuf rowsPerMatch) { const auto pos = node->Pos(); - if (node->ChildrenSize() != 0 && node->ChildRef(0)->IsAtom()) { - const auto& varName = node->ChildRef(0)->Content(); - bool varUsed = !usedVars.has_value() || usedVars.value().contains(varName); + if (node->ChildrenSize() != 0 && node->Child(0)->IsAtom()) { + const auto varName = node->Child(0)->Content(); + const auto output = node->Child(4); + const auto varUnused = ("RowsPerMatch_AllRows" != rowsPerMatch || !output) && usedVars && !usedVars->contains(varName); return ctx.Builder(pos) .List() - .Add(0, node->ChildRef(0)) - .Add(1, node->ChildRef(1)) - .Add(2, node->ChildRef(2)) - .Add(3, node->ChildRef(3)) - .Add(4, node->ChildRef(4)) - .Add(5, ctx.NewAtom(pos, varUsed ? "0" : "1")) + .Add(0, node->ChildPtr(0)) + .Add(1, node->ChildPtr(1)) + .Add(2, node->ChildPtr(2)) + .Add(3, node->ChildPtr(3)) + .Add(4, output) + .Add(5, ctx.NewAtom(pos, ToString(varUnused))) .Seal() .Build(); } TExprNodeList newChildren; for (size_t chPos = 0; chPos != node->ChildrenSize(); chPos++) { - newChildren.push_back(MarkUnusedPatternVars(node->ChildRef(chPos), ctx, usedVars)); + newChildren.push_back(MarkUnusedPatternVars(node->ChildPtr(chPos), ctx, usedVars, rowsPerMatch)); } if (node->IsCallable()) { return ctx.Builder(pos).Callable(node->Content()).Add(std::move(newChildren)).Seal().Build(); @@ -106,11 +107,11 @@ TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typeAnnCtx) { YQL_ENSURE(node->IsCallable({"MatchRecognize"})); - const auto& input = node->ChildRef(0); - const auto& partitionKeySelector = node->ChildRef(1); - const auto& partitionColumns = node->ChildRef(2); - const auto& sortTraits = node->ChildRef(3); - const auto& params = node->ChildRef(4); + const auto input = node->Child(0); + const auto partitionKeySelector = node->Child(1); + const auto partitionColumns = node->Child(2); + const auto sortTraits = node->Child(3); + const auto params = node->ChildPtr(4); const auto pos = node->Pos(); const bool isStreaming = IsStreaming(input, typeAnnCtx); @@ -118,6 +119,7 @@ TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext& TExprNode::TPtr settings = AddSetting(*ctx.NewList(pos, {}), pos, "Streaming", ctx.NewAtom(pos, ToString(isStreaming)), ctx); + const auto rowsPerMatch = params->Child(1)->Content(); const auto matchRecognize = ctx.Builder(pos) .Lambda() .Param("sortedPartition") @@ -129,11 +131,11 @@ TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext& .Add(1, partitionKeySelector) .Add(2, partitionColumns) .Callable(3, params->Content()) - .Add(0, params->ChildRef(0)) - .Add(1, params->ChildRef(1)) - .Add(2, params->ChildRef(2)) - .Add(3, MarkUnusedPatternVars(params->ChildRef(3), ctx, FindUsedVars(params))) - .Add(4, params->ChildRef(4)) + .Add(0, params->ChildPtr(0)) + .Add(1, params->ChildPtr(1)) + .Add(2, params->ChildPtr(2)) + .Add(3, MarkUnusedPatternVars(params->ChildPtr(3), ctx, FindUsedVars(params), rowsPerMatch)) + .Add(4, params->ChildPtr(4)) .Seal() .Add(4, settings) .Seal() diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp index 80caaad37e..e1fc529fda 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp @@ -1,32 +1,28 @@ #include "mkql_match_recognize_list.h" -#include "mkql_match_recognize_matched_vars.h" #include "mkql_match_recognize_measure_arg.h" +#include "mkql_match_recognize_matched_vars.h" #include "mkql_match_recognize_nfa.h" +#include "mkql_match_recognize_rows_formatter.h" #include "mkql_match_recognize_save_load.h" #include "mkql_saveload.h" #include <yql/essentials/core/sql_types/match_recognize.h> -#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> #include <yql/essentials/minikql/computation/mkql_computation_node_pack.h> +#include <yql/essentials/minikql/mkql_node.h> #include <yql/essentials/minikql/mkql_node_cast.h> -#include <yql/essentials/minikql/mkql_runtime_version.h> #include <yql/essentials/minikql/mkql_string_util.h> -#include <yql/essentials/core/sql_types/match_recognize.h> + #include <deque> namespace NKikimr::NMiniKQL { namespace NMatchRecognize { -enum class EOutputColumnSource {PartitionKey, Measure}; -using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>, TMKQLAllocator<std::pair<EOutputColumnSource, size_t>>>; - constexpr ui32 StateVersion = 1; -using namespace NYql::NMatchRecognize; - struct TMatchRecognizeProcessorParameters { IComputationExternalNode* InputDataArg; TRowPattern Pattern; @@ -37,27 +33,21 @@ struct TMatchRecognizeProcessorParameters { TComputationNodePtrVector Defines; IComputationExternalNode* MeasureInputDataArg; TMeasureInputColumnOrder MeasureInputColumnOrder; - TComputationNodePtrVector Measures; - TOutputColumnOrder OutputColumnOrder; - TAfterMatchSkipTo SkipTo; + TAfterMatchSkipTo SkipTo; }; class TStreamingMatchRecognize { - using TPartitionList = TSparseList; - using TRange = TPartitionList::TRange; public: TStreamingMatchRecognize( NUdf::TUnboxedValue&& partitionKey, const TMatchRecognizeProcessorParameters& parameters, - TNfaTransitionGraph::TPtr nfaTransitions, - const TContainerCacheOnContext& cache - ) + const IRowsFormatter::TState& rowsFormatterState, + TNfaTransitionGraph::TPtr nfaTransitions) : PartitionKey(std::move(partitionKey)) , Parameters(parameters) + , RowsFormatter_(IRowsFormatter::Create(rowsFormatterState)) , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo) - , Cache(cache) - { - } + {} bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) { Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows)); @@ -71,33 +61,26 @@ public: } NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) { + if (auto result = RowsFormatter_->GetOtherMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph())) { + return result; + } auto match = Nfa.GetMatched(); if (!match) { return NUdf::TUnboxedValue{}; } - Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars)); + Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TSparseList::TRange>>(ctx.HolderFactory, match->Vars)); Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>( ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows), Parameters.MeasureInputColumnOrder, Parameters.MatchedVarsArg->GetValue(ctx), Parameters.VarNames, MatchNumber - )); - NUdf::TUnboxedValue *itemsPtr = nullptr; - const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr); - for (auto const& c: Parameters.OutputColumnOrder) { - switch(c.first) { - case EOutputColumnSource::Measure: - *itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx); - break; - case EOutputColumnSource::PartitionKey: - *itemsPtr++ = PartitionKey.GetElement(c.second); - break; - } - } + )); + auto result = RowsFormatter_->GetFirstMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph(), *match); Nfa.AfterMatchSkip(*match); return result; } + bool ProcessEndOfData(TComputationContext& ctx) { return Nfa.ProcessEndOfData(ctx); } @@ -117,11 +100,11 @@ public: } private: - const NUdf::TUnboxedValue PartitionKey; + NUdf::TUnboxedValue PartitionKey; const TMatchRecognizeProcessorParameters& Parameters; + std::unique_ptr<IRowsFormatter> RowsFormatter_; TSparseList Rows; TNfa Nfa; - const TContainerCacheOnContext& Cache; ui64 MatchNumber = 0; }; @@ -135,7 +118,7 @@ public: IComputationNode* partitionKey, TType* partitionKeyType, const TMatchRecognizeProcessorParameters& parameters, - const TContainerCacheOnContext& cache, + const IRowsFormatter::TState& rowsFormatterState, TComputationContext &ctx, TType* rowType, const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker @@ -145,8 +128,8 @@ public: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , Parameters(parameters) + , RowsFormatterState(rowsFormatterState) , RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup)) - , Cache(cache) , Terminating(false) , SerializerContext(ctx, rowType, rowPacker) , Ctx(ctx) @@ -184,8 +167,8 @@ public: PartitionHandler.reset(new TStreamingMatchRecognize( std::move(key), Parameters, - RowPatternConfiguration, - Cache + RowsFormatterState, + RowPatternConfiguration )); PartitionHandler->Load(in); } @@ -250,8 +233,9 @@ public: PartitionHandler.reset(new TStreamingMatchRecognize( std::move(partitionKey), Parameters, - RowPatternConfiguration, - Cache)); + RowsFormatterState, + RowPatternConfiguration + )); PartitionHandler->ProcessInputRow(std::move(temp), ctx); } if (Terminating) { @@ -266,8 +250,8 @@ private: IComputationNode* PartitionKey; TValuePackerGeneric<false> PartitionKeyPacker; const TMatchRecognizeProcessorParameters& Parameters; + const IRowsFormatter::TState& RowsFormatterState; const TNfaTransitionGraph::TPtr RowPatternConfiguration; - const TContainerCacheOnContext& Cache; NUdf::TUnboxedValue DelayedRow; bool Terminating; TSerializerContext SerializerContext; @@ -286,7 +270,7 @@ public: IComputationNode* partitionKey, TType* partitionKeyType, const TMatchRecognizeProcessorParameters& parameters, - const TContainerCacheOnContext& cache, + const IRowsFormatter::TState& rowsFormatterState, TComputationContext &ctx, TType* rowType, const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker @@ -296,8 +280,8 @@ public: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , Parameters(parameters) + , RowsFormatterState(rowsFormatterState) , NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup)) - , Cache(cache) , SerializerContext(ctx, rowType, rowPacker) , Ctx(ctx) {} @@ -335,8 +319,10 @@ public: std::make_unique<TStreamingMatchRecognize>( std::move(key), Parameters, - NfaTransitionGraph, - Cache)); + RowsFormatterState, + NfaTransitionGraph + ) + ); pair.first->second->Load(in); } @@ -402,8 +388,8 @@ private: return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamingMatchRecognize>( std::move(partitionKey), Parameters, - NfaTransitionGraph, - Cache + RowsFormatterState, + NfaTransitionGraph )); } } @@ -418,8 +404,8 @@ private: //TODO switch to tuple compare TValuePackerGeneric<false> PartitionKeyPacker; const TMatchRecognizeProcessorParameters& Parameters; + const IRowsFormatter::TState& RowsFormatterState; const TNfaTransitionGraph::TPtr NfaTransitionGraph; - const TContainerCacheOnContext& Cache; TSerializerContext SerializerContext; TComputationContext& Ctx; }; @@ -428,20 +414,23 @@ template<class State> class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true> { using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true>; public: - TMatchRecognizeWrapper(TComputationMutables &mutables, EValueRepresentation kind, IComputationNode *inputFlow, - IComputationExternalNode *inputRowArg, - IComputationNode *partitionKey, - TType* partitionKeyType, - const TMatchRecognizeProcessorParameters& parameters, - TType* rowType - ) - :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) + TMatchRecognizeWrapper( + TComputationMutables& mutables, + EValueRepresentation kind, + IComputationNode *inputFlow, + IComputationExternalNode *inputRowArg, + IComputationNode *partitionKey, + TType* partitionKeyType, + TMatchRecognizeProcessorParameters&& parameters, + IRowsFormatter::TState&& rowsFormatterState, + TType* rowType) + : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) , InputFlow(inputFlow) , InputRowArg(inputRowArg) , PartitionKey(partitionKey) , PartitionKeyType(partitionKeyType) - , Parameters(parameters) - , Cache(mutables) + , Parameters(std::move(parameters)) + , RowsFormatterState(std::move(rowsFormatterState)) , RowType(rowType) , RowPacker(mutables) {} @@ -453,7 +442,7 @@ public: PartitionKey, PartitionKeyType, Parameters, - Cache, + RowsFormatterState, ctx, RowType, RowPacker @@ -468,7 +457,7 @@ public: PartitionKey, PartitionKeyType, Parameters, - Cache, + RowsFormatterState, ctx, RowType, RowPacker @@ -503,7 +492,7 @@ private: Own(flow, Parameters.CurrentRowIndexArg); Own(flow, Parameters.MeasureInputDataArg); DependsOn(flow, PartitionKey); - for (auto& m: Parameters.Measures) { + for (auto& m: RowsFormatterState.Measures) { DependsOn(flow, m); } for (auto& d: Parameters.Defines) { @@ -516,32 +505,31 @@ private: IComputationExternalNode* const InputRowArg; IComputationNode* const PartitionKey; TType* const PartitionKeyType; - const TMatchRecognizeProcessorParameters Parameters; - const TContainerCacheOnContext Cache; + TMatchRecognizeProcessorParameters Parameters; + IRowsFormatter::TState RowsFormatterState; TType* const RowType; TMutableObjectOverBoxedValue<TValuePackerBoxed> RowPacker; }; TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode partitionKyeColumnsIndexes, TRuntimeNode measureColumnsIndexes) { - using tempMapValue = std::pair<EOutputColumnSource, size_t>; - std::unordered_map<size_t, tempMapValue, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, tempMapValue>, EMemorySubPool::Temporary>> temp; + std::unordered_map<size_t, TOutputColumnEntry, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, TOutputColumnEntry>, EMemorySubPool::Temporary>> temp; { auto list = AS_VALUE(TListLiteral, partitionKyeColumnsIndexes); for (ui32 i = 0; i != list->GetItemsCount(); ++i) { auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>(); - temp[index] = std::make_pair(EOutputColumnSource::PartitionKey, i); + temp[index] = {i, EOutputColumnSource::PartitionKey}; } } { auto list = AS_VALUE(TListLiteral, measureColumnsIndexes); for (ui32 i = 0; i != list->GetItemsCount(); ++i) { auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>(); - temp[index] = std::make_pair(EOutputColumnSource::Measure, i); + temp[index] = {i, EOutputColumnSource::Measure}; } } if (temp.empty()) return {}; - auto outputSize = max_element(temp.cbegin(), temp.cend())->first + 1; + auto outputSize = std::ranges::max_element(temp, {}, &std::pair<const size_t, TOutputColumnEntry>::first)->first + 1; TOutputColumnOrder result(outputSize); for (const auto& [i, v]: temp) { result[i] = v; @@ -576,7 +564,6 @@ TRowPattern ConvertPattern(const TRuntimeNode& pattern) { } TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) { - using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns; //Use Last enum value to denote that c colum comes from the input table TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0)); if (specialColumnIndexes.GetItemsCount() != 0) { @@ -621,7 +608,6 @@ std::pair<TUnboxedValueVector, THashMap<TString, size_t>> ConvertListOfStrings(c } //namespace NMatchRecognize - IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) { using namespace NMatchRecognize; size_t inputIndex = 0; @@ -641,9 +627,9 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation const auto& pattern = callable.GetInput(inputIndex++); const auto& currentRowIndexArg = callable.GetInput(inputIndex++); const auto& inputDataArg = callable.GetInput(inputIndex++); - const auto& varNames = callable.GetInput(inputIndex++); + const auto& defineNames = callable.GetInput(inputIndex++); TRuntimeNode::TList defines; - for (size_t i = 0; i != AS_VALUE(TListLiteral, varNames)->GetItemsCount(); ++i) { + for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) { defines.push_back(callable.GetInput(inputIndex++)); } const auto& streamingMode = callable.GetInput(inputIndex++); @@ -652,47 +638,58 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>()); skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef(); } + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch = NYql::NMatchRecognize::ERowsPerMatch::OneRow; + TOutputColumnOrder outputColumnOrder; + if (inputIndex + 2 <= callable.GetInputsCount()) { + rowsPerMatch = static_cast<ERowsPerMatch>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>()); + outputColumnOrder = IRowsFormatter::GetOutputColumnOrder(callable.GetInput(inputIndex++)); + } else { + outputColumnOrder = GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes); + } MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count"); - const auto& [vars, varsLookup] = ConvertListOfStrings(varNames); + const auto& [varNames, varNamesLookup] = ConvertListOfStrings(defineNames); auto* rowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputFlow.GetStaticType())->GetItemType()); - const auto parameters = TMatchRecognizeProcessorParameters { - static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())) - , ConvertPattern(pattern) - , vars - , varsLookup - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())) - , ConvertVectorOfCallables(defines, ctx) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())) - , GetMeasureColumnOrder( + auto parameters = TMatchRecognizeProcessorParameters { + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())), + ConvertPattern(pattern), + varNames, + varNamesLookup, + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())), + ConvertVectorOfCallables(defines, ctx), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())), + GetMeasureColumnOrder( *AS_VALUE(TListLiteral, measureSpecialColumnIndexes), AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>() - ) - , ConvertVectorOfCallables(measures, ctx) - , GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes) - , skipTo + ), + skipTo }; + IRowsFormatter::TState rowsFormatterState(ctx, outputColumnOrder, ConvertVectorOfCallables(measures, ctx), rowsPerMatch); if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) { - return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables - , GetValueRepresentation(inputFlow.GetStaticType()) - , LocateNode(ctx.NodeLocator, *inputFlow.GetNode()) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())) - , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()) - , partitionKeySelector.GetStaticType() - , std::move(parameters) - , rowType + return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>( + ctx.Mutables, + GetValueRepresentation(inputFlow.GetStaticType()), + LocateNode(ctx.NodeLocator, *inputFlow.GetNode()), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())), + LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()), + partitionKeySelector.GetStaticType(), + std::move(parameters), + std::move(rowsFormatterState), + rowType ); } else { - return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables - , GetValueRepresentation(inputFlow.GetStaticType()) - , LocateNode(ctx.NodeLocator, *inputFlow.GetNode()) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())) - , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()) - , partitionKeySelector.GetStaticType() - , std::move(parameters) - , rowType + return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>( + ctx.Mutables, + GetValueRepresentation(inputFlow.GetStaticType()), + LocateNode(ctx.NodeLocator, *inputFlow.GetNode()), + static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())), + LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()), + partitionKeySelector.GetStaticType(), + std::move(parameters), + std::move(rowsFormatterState), + rowType ); } } diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h index 3c7771a41e..cf9bfb46a9 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h @@ -18,26 +18,29 @@ public: class TRange { public: TRange() - : FromIndex(-1) - , ToIndex(-1) + : FromIndex(Max()) + , ToIndex(Max()) + , NfaIndex_(Max()) { } explicit TRange(ui64 index) - : FromIndex(index) - , ToIndex(index) + : FromIndex(index) + , ToIndex(index) + , NfaIndex_(Max()) { } TRange(ui64 from, ui64 to) - : FromIndex(from) - , ToIndex(to) + : FromIndex(from) + , ToIndex(to) + , NfaIndex_(Max()) { MKQL_ENSURE(FromIndex <= ToIndex, "Internal logic error"); } bool IsValid() const { - return true; + return FromIndex != Max<size_t>() && ToIndex != Max<size_t>(); } size_t From() const { @@ -50,6 +53,11 @@ public: return ToIndex; } + [[nodiscard]] size_t NfaIndex() const { + MKQL_ENSURE(IsValid(), "Internal logic error"); + return NfaIndex_; + } + size_t Size() const { MKQL_ENSURE(IsValid(), "Internal logic error"); return ToIndex - FromIndex + 1; @@ -61,8 +69,9 @@ public: } private: - ui64 FromIndex; - ui64 ToIndex; + size_t FromIndex; + size_t ToIndex; + size_t NfaIndex_; }; TRange Append(NUdf::TUnboxedValue&& value) { @@ -156,14 +165,12 @@ class TSparseList { private: //TODO consider to replace hash table with contiguous chunks - using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>; - using TStorage = std::unordered_map< size_t, TItem, std::hash<size_t>, std::equal_to<size_t>, - TAllocator>; + TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>>; TStorage Storage; }; @@ -178,8 +185,9 @@ public: public: TRange() : Container() - , FromIndex(-1) - , ToIndex(-1) + , FromIndex(Max()) + , ToIndex(Max()) + , NfaIndex_(Max()) { } @@ -187,6 +195,7 @@ public: : Container(other.Container) , FromIndex(other.FromIndex) , ToIndex(other.ToIndex) + , NfaIndex_(other.NfaIndex_) { LockRange(FromIndex, ToIndex); } @@ -195,6 +204,7 @@ public: : Container(other.Container) , FromIndex(other.FromIndex) , ToIndex(other.ToIndex) + , NfaIndex_(other.NfaIndex_) { other.Reset(); } @@ -212,6 +222,7 @@ public: Container = other.Container; FromIndex = other.FromIndex; ToIndex = other.ToIndex; + NfaIndex_ = other.NfaIndex_; LockRange(FromIndex, ToIndex); return *this; } @@ -224,20 +235,21 @@ public: Container = other.Container; FromIndex = other.FromIndex; ToIndex = other.ToIndex; + NfaIndex_ = other.NfaIndex_; other.Reset(); return *this; } friend inline bool operator==(const TRange& lhs, const TRange& rhs) { - return std::tie(lhs.FromIndex, lhs.ToIndex) == std::tie(rhs.FromIndex, rhs.ToIndex); + return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) == std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_); } friend inline bool operator<(const TRange& lhs, const TRange& rhs) { - return std::tie(lhs.FromIndex, lhs.ToIndex) < std::tie(rhs.FromIndex, rhs.ToIndex); + return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) < std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_); } bool IsValid() const { - return static_cast<bool>(Container); + return static_cast<bool>(Container) && FromIndex != Max<size_t>() && ToIndex != Max<size_t>(); } size_t From() const { @@ -250,6 +262,15 @@ public: return ToIndex; } + [[nodiscard]] size_t NfaIndex() const { + MKQL_ENSURE(IsValid(), "Internal logic error"); + return NfaIndex_; + } + + void NfaIndex(size_t index) { + NfaIndex_ = index; + } + size_t Size() const { MKQL_ENSURE(IsValid(), "Internal logic error"); return ToIndex - FromIndex + 1; @@ -264,16 +285,17 @@ public: void Release() { UnlockRange(FromIndex, ToIndex); Container.Reset(); - FromIndex = -1; - ToIndex = -1; + FromIndex = Max(); + ToIndex = Max(); + NfaIndex_ = Max(); } void Save(TMrOutputSerializer& serializer) const { - serializer(Container, FromIndex, ToIndex); + serializer(Container, FromIndex, ToIndex, NfaIndex_); } void Load(TMrInputSerializer& serializer) { - serializer(Container, FromIndex, ToIndex); + serializer(Container, FromIndex, ToIndex, NfaIndex_); } private: @@ -281,6 +303,7 @@ public: : Container(container) , FromIndex(index) , ToIndex(index) + , NfaIndex_(Max()) {} void LockRange(size_t from, size_t to) { @@ -297,13 +320,15 @@ public: void Reset() { Container.Reset(); - FromIndex = -1; - ToIndex = -1; + FromIndex = Max(); + ToIndex = Max(); + NfaIndex_ = Max(); } TContainerPtr Container; size_t FromIndex; size_t ToIndex; + size_t NfaIndex_; }; public: diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h index 8c2576798a..de0ccc352f 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h @@ -15,7 +15,7 @@ void Extend(TMatchedVar<R>& var, const R& r) { var.emplace_back(r); } else { MKQL_ENSURE(r.From() > var.back().To(), "Internal logic error"); - if (var.back().To() + 1 == r.From()) { + if (var.back().To() + 1 == r.From() && var.back().NfaIndex() == r.NfaIndex()) { var.back().Extend(); } else { var.emplace_back(r); diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h index 2b194212f4..c868379d40 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h @@ -20,9 +20,10 @@ struct TEpsilonTransitions { friend constexpr bool operator==(const TEpsilonTransitions&, const TEpsilonTransitions&) = default; }; struct TMatchedVarTransition { + size_t To; ui32 VarIndex; bool SaveState; - size_t To; + bool ExcludeFromOutput; friend constexpr bool operator==(const TMatchedVarTransition&, const TMatchedVarTransition&) = default; }; struct TQuantityEnterTransition { @@ -116,7 +117,7 @@ struct TNfaTransitionGraph { serializer(tr.To); }, [&](const TMatchedVarTransition& tr) { - serializer(tr.VarIndex, tr.SaveState, tr.To); + serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput); }, [&](const TQuantityEnterTransition& tr) { serializer(tr.To); @@ -141,7 +142,7 @@ struct TNfaTransitionGraph { serializer(tr.To); }, [&](TMatchedVarTransition& tr) { - serializer(tr.VarIndex, tr.SaveState, tr.To); + serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput); }, [&](TQuantityEnterTransition& tr) { serializer(tr.To); @@ -297,7 +298,7 @@ private: auto input = AddNode(); auto output = AddNode(); auto item = factor.Primary.index() == 0 ? - BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused) : + BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused, !factor.Output) : BuildTerms(std::get<1>(factor.Primary), varNameToIndex); if (1 == factor.QuantityMin && 1 == factor.QuantityMax) { //simple linear case Graph->Transitions[input] = TEpsilonTransitions{{item.Input}}; @@ -319,15 +320,16 @@ private: } return {input, output}; } - TNfaItem BuildVar(ui32 varIndex, bool isUsed) { + TNfaItem BuildVar(ui32 varIndex, bool isUsed, bool excludeFromOutput) { auto input = AddNode(); auto matchVar = AddNode(); auto output = AddNode(); Graph->Transitions[input] = TEpsilonTransitions({matchVar}); Graph->Transitions[matchVar] = TMatchedVarTransition{ + output, varIndex, isUsed, - output, + excludeFromOutput, }; return {input, output}; } @@ -441,6 +443,7 @@ public: if (matchedVarTransition->SaveState) { auto vars = state.Match.Vars; //TODO get rid of this copy auto& matchedVar = vars[varIndex]; + currentRowLock.NfaIndex(state.Index); Extend(matchedVar, currentRowLock); newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers); } else { @@ -575,6 +578,10 @@ public: } } + const TNfaTransitionGraph& GetTransitionGraph() const { + return *TransitionGraph; + } + private: //TODO (zverevgeny): Consider to change to std::vector for the sake of perf using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>; @@ -593,14 +600,14 @@ private: [&](const TEpsilonTransitions& epsilonTransitions) { deletedStates.insert(state); for (const auto& i : epsilonTransitions.To) { - newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers); + newStates.emplace(i, state.Match, state.Quantifiers); } }, [&](const TQuantityEnterTransition& quantityEnterTransition) { deletedStates.insert(state); auto quantifiers = state.Quantifiers; //TODO get rid of this copy quantifiers.push_back(0); - newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers)); + newStates.emplace(quantityEnterTransition.To, state.Match, std::move(quantifiers)); }, [&](const TQuantityExitTransition& quantityExitTransition) { deletedStates.insert(state); @@ -608,12 +615,12 @@ private: if (state.Quantifiers.back() + 1 < quantityMax) { auto q = state.Quantifiers; q.back()++; - newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); + newStates.emplace(toFindMore, state.Match, std::move(q)); } if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) { auto q = state.Quantifiers; q.pop_back(); - newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q)); + newStates.emplace(toMatched, state.Match, std::move(q)); } }, }, TransitionGraph->Transitions[state.Index]); diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp new file mode 100644 index 0000000000..6d70458b3b --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp @@ -0,0 +1,144 @@ +#include "mkql_match_recognize_rows_formatter.h" + +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/minikql/mkql_node_cast.h> + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +namespace { + +class TOneRowFormatter final : public IRowsFormatter { +public: + explicit TOneRowFormatter(const TState& state) : IRowsFormatter(state) {} + + NUdf::TUnboxedValue GetFirstMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph, + const TNfa::TMatch& match) { + Match_ = match; + const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph); + IRowsFormatter::Clear(); + return result; + } + + NUdf::TUnboxedValue GetOtherMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph) { + return NUdf::TUnboxedValue{}; + } +}; + +class TAllRowsFormatter final : public IRowsFormatter { +public: + explicit TAllRowsFormatter(const IRowsFormatter::TState& state) : IRowsFormatter(state) {} + + NUdf::TUnboxedValue GetFirstMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph, + const TNfa::TMatch& match) { + Match_ = match; + CurrentRowIndex_ = Match_.BeginIndex; + for (const auto& matchedVar : Match_.Vars) { + for (const auto& range : matchedVar) { + ToIndexToMatchRangeLookup_.emplace(range.To(), range); + } + } + return GetMatchRow(ctx, rows, partitionKey, graph); + } + + NUdf::TUnboxedValue GetOtherMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph) { + return GetMatchRow(ctx, rows, partitionKey, graph); + } + +private: + NUdf::TUnboxedValue GetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) { + while (CurrentRowIndex_ <= Match_.EndIndex) { + if (auto iter = ToIndexToMatchRangeLookup_.lower_bound(CurrentRowIndex_); + iter == ToIndexToMatchRangeLookup_.end()) { + MKQL_ENSURE(false, "Internal logic error"); + } else if (auto transition = std::get_if<TMatchedVarTransition>(&graph.Transitions.at(iter->second.NfaIndex())); + !transition) { + MKQL_ENSURE(false, "Internal logic error"); + } else if (transition->ExcludeFromOutput) { + ++CurrentRowIndex_; + } else { + break; + } + } + if (CurrentRowIndex_ > Match_.EndIndex) { + return NUdf::TUnboxedValue{}; + } + const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph); + ++CurrentRowIndex_; + if (CurrentRowIndex_ == Match_.EndIndex) { + Clear(); + } + return result; + } + + void Clear() { + IRowsFormatter::Clear(); + ToIndexToMatchRangeLookup_.clear(); + } + + TMap<size_t, const TSparseList::TRange&> ToIndexToMatchRangeLookup_; +}; + +} // anonymous namespace + +IRowsFormatter::IRowsFormatter(const TState& state) : State_(state) {} + +TOutputColumnOrder IRowsFormatter::GetOutputColumnOrder( + TRuntimeNode outputColumnOrder) { + TOutputColumnOrder result; + auto list = AS_VALUE(TListLiteral, outputColumnOrder); + TConstArrayRef<TRuntimeNode> items(list->GetItems(), list->GetItemsCount()); + for (auto item : items) { + const auto entry = AS_VALUE(TStructLiteral, item); + result.emplace_back( + AS_VALUE(TDataLiteral, entry->GetValue(0))->AsValue().Get<ui32>(), + static_cast<EOutputColumnSource>(AS_VALUE(TDataLiteral, entry->GetValue(1))->AsValue().Get<i32>()) + ); + } + return result; +} + +NUdf::TUnboxedValue IRowsFormatter::DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) { + NUdf::TUnboxedValue *itemsPtr = nullptr; + const auto result = State_.Cache->NewArray(ctx, State_.OutputColumnOrder.size(), itemsPtr); + for (const auto& columnEntry: State_.OutputColumnOrder) { + switch(columnEntry.SourceType) { + case EOutputColumnSource::PartitionKey: + *itemsPtr++ = partitionKey.GetElement(columnEntry.Index); + break; + case EOutputColumnSource::Measure: + *itemsPtr++ = State_.Measures[columnEntry.Index]->GetValue(ctx); + break; + case EOutputColumnSource::Other: + *itemsPtr++ = rows.Get(CurrentRowIndex_).GetElement(columnEntry.Index); + break; + } + } + return result; +} + +std::unique_ptr<IRowsFormatter> IRowsFormatter::Create(const IRowsFormatter::TState& state) { + switch (state.RowsPerMatch) { + case ERowsPerMatch::OneRow: + return std::unique_ptr<IRowsFormatter>(new TOneRowFormatter(state)); + case ERowsPerMatch::AllRows: + return std::unique_ptr<IRowsFormatter>(new TAllRowsFormatter(state)); + } +} + +} //namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h new file mode 100644 index 0000000000..39750bf4ea --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h @@ -0,0 +1,72 @@ +#pragma once + +#include "mkql_match_recognize_nfa.h" + +#include <yql/essentials/core/sql_types/match_recognize.h> +#include <yql/essentials/minikql/computation/mkql_computation_node.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h> +#include <yql/essentials/minikql/mkql_alloc.h> +#include <yql/essentials/public/udf/udf_value.h> + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +struct TOutputColumnEntry { + size_t Index; + NYql::NMatchRecognize::EOutputColumnSource SourceType; +}; +using TOutputColumnOrder = std::vector<TOutputColumnEntry, TMKQLAllocator<TOutputColumnEntry>>; + +class IRowsFormatter { +public: + struct TState { + std::unique_ptr<TContainerCacheOnContext> Cache; + TOutputColumnOrder OutputColumnOrder; + TComputationNodePtrVector Measures; + NYql::NMatchRecognize::ERowsPerMatch RowsPerMatch; + + TState( + const TComputationNodeFactoryContext& ctx, + TOutputColumnOrder outputColumnOrder, + TComputationNodePtrVector measures, + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch) + : Cache(std::make_unique<TContainerCacheOnContext>(ctx.Mutables)) + , OutputColumnOrder(std::move(outputColumnOrder)) + , Measures(std::move(measures)) + , RowsPerMatch(rowsPerMatch) + {} + }; + + explicit IRowsFormatter(const TState& state); + virtual ~IRowsFormatter() = default; + + virtual NUdf::TUnboxedValue GetFirstMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph, + const TNfa::TMatch& match) = 0; + + virtual NUdf::TUnboxedValue GetOtherMatchRow( + TComputationContext& ctx, + const TSparseList& rows, + const NUdf::TUnboxedValue& partitionKey, + const TNfaTransitionGraph& graph) = 0; + + static TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode outputColumnOrder); + + static std::unique_ptr<IRowsFormatter> Create(const TState& state); + +protected: + NUdf::TUnboxedValue DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph); + + inline void Clear() { + Match_ = {}; + CurrentRowIndex_ = Max(); + } + + const TState& State_; + TNfa::TMatch Match_ {}; + size_t CurrentRowIndex_ = Max(); +}; + +} // namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp index a9fad1b6ef..513a72df5e 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp @@ -64,57 +64,46 @@ namespace { const TTestInputData& input) { TProgramBuilder& pgmBuilder = *setup.PgmBuilder; - auto structType = pgmBuilder.NewStructType({ - {"time", pgmBuilder.NewDataType(NUdf::TDataType<i64>::Id)}, - {"key", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}, - {"sum", pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id)}, - {"part", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}}); + const auto structType = pgmBuilder.NewStructType({ + {"time", pgmBuilder.NewDataType(NUdf::EDataSlot::Int64)}, + {"key", pgmBuilder.NewDataType(NUdf::EDataSlot::String)}, + {"sum", pgmBuilder.NewDataType(NUdf::EDataSlot::Uint32)}, + {"part", pgmBuilder.NewDataType(NUdf::EDataSlot::String)} + }); TVector<TRuntimeNode> items; - for (size_t i = 0; i < input.size(); ++i) - { - auto time = pgmBuilder.NewDataLiteral<i64>(std::get<0>(input[i])); - auto key = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<1>(input[i]))); - auto sum = pgmBuilder.NewDataLiteral<ui32>(std::get<2>(input[i])); - auto part = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<3>(input[i]))); - - auto item = pgmBuilder.NewStruct(structType, - {{"time", time}, {"key", key}, {"sum", sum}, {"part", part}}); - items.push_back(std::move(item)); + for (size_t i = 0; i < input.size(); ++i) { + const auto& [time, key, sum, part] = input[i]; + items.push_back(pgmBuilder.NewStruct({ + {"time", pgmBuilder.NewDataLiteral(time)}, + {"key", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(key)}, + {"sum", pgmBuilder.NewDataLiteral(sum)}, + {"part", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(part)}, + })); } const auto list = pgmBuilder.NewList(structType, std::move(items)); auto inputFlow = pgmBuilder.ToFlow(list); - - TVector<TStringBuf> partitionColumns; - TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures = {{ - std::make_pair( - TStringBuf("key"), - [&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) { - return pgmBuilder.NewDataLiteral<ui32>(56); - } - )}}; - TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines = {{ - std::make_pair( - TStringBuf("A"), - [&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) { - return pgmBuilder.NewDataLiteral<bool>(true); - } - )}}; - auto pgmReturn = pgmBuilder.MatchRecognizeCore( inputFlow, [&](TRuntimeNode item) { - return pgmBuilder.Member(item, "part"); + return pgmBuilder.NewTuple({pgmBuilder.Member(item, "part")}); }, - partitionColumns, - getMeasures, + {}, + {"key"sv}, + {[&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) { + return pgmBuilder.NewDataLiteral<ui32>(56); + }}, { {NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}} }, - getDefines, + {"A"sv}, + {[&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) { + return pgmBuilder.NewDataLiteral<bool>(true); + }}, streamingMode, - {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""} + {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}, + NYql::NMatchRecognize::ERowsPerMatch::OneRow ); auto graph = setup.BuildGraph(pgmReturn); diff --git a/yql/essentials/minikql/comp_nodes/ya.make.inc b/yql/essentials/minikql/comp_nodes/ya.make.inc index 518f5f3c43..b2f0da8ac9 100644 --- a/yql/essentials/minikql/comp_nodes/ya.make.inc +++ b/yql/essentials/minikql/comp_nodes/ya.make.inc @@ -79,6 +79,7 @@ SET(ORIG_SOURCES mkql_mapnext.cpp mkql_map_join.cpp mkql_match_recognize.cpp + mkql_match_recognize_rows_formatter.cpp mkql_multihopping.cpp mkql_multimap.cpp mkql_next_value.cpp diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 691a642814..0b9197024d 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -1,5 +1,4 @@ #include "mkql_program_builder.h" -#include "mkql_opt_literal.h" #include "mkql_node_visitor.h" #include "mkql_node_cast.h" #include "mkql_runtime_version.h" @@ -11,6 +10,7 @@ #include "yql/essentials/core/sql_types/time_order_recover.h" #include <yql/essentials/parser/pg_catalog/catalog.h> +#include <util/generic/overloaded.h> #include <util/string/cast.h> #include <util/string/printf.h> #include <array> @@ -6035,15 +6035,19 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil TTupleLiteralBuilder termBuilder(env); for (const auto& factor: term) { TTupleLiteralBuilder factorBuilder(env); - factorBuilder.Add(factor.Primary.index() == 0 ? - programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(std::get<0>(factor.Primary)) : - PatternToRuntimeNode(std::get<1>(factor.Primary), programBuilder) - ); - factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMin)); - factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMax)); - factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Greedy)); - factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Output)); - factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Unused)); + factorBuilder.Add(std::visit(TOverloaded { + [&](const TString& s) { + return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s); + }, + [&](const TRowPattern& pattern) { + return PatternToRuntimeNode(pattern, programBuilder); + }, + }, factor.Primary)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output)); + factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused)); termBuilder.Add({factorBuilder.Build(), true}); } patternBuilder.Add({termBuilder.Build(), true}); @@ -6056,151 +6060,172 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil TRuntimeNode TProgramBuilder::MatchRecognizeCore( TRuntimeNode inputStream, const TUnaryLambda& getPartitionKeySelectorNode, - const TArrayRef<TStringBuf>& partitionColumns, - const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures, + const TArrayRef<TStringBuf>& partitionColumnNames, + const TVector<TStringBuf>& measureColumnNames, + const TVector<TBinaryLambda>& getMeasures, const NYql::NMatchRecognize::TRowPattern& pattern, - const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines, + const TVector<TStringBuf>& defineVarNames, + const TVector<TTernaryLambda>& getDefines, bool streamingMode, - const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo + const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo, + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch ) { MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion); const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType()); const auto inputRowArg = Arg(inputRowType); const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg); + const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements(); - TStructTypeBuilder indexRangeTypeBuilder(Env); - indexRangeTypeBuilder.Add("From", TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); - indexRangeTypeBuilder.Add("To", TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); - const auto& rangeList = TListType::Create(indexRangeTypeBuilder.Build(), Env); + const auto rangeList = NewListType(NewStructType({ + {"From", NewDataType(NUdf::EDataSlot::Uint64)}, + {"To", NewDataType(NUdf::EDataSlot::Uint64)} + })); TStructTypeBuilder matchedVarsTypeBuilder(Env); for (const auto& var: GetPatternVars(pattern)) { matchedVarsTypeBuilder.Add(var, rangeList); } - TRuntimeNode matchedVarsArg = Arg(matchedVarsTypeBuilder.Build()); + const auto matchedVarsType = matchedVarsTypeBuilder.Build(); + TRuntimeNode matchedVarsArg = Arg(matchedVarsType); //---These vars may be empty in case of no measures TRuntimeNode measureInputDataArg; std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow; TVector<TRuntimeNode> measures; - TVector<TType*> measureTypes; //--- if (getMeasures.empty()) { measureInputDataArg = Arg(Env.GetTypeOfVoidLazy()); } else { - using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns; measures.reserve(getMeasures.size()); - measureTypes.reserve(getMeasures.size()); specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last)); TStructTypeBuilder measureInputDataRowTypeBuilder(Env); - for (ui32 i = 0; i != inputRowType->GetMembersCount(); ++i) { + for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) { measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i)); } measureInputDataRowTypeBuilder.Add( MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier), - TDataType::Create(NUdf::TDataType<NYql::NUdf::TUtf8>::Id, Env) + NewDataType(NUdf::EDataSlot::Utf8) ); measureInputDataRowTypeBuilder.Add( MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber), - TDataType::Create(NUdf::TDataType<ui64>::Id, Env) + NewDataType(NUdf::EDataSlot::Uint64) ); const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build(); - for (ui32 i = 0; i != measureInputDataRowType->GetMembersCount(); ++i) { + for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) { //assume a few, if grows, it's better to use a lookup table here static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5); for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) { if (measureInputDataRowType->GetMemberName(i) == NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j))) - specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral<ui32>(i); + specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i); } } - measureInputDataArg = Arg(TListType::Create(measureInputDataRowType, Env)); + measureInputDataArg = Arg(NewListType(measureInputDataRowType)); for (size_t i = 0; i != getMeasures.size(); ++i) { - measures.push_back(getMeasures[i].second(measureInputDataArg, matchedVarsArg)); - measureTypes.push_back(measures[i].GetStaticType()); + measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg)); } } TStructTypeBuilder outputRowTypeBuilder(Env); THashMap<TStringBuf, size_t> partitionColumnLookup; - for (size_t i = 0; i != partitionColumns.size(); ++i) { - const auto& name = partitionColumns[i]; - partitionColumnLookup[name] = i; - outputRowTypeBuilder.Add( - name, - AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElementType(i) - ); - } THashMap<TStringBuf, size_t> measureColumnLookup; - for (size_t i = 0; i != measures.size(); ++i) { - const auto& name = getMeasures[i].first; - measureColumnLookup[name] = i; - outputRowTypeBuilder.Add( - name, - measures[i].GetStaticType() - ); + THashMap<TStringBuf, size_t> otherColumnLookup; + for (size_t i = 0; i < measureColumnNames.size(); ++i) { + const auto name = measureColumnNames[i]; + measureColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, measures[i].GetStaticType()); + } + switch (rowsPerMatch) { + case NYql::NMatchRecognize::ERowsPerMatch::OneRow: + for (size_t i = 0; i < partitionColumnNames.size(); ++i) { + const auto name = partitionColumnNames[i]; + partitionColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, partitionColumnTypes[i]); + } + break; + case NYql::NMatchRecognize::ERowsPerMatch::AllRows: + for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) { + const auto name = inputRowType->GetMemberName(i); + otherColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i)); + } + break; } auto outputRowType = outputRowTypeBuilder.Build(); std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size()); std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size()); - for (ui32 i = 0; i != outputRowType->GetMembersCount(); ++i) { - if (auto it = partitionColumnLookup.find(outputRowType->GetMemberName(i)); it != partitionColumnLookup.end()) { - partitionColumnIndexes[it->second] = NewDataLiteral<ui32>(i); - } - else if (auto it = measureColumnLookup.find(outputRowType->GetMemberName(i)); it != measureColumnLookup.end()) { - measureColumnIndexes[it->second] = NewDataLiteral<ui32>(i); + TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()}); + for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) { + const auto name = outputRowType->GetMemberName(i); + if (auto iter = partitionColumnLookup.find(name); + iter != partitionColumnLookup.end()) { + partitionColumnIndexes[iter->second] = NewDataLiteral(i); + outputColumnOrder.push_back(NewStruct({ + std::pair{"Index", NewDataLiteral(iter->second)}, + std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))}, + })); + } else if (auto iter = measureColumnLookup.find(name); + iter != measureColumnLookup.end()) { + measureColumnIndexes[iter->second] = NewDataLiteral(i); + outputColumnOrder.push_back(NewStruct({ + std::pair{"Index", NewDataLiteral(iter->second)}, + std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))}, + })); + } else if (auto iter = otherColumnLookup.find(name); + iter != otherColumnLookup.end()) { + outputColumnOrder.push_back(NewStruct({ + std::pair{"Index", NewDataLiteral(iter->second)}, + std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))}, + })); } } - auto outputType = (TType*)TFlowType::Create(outputRowType, Env); + const auto outputType = NewFlowType(outputRowType); - THashMap<TStringBuf , size_t> patternVarLookup; - for (ui32 i = 0; i != AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMembersCount(); ++i){ - patternVarLookup[AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMemberName(i)] = i; + THashMap<TStringBuf, size_t> patternVarLookup; + for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) { + patternVarLookup[matchedVarsType->GetMemberName(i)] = i; } - THashMap<TStringBuf , size_t> defineLookup; - for (size_t i = 0; i != getDefines.size(); ++i) { - defineLookup[getDefines[i].first] = i; + THashMap<TStringBuf, size_t> defineLookup; + for (size_t i = 0; i < defineVarNames.size(); ++i) { + const auto name = defineVarNames[i]; + defineLookup[name] = i; } - std::vector<TRuntimeNode> defineNames(patternVarLookup.size()); - std::vector<TRuntimeNode> defineNodes(patternVarLookup.size()); - const auto& inputDataArg = Arg(TListType::Create(inputRowType, Env)); - const auto& currentRowIndexArg = Arg(TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); + TVector<TRuntimeNode> defineNames(defineVarNames.size()); + TVector<TRuntimeNode> defineNodes(patternVarLookup.size()); + const auto inputDataArg = Arg(NewListType(inputRowType)); + const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64)); for (const auto& [v, i]: patternVarLookup) { defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v); - if (const auto it = defineLookup.find(v); it != defineLookup.end()) { - defineNodes[i] = getDefines[it->second].second(inputDataArg, matchedVarsArg, currentRowIndexArg); - } - else { //no predicate for var - if ("$" == v || "^" == v) { - //DO nothing, //will be handled in a specific way - } - else { // a var without a predicate matches any row - defineNodes[i] = NewDataLiteral<bool>(true); - } + if (auto iter = defineLookup.find(v); + iter != defineLookup.end()) { + defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg); + } else if ("$" == v || "^" == v) { + //DO nothing, //will be handled in a specific way + } else { // a var without a predicate matches any row + defineNodes[i] = NewDataLiteral(true); } } TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType); - auto indexType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env); - auto indexListType = TListType::Create(indexType, Env); + const auto indexType = NewDataType(NUdf::EDataSlot::Uint32); + const auto outputColumnEntryType = NewStructType({ + {"Index", NewDataType(NUdf::EDataSlot::Uint64)}, + {"SourceType", NewDataType(NUdf::EDataSlot::Int32)}, + }); callableBuilder.Add(inputStream); callableBuilder.Add(inputRowArg); callableBuilder.Add(partitionKeySelectorNode); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create(partitionColumnIndexes.data(), partitionColumnIndexes.size(), indexListType, Env), true)); + callableBuilder.Add(NewList(indexType, partitionColumnIndexes)); callableBuilder.Add(measureInputDataArg); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create( - specialColumnIndexesInMeasureInputDataRow.data(), specialColumnIndexesInMeasureInputDataRow.size(), - indexListType, Env - ), - true)); - callableBuilder.Add(NewDataLiteral<ui32>(inputRowType->GetMembersCount())); + callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow)); + callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount())); callableBuilder.Add(matchedVarsArg); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create(measureColumnIndexes.data(), measureColumnIndexes.size(), indexListType, Env), true)); + callableBuilder.Add(NewList(indexType, measureColumnIndexes)); for (const auto& m: measures) { callableBuilder.Add(m); } @@ -6209,16 +6234,19 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( callableBuilder.Add(currentRowIndexArg); callableBuilder.Add(inputDataArg); - const auto stringType = NewDataType(NUdf::EDataSlot::String); - callableBuilder.Add(TRuntimeNode(TListLiteral::Create(defineNames.begin(), defineNames.size(), TListType::Create(stringType, Env), Env), true)); + callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames)); for (const auto& d: defineNodes) { callableBuilder.Add(d); } callableBuilder.Add(NewDataLiteral(streamingMode)); - if (RuntimeVersion >= 52U) { + if constexpr (RuntimeVersion >= 52U) { callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To))); callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var)); } + if constexpr (RuntimeVersion >= 53U) { + callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch))); + callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder)); + } return TRuntimeNode(callableBuilder.Build(), false); } diff --git a/yql/essentials/minikql/mkql_program_builder.h b/yql/essentials/minikql/mkql_program_builder.h index 762d3a013e..9e6b0d97d0 100644 --- a/yql/essentials/minikql/mkql_program_builder.h +++ b/yql/essentials/minikql/mkql_program_builder.h @@ -712,12 +712,15 @@ public: TRuntimeNode MatchRecognizeCore( TRuntimeNode inputStream, const TUnaryLambda& getPartitionKeySelectorNode, - const TArrayRef<TStringBuf>& partitionColumns, - const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures, + const TArrayRef<TStringBuf>& partitionColumnNames, + const TVector<TStringBuf>& measureColumnNames, + const TVector<TBinaryLambda>& getMeasures, const NYql::NMatchRecognize::TRowPattern& pattern, - const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines, + const TVector<TStringBuf>& defineVarNames, + const TVector<TTernaryLambda>& getDefines, bool streamingMode, - const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo + const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo, + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch ); TRuntimeNode TimeOrderRecover( diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp index fb2a2a29c1..211fc6c84c 100644 --- a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp +++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp @@ -878,17 +878,17 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto& settings = node.Child(4); //explore params - const auto& measures = params->ChildRef(0); - const auto& skipTo = params->ChildRef(2); - const auto& pattern = params->ChildRef(3); - const auto& defines = params->ChildRef(4); + const auto measures = params->Child(0); + const auto skipTo = params->Child(2); + const auto pattern = params->Child(3); + const auto defines = params->Child(4); //explore measures - const auto measureNames = measures->ChildRef(2); + const auto measureNames = measures->Child(2); constexpr size_t FirstMeasureLambdaIndex = 3; //explore defines - const auto defineNames = defines->ChildRef(2); + const auto defineNames = defines->Child(2); const size_t FirstDefineLambdaIndex = 3; TVector<TStringBuf> partitionColumnNames; @@ -900,25 +900,21 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return MkqlBuildLambda(*partitionKeySelector, ctx, {inputRowArg}); }; - TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines(defineNames->ChildrenSize()); + TVector<TStringBuf> defineVarNames(defineNames->ChildrenSize()); + TVector<TProgramBuilder::TTernaryLambda> getDefines(defineNames->ChildrenSize()); for (size_t i = 0; i != defineNames->ChildrenSize(); ++i) { - getDefines[i] = std::pair{ - defineNames->ChildRef(i)->Content(), - [i, defines, &ctx](TRuntimeNode data, TRuntimeNode matchedVars, TRuntimeNode rowIndex) { - return MkqlBuildLambda(*defines->ChildRef(FirstDefineLambdaIndex + i), ctx, - {data, matchedVars, rowIndex}); - } + defineVarNames[i] = defineNames->Child(i)->Content(); + getDefines[i] = [i, defines, &ctx](TRuntimeNode data, TRuntimeNode matchedVars, TRuntimeNode rowIndex) { + return MkqlBuildLambda(*defines->Child(FirstDefineLambdaIndex + i), ctx, {data, matchedVars, rowIndex}); }; } - TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures(measureNames->ChildrenSize()); + TVector<TStringBuf> measureColumnNames(measureNames->ChildrenSize()); + TVector<TProgramBuilder::TBinaryLambda> getMeasures(measureNames->ChildrenSize()); for (size_t i = 0; i != measureNames->ChildrenSize(); ++i) { - getMeasures[i] = std::pair{ - measureNames->ChildRef(i)->Content(), - [i, measures, &ctx](TRuntimeNode data, TRuntimeNode matchedVars) { - return MkqlBuildLambda(*measures->ChildRef(FirstMeasureLambdaIndex + i), ctx, - {data, matchedVars}); - } + measureColumnNames[i] = measureNames->Child(i)->Content(); + getMeasures[i] = [i, measures, &ctx](TRuntimeNode data, TRuntimeNode matchedVars) { + return MkqlBuildLambda(*measures->Child(FirstMeasureLambdaIndex + i), ctx, {data, matchedVars}); }; } @@ -928,18 +924,26 @@ TMkqlCommonCallableCompiler::TShared::TShared() { NYql::NMatchRecognize::EAfterMatchSkipTo to; MKQL_ENSURE(TryFromString<NYql::NMatchRecognize::EAfterMatchSkipTo>(stringTo, to), "MATCH_RECOGNIZE: <row pattern skip to> cannot parse AfterMatchSkipTo mode"); + auto rowsPerMatchString = params->Child(1)->Content(); + MKQL_ENSURE(rowsPerMatchString.SkipPrefix("RowsPerMatch_"), R"(MATCH_RECOGNIZE: <row pattern rows per match> should start with "RowsPerMatch_")"); + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch; + MKQL_ENSURE(TryFromString<NYql::NMatchRecognize::ERowsPerMatch>(rowsPerMatchString, rowsPerMatch), "MATCH_RECOGNIZE: cannot parse RowsPerMatch mode"); + const auto streamingMode = FromString<bool>(settings->Child(0)->Child(1)->Content()); return ctx.ProgramBuilder.MatchRecognizeCore( MkqlBuildExpr(*inputStream, ctx), getPartitionKeySelector, partitionColumnNames, + measureColumnNames, getMeasures, NYql::NMatchRecognize::ConvertPattern(pattern, ctx.ExprCtx), + defineVarNames, getDefines, streamingMode, - NYql::NMatchRecognize::TAfterMatchSkipTo{to, TString{var}} - ); + NYql::NMatchRecognize::TAfterMatchSkipTo{to, TString{var}}, + rowsPerMatch + ); }); AddCallable("TimeOrderRecover", [](const TExprNode& node, TMkqlBuildContext& ctx) { diff --git a/yql/essentials/sql/v1/SQLv1.g.in b/yql/essentials/sql/v1/SQLv1.g.in index e9685c5094..670ad27e3e 100644 --- a/yql/essentials/sql/v1/SQLv1.g.in +++ b/yql/essentials/sql/v1/SQLv1.g.in @@ -460,7 +460,7 @@ row_pattern_primary: | DOLLAR | CARET | LPAREN row_pattern? RPAREN - | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY //TODO This rule accepts spaces between brace and minus sign, i.e: { - S2 - } that is not supposed to. Handle this case in https://st.yandex-team.ru/YQL-16227 + | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY | row_pattern_permute ; diff --git a/yql/essentials/sql/v1/SQLv1Antlr4.g.in b/yql/essentials/sql/v1/SQLv1Antlr4.g.in index 40593fe075..89131437e9 100644 --- a/yql/essentials/sql/v1/SQLv1Antlr4.g.in +++ b/yql/essentials/sql/v1/SQLv1Antlr4.g.in @@ -459,7 +459,7 @@ row_pattern_primary: | DOLLAR | CARET | LPAREN row_pattern? RPAREN - | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY //TODO This rule accepts spaces between brace and minus sign, i.e: { - S2 - } that is not supposed to. Handle this case in https://st.yandex-team.ru/YQL-16227 + | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY | row_pattern_permute ; diff --git a/yql/essentials/sql/v1/match_recognize.cpp b/yql/essentials/sql/v1/match_recognize.cpp index 47055e2f3d..84a20ae273 100644 --- a/yql/essentials/sql/v1/match_recognize.cpp +++ b/yql/essentials/sql/v1/match_recognize.cpp @@ -21,7 +21,7 @@ public: std::pair<TPosition, TVector<TNamedFunction>>&& partitioners, std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs, std::pair<TPosition, TVector<TNamedFunction>>&& measures, - std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch, + std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch>&& rowsPerMatch, std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo, std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern, std::pair<TPosition, TNodePtr>&& subset, @@ -56,7 +56,7 @@ private: std::pair<TPosition, TVector<TNamedFunction>>&& partitioners, std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs, std::pair<TPosition, TVector<TNamedFunction>>&& measures, - std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch, + std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch>&& rowsPerMatch, std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo, std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern, std::pair<TPosition, TNodePtr>&& subset, diff --git a/yql/essentials/sql/v1/match_recognize.h b/yql/essentials/sql/v1/match_recognize.h index b78c0faf65..4b0e98b9b7 100644 --- a/yql/essentials/sql/v1/match_recognize.h +++ b/yql/essentials/sql/v1/match_recognize.h @@ -10,11 +10,6 @@ struct TNamedFunction { TString name; }; -enum class ERowsPerMatch { - OneRow, - AllRows -}; - class TMatchRecognizeBuilder: public TSimpleRefCount<TMatchRecognizeBuilder> { public: TMatchRecognizeBuilder( @@ -22,7 +17,7 @@ public: std::pair<TPosition, TVector<TNamedFunction>>&& partitioners, std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs, std::pair<TPosition, TVector<TNamedFunction>>&& measures, - std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch, + std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch>&& rowsPerMatch, std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo, std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern, std::pair<TPosition, TNodePtr>&& subset, @@ -45,7 +40,7 @@ private: std::pair<TPosition, TVector<TNamedFunction>> Partitioners; std::pair<TPosition, TVector<TSortSpecificationPtr>> SortSpecs; std::pair<TPosition, TVector<TNamedFunction>> Measures; - std::pair<TPosition, ERowsPerMatch> RowsPerMatch; + std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch> RowsPerMatch; std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> SkipTo; std::pair<TPosition, NYql::NMatchRecognize::TRowPattern> Pattern; std::pair<TPosition, TNodePtr> Subset; diff --git a/yql/essentials/sql/v1/sql_match_recognize.cpp b/yql/essentials/sql/v1/sql_match_recognize.cpp index 47e001efbb..41415b7f23 100644 --- a/yql/essentials/sql/v1/sql_match_recognize.cpp +++ b/yql/essentials/sql/v1/sql_match_recognize.cpp @@ -53,15 +53,9 @@ TMatchRecognizeBuilderPtr TSqlMatchRecognizeClause::CreateBuilder(const NSQLv1Ge measures = ParseMeasures(measuresClause.GetRule_row_pattern_measure_list2()); } - TPosition rowsPerMatchPos = pos; - ERowsPerMatch rowsPerMatch = ERowsPerMatch::OneRow; + auto rowsPerMatch = std::pair {pos, NYql::NMatchRecognize::ERowsPerMatch::OneRow}; if (matchRecognizeClause.HasBlock6()) { - std::tie(rowsPerMatchPos, rowsPerMatch) = ParseRowsPerMatch(matchRecognizeClause.GetBlock6().GetRule_row_pattern_rows_per_match1()); - if (ERowsPerMatch::AllRows == rowsPerMatch) { - //https://st.yandex-team.ru/YQL-16213 - Ctx.Error(pos, TIssuesIds::CORE) << "ALL ROWS PER MATCH is not supported yet"; - return {}; - } + rowsPerMatch = ParseRowsPerMatch(matchRecognizeClause.GetBlock6().GetRule_row_pattern_rows_per_match1()); } const auto& commonSyntax = matchRecognizeClause.GetRule_row_pattern_common_syntax7(); @@ -126,7 +120,7 @@ TMatchRecognizeBuilderPtr TSqlMatchRecognizeClause::CreateBuilder(const NSQLv1Ge std::pair{partitionsPos, std::move(partitioners)}, std::pair{orderByPos, std::move(sortSpecs)}, std::pair{measuresPos, measures}, - std::pair{rowsPerMatchPos, rowsPerMatch}, + std::move(rowsPerMatch), std::move(skipTo), std::pair{patternPos, std::move(pattern)}, std::pair{subsetPos, std::move(subset)}, @@ -159,7 +153,6 @@ TNamedFunction TSqlMatchRecognizeClause::ParseOneMeasure(const TRule_row_pattern TColumnRefScope scope(Ctx, EColumnRefState::MatchRecognize); const auto& expr = TSqlExpression(Ctx, Mode).Build(node.GetRule_expr1()); const auto& name = Id(node.GetRule_an_id3(), *this); - //TODO https://st.yandex-team.ru/YQL-16186 //Each measure must be a lambda, that accepts 2 args: // - List<InputTableColumns + _yql_Classifier, _yql_MatchNumber> // - Struct that maps row pattern variables to ranges in the queue @@ -174,18 +167,18 @@ TVector<TNamedFunction> TSqlMatchRecognizeClause::ParseMeasures(const TRule_row_ return result; } -std::pair<TPosition, ERowsPerMatch> TSqlMatchRecognizeClause::ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause) { +std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch> TSqlMatchRecognizeClause::ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause) { switch(rowsPerMatchClause.GetAltCase()) { case TRule_row_pattern_rows_per_match::kAltRowPatternRowsPerMatch1: return std::pair { TokenPosition(rowsPerMatchClause.GetAlt_row_pattern_rows_per_match1().GetToken1()), - ERowsPerMatch::OneRow + NYql::NMatchRecognize::ERowsPerMatch::OneRow }; case TRule_row_pattern_rows_per_match::kAltRowPatternRowsPerMatch2: return std::pair { TokenPosition(rowsPerMatchClause.GetAlt_row_pattern_rows_per_match2().GetToken1()), - ERowsPerMatch::AllRows + NYql::NMatchRecognize::ERowsPerMatch::AllRows }; case TRule_row_pattern_rows_per_match::ALT_NOT_SET: Y_ABORT("You should change implementation according to grammar changes"); @@ -233,13 +226,13 @@ std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> TSqlMatchRecogniz } } -NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTerm(const TRule_row_pattern_term& node){ +NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTerm(const TRule_row_pattern_term& node, size_t patternNestingLevel, bool outputArg) { NYql::NMatchRecognize::TRowPatternTerm term; TPosition pos; for (const auto& factor: node.GetBlock1()) { const auto& primaryVar = factor.GetRule_row_pattern_factor1().GetRule_row_pattern_primary1(); NYql::NMatchRecognize::TRowPatternPrimary primary; - bool output = true; + bool output = outputArg; switch (primaryVar.GetAltCase()) { case TRule_row_pattern_primary::kAltRowPatternPrimary1: primary = PatternVar(primaryVar.GetAlt_row_pattern_primary1().GetRule_row_pattern_primary_variable_name1().GetRule_row_pattern_variable_name1(), *this); @@ -253,9 +246,8 @@ NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTer Y_ENSURE("^" == std::get<0>(primary)); break; case TRule_row_pattern_primary::kAltRowPatternPrimary4: { - if (++PatternNestingLevel <= NYql::NMatchRecognize::MaxPatternNesting) { - primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary4().GetBlock2().GetRule_row_pattern1()); - --PatternNestingLevel; + if (patternNestingLevel <= NYql::NMatchRecognize::MaxPatternNesting) { + primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary4().GetBlock2().GetRule_row_pattern1(), patternNestingLevel + 1, output); } else { Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1())) << "To big nesting level in the pattern"; @@ -265,15 +257,14 @@ NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTer } case TRule_row_pattern_primary::kAltRowPatternPrimary5: output = false; - Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1())) - << "ALL ROWS PER MATCH and {- -} are not supported yet"; //https://st.yandex-team.ru/YQL-16227 + primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary5().GetRule_row_pattern3(), patternNestingLevel + 1, output); break; case TRule_row_pattern_primary::kAltRowPatternPrimary6: { std::vector<NYql::NMatchRecognize::TRowPatternPrimary> items{ParsePattern( - primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetRule_row_pattern3()) + primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetRule_row_pattern3(), patternNestingLevel + 1, output) }; for (const auto& p: primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetBlock4()) { - items.push_back(ParsePattern(p.GetRule_row_pattern2())); + items.push_back(ParsePattern(p.GetRule_row_pattern2(), patternNestingLevel + 1, output)); } //Permutations now is a syntactic sugar and converted to all possible alternatives if (items.size() > NYql::NMatchRecognize::MaxPermutedItems) { @@ -346,11 +337,11 @@ NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTer return term; } -NYql::NMatchRecognize::TRowPattern TSqlMatchRecognizeClause::ParsePattern(const TRule_row_pattern& node){ +NYql::NMatchRecognize::TRowPattern TSqlMatchRecognizeClause::ParsePattern(const TRule_row_pattern& node, size_t patternNestingLevel, bool output){ TVector<NYql::NMatchRecognize::TRowPatternTerm> result; - result.push_back(ParsePatternTerm(node.GetRule_row_pattern_term1())); + result.push_back(ParsePatternTerm(node.GetRule_row_pattern_term1(), patternNestingLevel, output)); for (const auto& term: node.GetBlock2()) - result.push_back(ParsePatternTerm(term.GetRule_row_pattern_term2())); + result.push_back(ParsePatternTerm(term.GetRule_row_pattern_term2(), patternNestingLevel, output)); return result; } @@ -364,7 +355,6 @@ TNamedFunction TSqlMatchRecognizeClause::ParseOneDefinition(const TRule_row_patt TVector<TNamedFunction> TSqlMatchRecognizeClause::ParseDefinitions(const TRule_row_pattern_definition_list& node) { TVector<TNamedFunction> result { ParseOneDefinition(node.GetRule_row_pattern_definition1())}; for (const auto& d: node.GetBlock2()) { - //TODO https://st.yandex-team.ru/YQL-16186 //Each define must be a predicate lambda, that accepts 3 args: // - List<input table rows> // - A struct that maps row pattern variables to ranges in the queue diff --git a/yql/essentials/sql/v1/sql_match_recognize.h b/yql/essentials/sql/v1/sql_match_recognize.h index 6766acc953..219baeaa09 100644 --- a/yql/essentials/sql/v1/sql_match_recognize.h +++ b/yql/essentials/sql/v1/sql_match_recognize.h @@ -17,14 +17,12 @@ private: TVector<TNamedFunction> ParsePartitionBy(const TRule_window_partition_clause& partitionClause); TNamedFunction ParseOneMeasure(const TRule_row_pattern_measure_definition& node); TVector<TNamedFunction> ParseMeasures(const TRule_row_pattern_measure_list& node); - std::pair<TPosition, ERowsPerMatch> ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause); + std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch> ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause); std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> ParseAfterMatchSkipTo(const TRule_row_pattern_skip_to& skipToClause); - NYql::NMatchRecognize::TRowPatternTerm ParsePatternTerm(const TRule_row_pattern_term& node); - NYql::NMatchRecognize::TRowPattern ParsePattern(const TRule_row_pattern& node); + NYql::NMatchRecognize::TRowPatternTerm ParsePatternTerm(const TRule_row_pattern_term& node, size_t patternNestingLevel, bool output); + NYql::NMatchRecognize::TRowPattern ParsePattern(const TRule_row_pattern& node, size_t patternNestingLevel = 1, bool output = true); TNamedFunction ParseOneDefinition(const TRule_row_pattern_definition& node); TVector<TNamedFunction> ParseDefinitions(const TRule_row_pattern_definition_list& node); -private: - size_t PatternNestingLevel = 0; }; } // namespace NSQLTranslationV1 diff --git a/yql/essentials/sql/v1/sql_match_recognize_ut.cpp b/yql/essentials/sql/v1/sql_match_recognize_ut.cpp index 20c5e6ab7b..f591ef0647 100644 --- a/yql/essentials/sql/v1/sql_match_recognize_ut.cpp +++ b/yql/essentials/sql/v1/sql_match_recognize_ut.cpp @@ -183,7 +183,7 @@ FROM Input MATCH_RECOGNIZE( ) )"; auto r = MatchRecognizeSqlToYql(stmt); - UNIT_ASSERT(not r.IsOk()); ///https://st.yandex-team.ru/YQL-16213 + UNIT_ASSERT(r.IsOk()); } { //default const auto stmt = R"( diff --git a/yql/essentials/tests/sql/sql2yql/canondata/result.json b/yql/essentials/tests/sql/sql2yql/canondata/result.json index c4b7d336e6..0c46d1d209 100644 --- a/yql/essentials/tests/sql/sql2yql/canondata/result.json +++ b/yql/essentials/tests/sql/sql2yql/canondata/result.json @@ -11213,6 +11213,13 @@ "uri": "https://{canondata_backend}/1942173/99e88108149e222741552e7e6cddef041d6a2846/resource.tar.gz#test_sql2yql.test_match_recognize-alerts_without_order_/sql.yql" } ], + "test_sql2yql.test[match_recognize-all_rows_per_match]": [ + { + "checksum": "31a940b1bbcea146ae9383e278326c7a", + "size": 6729, + "uri": "https://{canondata_backend}/1889210/954e2f1656d98697ece5794c59acf75dd1d40612/resource.tar.gz#test_sql2yql.test_match_recognize-all_rows_per_match_/sql.yql" + } + ], "test_sql2yql.test[match_recognize-greedy_quantifiers]": [ { "checksum": "41e90a3a986f9b2a7a36a83b918667cc", @@ -28001,6 +28008,11 @@ "uri": "file://test_sql_format.test_match_recognize-alerts_without_order_/formatted.sql" } ], + "test_sql_format.test[match_recognize-all_rows_per_match]": [ + { + "uri": "file://test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql" + } + ], "test_sql_format.test[match_recognize-greedy_quantifiers]": [ { "uri": "file://test_sql_format.test_match_recognize-greedy_quantifiers_/formatted.sql" diff --git a/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql b/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql new file mode 100644 index 0000000000..edcc854cfa --- /dev/null +++ b/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql @@ -0,0 +1,52 @@ +PRAGMA FeatureR010 = "prototype"; + +$input = + SELECT + * + FROM + AS_TABLE([ + <|time: 0, value: 0|>, + <|time: 100, value: 1|>, + <|time: 200, value: 2|>, + <|time: 300, value: 3|>, + <|time: 400, value: 4|>, + <|time: 500, value: 5|>, + <|time: 600, value: 0|>, + <|time: 700, value: 1|>, + <|time: 800, value: 2|>, + <|time: 900, value: 3|>, + <|time: 1000, value: 4|>, + <|time: 1100, value: 5|>, + <|time: 1200, value: 0|>, + ]) +; + +SELECT + * +FROM + $input MATCH_RECOGNIZE ( + ORDER BY + CAST(time AS Timestamp) + MEASURES + FIRST(A.time) AS a_time, + FIRST(B.time) AS b_time, + LAST(C.time) AS c_time, + LAST(F.time) AS f_time + ALL ROWS PER MATCH + AFTER MATCH SKIP PAST LAST ROW + PATTERN (A B {- C -} D {- E -} F +) + DEFINE + A AS A.value == 0 + AND COALESCE(A.time - FIRST(A.time) <= 1000, TRUE), + B AS B.value == 1 + AND COALESCE(B.time - FIRST(A.time) <= 1000, TRUE), + C AS C.value == 2 + AND COALESCE(C.time - FIRST(A.time) <= 1000, TRUE), + D AS D.value == 3 + AND COALESCE(D.time - FIRST(A.time) <= 1000, TRUE), + E AS E.value == 4 + AND COALESCE(E.time - FIRST(A.time) <= 1000, TRUE), + F AS F.value == 5 + AND COALESCE(F.time - FIRST(A.time) <= 1000, TRUE) + ) +; diff --git a/yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql b/yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql new file mode 100644 index 0000000000..ce55e529fb --- /dev/null +++ b/yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql @@ -0,0 +1,48 @@ +PRAGMA FeatureR010="prototype"; + +$input = SELECT * FROM AS_TABLE([ + <|time: 0, value: 0|>, + <|time: 100, value: 1|>, + <|time: 200, value: 2|>, + <|time: 300, value: 3|>, + <|time: 400, value: 4|>, + <|time: 500, value: 5|>, + <|time: 600, value: 0|>, + <|time: 700, value: 1|>, + <|time: 800, value: 2|>, + <|time: 900, value: 3|>, + <|time: 1000, value: 4|>, + <|time: 1100, value: 5|>, + <|time: 1200, value: 0|>, +]); + +SELECT * FROM $input MATCH_RECOGNIZE( + ORDER BY CAST(time as Timestamp) + MEASURES + FIRST(A.time) AS a_time, + FIRST(B.time) AS b_time, + LAST(C.time) AS c_time, + LAST(F.time) AS f_time + ALL ROWS PER MATCH + AFTER MATCH SKIP PAST LAST ROW + PATTERN (A B {- C -} D {- E -} F+) + DEFINE + A AS + A.value = 0 AND + COALESCE(A.time - FIRST(A.time) <= 1000, TRUE), + B AS + B.value = 1 AND + COALESCE(B.time - FIRST(A.time) <= 1000, TRUE), + C AS + C.value = 2 AND + COALESCE(C.time - FIRST(A.time) <= 1000, TRUE), + D AS + D.value = 3 AND + COALESCE(D.time - FIRST(A.time) <= 1000, TRUE), + E AS + E.value = 4 AND + COALESCE(E.time - FIRST(A.time) <= 1000, TRUE), + F AS + F.value = 5 AND + COALESCE(F.time - FIRST(A.time) <= 1000, TRUE) +); |