aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2023-05-19 17:48:47 +0300
committervvvv <vvvv@ydb.tech>2023-05-19 17:48:47 +0300
commit253904e00ed53deb7ba78466b41f25a78f74b6de (patch)
treec67c6d444d76ff0751c092dffabfd6feccddd81a
parent81eb00dfba71fc3110d5af2eed7fa320ace33e5b (diff)
downloadydb-253904e00ed53deb7ba78466b41f25a78f74b6de.tar.gz
Implementation of agg_apply mode for Pg aggregations
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_datasource.cpp4
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_flow2.cpp11
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp87
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_pg.cpp342
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.cpp20
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.cpp341
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.h7
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.cpp36
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.h1
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp8
-rw-r--r--ydb/library/yql/sql/v1/aggregation.cpp5
11 files changed, 521 insertions, 341 deletions
diff --git a/ydb/core/kqp/provider/yql_kikimr_datasource.cpp b/ydb/core/kqp/provider/yql_kikimr_datasource.cpp
index 65f45687112..e92fb35197c 100644
--- a/ydb/core/kqp/provider/yql_kikimr_datasource.cpp
+++ b/ydb/core/kqp/provider/yql_kikimr_datasource.cpp
@@ -535,8 +535,8 @@ public:
if (typeInfo.GetTypeId() == NScheme::NTypeIds::Pg) {
auto* typeDesc = typeInfo.GetTypeDesc();
auto* pg = ydbType.mutable_pg_type();
- pg->set_type_name(NPg::PgTypeNameFromTypeDesc(typeDesc));
- pg->set_oid(NPg::PgTypeIdFromTypeDesc(typeDesc));
+ pg->set_type_name(NKikimr::NPg::PgTypeNameFromTypeDesc(typeDesc));
+ pg->set_oid(NKikimr::NPg::PgTypeIdFromTypeDesc(typeDesc));
} else {
auto& item = notNull
? ydbType
diff --git a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
index a5dc8179691..e30ef474cf9 100644
--- a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
@@ -1746,12 +1746,15 @@ void RegisterCoFlowCallables2(TCallableOptimizerMap& map) {
auto structType = type->Cast<TStructExprType>();
TSet<TStringBuf> usedFields;
auto extractor = node->Child(2);
- TSet<TStringBuf> lambdaSubset;
- if (!HaveFieldsSubset(extractor->ChildPtr(1), *extractor->Child(0)->Child(0), lambdaSubset, *optCtx.ParentsMap)) {
- return node;
+ for (ui32 i = 1; i < extractor->ChildrenSize(); ++i) {
+ TSet<TStringBuf> lambdaSubset;
+ if (!HaveFieldsSubset(extractor->ChildPtr(i), *extractor->Child(0)->Child(0), lambdaSubset, *optCtx.ParentsMap)) {
+ return node;
+ }
+
+ usedFields.insert(lambdaSubset.cbegin(), lambdaSubset.cend());
}
- usedFields.insert(lambdaSubset.cbegin(), lambdaSubset.cend());
if (usedFields.size() == structType->GetSize()) {
return node;
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index 4af726c03f1..0de23081a7d 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -8,6 +8,8 @@
#include <ydb/library/yql/core/yql_opt_window.h>
#include <ydb/library/yql/core/yql_type_helpers.h>
+#include <ydb/library/yql/parser/pg_catalog/catalog.h>
+
#include <util/generic/algorithm.h>
#include <util/string/join.h>
@@ -21,29 +23,39 @@ namespace {
return x->GetTypeAnn() && x->GetTypeAnn()->GetKind() == ETypeAnnotationKind::EmptyList;
};
- bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
- if (!EnsureStructType(input->Pos(), *originalExtractorType, ctx)) {
- return false;
+ const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
+ if (!EnsureStructType(pos, *originalExtractorType, ctx)) {
+ return nullptr;
}
auto structType = originalExtractorType->Cast<TStructExprType>();
if (structType->GetSize() != 1) {
- ctx.AddError(TIssue(ctx.GetPosition(input->Pos()),
+ ctx.AddError(TIssue(ctx.GetPosition(pos),
TStringBuilder() << "Expected struct with one member"));
- return false;
+ return nullptr;
}
- input->SetTypeAnn(structType->GetItems()[0]->GetItemType());
+ auto type = structType->GetItems()[0]->GetItemType();
if (isMany) {
- if (input->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional) {
- ctx.AddError(TIssue(ctx.GetPosition(input->Pos()),
+ if (type->GetKind() != ETypeAnnotationKind::Optional) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos),
TStringBuilder() << "Expected optional state"));
- return false;
+ return nullptr;
}
- input->SetTypeAnn(input->GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType());
+ type = type->Cast<TOptionalExprType>()->GetItemType();
}
+ return type;
+ }
+
+ bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
+ auto type = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx);
+ if (!type) {
+ return false;
+ }
+
+ input->SetTypeAnn(type);
return true;
}
@@ -5360,6 +5372,61 @@ namespace {
input->SetTypeAnn(retType);
} else if (name == "some") {
input->SetTypeAnn(lambda->GetTypeAnn());
+ } else if (name.StartsWith("pg_")) {
+ auto func = name;
+ func.SkipPrefix("pg_");
+ TVector<ui32> argTypes;
+ bool needRetype = false;
+ if (auto status = ExtractPgTypesFromMultiLambda(lambda, argTypes, needRetype, ctx.Expr);
+ status != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ if (overState && !hasOriginalType) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Partial aggregation of " << name << " is not supported"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ const TTypeAnnotationNode* originalResultType = nullptr;
+ if (hasOriginalType) {
+ auto originalExtractorType = input->Child(3)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ originalResultType = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx.Expr);
+ if (!originalResultType) {
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
+ const NPg::TAggregateDesc* aggDescPtr;
+ try {
+ if (overState) {
+ if (argTypes.size() != 1) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ "Expected only one argument"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes[0], originalResultType->Cast<TPgExprType>()->GetId());
+ } else {
+ aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes);
+ }
+ if (aggDescPtr->Kind != NPg::EAggKind::Normal) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ "Only normal aggregation supported"));
+ return IGraphTransformer::TStatus::Error;
+ }
+ } catch (const yexception& e) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), e.what()));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (overState) {
+ input->SetTypeAnn(originalResultType);
+ } else {
+ const NPg::TAggregateDesc& aggDesc = *aggDescPtr;
+ const auto& finalDesc = NPg::LookupProc(aggDesc.FinalFuncId ? aggDesc.FinalFuncId : aggDesc.TransFuncId);
+ input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(finalDesc.ResultType));
+ }
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Unsupported agg name: " << name));
diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp
index b5d9e715311..77acf647e43 100644
--- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp
@@ -285,13 +285,14 @@ IGraphTransformer::TStatus ToPgWrapper(const TExprNode::TPtr& input, TExprNode::
}
IGraphTransformer::TStatus PgCloneWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ Y_UNUSED(output);
if (!EnsureDependsOnTail(*input, ctx.Expr, 1)) {
return IGraphTransformer::TStatus::Error;
}
if (IsNull(input->Head())) {
- output = input->HeadPtr();
- return IGraphTransformer::TStatus::Repeat;
+ input->SetTypeAnn(input->Head().GetTypeAnn());
+ return IGraphTransformer::TStatus::Ok;
}
auto type = input->Head().GetTypeAnn();
@@ -1080,335 +1081,26 @@ IGraphTransformer::TStatus PgAggregationTraitsWrapper(const TExprNode::TPtr& inp
TVector<ui32> argTypes;
bool needRetype = false;
- for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) {
- auto type = lambda->Child(i)->GetTypeAnn();
- ui32 argType;
- bool convertToPg;
- if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
-
- if (convertToPg) {
- needRetype = true;
- }
-
- argTypes.push_back(argType);
+ if (auto status = ExtractPgTypesFromMultiLambda(lambda, argTypes, needRetype, ctx.Expr);
+ status != IGraphTransformer::TStatus::Ok) {
+ return status;
}
- if (needRetype) {
- auto newLambda = ctx.Expr.DeepCopyLambda(*lambda);
- for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) {
- auto type = lambda->Child(i)->GetTypeAnn();
- ui32 argType;
- bool convertToPg;
- if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
-
- if (convertToPg) {
- newLambda->ChildRef(i) = ctx.Expr.NewCallable(newLambda->Child(i)->Pos(), "ToPg", { newLambda->ChildPtr(i) });
- }
+ const NPg::TAggregateDesc* aggDescPtr;
+ try {
+ aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes);
+ if (aggDescPtr->Kind != NPg::EAggKind::Normal) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ "Only normal aggregation supported"));
+ return IGraphTransformer::TStatus::Error;
}
-
- lambda = newLambda;
- return IGraphTransformer::TStatus::Repeat;
- }
-
- const auto& aggDesc = NPg::LookupAggregation(TString(func), argTypes);
- if (aggDesc.Kind != NPg::EAggKind::Normal) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
- "Only normal aggregation supported"));
+ } catch (const yexception& e) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), e.what()));
return IGraphTransformer::TStatus::Error;
}
- auto idLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state")
- .Arg("state")
- .Seal()
- .Build();
-
- auto saveLambda = idLambda;
- auto loadLambda = idLambda;
- auto finishLambda = idLambda;
- if (aggDesc.FinalFuncId) {
- finishLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state")
- .Callable("PgResolvedCallCtx")
- .Atom(0, NPg::LookupProc(aggDesc.FinalFuncId).Name)
- .Atom(1, ToString(aggDesc.FinalFuncId))
- .List(2)
- .Seal()
- .Arg(3, "state")
- .Seal()
- .Seal()
- .Build();
- }
-
- auto nullValue = ctx.Expr.NewCallable(input->Pos(), "Null", {});
- auto initValue = nullValue;
- if (aggDesc.InitValue) {
- initValue = ctx.Expr.Builder(input->Pos())
- .Callable("PgCast")
- .Callable(0, "PgConst")
- .Atom(0, aggDesc.InitValue)
- .Callable(1, "PgType")
- .Atom(0, "text")
- .Seal()
- .Seal()
- .Callable(1, "PgType")
- .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name)
- .Seal()
- .Seal()
- .Build();
- }
-
- const auto& transFuncDesc = NPg::LookupProc(aggDesc.TransFuncId);
- // use first non-null value as state if transFunc is strict
- bool searchNonNullForState = false;
- if (transFuncDesc.IsStrict && !aggDesc.InitValue) {
- Y_ENSURE(argTypes.size() == 1);
- searchNonNullForState = true;
- }
-
- TExprNode::TPtr initLambda, updateLambda;
- if (!searchNonNullForState) {
- initLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("row")
- .Param("parent")
- .Callable("PgResolvedCallCtx")
- .Atom(0, transFuncDesc.Name)
- .Atom(1, ToString(aggDesc.TransFuncId))
- .List(2)
- .Seal()
- .Callable(3, "PgClone")
- .Add(0, initValue)
- .Callable(1, "DependsOn")
- .Arg(0, "parent")
- .Seal()
- .Seal()
- .Apply(4, lambda)
- .With(0, "row")
- .Seal()
- .Seal()
- .Seal()
- .Build();
-
- updateLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("row")
- .Param("state")
- .Param("parent")
- .Callable("Coalesce")
- .Callable(0, "PgResolvedCallCtx")
- .Atom(0, transFuncDesc.Name)
- .Atom(1, ToString(aggDesc.TransFuncId))
- .List(2)
- .Seal()
- .Callable(3, "Coalesce")
- .Arg(0, "state")
- .Callable(1, "PgClone")
- .Add(0, initValue)
- .Callable(1, "DependsOn")
- .Arg(0, "parent")
- .Seal()
- .Seal()
- .Seal()
- .Apply(4, lambda)
- .With(0, "row")
- .Seal()
- .Seal()
- .Arg(1, "state")
- .Seal()
- .Seal()
- .Build();
- } else {
- initLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("row")
- .Apply(lambda)
- .With(0, "row")
- .Seal()
- .Seal()
- .Build();
-
- if (lambdaResult->GetKind() == ETypeAnnotationKind::Null) {
- initLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("row")
- .Callable("PgCast")
- .Apply(0, initLambda)
- .With(0, "row")
- .Seal()
- .Callable(1, "PgType")
- .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name)
- .Seal()
- .Seal()
- .Seal()
- .Build();
- }
-
- updateLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("row")
- .Param("state")
- .Callable("If")
- .Callable(0, "Exists")
- .Arg(0, "state")
- .Seal()
- .Callable(1, "Coalesce")
- .Callable(0, "PgResolvedCallCtx")
- .Atom(0, transFuncDesc.Name)
- .Atom(1, ToString(aggDesc.TransFuncId))
- .List(2)
- .Seal()
- .Arg(3, "state")
- .Apply(4, lambda)
- .With(0, "row")
- .Seal()
- .Seal()
- .Arg(1, "state")
- .Seal()
- .Apply(2, lambda)
- .With(0, "row")
- .Seal()
- .Seal()
- .Seal()
- .Build();
- }
-
- auto mergeLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state1")
- .Param("state2")
- .Callable("Void")
- .Seal()
- .Seal()
- .Build();
-
- auto zero = ctx.Expr.Builder(input->Pos())
- .Callable("PgConst")
- .Atom(0, "0")
- .Callable(1, "PgType")
- .Atom(0, "int8")
- .Seal()
- .Seal()
- .Build();
-
- auto defaultValue = (func != "count") ? nullValue : zero;
-
- if (aggDesc.SerializeFuncId) {
- const auto& serializeFuncDesc = NPg::LookupProc(aggDesc.SerializeFuncId);
- saveLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state")
- .Callable("PgResolvedCallCtx")
- .Atom(0, serializeFuncDesc.Name)
- .Atom(1, ToString(aggDesc.SerializeFuncId))
- .List(2)
- .Seal()
- .Arg(3, "state")
- .Seal()
- .Seal()
- .Build();
- }
-
- if (aggDesc.DeserializeFuncId) {
- const auto& deserializeFuncDesc = NPg::LookupProc(aggDesc.DeserializeFuncId);
- loadLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state")
- .Callable("PgResolvedCallCtx")
- .Atom(0, deserializeFuncDesc.Name)
- .Atom(1, ToString(aggDesc.DeserializeFuncId))
- .List(2)
- .Seal()
- .Arg(3, "state")
- .Callable(4, "PgInternal0")
- .Seal()
- .Seal()
- .Seal()
- .Build();
- }
-
- if (aggDesc.CombineFuncId) {
- const auto& combineFuncDesc = NPg::LookupProc(aggDesc.CombineFuncId);
- if (combineFuncDesc.IsStrict) {
- mergeLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state1")
- .Param("state2")
- .Callable("If")
- .Callable(0, "Exists")
- .Arg(0, "state1")
- .Seal()
- .Callable(1, "Coalesce")
- .Callable(0, "PgResolvedCallCtx")
- .Atom(0, combineFuncDesc.Name)
- .Atom(1, ToString(aggDesc.CombineFuncId))
- .List(2)
- .Seal()
- .Arg(3, "state1")
- .Arg(4, "state2")
- .Seal()
- .Arg(1, "state1")
- .Seal()
- .Arg(2, "state2")
- .Seal()
- .Seal()
- .Build();
- } else {
- mergeLambda = ctx.Expr.Builder(input->Pos())
- .Lambda()
- .Param("state1")
- .Param("state2")
- .Callable("PgResolvedCallCtx")
- .Atom(0, combineFuncDesc.Name)
- .Atom(1, ToString(aggDesc.CombineFuncId))
- .List(2)
- .Seal()
- .Arg(3, "state1")
- .Arg(4, "state2")
- .Seal()
- .Seal()
- .Build();
- }
- }
-
- auto typeNode = ExpandType(input->Pos(), *itemType, ctx.Expr);
- if (onWindow) {
- output = ctx.Expr.Builder(input->Pos())
- .Callable("WindowTraits")
- .Add(0, typeNode)
- .Add(1, initLambda)
- .Add(2, updateLambda)
- .Lambda(3)
- .Param("value")
- .Param("state")
- .Callable("Void")
- .Seal()
- .Seal()
- .Add(4, finishLambda)
- .Add(5, defaultValue)
- .Seal()
- .Build();
- } else {
- output = ctx.Expr.Builder(input->Pos())
- .Callable("AggregationTraits")
- .Add(0, typeNode)
- .Add(1, initLambda)
- .Add(2, updateLambda)
- .Add(3, saveLambda)
- .Add(4, loadLambda)
- .Add(5, mergeLambda)
- .Add(6, finishLambda)
- .Add(7, defaultValue)
- .Seal()
- .Build();
- }
-
+ const NPg::TAggregateDesc& aggDesc = *aggDescPtr;
+ output = ExpandPgAggregationTraits(input->Pos(), aggDesc, onWindow, lambda, argTypes, itemType, ctx.Expr);
return IGraphTransformer::TStatus::Repeat;
}
diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp
index 177755f7b15..8ac57ad21d4 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.cpp
+++ b/ydb/library/yql/core/yql_aggregate_expander.cpp
@@ -96,6 +96,26 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate()
TExprNode::TPtr TAggregateExpander::ExpandAggApply(const TExprNode::TPtr& node)
{
auto name = node->Head().Content();
+ if (name.StartsWith("pg_")) {
+ auto func = name.SubStr(3);
+ auto itemType = node->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ TVector<ui32> argTypes;
+ bool needRetype = false;
+ auto status = ExtractPgTypesFromMultiLambda(node->ChildRef(2), argTypes, needRetype, Ctx);
+ YQL_ENSURE(status == IGraphTransformer::TStatus::Ok);
+
+ const NPg::TAggregateDesc* aggDescPtr;
+ if (node->Content().EndsWith("State")) {
+ auto stateType = node->Child(2)->GetTypeAnn()->Cast<TPgExprType>()->GetId();
+ auto resultType = node->GetTypeAnn()->Cast<TPgExprType>()->GetId();
+ aggDescPtr = &NPg::LookupAggregation(TString(func), stateType, resultType);
+ } else {
+ aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes);
+ }
+
+ return ExpandPgAggregationTraits(node->Pos(), *aggDescPtr, false, node->ChildPtr(2), argTypes, itemType, Ctx);
+ }
+
auto exportsPtr = TypesCtx.Modules->GetModule("/lib/yql/aggregate.yql");
YQL_ENSURE(exportsPtr);
const auto& exports = exportsPtr->Symbols();
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp
index 621483765f5..12135cbbf70 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.cpp
+++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp
@@ -5566,6 +5566,16 @@ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& in
}
return stateType;
+ } else if (name.StartsWith("pg_")) {
+ auto func = name.SubStr(3);
+ TVector<ui32> argTypes;
+ bool needRetype = false;
+ auto status = ExtractPgTypesFromMultiLambda(input->ChildRef(2), argTypes, needRetype, ctx);
+ YQL_ENSURE(status == IGraphTransformer::TStatus::Ok);
+
+ const NPg::TAggregateDesc& aggDesc = NPg::LookupAggregation(TString(func), argTypes);
+ const auto& procDesc = NPg::LookupProc(aggDesc.SerializeFuncId ? aggDesc.SerializeFuncId : aggDesc.TransFuncId);
+ return ctx.MakeType<TPgExprType>(procDesc.ResultType);
} else {
YQL_ENSURE(false, "Unknown AggApply: " << name);
}
@@ -5696,4 +5706,335 @@ bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode&
return true;
}
+IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda, TVector<ui32>& argTypes,
+ bool& needRetype, TExprContext& ctx) {
+ for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) {
+ auto type = lambda->Child(i)->GetTypeAnn();
+ ui32 argType;
+ bool convertToPg;
+ if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (convertToPg) {
+ needRetype = true;
+ }
+
+ argTypes.push_back(argType);
+ }
+
+ if (needRetype) {
+ auto newLambda = ctx.DeepCopyLambda(*lambda);
+ for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) {
+ auto type = lambda->Child(i)->GetTypeAnn();
+ ui32 argType;
+ bool convertToPg;
+ if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (convertToPg) {
+ newLambda->ChildRef(i) = ctx.NewCallable(newLambda->Child(i)->Pos(), "ToPg", { newLambda->ChildPtr(i) });
+ }
+ }
+
+ lambda = newLambda;
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ return IGraphTransformer::TStatus::Ok;
+}
+
+TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow,
+ const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx) {
+ auto idLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state")
+ .Arg("state")
+ .Seal()
+ .Build();
+
+ auto saveLambda = idLambda;
+ auto loadLambda = idLambda;
+ auto finishLambda = idLambda;
+ if (aggDesc.FinalFuncId) {
+ finishLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state")
+ .Callable("PgResolvedCallCtx")
+ .Atom(0, NPg::LookupProc(aggDesc.FinalFuncId).Name)
+ .Atom(1, ToString(aggDesc.FinalFuncId))
+ .List(2)
+ .Seal()
+ .Arg(3, "state")
+ .Seal()
+ .Seal()
+ .Build();
+ }
+
+ auto nullValue = ctx.NewCallable(pos, "Null", {});
+ auto initValue = nullValue;
+ if (aggDesc.InitValue) {
+ initValue = ctx.Builder(pos)
+ .Callable("PgCast")
+ .Callable(0, "PgConst")
+ .Atom(0, aggDesc.InitValue)
+ .Callable(1, "PgType")
+ .Atom(0, "text")
+ .Seal()
+ .Seal()
+ .Callable(1, "PgType")
+ .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name)
+ .Seal()
+ .Seal()
+ .Build();
+ }
+
+ const auto& transFuncDesc = NPg::LookupProc(aggDesc.TransFuncId);
+ // use first non-null value as state if transFunc is strict
+ bool searchNonNullForState = false;
+ if (transFuncDesc.IsStrict && !aggDesc.InitValue) {
+ Y_ENSURE(argTypes.size() == 1);
+ searchNonNullForState = true;
+ }
+
+ TExprNode::TPtr initLambda, updateLambda;
+ if (!searchNonNullForState) {
+ initLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("row")
+ .Param("parent")
+ .Callable("PgResolvedCallCtx")
+ .Atom(0, transFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.TransFuncId))
+ .List(2)
+ .Seal()
+ .Callable(3, "PgClone")
+ .Add(0, initValue)
+ .Callable(1, "DependsOn")
+ .Arg(0, "parent")
+ .Seal()
+ .Seal()
+ .Apply(4, lambda)
+ .With(0, "row")
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+
+ updateLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("row")
+ .Param("state")
+ .Param("parent")
+ .Callable("Coalesce")
+ .Callable(0, "PgResolvedCallCtx")
+ .Atom(0, transFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.TransFuncId))
+ .List(2)
+ .Seal()
+ .Callable(3, "Coalesce")
+ .Arg(0, "state")
+ .Callable(1, "PgClone")
+ .Add(0, initValue)
+ .Callable(1, "DependsOn")
+ .Arg(0, "parent")
+ .Seal()
+ .Seal()
+ .Seal()
+ .Apply(4, lambda)
+ .With(0, "row")
+ .Seal()
+ .Seal()
+ .Arg(1, "state")
+ .Seal()
+ .Seal()
+ .Build();
+ } else {
+ initLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("row")
+ .Apply(lambda)
+ .With(0, "row")
+ .Seal()
+ .Seal()
+ .Build();
+
+ if (lambda->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Null) {
+ initLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("row")
+ .Callable("PgCast")
+ .Apply(0, initLambda)
+ .With(0, "row")
+ .Seal()
+ .Callable(1, "PgType")
+ .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name)
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+ }
+
+ updateLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("row")
+ .Param("state")
+ .Callable("If")
+ .Callable(0, "Exists")
+ .Arg(0, "state")
+ .Seal()
+ .Callable(1, "Coalesce")
+ .Callable(0, "PgResolvedCallCtx")
+ .Atom(0, transFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.TransFuncId))
+ .List(2)
+ .Seal()
+ .Arg(3, "state")
+ .Apply(4, lambda)
+ .With(0, "row")
+ .Seal()
+ .Seal()
+ .Arg(1, "state")
+ .Seal()
+ .Apply(2, lambda)
+ .With(0, "row")
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+ }
+
+ auto mergeLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state1")
+ .Param("state2")
+ .Callable("Void")
+ .Seal()
+ .Seal()
+ .Build();
+
+ auto zero = ctx.Builder(pos)
+ .Callable("PgConst")
+ .Atom(0, "0")
+ .Callable(1, "PgType")
+ .Atom(0, "int8")
+ .Seal()
+ .Seal()
+ .Build();
+
+ auto defaultValue = (aggDesc.Name != "count") ? nullValue : zero;
+
+ if (aggDesc.SerializeFuncId) {
+ const auto& serializeFuncDesc = NPg::LookupProc(aggDesc.SerializeFuncId);
+ saveLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state")
+ .Callable("PgResolvedCallCtx")
+ .Atom(0, serializeFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.SerializeFuncId))
+ .List(2)
+ .Seal()
+ .Arg(3, "state")
+ .Seal()
+ .Seal()
+ .Build();
+ }
+
+ if (aggDesc.DeserializeFuncId) {
+ const auto& deserializeFuncDesc = NPg::LookupProc(aggDesc.DeserializeFuncId);
+ loadLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state")
+ .Callable("PgResolvedCallCtx")
+ .Atom(0, deserializeFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.DeserializeFuncId))
+ .List(2)
+ .Seal()
+ .Arg(3, "state")
+ .Callable(4, "PgInternal0")
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+ }
+
+ if (aggDesc.CombineFuncId) {
+ const auto& combineFuncDesc = NPg::LookupProc(aggDesc.CombineFuncId);
+ if (combineFuncDesc.IsStrict) {
+ mergeLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state1")
+ .Param("state2")
+ .Callable("If")
+ .Callable(0, "Exists")
+ .Arg(0, "state1")
+ .Seal()
+ .Callable(1, "Coalesce")
+ .Callable(0, "PgResolvedCallCtx")
+ .Atom(0, combineFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.CombineFuncId))
+ .List(2)
+ .Seal()
+ .Arg(3, "state1")
+ .Arg(4, "state2")
+ .Seal()
+ .Arg(1, "state1")
+ .Seal()
+ .Arg(2, "state2")
+ .Seal()
+ .Seal()
+ .Build();
+ } else {
+ mergeLambda = ctx.Builder(pos)
+ .Lambda()
+ .Param("state1")
+ .Param("state2")
+ .Callable("PgResolvedCallCtx")
+ .Atom(0, combineFuncDesc.Name)
+ .Atom(1, ToString(aggDesc.CombineFuncId))
+ .List(2)
+ .Seal()
+ .Arg(3, "state1")
+ .Arg(4, "state2")
+ .Seal()
+ .Seal()
+ .Build();
+ }
+ }
+
+ auto typeNode = ExpandType(pos, *itemType, ctx);
+ if (onWindow) {
+ return ctx.Builder(pos)
+ .Callable("WindowTraits")
+ .Add(0, typeNode)
+ .Add(1, initLambda)
+ .Add(2, updateLambda)
+ .Lambda(3)
+ .Param("value")
+ .Param("state")
+ .Callable("Void")
+ .Seal()
+ .Seal()
+ .Add(4, finishLambda)
+ .Add(5, defaultValue)
+ .Seal()
+ .Build();
+ } else {
+ return ctx.Builder(pos)
+ .Callable("AggregationTraits")
+ .Add(0, typeNode)
+ .Add(1, initLambda)
+ .Add(2, updateLambda)
+ .Add(3, saveLambda)
+ .Add(4, loadLambda)
+ .Add(5, mergeLambda)
+ .Add(6, finishLambda)
+ .Add(7, defaultValue)
+ .Seal()
+ .Build();
+ }
+}
+
+
} // NYql
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h
index a7f45045229..6daf85dcca4 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.h
+++ b/ydb/library/yql/core/yql_expr_type_annotation.h
@@ -6,6 +6,7 @@
#include <ydb/library/yql/ast/yql_expr.h>
#include <ydb/library/yql/core/expr_nodes/yql_expr_nodes.h>
#include <ydb/library/yql/minikql/mkql_type_ops.h>
+#include <ydb/library/yql/parser/pg_catalog/catalog.h>
#include <library/cpp/enumbitset/enumbitset.h>
@@ -317,4 +318,10 @@ bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inp
bool GetAvgResultTypeOverState(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
+IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda, TVector<ui32>& argTypes,
+ bool& needRetype, TExprContext& ctx);
+
+TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow,
+ const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx);
+
}
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp
index 947972a04b7..20a8318f504 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.cpp
+++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp
@@ -1475,6 +1475,20 @@ bool ValidateAggregateArgs(const TAggregateDesc& d, const TVector<ui32>& argType
return ValidateArgs(d.ArgTypes, argTypeIds);
}
+bool ValidateAggregateArgs(const TAggregateDesc& d, ui32 stateType, ui32 resultType) {
+ auto expectedStateType = LookupProc(d.SerializeFuncId ? d.SerializeFuncId : d.TransFuncId).ResultType;
+ if (stateType != expectedStateType) {
+ return false;
+ }
+
+ auto expectedResultType = LookupProc(d.FinalFuncId ? d.FinalFuncId : d.TransFuncId).ResultType;
+ if (resultType != expectedResultType) {
+ return false;
+ }
+
+ return true;
+}
+
const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>& argTypeIds) {
const auto& catalog = TCatalog::Instance();
auto aggIdPtr = catalog.AggregationsByName.FindPtr(to_lower(name));
@@ -1496,6 +1510,28 @@ const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>
<< ArgTypesList(argTypeIds);
}
+const TAggregateDesc& LookupAggregation(const TString& name, ui32 stateType, ui32 resultType) {
+ const auto& catalog = TCatalog::Instance();
+ auto aggIdPtr = catalog.AggregationsByName.FindPtr(to_lower(name));
+ if (!aggIdPtr) {
+ throw yexception() << "No such aggregate: " << name;
+ }
+
+ for (const auto& id : *aggIdPtr) {
+ const auto& d = catalog.Aggregations.FindPtr(id);
+ Y_ENSURE(d);
+ if (!ValidateAggregateArgs(*d, stateType, resultType)) {
+ continue;
+ }
+
+ return *d;
+ }
+
+ throw yexception() << "Unable to find an overload for aggregate " << name << " with given state type: " <<
+ NPg::LookupType(stateType).Name << " and result type: " <<
+ NPg::LookupType(resultType).Name;
+}
+
bool HasOpClass(EOpClassMethod method, ui32 typeId) {
const auto& catalog = TCatalog::Instance();
return catalog.OpClasses.contains(std::make_pair(method, typeId));
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h
index cc1bbe46732..e8d4968b29a 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.h
+++ b/ydb/library/yql/parser/pg_catalog/catalog.h
@@ -206,6 +206,7 @@ const TOperDesc& LookupOper(ui32 operId);
bool HasAggregation(const TString& name);
const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>& argTypeIds);
+const TAggregateDesc& LookupAggregation(const TString& name, ui32 stateType, ui32 resultType);
bool HasOpClass(EOpClassMethod method, ui32 typeId);
const TOpClassDesc* LookupDefaultOpClass(EOpClassMethod method, ui32 typeId);
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index fa583b41107..8969d7c3fba 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -2505,6 +2505,14 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
AddCallable("PgClone", [](const TExprNode& node, TMkqlBuildContext& ctx) {
auto input = MkqlBuildExpr(*node.Child(0), ctx);
+ if (IsNull(node.Head())) {
+ return input;
+ }
+
+ if (NPg::LookupType(node.GetTypeAnn()->Cast<TPgExprType>()->GetId()).PassByValue) {
+ return input;
+ }
+
TVector<TRuntimeNode> dependentNodes;
for (ui32 i = 1; i < node.ChildrenSize(); ++i) {
dependentNodes.push_back(MkqlBuildExpr(*node.Child(i), ctx));
diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp
index 57a8395d0fa..ceb2d3a3ad0 100644
--- a/ydb/library/yql/sql/v1/aggregation.cpp
+++ b/ydb/library/yql/sql/v1/aggregation.cpp
@@ -1370,6 +1370,11 @@ public:
Y_UNUSED(many);
Y_UNUSED(ctx);
Y_UNUSED(allowAggApply);
+ if (ctx.EmitAggApply && allowAggApply && AggMode != EAggregateMode::OverWindow) {
+ return Y("AggApply",
+ Q("pg_" + to_lower(PgFunc)), Y("ListItemType", type), Lambda);
+ }
+
return Y(AggMode == EAggregateMode::OverWindow ? "PgWindowTraits" : "PgAggregationTraits",
Q(PgFunc), Y("ListItemType", type), Lambda);
}