aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvokayndzop <vokayndzop@yandex-team.com>2024-12-16 15:55:05 +0300
committervokayndzop <vokayndzop@yandex-team.com>2024-12-16 16:34:36 +0300
commitb1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c (patch)
tree230bddb8bb4ce7d8290a16a4465ec98dbf513a5a
parent88e0ad5922cea1349ec1f8cbf133524cf865d696 (diff)
downloadydb-b1cde7dcb055fb6f3367e81fd0f57bd55b8bb93c.tar.gz
MR: support ALL ROWS PER MATCH
commit_hash:9e2ba38d0d523bb870f6dc76717a3bec5d8ffadc
-rw-r--r--yql/essentials/core/sql_types/match_recognize.h14
-rw-r--r--yql/essentials/core/type_ann/type_ann_match_recognize.cpp90
-rw-r--r--yql/essentials/core/yql_opt_match_recognize.cpp72
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp207
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h71
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h2
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h27
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp144
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h72
-rw-r--r--yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp63
-rw-r--r--yql/essentials/minikql/comp_nodes/ya.make.inc1
-rw-r--r--yql/essentials/minikql/mkql_program_builder.cpp198
-rw-r--r--yql/essentials/minikql/mkql_program_builder.h11
-rw-r--r--yql/essentials/providers/common/mkql/yql_provider_mkql.cpp48
-rw-r--r--yql/essentials/sql/v1/SQLv1.g.in2
-rw-r--r--yql/essentials/sql/v1/SQLv1Antlr4.g.in2
-rw-r--r--yql/essentials/sql/v1/match_recognize.cpp4
-rw-r--r--yql/essentials/sql/v1/match_recognize.h9
-rw-r--r--yql/essentials/sql/v1/sql_match_recognize.cpp42
-rw-r--r--yql/essentials/sql/v1/sql_match_recognize.h8
-rw-r--r--yql/essentials/sql/v1/sql_match_recognize_ut.cpp2
-rw-r--r--yql/essentials/tests/sql/sql2yql/canondata/result.json12
-rw-r--r--yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql52
-rw-r--r--yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql48
24 files changed, 800 insertions, 401 deletions
diff --git a/yql/essentials/core/sql_types/match_recognize.h b/yql/essentials/core/sql_types/match_recognize.h
index 0c6105ad94..e142587890 100644
--- a/yql/essentials/core/sql_types/match_recognize.h
+++ b/yql/essentials/core/sql_types/match_recognize.h
@@ -23,6 +23,16 @@ struct TAfterMatchSkipTo {
[[nodiscard]] bool operator==(const TAfterMatchSkipTo&) const noexcept = default;
};
+enum class ERowsPerMatch {
+ OneRow,
+ AllRows
+};
+enum class EOutputColumnSource {
+ PartitionKey,
+ Measure,
+ Other,
+};
+
constexpr size_t MaxPatternNesting = 20; //Limit recursion for patterns
constexpr size_t MaxPermutedItems = 6;
@@ -47,8 +57,8 @@ using TRowPatternPrimary = std::variant<TString, TRowPattern>;
struct TRowPatternFactor {
TRowPatternPrimary Primary;
- uint64_t QuantityMin;
- uint64_t QuantityMax;
+ ui64 QuantityMin;
+ ui64 QuantityMax;
bool Greedy;
bool Output; //include in output with ALL ROW PER MATCH
bool Unused; // optimization flag; is true when the variable is not used in defines and measures
diff --git a/yql/essentials/core/type_ann/type_ann_match_recognize.cpp b/yql/essentials/core/type_ann/type_ann_match_recognize.cpp
index 58d24e7b1b..4a6d61192e 100644
--- a/yql/essentials/core/type_ann/type_ann_match_recognize.cpp
+++ b/yql/essentials/core/type_ann/type_ann_match_recognize.cpp
@@ -10,11 +10,11 @@ MatchRecognizeWrapper(const TExprNode::TPtr &input, TExprNode::TPtr &output, TCo
if (!EnsureArgsCount(*input, 5, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto& source = input->ChildRef(0);
+ const auto source = input->Child(0);
auto& partitionKeySelector = input->ChildRef(1);
- const auto& partitionColumns = input->ChildRef(2);
- const auto& sortTraits = input->ChildRef(3);
- const auto& params = input->ChildRef(4);
+ const auto partitionColumns = input->Child(2);
+ const auto sortTraits = input->Child(3);
+ const auto params = input->Child(4);
Y_UNUSED(source, sortTraits);
auto status = ConvertToLambda(partitionKeySelector, ctx.Expr, 1, 1);
if (status.Level != IGraphTransformer::TStatus::Ok) {
@@ -31,11 +31,22 @@ MatchRecognizeWrapper(const TExprNode::TPtr &input, TExprNode::TPtr &output, TCo
//merge measure columns, came from params, with partition columns to form output row type
auto outputTableColumns = params->GetTypeAnn()->Cast<TStructExprType>()->GetItems();
- for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) {
- outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(
- partitionColumns->ChildRef(i)->Content(),
- partitionKeySelectorItemTypes[i]
- ));
+ if (const auto rowsPerMatch = params->Child(1);
+ "RowsPerMatch_OneRow" == rowsPerMatch->Content()) {
+ for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) {
+ outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(
+ partitionColumns->Child(i)->Content(),
+ partitionKeySelectorItemTypes[i]
+ ));
+ }
+ } else if ("RowsPerMatch_AllRows" == rowsPerMatch->Content()) {
+ const auto& inputTableColumns = GetSeqItemType(source->GetTypeAnn())->Cast<TStructExprType>()->GetItems();
+ for (const auto& column : inputTableColumns) {
+ outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(column->GetName(), column->GetItemType()));
+ }
+ } else {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(rowsPerMatch->Pos()), "Unknown RowsPerMatch option"));
+ return IGraphTransformer::TStatus::Error;
}
const auto outputTableRowType = ctx.Expr.MakeType<TStructExprType>(outputTableColumns);
input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(outputTableRowType));
@@ -48,7 +59,7 @@ MatchRecognizeParamsWrapper(const TExprNode::TPtr &input, TExprNode::TPtr &outpu
if (!EnsureArgsCount(*input, 5, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto& measures = input->ChildRef(0);
+ const auto measures = input->Child(0);
input->SetTypeAnn(measures->GetTypeAnn());
return IGraphTransformer::TStatus::Ok;
}
@@ -77,10 +88,10 @@ MatchRecognizeMeasuresWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& out
if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto& inputRowType = input->ChildRef(0);
- const auto& pattern = input->ChildRef(1);
- const auto& names = input->ChildRef(2);
- const size_t FirstLambdaIndex = 3;
+ const auto inputRowType = input->Child(0);
+ const auto pattern = input->Child(1);
+ const auto names = input->Child(2);
+ constexpr size_t FirstLambdaIndex = 3;
if (!EnsureTupleOfAtoms(*names, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -119,7 +130,7 @@ MatchRecognizeMeasuresWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& out
return IGraphTransformer::TStatus::Error;
}
if (auto type = lambda->GetTypeAnn()) {
- items.push_back(ctx.Expr.MakeType<TItemExprType>(names->ChildRef(i)->Content(), type));
+ items.push_back(ctx.Expr.MakeType<TItemExprType>(names->Child(i)->Content(), type));
} else {
return IGraphTransformer::TStatus::Repeat;
}
@@ -143,9 +154,9 @@ MatchRecognizeDefinesWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& outp
if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto& inputRowType = input->ChildRef(0);
- const auto& pattern = input->ChildRef(1);
- const auto& names = input->ChildRef(2);
+ const auto inputRowType = input->Child(0);
+ const auto pattern = input->Child(1);
+ const auto names = input->Child(2);
const size_t FirstLambdaIndex = 3;
if (!EnsureTupleOfAtoms(*names, ctx.Expr)) {
@@ -176,7 +187,7 @@ MatchRecognizeDefinesWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& outp
}
if (auto type = lambda->GetTypeAnn()) {
if (IsBoolLike(*type)) {
- items.push_back(ctx.Expr.MakeType<TItemExprType>(names->ChildRef(i)->Content(), type));
+ items.push_back(ctx.Expr.MakeType<TItemExprType>(names->Child(i)->Content(), type));
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(lambda->Pos()), "DEFINE expression must be a predicate"));
return IGraphTransformer::TStatus::Error;
@@ -200,18 +211,18 @@ bool ValidateSettings(const TExprNode::TPtr& settings, TExprContext& ctx) {
return false;
}
- const auto streamingMode = settings->ChildRef(0);
+ const auto streamingMode = settings->Child(0);
if (!EnsureTupleOfAtoms(*streamingMode, ctx)) {
return false;
}
if (!EnsureArgsCount(*streamingMode, 2, ctx)) {
return false;
}
- if (streamingMode->ChildRef(0)->Content() != "Streaming") {
+ if (streamingMode->Child(0)->Content() != "Streaming") {
ctx.AddError(TIssue(ctx.GetPosition(settings->Pos()), "Expected Streaming setting"));
return false;
}
- const auto mode = streamingMode->ChildRef(1)->Content();
+ const auto mode = streamingMode->Child(1)->Content();
if (mode != "0" and mode != "1") {
ctx.AddError(TIssue(ctx.GetPosition(settings->Pos()), TStringBuilder() << "Expected 0 or 1, but got: " << mode));
return false;
@@ -231,11 +242,11 @@ MatchRecognizeCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output,
if (!EnsureArgsCount(*input, 5, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto& source = input->ChildRef(0);
+ const auto source = input->Child(0);
auto& partitionKeySelector = input->ChildRef(1);
- const auto& partitionColumns = input->ChildRef(2);
- const auto& params = input->ChildRef(3);
- const auto& settings = input->ChildRef(4);
+ const auto partitionColumns = input->Child(2);
+ const auto params = input->Child(3);
+ const auto settings = input->Child(4);
if (not params->IsCallable("MatchRecognizeParams")) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(params->Pos()), "Expected MatchRecognizeParams"));
return IGraphTransformer::TStatus::Error;
@@ -248,9 +259,9 @@ MatchRecognizeCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output,
if (!EnsureFlowType(*source, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto& inputRowType = GetSeqItemType(source->GetTypeAnn());
- const auto& define = params->ChildRef(4);
- if (not inputRowType->Equals(*define->ChildRef(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType())) {
+ const auto inputRowType = GetSeqItemType(source->GetTypeAnn());
+ const auto define = params->Child(4);
+ if (not inputRowType->Equals(*define->Child(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType())) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Expected the same input row type as for DEFINE"));
return IGraphTransformer::TStatus::Error;
}
@@ -279,11 +290,22 @@ MatchRecognizeCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output,
}
auto outputTableColumns = params->GetTypeAnn()->Cast<TStructExprType>()->GetItems();
- for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) {
- outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(
- partitionColumns->ChildRef(i)->Content(),
- partitionKeySelectorItemTypes[i]
- ));
+ if (const auto rowsPerMatch = params->Child(1);
+ "RowsPerMatch_OneRow" == rowsPerMatch->Content()) {
+ for (size_t i = 0; i != partitionColumns->ChildrenSize(); ++i) {
+ outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(
+ partitionColumns->Child(i)->Content(),
+ partitionKeySelectorItemTypes[i]
+ ));
+ }
+ } else if ("RowsPerMatch_AllRows" == rowsPerMatch->Content()) {
+ const auto& inputTableColumns = GetSeqItemType(source->GetTypeAnn())->Cast<TStructExprType>()->GetItems();
+ for (const auto& column : inputTableColumns) {
+ outputTableColumns.push_back(ctx.Expr.MakeType<TItemExprType>(column->GetName(), column->GetItemType()));
+ }
+ } else {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(rowsPerMatch->Pos()), "Unknown RowsPerMatch option"));
+ return IGraphTransformer::TStatus::Error;
}
const auto outputTableRowType = ctx.Expr.MakeType<TStructExprType>(outputTableColumns);
input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputTableRowType));
diff --git a/yql/essentials/core/yql_opt_match_recognize.cpp b/yql/essentials/core/yql_opt_match_recognize.cpp
index 8e40e400a1..f2bef7d7d4 100644
--- a/yql/essentials/core/yql_opt_match_recognize.cpp
+++ b/yql/essentials/core/yql_opt_match_recognize.cpp
@@ -22,8 +22,8 @@ bool IsStreaming(const TExprNode::TPtr& input, const TTypeAnnotationContext& typ
bool hasPq = false;
NYql::VisitExpr(input, [&hasPq](const TExprNode::TPtr& node){
if (node->IsCallable("DataSource")) {
- YQL_ENSURE(node->ChildrenSize() > 0 and node->ChildRef(0)->IsAtom());
- hasPq = node->ChildRef(0)->Content() == "pq";
+ YQL_ENSURE(node->ChildrenSize() > 0 and node->Child(0)->IsAtom());
+ hasPq = node->Child(0)->Content() == "pq";
}
return !hasPq;
});
@@ -39,8 +39,8 @@ std::optional<TSet<TStringBuf>> FindUsedVars(const TExprNode::TPtr& params) {
const auto createVisitor = [&usedVars, &allVarsUsed](const TExprNode::TPtr& varsArg) {
return [&varsArg, &usedVars, &allVarsUsed](const TExprNode::TPtr& node) -> bool {
if (node->IsCallable("Member")) {
- if (node->ChildRef(0) == varsArg) {
- usedVars.insert(node->ChildRef(1)->Content());
+ if (node->Child(0) == varsArg) {
+ usedVars.insert(node->Child(1)->Content());
return false;
}
}
@@ -51,23 +51,23 @@ std::optional<TSet<TStringBuf>> FindUsedVars(const TExprNode::TPtr& params) {
};
};
- const auto& measures = params->ChildRef(0);
+ const auto measures = params->Child(0);
static constexpr size_t measureLambdasStartPos = 3;
for (size_t pos = measureLambdasStartPos; pos != measures->ChildrenSize(); pos++) {
- const auto& lambda = measures->ChildRef(pos);
- const auto& lambdaArgs = lambda->ChildRef(0);
- const auto& lambdaBody = lambda->ChildRef(1);
- const auto& varsArg = lambdaArgs->ChildRef(1);
+ const auto lambda = measures->Child(pos);
+ const auto lambdaArgs = lambda->Child(0);
+ const auto lambdaBody = lambda->ChildPtr(1);
+ const auto varsArg = lambdaArgs->ChildPtr(1);
NYql::VisitExpr(lambdaBody, createVisitor(varsArg));
}
- const auto& defines = params->ChildRef(4);
+ const auto defines = params->Child(4);
static constexpr size_t defineLambdasStartPos = 3;
for (size_t pos = defineLambdasStartPos; pos != defines->ChildrenSize(); pos++) {
- const auto& lambda = defines->ChildRef(pos);
- const auto& lambdaArgs = lambda->ChildRef(0);
- const auto& lambdaBody = lambda->ChildRef(1);
- const auto& varsArg = lambdaArgs->ChildRef(1);
+ const auto lambda = defines->Child(pos);
+ const auto lambdaArgs = lambda->Child(0);
+ const auto lambdaBody = lambda->ChildPtr(1);
+ const auto varsArg = lambdaArgs->ChildPtr(1);
NYql::VisitExpr(lambdaBody, createVisitor(varsArg));
}
@@ -75,25 +75,26 @@ std::optional<TSet<TStringBuf>> FindUsedVars(const TExprNode::TPtr& params) {
}
// usedVars can be std::nullopt if all vars could probably be used
-TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& ctx, const std::optional<TSet<TStringBuf>> &usedVars) {
+TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& ctx, const std::optional<TSet<TStringBuf>> &usedVars, TStringBuf rowsPerMatch) {
const auto pos = node->Pos();
- if (node->ChildrenSize() != 0 && node->ChildRef(0)->IsAtom()) {
- const auto& varName = node->ChildRef(0)->Content();
- bool varUsed = !usedVars.has_value() || usedVars.value().contains(varName);
+ if (node->ChildrenSize() != 0 && node->Child(0)->IsAtom()) {
+ const auto varName = node->Child(0)->Content();
+ const auto output = node->Child(4);
+ const auto varUnused = ("RowsPerMatch_AllRows" != rowsPerMatch || !output) && usedVars && !usedVars->contains(varName);
return ctx.Builder(pos)
.List()
- .Add(0, node->ChildRef(0))
- .Add(1, node->ChildRef(1))
- .Add(2, node->ChildRef(2))
- .Add(3, node->ChildRef(3))
- .Add(4, node->ChildRef(4))
- .Add(5, ctx.NewAtom(pos, varUsed ? "0" : "1"))
+ .Add(0, node->ChildPtr(0))
+ .Add(1, node->ChildPtr(1))
+ .Add(2, node->ChildPtr(2))
+ .Add(3, node->ChildPtr(3))
+ .Add(4, output)
+ .Add(5, ctx.NewAtom(pos, ToString(varUnused)))
.Seal()
.Build();
}
TExprNodeList newChildren;
for (size_t chPos = 0; chPos != node->ChildrenSize(); chPos++) {
- newChildren.push_back(MarkUnusedPatternVars(node->ChildRef(chPos), ctx, usedVars));
+ newChildren.push_back(MarkUnusedPatternVars(node->ChildPtr(chPos), ctx, usedVars, rowsPerMatch));
}
if (node->IsCallable()) {
return ctx.Builder(pos).Callable(node->Content()).Add(std::move(newChildren)).Seal().Build();
@@ -106,11 +107,11 @@ TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext&
TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typeAnnCtx) {
YQL_ENSURE(node->IsCallable({"MatchRecognize"}));
- const auto& input = node->ChildRef(0);
- const auto& partitionKeySelector = node->ChildRef(1);
- const auto& partitionColumns = node->ChildRef(2);
- const auto& sortTraits = node->ChildRef(3);
- const auto& params = node->ChildRef(4);
+ const auto input = node->Child(0);
+ const auto partitionKeySelector = node->Child(1);
+ const auto partitionColumns = node->Child(2);
+ const auto sortTraits = node->Child(3);
+ const auto params = node->ChildPtr(4);
const auto pos = node->Pos();
const bool isStreaming = IsStreaming(input, typeAnnCtx);
@@ -118,6 +119,7 @@ TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext&
TExprNode::TPtr settings = AddSetting(*ctx.NewList(pos, {}), pos,
"Streaming", ctx.NewAtom(pos, ToString(isStreaming)), ctx);
+ const auto rowsPerMatch = params->Child(1)->Content();
const auto matchRecognize = ctx.Builder(pos)
.Lambda()
.Param("sortedPartition")
@@ -129,11 +131,11 @@ TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext&
.Add(1, partitionKeySelector)
.Add(2, partitionColumns)
.Callable(3, params->Content())
- .Add(0, params->ChildRef(0))
- .Add(1, params->ChildRef(1))
- .Add(2, params->ChildRef(2))
- .Add(3, MarkUnusedPatternVars(params->ChildRef(3), ctx, FindUsedVars(params)))
- .Add(4, params->ChildRef(4))
+ .Add(0, params->ChildPtr(0))
+ .Add(1, params->ChildPtr(1))
+ .Add(2, params->ChildPtr(2))
+ .Add(3, MarkUnusedPatternVars(params->ChildPtr(3), ctx, FindUsedVars(params), rowsPerMatch))
+ .Add(4, params->ChildPtr(4))
.Seal()
.Add(4, settings)
.Seal()
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
index 80caaad37e..e1fc529fda 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -1,32 +1,28 @@
#include "mkql_match_recognize_list.h"
-#include "mkql_match_recognize_matched_vars.h"
#include "mkql_match_recognize_measure_arg.h"
+#include "mkql_match_recognize_matched_vars.h"
#include "mkql_match_recognize_nfa.h"
+#include "mkql_match_recognize_rows_formatter.h"
#include "mkql_match_recognize_save_load.h"
#include "mkql_saveload.h"
#include <yql/essentials/core/sql_types/match_recognize.h>
-#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_pack.h>
+#include <yql/essentials/minikql/mkql_node.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
-#include <yql/essentials/minikql/mkql_runtime_version.h>
#include <yql/essentials/minikql/mkql_string_util.h>
-#include <yql/essentials/core/sql_types/match_recognize.h>
+
#include <deque>
namespace NKikimr::NMiniKQL {
namespace NMatchRecognize {
-enum class EOutputColumnSource {PartitionKey, Measure};
-using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>, TMKQLAllocator<std::pair<EOutputColumnSource, size_t>>>;
-
constexpr ui32 StateVersion = 1;
-using namespace NYql::NMatchRecognize;
-
struct TMatchRecognizeProcessorParameters {
IComputationExternalNode* InputDataArg;
TRowPattern Pattern;
@@ -37,27 +33,21 @@ struct TMatchRecognizeProcessorParameters {
TComputationNodePtrVector Defines;
IComputationExternalNode* MeasureInputDataArg;
TMeasureInputColumnOrder MeasureInputColumnOrder;
- TComputationNodePtrVector Measures;
- TOutputColumnOrder OutputColumnOrder;
- TAfterMatchSkipTo SkipTo;
+ TAfterMatchSkipTo SkipTo;
};
class TStreamingMatchRecognize {
- using TPartitionList = TSparseList;
- using TRange = TPartitionList::TRange;
public:
TStreamingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
- TNfaTransitionGraph::TPtr nfaTransitions,
- const TContainerCacheOnContext& cache
- )
+ const IRowsFormatter::TState& rowsFormatterState,
+ TNfaTransitionGraph::TPtr nfaTransitions)
: PartitionKey(std::move(partitionKey))
, Parameters(parameters)
+ , RowsFormatter_(IRowsFormatter::Create(rowsFormatterState))
, Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo)
- , Cache(cache)
- {
- }
+ {}
bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows));
@@ -71,33 +61,26 @@ public:
}
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
+ if (auto result = RowsFormatter_->GetOtherMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph())) {
+ return result;
+ }
auto match = Nfa.GetMatched();
if (!match) {
return NUdf::TUnboxedValue{};
}
- Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match->Vars));
+ Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TSparseList::TRange>>(ctx.HolderFactory, match->Vars));
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows),
Parameters.MeasureInputColumnOrder,
Parameters.MatchedVarsArg->GetValue(ctx),
Parameters.VarNames,
MatchNumber
- ));
- NUdf::TUnboxedValue *itemsPtr = nullptr;
- const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr);
- for (auto const& c: Parameters.OutputColumnOrder) {
- switch(c.first) {
- case EOutputColumnSource::Measure:
- *itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx);
- break;
- case EOutputColumnSource::PartitionKey:
- *itemsPtr++ = PartitionKey.GetElement(c.second);
- break;
- }
- }
+ ));
+ auto result = RowsFormatter_->GetFirstMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph(), *match);
Nfa.AfterMatchSkip(*match);
return result;
}
+
bool ProcessEndOfData(TComputationContext& ctx) {
return Nfa.ProcessEndOfData(ctx);
}
@@ -117,11 +100,11 @@ public:
}
private:
- const NUdf::TUnboxedValue PartitionKey;
+ NUdf::TUnboxedValue PartitionKey;
const TMatchRecognizeProcessorParameters& Parameters;
+ std::unique_ptr<IRowsFormatter> RowsFormatter_;
TSparseList Rows;
TNfa Nfa;
- const TContainerCacheOnContext& Cache;
ui64 MatchNumber = 0;
};
@@ -135,7 +118,7 @@ public:
IComputationNode* partitionKey,
TType* partitionKeyType,
const TMatchRecognizeProcessorParameters& parameters,
- const TContainerCacheOnContext& cache,
+ const IRowsFormatter::TState& rowsFormatterState,
TComputationContext &ctx,
TType* rowType,
const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
@@ -145,8 +128,8 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
+ , RowsFormatterState(rowsFormatterState)
, RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
- , Cache(cache)
, Terminating(false)
, SerializerContext(ctx, rowType, rowPacker)
, Ctx(ctx)
@@ -184,8 +167,8 @@ public:
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(key),
Parameters,
- RowPatternConfiguration,
- Cache
+ RowsFormatterState,
+ RowPatternConfiguration
));
PartitionHandler->Load(in);
}
@@ -250,8 +233,9 @@ public:
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(partitionKey),
Parameters,
- RowPatternConfiguration,
- Cache));
+ RowsFormatterState,
+ RowPatternConfiguration
+ ));
PartitionHandler->ProcessInputRow(std::move(temp), ctx);
}
if (Terminating) {
@@ -266,8 +250,8 @@ private:
IComputationNode* PartitionKey;
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
+ const IRowsFormatter::TState& RowsFormatterState;
const TNfaTransitionGraph::TPtr RowPatternConfiguration;
- const TContainerCacheOnContext& Cache;
NUdf::TUnboxedValue DelayedRow;
bool Terminating;
TSerializerContext SerializerContext;
@@ -286,7 +270,7 @@ public:
IComputationNode* partitionKey,
TType* partitionKeyType,
const TMatchRecognizeProcessorParameters& parameters,
- const TContainerCacheOnContext& cache,
+ const IRowsFormatter::TState& rowsFormatterState,
TComputationContext &ctx,
TType* rowType,
const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
@@ -296,8 +280,8 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
+ , RowsFormatterState(rowsFormatterState)
, NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
- , Cache(cache)
, SerializerContext(ctx, rowType, rowPacker)
, Ctx(ctx)
{}
@@ -335,8 +319,10 @@ public:
std::make_unique<TStreamingMatchRecognize>(
std::move(key),
Parameters,
- NfaTransitionGraph,
- Cache));
+ RowsFormatterState,
+ NfaTransitionGraph
+ )
+ );
pair.first->second->Load(in);
}
@@ -402,8 +388,8 @@ private:
return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamingMatchRecognize>(
std::move(partitionKey),
Parameters,
- NfaTransitionGraph,
- Cache
+ RowsFormatterState,
+ NfaTransitionGraph
));
}
}
@@ -418,8 +404,8 @@ private:
//TODO switch to tuple compare
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
+ const IRowsFormatter::TState& RowsFormatterState;
const TNfaTransitionGraph::TPtr NfaTransitionGraph;
- const TContainerCacheOnContext& Cache;
TSerializerContext SerializerContext;
TComputationContext& Ctx;
};
@@ -428,20 +414,23 @@ template<class State>
class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true> {
using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true>;
public:
- TMatchRecognizeWrapper(TComputationMutables &mutables, EValueRepresentation kind, IComputationNode *inputFlow,
- IComputationExternalNode *inputRowArg,
- IComputationNode *partitionKey,
- TType* partitionKeyType,
- const TMatchRecognizeProcessorParameters& parameters,
- TType* rowType
- )
- :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
+ TMatchRecognizeWrapper(
+ TComputationMutables& mutables,
+ EValueRepresentation kind,
+ IComputationNode *inputFlow,
+ IComputationExternalNode *inputRowArg,
+ IComputationNode *partitionKey,
+ TType* partitionKeyType,
+ TMatchRecognizeProcessorParameters&& parameters,
+ IRowsFormatter::TState&& rowsFormatterState,
+ TType* rowType)
+ : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
, InputFlow(inputFlow)
, InputRowArg(inputRowArg)
, PartitionKey(partitionKey)
, PartitionKeyType(partitionKeyType)
- , Parameters(parameters)
- , Cache(mutables)
+ , Parameters(std::move(parameters))
+ , RowsFormatterState(std::move(rowsFormatterState))
, RowType(rowType)
, RowPacker(mutables)
{}
@@ -453,7 +442,7 @@ public:
PartitionKey,
PartitionKeyType,
Parameters,
- Cache,
+ RowsFormatterState,
ctx,
RowType,
RowPacker
@@ -468,7 +457,7 @@ public:
PartitionKey,
PartitionKeyType,
Parameters,
- Cache,
+ RowsFormatterState,
ctx,
RowType,
RowPacker
@@ -503,7 +492,7 @@ private:
Own(flow, Parameters.CurrentRowIndexArg);
Own(flow, Parameters.MeasureInputDataArg);
DependsOn(flow, PartitionKey);
- for (auto& m: Parameters.Measures) {
+ for (auto& m: RowsFormatterState.Measures) {
DependsOn(flow, m);
}
for (auto& d: Parameters.Defines) {
@@ -516,32 +505,31 @@ private:
IComputationExternalNode* const InputRowArg;
IComputationNode* const PartitionKey;
TType* const PartitionKeyType;
- const TMatchRecognizeProcessorParameters Parameters;
- const TContainerCacheOnContext Cache;
+ TMatchRecognizeProcessorParameters Parameters;
+ IRowsFormatter::TState RowsFormatterState;
TType* const RowType;
TMutableObjectOverBoxedValue<TValuePackerBoxed> RowPacker;
};
TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode partitionKyeColumnsIndexes, TRuntimeNode measureColumnsIndexes) {
- using tempMapValue = std::pair<EOutputColumnSource, size_t>;
- std::unordered_map<size_t, tempMapValue, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, tempMapValue>, EMemorySubPool::Temporary>> temp;
+ std::unordered_map<size_t, TOutputColumnEntry, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, TOutputColumnEntry>, EMemorySubPool::Temporary>> temp;
{
auto list = AS_VALUE(TListLiteral, partitionKyeColumnsIndexes);
for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>();
- temp[index] = std::make_pair(EOutputColumnSource::PartitionKey, i);
+ temp[index] = {i, EOutputColumnSource::PartitionKey};
}
}
{
auto list = AS_VALUE(TListLiteral, measureColumnsIndexes);
for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>();
- temp[index] = std::make_pair(EOutputColumnSource::Measure, i);
+ temp[index] = {i, EOutputColumnSource::Measure};
}
}
if (temp.empty())
return {};
- auto outputSize = max_element(temp.cbegin(), temp.cend())->first + 1;
+ auto outputSize = std::ranges::max_element(temp, {}, &std::pair<const size_t, TOutputColumnEntry>::first)->first + 1;
TOutputColumnOrder result(outputSize);
for (const auto& [i, v]: temp) {
result[i] = v;
@@ -576,7 +564,6 @@ TRowPattern ConvertPattern(const TRuntimeNode& pattern) {
}
TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) {
- using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns;
//Use Last enum value to denote that c colum comes from the input table
TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0));
if (specialColumnIndexes.GetItemsCount() != 0) {
@@ -621,7 +608,6 @@ std::pair<TUnboxedValueVector, THashMap<TString, size_t>> ConvertListOfStrings(c
} //namespace NMatchRecognize
-
IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
using namespace NMatchRecognize;
size_t inputIndex = 0;
@@ -641,9 +627,9 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
const auto& pattern = callable.GetInput(inputIndex++);
const auto& currentRowIndexArg = callable.GetInput(inputIndex++);
const auto& inputDataArg = callable.GetInput(inputIndex++);
- const auto& varNames = callable.GetInput(inputIndex++);
+ const auto& defineNames = callable.GetInput(inputIndex++);
TRuntimeNode::TList defines;
- for (size_t i = 0; i != AS_VALUE(TListLiteral, varNames)->GetItemsCount(); ++i) {
+ for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) {
defines.push_back(callable.GetInput(inputIndex++));
}
const auto& streamingMode = callable.GetInput(inputIndex++);
@@ -652,47 +638,58 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef();
}
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch = NYql::NMatchRecognize::ERowsPerMatch::OneRow;
+ TOutputColumnOrder outputColumnOrder;
+ if (inputIndex + 2 <= callable.GetInputsCount()) {
+ rowsPerMatch = static_cast<ERowsPerMatch>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
+ outputColumnOrder = IRowsFormatter::GetOutputColumnOrder(callable.GetInput(inputIndex++));
+ } else {
+ outputColumnOrder = GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes);
+ }
MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");
- const auto& [vars, varsLookup] = ConvertListOfStrings(varNames);
+ const auto& [varNames, varNamesLookup] = ConvertListOfStrings(defineNames);
auto* rowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputFlow.GetStaticType())->GetItemType());
- const auto parameters = TMatchRecognizeProcessorParameters {
- static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode()))
- , ConvertPattern(pattern)
- , vars
- , varsLookup
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode()))
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode()))
- , ConvertVectorOfCallables(defines, ctx)
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode()))
- , GetMeasureColumnOrder(
+ auto parameters = TMatchRecognizeProcessorParameters {
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())),
+ ConvertPattern(pattern),
+ varNames,
+ varNamesLookup,
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())),
+ ConvertVectorOfCallables(defines, ctx),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())),
+ GetMeasureColumnOrder(
*AS_VALUE(TListLiteral, measureSpecialColumnIndexes),
AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>()
- )
- , ConvertVectorOfCallables(measures, ctx)
- , GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
- , skipTo
+ ),
+ skipTo
};
+ IRowsFormatter::TState rowsFormatterState(ctx, outputColumnOrder, ConvertVectorOfCallables(measures, ctx), rowsPerMatch);
if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) {
- return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables
- , GetValueRepresentation(inputFlow.GetStaticType())
- , LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
- , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
- , partitionKeySelector.GetStaticType()
- , std::move(parameters)
- , rowType
+ return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(
+ ctx.Mutables,
+ GetValueRepresentation(inputFlow.GetStaticType()),
+ LocateNode(ctx.NodeLocator, *inputFlow.GetNode()),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())),
+ LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()),
+ partitionKeySelector.GetStaticType(),
+ std::move(parameters),
+ std::move(rowsFormatterState),
+ rowType
);
} else {
- return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables
- , GetValueRepresentation(inputFlow.GetStaticType())
- , LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
- , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
- , partitionKeySelector.GetStaticType()
- , std::move(parameters)
- , rowType
+ return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(
+ ctx.Mutables,
+ GetValueRepresentation(inputFlow.GetStaticType()),
+ LocateNode(ctx.NodeLocator, *inputFlow.GetNode()),
+ static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())),
+ LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()),
+ partitionKeySelector.GetStaticType(),
+ std::move(parameters),
+ std::move(rowsFormatterState),
+ rowType
);
}
}
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h
index 3c7771a41e..cf9bfb46a9 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_list.h
@@ -18,26 +18,29 @@ public:
class TRange {
public:
TRange()
- : FromIndex(-1)
- , ToIndex(-1)
+ : FromIndex(Max())
+ , ToIndex(Max())
+ , NfaIndex_(Max())
{
}
explicit TRange(ui64 index)
- : FromIndex(index)
- , ToIndex(index)
+ : FromIndex(index)
+ , ToIndex(index)
+ , NfaIndex_(Max())
{
}
TRange(ui64 from, ui64 to)
- : FromIndex(from)
- , ToIndex(to)
+ : FromIndex(from)
+ , ToIndex(to)
+ , NfaIndex_(Max())
{
MKQL_ENSURE(FromIndex <= ToIndex, "Internal logic error");
}
bool IsValid() const {
- return true;
+ return FromIndex != Max<size_t>() && ToIndex != Max<size_t>();
}
size_t From() const {
@@ -50,6 +53,11 @@ public:
return ToIndex;
}
+ [[nodiscard]] size_t NfaIndex() const {
+ MKQL_ENSURE(IsValid(), "Internal logic error");
+ return NfaIndex_;
+ }
+
size_t Size() const {
MKQL_ENSURE(IsValid(), "Internal logic error");
return ToIndex - FromIndex + 1;
@@ -61,8 +69,9 @@ public:
}
private:
- ui64 FromIndex;
- ui64 ToIndex;
+ size_t FromIndex;
+ size_t ToIndex;
+ size_t NfaIndex_;
};
TRange Append(NUdf::TUnboxedValue&& value) {
@@ -156,14 +165,12 @@ class TSparseList {
private:
//TODO consider to replace hash table with contiguous chunks
- using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>;
-
using TStorage = std::unordered_map<
size_t,
TItem,
std::hash<size_t>,
std::equal_to<size_t>,
- TAllocator>;
+ TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>>;
TStorage Storage;
};
@@ -178,8 +185,9 @@ public:
public:
TRange()
: Container()
- , FromIndex(-1)
- , ToIndex(-1)
+ , FromIndex(Max())
+ , ToIndex(Max())
+ , NfaIndex_(Max())
{
}
@@ -187,6 +195,7 @@ public:
: Container(other.Container)
, FromIndex(other.FromIndex)
, ToIndex(other.ToIndex)
+ , NfaIndex_(other.NfaIndex_)
{
LockRange(FromIndex, ToIndex);
}
@@ -195,6 +204,7 @@ public:
: Container(other.Container)
, FromIndex(other.FromIndex)
, ToIndex(other.ToIndex)
+ , NfaIndex_(other.NfaIndex_)
{
other.Reset();
}
@@ -212,6 +222,7 @@ public:
Container = other.Container;
FromIndex = other.FromIndex;
ToIndex = other.ToIndex;
+ NfaIndex_ = other.NfaIndex_;
LockRange(FromIndex, ToIndex);
return *this;
}
@@ -224,20 +235,21 @@ public:
Container = other.Container;
FromIndex = other.FromIndex;
ToIndex = other.ToIndex;
+ NfaIndex_ = other.NfaIndex_;
other.Reset();
return *this;
}
friend inline bool operator==(const TRange& lhs, const TRange& rhs) {
- return std::tie(lhs.FromIndex, lhs.ToIndex) == std::tie(rhs.FromIndex, rhs.ToIndex);
+ return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) == std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_);
}
friend inline bool operator<(const TRange& lhs, const TRange& rhs) {
- return std::tie(lhs.FromIndex, lhs.ToIndex) < std::tie(rhs.FromIndex, rhs.ToIndex);
+ return std::tie(lhs.FromIndex, lhs.ToIndex, lhs.NfaIndex_) < std::tie(rhs.FromIndex, rhs.ToIndex, rhs.NfaIndex_);
}
bool IsValid() const {
- return static_cast<bool>(Container);
+ return static_cast<bool>(Container) && FromIndex != Max<size_t>() && ToIndex != Max<size_t>();
}
size_t From() const {
@@ -250,6 +262,15 @@ public:
return ToIndex;
}
+ [[nodiscard]] size_t NfaIndex() const {
+ MKQL_ENSURE(IsValid(), "Internal logic error");
+ return NfaIndex_;
+ }
+
+ void NfaIndex(size_t index) {
+ NfaIndex_ = index;
+ }
+
size_t Size() const {
MKQL_ENSURE(IsValid(), "Internal logic error");
return ToIndex - FromIndex + 1;
@@ -264,16 +285,17 @@ public:
void Release() {
UnlockRange(FromIndex, ToIndex);
Container.Reset();
- FromIndex = -1;
- ToIndex = -1;
+ FromIndex = Max();
+ ToIndex = Max();
+ NfaIndex_ = Max();
}
void Save(TMrOutputSerializer& serializer) const {
- serializer(Container, FromIndex, ToIndex);
+ serializer(Container, FromIndex, ToIndex, NfaIndex_);
}
void Load(TMrInputSerializer& serializer) {
- serializer(Container, FromIndex, ToIndex);
+ serializer(Container, FromIndex, ToIndex, NfaIndex_);
}
private:
@@ -281,6 +303,7 @@ public:
: Container(container)
, FromIndex(index)
, ToIndex(index)
+ , NfaIndex_(Max())
{}
void LockRange(size_t from, size_t to) {
@@ -297,13 +320,15 @@ public:
void Reset() {
Container.Reset();
- FromIndex = -1;
- ToIndex = -1;
+ FromIndex = Max();
+ ToIndex = Max();
+ NfaIndex_ = Max();
}
TContainerPtr Container;
size_t FromIndex;
size_t ToIndex;
+ size_t NfaIndex_;
};
public:
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
index 8c2576798a..de0ccc352f 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
@@ -15,7 +15,7 @@ void Extend(TMatchedVar<R>& var, const R& r) {
var.emplace_back(r);
} else {
MKQL_ENSURE(r.From() > var.back().To(), "Internal logic error");
- if (var.back().To() + 1 == r.From()) {
+ if (var.back().To() + 1 == r.From() && var.back().NfaIndex() == r.NfaIndex()) {
var.back().Extend();
} else {
var.emplace_back(r);
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
index 2b194212f4..c868379d40 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_nfa.h
@@ -20,9 +20,10 @@ struct TEpsilonTransitions {
friend constexpr bool operator==(const TEpsilonTransitions&, const TEpsilonTransitions&) = default;
};
struct TMatchedVarTransition {
+ size_t To;
ui32 VarIndex;
bool SaveState;
- size_t To;
+ bool ExcludeFromOutput;
friend constexpr bool operator==(const TMatchedVarTransition&, const TMatchedVarTransition&) = default;
};
struct TQuantityEnterTransition {
@@ -116,7 +117,7 @@ struct TNfaTransitionGraph {
serializer(tr.To);
},
[&](const TMatchedVarTransition& tr) {
- serializer(tr.VarIndex, tr.SaveState, tr.To);
+ serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput);
},
[&](const TQuantityEnterTransition& tr) {
serializer(tr.To);
@@ -141,7 +142,7 @@ struct TNfaTransitionGraph {
serializer(tr.To);
},
[&](TMatchedVarTransition& tr) {
- serializer(tr.VarIndex, tr.SaveState, tr.To);
+ serializer(tr.To, tr.VarIndex, tr.SaveState, tr.ExcludeFromOutput);
},
[&](TQuantityEnterTransition& tr) {
serializer(tr.To);
@@ -297,7 +298,7 @@ private:
auto input = AddNode();
auto output = AddNode();
auto item = factor.Primary.index() == 0 ?
- BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused) :
+ BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused, !factor.Output) :
BuildTerms(std::get<1>(factor.Primary), varNameToIndex);
if (1 == factor.QuantityMin && 1 == factor.QuantityMax) { //simple linear case
Graph->Transitions[input] = TEpsilonTransitions{{item.Input}};
@@ -319,15 +320,16 @@ private:
}
return {input, output};
}
- TNfaItem BuildVar(ui32 varIndex, bool isUsed) {
+ TNfaItem BuildVar(ui32 varIndex, bool isUsed, bool excludeFromOutput) {
auto input = AddNode();
auto matchVar = AddNode();
auto output = AddNode();
Graph->Transitions[input] = TEpsilonTransitions({matchVar});
Graph->Transitions[matchVar] = TMatchedVarTransition{
+ output,
varIndex,
isUsed,
- output,
+ excludeFromOutput,
};
return {input, output};
}
@@ -441,6 +443,7 @@ public:
if (matchedVarTransition->SaveState) {
auto vars = state.Match.Vars; //TODO get rid of this copy
auto& matchedVar = vars[varIndex];
+ currentRowLock.NfaIndex(state.Index);
Extend(matchedVar, currentRowLock);
newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers);
} else {
@@ -575,6 +578,10 @@ public:
}
}
+ const TNfaTransitionGraph& GetTransitionGraph() const {
+ return *TransitionGraph;
+ }
+
private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
@@ -593,14 +600,14 @@ private:
[&](const TEpsilonTransitions& epsilonTransitions) {
deletedStates.insert(state);
for (const auto& i : epsilonTransitions.To) {
- newStates.emplace(i, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, state.Quantifiers);
+ newStates.emplace(i, state.Match, state.Quantifiers);
}
},
[&](const TQuantityEnterTransition& quantityEnterTransition) {
deletedStates.insert(state);
auto quantifiers = state.Quantifiers; //TODO get rid of this copy
quantifiers.push_back(0);
- newStates.emplace(quantityEnterTransition.To, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(quantifiers));
+ newStates.emplace(quantityEnterTransition.To, state.Match, std::move(quantifiers));
},
[&](const TQuantityExitTransition& quantityExitTransition) {
deletedStates.insert(state);
@@ -608,12 +615,12 @@ private:
if (state.Quantifiers.back() + 1 < quantityMax) {
auto q = state.Quantifiers;
q.back()++;
- newStates.emplace(toFindMore, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
+ newStates.emplace(toFindMore, state.Match, std::move(q));
}
if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) {
auto q = state.Quantifiers;
q.pop_back();
- newStates.emplace(toMatched, TMatch{state.Match.BeginIndex, state.Match.EndIndex, state.Match.Vars}, std::move(q));
+ newStates.emplace(toMatched, state.Match, std::move(q));
}
},
}, TransitionGraph->Transitions[state.Index]);
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp
new file mode 100644
index 0000000000..6d70458b3b
--- /dev/null
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.cpp
@@ -0,0 +1,144 @@
+#include "mkql_match_recognize_rows_formatter.h"
+
+#include <yql/essentials/minikql/mkql_node.h>
+#include <yql/essentials/minikql/mkql_node_cast.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+namespace {
+
+class TOneRowFormatter final : public IRowsFormatter {
+public:
+ explicit TOneRowFormatter(const TState& state) : IRowsFormatter(state) {}
+
+ NUdf::TUnboxedValue GetFirstMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph,
+ const TNfa::TMatch& match) {
+ Match_ = match;
+ const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph);
+ IRowsFormatter::Clear();
+ return result;
+ }
+
+ NUdf::TUnboxedValue GetOtherMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph) {
+ return NUdf::TUnboxedValue{};
+ }
+};
+
+class TAllRowsFormatter final : public IRowsFormatter {
+public:
+ explicit TAllRowsFormatter(const IRowsFormatter::TState& state) : IRowsFormatter(state) {}
+
+ NUdf::TUnboxedValue GetFirstMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph,
+ const TNfa::TMatch& match) {
+ Match_ = match;
+ CurrentRowIndex_ = Match_.BeginIndex;
+ for (const auto& matchedVar : Match_.Vars) {
+ for (const auto& range : matchedVar) {
+ ToIndexToMatchRangeLookup_.emplace(range.To(), range);
+ }
+ }
+ return GetMatchRow(ctx, rows, partitionKey, graph);
+ }
+
+ NUdf::TUnboxedValue GetOtherMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph) {
+ return GetMatchRow(ctx, rows, partitionKey, graph);
+ }
+
+private:
+ NUdf::TUnboxedValue GetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) {
+ while (CurrentRowIndex_ <= Match_.EndIndex) {
+ if (auto iter = ToIndexToMatchRangeLookup_.lower_bound(CurrentRowIndex_);
+ iter == ToIndexToMatchRangeLookup_.end()) {
+ MKQL_ENSURE(false, "Internal logic error");
+ } else if (auto transition = std::get_if<TMatchedVarTransition>(&graph.Transitions.at(iter->second.NfaIndex()));
+ !transition) {
+ MKQL_ENSURE(false, "Internal logic error");
+ } else if (transition->ExcludeFromOutput) {
+ ++CurrentRowIndex_;
+ } else {
+ break;
+ }
+ }
+ if (CurrentRowIndex_ > Match_.EndIndex) {
+ return NUdf::TUnboxedValue{};
+ }
+ const auto result = DoGetMatchRow(ctx, rows, partitionKey, graph);
+ ++CurrentRowIndex_;
+ if (CurrentRowIndex_ == Match_.EndIndex) {
+ Clear();
+ }
+ return result;
+ }
+
+ void Clear() {
+ IRowsFormatter::Clear();
+ ToIndexToMatchRangeLookup_.clear();
+ }
+
+ TMap<size_t, const TSparseList::TRange&> ToIndexToMatchRangeLookup_;
+};
+
+} // anonymous namespace
+
+IRowsFormatter::IRowsFormatter(const TState& state) : State_(state) {}
+
+TOutputColumnOrder IRowsFormatter::GetOutputColumnOrder(
+ TRuntimeNode outputColumnOrder) {
+ TOutputColumnOrder result;
+ auto list = AS_VALUE(TListLiteral, outputColumnOrder);
+ TConstArrayRef<TRuntimeNode> items(list->GetItems(), list->GetItemsCount());
+ for (auto item : items) {
+ const auto entry = AS_VALUE(TStructLiteral, item);
+ result.emplace_back(
+ AS_VALUE(TDataLiteral, entry->GetValue(0))->AsValue().Get<ui32>(),
+ static_cast<EOutputColumnSource>(AS_VALUE(TDataLiteral, entry->GetValue(1))->AsValue().Get<i32>())
+ );
+ }
+ return result;
+}
+
+NUdf::TUnboxedValue IRowsFormatter::DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph) {
+ NUdf::TUnboxedValue *itemsPtr = nullptr;
+ const auto result = State_.Cache->NewArray(ctx, State_.OutputColumnOrder.size(), itemsPtr);
+ for (const auto& columnEntry: State_.OutputColumnOrder) {
+ switch(columnEntry.SourceType) {
+ case EOutputColumnSource::PartitionKey:
+ *itemsPtr++ = partitionKey.GetElement(columnEntry.Index);
+ break;
+ case EOutputColumnSource::Measure:
+ *itemsPtr++ = State_.Measures[columnEntry.Index]->GetValue(ctx);
+ break;
+ case EOutputColumnSource::Other:
+ *itemsPtr++ = rows.Get(CurrentRowIndex_).GetElement(columnEntry.Index);
+ break;
+ }
+ }
+ return result;
+}
+
+std::unique_ptr<IRowsFormatter> IRowsFormatter::Create(const IRowsFormatter::TState& state) {
+ switch (state.RowsPerMatch) {
+ case ERowsPerMatch::OneRow:
+ return std::unique_ptr<IRowsFormatter>(new TOneRowFormatter(state));
+ case ERowsPerMatch::AllRows:
+ return std::unique_ptr<IRowsFormatter>(new TAllRowsFormatter(state));
+ }
+}
+
+} //namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h
new file mode 100644
index 0000000000..39750bf4ea
--- /dev/null
+++ b/yql/essentials/minikql/comp_nodes/mkql_match_recognize_rows_formatter.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include "mkql_match_recognize_nfa.h"
+
+#include <yql/essentials/core/sql_types/match_recognize.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
+#include <yql/essentials/minikql/mkql_alloc.h>
+#include <yql/essentials/public/udf/udf_value.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+struct TOutputColumnEntry {
+ size_t Index;
+ NYql::NMatchRecognize::EOutputColumnSource SourceType;
+};
+using TOutputColumnOrder = std::vector<TOutputColumnEntry, TMKQLAllocator<TOutputColumnEntry>>;
+
+class IRowsFormatter {
+public:
+ struct TState {
+ std::unique_ptr<TContainerCacheOnContext> Cache;
+ TOutputColumnOrder OutputColumnOrder;
+ TComputationNodePtrVector Measures;
+ NYql::NMatchRecognize::ERowsPerMatch RowsPerMatch;
+
+ TState(
+ const TComputationNodeFactoryContext& ctx,
+ TOutputColumnOrder outputColumnOrder,
+ TComputationNodePtrVector measures,
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch)
+ : Cache(std::make_unique<TContainerCacheOnContext>(ctx.Mutables))
+ , OutputColumnOrder(std::move(outputColumnOrder))
+ , Measures(std::move(measures))
+ , RowsPerMatch(rowsPerMatch)
+ {}
+ };
+
+ explicit IRowsFormatter(const TState& state);
+ virtual ~IRowsFormatter() = default;
+
+ virtual NUdf::TUnboxedValue GetFirstMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph,
+ const TNfa::TMatch& match) = 0;
+
+ virtual NUdf::TUnboxedValue GetOtherMatchRow(
+ TComputationContext& ctx,
+ const TSparseList& rows,
+ const NUdf::TUnboxedValue& partitionKey,
+ const TNfaTransitionGraph& graph) = 0;
+
+ static TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode outputColumnOrder);
+
+ static std::unique_ptr<IRowsFormatter> Create(const TState& state);
+
+protected:
+ NUdf::TUnboxedValue DoGetMatchRow(TComputationContext& ctx, const TSparseList& rows, const NUdf::TUnboxedValue& partitionKey, const TNfaTransitionGraph& graph);
+
+ inline void Clear() {
+ Match_ = {};
+ CurrentRowIndex_ = Max();
+ }
+
+ const TState& State_;
+ TNfa::TMatch Match_ {};
+ size_t CurrentRowIndex_ = Max();
+};
+
+} // namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp
index a9fad1b6ef..513a72df5e 100644
--- a/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp
+++ b/yql/essentials/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp
@@ -64,57 +64,46 @@ namespace {
const TTestInputData& input) {
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
- auto structType = pgmBuilder.NewStructType({
- {"time", pgmBuilder.NewDataType(NUdf::TDataType<i64>::Id)},
- {"key", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)},
- {"sum", pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id)},
- {"part", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}});
+ const auto structType = pgmBuilder.NewStructType({
+ {"time", pgmBuilder.NewDataType(NUdf::EDataSlot::Int64)},
+ {"key", pgmBuilder.NewDataType(NUdf::EDataSlot::String)},
+ {"sum", pgmBuilder.NewDataType(NUdf::EDataSlot::Uint32)},
+ {"part", pgmBuilder.NewDataType(NUdf::EDataSlot::String)}
+ });
TVector<TRuntimeNode> items;
- for (size_t i = 0; i < input.size(); ++i)
- {
- auto time = pgmBuilder.NewDataLiteral<i64>(std::get<0>(input[i]));
- auto key = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<1>(input[i])));
- auto sum = pgmBuilder.NewDataLiteral<ui32>(std::get<2>(input[i]));
- auto part = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<3>(input[i])));
-
- auto item = pgmBuilder.NewStruct(structType,
- {{"time", time}, {"key", key}, {"sum", sum}, {"part", part}});
- items.push_back(std::move(item));
+ for (size_t i = 0; i < input.size(); ++i) {
+ const auto& [time, key, sum, part] = input[i];
+ items.push_back(pgmBuilder.NewStruct({
+ {"time", pgmBuilder.NewDataLiteral(time)},
+ {"key", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(key)},
+ {"sum", pgmBuilder.NewDataLiteral(sum)},
+ {"part", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(part)},
+ }));
}
const auto list = pgmBuilder.NewList(structType, std::move(items));
auto inputFlow = pgmBuilder.ToFlow(list);
-
- TVector<TStringBuf> partitionColumns;
- TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures = {{
- std::make_pair(
- TStringBuf("key"),
- [&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) {
- return pgmBuilder.NewDataLiteral<ui32>(56);
- }
- )}};
- TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines = {{
- std::make_pair(
- TStringBuf("A"),
- [&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) {
- return pgmBuilder.NewDataLiteral<bool>(true);
- }
- )}};
-
auto pgmReturn = pgmBuilder.MatchRecognizeCore(
inputFlow,
[&](TRuntimeNode item) {
- return pgmBuilder.Member(item, "part");
+ return pgmBuilder.NewTuple({pgmBuilder.Member(item, "part")});
},
- partitionColumns,
- getMeasures,
+ {},
+ {"key"sv},
+ {[&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) {
+ return pgmBuilder.NewDataLiteral<ui32>(56);
+ }},
{
{NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
},
- getDefines,
+ {"A"sv},
+ {[&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) {
+ return pgmBuilder.NewDataLiteral<bool>(true);
+ }},
streamingMode,
- {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}
+ {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""},
+ NYql::NMatchRecognize::ERowsPerMatch::OneRow
);
auto graph = setup.BuildGraph(pgmReturn);
diff --git a/yql/essentials/minikql/comp_nodes/ya.make.inc b/yql/essentials/minikql/comp_nodes/ya.make.inc
index 518f5f3c43..b2f0da8ac9 100644
--- a/yql/essentials/minikql/comp_nodes/ya.make.inc
+++ b/yql/essentials/minikql/comp_nodes/ya.make.inc
@@ -79,6 +79,7 @@ SET(ORIG_SOURCES
mkql_mapnext.cpp
mkql_map_join.cpp
mkql_match_recognize.cpp
+ mkql_match_recognize_rows_formatter.cpp
mkql_multihopping.cpp
mkql_multimap.cpp
mkql_next_value.cpp
diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp
index 691a642814..0b9197024d 100644
--- a/yql/essentials/minikql/mkql_program_builder.cpp
+++ b/yql/essentials/minikql/mkql_program_builder.cpp
@@ -1,5 +1,4 @@
#include "mkql_program_builder.h"
-#include "mkql_opt_literal.h"
#include "mkql_node_visitor.h"
#include "mkql_node_cast.h"
#include "mkql_runtime_version.h"
@@ -11,6 +10,7 @@
#include "yql/essentials/core/sql_types/time_order_recover.h"
#include <yql/essentials/parser/pg_catalog/catalog.h>
+#include <util/generic/overloaded.h>
#include <util/string/cast.h>
#include <util/string/printf.h>
#include <array>
@@ -6035,15 +6035,19 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil
TTupleLiteralBuilder termBuilder(env);
for (const auto& factor: term) {
TTupleLiteralBuilder factorBuilder(env);
- factorBuilder.Add(factor.Primary.index() == 0 ?
- programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(std::get<0>(factor.Primary)) :
- PatternToRuntimeNode(std::get<1>(factor.Primary), programBuilder)
- );
- factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMin));
- factorBuilder.Add(programBuilder.NewDataLiteral<ui64>(factor.QuantityMax));
- factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Greedy));
- factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Output));
- factorBuilder.Add(programBuilder.NewDataLiteral<bool>(factor.Unused));
+ factorBuilder.Add(std::visit(TOverloaded {
+ [&](const TString& s) {
+ return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s);
+ },
+ [&](const TRowPattern& pattern) {
+ return PatternToRuntimeNode(pattern, programBuilder);
+ },
+ }, factor.Primary));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output));
+ factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused));
termBuilder.Add({factorBuilder.Build(), true});
}
patternBuilder.Add({termBuilder.Build(), true});
@@ -6056,151 +6060,172 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil
TRuntimeNode TProgramBuilder::MatchRecognizeCore(
TRuntimeNode inputStream,
const TUnaryLambda& getPartitionKeySelectorNode,
- const TArrayRef<TStringBuf>& partitionColumns,
- const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
+ const TArrayRef<TStringBuf>& partitionColumnNames,
+ const TVector<TStringBuf>& measureColumnNames,
+ const TVector<TBinaryLambda>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
- const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
+ const TVector<TStringBuf>& defineVarNames,
+ const TVector<TTernaryLambda>& getDefines,
bool streamingMode,
- const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
+ const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
) {
MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);
const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType());
const auto inputRowArg = Arg(inputRowType);
const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg);
+ const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements();
- TStructTypeBuilder indexRangeTypeBuilder(Env);
- indexRangeTypeBuilder.Add("From", TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
- indexRangeTypeBuilder.Add("To", TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
- const auto& rangeList = TListType::Create(indexRangeTypeBuilder.Build(), Env);
+ const auto rangeList = NewListType(NewStructType({
+ {"From", NewDataType(NUdf::EDataSlot::Uint64)},
+ {"To", NewDataType(NUdf::EDataSlot::Uint64)}
+ }));
TStructTypeBuilder matchedVarsTypeBuilder(Env);
for (const auto& var: GetPatternVars(pattern)) {
matchedVarsTypeBuilder.Add(var, rangeList);
}
- TRuntimeNode matchedVarsArg = Arg(matchedVarsTypeBuilder.Build());
+ const auto matchedVarsType = matchedVarsTypeBuilder.Build();
+ TRuntimeNode matchedVarsArg = Arg(matchedVarsType);
//---These vars may be empty in case of no measures
TRuntimeNode measureInputDataArg;
std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow;
TVector<TRuntimeNode> measures;
- TVector<TType*> measureTypes;
//---
if (getMeasures.empty()) {
measureInputDataArg = Arg(Env.GetTypeOfVoidLazy());
} else {
- using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns;
measures.reserve(getMeasures.size());
- measureTypes.reserve(getMeasures.size());
specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last));
TStructTypeBuilder measureInputDataRowTypeBuilder(Env);
- for (ui32 i = 0; i != inputRowType->GetMembersCount(); ++i) {
+ for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) {
measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i));
}
measureInputDataRowTypeBuilder.Add(
MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier),
- TDataType::Create(NUdf::TDataType<NYql::NUdf::TUtf8>::Id, Env)
+ NewDataType(NUdf::EDataSlot::Utf8)
);
measureInputDataRowTypeBuilder.Add(
MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber),
- TDataType::Create(NUdf::TDataType<ui64>::Id, Env)
+ NewDataType(NUdf::EDataSlot::Uint64)
);
const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build();
- for (ui32 i = 0; i != measureInputDataRowType->GetMembersCount(); ++i) {
+ for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) {
//assume a few, if grows, it's better to use a lookup table here
static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5);
for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) {
if (measureInputDataRowType->GetMemberName(i) ==
NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j)))
- specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral<ui32>(i);
+ specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i);
}
}
- measureInputDataArg = Arg(TListType::Create(measureInputDataRowType, Env));
+ measureInputDataArg = Arg(NewListType(measureInputDataRowType));
for (size_t i = 0; i != getMeasures.size(); ++i) {
- measures.push_back(getMeasures[i].second(measureInputDataArg, matchedVarsArg));
- measureTypes.push_back(measures[i].GetStaticType());
+ measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg));
}
}
TStructTypeBuilder outputRowTypeBuilder(Env);
THashMap<TStringBuf, size_t> partitionColumnLookup;
- for (size_t i = 0; i != partitionColumns.size(); ++i) {
- const auto& name = partitionColumns[i];
- partitionColumnLookup[name] = i;
- outputRowTypeBuilder.Add(
- name,
- AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElementType(i)
- );
- }
THashMap<TStringBuf, size_t> measureColumnLookup;
- for (size_t i = 0; i != measures.size(); ++i) {
- const auto& name = getMeasures[i].first;
- measureColumnLookup[name] = i;
- outputRowTypeBuilder.Add(
- name,
- measures[i].GetStaticType()
- );
+ THashMap<TStringBuf, size_t> otherColumnLookup;
+ for (size_t i = 0; i < measureColumnNames.size(); ++i) {
+ const auto name = measureColumnNames[i];
+ measureColumnLookup.emplace(name, i);
+ outputRowTypeBuilder.Add(name, measures[i].GetStaticType());
+ }
+ switch (rowsPerMatch) {
+ case NYql::NMatchRecognize::ERowsPerMatch::OneRow:
+ for (size_t i = 0; i < partitionColumnNames.size(); ++i) {
+ const auto name = partitionColumnNames[i];
+ partitionColumnLookup.emplace(name, i);
+ outputRowTypeBuilder.Add(name, partitionColumnTypes[i]);
+ }
+ break;
+ case NYql::NMatchRecognize::ERowsPerMatch::AllRows:
+ for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) {
+ const auto name = inputRowType->GetMemberName(i);
+ otherColumnLookup.emplace(name, i);
+ outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i));
+ }
+ break;
}
auto outputRowType = outputRowTypeBuilder.Build();
std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size());
std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size());
- for (ui32 i = 0; i != outputRowType->GetMembersCount(); ++i) {
- if (auto it = partitionColumnLookup.find(outputRowType->GetMemberName(i)); it != partitionColumnLookup.end()) {
- partitionColumnIndexes[it->second] = NewDataLiteral<ui32>(i);
- }
- else if (auto it = measureColumnLookup.find(outputRowType->GetMemberName(i)); it != measureColumnLookup.end()) {
- measureColumnIndexes[it->second] = NewDataLiteral<ui32>(i);
+ TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()});
+ for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) {
+ const auto name = outputRowType->GetMemberName(i);
+ if (auto iter = partitionColumnLookup.find(name);
+ iter != partitionColumnLookup.end()) {
+ partitionColumnIndexes[iter->second] = NewDataLiteral(i);
+ outputColumnOrder.push_back(NewStruct({
+ std::pair{"Index", NewDataLiteral(iter->second)},
+ std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))},
+ }));
+ } else if (auto iter = measureColumnLookup.find(name);
+ iter != measureColumnLookup.end()) {
+ measureColumnIndexes[iter->second] = NewDataLiteral(i);
+ outputColumnOrder.push_back(NewStruct({
+ std::pair{"Index", NewDataLiteral(iter->second)},
+ std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))},
+ }));
+ } else if (auto iter = otherColumnLookup.find(name);
+ iter != otherColumnLookup.end()) {
+ outputColumnOrder.push_back(NewStruct({
+ std::pair{"Index", NewDataLiteral(iter->second)},
+ std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))},
+ }));
}
}
- auto outputType = (TType*)TFlowType::Create(outputRowType, Env);
+ const auto outputType = NewFlowType(outputRowType);
- THashMap<TStringBuf , size_t> patternVarLookup;
- for (ui32 i = 0; i != AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMembersCount(); ++i){
- patternVarLookup[AS_TYPE(TStructType, matchedVarsArg.GetStaticType())->GetMemberName(i)] = i;
+ THashMap<TStringBuf, size_t> patternVarLookup;
+ for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) {
+ patternVarLookup[matchedVarsType->GetMemberName(i)] = i;
}
- THashMap<TStringBuf , size_t> defineLookup;
- for (size_t i = 0; i != getDefines.size(); ++i) {
- defineLookup[getDefines[i].first] = i;
+ THashMap<TStringBuf, size_t> defineLookup;
+ for (size_t i = 0; i < defineVarNames.size(); ++i) {
+ const auto name = defineVarNames[i];
+ defineLookup[name] = i;
}
- std::vector<TRuntimeNode> defineNames(patternVarLookup.size());
- std::vector<TRuntimeNode> defineNodes(patternVarLookup.size());
- const auto& inputDataArg = Arg(TListType::Create(inputRowType, Env));
- const auto& currentRowIndexArg = Arg(TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
+ TVector<TRuntimeNode> defineNames(defineVarNames.size());
+ TVector<TRuntimeNode> defineNodes(patternVarLookup.size());
+ const auto inputDataArg = Arg(NewListType(inputRowType));
+ const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64));
for (const auto& [v, i]: patternVarLookup) {
defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v);
- if (const auto it = defineLookup.find(v); it != defineLookup.end()) {
- defineNodes[i] = getDefines[it->second].second(inputDataArg, matchedVarsArg, currentRowIndexArg);
- }
- else { //no predicate for var
- if ("$" == v || "^" == v) {
- //DO nothing, //will be handled in a specific way
- }
- else { // a var without a predicate matches any row
- defineNodes[i] = NewDataLiteral<bool>(true);
- }
+ if (auto iter = defineLookup.find(v);
+ iter != defineLookup.end()) {
+ defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg);
+ } else if ("$" == v || "^" == v) {
+ //DO nothing, //will be handled in a specific way
+ } else { // a var without a predicate matches any row
+ defineNodes[i] = NewDataLiteral(true);
}
}
TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType);
- auto indexType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env);
- auto indexListType = TListType::Create(indexType, Env);
+ const auto indexType = NewDataType(NUdf::EDataSlot::Uint32);
+ const auto outputColumnEntryType = NewStructType({
+ {"Index", NewDataType(NUdf::EDataSlot::Uint64)},
+ {"SourceType", NewDataType(NUdf::EDataSlot::Int32)},
+ });
callableBuilder.Add(inputStream);
callableBuilder.Add(inputRowArg);
callableBuilder.Add(partitionKeySelectorNode);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(partitionColumnIndexes.data(), partitionColumnIndexes.size(), indexListType, Env), true));
+ callableBuilder.Add(NewList(indexType, partitionColumnIndexes));
callableBuilder.Add(measureInputDataArg);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(
- specialColumnIndexesInMeasureInputDataRow.data(), specialColumnIndexesInMeasureInputDataRow.size(),
- indexListType, Env
- ),
- true));
- callableBuilder.Add(NewDataLiteral<ui32>(inputRowType->GetMembersCount()));
+ callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow));
+ callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount()));
callableBuilder.Add(matchedVarsArg);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(measureColumnIndexes.data(), measureColumnIndexes.size(), indexListType, Env), true));
+ callableBuilder.Add(NewList(indexType, measureColumnIndexes));
for (const auto& m: measures) {
callableBuilder.Add(m);
}
@@ -6209,16 +6234,19 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
callableBuilder.Add(currentRowIndexArg);
callableBuilder.Add(inputDataArg);
- const auto stringType = NewDataType(NUdf::EDataSlot::String);
- callableBuilder.Add(TRuntimeNode(TListLiteral::Create(defineNames.begin(), defineNames.size(), TListType::Create(stringType, Env), Env), true));
+ callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames));
for (const auto& d: defineNodes) {
callableBuilder.Add(d);
}
callableBuilder.Add(NewDataLiteral(streamingMode));
- if (RuntimeVersion >= 52U) {
+ if constexpr (RuntimeVersion >= 52U) {
callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
}
+ if constexpr (RuntimeVersion >= 53U) {
+ callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch)));
+ callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder));
+ }
return TRuntimeNode(callableBuilder.Build(), false);
}
diff --git a/yql/essentials/minikql/mkql_program_builder.h b/yql/essentials/minikql/mkql_program_builder.h
index 762d3a013e..9e6b0d97d0 100644
--- a/yql/essentials/minikql/mkql_program_builder.h
+++ b/yql/essentials/minikql/mkql_program_builder.h
@@ -712,12 +712,15 @@ public:
TRuntimeNode MatchRecognizeCore(
TRuntimeNode inputStream,
const TUnaryLambda& getPartitionKeySelectorNode,
- const TArrayRef<TStringBuf>& partitionColumns,
- const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
+ const TArrayRef<TStringBuf>& partitionColumnNames,
+ const TVector<TStringBuf>& measureColumnNames,
+ const TVector<TBinaryLambda>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
- const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
+ const TVector<TStringBuf>& defineVarNames,
+ const TVector<TTernaryLambda>& getDefines,
bool streamingMode,
- const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
+ const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
);
TRuntimeNode TimeOrderRecover(
diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp
index fb2a2a29c1..211fc6c84c 100644
--- a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp
+++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp
@@ -878,17 +878,17 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
const auto& settings = node.Child(4);
//explore params
- const auto& measures = params->ChildRef(0);
- const auto& skipTo = params->ChildRef(2);
- const auto& pattern = params->ChildRef(3);
- const auto& defines = params->ChildRef(4);
+ const auto measures = params->Child(0);
+ const auto skipTo = params->Child(2);
+ const auto pattern = params->Child(3);
+ const auto defines = params->Child(4);
//explore measures
- const auto measureNames = measures->ChildRef(2);
+ const auto measureNames = measures->Child(2);
constexpr size_t FirstMeasureLambdaIndex = 3;
//explore defines
- const auto defineNames = defines->ChildRef(2);
+ const auto defineNames = defines->Child(2);
const size_t FirstDefineLambdaIndex = 3;
TVector<TStringBuf> partitionColumnNames;
@@ -900,25 +900,21 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
return MkqlBuildLambda(*partitionKeySelector, ctx, {inputRowArg});
};
- TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines(defineNames->ChildrenSize());
+ TVector<TStringBuf> defineVarNames(defineNames->ChildrenSize());
+ TVector<TProgramBuilder::TTernaryLambda> getDefines(defineNames->ChildrenSize());
for (size_t i = 0; i != defineNames->ChildrenSize(); ++i) {
- getDefines[i] = std::pair{
- defineNames->ChildRef(i)->Content(),
- [i, defines, &ctx](TRuntimeNode data, TRuntimeNode matchedVars, TRuntimeNode rowIndex) {
- return MkqlBuildLambda(*defines->ChildRef(FirstDefineLambdaIndex + i), ctx,
- {data, matchedVars, rowIndex});
- }
+ defineVarNames[i] = defineNames->Child(i)->Content();
+ getDefines[i] = [i, defines, &ctx](TRuntimeNode data, TRuntimeNode matchedVars, TRuntimeNode rowIndex) {
+ return MkqlBuildLambda(*defines->Child(FirstDefineLambdaIndex + i), ctx, {data, matchedVars, rowIndex});
};
}
- TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures(measureNames->ChildrenSize());
+ TVector<TStringBuf> measureColumnNames(measureNames->ChildrenSize());
+ TVector<TProgramBuilder::TBinaryLambda> getMeasures(measureNames->ChildrenSize());
for (size_t i = 0; i != measureNames->ChildrenSize(); ++i) {
- getMeasures[i] = std::pair{
- measureNames->ChildRef(i)->Content(),
- [i, measures, &ctx](TRuntimeNode data, TRuntimeNode matchedVars) {
- return MkqlBuildLambda(*measures->ChildRef(FirstMeasureLambdaIndex + i), ctx,
- {data, matchedVars});
- }
+ measureColumnNames[i] = measureNames->Child(i)->Content();
+ getMeasures[i] = [i, measures, &ctx](TRuntimeNode data, TRuntimeNode matchedVars) {
+ return MkqlBuildLambda(*measures->Child(FirstMeasureLambdaIndex + i), ctx, {data, matchedVars});
};
}
@@ -928,18 +924,26 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
NYql::NMatchRecognize::EAfterMatchSkipTo to;
MKQL_ENSURE(TryFromString<NYql::NMatchRecognize::EAfterMatchSkipTo>(stringTo, to), "MATCH_RECOGNIZE: <row pattern skip to> cannot parse AfterMatchSkipTo mode");
+ auto rowsPerMatchString = params->Child(1)->Content();
+ MKQL_ENSURE(rowsPerMatchString.SkipPrefix("RowsPerMatch_"), R"(MATCH_RECOGNIZE: <row pattern rows per match> should start with "RowsPerMatch_")");
+ NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch;
+ MKQL_ENSURE(TryFromString<NYql::NMatchRecognize::ERowsPerMatch>(rowsPerMatchString, rowsPerMatch), "MATCH_RECOGNIZE: cannot parse RowsPerMatch mode");
+
const auto streamingMode = FromString<bool>(settings->Child(0)->Child(1)->Content());
return ctx.ProgramBuilder.MatchRecognizeCore(
MkqlBuildExpr(*inputStream, ctx),
getPartitionKeySelector,
partitionColumnNames,
+ measureColumnNames,
getMeasures,
NYql::NMatchRecognize::ConvertPattern(pattern, ctx.ExprCtx),
+ defineVarNames,
getDefines,
streamingMode,
- NYql::NMatchRecognize::TAfterMatchSkipTo{to, TString{var}}
- );
+ NYql::NMatchRecognize::TAfterMatchSkipTo{to, TString{var}},
+ rowsPerMatch
+ );
});
AddCallable("TimeOrderRecover", [](const TExprNode& node, TMkqlBuildContext& ctx) {
diff --git a/yql/essentials/sql/v1/SQLv1.g.in b/yql/essentials/sql/v1/SQLv1.g.in
index e9685c5094..670ad27e3e 100644
--- a/yql/essentials/sql/v1/SQLv1.g.in
+++ b/yql/essentials/sql/v1/SQLv1.g.in
@@ -460,7 +460,7 @@ row_pattern_primary:
| DOLLAR
| CARET
| LPAREN row_pattern? RPAREN
- | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY //TODO This rule accepts spaces between brace and minus sign, i.e: { - S2 - } that is not supposed to. Handle this case in https://st.yandex-team.ru/YQL-16227
+ | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY
| row_pattern_permute
;
diff --git a/yql/essentials/sql/v1/SQLv1Antlr4.g.in b/yql/essentials/sql/v1/SQLv1Antlr4.g.in
index 40593fe075..89131437e9 100644
--- a/yql/essentials/sql/v1/SQLv1Antlr4.g.in
+++ b/yql/essentials/sql/v1/SQLv1Antlr4.g.in
@@ -459,7 +459,7 @@ row_pattern_primary:
| DOLLAR
| CARET
| LPAREN row_pattern? RPAREN
- | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY //TODO This rule accepts spaces between brace and minus sign, i.e: { - S2 - } that is not supposed to. Handle this case in https://st.yandex-team.ru/YQL-16227
+ | LBRACE_CURLY MINUS row_pattern MINUS RBRACE_CURLY
| row_pattern_permute
;
diff --git a/yql/essentials/sql/v1/match_recognize.cpp b/yql/essentials/sql/v1/match_recognize.cpp
index 47055e2f3d..84a20ae273 100644
--- a/yql/essentials/sql/v1/match_recognize.cpp
+++ b/yql/essentials/sql/v1/match_recognize.cpp
@@ -21,7 +21,7 @@ public:
std::pair<TPosition, TVector<TNamedFunction>>&& partitioners,
std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs,
std::pair<TPosition, TVector<TNamedFunction>>&& measures,
- std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch,
+ std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch>&& rowsPerMatch,
std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo,
std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern,
std::pair<TPosition, TNodePtr>&& subset,
@@ -56,7 +56,7 @@ private:
std::pair<TPosition, TVector<TNamedFunction>>&& partitioners,
std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs,
std::pair<TPosition, TVector<TNamedFunction>>&& measures,
- std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch,
+ std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch>&& rowsPerMatch,
std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo,
std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern,
std::pair<TPosition, TNodePtr>&& subset,
diff --git a/yql/essentials/sql/v1/match_recognize.h b/yql/essentials/sql/v1/match_recognize.h
index b78c0faf65..4b0e98b9b7 100644
--- a/yql/essentials/sql/v1/match_recognize.h
+++ b/yql/essentials/sql/v1/match_recognize.h
@@ -10,11 +10,6 @@ struct TNamedFunction {
TString name;
};
-enum class ERowsPerMatch {
- OneRow,
- AllRows
-};
-
class TMatchRecognizeBuilder: public TSimpleRefCount<TMatchRecognizeBuilder> {
public:
TMatchRecognizeBuilder(
@@ -22,7 +17,7 @@ public:
std::pair<TPosition, TVector<TNamedFunction>>&& partitioners,
std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs,
std::pair<TPosition, TVector<TNamedFunction>>&& measures,
- std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch,
+ std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch>&& rowsPerMatch,
std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo,
std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern,
std::pair<TPosition, TNodePtr>&& subset,
@@ -45,7 +40,7 @@ private:
std::pair<TPosition, TVector<TNamedFunction>> Partitioners;
std::pair<TPosition, TVector<TSortSpecificationPtr>> SortSpecs;
std::pair<TPosition, TVector<TNamedFunction>> Measures;
- std::pair<TPosition, ERowsPerMatch> RowsPerMatch;
+ std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch> RowsPerMatch;
std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> SkipTo;
std::pair<TPosition, NYql::NMatchRecognize::TRowPattern> Pattern;
std::pair<TPosition, TNodePtr> Subset;
diff --git a/yql/essentials/sql/v1/sql_match_recognize.cpp b/yql/essentials/sql/v1/sql_match_recognize.cpp
index 47e001efbb..41415b7f23 100644
--- a/yql/essentials/sql/v1/sql_match_recognize.cpp
+++ b/yql/essentials/sql/v1/sql_match_recognize.cpp
@@ -53,15 +53,9 @@ TMatchRecognizeBuilderPtr TSqlMatchRecognizeClause::CreateBuilder(const NSQLv1Ge
measures = ParseMeasures(measuresClause.GetRule_row_pattern_measure_list2());
}
- TPosition rowsPerMatchPos = pos;
- ERowsPerMatch rowsPerMatch = ERowsPerMatch::OneRow;
+ auto rowsPerMatch = std::pair {pos, NYql::NMatchRecognize::ERowsPerMatch::OneRow};
if (matchRecognizeClause.HasBlock6()) {
- std::tie(rowsPerMatchPos, rowsPerMatch) = ParseRowsPerMatch(matchRecognizeClause.GetBlock6().GetRule_row_pattern_rows_per_match1());
- if (ERowsPerMatch::AllRows == rowsPerMatch) {
- //https://st.yandex-team.ru/YQL-16213
- Ctx.Error(pos, TIssuesIds::CORE) << "ALL ROWS PER MATCH is not supported yet";
- return {};
- }
+ rowsPerMatch = ParseRowsPerMatch(matchRecognizeClause.GetBlock6().GetRule_row_pattern_rows_per_match1());
}
const auto& commonSyntax = matchRecognizeClause.GetRule_row_pattern_common_syntax7();
@@ -126,7 +120,7 @@ TMatchRecognizeBuilderPtr TSqlMatchRecognizeClause::CreateBuilder(const NSQLv1Ge
std::pair{partitionsPos, std::move(partitioners)},
std::pair{orderByPos, std::move(sortSpecs)},
std::pair{measuresPos, measures},
- std::pair{rowsPerMatchPos, rowsPerMatch},
+ std::move(rowsPerMatch),
std::move(skipTo),
std::pair{patternPos, std::move(pattern)},
std::pair{subsetPos, std::move(subset)},
@@ -159,7 +153,6 @@ TNamedFunction TSqlMatchRecognizeClause::ParseOneMeasure(const TRule_row_pattern
TColumnRefScope scope(Ctx, EColumnRefState::MatchRecognize);
const auto& expr = TSqlExpression(Ctx, Mode).Build(node.GetRule_expr1());
const auto& name = Id(node.GetRule_an_id3(), *this);
- //TODO https://st.yandex-team.ru/YQL-16186
//Each measure must be a lambda, that accepts 2 args:
// - List<InputTableColumns + _yql_Classifier, _yql_MatchNumber>
// - Struct that maps row pattern variables to ranges in the queue
@@ -174,18 +167,18 @@ TVector<TNamedFunction> TSqlMatchRecognizeClause::ParseMeasures(const TRule_row_
return result;
}
-std::pair<TPosition, ERowsPerMatch> TSqlMatchRecognizeClause::ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause) {
+std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch> TSqlMatchRecognizeClause::ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause) {
switch(rowsPerMatchClause.GetAltCase()) {
case TRule_row_pattern_rows_per_match::kAltRowPatternRowsPerMatch1:
return std::pair {
TokenPosition(rowsPerMatchClause.GetAlt_row_pattern_rows_per_match1().GetToken1()),
- ERowsPerMatch::OneRow
+ NYql::NMatchRecognize::ERowsPerMatch::OneRow
};
case TRule_row_pattern_rows_per_match::kAltRowPatternRowsPerMatch2:
return std::pair {
TokenPosition(rowsPerMatchClause.GetAlt_row_pattern_rows_per_match2().GetToken1()),
- ERowsPerMatch::AllRows
+ NYql::NMatchRecognize::ERowsPerMatch::AllRows
};
case TRule_row_pattern_rows_per_match::ALT_NOT_SET:
Y_ABORT("You should change implementation according to grammar changes");
@@ -233,13 +226,13 @@ std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> TSqlMatchRecogniz
}
}
-NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTerm(const TRule_row_pattern_term& node){
+NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTerm(const TRule_row_pattern_term& node, size_t patternNestingLevel, bool outputArg) {
NYql::NMatchRecognize::TRowPatternTerm term;
TPosition pos;
for (const auto& factor: node.GetBlock1()) {
const auto& primaryVar = factor.GetRule_row_pattern_factor1().GetRule_row_pattern_primary1();
NYql::NMatchRecognize::TRowPatternPrimary primary;
- bool output = true;
+ bool output = outputArg;
switch (primaryVar.GetAltCase()) {
case TRule_row_pattern_primary::kAltRowPatternPrimary1:
primary = PatternVar(primaryVar.GetAlt_row_pattern_primary1().GetRule_row_pattern_primary_variable_name1().GetRule_row_pattern_variable_name1(), *this);
@@ -253,9 +246,8 @@ NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTer
Y_ENSURE("^" == std::get<0>(primary));
break;
case TRule_row_pattern_primary::kAltRowPatternPrimary4: {
- if (++PatternNestingLevel <= NYql::NMatchRecognize::MaxPatternNesting) {
- primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary4().GetBlock2().GetRule_row_pattern1());
- --PatternNestingLevel;
+ if (patternNestingLevel <= NYql::NMatchRecognize::MaxPatternNesting) {
+ primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary4().GetBlock2().GetRule_row_pattern1(), patternNestingLevel + 1, output);
} else {
Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1()))
<< "To big nesting level in the pattern";
@@ -265,15 +257,14 @@ NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTer
}
case TRule_row_pattern_primary::kAltRowPatternPrimary5:
output = false;
- Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1()))
- << "ALL ROWS PER MATCH and {- -} are not supported yet"; //https://st.yandex-team.ru/YQL-16227
+ primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary5().GetRule_row_pattern3(), patternNestingLevel + 1, output);
break;
case TRule_row_pattern_primary::kAltRowPatternPrimary6: {
std::vector<NYql::NMatchRecognize::TRowPatternPrimary> items{ParsePattern(
- primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetRule_row_pattern3())
+ primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetRule_row_pattern3(), patternNestingLevel + 1, output)
};
for (const auto& p: primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetBlock4()) {
- items.push_back(ParsePattern(p.GetRule_row_pattern2()));
+ items.push_back(ParsePattern(p.GetRule_row_pattern2(), patternNestingLevel + 1, output));
}
//Permutations now is a syntactic sugar and converted to all possible alternatives
if (items.size() > NYql::NMatchRecognize::MaxPermutedItems) {
@@ -346,11 +337,11 @@ NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTer
return term;
}
-NYql::NMatchRecognize::TRowPattern TSqlMatchRecognizeClause::ParsePattern(const TRule_row_pattern& node){
+NYql::NMatchRecognize::TRowPattern TSqlMatchRecognizeClause::ParsePattern(const TRule_row_pattern& node, size_t patternNestingLevel, bool output){
TVector<NYql::NMatchRecognize::TRowPatternTerm> result;
- result.push_back(ParsePatternTerm(node.GetRule_row_pattern_term1()));
+ result.push_back(ParsePatternTerm(node.GetRule_row_pattern_term1(), patternNestingLevel, output));
for (const auto& term: node.GetBlock2())
- result.push_back(ParsePatternTerm(term.GetRule_row_pattern_term2()));
+ result.push_back(ParsePatternTerm(term.GetRule_row_pattern_term2(), patternNestingLevel, output));
return result;
}
@@ -364,7 +355,6 @@ TNamedFunction TSqlMatchRecognizeClause::ParseOneDefinition(const TRule_row_patt
TVector<TNamedFunction> TSqlMatchRecognizeClause::ParseDefinitions(const TRule_row_pattern_definition_list& node) {
TVector<TNamedFunction> result { ParseOneDefinition(node.GetRule_row_pattern_definition1())};
for (const auto& d: node.GetBlock2()) {
- //TODO https://st.yandex-team.ru/YQL-16186
//Each define must be a predicate lambda, that accepts 3 args:
// - List<input table rows>
// - A struct that maps row pattern variables to ranges in the queue
diff --git a/yql/essentials/sql/v1/sql_match_recognize.h b/yql/essentials/sql/v1/sql_match_recognize.h
index 6766acc953..219baeaa09 100644
--- a/yql/essentials/sql/v1/sql_match_recognize.h
+++ b/yql/essentials/sql/v1/sql_match_recognize.h
@@ -17,14 +17,12 @@ private:
TVector<TNamedFunction> ParsePartitionBy(const TRule_window_partition_clause& partitionClause);
TNamedFunction ParseOneMeasure(const TRule_row_pattern_measure_definition& node);
TVector<TNamedFunction> ParseMeasures(const TRule_row_pattern_measure_list& node);
- std::pair<TPosition, ERowsPerMatch> ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause);
+ std::pair<TPosition, NYql::NMatchRecognize::ERowsPerMatch> ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause);
std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> ParseAfterMatchSkipTo(const TRule_row_pattern_skip_to& skipToClause);
- NYql::NMatchRecognize::TRowPatternTerm ParsePatternTerm(const TRule_row_pattern_term& node);
- NYql::NMatchRecognize::TRowPattern ParsePattern(const TRule_row_pattern& node);
+ NYql::NMatchRecognize::TRowPatternTerm ParsePatternTerm(const TRule_row_pattern_term& node, size_t patternNestingLevel, bool output);
+ NYql::NMatchRecognize::TRowPattern ParsePattern(const TRule_row_pattern& node, size_t patternNestingLevel = 1, bool output = true);
TNamedFunction ParseOneDefinition(const TRule_row_pattern_definition& node);
TVector<TNamedFunction> ParseDefinitions(const TRule_row_pattern_definition_list& node);
-private:
- size_t PatternNestingLevel = 0;
};
} // namespace NSQLTranslationV1
diff --git a/yql/essentials/sql/v1/sql_match_recognize_ut.cpp b/yql/essentials/sql/v1/sql_match_recognize_ut.cpp
index 20c5e6ab7b..f591ef0647 100644
--- a/yql/essentials/sql/v1/sql_match_recognize_ut.cpp
+++ b/yql/essentials/sql/v1/sql_match_recognize_ut.cpp
@@ -183,7 +183,7 @@ FROM Input MATCH_RECOGNIZE(
)
)";
auto r = MatchRecognizeSqlToYql(stmt);
- UNIT_ASSERT(not r.IsOk()); ///https://st.yandex-team.ru/YQL-16213
+ UNIT_ASSERT(r.IsOk());
}
{ //default
const auto stmt = R"(
diff --git a/yql/essentials/tests/sql/sql2yql/canondata/result.json b/yql/essentials/tests/sql/sql2yql/canondata/result.json
index c4b7d336e6..0c46d1d209 100644
--- a/yql/essentials/tests/sql/sql2yql/canondata/result.json
+++ b/yql/essentials/tests/sql/sql2yql/canondata/result.json
@@ -11213,6 +11213,13 @@
"uri": "https://{canondata_backend}/1942173/99e88108149e222741552e7e6cddef041d6a2846/resource.tar.gz#test_sql2yql.test_match_recognize-alerts_without_order_/sql.yql"
}
],
+ "test_sql2yql.test[match_recognize-all_rows_per_match]": [
+ {
+ "checksum": "31a940b1bbcea146ae9383e278326c7a",
+ "size": 6729,
+ "uri": "https://{canondata_backend}/1889210/954e2f1656d98697ece5794c59acf75dd1d40612/resource.tar.gz#test_sql2yql.test_match_recognize-all_rows_per_match_/sql.yql"
+ }
+ ],
"test_sql2yql.test[match_recognize-greedy_quantifiers]": [
{
"checksum": "41e90a3a986f9b2a7a36a83b918667cc",
@@ -28001,6 +28008,11 @@
"uri": "file://test_sql_format.test_match_recognize-alerts_without_order_/formatted.sql"
}
],
+ "test_sql_format.test[match_recognize-all_rows_per_match]": [
+ {
+ "uri": "file://test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql"
+ }
+ ],
"test_sql_format.test[match_recognize-greedy_quantifiers]": [
{
"uri": "file://test_sql_format.test_match_recognize-greedy_quantifiers_/formatted.sql"
diff --git a/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql b/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql
new file mode 100644
index 0000000000..edcc854cfa
--- /dev/null
+++ b/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_match_recognize-all_rows_per_match_/formatted.sql
@@ -0,0 +1,52 @@
+PRAGMA FeatureR010 = "prototype";
+
+$input =
+ SELECT
+ *
+ FROM
+ AS_TABLE([
+ <|time: 0, value: 0|>,
+ <|time: 100, value: 1|>,
+ <|time: 200, value: 2|>,
+ <|time: 300, value: 3|>,
+ <|time: 400, value: 4|>,
+ <|time: 500, value: 5|>,
+ <|time: 600, value: 0|>,
+ <|time: 700, value: 1|>,
+ <|time: 800, value: 2|>,
+ <|time: 900, value: 3|>,
+ <|time: 1000, value: 4|>,
+ <|time: 1100, value: 5|>,
+ <|time: 1200, value: 0|>,
+ ])
+;
+
+SELECT
+ *
+FROM
+ $input MATCH_RECOGNIZE (
+ ORDER BY
+ CAST(time AS Timestamp)
+ MEASURES
+ FIRST(A.time) AS a_time,
+ FIRST(B.time) AS b_time,
+ LAST(C.time) AS c_time,
+ LAST(F.time) AS f_time
+ ALL ROWS PER MATCH
+ AFTER MATCH SKIP PAST LAST ROW
+ PATTERN (A B {- C -} D {- E -} F +)
+ DEFINE
+ A AS A.value == 0
+ AND COALESCE(A.time - FIRST(A.time) <= 1000, TRUE),
+ B AS B.value == 1
+ AND COALESCE(B.time - FIRST(A.time) <= 1000, TRUE),
+ C AS C.value == 2
+ AND COALESCE(C.time - FIRST(A.time) <= 1000, TRUE),
+ D AS D.value == 3
+ AND COALESCE(D.time - FIRST(A.time) <= 1000, TRUE),
+ E AS E.value == 4
+ AND COALESCE(E.time - FIRST(A.time) <= 1000, TRUE),
+ F AS F.value == 5
+ AND COALESCE(F.time - FIRST(A.time) <= 1000, TRUE)
+ )
+;
diff --git a/yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql b/yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql
new file mode 100644
index 0000000000..ce55e529fb
--- /dev/null
+++ b/yql/essentials/tests/sql/suites/match_recognize/all_rows_per_match.sql
@@ -0,0 +1,48 @@
+PRAGMA FeatureR010="prototype";
+
+$input = SELECT * FROM AS_TABLE([
+ <|time: 0, value: 0|>,
+ <|time: 100, value: 1|>,
+ <|time: 200, value: 2|>,
+ <|time: 300, value: 3|>,
+ <|time: 400, value: 4|>,
+ <|time: 500, value: 5|>,
+ <|time: 600, value: 0|>,
+ <|time: 700, value: 1|>,
+ <|time: 800, value: 2|>,
+ <|time: 900, value: 3|>,
+ <|time: 1000, value: 4|>,
+ <|time: 1100, value: 5|>,
+ <|time: 1200, value: 0|>,
+]);
+
+SELECT * FROM $input MATCH_RECOGNIZE(
+ ORDER BY CAST(time as Timestamp)
+ MEASURES
+ FIRST(A.time) AS a_time,
+ FIRST(B.time) AS b_time,
+ LAST(C.time) AS c_time,
+ LAST(F.time) AS f_time
+ ALL ROWS PER MATCH
+ AFTER MATCH SKIP PAST LAST ROW
+ PATTERN (A B {- C -} D {- E -} F+)
+ DEFINE
+ A AS
+ A.value = 0 AND
+ COALESCE(A.time - FIRST(A.time) <= 1000, TRUE),
+ B AS
+ B.value = 1 AND
+ COALESCE(B.time - FIRST(A.time) <= 1000, TRUE),
+ C AS
+ C.value = 2 AND
+ COALESCE(C.time - FIRST(A.time) <= 1000, TRUE),
+ D AS
+ D.value = 3 AND
+ COALESCE(D.time - FIRST(A.time) <= 1000, TRUE),
+ E AS
+ E.value = 4 AND
+ COALESCE(E.time - FIRST(A.time) <= 1000, TRUE),
+ F AS
+ F.value = 5 AND
+ COALESCE(F.time - FIRST(A.time) <= 1000, TRUE)
+);