diff options
author | vvvv <vvvv@ydb.tech> | 2022-11-09 20:10:11 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-11-09 20:10:11 +0300 |
commit | b681e93691425927769e98be40a5750be2a45740 (patch) | |
tree | 2a320cf0797215841e0581621fbfefb302a62c70 | |
parent | 0211ab89df33d2b2e9e55b797e40219f32998c3f (diff) | |
download | ydb-b681e93691425927769e98be40a5750be2a45740.tar.gz |
aggApply for avg
-rw-r--r-- | ydb/core/kqp/prepare/kqp_type_ann.cpp | 3 | ||||
-rw-r--r-- | ydb/library/yql/core/expr_nodes/yql_expr_nodes.json | 5 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_impl.h | 1 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 144 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.cpp | 12 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.cpp | 109 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_phy.cpp | 14 | ||||
-rw-r--r-- | ydb/library/yql/sql/v1/aggregation.cpp | 3 |
9 files changed, 246 insertions, 47 deletions
diff --git a/ydb/core/kqp/prepare/kqp_type_ann.cpp b/ydb/core/kqp/prepare/kqp_type_ann.cpp index 9f12a89f54..d419fbd1a8 100644 --- a/ydb/core/kqp/prepare/kqp_type_ann.cpp +++ b/ydb/core/kqp/prepare/kqp_type_ann.cpp @@ -6,6 +6,7 @@ #include <ydb/library/yql/core/type_ann/type_ann_core.h> #include "ydb/library/yql/core/type_ann/type_ann_impl.h" #include <ydb/library/yql/core/yql_opt_utils.h> +#include <ydb/library/yql/core/yql_expr_type_annotation.h> #include <ydb/library/yql/dq/type_ann/dq_type_ann.h> #include <ydb/library/yql/utils/log/log.h> @@ -895,7 +896,7 @@ TStatus AnnotateOlapAgg(const TExprNode::TPtr& node, TExprContext& ctx) { } else if (opType->Content() == "sum") { auto colType = structType->FindItemType(colName->Content()); const TTypeAnnotationNode* resultType = nullptr; - if(!NTypeAnnImpl::GetSumResultType(node->Pos(), *colType, resultType, ctx)) { + if(!GetSumResultType(node->Pos(), *colType, resultType, ctx)) { ctx.AddError(TIssue(ctx.GetPosition(node->Pos()), TStringBuilder() << "Unsupported type: " << FormatType(colType) << ". Expected Data or Optional of Data or Null.")); return TStatus::Error; diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json index 6cd1d45e50..307844f9f7 100644 --- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json +++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json @@ -2228,7 +2228,10 @@ { "Name": "TCoAggApplyState", "Base": "TCoAggApplyBase", - "Match": {"Type": "Callable", "Name": "AggApplyState"} + "Match": {"Type": "Callable", "Name": "AggApplyState"}, + "Children": [ + {"Index": 3, "Name": "OriginalType", "Type": "TExprBase"} + ] }, { "Name": "TCoAggOverState", diff --git a/ydb/library/yql/core/type_ann/type_ann_impl.h b/ydb/library/yql/core/type_ann/type_ann_impl.h index 96896b037a..48e8ffd60a 100644 --- a/ydb/library/yql/core/type_ann/type_ann_impl.h +++ b/ydb/library/yql/core/type_ann/type_ann_impl.h @@ -37,6 +37,5 @@ namespace NTypeAnnImpl { TMaybe<ui32> FindOrReportMissingMember(TStringBuf memberName, TPositionHandle pos, const TStructExprType& structType, TExprContext& ctx); TExprNode::TPtr MakeNothingData(TExprContext& ctx, TPositionHandle pos, TStringBuf data); - bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx); } // namespace NTypeAnnImpl } // namespace NYql diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index c6708b83ca..486163dec8 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -4951,6 +4951,8 @@ namespace { .Add(0, root->ChildPtr(0)) .Add(1, root->ChildPtr(1)) .Add(2, input->ChildPtr(0)) + .Callable(3, "Void") + .Seal() .Seal() .Build(); @@ -5104,47 +5106,17 @@ namespace { return IGraphTransformer::TStatus::Ok; } - bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { - bool isOptional; - const TDataExprType* lambdaType; - if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) { - auto lambdaTypeSlot = lambdaType->GetSlot(); - const TTypeAnnotationNode *sumResultType = nullptr; - if (IsDataTypeSigned(lambdaTypeSlot)) { - sumResultType = ctx.MakeType<TDataExprType>(EDataSlot::Int64); - } else if (IsDataTypeUnsigned(lambdaTypeSlot)) { - sumResultType = ctx.MakeType<TDataExprType>(EDataSlot::Uint64); - } else if (IsDataTypeDecimal(lambdaTypeSlot)) { - const auto decimalType = lambdaType->Cast<TDataExprParamsType>(); - sumResultType = ctx.MakeType<TDataExprParamsType>(EDataSlot::Decimal, "35", decimalType->GetParamTwo()); - } else if (IsDataTypeFloat(lambdaTypeSlot) || IsDataTypeInterval(lambdaTypeSlot)) { - sumResultType = ctx.MakeType<TDataExprType>(lambdaTypeSlot); - } else { - ctx.AddError(TIssue(ctx.GetPosition(pos), - TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot)); + IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + const bool overState = input->Content().EndsWith("State"); + if (overState) { + if (!EnsureArgsCount(*input, 4, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - - if (isOptional) { - sumResultType = ctx.MakeType<TOptionalExprType>(sumResultType); - } - - retType = sumResultType; - return true; - } else if (IsNull(itemType)) { - retType = ctx.MakeType<TNullExprType>(); - return true; } else { - ctx.AddError(TIssue(ctx.GetPosition(pos), - TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data.")); - return false; - } - } - - IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { - Y_UNUSED(output); - if (!EnsureArgsCount(*input, 3, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; + if (!EnsureArgsCount(*input, 3, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } } if (!EnsureAtom(*input->Child(0), ctx.Expr)) { @@ -5156,6 +5128,15 @@ namespace { return status; } + bool hasOriginalType = false; + if (overState && !input->Child(3)->IsCallable("Void")) { + if (auto status = EnsureTypeRewrite(input->ChildRef(3), ctx.Expr); status != IGraphTransformer::TStatus::Ok) { + return status; + } + + hasOriginalType = true; + } + auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); auto& lambda = input->ChildRef(2); const auto status = ConvertToLambda(lambda, ctx.Expr, 1); @@ -5179,7 +5160,78 @@ namespace { return IGraphTransformer::TStatus::Error; } + if (overState) { + if (!IsSameAnnotation(*lambda->GetTypeAnn(), *retType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch sum type, expected: " << *lambda->GetTypeAnn() << ", but got: " << *retType)); + return IGraphTransformer::TStatus::Error; + } + } + input->SetTypeAnn(retType); + } else if (name == "avg") { + const TTypeAnnotationNode* retType; + if (!overState) { + if (!GetAvgResultType(input->Pos(), *lambda->GetTypeAnn(), retType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } else { + auto itemType = lambda->GetTypeAnn(); + if (IsNull(*itemType)) { + retType = itemType; + } else { + bool isOptional = false; + if (itemType->GetKind() == ETypeAnnotationKind::Optional) { + isOptional = true; + itemType = itemType->Cast<TOptionalExprType>()->GetItemType(); + } + + if (!EnsureTupleTypeSize(lambda->Pos(), itemType, 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto tupleType = itemType->Cast<TTupleExprType>(); + auto sumType = tupleType->GetItems()[0]; + const TTypeAnnotationNode* sumTypeOut; + if (!GetSumResultType(input->Pos(), *sumType, sumTypeOut, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!IsSameAnnotation(*sumType, *sumTypeOut)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch sum type, expected: " << *sumType << ", but got: " << *sumTypeOut)); + return IGraphTransformer::TStatus::Error; + } + + auto countType = tupleType->GetItems()[1]; + if (!EnsureSpecificDataType(lambda->Pos(), *countType, EDataSlot::Uint64, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + retType = sumType; + if (isOptional) { + retType = ctx.Expr.MakeType<TOptionalExprType>(retType); + } + } + } + + if (hasOriginalType) { + auto originalExtractorType = input->Child(3)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + if (!EnsureStructType(input->Pos(), *originalExtractorType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto structType = originalExtractorType->Cast<TStructExprType>(); + if (structType->GetSize() != 1) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Expected struct with one member")); + return IGraphTransformer::TStatus::Error; + } + + input->SetTypeAnn(structType->GetItems()[0]->GetItemType()); + } else { + input->SetTypeAnn(retType); + } } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "Unsupported agg name: " << name)); @@ -5203,7 +5255,7 @@ namespace { ui32 expectedArgs; if (name == "count_all") { expectedArgs = 1; - } else if (name == "count" || name == "sum") { + } else if (name == "count" || name == "sum" || name == "avg") { expectedArgs = 2; } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), @@ -5223,7 +5275,7 @@ namespace { if (name == "count_all" || name == "count") { input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)); - } else { + } else if (name == "sum") { auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const TTypeAnnotationNode* retType; if (!GetSumResultType(input->Pos(), *itemType, retType, ctx.Expr)) { @@ -5231,6 +5283,18 @@ namespace { } input->SetTypeAnn(retType); + } else if (name == "avg") { + auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + const TTypeAnnotationNode* retType; + if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + input->SetTypeAnn(retType); + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Unsupported agg name: " << name)); + return IGraphTransformer::TStatus::Error; } return IGraphTransformer::TStatus::Ok; diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index cfbcc4d5f9..159ce50af7 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -97,7 +97,7 @@ TExprNode::TPtr TAggregateExpander::ExpandAggApply(const TExprNode::TPtr& node) TNodeOnNodeOwnedMap deepClones; auto lambda = Ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false); - auto listTypeNode = Ctx.NewCallable(node->Pos(), "ListType", { node->ChildPtr(1) }); + auto listTypeNode = Ctx.NewCallable(node->Pos(), "ListType", { node->ChildPtr(node->ChildrenSize() == 4 && !node->Child(3)->IsCallable("Void") ? 3 : 1) }); auto extractor = node->ChildPtr(2); auto traits = Ctx.ReplaceNodes(lambda->TailPtr(), { @@ -1969,11 +1969,21 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() { .Build(); if (isAggApply) { + auto originalExtractorTypeNode = Ctx.Builder(Node->Pos()) + .Callable("StructType") + .List(0) + .Add(0, InitialColumnNames[index]) + .Add(1, ExpandType(Node->Pos(), *originalTrait->GetTypeAnn(), Ctx)) + .Seal() + .Seal() + .Build(); + mergeTraits.push_back(Ctx.Builder(Node->Pos()) .Callable("AggApplyState") .Add(0, originalTrait->ChildPtr(0)) .Add(1, extractorTypeNode) .Add(2, extractor) + .Add(3, originalExtractorTypeNode) .Seal() .Build()); } else { diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 4e6a43e9b6..91d3bd8c70 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -5227,13 +5227,120 @@ const TTypeAnnotationNode* GetBlockItemType(const TTypeAnnotationNode& type, boo } const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx) { - Y_UNUSED(ctx); auto name = input->Child(0)->Content(); if (name == "count" || name == "count_all" || name == "sum") { return input->GetTypeAnn(); + } else if (name == "avg") { + auto itemType = input->Content().StartsWith("AggBlock") ? + input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType() : + input->Child(2)->GetTypeAnn(); + if (input->Content().EndsWith("State")) { + return itemType; + } + + if (IsNull(*itemType)) { + return itemType; + } + + bool isOptional; + const TDataExprType* lambdaType; + YQL_ENSURE(IsDataOrOptionalOfData(itemType, isOptional, lambdaType)); + auto lambdaTypeSlot = lambdaType->GetSlot(); + const TTypeAnnotationNode* stateValueType; + if (IsDataTypeDecimal(lambdaTypeSlot)) { + const auto decimalType = lambdaType->Cast<TDataExprParamsType>(); + stateValueType = ctx.MakeType<TDataExprParamsType>(EDataSlot::Decimal, "35", decimalType->GetParamTwo()); + } else if (IsDataTypeInterval(lambdaTypeSlot)) { + stateValueType = ctx.MakeType<TDataExprParamsType>(EDataSlot::Decimal, "35", "0"); + } else { + stateValueType = ctx.MakeType<TDataExprType>(NUdf::EDataSlot::Double); + } + + TVector<const TTypeAnnotationNode*> items = { + stateValueType, + ctx.MakeType<TDataExprType>(NUdf::EDataSlot::Uint64) + }; + + const TTypeAnnotationNode* stateType = ctx.MakeType<TTupleExprType>(std::move(items)); + if (itemType->GetKind() == ETypeAnnotationKind::Optional) { + stateType = ctx.MakeType<TOptionalExprType>(stateType); + } + + return stateType; } else { YQL_ENSURE(false, "Unknown AggApply: " << name); } } +bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { + bool isOptional; + const TDataExprType* lambdaType; + if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) { + auto lambdaTypeSlot = lambdaType->GetSlot(); + const TTypeAnnotationNode *sumResultType = nullptr; + if (IsDataTypeSigned(lambdaTypeSlot)) { + sumResultType = ctx.MakeType<TDataExprType>(EDataSlot::Int64); + } else if (IsDataTypeUnsigned(lambdaTypeSlot)) { + sumResultType = ctx.MakeType<TDataExprType>(EDataSlot::Uint64); + } else if (IsDataTypeDecimal(lambdaTypeSlot)) { + const auto decimalType = lambdaType->Cast<TDataExprParamsType>(); + sumResultType = ctx.MakeType<TDataExprParamsType>(EDataSlot::Decimal, "35", decimalType->GetParamTwo()); + } else if (IsDataTypeFloat(lambdaTypeSlot) || IsDataTypeInterval(lambdaTypeSlot)) { + sumResultType = ctx.MakeType<TDataExprType>(lambdaTypeSlot); + } else { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot)); + return IGraphTransformer::TStatus::Error; + } + + if (isOptional) { + sumResultType = ctx.MakeType<TOptionalExprType>(sumResultType); + } + + retType = sumResultType; + return true; + } else if (IsNull(itemType)) { + retType = ctx.MakeType<TNullExprType>(); + return true; + } else { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data.")); + return false; + } +} + +bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) { + bool isOptional; + const TDataExprType* lambdaType; + if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) { + auto lambdaTypeSlot = lambdaType->GetSlot(); + const TTypeAnnotationNode *avgResultType = nullptr; + if (IsDataTypeNumeric(lambdaTypeSlot)) { + avgResultType = ctx.MakeType<TDataExprType>(EDataSlot::Double); + } else if (IsDataTypeDecimal(lambdaTypeSlot)) { + avgResultType = &itemType; + } else if (IsDataTypeInterval(lambdaTypeSlot)) { + avgResultType = &itemType; + } else { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot)); + return IGraphTransformer::TStatus::Error; + } + + if (isOptional) { + avgResultType = ctx.MakeType<TOptionalExprType>(avgResultType); + } + + retType = avgResultType; + return true; + } else if (IsNull(itemType)) { + retType = ctx.MakeType<TNullExprType>(); + return true; + } else { + ctx.AddError(TIssue(ctx.GetPosition(pos), + TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data.")); + return false; + } +} + } // NYql diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h index 781d6d13fb..a5c260a873 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.h +++ b/ydb/library/yql/core/yql_expr_type_annotation.h @@ -298,5 +298,7 @@ bool EnsureBlockOrScalarType(TPositionHandle position, const TTypeAnnotationNode const TTypeAnnotationNode* GetBlockItemType(const TTypeAnnotationNode& type, bool& isScalar); const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx); +bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx); +bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx); } diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp index f68b79353a..884744bd11 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp @@ -1687,6 +1687,17 @@ TExprBase DqRewriteLengthOfStageOutputLegacy(TExprBase node, TExprContext& ctx, }; auto stateRowType = ctx.MakeType<TStructExprType>(stateItems); + auto stateTypeNode = ExpandType(node.Pos(), *stateRowType, ctx); + + auto originalTypeNode = ctx.Builder(node.Pos()) + .Callable("StructType") + .List(0) + .Add(0, field.Ptr()) + .Callable(1, "VoidType") + .Seal() + .Seal() + .Seal() + .Build(); auto aggregateFinal = Build<TCoAggregateMergeFinalize>(ctx, node.Pos()) .Input(aggregateCombine) @@ -1699,7 +1710,7 @@ TExprBase DqRewriteLengthOfStageOutputLegacy(TExprBase node, TExprContext& ctx, .Name<TCoAtom>() .Value("count_all") .Build() - .InputType(ExpandType(node.Pos(), *stateRowType, ctx)) + .InputType(stateTypeNode) .Extractor<TCoLambda>() .Args({ "row" }) .Body<TCoMember>() @@ -1707,6 +1718,7 @@ TExprBase DqRewriteLengthOfStageOutputLegacy(TExprBase node, TExprContext& ctx, .Name(field) .Build() .Build() + .OriginalType(originalTypeNode) .Build() .Build() .Build() diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp index c141c859c1..0182fc0bd7 100644 --- a/ydb/library/yql/sql/v1/aggregation.cpp +++ b/ydb/library/yql/sql/v1/aggregation.cpp @@ -31,7 +31,8 @@ namespace { static const THashSet<TString> AggApplyFuncs = { "count_traits_factory", - "sum_traits_factory" + "sum_traits_factory", + "avg_traits_factory", }; class TAggregationFactory : public IAggregation { |