aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@yandex-team.ru>2022-04-07 06:40:48 +0300
committervvvv <vvvv@yandex-team.ru>2022-04-07 06:40:48 +0300
commitfb12d151cadf2e1ad39b0b6a7d8baa5d31ab663a (patch)
tree0d008db68e0d69dbb89fc3d9532508b3bd9be5a0
parent7c618ed9486574191f73591208f9d4388c4f1c94 (diff)
downloadydb-fb12d151cadf2e1ad39b0b6a7d8baa5d31ab663a.tar.gz
YQL-13710 pgtypes in aggregations over window
ref:b398653778664bb220808d0e2f22b5cef56f0b12
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_pgselect.cpp97
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp3
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_pg.cpp43
3 files changed, 87 insertions, 56 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
index 6161f56d04..eee4c080b0 100644
--- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
@@ -476,6 +476,30 @@ std::tuple<TAggs, TNodeMap<ui32>> GatherAggregations(const TExprNode::TPtr& proj
return { aggs, aggId };
}
+TExprNode::TPtr BuildAggregationTraits(TPositionHandle pos, bool onWindow,
+ const std::pair<TExprNode::TPtr, TExprNode::TPtr>& agg,
+ const TExprNode::TPtr& listTypeNode, TExprContext& ctx) {
+ auto arg = ctx.NewArgument(pos, "row");
+ auto arguments = ctx.NewArguments(pos, { arg });
+ auto func = agg.first->Head().Content();
+ TExprNode::TListType aggFuncArgs;
+ for (ui32 j = onWindow ? 2 : 1; j < agg.first->ChildrenSize(); ++j) {
+ aggFuncArgs.push_back(ctx.ReplaceNode(agg.first->ChildPtr(j), *agg.second, arg));
+ }
+
+ auto extractor = ctx.NewLambda(pos, std::move(arguments), std::move(aggFuncArgs));
+
+ return ctx.Builder(pos)
+ .Callable(onWindow ? "PgWindowTraits" : "PgAggregationTraits")
+ .Atom(0, func)
+ .Callable(1, "ListItemType")
+ .Add(0, listTypeNode)
+ .Seal()
+ .Add(2, extractor)
+ .Seal()
+ .Build();
+}
+
TExprNode::TPtr BuildGroupByAndHaving(TPositionHandle pos, const TExprNode::TPtr& list, const TAggs& aggs, const TNodeMap<ui32>& aggId,
const TExprNode::TPtr& groupBy, const TExprNode::TPtr& having, TExprNode::TPtr& projectionLambda, TExprContext& ctx, TOptimizeContext& optCtx) {
auto listTypeNode = ctx.Builder(pos)
@@ -490,28 +514,11 @@ TExprNode::TPtr BuildGroupByAndHaving(TPositionHandle pos, const TExprNode::TPtr
TNodeOnNodeOwnedMap deepClones;
TExprNode::TListType payloadItems;
for (ui32 i = 0; i < aggs.size(); ++i) {
- auto func = aggs[i].first->Head().Content();
TExprNode::TPtr traits;
if (optCtx.Types->PgTypes) {
- auto arg = ctx.NewArgument(pos, "row");
- auto arguments = ctx.NewArguments(pos, { arg });
- TExprNode::TListType aggFuncArgs;
- for (ui32 j = 1; j < aggs[i].first->ChildrenSize(); ++j) {
- aggFuncArgs.push_back(ctx.ReplaceNode(aggs[i].first->ChildPtr(j), *aggs[i].second, arg));
- }
-
- auto extractor = ctx.NewLambda(pos, std::move(arguments), std::move(aggFuncArgs));
-
- traits = ctx.Builder(pos)
- .Callable("PgAggregationTraits")
- .Atom(0, func)
- .Callable(1, "ListItemType")
- .Add(0, listTypeNode)
- .Seal()
- .Add(2, extractor)
- .Seal()
- .Build();
+ traits = BuildAggregationTraits(pos, false, aggs[i], listTypeNode, ctx);
} else {
+ auto func = aggs[i].first->Head().Content();
const auto& exports = exportsPtr->Symbols();
if (func == "count" && aggs[i].first->ChildrenSize() == 1) {
func = "count_all";
@@ -806,32 +813,36 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c
bool isAgg = p.first->IsCallable("PgAggWindowCall");
TExprNode::TPtr value;
if (isAgg) {
- const auto& exports = exportsPtr->Symbols();
- if (name == "count" && p.first->ChildrenSize() == 2) {
- name = "count_all";
- }
+ if (optCtx.Types->PgTypes) {
+ value = BuildAggregationTraits(pos, true, p, listTypeNode, ctx);
+ } else {
+ const auto& exports = exportsPtr->Symbols();
+ if (name == "count" && p.first->ChildrenSize() == 2) {
+ name = "count_all";
+ }
- TString factory = TString(name) + "_traits_factory";
- const auto ex = exports.find(factory);
- YQL_ENSURE(exports.cend() != ex);
- auto lambda = ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false);
- auto arg = ctx.NewArgument(pos, "row");
- auto arguments = ctx.NewArguments(pos, { arg });
- auto extractor = ctx.NewLambda(pos, std::move(arguments),
- ctx.ReplaceNode(p.first->TailPtr(), *p.second, arg));
-
- auto traits = ctx.ReplaceNodes(lambda->TailPtr(), {
- {lambda->Head().Child(0), listTypeNode},
- {lambda->Head().Child(1), extractor}
- });
-
- ctx.Step.Repeat(TExprStep::ExpandApplyForLambdas);
- auto status = ExpandApply(traits, traits, ctx);
- if (status == IGraphTransformer::TStatus::Error) {
- return {};
- }
+ TString factory = TString(name) + "_traits_factory";
+ const auto ex = exports.find(factory);
+ YQL_ENSURE(exports.cend() != ex);
+ auto lambda = ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false);
+ auto arg = ctx.NewArgument(pos, "row");
+ auto arguments = ctx.NewArguments(pos, { arg });
+ auto extractor = ctx.NewLambda(pos, std::move(arguments),
+ ctx.ReplaceNode(p.first->TailPtr(), *p.second, arg));
+
+ auto traits = ctx.ReplaceNodes(lambda->TailPtr(), {
+ {lambda->Head().Child(0), listTypeNode},
+ {lambda->Head().Child(1), extractor}
+ });
- value = traits;
+ ctx.Step.Repeat(TExprStep::ExpandApplyForLambdas);
+ auto status = ExpandApply(traits, traits, ctx);
+ if (status == IGraphTransformer::TStatus::Error) {
+ return {};
+ }
+
+ value = traits;
+ }
} else {
if (name == "row_number") {
value = ctx.Builder(pos)
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 5d4000e8b3..fd50cb8301 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -3216,7 +3216,7 @@ namespace NTypeAnnImpl {
if (input->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) {
output = ctx.Expr.Builder(input->Pos())
.Callable("FromFlow")
- .Callable("WithContext")
+ .Callable(0, "WithContext")
.Callable(0, "ToFlow")
.Add(0, input->HeadPtr())
.Seal()
@@ -11256,6 +11256,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["PgType"] = &PgTypeWrapper;
Functions["PgCast"] = &PgCastWrapper;
Functions["PgAggregationTraits"] = &PgAggregationTraitsWrapper;
+ Functions["PgWindowTraits"] = &PgAggregationTraitsWrapper;
Functions["PgInternal0"] = &PgInternal0Wrapper;
Functions["AutoDemuxList"] = &AutoDemuxListWrapper;
Functions["AggrCountInit"] = &AggrCountInitWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp
index c0313f9bfe..94c01c68ca 100644
--- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp
@@ -964,6 +964,7 @@ IGraphTransformer::TStatus PgCastWrapper(const TExprNode::TPtr& input, TExprNode
}
IGraphTransformer::TStatus PgAggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ const bool onWindow = input->IsCallable("PgWindowTraits");
if (!EnsureArgsCount(*input, 3, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -1248,18 +1249,36 @@ IGraphTransformer::TStatus PgAggregationTraitsWrapper(const TExprNode::TPtr& inp
}
auto typeNode = ExpandType(input->Pos(), *itemType, ctx.Expr);
- output = ctx.Expr.Builder(input->Pos())
- .Callable("AggregationTraits")
- .Add(0, typeNode)
- .Add(1, initLambda)
- .Add(2, updateLambda)
- .Add(3, saveLambda)
- .Add(4, loadLambda)
- .Add(5, mergeLambda)
- .Add(6, finishLambda)
- .Add(7, defaultValue)
- .Seal()
- .Build();
+ if (onWindow) {
+ output = ctx.Expr.Builder(input->Pos())
+ .Callable("WindowTraits")
+ .Add(0, typeNode)
+ .Add(1, initLambda)
+ .Add(2, updateLambda)
+ .Lambda(3)
+ .Param("value")
+ .Param("state")
+ .Callable("Void")
+ .Seal()
+ .Seal()
+ .Add(4, finishLambda)
+ .Add(5, defaultValue)
+ .Seal()
+ .Build();
+ } else {
+ output = ctx.Expr.Builder(input->Pos())
+ .Callable("AggregationTraits")
+ .Add(0, typeNode)
+ .Add(1, initLambda)
+ .Add(2, updateLambda)
+ .Add(3, saveLambda)
+ .Add(4, loadLambda)
+ .Add(5, mergeLambda)
+ .Add(6, finishLambda)
+ .Add(7, defaultValue)
+ .Seal()
+ .Build();
+ }
return IGraphTransformer::TStatus::Repeat;
}