diff options
author | vvvv <vvvv@ydb.tech> | 2023-05-19 17:48:47 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-05-19 17:48:47 +0300 |
commit | 253904e00ed53deb7ba78466b41f25a78f74b6de (patch) | |
tree | c67c6d444d76ff0751c092dffabfd6feccddd81a | |
parent | 81eb00dfba71fc3110d5af2eed7fa320ace33e5b (diff) | |
download | ydb-253904e00ed53deb7ba78466b41f25a78f74b6de.tar.gz |
Implementation of agg_apply mode for Pg aggregations
-rw-r--r-- | ydb/core/kqp/provider/yql_kikimr_datasource.cpp | 4 | ||||
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_flow2.cpp | 11 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 87 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_pg.cpp | 342 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.cpp | 20 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.cpp | 341 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.h | 7 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.cpp | 36 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.h | 1 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 8 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/aggregation.cpp | 5 |
11 files changed, 521 insertions, 341 deletions
diff --git a/ydb/core/kqp/provider/yql_kikimr_datasource.cpp b/ydb/core/kqp/provider/yql_kikimr_datasource.cpp index 65f45687112..e92fb35197c 100644 --- a/ydb/core/kqp/provider/yql_kikimr_datasource.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_datasource.cpp @@ -535,8 +535,8 @@ public: if (typeInfo.GetTypeId() == NScheme::NTypeIds::Pg) { auto* typeDesc = typeInfo.GetTypeDesc(); auto* pg = ydbType.mutable_pg_type(); - pg->set_type_name(NPg::PgTypeNameFromTypeDesc(typeDesc)); - pg->set_oid(NPg::PgTypeIdFromTypeDesc(typeDesc)); + pg->set_type_name(NKikimr::NPg::PgTypeNameFromTypeDesc(typeDesc)); + pg->set_oid(NKikimr::NPg::PgTypeIdFromTypeDesc(typeDesc)); } else { auto& item = notNull ? ydbType diff --git a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp index a5dc8179691..e30ef474cf9 100644 --- a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp @@ -1746,12 +1746,15 @@ void RegisterCoFlowCallables2(TCallableOptimizerMap& map) { auto structType = type->Cast<TStructExprType>(); TSet<TStringBuf> usedFields; auto extractor = node->Child(2); - TSet<TStringBuf> lambdaSubset; - if (!HaveFieldsSubset(extractor->ChildPtr(1), *extractor->Child(0)->Child(0), lambdaSubset, *optCtx.ParentsMap)) { - return node; + for (ui32 i = 1; i < extractor->ChildrenSize(); ++i) { + TSet<TStringBuf> lambdaSubset; + if (!HaveFieldsSubset(extractor->ChildPtr(i), *extractor->Child(0)->Child(0), lambdaSubset, *optCtx.ParentsMap)) { + return node; + } + + usedFields.insert(lambdaSubset.cbegin(), lambdaSubset.cend()); } - usedFields.insert(lambdaSubset.cbegin(), lambdaSubset.cend()); if (usedFields.size() == structType->GetSize()) { return node; } diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index 4af726c03f1..0de23081a7d 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -8,6 +8,8 @@ #include <ydb/library/yql/core/yql_opt_window.h> #include <ydb/library/yql/core/yql_type_helpers.h> +#include <ydb/library/yql/parser/pg_catalog/catalog.h> + #include <util/generic/algorithm.h> #include <util/string/join.h> @@ -21,29 +23,39 @@ namespace { return x->GetTypeAnn() && x->GetTypeAnn()->GetKind() == ETypeAnnotationKind::EmptyList; }; - bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { - if (!EnsureStructType(input->Pos(), *originalExtractorType, ctx)) { - return false; + const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { + if (!EnsureStructType(pos, *originalExtractorType, ctx)) { + return nullptr; } auto structType = originalExtractorType->Cast<TStructExprType>(); if (structType->GetSize() != 1) { - ctx.AddError(TIssue(ctx.GetPosition(input->Pos()), + ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected struct with one member")); - return false; + return nullptr; } - input->SetTypeAnn(structType->GetItems()[0]->GetItemType()); + auto type = structType->GetItems()[0]->GetItemType(); if (isMany) { - if (input->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional) { - ctx.AddError(TIssue(ctx.GetPosition(input->Pos()), + if (type->GetKind() != ETypeAnnotationKind::Optional) { + ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected optional state")); - return false; + return nullptr; } - input->SetTypeAnn(input->GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType()); + type = type->Cast<TOptionalExprType>()->GetItemType(); } + return type; + } + + bool ApplyOriginalType(TExprNode::TPtr input, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) { + auto type = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx); + if (!type) { + return false; + } + + input->SetTypeAnn(type); return true; } @@ -5360,6 +5372,61 @@ namespace { input->SetTypeAnn(retType); } else if (name == "some") { input->SetTypeAnn(lambda->GetTypeAnn()); + } else if (name.StartsWith("pg_")) { + auto func = name; + func.SkipPrefix("pg_"); + TVector<ui32> argTypes; + bool needRetype = false; + if (auto status = ExtractPgTypesFromMultiLambda(lambda, argTypes, needRetype, ctx.Expr); + status != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (overState && !hasOriginalType) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Partial aggregation of " << name << " is not supported")); + return IGraphTransformer::TStatus::Error; + } + + const TTypeAnnotationNode* originalResultType = nullptr; + if (hasOriginalType) { + auto originalExtractorType = input->Child(3)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + originalResultType = GetOriginalResultType(input->Pos(), isMany, originalExtractorType, ctx.Expr); + if (!originalResultType) { + return IGraphTransformer::TStatus::Error; + } + } + + const NPg::TAggregateDesc* aggDescPtr; + try { + if (overState) { + if (argTypes.size() != 1) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + "Expected only one argument")); + return IGraphTransformer::TStatus::Error; + } + + aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes[0], originalResultType->Cast<TPgExprType>()->GetId()); + } else { + aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes); + } + if (aggDescPtr->Kind != NPg::EAggKind::Normal) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + "Only normal aggregation supported")); + return IGraphTransformer::TStatus::Error; + } + } catch (const yexception& e) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), e.what())); + return IGraphTransformer::TStatus::Error; + } + + if (overState) { + input->SetTypeAnn(originalResultType); + } else { + const NPg::TAggregateDesc& aggDesc = *aggDescPtr; + const auto& finalDesc = NPg::LookupProc(aggDesc.FinalFuncId ? aggDesc.FinalFuncId : aggDesc.TransFuncId); + input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(finalDesc.ResultType)); + } } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "Unsupported agg name: " << name)); diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp index b5d9e715311..77acf647e43 100644 --- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp @@ -285,13 +285,14 @@ IGraphTransformer::TStatus ToPgWrapper(const TExprNode::TPtr& input, TExprNode:: } IGraphTransformer::TStatus PgCloneWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); if (!EnsureDependsOnTail(*input, ctx.Expr, 1)) { return IGraphTransformer::TStatus::Error; } if (IsNull(input->Head())) { - output = input->HeadPtr(); - return IGraphTransformer::TStatus::Repeat; + input->SetTypeAnn(input->Head().GetTypeAnn()); + return IGraphTransformer::TStatus::Ok; } auto type = input->Head().GetTypeAnn(); @@ -1080,335 +1081,26 @@ IGraphTransformer::TStatus PgAggregationTraitsWrapper(const TExprNode::TPtr& inp TVector<ui32> argTypes; bool needRetype = false; - for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) { - auto type = lambda->Child(i)->GetTypeAnn(); - ui32 argType; - bool convertToPg; - if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - - if (convertToPg) { - needRetype = true; - } - - argTypes.push_back(argType); + if (auto status = ExtractPgTypesFromMultiLambda(lambda, argTypes, needRetype, ctx.Expr); + status != IGraphTransformer::TStatus::Ok) { + return status; } - if (needRetype) { - auto newLambda = ctx.Expr.DeepCopyLambda(*lambda); - for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) { - auto type = lambda->Child(i)->GetTypeAnn(); - ui32 argType; - bool convertToPg; - if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - - if (convertToPg) { - newLambda->ChildRef(i) = ctx.Expr.NewCallable(newLambda->Child(i)->Pos(), "ToPg", { newLambda->ChildPtr(i) }); - } + const NPg::TAggregateDesc* aggDescPtr; + try { + aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes); + if (aggDescPtr->Kind != NPg::EAggKind::Normal) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + "Only normal aggregation supported")); + return IGraphTransformer::TStatus::Error; } - - lambda = newLambda; - return IGraphTransformer::TStatus::Repeat; - } - - const auto& aggDesc = NPg::LookupAggregation(TString(func), argTypes); - if (aggDesc.Kind != NPg::EAggKind::Normal) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), - "Only normal aggregation supported")); + } catch (const yexception& e) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), e.what())); return IGraphTransformer::TStatus::Error; } - auto idLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state") - .Arg("state") - .Seal() - .Build(); - - auto saveLambda = idLambda; - auto loadLambda = idLambda; - auto finishLambda = idLambda; - if (aggDesc.FinalFuncId) { - finishLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state") - .Callable("PgResolvedCallCtx") - .Atom(0, NPg::LookupProc(aggDesc.FinalFuncId).Name) - .Atom(1, ToString(aggDesc.FinalFuncId)) - .List(2) - .Seal() - .Arg(3, "state") - .Seal() - .Seal() - .Build(); - } - - auto nullValue = ctx.Expr.NewCallable(input->Pos(), "Null", {}); - auto initValue = nullValue; - if (aggDesc.InitValue) { - initValue = ctx.Expr.Builder(input->Pos()) - .Callable("PgCast") - .Callable(0, "PgConst") - .Atom(0, aggDesc.InitValue) - .Callable(1, "PgType") - .Atom(0, "text") - .Seal() - .Seal() - .Callable(1, "PgType") - .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name) - .Seal() - .Seal() - .Build(); - } - - const auto& transFuncDesc = NPg::LookupProc(aggDesc.TransFuncId); - // use first non-null value as state if transFunc is strict - bool searchNonNullForState = false; - if (transFuncDesc.IsStrict && !aggDesc.InitValue) { - Y_ENSURE(argTypes.size() == 1); - searchNonNullForState = true; - } - - TExprNode::TPtr initLambda, updateLambda; - if (!searchNonNullForState) { - initLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("row") - .Param("parent") - .Callable("PgResolvedCallCtx") - .Atom(0, transFuncDesc.Name) - .Atom(1, ToString(aggDesc.TransFuncId)) - .List(2) - .Seal() - .Callable(3, "PgClone") - .Add(0, initValue) - .Callable(1, "DependsOn") - .Arg(0, "parent") - .Seal() - .Seal() - .Apply(4, lambda) - .With(0, "row") - .Seal() - .Seal() - .Seal() - .Build(); - - updateLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("row") - .Param("state") - .Param("parent") - .Callable("Coalesce") - .Callable(0, "PgResolvedCallCtx") - .Atom(0, transFuncDesc.Name) - .Atom(1, ToString(aggDesc.TransFuncId)) - .List(2) - .Seal() - .Callable(3, "Coalesce") - .Arg(0, "state") - .Callable(1, "PgClone") - .Add(0, initValue) - .Callable(1, "DependsOn") - .Arg(0, "parent") - .Seal() - .Seal() - .Seal() - .Apply(4, lambda) - .With(0, "row") - .Seal() - .Seal() - .Arg(1, "state") - .Seal() - .Seal() - .Build(); - } else { - initLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("row") - .Apply(lambda) - .With(0, "row") - .Seal() - .Seal() - .Build(); - - if (lambdaResult->GetKind() == ETypeAnnotationKind::Null) { - initLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("row") - .Callable("PgCast") - .Apply(0, initLambda) - .With(0, "row") - .Seal() - .Callable(1, "PgType") - .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name) - .Seal() - .Seal() - .Seal() - .Build(); - } - - updateLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("row") - .Param("state") - .Callable("If") - .Callable(0, "Exists") - .Arg(0, "state") - .Seal() - .Callable(1, "Coalesce") - .Callable(0, "PgResolvedCallCtx") - .Atom(0, transFuncDesc.Name) - .Atom(1, ToString(aggDesc.TransFuncId)) - .List(2) - .Seal() - .Arg(3, "state") - .Apply(4, lambda) - .With(0, "row") - .Seal() - .Seal() - .Arg(1, "state") - .Seal() - .Apply(2, lambda) - .With(0, "row") - .Seal() - .Seal() - .Seal() - .Build(); - } - - auto mergeLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state1") - .Param("state2") - .Callable("Void") - .Seal() - .Seal() - .Build(); - - auto zero = ctx.Expr.Builder(input->Pos()) - .Callable("PgConst") - .Atom(0, "0") - .Callable(1, "PgType") - .Atom(0, "int8") - .Seal() - .Seal() - .Build(); - - auto defaultValue = (func != "count") ? nullValue : zero; - - if (aggDesc.SerializeFuncId) { - const auto& serializeFuncDesc = NPg::LookupProc(aggDesc.SerializeFuncId); - saveLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state") - .Callable("PgResolvedCallCtx") - .Atom(0, serializeFuncDesc.Name) - .Atom(1, ToString(aggDesc.SerializeFuncId)) - .List(2) - .Seal() - .Arg(3, "state") - .Seal() - .Seal() - .Build(); - } - - if (aggDesc.DeserializeFuncId) { - const auto& deserializeFuncDesc = NPg::LookupProc(aggDesc.DeserializeFuncId); - loadLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state") - .Callable("PgResolvedCallCtx") - .Atom(0, deserializeFuncDesc.Name) - .Atom(1, ToString(aggDesc.DeserializeFuncId)) - .List(2) - .Seal() - .Arg(3, "state") - .Callable(4, "PgInternal0") - .Seal() - .Seal() - .Seal() - .Build(); - } - - if (aggDesc.CombineFuncId) { - const auto& combineFuncDesc = NPg::LookupProc(aggDesc.CombineFuncId); - if (combineFuncDesc.IsStrict) { - mergeLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state1") - .Param("state2") - .Callable("If") - .Callable(0, "Exists") - .Arg(0, "state1") - .Seal() - .Callable(1, "Coalesce") - .Callable(0, "PgResolvedCallCtx") - .Atom(0, combineFuncDesc.Name) - .Atom(1, ToString(aggDesc.CombineFuncId)) - .List(2) - .Seal() - .Arg(3, "state1") - .Arg(4, "state2") - .Seal() - .Arg(1, "state1") - .Seal() - .Arg(2, "state2") - .Seal() - .Seal() - .Build(); - } else { - mergeLambda = ctx.Expr.Builder(input->Pos()) - .Lambda() - .Param("state1") - .Param("state2") - .Callable("PgResolvedCallCtx") - .Atom(0, combineFuncDesc.Name) - .Atom(1, ToString(aggDesc.CombineFuncId)) - .List(2) - .Seal() - .Arg(3, "state1") - .Arg(4, "state2") - .Seal() - .Seal() - .Build(); - } - } - - auto typeNode = ExpandType(input->Pos(), *itemType, ctx.Expr); - if (onWindow) { - output = ctx.Expr.Builder(input->Pos()) - .Callable("WindowTraits") - .Add(0, typeNode) - .Add(1, initLambda) - .Add(2, updateLambda) - .Lambda(3) - .Param("value") - .Param("state") - .Callable("Void") - .Seal() - .Seal() - .Add(4, finishLambda) - .Add(5, defaultValue) - .Seal() - .Build(); - } else { - output = ctx.Expr.Builder(input->Pos()) - .Callable("AggregationTraits") - .Add(0, typeNode) - .Add(1, initLambda) - .Add(2, updateLambda) - .Add(3, saveLambda) - .Add(4, loadLambda) - .Add(5, mergeLambda) - .Add(6, finishLambda) - .Add(7, defaultValue) - .Seal() - .Build(); - } - + const NPg::TAggregateDesc& aggDesc = *aggDescPtr; + output = ExpandPgAggregationTraits(input->Pos(), aggDesc, onWindow, lambda, argTypes, itemType, ctx.Expr); return IGraphTransformer::TStatus::Repeat; } diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index 177755f7b15..8ac57ad21d4 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -96,6 +96,26 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate() TExprNode::TPtr TAggregateExpander::ExpandAggApply(const TExprNode::TPtr& node) { auto name = node->Head().Content(); + if (name.StartsWith("pg_")) { + auto func = name.SubStr(3); + auto itemType = node->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + TVector<ui32> argTypes; + bool needRetype = false; + auto status = ExtractPgTypesFromMultiLambda(node->ChildRef(2), argTypes, needRetype, Ctx); + YQL_ENSURE(status == IGraphTransformer::TStatus::Ok); + + const NPg::TAggregateDesc* aggDescPtr; + if (node->Content().EndsWith("State")) { + auto stateType = node->Child(2)->GetTypeAnn()->Cast<TPgExprType>()->GetId(); + auto resultType = node->GetTypeAnn()->Cast<TPgExprType>()->GetId(); + aggDescPtr = &NPg::LookupAggregation(TString(func), stateType, resultType); + } else { + aggDescPtr = &NPg::LookupAggregation(TString(func), argTypes); + } + + return ExpandPgAggregationTraits(node->Pos(), *aggDescPtr, false, node->ChildPtr(2), argTypes, itemType, Ctx); + } + auto exportsPtr = TypesCtx.Modules->GetModule("/lib/yql/aggregate.yql"); YQL_ENSURE(exportsPtr); const auto& exports = exportsPtr->Symbols(); diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 621483765f5..12135cbbf70 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -5566,6 +5566,16 @@ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& in } return stateType; + } else if (name.StartsWith("pg_")) { + auto func = name.SubStr(3); + TVector<ui32> argTypes; + bool needRetype = false; + auto status = ExtractPgTypesFromMultiLambda(input->ChildRef(2), argTypes, needRetype, ctx); + YQL_ENSURE(status == IGraphTransformer::TStatus::Ok); + + const NPg::TAggregateDesc& aggDesc = NPg::LookupAggregation(TString(func), argTypes); + const auto& procDesc = NPg::LookupProc(aggDesc.SerializeFuncId ? aggDesc.SerializeFuncId : aggDesc.TransFuncId); + return ctx.MakeType<TPgExprType>(procDesc.ResultType); } else { YQL_ENSURE(false, "Unknown AggApply: " << name); } @@ -5696,4 +5706,335 @@ bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& return true; } +IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda, TVector<ui32>& argTypes, + bool& needRetype, TExprContext& ctx) { + for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) { + auto type = lambda->Child(i)->GetTypeAnn(); + ui32 argType; + bool convertToPg; + if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx)) { + return IGraphTransformer::TStatus::Error; + } + + if (convertToPg) { + needRetype = true; + } + + argTypes.push_back(argType); + } + + if (needRetype) { + auto newLambda = ctx.DeepCopyLambda(*lambda); + for (ui32 i = 1; i < lambda->ChildrenSize(); ++i) { + auto type = lambda->Child(i)->GetTypeAnn(); + ui32 argType; + bool convertToPg; + if (!ExtractPgType(type, argType, convertToPg, lambda->Child(i)->Pos(), ctx)) { + return IGraphTransformer::TStatus::Error; + } + + if (convertToPg) { + newLambda->ChildRef(i) = ctx.NewCallable(newLambda->Child(i)->Pos(), "ToPg", { newLambda->ChildPtr(i) }); + } + } + + lambda = newLambda; + return IGraphTransformer::TStatus::Repeat; + } + + return IGraphTransformer::TStatus::Ok; +} + +TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow, + const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx) { + auto idLambda = ctx.Builder(pos) + .Lambda() + .Param("state") + .Arg("state") + .Seal() + .Build(); + + auto saveLambda = idLambda; + auto loadLambda = idLambda; + auto finishLambda = idLambda; + if (aggDesc.FinalFuncId) { + finishLambda = ctx.Builder(pos) + .Lambda() + .Param("state") + .Callable("PgResolvedCallCtx") + .Atom(0, NPg::LookupProc(aggDesc.FinalFuncId).Name) + .Atom(1, ToString(aggDesc.FinalFuncId)) + .List(2) + .Seal() + .Arg(3, "state") + .Seal() + .Seal() + .Build(); + } + + auto nullValue = ctx.NewCallable(pos, "Null", {}); + auto initValue = nullValue; + if (aggDesc.InitValue) { + initValue = ctx.Builder(pos) + .Callable("PgCast") + .Callable(0, "PgConst") + .Atom(0, aggDesc.InitValue) + .Callable(1, "PgType") + .Atom(0, "text") + .Seal() + .Seal() + .Callable(1, "PgType") + .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name) + .Seal() + .Seal() + .Build(); + } + + const auto& transFuncDesc = NPg::LookupProc(aggDesc.TransFuncId); + // use first non-null value as state if transFunc is strict + bool searchNonNullForState = false; + if (transFuncDesc.IsStrict && !aggDesc.InitValue) { + Y_ENSURE(argTypes.size() == 1); + searchNonNullForState = true; + } + + TExprNode::TPtr initLambda, updateLambda; + if (!searchNonNullForState) { + initLambda = ctx.Builder(pos) + .Lambda() + .Param("row") + .Param("parent") + .Callable("PgResolvedCallCtx") + .Atom(0, transFuncDesc.Name) + .Atom(1, ToString(aggDesc.TransFuncId)) + .List(2) + .Seal() + .Callable(3, "PgClone") + .Add(0, initValue) + .Callable(1, "DependsOn") + .Arg(0, "parent") + .Seal() + .Seal() + .Apply(4, lambda) + .With(0, "row") + .Seal() + .Seal() + .Seal() + .Build(); + + updateLambda = ctx.Builder(pos) + .Lambda() + .Param("row") + .Param("state") + .Param("parent") + .Callable("Coalesce") + .Callable(0, "PgResolvedCallCtx") + .Atom(0, transFuncDesc.Name) + .Atom(1, ToString(aggDesc.TransFuncId)) + .List(2) + .Seal() + .Callable(3, "Coalesce") + .Arg(0, "state") + .Callable(1, "PgClone") + .Add(0, initValue) + .Callable(1, "DependsOn") + .Arg(0, "parent") + .Seal() + .Seal() + .Seal() + .Apply(4, lambda) + .With(0, "row") + .Seal() + .Seal() + .Arg(1, "state") + .Seal() + .Seal() + .Build(); + } else { + initLambda = ctx.Builder(pos) + .Lambda() + .Param("row") + .Apply(lambda) + .With(0, "row") + .Seal() + .Seal() + .Build(); + + if (lambda->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Null) { + initLambda = ctx.Builder(pos) + .Lambda() + .Param("row") + .Callable("PgCast") + .Apply(0, initLambda) + .With(0, "row") + .Seal() + .Callable(1, "PgType") + .Atom(0, NPg::LookupType(aggDesc.TransTypeId).Name) + .Seal() + .Seal() + .Seal() + .Build(); + } + + updateLambda = ctx.Builder(pos) + .Lambda() + .Param("row") + .Param("state") + .Callable("If") + .Callable(0, "Exists") + .Arg(0, "state") + .Seal() + .Callable(1, "Coalesce") + .Callable(0, "PgResolvedCallCtx") + .Atom(0, transFuncDesc.Name) + .Atom(1, ToString(aggDesc.TransFuncId)) + .List(2) + .Seal() + .Arg(3, "state") + .Apply(4, lambda) + .With(0, "row") + .Seal() + .Seal() + .Arg(1, "state") + .Seal() + .Apply(2, lambda) + .With(0, "row") + .Seal() + .Seal() + .Seal() + .Build(); + } + + auto mergeLambda = ctx.Builder(pos) + .Lambda() + .Param("state1") + .Param("state2") + .Callable("Void") + .Seal() + .Seal() + .Build(); + + auto zero = ctx.Builder(pos) + .Callable("PgConst") + .Atom(0, "0") + .Callable(1, "PgType") + .Atom(0, "int8") + .Seal() + .Seal() + .Build(); + + auto defaultValue = (aggDesc.Name != "count") ? nullValue : zero; + + if (aggDesc.SerializeFuncId) { + const auto& serializeFuncDesc = NPg::LookupProc(aggDesc.SerializeFuncId); + saveLambda = ctx.Builder(pos) + .Lambda() + .Param("state") + .Callable("PgResolvedCallCtx") + .Atom(0, serializeFuncDesc.Name) + .Atom(1, ToString(aggDesc.SerializeFuncId)) + .List(2) + .Seal() + .Arg(3, "state") + .Seal() + .Seal() + .Build(); + } + + if (aggDesc.DeserializeFuncId) { + const auto& deserializeFuncDesc = NPg::LookupProc(aggDesc.DeserializeFuncId); + loadLambda = ctx.Builder(pos) + .Lambda() + .Param("state") + .Callable("PgResolvedCallCtx") + .Atom(0, deserializeFuncDesc.Name) + .Atom(1, ToString(aggDesc.DeserializeFuncId)) + .List(2) + .Seal() + .Arg(3, "state") + .Callable(4, "PgInternal0") + .Seal() + .Seal() + .Seal() + .Build(); + } + + if (aggDesc.CombineFuncId) { + const auto& combineFuncDesc = NPg::LookupProc(aggDesc.CombineFuncId); + if (combineFuncDesc.IsStrict) { + mergeLambda = ctx.Builder(pos) + .Lambda() + .Param("state1") + .Param("state2") + .Callable("If") + .Callable(0, "Exists") + .Arg(0, "state1") + .Seal() + .Callable(1, "Coalesce") + .Callable(0, "PgResolvedCallCtx") + .Atom(0, combineFuncDesc.Name) + .Atom(1, ToString(aggDesc.CombineFuncId)) + .List(2) + .Seal() + .Arg(3, "state1") + .Arg(4, "state2") + .Seal() + .Arg(1, "state1") + .Seal() + .Arg(2, "state2") + .Seal() + .Seal() + .Build(); + } else { + mergeLambda = ctx.Builder(pos) + .Lambda() + .Param("state1") + .Param("state2") + .Callable("PgResolvedCallCtx") + .Atom(0, combineFuncDesc.Name) + .Atom(1, ToString(aggDesc.CombineFuncId)) + .List(2) + .Seal() + .Arg(3, "state1") + .Arg(4, "state2") + .Seal() + .Seal() + .Build(); + } + } + + auto typeNode = ExpandType(pos, *itemType, ctx); + if (onWindow) { + return ctx.Builder(pos) + .Callable("WindowTraits") + .Add(0, typeNode) + .Add(1, initLambda) + .Add(2, updateLambda) + .Lambda(3) + .Param("value") + .Param("state") + .Callable("Void") + .Seal() + .Seal() + .Add(4, finishLambda) + .Add(5, defaultValue) + .Seal() + .Build(); + } else { + return ctx.Builder(pos) + .Callable("AggregationTraits") + .Add(0, typeNode) + .Add(1, initLambda) + .Add(2, updateLambda) + .Add(3, saveLambda) + .Add(4, loadLambda) + .Add(5, mergeLambda) + .Add(6, finishLambda) + .Add(7, defaultValue) + .Seal() + .Build(); + } +} + + } // NYql diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h index a7f45045229..6daf85dcca4 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.h +++ b/ydb/library/yql/core/yql_expr_type_annotation.h @@ -6,6 +6,7 @@ #include <ydb/library/yql/ast/yql_expr.h> #include <ydb/library/yql/core/expr_nodes/yql_expr_nodes.h> #include <ydb/library/yql/minikql/mkql_type_ops.h> +#include <ydb/library/yql/parser/pg_catalog/catalog.h> #include <library/cpp/enumbitset/enumbitset.h> @@ -317,4 +318,10 @@ bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inp bool GetAvgResultTypeOverState(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx); bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx); +IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda, TVector<ui32>& argTypes, + bool& needRetype, TExprContext& ctx); + +TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow, + const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx); + } diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp index 947972a04b7..20a8318f504 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.cpp +++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp @@ -1475,6 +1475,20 @@ bool ValidateAggregateArgs(const TAggregateDesc& d, const TVector<ui32>& argType return ValidateArgs(d.ArgTypes, argTypeIds); } +bool ValidateAggregateArgs(const TAggregateDesc& d, ui32 stateType, ui32 resultType) { + auto expectedStateType = LookupProc(d.SerializeFuncId ? d.SerializeFuncId : d.TransFuncId).ResultType; + if (stateType != expectedStateType) { + return false; + } + + auto expectedResultType = LookupProc(d.FinalFuncId ? d.FinalFuncId : d.TransFuncId).ResultType; + if (resultType != expectedResultType) { + return false; + } + + return true; +} + const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>& argTypeIds) { const auto& catalog = TCatalog::Instance(); auto aggIdPtr = catalog.AggregationsByName.FindPtr(to_lower(name)); @@ -1496,6 +1510,28 @@ const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32> << ArgTypesList(argTypeIds); } +const TAggregateDesc& LookupAggregation(const TString& name, ui32 stateType, ui32 resultType) { + const auto& catalog = TCatalog::Instance(); + auto aggIdPtr = catalog.AggregationsByName.FindPtr(to_lower(name)); + if (!aggIdPtr) { + throw yexception() << "No such aggregate: " << name; + } + + for (const auto& id : *aggIdPtr) { + const auto& d = catalog.Aggregations.FindPtr(id); + Y_ENSURE(d); + if (!ValidateAggregateArgs(*d, stateType, resultType)) { + continue; + } + + return *d; + } + + throw yexception() << "Unable to find an overload for aggregate " << name << " with given state type: " << + NPg::LookupType(stateType).Name << " and result type: " << + NPg::LookupType(resultType).Name; +} + bool HasOpClass(EOpClassMethod method, ui32 typeId) { const auto& catalog = TCatalog::Instance(); return catalog.OpClasses.contains(std::make_pair(method, typeId)); diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h index cc1bbe46732..e8d4968b29a 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.h +++ b/ydb/library/yql/parser/pg_catalog/catalog.h @@ -206,6 +206,7 @@ const TOperDesc& LookupOper(ui32 operId); bool HasAggregation(const TString& name); const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>& argTypeIds); +const TAggregateDesc& LookupAggregation(const TString& name, ui32 stateType, ui32 resultType); bool HasOpClass(EOpClassMethod method, ui32 typeId); const TOpClassDesc* LookupDefaultOpClass(EOpClassMethod method, ui32 typeId); diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index fa583b41107..8969d7c3fba 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2505,6 +2505,14 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("PgClone", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); + if (IsNull(node.Head())) { + return input; + } + + if (NPg::LookupType(node.GetTypeAnn()->Cast<TPgExprType>()->GetId()).PassByValue) { + return input; + } + TVector<TRuntimeNode> dependentNodes; for (ui32 i = 1; i < node.ChildrenSize(); ++i) { dependentNodes.push_back(MkqlBuildExpr(*node.Child(i), ctx)); diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp index 57a8395d0fa..ceb2d3a3ad0 100644 --- a/ydb/library/yql/sql/v1/aggregation.cpp +++ b/ydb/library/yql/sql/v1/aggregation.cpp @@ -1370,6 +1370,11 @@ public: Y_UNUSED(many); Y_UNUSED(ctx); Y_UNUSED(allowAggApply); + if (ctx.EmitAggApply && allowAggApply && AggMode != EAggregateMode::OverWindow) { + return Y("AggApply", + Q("pg_" + to_lower(PgFunc)), Y("ListItemType", type), Lambda); + } + return Y(AggMode == EAggregateMode::OverWindow ? "PgWindowTraits" : "PgAggregationTraits", Q(PgFunc), Y("ListItemType", type), Lambda); } |