aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2023-05-22 16:51:52 +0300
committervvvv <vvvv@ydb.tech>2023-05-22 16:51:52 +0300
commit7b06ae00f14aff82006eeeb5375af1a79ef2cf00 (patch)
treee7601642d8ab52b8ad9a00da28dbc8750acc2330
parentf86980eca696fad8264c575b94b8e0e2aa26be29 (diff)
downloadydb-7b06ae00f14aff82006eeeb5375af1a79ef2cf00.tar.gz
Upper part of block PG aggregations
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.cpp4
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp117
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.cpp71
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.cpp48
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.h3
5 files changed, 163 insertions, 80 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
index 3167d31745e..1f7343585f8 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
@@ -530,7 +530,7 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType
return false;
}
- if (agg->ChildrenSize() != agg->Head().ChildrenSize()) {
+ if (agg->ChildrenSize() + (overState ? 1 : 0) != agg->Head().ChildrenSize()) {
ctx.AddError(TIssue(ctx.GetPosition(pos), "Different amount of input arguments"));
return false;
}
@@ -548,7 +548,7 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType
return false;
}
- auto applyArgType = agg->Head().Child(i)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ auto applyArgType = agg->Head().Child(i + (overState ? 1 : 0))->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
auto expectedType = many ? ctx.MakeType<TOptionalExprType>(applyArgType) : applyArgType;
if (!IsSameAnnotation(*inputItems[argColumnIndex], *expectedType)) {
ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() <<
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index 0de23081a7d..0e259b8e31c 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -23,42 +23,6 @@ namespace {
return x->GetTypeAnn() && x->GetTypeAnn()->GetKind() == ETypeAnnotationKind::EmptyList;
};
- const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
- if (!EnsureStructType(pos, *originalExtractorType, ctx)) {
- return nullptr;
- }
-
- auto structType = originalExtractorType->Cast<TStructExprType>();
- if (structType->GetSize() != 1) {
- ctx.AddError(TIssue(ctx.GetPosition(pos),
- TStringBuilder() << "Expected struct with one member"));
- return nullptr;
- }
-
- auto type = structType->GetItems()[0]->GetItemType();
- if (isMany) {
- if (type->GetKind() != ETypeAnnotationKind::Optional) {
- ctx.AddError(TIssue(ctx.GetPosition(pos),
- TStringBuilder() << "Expected optional state"));
- return nullptr;
- }
-
- type = type->Cast<TOptionalExprType>()->GetItemType();
- }
-
- return type;
- }
-
- bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
- auto type = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx);
- if (!type) {
- return false;
- }
-
- input->SetTypeAnn(type);
- return true;
- }
-
TExprNode::TPtr RewriteMultiAggregate(const TExprNode& node, TExprContext& ctx) {
auto exprLambda = node.Child(1);
const TStructExprType* structType = nullptr;
@@ -5439,7 +5403,7 @@ namespace {
IGraphTransformer::TStatus AggBlockApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
const bool overState = input->Content().EndsWith("State");
- if (!EnsureMinArgsCount(*input, 1, ctx.Expr)) {
+ if (!EnsureMinArgsCount(*input, overState ? 2 : 1, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -5447,23 +5411,38 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
- auto name = input->Child(0)->Content();
- ui32 expectedArgs;
- if (name == "count_all") {
- expectedArgs = overState ? 2 : 1;
- } else if (name == "count" || name == "sum" || name == "avg" || name == "min" || name == "max" || name == "some") {
- expectedArgs = 2;
- } else {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
- TStringBuilder() << "Unsupported agg name: " << name));
- return IGraphTransformer::TStatus::Error;
+ const TTypeAnnotationNode* originalType = nullptr;
+ if (overState && !input->Child(1)->IsCallable("Void")) {
+ if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ originalType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
}
- if (!EnsureArgsCount(*input, expectedArgs, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
+ auto name = input->Child(0)->Content();
+ if (!name.StartsWith("pg_")) {
+ ui32 expectedArgs;
+ if (name == "count_all") {
+ expectedArgs = overState ? 2 : 1;
+ } else if (name == "count" || name == "sum" || name == "avg" || name == "min" || name == "max" || name == "some") {
+ expectedArgs = 2;
+ } else {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Unsupported agg name: " << name));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (overState) {
+ ++expectedArgs;
+ }
+
+ if (!EnsureArgsCount(*input, expectedArgs, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
}
- for (ui32 i = 1; i < expectedArgs; ++i) {
+ for (ui32 i = overState ? 2 : 1; i < input->ChildrenSize(); ++i) {
if (auto status = EnsureTypeRewrite(input->ChildRef(i), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
return status;
}
@@ -5472,7 +5451,7 @@ namespace {
if (name == "count_all" || name == "count") {
const TTypeAnnotationNode* retType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
if (overState) {
- auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ auto itemType = input->Child(2)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
if (!IsSameAnnotation(*itemType, *retType)) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Mismatch count type, expected: " << *itemType << ", but got: " << *retType));
@@ -5482,7 +5461,7 @@ namespace {
input->SetTypeAnn(retType);
} else if (name == "sum") {
- auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
const TTypeAnnotationNode* retType;
if (!GetSumResultType(input->Pos(), *itemType, retType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -5498,7 +5477,7 @@ namespace {
input->SetTypeAnn(retType);
} else if (name == "avg") {
- auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
const TTypeAnnotationNode* retType;
if (!overState) {
if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) {
@@ -5512,7 +5491,7 @@ namespace {
input->SetTypeAnn(retType);
} else if (name == "min" || name == "max") {
- auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
const TTypeAnnotationNode* retType;
if (!GetMinMaxResultType(input->Pos(), *itemType, retType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -5528,9 +5507,37 @@ namespace {
input->SetTypeAnn(retType);
} else if (name == "some") {
- auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ auto itemType = input->Child(overState ? 2 : 1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
const TTypeAnnotationNode* retType = itemType;
input->SetTypeAnn(retType);
+ } else if (name.StartsWith("pg_")) {
+ auto func = name.SubStr(3);
+ TVector<ui32> argTypes;
+ for (ui32 i = 1 + (overState ? 1 : 0); i < input->ChildrenSize(); ++i) {
+ argTypes.push_back(input->Child(i)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TPgExprType>()->GetId());
+ }
+
+ const NPg::TAggregateDesc* aggDescPtr;
+ if (overState) {
+ YQL_ENSURE(argTypes.size() == 1);
+ if (!originalType) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Partial aggreation is not supported for: " << name));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto resultType = originalType->Cast<TPgExprType>()->GetId();
+ aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes[0], resultType);
+ } else {
+ aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes);
+ }
+
+ if (overState) {
+ input->SetTypeAnn(originalType);
+ } else {
+ auto stateType = NPg::LookupProc(aggDescPtr->SerializeFuncId ? aggDescPtr->SerializeFuncId : aggDescPtr->TransFuncId).ResultType;
+ input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(stateType));
+ }
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Unsupported agg name: " << name));
diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp
index 8ac57ad21d4..a5c1cacd623 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.cpp
+++ b/ydb/library/yql/core/yql_aggregate_expander.cpp
@@ -570,20 +570,22 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea
auto trait = AggregatedColumns->Child(index)->ChildPtr(1);
TVector<const TTypeAnnotationNode*> allTypes;
+ const TTypeAnnotationNode* originalType = nullptr;
+ if (overState && !trait->Child(3)->IsCallable("Void")) {
+ auto originalExtractorType = trait->Child(3)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ originalType = GetOriginalResultType(trait->Pos(), many, originalExtractorType, Ctx);
+ YQL_ENSURE(originalType);
+ }
+
+ ui32 argsCount = trait->Child(2)->ChildrenSize() - 1;
if (!overState && trait->Child(0)->Content() == "count_all") {
- // 0 columns
- aggs.push_back(Ctx.Builder(Node->Pos())
- .List()
- .Callable(0, "AggBlockApply")
- .Atom(0, trait->Child(0)->Content())
- .Seal()
- .Seal()
- .Build());
+ argsCount = 0;
}
- else {
- // 1 column
- auto root = trait->Child(2)->TailPtr();
- auto rowArg = &trait->Child(2)->Head().Head();
+
+ auto rowArg = &trait->Child(2)->Head().Head();
+ TVector<TExprNode::TPtr> roots;
+ for (ui32 i = 1; i < argsCount + 1; ++i) {
+ auto root = trait->Child(2)->ChildPtr(i);
allTypes.push_back(root->GetTypeAnn());
auto status = OptimizeExpr(root, root, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr {
@@ -598,17 +600,46 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea
}, Ctx, TOptimizeExprSettings(&TypesCtx));
YQL_ENSURE(status.Level != IGraphTransformer::TStatus::Error);
+ roots.push_back(root);
+ }
- aggs.push_back(Ctx.Builder(Node->Pos())
- .List()
- .Callable(0, TString("AggBlockApply") + (overState ? "State" : ""))
- .Atom(0, trait->Child(0)->Content())
- .Add(1, ExpandType(Node->Pos(), *trait->Child(2)->GetTypeAnn(), Ctx))
- .Seal()
- .Atom(1, ToString(extractorRoots.size()))
+ aggs.push_back(Ctx.Builder(Node->Pos())
+ .List()
+ .Callable(0, TString("AggBlockApply") + (overState ? "State" : ""))
+ .Atom(0, trait->Child(0)->Content())
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
+ if (overState) {
+ if (originalType) {
+ parent.Add(1, ExpandType(Node->Pos(), *originalType, Ctx));
+ } else {
+ parent
+ .Callable(1, "Null")
+ .Seal();
+ }
+ }
+
+ return parent;
+ })
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
+ for (ui32 i = 1; i < argsCount + 1; ++i) {
+ parent.Add(i + (overState ? 1 : 0), ExpandType(Node->Pos(), *trait->Child(2)->Child(i)->GetTypeAnn(), Ctx));
+ return parent;
+ }
+
+ return parent;
+ })
.Seal()
- .Build());
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
+ for (ui32 i = 1; i < argsCount + 1; ++i) {
+ parent.Atom(i, ToString(extractorRoots.size() + i - 1));
+ }
+
+ return parent;
+ })
+ .Seal()
+ .Build());
+ for (auto root : roots) {
if (many) {
if (root->IsCallable("Unwrap")) {
root = root->HeadPtr();
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp
index 12135cbbf70..79ca51ece68 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.cpp
+++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp
@@ -5569,9 +5569,15 @@ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& in
} else if (name.StartsWith("pg_")) {
auto func = name.SubStr(3);
TVector<ui32> argTypes;
- bool needRetype = false;
- auto status = ExtractPgTypesFromMultiLambda(input->ChildRef(2), argTypes, needRetype, ctx);
- YQL_ENSURE(status == IGraphTransformer::TStatus::Ok);
+ if (input->Content().StartsWith("AggBlock")) {
+ for (ui32 i = 1; i < input->ChildrenSize(); ++i) {
+ argTypes.push_back(input->Child(i)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TPgExprType>()->GetId());
+ }
+ } else {
+ bool needRetype = false;
+ auto status = ExtractPgTypesFromMultiLambda(input->ChildRef(2), argTypes, needRetype, ctx);
+ YQL_ENSURE(status == IGraphTransformer::TStatus::Ok);
+ }
const NPg::TAggregateDesc& aggDesc = NPg::LookupAggregation(TString(func), argTypes);
const auto& procDesc = NPg::LookupProc(aggDesc.SerializeFuncId ? aggDesc.SerializeFuncId : aggDesc.TransFuncId);
@@ -6036,5 +6042,41 @@ TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggre
}
}
+const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
+ if (!EnsureStructType(pos, *originalExtractorType, ctx)) {
+ return nullptr;
+ }
+
+ auto structType = originalExtractorType->Cast<TStructExprType>();
+ if (structType->GetSize() != 1) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos),
+ TStringBuilder() << "Expected struct with one member"));
+ return nullptr;
+ }
+
+ auto type = structType->GetItems()[0]->GetItemType();
+ if (isMany) {
+ if (type->GetKind() != ETypeAnnotationKind::Optional) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos),
+ TStringBuilder() << "Expected optional state"));
+ return nullptr;
+ }
+
+ type = type->Cast<TOptionalExprType>()->GetItemType();
+ }
+
+ return type;
+}
+
+bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
+ auto type = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx);
+ if (!type) {
+ return false;
+ }
+
+ input->SetTypeAnn(type);
+ return true;
+}
+
} // NYql
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h
index 6daf85dcca4..a2b665e562b 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.h
+++ b/ydb/library/yql/core/yql_expr_type_annotation.h
@@ -324,4 +324,7 @@ IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda
TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow,
const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx);
+const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx);
+bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx);
+
}