diff options
Diffstat (limited to 'yql/essentials/core/yql_opt_match_recognize.cpp')
-rw-r--r-- | yql/essentials/core/yql_opt_match_recognize.cpp | 162 |
1 files changed, 79 insertions, 83 deletions
diff --git a/yql/essentials/core/yql_opt_match_recognize.cpp b/yql/essentials/core/yql_opt_match_recognize.cpp index 30544f8333..4b6f6c7a80 100644 --- a/yql/essentials/core/yql_opt_match_recognize.cpp +++ b/yql/essentials/core/yql_opt_match_recognize.cpp @@ -32,95 +32,88 @@ bool IsStreaming(const TExprNode::TPtr& input, const TTypeAnnotationContext& typ } } -TExprNode::TPtr ExpandMatchRecognizeMeasuresAggregates(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& /* typeAnnCtx */) { - const auto pos = node->Pos(); - const auto vars = node->Child(3); - static constexpr size_t AggregatesLambdasStartPos = 4; +TExprNode::TPtr ExpandMatchRecognizeMeasuresCallables(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& /* typeAnnCtx */) { + YQL_CLOG(DEBUG, Core) << "Expand " << node->Content(); static constexpr size_t MeasuresLambdasStartPos = 3; - - return ctx.Builder(pos) + return ctx.Builder(node->Pos()) .Callable("MatchRecognizeMeasures") .Add(0, node->ChildPtr(0)) .Add(1, node->ChildPtr(1)) .Add(2, node->ChildPtr(2)) .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (size_t i = 0; i < vars->ChildrenSize(); ++i) { - const auto var = vars->Child(i)->Content(); - const auto handler = node->ChildPtr(AggregatesLambdasStartPos + i); - if (!var) { - auto value = handler->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional - ? ctx.Builder(pos).Callable("Just").Add(0, handler).Seal().Build() - : handler; - parent.Add( - MeasuresLambdasStartPos + i, - ctx.Builder(pos) - .Lambda() - .Param("data") - .Param("vars") - .Add(0, std::move(value)) - .Seal() - .Build() - ); - continue; - } - parent.Add( - MeasuresLambdasStartPos + i, - ctx.Builder(pos) - .Lambda() - .Param("data") - .Param("vars") - .Callable(0, "Member") - .Callable(0, "Head") - .Callable(0, "Aggregate") - .Callable(0, "OrderedMap") - .Callable(0, "OrderedFlatMap") - .Callable(0, "Member") - .Arg(0, "vars") - .Atom(1, var) - .Seal() - .Lambda(1) - .Param("item") - .Callable(0, "ListFromRange") - .Callable(0, "Member") - .Arg(0, "item") - .Atom(1, "From") - .Seal() - .Callable(1, "+MayWarn") - .Callable(0, "Member") - .Arg(0, "item") - .Atom(1, "To") + const auto aggregatesItems = node->Child(3); + for (size_t i = 0; i < aggregatesItems->ChildrenSize(); ++i) { + const auto item = aggregatesItems->Child(i); + auto lambda = item->ChildPtr(0); + const auto vars = item->Child(1); + const auto aggregates = item->Child(2); + parent.Lambda(MeasuresLambdasStartPos + i, lambda->Pos()) + .Param("data") + .Param("vars") + .Apply(std::move(lambda)) + .With(0) + .Callable("FlattenMembers") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (size_t i = 0; i < aggregates->ChildrenSize(); ++i) { + const auto var = vars->Child(i)->Content(); + auto aggregate = aggregates->Child(i); + parent + .List(i) + .Atom(0, "") + .Callable(1, "Head") + .Callable(0, "Aggregate") + .Callable(0, "OrderedMap") + .Callable(0, "OrderedFlatMap") + .Callable(0, "Member") + .Arg(0, "vars") + .Atom(1, var) + .Seal() + .Lambda(1) + .Param("item") + .Callable(0, "ListFromRange") + .Callable(0, "Member") + .Arg(0, "item") + .Atom(1, "From") + .Seal() + .Callable(1, "+MayWarn") + .Callable(0, "Member") + .Arg(0, "item") + .Atom(1, "To") + .Seal() + .Callable(1, "Uint64") + .Atom(0, "1") + .Seal() + .Seal() + .Seal() + .Seal() .Seal() - .Callable(1, "Uint64") - .Atom(0, "1") + .Lambda(1) + .Param("index") + .Callable(0, "Unwrap") + .Callable(0, "Lookup") + .Callable(0, "ToIndexDict") + .Arg(0, "data") + .Seal() + .Arg(1, "index") + .Seal() + .Seal() .Seal() .Seal() - .Seal() - .Seal() - .Seal() - .Lambda(1) - .Param("index") - .Callable(0, "Unwrap") - .Callable(0, "Lookup") - .Callable(0, "ToIndexDict") - .Arg(0, "data") + .List(1).Seal() + .List(2) + .Add(0, std::move(aggregate)) .Seal() - .Arg(1, "index") + .List(3).Seal() .Seal() .Seal() - .Seal() - .Seal() - .List(1).Seal() - .List(2) - .Add(0, handler) - .Seal() - .List(3).Seal() - .Seal() - .Seal() - .Atom(1, handler->Child(0)->Content()) + .Seal(); + } + return parent; + }) .Seal() - .Seal() - .Build() - ); + .Done() + .Seal() + .Seal(); } return parent; }) @@ -128,13 +121,16 @@ TExprNode::TPtr ExpandMatchRecognizeMeasuresAggregates(const TExprNode::TPtr& no .Build(); } -THashSet<TStringBuf> FindUsedVars(const TExprNode::TPtr& params) { - THashSet<TStringBuf> result; +std::unordered_set<std::string_view> FindUsedVars(const TExprNode::TPtr& params) { + std::unordered_set<std::string_view> result; const auto measures = params->Child(0); - const auto measuresVars = measures->Child(3); - for (const auto& var : measuresVars->Children()) { - result.insert(var->Content()); + const auto callablesItems = measures->Child(3); + for (const auto& item : callablesItems->Children()) { + const auto vars = item->Child(1); + for (const auto& var : vars->Children()) { + result.insert(var->Content()); + } } const auto defines = params->Child(4); @@ -159,7 +155,7 @@ THashSet<TStringBuf> FindUsedVars(const TExprNode::TPtr& params) { return result; } -TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& ctx, const THashSet<TStringBuf>& usedVars, const TExprNode::TPtr& rowsPerMatch) { +TExprNode::TPtr MarkUnusedPatternVars(const TExprNode::TPtr& node, TExprContext& ctx, const std::unordered_set<std::string_view>& usedVars, const TExprNode::TPtr& rowsPerMatch) { const auto pos = node->Pos(); if (node->ChildrenSize() == 6 && node->Child(0)->IsAtom()) { const auto varName = node->Child(0)->Content(); @@ -262,7 +258,7 @@ TExprNode::TPtr ExpandMatchRecognize(const TExprNode::TPtr& node, TExprContext& return {}; } - auto measures = ExpandMatchRecognizeMeasuresAggregates(params->ChildPtr(0), ctx, typeAnnCtx); + auto measures = ExpandMatchRecognizeMeasuresCallables(params->ChildPtr(0), ctx, typeAnnCtx); auto rowsPerMatch = params->ChildPtr(1); const auto usedVars = FindUsedVars(params); auto pattern = MarkUnusedPatternVars(params->ChildPtr(3), ctx, usedVars, rowsPerMatch); |