diff options
| author | vvvv <[email protected]> | 2022-08-30 17:04:46 +0300 |
|---|---|---|
| committer | vvvv <[email protected]> | 2022-08-30 17:04:46 +0300 |
| commit | be4e3940d724b3f7fbdc9eacb063817fd2fe1fb7 (patch) | |
| tree | 66bb16153f72c3762085fb34648321f45b34db2f | |
| parent | b7d7e31aed459810973844be026ce24aa309a0f3 (diff) | |
support of aggregations in window definition and functions
| -rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_pgselect.cpp | 446 | ||||
| -rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_pg.cpp | 424 | ||||
| -rw-r--r-- | ydb/library/yql/sql/pg/pg_sql.cpp | 3 |
3 files changed, 485 insertions, 388 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp index a523537f962..51decdd9d72 100644 --- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp @@ -8,10 +8,10 @@ namespace NYql { -TNodeMap<ui32> GatherSubLinks(const TExprNode::TPtr& lambda) { +TNodeMap<ui32> GatherSubLinks(const TExprNode::TPtr& root) { TNodeMap<ui32> subLinks; - VisitExpr(lambda->TailPtr(), [&](const TExprNode::TPtr& node) { + VisitExpr(root, [&](const TExprNode::TPtr& node) { if (node->IsCallable("PgSubLink")) { subLinks[node.Get()] = subLinks.size(); return false; @@ -147,14 +147,26 @@ TExprNode::TPtr JoinColumns(TPositionHandle pos, const TExprNode::TPtr& list1, c using TAggregationMap = TNodeMap<std::pair<ui32, bool>>; // uid + sublink test expression using TAggs = TVector<std::pair<TExprNode::TPtr, TExprNode::TPtr>>; -void RewriteAggs(TExprNode::TPtr& lambda, const TAggregationMap& aggId, TExprContext& ctx, TOptimizeContext& optCtx, bool testExpr) { - auto status = OptimizeExpr(lambda, lambda, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { +void RewriteAggs(TExprNode::TPtr& lambda, const TAggregationMap& aggId, TExprContext& ctx, TOptimizeContext& optCtx, bool testExpr); + +void RewriteAggsPartial(TExprNode::TPtr& root, const TExprNode::TPtr& arg, const TAggregationMap& aggId, TExprContext& ctx, TOptimizeContext& optCtx, bool testExpr) { + auto subLinks = GatherSubLinks(root); + auto status = OptimizeExpr(root, root, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { + auto subIt = subLinks.find(node.Get()); + if (subIt != subLinks.end()) { + if (!node->Child(3)->IsCallable("Void")) { + auto lambda = node->ChildPtr(3); + RewriteAggs(lambda, aggId, ctx, optCtx, true); + return ctx.ChangeChild(*node, 3, std::move(lambda)); + } + } + auto it = aggId.find(node.Get()); if (it != aggId.end() && it->second.second == testExpr) { auto ret = ctx.Builder(node->Pos()) .Callable("Member") - .Add(0, lambda->Head().HeadPtr()) - .Atom(1, "_yql_agg_" + ToString(it->second.first)) + .Add(0, arg) + .Atom(1, "_yql_agg_" + ToString(it->second.first)) .Seal() .Build(); @@ -165,43 +177,26 @@ void RewriteAggs(TExprNode::TPtr& lambda, const TAggregationMap& aggId, TExprCon }, ctx, TOptimizeExprSettings(optCtx.Types)); YQL_ENSURE(status.Level != IGraphTransformer::TStatus::Error); +} + +void RewriteAggs(TExprNode::TPtr& lambda, const TAggregationMap& aggId, TExprContext& ctx, TOptimizeContext& optCtx, bool testExpr) { + RewriteAggsPartial(lambda, lambda->Head().HeadPtr(), aggId, ctx, optCtx, testExpr); }; -std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, - const TExprNode::TPtr& list, const TExprNode::TPtr& lambda, +std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinksPartial(TPositionHandle pos, + const TExprNode::TPtr& list, const TExprNode::TPtr& root, const TExprNode::TPtr& originalArg, const TNodeMap<ui32>& subLinks, const TVector<TString>& inputAliases, const TExprNode::TListType& cleanedInputs, - const TAggregationMap* aggId, TExprContext& ctx, TOptimizeContext& optCtx, + TExprContext& ctx, TOptimizeContext& optCtx, const TString& leftPrefix = {}, TVector<TString>* sublinkColumns = nullptr) { + auto newRoot = root; auto newList = list; - auto originalRow = lambda->Head().HeadPtr(); - auto arg = ctx.NewArgument(pos, "row"); - auto arguments = ctx.NewArguments(pos, { arg }); - auto root = lambda->TailPtr(); TNodeOnNodeOwnedMap deepClones; - auto status = OptimizeExpr(root, root, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { - if (aggId) { - auto it = aggId->find(node.Get()); - if (it != aggId->end() && !it->second.second) { - auto ret = ctx.Builder(node->Pos()) - .Callable("Member") - .Add(0, arg) - .Atom(1, "_yql_agg_" + ToString(it->second.first)) - .Seal() - .Build(); - - return ret; - } - } - + auto status = OptimizeExpr(newRoot, newRoot, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { auto it = subLinks.find(node.Get()); if (it != subLinks.end()) { auto linkType = node->Head().Content(); auto testLambda = node->ChildPtr(3); - if (aggId && (linkType == "any" || linkType == "all")) { - RewriteAggs(testLambda, *aggId, ctx, optCtx, true); - } - auto extColumns = NTypeAnnImpl::ExtractExternalColumns(node->Tail()); if (extColumns.empty()) { auto select = ExpandPgSelectSublink(node->TailPtr(), ctx, optCtx, it->second, cleanedInputs, inputAliases); @@ -261,7 +256,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, if (useIn) { auto value = ctx.ReplaceNodes(testLambda->Tail().ChildPtr(2), { - {testLambda->Head().Child(0), originalRow} + {testLambda->Head().Child(0), originalArg} }); return ctx.Builder(node->Pos()) @@ -292,7 +287,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, .Build(); auto foldExpr = ctx.ReplaceNodes(testLambda->TailPtr(), { - {testLambda->Head().Child(0), originalRow}, + {testLambda->Head().Child(0), originalArg}, {testLambda->Head().Child(1), value}, }); @@ -544,7 +539,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, .Callable(0, ">") .Callable(0, "Coalesce") .Callable(0, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName) .Seal() .Callable(1, "Uint64") @@ -561,13 +556,13 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, return ctx.Builder(node->Pos()) .Callable("Ensure") .Callable(0, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName + "_value") .Seal() .Callable(1, "<=") .Callable(0, "Coalesce") .Callable(0, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName + "_count") .Seal() .Callable(1, "Uint64") @@ -590,7 +585,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, .Callable(0, "!=") .Callable(0, "Coalesce") .Callable(0, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName + "_count") .Seal() .Callable(1, "Uint64") @@ -602,7 +597,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, .Seal() .Seal() .Callable(1, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName + "_value") .Seal() .Seal() @@ -616,7 +611,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, .Callable(0, "==") .Callable(0, "Coalesce") .Callable(0, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName + "_count") .Seal() .Callable(1, "Uint64") @@ -628,7 +623,7 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, .Seal() .Seal() .Callable(1, "Member") - .Add(0, originalRow) + .Add(0, originalArg) .Atom(1, columnName + "_value") .Seal() .Seal() @@ -644,12 +639,33 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, }, ctx, TOptimizeExprSettings(optCtx.Types)); YQL_ENSURE(status.Level != IGraphTransformer::TStatus::Error); - root = ctx.ReplaceNode(std::move(root), *originalRow, arg); - auto newLambda = ctx.NewLambda(pos, std::move(arguments), std::move(root)); + return { + newRoot, + newList + }; +} + +std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos, + const TExprNode::TPtr& list, const TExprNode::TPtr& lambda, + const TNodeMap<ui32>& subLinks, const TVector<TString>& inputAliases, + const TExprNode::TListType& cleanedInputs, + TExprContext& ctx, TOptimizeContext& optCtx, + const TString& leftPrefix = {}, TVector<TString>* sublinkColumns = nullptr) { + + auto arg = ctx.NewArgument(pos, "row"); + auto arguments = ctx.NewArguments(pos, { arg }); + auto root = lambda->TailPtr(); + + TExprNode::TPtr newList, newRoot; + std::tie(newRoot, newList) = RewriteSubLinksPartial(pos, list, root, lambda->Head().HeadPtr(), subLinks, inputAliases, + cleanedInputs, ctx, optCtx, leftPrefix, sublinkColumns); + + newRoot = ctx.ReplaceNode(std::move(newRoot), lambda->Head().Head(), arg); + auto newLambda = ctx.NewLambda(pos, std::move(arguments), std::move(newRoot)); return { - newList, - newLambda + newLambda, + newList }; } @@ -658,8 +674,8 @@ TExprNode::TPtr BuildFilter(TPositionHandle pos, const TExprNode::TPtr& list, co TExprNode::TPtr actualList = list, actualFilter = filter; auto subLinks = GatherSubLinks(filter); if (!subLinks.empty()) { - std::tie(actualList, actualFilter) = RewriteSubLinks(filter->Pos(), list, filter, - subLinks, inputAliases, cleanedInputs, nullptr, ctx, optCtx); + std::tie(actualFilter, actualList) = RewriteSubLinks(filter->Pos(), list, filter, + subLinks, inputAliases, cleanedInputs, ctx, optCtx); } return ctx.Builder(pos) @@ -793,8 +809,8 @@ struct TWindowsCtx { TNodeMap<ui32> FuncsId; }; -void GatherUsedWindows(const TExprNode::TPtr& window, TExprNode::TPtr& projectionLambda, TWindowsCtx& winCtx) { - VisitExpr(projectionLambda->TailPtr(), [&](const TExprNode::TPtr& node) { +void GatherUsedWindows(const TExprNode::TPtr& window, const TExprNode::TPtr& projectionLambda, TWindowsCtx& winCtx) { + VisitExpr(projectionLambda, [&](const TExprNode::TPtr& node) { if (node->IsCallable("PgWindowCall") || node->IsCallable("PgAggWindowCall")) { YQL_ENSURE(window); ui32 windowIndex; @@ -1644,7 +1660,7 @@ void GatherAggregationsFromLambda(const TExprNode::TPtr& lambda, TAggs& aggs, TA TExprNode::TPtr BuildAggregationTraits(TPositionHandle pos, bool onWindow, const TString& distinctColumnName, const std::pair<TExprNode::TPtr, TExprNode::TPtr>& agg, - const TExprNode::TPtr& listTypeNode, TExprContext& ctx) { + const TExprNode::TPtr& listTypeNode, const TAggregationMap* aggId, TExprContext& ctx, TOptimizeContext& optCtx) { auto func = agg.first->Head().Content(); TExprNode::TPtr type = ctx.Builder(pos) .Callable("ListItemType") @@ -1672,7 +1688,12 @@ TExprNode::TPtr BuildAggregationTraits(TPositionHandle pos, bool onWindow, const auto arguments = ctx.NewArguments(pos, { arg }); TExprNode::TListType aggFuncArgs; for (ui32 j = onWindow ? 3 : 2; j < agg.first->ChildrenSize(); ++j) { - aggFuncArgs.push_back(ctx.ReplaceNode(agg.first->ChildPtr(j), *agg.second, arg)); + auto root = agg.first->ChildPtr(j); + if (aggId && onWindow) { + RewriteAggsPartial(root, arg, *aggId, ctx, optCtx, false); + } + + aggFuncArgs.push_back(ctx.ReplaceNode(std::move(root), *agg.second, arg)); } extractor = ctx.NewLambda(pos, std::move(arguments), std::move(aggFuncArgs)); @@ -1689,7 +1710,7 @@ TExprNode::TPtr BuildAggregationTraits(TPositionHandle pos, bool onWindow, const TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list, const TAggs& aggs, const TExprNode::TPtr& groupBy, - const TExprNode::TPtr& finalExtTypes, TExprContext& ctx) { + const TExprNode::TPtr& finalExtTypes, TExprContext& ctx, TOptimizeContext& optCtx) { bool needRemapForDistinct = false; for (ui32 i = 0; i < aggs.size(); ++i) { @@ -1755,7 +1776,7 @@ TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list, TExprNode::TListType payloadItems; for (ui32 i = 0; i < aggs.size(); ++i) { const bool distinct = GetSetting(*aggs[i].first->Child(1), "distinct") != nullptr; - auto traits = BuildAggregationTraits(pos, false, distinct ? "_yql_distinct_" + ToString(i) : "", aggs[i], listTypeNode, ctx); + auto traits = BuildAggregationTraits(pos, false, distinct ? "_yql_distinct_" + ToString(i) : "", aggs[i], listTypeNode, nullptr, ctx, optCtx); if (distinct) { payloadItems.push_back(ctx.Builder(pos) .List() @@ -1829,12 +1850,12 @@ TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list, .Build(); } - auto keys = ctx.NewList(pos, std::move(keysItems)); + auto keysNode = ctx.NewList(pos, std::move(keysItems)); return ctx.Builder(pos) .Callable("Aggregate") .Add(0, list) - .Add(1, keys) + .Add(1, keysNode) .Add(2, payloadsNode) .List(3) // options .Seal() @@ -1846,12 +1867,12 @@ TExprNode::TPtr BuildHaving(TPositionHandle pos, TExprNode::TPtr list, const TEx const TAggregationMap& aggId, const TVector<TString>& inputAliases, const TExprNode::TListType& cleanedInputs, TExprContext& ctx, TOptimizeContext& optCtx) { auto havingLambda = having->TailPtr(); - auto havingSubLinks = GatherSubLinks(havingLambda); + auto havingLambdaRoot = havingLambda->TailPtr(); + RewriteAggsPartial(havingLambdaRoot, havingLambda->Head().HeadPtr(), aggId, ctx, optCtx, false); + auto havingSubLinks = GatherSubLinks(havingLambdaRoot); if (!havingSubLinks.empty()) { - std::tie(list, havingLambda) = RewriteSubLinks(havingLambda->Pos(), list, havingLambda, - havingSubLinks, inputAliases, cleanedInputs, &aggId, ctx, optCtx); - } else { - RewriteAggs(havingLambda, aggId, ctx, optCtx, false); + std::tie(havingLambdaRoot, list) = RewriteSubLinksPartial(havingLambda->Pos(), list, havingLambdaRoot, + havingLambda->Head().HeadPtr(), havingSubLinks, inputAliases, cleanedInputs, ctx, optCtx); } return ctx.Builder(pos) @@ -1861,7 +1882,7 @@ TExprNode::TPtr BuildHaving(TPositionHandle pos, TExprNode::TPtr list, const TEx .Param("row") .Callable("Coalesce") .Callable(0, "FromPg") - .Apply(0, havingLambda) + .ApplyPartial(0, havingLambda->HeadPtr(), havingLambdaRoot) .With(0, "row") .Seal() .Seal() @@ -1914,8 +1935,14 @@ std::tuple<TExprNode::TPtr, TExprNode::TPtr> BuildFrame(TPositionHandle pos, con return { begin, end }; } -TExprNode::TPtr BuildSortTraits(TPositionHandle pos, const TExprNode& sortColumns, const TExprNode::TPtr& list, TExprContext& ctx) { +TExprNode::TPtr BuildSortTraits(TPositionHandle pos, const TExprNode& sortColumns, const TExprNode::TPtr& list, + const TAggregationMap* aggId, TExprContext& ctx, TOptimizeContext& optCtx) { if (sortColumns.ChildrenSize() == 1) { + auto lambda = sortColumns.Head().ChildPtr(1); + if (aggId) { + RewriteAggs(lambda, *aggId, ctx, optCtx, false); + } + return ctx.Builder(pos) .Callable("SortTraits") .Callable(0, "TypeOf") @@ -1926,7 +1953,7 @@ TExprNode::TPtr BuildSortTraits(TPositionHandle pos, const TExprNode& sortColumn .Seal() .Lambda(2) .Param("row") - .Apply(sortColumns.Head().ChildPtr(1)) + .Apply(lambda) .With(0, "row") .Seal() .Seal() @@ -1953,7 +1980,12 @@ TExprNode::TPtr BuildSortTraits(TPositionHandle pos, const TExprNode& sortColumn .List() .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { for (ui32 i = 0; i < sortColumns.ChildrenSize(); ++i) { - parent.Apply(i, sortColumns.Child(i)->ChildPtr(1)) + auto lambda = sortColumns.Child(i)->ChildPtr(1); + if (aggId) { + RewriteAggs(lambda, *aggId, ctx, optCtx, false); + } + + parent.Apply(i, lambda) .With(0, "row") .Seal(); } @@ -1968,7 +2000,7 @@ TExprNode::TPtr BuildSortTraits(TPositionHandle pos, const TExprNode& sortColumn } TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, const TExprNode::TPtr& window, const TWindowsCtx& winCtx, - TExprNode::TPtr& projectionLambda, TExprContext& ctx, TOptimizeContext& optCtx) { + TExprNode::TPtr& projectionRoot, const TExprNode::TPtr& projectionArg, const TAggregationMap& aggId, TExprContext& ctx, TOptimizeContext& optCtx) { auto ret = list; auto listTypeNode = ctx.Builder(pos) .Callable("TypeOf") @@ -1980,19 +2012,57 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c auto winDef = window->Tail().Child(x.first); const auto& frameSettings = winDef->Tail(); - TExprNode::TListType keys; - for (auto p : winDef->Child(2)->Children()) { - YQL_ENSURE(p->IsCallable("PgGroup")); - const auto& member = p->Tail().Tail(); - YQL_ENSURE(member.IsCallable("Member")); - keys.push_back(member.TailPtr()); + TExprNode::TListType keysItems; + if (winDef->Child(2)->ChildrenSize()) { + auto arg = ctx.NewArgument(pos, "row"); + auto arguments = ctx.NewArguments(pos, { arg }); + + TExprNode::TListType newColumns; + for (ui32 i = 0; i < winDef->Child(2)->ChildrenSize(); ++i) { + const auto& group = winDef->Child(2)->Child(i); + auto lambda = group->TailPtr(); + RewriteAggs(lambda, aggId, ctx, optCtx, false); + auto name = "_yql_partition_key_" + ToString(x.first) + "_" + ToString(i); + keysItems.push_back(ctx.NewAtom(pos, name)); + newColumns.push_back(ctx.Builder(pos) + .List() + .Atom(0, name) + .Apply(1, *lambda) + .With(0, arg) + .Seal() + .Seal() + .Build()); + } + + auto newColumnsNode = ctx.NewCallable(pos, "AsStruct", std::move(newColumns)); + auto root = ctx.Builder(pos) + .Callable("FlattenMembers") + .List(0) + .Atom(0, "") + .Add(1, arg) + .Seal() + .List(1) + .Atom(0, "") + .Add(1, newColumnsNode) + .Seal() + .Seal() + .Build(); + + auto keyExprsLambda = ctx.NewLambda(pos, std::move(arguments), std::move(root)); + + ret = ctx.Builder(pos) + .Callable("OrderedMap") + .Add(0, ret) + .Add(1, keyExprsLambda) + .Seal() + .Build(); } - auto keysNode = ctx.NewList(pos, std::move(keys)); + auto keysNode = ctx.NewList(pos, std::move(keysItems)); auto sortNode = ctx.NewCallable(pos, "Void", {}); TExprNode::TPtr keyLambda; if (winDef->Child(3)->ChildrenSize() > 0) { - sortNode = BuildSortTraits(pos, *winDef->Child(3), ret, ctx); + sortNode = BuildSortTraits(pos, *winDef->Child(3), ret, &aggId, ctx, optCtx); keyLambda = sortNode->TailPtr(); } else { keyLambda = ctx.Builder(pos) @@ -2004,15 +2074,30 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c .Build(); } - TExprNode::TListType args; - // default frame - auto begin = ctx.NewCallable(pos, "Void", {}); - auto end = winDef->Child(3)->ChildrenSize() > 0 ? - ctx.NewCallable(pos, "Int32", { ctx.NewAtom(pos, "0") }) : - ctx.NewCallable(pos, "Void", {}); + TExprNode::TPtr begin, end; + bool useRange = false; if (HasSetting(frameSettings, "type")) { std::tie(begin, end) = BuildFrame(pos, frameSettings, ctx); + } else { + // default frame + if (winDef->Child(3)->ChildrenSize() > 0) { + useRange = true; + begin = ctx.Builder(pos) + .List() + .Atom(0, "preceding") + .Atom(1, "unbounded") + .Seal() + .Build(); + end = ctx.Builder(pos) + .List() + .Atom(0, "currentRow") + .Seal() + .Build(); + } else { + begin = ctx.NewCallable(pos, "Void", {}); + end = begin; + } } args.push_back(ctx.Builder(pos) @@ -2034,7 +2119,7 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c bool isAgg = p.first->IsCallable("PgAggWindowCall"); TExprNode::TPtr value; if (isAgg) { - value = BuildAggregationTraits(pos, true, "", p, listTypeNode, ctx); + value = BuildAggregationTraits(pos, true, "", p, listTypeNode, &aggId, ctx, optCtx); } else { if (name == "row_number") { value = ctx.Builder(pos) @@ -2061,8 +2146,10 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c } else if (name == "lead" || name == "lag") { auto arg = ctx.NewArgument(pos, "row"); auto arguments = ctx.NewArguments(pos, { arg }); + auto root = p.first->TailPtr(); + RewriteAggsPartial(root, arg, aggId, ctx, optCtx, false); auto extractor = ctx.NewLambda(pos, std::move(arguments), - ctx.ReplaceNode(p.first->TailPtr(), *p.second, arg)); + ctx.ReplaceNode(std::move(root), *p.second, arg)); value = ctx.Builder(pos) .Callable(name == "lead" ? "Lead" : "Lag") @@ -2085,7 +2172,7 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c .Build()); } - auto winOnRows = ctx.NewCallable(pos, "WinOnRows", std::move(args)); + auto winOnRows = ctx.NewCallable(pos, useRange ? "WinOnRange" : "WinOnRows", std::move(args)); auto frames = ctx.Builder(pos) .List() @@ -2103,12 +2190,12 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c .Build(); } - auto status = OptimizeExpr(projectionLambda, projectionLambda, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { + auto status = OptimizeExpr(projectionRoot, projectionRoot, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { auto it = winCtx.FuncsId.find(node.Get()); if (it != winCtx.FuncsId.end()) { auto ret = ctx.Builder(pos) .Callable("Member") - .Add(0, projectionLambda->Head().HeadPtr()) + .Add(0, projectionArg) .Atom(1, "_yql_win_" + ToString(it->second)) .Seal() .Build(); @@ -2149,7 +2236,8 @@ TExprNode::TPtr BuildSortLambda(TPositionHandle pos, const TExprNode::TPtr& sort return lambda; } -TExprNode::TPtr BuildSort(TPositionHandle pos, const TExprNode::TPtr& sort, const TExprNode::TPtr& list, const TExprNode::TPtr& sortLambda, TExprContext& ctx) { +TExprNode::TPtr BuildSort(TPositionHandle pos, const TExprNode::TPtr& sort, const TExprNode::TPtr& list, + const TExprNode::TPtr& sortLambdaRoot, const TExprNode::TPtr& sortLambdaArgs, TExprContext& ctx) { const auto& keys = sort->Tail(); TExprNode::TListType dirItems; @@ -2167,13 +2255,18 @@ TExprNode::TPtr BuildSort(TPositionHandle pos, const TExprNode::TPtr& sort, cons .Callable("Sort") .Add(0, list) .Add(1, dir) - .Add(2, sortLambda) + .Lambda(2) + .Param("row") + .ApplyPartial(sortLambdaArgs, sortLambdaRoot) + .With(0, "row") + .Seal() + .Seal() .Seal() .Build(); } TExprNode::TPtr BuildDistinctOn(TPositionHandle pos, TExprNode::TPtr list, const TExprNode::TPtr& distinctOn, - const TExprNode::TPtr& sort, TExprContext& ctx) { + const TExprNode::TPtr& sort, TExprContext& ctx, TOptimizeContext& optCtx) { // filter by RowNumber() == 1 TExprNode::TListType args; @@ -2280,7 +2373,7 @@ TExprNode::TPtr BuildDistinctOn(TPositionHandle pos, TExprNode::TPtr list, const auto sortNode = ctx.NewCallable(pos, "Void", {}); if (sort && sort->Tail().ChildrenSize() > 0) { - sortNode = BuildSortTraits(pos, sort->Tail(), list, ctx); + sortNode = BuildSortTraits(pos, sort->Tail(), list, nullptr, ctx, optCtx); } auto ret = ctx.Builder(pos) @@ -2378,39 +2471,34 @@ TExprNode::TPtr BuildLimit(TPositionHandle pos, const TExprNode::TPtr& limit, co .Build(); } -TExprNode::TPtr AddExtColumns(const TExprNode::TPtr& lambda, const TExprNode::TPtr& finalExtTypes, +TExprNode::TPtr AddExtColumns(const TExprNode::TPtr& projectionRoot, const TExprNode::TPtr& projectionArg, const TExprNode::TPtr& finalExtTypes, TExprNode::TListType& columns, ui32 subLinkId, TExprContext& ctx) { - return ctx.Builder(lambda->Pos()) - .Lambda() - .Param("row") - .Callable("FlattenMembers") - .List(0) - .Atom(0, "") - .Apply(1, lambda) - .With(0, "row") - .Seal() - .Seal() - .List(1) - .Atom(0, "_yql_join_sublink_" + ToString(subLinkId) + "_") - .Callable(1, "FilterMembers") - .Arg(0, "row") - .List(1) - .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & { - ui32 i = 0; - for (const auto& x : finalExtTypes->Children()) { - auto alias = x->Head().Content(); - auto type = x->Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>(); - for (const auto& item : type->GetItems()) { - auto withAlias = NTypeAnnImpl::MakeAliasedColumn(alias, item->GetName()); - parent.Atom(i++, withAlias); - columns.push_back(ctx.NewAtom(lambda->Pos(), TString("_yql_join_sublink_") + - ToString(subLinkId) + "_" + withAlias)); - } + return ctx.Builder(projectionRoot->Pos()) + .Callable("FlattenMembers") + .List(0) + .Atom(0, "") + .Add(1, projectionRoot) + .Seal() + .List(1) + .Atom(0, "_yql_join_sublink_" + ToString(subLinkId) + "_") + .Callable(1, "FilterMembers") + .Add(0, projectionArg) + .List(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & { + ui32 i = 0; + for (const auto& x : finalExtTypes->Children()) { + auto alias = x->Head().Content(); + auto type = x->Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>(); + for (const auto& item : type->GetItems()) { + auto withAlias = NTypeAnnImpl::MakeAliasedColumn(alias, item->GetName()); + parent.Atom(i++, withAlias); + columns.push_back(ctx.NewAtom(projectionRoot->Pos(), TString("_yql_join_sublink_") + + ToString(subLinkId) + "_" + withAlias)); } + } - return parent; - }) - .Seal() + return parent; + }) .Seal() .Seal() .Seal() @@ -2440,40 +2528,35 @@ void BuildExtraSortColumns(const TExprNode::TPtr& groupBy, } } -TExprNode::TPtr AddExtraSortColumns(const TExprNode::TPtr& lambda, const TExprNode::TPtr& groupBy, +TExprNode::TPtr AddExtraSortColumns(const TExprNode::TPtr& root, const TExprNode::TPtr& originalArg, const TExprNode::TPtr& groupBy, const TExprNode::TPtr& extraSortColumns, const TExprNode::TPtr& extraSortKeys, size_t aggIndexBegin, size_t aggIndexEnd, TExprContext& ctx) { - return ctx.Builder(lambda->Pos()) - .Lambda() - .Param("row") - .Callable("FlattenMembers") - .List(0) - .Atom(0, "") - .Apply(1, lambda) - .With(0, "row") - .Seal() - .Seal() - .List(1) - .Atom(0, "") - .Callable(1, "AsStruct") - .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & { - TVector<TString> list; - BuildExtraSortColumns(groupBy, extraSortColumns, extraSortKeys, aggIndexBegin, aggIndexEnd, list); - for (ui32 i = 0; i < list.size(); ++i) { - TStringBuf from = list[i]; - from.SkipPrefix("_yql_extra_"); - parent.List(i) - .Atom(0, list[i]) - .Callable(1, "Member") - .Arg(0, "row") - .Atom(1, from) - .Seal() - .Seal(); - } + return ctx.Builder(root->Pos()) + .Callable("FlattenMembers") + .List(0) + .Atom(0, "") + .Add(1, root) + .Seal() + .List(1) + .Atom(0, "") + .Callable(1, "AsStruct") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder & { + TVector<TString> list; + BuildExtraSortColumns(groupBy, extraSortColumns, extraSortKeys, aggIndexBegin, aggIndexEnd, list); + for (ui32 i = 0; i < list.size(); ++i) { + TStringBuf from = list[i]; + from.SkipPrefix("_yql_extra_"); + parent.List(i) + .Atom(0, list[i]) + .Callable(1, "Member") + .Add(0, originalArg) + .Atom(1, from) + .Seal() + .Seal(); + } - return parent; - }) - .Seal() + return parent; + }) .Seal() .Seal() .Seal() @@ -2819,6 +2902,8 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct } else { YQL_ENSURE(result); TExprNode::TPtr projectionLambda = BuildProjectionLambda(node->Pos(), result, subLinkId.Defined(), ctx); + TExprNode::TPtr projectionArg = projectionLambda->Head().HeadPtr(); + TExprNode::TPtr projectionRoot = projectionLambda->TailPtr(); TVector<TString> inputAliases; TExprNode::TListType cleanedInputs; TWindowsCtx winCtx; @@ -2869,6 +2954,22 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct GatherAggregationsFromLambda(having->Tail().TailPtr(), aggs, aggId, false); } + if (!winCtx.Window2funcs.empty()) { + YQL_ENSURE(window); + for (const auto& x : winCtx.Window2funcs) { + auto winDef = window->Tail().Child(x.first); + for (ui32 i = 0; i < winDef->Child(2)->ChildrenSize(); ++i) { + const auto& group = winDef->Child(2)->Child(i); + GatherAggregationsFromLambda(group->TailPtr(), aggs, aggId, false); + } + + for (ui32 i = 0; i < winDef->Child(3)->ChildrenSize(); ++i) { + const auto& sort = winDef->Child(3)->Child(i); + GatherAggregationsFromLambda(sort->ChildPtr(1), aggs, aggId, false); + } + } + } + TExprNode::TPtr sortLambda; auto aggsSizeBeforeSort = aggs.size(); if (sort) { @@ -2876,16 +2977,8 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct GatherAggregationsFromLambda(sortLambda, aggs, aggId, false); } - auto projectionSubLinks = GatherSubLinks(projectionLambda); - if (!projectionSubLinks.empty()) { - std::tie(list, projectionLambda) = RewriteSubLinks(projectionLambda->Pos(), list, projectionLambda, - projectionSubLinks, inputAliases, cleanedInputs, &aggId, ctx, optCtx); - } else { - RewriteAggs(projectionLambda, aggId, ctx, optCtx, false); - } - if (groupBy) { - list = BuildGroup(node->Pos(), list, aggs, groupBy, finalExtTypes, ctx); + list = BuildGroup(node->Pos(), list, aggs, groupBy, finalExtTypes, ctx, optCtx); } if (having) { @@ -2893,23 +2986,35 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct } if (!winCtx.Funcs.empty()) { - list = BuildWindows(node->Pos(), list, window, winCtx, projectionLambda, ctx, optCtx); + list = BuildWindows(node->Pos(), list, window, winCtx, projectionRoot, projectionArg, aggId, ctx, optCtx); + } + + RewriteAggsPartial(projectionRoot, projectionArg, aggId, ctx, optCtx, false); + auto projectionSubLinks = GatherSubLinks(projectionRoot); + if (!projectionSubLinks.empty()) { + std::tie(projectionRoot, list) = RewriteSubLinksPartial(projectionLambda->Pos(), list, projectionRoot, projectionArg, + projectionSubLinks, inputAliases, cleanedInputs, ctx, optCtx); } if (finalExtTypes) { - projectionLambda = AddExtColumns(projectionLambda, finalExtTypes->TailPtr(), columnsItems, *subLinkId, ctx); + projectionRoot = AddExtColumns(projectionRoot, projectionArg, finalExtTypes->TailPtr(), columnsItems, *subLinkId, ctx); } bool hasExtraSortColumns = (extraSortColumns || extraSortKeys || (aggsSizeBeforeSort < aggs.size())); if (hasExtraSortColumns) { YQL_ENSURE(!distinctAll && !distinctOn); - projectionLambda = AddExtraSortColumns(projectionLambda, groupBy, extraSortColumns, extraSortKeys, aggsSizeBeforeSort, aggs.size(), ctx); + projectionRoot = AddExtraSortColumns(projectionRoot, projectionArg, groupBy, extraSortColumns, extraSortKeys, aggsSizeBeforeSort, aggs.size(), ctx); } list = ctx.Builder(node->Pos()) .Callable("OrderedMap") .Add(0, list) - .Add(1, projectionLambda) + .Lambda(1) + .Param("row") + .ApplyPartial(projectionLambda->HeadPtr(), projectionRoot) + .With(0, "row") + .Seal() + .Seal() .Seal() .Build(); @@ -2918,20 +3023,21 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct list = ctx.NewCallable(node->Pos(), "SqlAggregateAll", { list }); } else if (distinctOn) { YQL_ENSURE(!extraSortColumns); - list = BuildDistinctOn(node->Pos(), list, distinctOn->TailPtr(), sort, ctx); + list = BuildDistinctOn(node->Pos(), list, distinctOn->TailPtr(), sort, ctx, optCtx); } TVector<TString> sublinkColumns; if (sort) { - auto sortSubLinks = GatherSubLinks(sortLambda); + auto sortLambdaRoot = sortLambda->TailPtr(); + RewriteAggsPartial(sortLambdaRoot, sortLambda->Head().HeadPtr(), aggId, ctx, optCtx, false); + + auto sortSubLinks = GatherSubLinks(sortLambdaRoot); if (!sortSubLinks.empty()) { - std::tie(list, sortLambda) = RewriteSubLinks(sortLambda->Pos(), list, sortLambda, - sortSubLinks, inputAliases, cleanedInputs, &aggId, ctx, optCtx, "_yql_extra_", &sublinkColumns); - } else { - RewriteAggs(sortLambda, aggId, ctx, optCtx, false); + std::tie(sortLambdaRoot, list) = RewriteSubLinksPartial(sortLambda->Pos(), list, sortLambdaRoot, + sortLambda->Head().HeadPtr(), sortSubLinks, inputAliases, cleanedInputs, ctx, optCtx, "_yql_extra_", &sublinkColumns); } - list = BuildSort(node->Pos(), sort, list, sortLambda, ctx); + list = BuildSort(node->Pos(), sort, list, sortLambdaRoot, sortLambda->HeadPtr(), ctx); } if (hasExtraSortColumns) { @@ -2987,7 +3093,7 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct auto finalSort = GetSetting(node->Head(), "sort"); if (finalSort && finalSort->Tail().ChildrenSize() > 0) { auto finalSortLambda = BuildSortLambda(node->Pos(), finalSort, ctx); - list = BuildSort(node->Pos(), finalSort, list, finalSortLambda, ctx); + list = BuildSort(node->Pos(), finalSort, list, finalSortLambda->TailPtr(), finalSortLambda->HeadPtr(), ctx); } auto limit = GetSetting(node->Head(), "limit"); diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp index b1879b6e300..314eddab845 100644 --- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp @@ -1984,8 +1984,180 @@ void ScanAggregations(const TExprNode::TPtr& root, bool& hasAggregations) { } } +struct TGroupExpr { + TExprNode::TPtr OriginalRoot; + ui64 Hash; + TExprNode::TPtr TypeNode; +}; + +ui64 CalculateExprHash(const TExprNode& root, TNodeMap<ui64>& visited) { + auto it = visited.find(&root); + if (it != visited.end()) { + return it->second; + } + + ui64 hash = 0; + switch (root.Type()) { + case TExprNode::EType::Callable: + hash = CseeHash(root.Content().Size(), hash); + hash = CseeHash(root.Content().Data(), root.Content().Size(), hash); + [[fallthrough]]; + case TExprNode::EType::List: + hash = CseeHash(root.ChildrenSize(), hash); + for (ui32 i = 0; i < root.ChildrenSize(); ++i) { + hash = CalculateExprHash(*root.Child(i), visited); + } + + break; + case TExprNode::EType::Atom: + hash = CseeHash(root.Content().Size(), hash); + hash = CseeHash(root.Content().Data(), root.Content().Size(), hash); + hash = CseeHash(root.GetFlagsToCompare(), hash); + break; + default: + YQL_ENSURE(false, "Unexpected node type"); + } + + visited.emplace(&root, hash); + return hash; +} + +bool ExprNodesEquals(const TExprNode& left, const TExprNode& right, TNodeSet& visited) { + if (!visited.emplace(&left).second) { + return true; + } + + if (left.Type() != right.Type()) { + return false; + } + + switch (left.Type()) { + case TExprNode::EType::Callable: + if (left.Content() != right.Content()) { + return false; + } + + [[fallthrough]]; + case TExprNode::EType::List: + if (left.ChildrenSize() != right.ChildrenSize()) { + return false; + } + + for (ui32 i = 0; i < left.ChildrenSize(); ++i) { + if (!ExprNodesEquals(*left.Child(i), *right.Child(i), visited)) { + return false; + } + } + + return true; + case TExprNode::EType::Atom: + return left.Content() == right.Content() && left.GetFlagsToCompare() == right.GetFlagsToCompare(); + case TExprNode::EType::Argument: + return left.GetArgIndex() == right.GetArgIndex(); + default: + YQL_ENSURE(false, "Unexpected node type"); + } +} + +bool ScanExprForMatchedGroup(const TExprNode::TPtr& row, const TExprNode& root, const TVector<TGroupExpr>& exprs, + TNodeOnNodeOwnedMap& replaces, TNodeMap<ui64>& hashVisited, TNodeMap<bool>& nodeVisited, TExprContext& ctx) { + auto it = nodeVisited.find(&root); + if (it != nodeVisited.end()) { + return it->second; + } + + if (root.IsCallable("PgSubLink")) { + const auto& testRowLambda = *root.Child(3); + if (!testRowLambda.IsCallable("Void")) { + hashVisited[testRowLambda.Head().Child(0)] = 0; // original row + hashVisited[testRowLambda.Head().Child(1)] = 1; // sublink value + ScanExprForMatchedGroup(testRowLambda.Head().ChildPtr(0), testRowLambda.Tail(), + exprs, replaces, hashVisited, nodeVisited, ctx); + } + + nodeVisited[&root] = false; + return false; + } + + bool hasChanges = false; + for (const auto& child : root.Children()) { + if (!ScanExprForMatchedGroup(row, *child, exprs, replaces, hashVisited, nodeVisited, ctx)) { + hasChanges = true; + } + } + + if (hasChanges) { + nodeVisited[&root] = false; + return false; + } + + ui64 hash = CalculateExprHash(root, hashVisited); + for (ui32 i = 0; i < exprs.size(); ++i) { + if (exprs[i].Hash != hash) { + continue; + } + + TNodeSet equalsVisited; + if (!ExprNodesEquals(*exprs[i].OriginalRoot, root, equalsVisited)) { + continue; + } + + replaces[&root] = ctx.Builder(root.Pos()) + .Callable("PgGroupRef") + .Add(0, row) + .Add(1, exprs[i].TypeNode) + .Atom(2, ToString(i)) + .Seal() + .Build(); + + nodeVisited[&root] = false; + return false; + } + + nodeVisited[&root] = true; + return true; +} + +TExprNode::TPtr ReplaceGroupByExpr(const TExprNode::TPtr& root, const TExprNode& groups, TExprContext& ctx) { + if (!groups.ChildrenSize()) { + return root; + } + + // calculate hashes + TVector<TGroupExpr> exprs; + TExprNode::TListType typeNodes; + for (ui32 index = 0; index < groups.ChildrenSize(); ++index) { + const auto& g = *groups.Child(index); + const auto& lambda = g.Tail(); + TNodeMap<ui64> visited; + visited[&lambda.Head().Head()] = 0; + exprs.push_back({ + lambda.TailPtr(), + CalculateExprHash(lambda.Tail(), visited), + ExpandType(g.Pos(), *lambda.GetTypeAnn(), ctx) + }); + } + + TNodeOnNodeOwnedMap replaces; + TNodeMap<ui64> hashVisited; + TNodeMap<bool> nodeVisited; + hashVisited[&root->Head().Head()] = 0; + ScanExprForMatchedGroup(root->Head().HeadPtr(), root->Tail(), exprs, replaces, hashVisited, nodeVisited, ctx); + auto ret = root; + if (replaces.empty()) { + return ret; + } + + TOptimizeExprSettings settings(nullptr); + settings.VisitTuples = true; + auto status = RemapExpr(ret, ret, replaces, ctx, settings); + YQL_ENSURE(status != IGraphTransformer::TStatus::Error); + return ret; +} + bool ValidateGroups(TInputs& inputs, const THashSet<TString>& possibleAliases, - const TExprNode& data, TExtContext& ctx, TExprNode::TListType& newGroups, bool& hasNewGroups, bool scanColumnsOnly) { + const TExprNode& data, TExtContext& ctx, TExprNode::TListType& newGroups, bool& hasNewGroups, bool scanColumnsOnly, + bool allowAggregates, const TExprNode::TPtr& groupBy) { newGroups.clear(); hasNewGroups = false; bool hasColumnRef = false; @@ -2012,7 +2184,7 @@ bool ValidateGroups(TInputs& inputs, const THashSet<TString>& possibleAliases, bool hasNestedAggregations = false; ScanAggregations(group->Tail().TailPtr(), hasNestedAggregations); - if (hasNestedAggregations) { + if (!allowAggregates && hasNestedAggregations) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(group->Pos()), "Nested aggregations aren't allowed")); return false; } @@ -2042,9 +2214,19 @@ bool ValidateGroups(TInputs& inputs, const THashSet<TString>& possibleAliases, auto newGroup = ctx.Expr.NewCallable(group->Pos(), "PgGroup", std::move(newChildren)); newGroups.push_back(newGroup); hasNewGroups = true; - } else { - newGroups.push_back(data.ChildPtr(index)); + continue; } + + if (groupBy) { + auto ret = ReplaceGroupByExpr(group->TailPtr(), groupBy->Tail(), ctx.Expr); + if (ret != group->TailPtr()) { + newGroups.push_back(ctx.Expr.ChangeChild(*group, 1, std::move(ret))); + hasNewGroups = true; + continue; + } + } + + newGroups.push_back(data.ChildPtr(index)); } return true; @@ -2092,7 +2274,7 @@ TMap<TString, ui32> ExtractExternalColumns(const TExprNode& select) { } bool ValidateSort(TInputs& inputs, TInputs& subLinkInputs, const THashSet<TString>& possibleAliases, - const TExprNode& data, TExtContext& ctx, bool& hasNewSort, TExprNode::TListType& newSorts, bool scanColumnsOnly) { + const TExprNode& data, TExtContext& ctx, bool& hasNewSort, TExprNode::TListType& newSorts, bool scanColumnsOnly, const TExprNode::TPtr& groupBy) { newSorts.clear(); for (ui32 index = 0; index < data.ChildrenSize(); ++index) { auto oneSort = data.Child(index); @@ -2163,184 +2345,24 @@ bool ValidateSort(TInputs& inputs, TInputs& subLinkInputs, const THashSet<TStrin auto newSort = ctx.Expr.ChangeChildren(*oneSort, std::move(newChildren)); newSorts.push_back(newSort); hasNewSort = true; - } else { - newSorts.push_back(data.ChildPtr(index)); - } - } - - return true; -} - -ui64 CalculateExprHash(const TExprNode& root, TNodeMap<ui64>& visited) { - auto it = visited.find(&root); - if (it != visited.end()) { - return it->second; - } - - ui64 hash = 0; - switch (root.Type()) { - case TExprNode::EType::Callable: - hash = CseeHash(root.Content().Size(), hash); - hash = CseeHash(root.Content().Data(), root.Content().Size(), hash); - [[fallthrough]]; - case TExprNode::EType::List: - hash = CseeHash(root.ChildrenSize(), hash); - for (ui32 i = 0; i < root.ChildrenSize(); ++i) { - hash = CalculateExprHash(*root.Child(i), visited); - } - - break; - case TExprNode::EType::Atom: - hash = CseeHash(root.Content().Size(), hash); - hash = CseeHash(root.Content().Data(), root.Content().Size(), hash); - hash = CseeHash(root.GetFlagsToCompare(), hash); - break; - default: - YQL_ENSURE(false, "Unexpected node type"); - } - - visited.emplace(&root, hash); - return hash; -} - -bool ExprNodesEquals(const TExprNode& left, const TExprNode& right, TNodeSet& visited) { - if (!visited.emplace(&left).second) { - return true; - } - - if (left.Type() != right.Type()) { - return false; - } - - switch (left.Type()) { - case TExprNode::EType::Callable: - if (left.Content() != right.Content()) { - return false; - } - - [[fallthrough]]; - case TExprNode::EType::List: - if (left.ChildrenSize() != right.ChildrenSize()) { - return false; - } - - for (ui32 i = 0; i < left.ChildrenSize(); ++i) { - if (!ExprNodesEquals(*left.Child(i), *right.Child(i), visited)) { - return false; - } - } - - return true; - case TExprNode::EType::Atom: - return left.Content() == right.Content() && left.GetFlagsToCompare() == right.GetFlagsToCompare(); - case TExprNode::EType::Argument: - return left.GetArgIndex() == right.GetArgIndex(); - default: - YQL_ENSURE(false, "Unexpected node type"); - } -} - -struct TGroupExpr { - TExprNode::TPtr OriginalRoot; - ui64 Hash; - TExprNode::TPtr TypeNode; -}; - -bool ScanExprForMatchedGroup(const TExprNode::TPtr& row, const TExprNode& root, const TVector<TGroupExpr>& exprs, - TNodeOnNodeOwnedMap& replaces, TNodeMap<ui64>& hashVisited, TNodeMap<bool>& nodeVisited, TExprContext& ctx) { - auto it = nodeVisited.find(&root); - if (it != nodeVisited.end()) { - return it->second; - } - - if (root.IsCallable("PgSubLink")) { - const auto& testRowLambda = *root.Child(3); - if (!testRowLambda.IsCallable("Void")) { - hashVisited[testRowLambda.Head().Child(0)] = 0; // original row - hashVisited[testRowLambda.Head().Child(1)] = 1; // sublink value - ScanExprForMatchedGroup(testRowLambda.Head().ChildPtr(0), testRowLambda.Tail(), - exprs, replaces, hashVisited, nodeVisited, ctx); - } - - nodeVisited[&root] = false; - return false; - } - - bool hasChanges = false; - for (const auto& child : root.Children()) { - if (!ScanExprForMatchedGroup(row, *child, exprs, replaces, hashVisited, nodeVisited, ctx)) { - hasChanges = true; - } - } - - if (hasChanges) { - nodeVisited[&root] = false; - return false; - } - - ui64 hash = CalculateExprHash(root, hashVisited); - for (ui32 i = 0; i < exprs.size(); ++i) { - if (exprs[i].Hash != hash) { continue; } - TNodeSet equalsVisited; - if (!ExprNodesEquals(*exprs[i].OriginalRoot, root, equalsVisited)) { - continue; + if (groupBy) { + auto ret = ReplaceGroupByExpr(newLambda, groupBy->Tail(), ctx.Expr); + if (ret != newLambda) { + newSorts.push_back(ctx.Expr.ChangeChild(*oneSort, 1, std::move(ret))); + hasNewSort = true; + continue; + } } - replaces[&root] = ctx.Builder(root.Pos()) - .Callable("PgGroupRef") - .Add(0, row) - .Add(1, exprs[i].TypeNode) - .Atom(2, ToString(i)) - .Seal() - .Build(); - nodeVisited[&root] = false; - return false; + newSorts.push_back(data.ChildPtr(index)); } - nodeVisited[&root] = true; return true; } -TExprNode::TPtr ReplaceGroupByExpr(const TExprNode::TPtr& root, const TExprNode& groups, TExprContext& ctx) { - if (!groups.ChildrenSize()) { - return root; - } - - // calculate hashes - TVector<TGroupExpr> exprs; - TExprNode::TListType typeNodes; - for (ui32 index = 0; index < groups.ChildrenSize(); ++index) { - const auto& g = *groups.Child(index); - const auto& lambda = g.Tail(); - TNodeMap<ui64> visited; - visited[&lambda.Head().Head()] = 0; - exprs.push_back({ - lambda.TailPtr(), - CalculateExprHash(lambda.Tail(), visited), - ExpandType(g.Pos(), *lambda.GetTypeAnn(), ctx) - }); - } - - TNodeOnNodeOwnedMap replaces; - TNodeMap<ui64> hashVisited; - TNodeMap<bool> nodeVisited; - hashVisited[&root->Head().Head()] = 0; - ScanExprForMatchedGroup(root->Head().HeadPtr(), root->Tail(), exprs, replaces, hashVisited, nodeVisited, ctx); - auto ret = root; - if (replaces.empty()) { - return ret; - } - - TOptimizeExprSettings settings(nullptr); - settings.VisitTuples = true; - auto status = RemapExpr(ret, ret, replaces, ctx, settings); - YQL_ENSURE(status != IGraphTransformer::TStatus::Error); - return ret; -} - bool GatherExtraSortColumns(const TExprNode& data, const TInputs& inputs, TExprNode::TPtr& extraInputColumns, TExprNode::TPtr& extraKeys, TExprContext& ctx) { ui32 inputsCount = inputs.size() - 1; THashMap<ui32, TSet<TString>> columns; @@ -3180,7 +3202,7 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN TExprNode::TListType newGroups; bool hasNewGroups = false; - if (!ValidateGroups(joinInputs, possibleAliases, data, ctx, newGroups, hasNewGroups, scanColumnsOnly)) { + if (!ValidateGroups(joinInputs, possibleAliases, data, ctx, newGroups, hasNewGroups, scanColumnsOnly, false, nullptr)) { return IGraphTransformer::TStatus::Error; } @@ -3223,28 +3245,19 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN auto partitions = x->Child(2); auto sort = x->Child(3); - bool needRebuildPartition = false; - for (const auto& p : partitions->Children()) { - if (p->Child(0)->IsCallable("Void")) { - needRebuildPartition = true; - break; - } - } auto newChildren = x->ChildrenList(); - if (needRebuildPartition) { - TExprNode::TListType newGroups; - bool hasNewGroups = false; - if (!ValidateGroups(joinInputs, possibleAliases, *partitions, ctx, newGroups, hasNewGroups, scanColumnsOnly)) { - return IGraphTransformer::TStatus::Error; - } - - newChildren[2] = ctx.Expr.NewList(x->Pos(), std::move(newGroups)); + TExprNode::TListType newGroups; + bool hasNewGroups = false; + if (!ValidateGroups(joinInputs, possibleAliases, *partitions, ctx, newGroups, hasNewGroups, scanColumnsOnly, true, groupBy)) { + return IGraphTransformer::TStatus::Error; } + newChildren[2] = ctx.Expr.NewList(x->Pos(), std::move(newGroups)); + bool hasNewSort = false; TExprNode::TListType newSorts; - if (!ValidateSort(joinInputs, joinInputs, possibleAliases, *sort, ctx, hasNewSort, newSorts, scanColumnsOnly)) { + if (!ValidateSort(joinInputs, joinInputs, possibleAliases, *sort, ctx, hasNewSort, newSorts, scanColumnsOnly, groupBy)) { return IGraphTransformer::TStatus::Error; } @@ -3252,7 +3265,7 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN newChildren[3] = ctx.Expr.NewList(x->Pos(), std::move(newSorts)); } - if (needRebuildPartition || hasNewSort) { + if (hasNewGroups || hasNewSort) { hasChanges = true; newWindow.push_back(ctx.Expr.ChangeChildren(*x, std::move(newChildren))); } else { @@ -3299,7 +3312,7 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN TExprNode::TListType newGroups; TInputs projectionInputs; projectionInputs.push_back(TInput{ "", outputRowType, Nothing(), TInput::Projection, {} }); - if (!ValidateGroups(projectionInputs, {}, data, ctx, newGroups, hasNewGroups, scanColumnsOnly)) { + if (!ValidateGroups(projectionInputs, {}, data, ctx, newGroups, hasNewGroups, scanColumnsOnly, false, nullptr)) { return IGraphTransformer::TStatus::Error; } @@ -3338,7 +3351,7 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN bool hasNewSort = false; TExprNode::TListType newSortTupleItems; // no effective types yet, scan lambda bodies - if (!ValidateSort(projectionInputs, joinInputs, possibleAliases, data, ctx, hasNewSort, newSortTupleItems, scanColumnsOnly)) { + if (!ValidateSort(projectionInputs, joinInputs, possibleAliases, data, ctx, hasNewSort, newSortTupleItems, scanColumnsOnly, groupBy)) { return IGraphTransformer::TStatus::Error; } @@ -3351,29 +3364,6 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN } if (!scanColumnsOnly) { - if (groupBy) { - TExprNode::TListType newSortItems; - bool hasChanges = false; - for (ui32 index = 0; index < data.ChildrenSize(); ++index) { - const auto& column = *data.Child(index); - auto ret = ReplaceGroupByExpr(column.ChildPtr(1), groupBy->Tail(), ctx.Expr); - if (ret != column.ChildPtr(1)) { - hasChanges = true; - newSortItems.push_back(ctx.Expr.ChangeChild(column, 1, std::move(ret))); - } - else { - newSortItems.push_back(data.ChildPtr(index)); - } - } - - if (hasChanges) { - auto newSort = ctx.Expr.NewList(input->Pos(), std::move(newSortItems)); - auto newSettings = ReplaceSetting(options, {}, "sort", newSort, ctx.Expr); - output = ctx.Expr.ChangeChild(*input, 0, std::move(newSettings)); - return IGraphTransformer::TStatus::Repeat; - } - } - if (!GetSetting(options, "final_extra_sort_columns") && !GetSetting(options, "final_extra_sort_keys")) { TExprNode::TPtr extraColumns; TExprNode::TPtr extraKeys; @@ -3727,7 +3717,7 @@ IGraphTransformer::TStatus PgSelectWrapper(const TExprNode::TPtr& input, TExprNo // no effective types yet, scan lambda bodies bool hasNewSort = false; - if (!ValidateSort(projectionInputs, projectionInputs, {}, data, ctx, hasNewSort, newSortTupleItems, false)) { + if (!ValidateSort(projectionInputs, projectionInputs, {}, data, ctx, hasNewSort, newSortTupleItems, false, nullptr)) { return IGraphTransformer::TStatus::Error; } diff --git a/ydb/library/yql/sql/pg/pg_sql.cpp b/ydb/library/yql/sql/pg/pg_sql.cpp index 027c4d19ecb..23c55f3abc4 100644 --- a/ydb/library/yql/sql/pg/pg_sql.cpp +++ b/ydb/library/yql/sql/pg/pg_sql.cpp @@ -1837,7 +1837,7 @@ public: return nullptr; } - auto sort = ParseSortBy(CAST_NODE_EXT(PG_SortBy, T_SortBy, node), false); + auto sort = ParseSortBy(CAST_NODE_EXT(PG_SortBy, T_SortBy, node), true); if (!sort) { return nullptr; } @@ -1851,6 +1851,7 @@ public: auto node = ListNodeNth(value->partitionClause, i); TExprSettings settings; settings.AllowColumns = true; + settings.AllowAggregates = true; settings.Scope = "PARTITITON BY"; auto expr = ParseExpr(node, settings); if (!expr) { |
