diff options
author | vvvv <[email protected]> | 2022-07-29 13:31:19 +0300 |
---|---|---|
committer | vvvv <[email protected]> | 2022-07-29 13:31:19 +0300 |
commit | c47ea27d22ee334dd949daf9bd7a39078639efdd (patch) | |
tree | 1e1481509741a54e19c3b327606e64c42ca65a10 | |
parent | c42419d12ba88adc5e1dcaac8577df094fa7d92e (diff) |
initial implementation of AggApply for count. Only peephole is working atm.
19 files changed, 202 insertions, 45 deletions
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp index 03a101d0e00..8c4405e6892 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp @@ -74,7 +74,7 @@ protected: } TMaybeNode<TExprBase> RewriteAggregate(TExprBase node, TExprContext& ctx) { - TExprBase output = DqRewriteAggregate(node, ctx); + TExprBase output = DqRewriteAggregate(node, ctx, TypesCtx); DumpAppliedRule("RewriteAggregate", node.Ptr(), output.Ptr(), ctx); return output; } diff --git a/ydb/core/kqp/provider/yql_kikimr_datasink.cpp b/ydb/core/kqp/provider/yql_kikimr_datasink.cpp index 0d1f6af7694..37f443e1bb8 100644 --- a/ydb/core/kqp/provider/yql_kikimr_datasink.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_datasink.cpp @@ -313,7 +313,7 @@ public: , SessionCtx(sessionCtx) , IntentDeterminationTransformer(CreateKiSinkIntentDeterminationTransformer(sessionCtx)) , TypeAnnotationTransformer(CreateKiSinkTypeAnnotationTransformer(gateway, sessionCtx)) - , LogicalOptProposalTransformer(CreateKiLogicalOptProposalTransformer(sessionCtx)) + , LogicalOptProposalTransformer(CreateKiLogicalOptProposalTransformer(sessionCtx, types)) , PhysicalOptProposalTransformer(CreateKiPhysicalOptProposalTransformer(sessionCtx)) , CallableExecutionTransformer(CreateKiSinkCallableExecutionTransformer(gateway, sessionCtx, queryExecutor)) , PlanInfoTransformer(CreateKiSinkPlanInfoTransformer(queryExecutor)) diff --git a/ydb/core/kqp/provider/yql_kikimr_opt.cpp b/ydb/core/kqp/provider/yql_kikimr_opt.cpp index e04c518f2f7..a3b8a36848a 100644 --- a/ydb/core/kqp/provider/yql_kikimr_opt.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_opt.cpp @@ -100,7 +100,7 @@ TExprNode::TPtr KiEraseOverSelectRow(TExprBase node, TExprContext& ctx) { return node.Ptr(); } -TExprNode::TPtr KiRewriteAggregate(TExprBase node, TExprContext& ctx) { +TExprNode::TPtr KiRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) { if (!node.Maybe<TCoAggregate>()) { return node.Ptr(); } @@ -120,7 +120,7 @@ TExprNode::TPtr KiRewriteAggregate(TExprBase node, TExprContext& ctx) { } YQL_CLOG(INFO, ProviderKikimr) << "KiRewriteAggregate"; - return ExpandAggregate(true, node.Ptr(), ctx); + return ExpandAggregate(true, node.Ptr(), ctx, typesCtx); } TExprNode::TPtr KiRedundantSortByPk(TExprBase node, TExprContext& ctx, @@ -655,8 +655,8 @@ TExprNode::TPtr KiApplyExtractMembersToSelectRange(TExprBase node, TExprContext& } // namespace -TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx) { - return CreateFunctorTransformer([sessionCtx](const TExprNode::TPtr& input, TExprNode::TPtr& output, +TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx, TTypeAnnotationContext& types) { + return CreateFunctorTransformer([sessionCtx, &types](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { typedef IGraphTransformer::TStatus TStatus; @@ -666,7 +666,7 @@ TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr< GatherParents(*input, parentsMap); optCtx.ParentsMap = &parentsMap; - TStatus status = OptimizeExpr(input, output, [sessionCtx, &optCtx](const TExprNode::TPtr& inputNode, TExprContext& ctx) { + TStatus status = OptimizeExpr(input, output, [sessionCtx, &optCtx, &types](const TExprNode::TPtr& inputNode, TExprContext& ctx) { auto ret = inputNode; TExprBase node(inputNode); @@ -710,7 +710,7 @@ TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr< return ret; } - ret = KiRewriteAggregate(node, ctx); + ret = KiRewriteAggregate(node, ctx, types); if (ret != inputNode) { return ret; } diff --git a/ydb/core/kqp/provider/yql_kikimr_provider_impl.h b/ydb/core/kqp/provider/yql_kikimr_provider_impl.h index 2149d95309f..43e710002e5 100644 --- a/ydb/core/kqp/provider/yql_kikimr_provider_impl.h +++ b/ydb/core/kqp/provider/yql_kikimr_provider_impl.h @@ -181,7 +181,7 @@ TAutoPtr<IGraphTransformer> CreateKiSourceTypeAnnotationTransformer(TIntrusivePt TTypeAnnotationContext& types); TAutoPtr<IGraphTransformer> CreateKiSinkTypeAnnotationTransformer(TIntrusivePtr<IKikimrGateway> gateway, TIntrusivePtr<TKikimrSessionContext> sessionCtx); -TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx); +TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx, TTypeAnnotationContext& types); TAutoPtr<IGraphTransformer> CreateKiPhysicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx); TAutoPtr<IGraphTransformer> CreateKiSourceLoadTableMetadataTransformer(TIntrusivePtr<IKikimrGateway> gateway, TIntrusivePtr<TKikimrSessionContext> sessionCtx); diff --git a/ydb/library/yql/core/common_opt/yql_co_flow1.cpp b/ydb/library/yql/core/common_opt/yql_co_flow1.cpp index d15804a3051..4b61c4ebd85 100644 --- a/ydb/library/yql/core/common_opt/yql_co_flow1.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_flow1.cpp @@ -1269,6 +1269,10 @@ TExprNode::TPtr CountAggregateRewrite(const TCoAggregate& node, TExprContext& ct const bool isDistinct = (aggregatedColumn.Ref().ChildrenSize() == 3); auto traits = aggregatedColumn.Ref().Child(1); + if (!traits->IsCallable("AggregationTraits")) { + return node.Ptr(); + } + auto outputColumn = aggregatedColumn.Ref().HeadPtr(); // validation of traits auto inputItemType = traits->Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(); diff --git a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp index 7a67e5bc029..b4c1b618b9e 100644 --- a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp @@ -50,6 +50,10 @@ TExprNode::TPtr AggregateSubsetFieldsAnalyzer(const TCoAggregate& node, TExprCon } else { auto traits = x.Ref().Child(1); + if (!traits->IsCallable("AggregationTraits")) { + return node.Ptr(); + } + auto structType = traits->Child(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>(); for (const auto& item : structType->GetItems()) { usedFields.insert(item->GetName()); diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 3b577aa0655..4f67f210c39 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -31,8 +31,8 @@ using namespace NNodes; using TPeepHoleOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&); using TPeepHoleOptimizerMap = std::unordered_map<std::string_view, TPeepHoleOptimizerPtr>; -using TNonDeterministicOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&, TTypeAnnotationContext& types); -using TNonDeterministicOptimizerMap = std::unordered_map<std::string_view, TNonDeterministicOptimizerPtr>; +using TExtPeepHoleOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&, TTypeAnnotationContext& types); +using TExtPeepHoleOptimizerMap = std::unordered_map<std::string_view, TExtPeepHoleOptimizerPtr>; TExprNode::TPtr MakeNothing(TPositionHandle pos, const TTypeAnnotationNode& type, TExprContext& ctx) { return ctx.NewCallable(pos, "Nothing", {ExpandType(pos, *ctx.MakeType<TOptionalExprType>(&type), ctx)}); @@ -1884,22 +1884,26 @@ TExprNode::TPtr ExpandFilter(const TExprNode::TPtr& input, TExprContext& ctx) { } IGraphTransformer::TStatus PeepHoleCommonStage(const TExprNode::TPtr& input, TExprNode::TPtr& output, - TExprContext& ctx, TTypeAnnotationContext& types, const TPeepHoleOptimizerMap& optimizers) + TExprContext& ctx, TTypeAnnotationContext& types, + const TPeepHoleOptimizerMap& optimizers, const TExtPeepHoleOptimizerMap& extOptimizers) { TOptimizeExprSettings settings(&types); settings.CustomInstantTypeTransformer = types.CustomInstantTypeTransformer.Get(); - return OptimizeExpr(input, output, [&optimizers](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { + return OptimizeExpr(input, output, [&optimizers, &extOptimizers, &types](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { if (const auto rule = optimizers.find(node->Content()); optimizers.cend() != rule) return (rule->second)(node, ctx); + if (const auto rule = extOptimizers.find(node->Content()); extOptimizers.cend() != rule) + return (rule->second)(node, ctx, types); + return node; }, ctx, settings); } IGraphTransformer::TStatus PeepHoleFinalStage(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx, TTypeAnnotationContext& types, bool* hasNonDeterministicFunctions, - const TPeepHoleOptimizerMap& optimizers, const TNonDeterministicOptimizerMap& nonDetOptimizers) + const TPeepHoleOptimizerMap& optimizers, const TExtPeepHoleOptimizerMap& nonDetOptimizers) { TOptimizeExprSettings settings(&types); settings.CustomInstantTypeTransformer = types.CustomInstantTypeTransformer.Get(); @@ -5823,7 +5827,6 @@ struct TPeepHoleRules { {"OptionalReduce", &ExpandOptionalReduce}, {"AggrMin", &ExpandAggrMinMax<true>}, {"AggrMax", &ExpandAggrMinMax<false>}, - {"Aggregate", &ExpandAggregatePeephole}, {"And", &OptimizeLogicalDups<true>}, {"Or", &OptimizeLogicalDups<false>}, {"CombineByKey", &ExpandCombineByKey}, @@ -5850,6 +5853,10 @@ struct TPeepHoleRules { {"ToFlow", &DropToFlowDeps}, }; + static constexpr std::initializer_list<TExtPeepHoleOptimizerMap::value_type> CommonStageExtRulesInit = { + {"Aggregate", &ExpandAggregatePeephole}, + }; + static constexpr std::initializer_list<TPeepHoleOptimizerMap::value_type> SimplifyStageRulesInit = { {"Map", &OptimizeMap<false, EnableNewOptimizers>}, {"OrderedMap", &OptimizeMap<true, EnableNewOptimizers>}, @@ -5909,7 +5916,7 @@ struct TPeepHoleRules { {"SqueezeToDict", &OptimizeSqueezeToDict} }; - static constexpr std::initializer_list<TNonDeterministicOptimizerMap::value_type> FinalStageNonDetRulesInit = { + static constexpr std::initializer_list<TExtPeepHoleOptimizerMap::value_type> FinalStageNonDetRulesInit = { {"Random", &Random0Arg<double>}, {"RandomNumber", &Random0Arg<ui64>}, {"RandomUuid", &Random0Arg<TGUID>}, @@ -5921,6 +5928,7 @@ struct TPeepHoleRules { TPeepHoleRules() : CommonStageRules(CommonStageRulesInit) + , CommonStageExtRules(CommonStageExtRulesInit) , FinalStageRules(FinalStageRulesInit) , SimplifyStageRules(SimplifyStageRulesInit) , FinalStageNonDetRules(FinalStageNonDetRulesInit) @@ -5931,9 +5939,10 @@ struct TPeepHoleRules { } const TPeepHoleOptimizerMap CommonStageRules; + const TExtPeepHoleOptimizerMap CommonStageExtRules; const TPeepHoleOptimizerMap FinalStageRules; const TPeepHoleOptimizerMap SimplifyStageRules; - const TNonDeterministicOptimizerMap FinalStageNonDetRules; + const TExtPeepHoleOptimizerMap FinalStageNonDetRules; }; template <bool EnableNewOptimizers> @@ -5955,7 +5964,9 @@ THolder<IGraphTransformer> CreatePeepHoleCommonStageTransformer(TTypeAnnotationC pipeline.AddCommonOptimization(issueCode); pipeline.Add(CreateFunctorTransformer( [&types](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { - return PeepHoleCommonStage(input, output, ctx, types, TPeepHoleRules<EnableNewOptimizers>::Instance().CommonStageRules); + return PeepHoleCommonStage(input, output, ctx, types, + TPeepHoleRules<EnableNewOptimizers>::Instance().CommonStageRules, + TPeepHoleRules<EnableNewOptimizers>::Instance().CommonStageExtRules); } ), "PeepHoleCommon", @@ -5997,7 +6008,7 @@ THolder<IGraphTransformer> CreatePeepHoleFinalStageTransformer(TTypeAnnotationCo } const auto& nonDetStageRules = withNonDeterministicRules ? - TPeepHoleRules<EnableNewOptimizers>::Instance().FinalStageNonDetRules : TNonDeterministicOptimizerMap{}; + TPeepHoleRules<EnableNewOptimizers>::Instance().FinalStageNonDetRules : TExtPeepHoleOptimizerMap{}; return PeepHoleFinalStage(input, output, ctx, types, hasNonDeterministicFunctions, stageRules, nonDetStageRules); } diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index b2dcbdc39fe..7b51ab5fac6 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11244,6 +11244,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["MultiAggregate"] = &MultiAggregateWrapper; Functions["Aggregate"] = &AggregateWrapper; Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper; + Functions["AggApply"] = &AggApplyWrapper; Functions["WinOnRows"] = &WinOnWrapper; Functions["WinOnGroups"] = &WinOnWrapper; Functions["WinOnRange"] = &WinOnWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index 743a7a872db..e9b27e10c91 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -4726,7 +4726,9 @@ namespace { } } - if (!child->Child(1)->IsCallable("AggregationTraits")) { + const bool isAggApply = child->Child(1)->IsCallable("AggApply"); + const bool isTraits = child->Child(1)->IsCallable("AggregationTraits"); + if (!isAggApply && !isTraits) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Pos()), "Expected aggregation traits")); return IGraphTransformer::TStatus::Error; } @@ -4752,7 +4754,7 @@ namespace { } } - auto finishType = child->Child(1)->Child(6)->GetTypeAnn(); + auto finishType = isAggApply ? child->Child(1)->GetTypeAnn() : child->Child(1)->Child(6)->GetTypeAnn(); bool isOptional = finishType->GetKind() == ETypeAnnotationKind::Optional; if (child->Head().IsList()) { if (isOptional) { @@ -4774,17 +4776,33 @@ namespace { item->GetKind() != ETypeAnnotationKind::Optional ? ctx.Expr.MakeType<TOptionalExprType>(item) : item)); } } else { - auto defVal = child->Child(1)->Child(7); - if (defVal->IsCallable("Null") && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) { + const TTypeAnnotationNode* defValType; + bool isDefNull; + if (isTraits) { + auto defVal = child->Child(1)->Child(7); + isDefNull = defVal->IsCallable("Null"); + defValType = defVal->GetTypeAnn(); + } else { + auto name = child->Child(1)->Child(0)->Content(); + if (name == "count" || name == "count_all") { + isDefNull = false; + defValType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); + } else { + isDefNull = true; + defValType = ctx.Expr.MakeType<TNullExprType>(); + } + } + + if (isDefNull && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) { if (finishType->GetKind() != ETypeAnnotationKind::Null && finishType->GetKind() != ETypeAnnotationKind::Pg) { finishType = ctx.Expr.MakeType<TOptionalExprType>(finishType); } - } else if (!defVal->IsCallable("Null") && defVal->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional + } else if (!isDefNull && defValType->GetKind() != ETypeAnnotationKind::Optional && finishType->GetKind() == ETypeAnnotationKind::Optional) { finishType = finishType->Cast<TOptionalExprType>()->GetItemType(); - } else if (!defVal->IsCallable("Null") && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) { - finishType = defVal->GetTypeAnn(); + } else if (!isDefNull && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) { + finishType = defValType; } rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), finishType)); @@ -4840,6 +4858,53 @@ namespace { return IGraphTransformer::TStatus::Ok; } + IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 3, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureAtom(*input->Child(0), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto name = input->Child(0)->Content(); + if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) { + return status; + } + + auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + if (itemType->GetKind() != ETypeAnnotationKind::Struct) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Expected struct type, but got: " << *itemType)); + return IGraphTransformer::TStatus::Error; + } + + auto& lambda = input->ChildRef(2); + const auto status = ConvertToLambda(lambda, ctx.Expr, 1); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(lambda, { itemType }, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!lambda->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + if (name == "count" || name == "count_all") { + input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)); + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Unsupported agg name: " << name)); + return IGraphTransformer::TStatus::Error; + } + + return IGraphTransformer::TStatus::Ok; + } + IGraphTransformer::TStatus FilterNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { if (!EnsureMinMaxArgsCount(*input, 1, 2, ctx.Expr)) { return IGraphTransformer::TStatus::Error; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h index 3fd48517da4..840f6477742 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.h +++ b/ydb/library/yql/core/type_ann/type_ann_list.h @@ -94,6 +94,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus FilterNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus SkipNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus FilterNullElementsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/ydb/library/yql/core/yql_opt_aggregate.cpp b/ydb/library/yql/core/yql_opt_aggregate.cpp index e977a47feae..edd7b4d26b2 100644 --- a/ydb/library/yql/core/yql_opt_aggregate.cpp +++ b/ydb/library/yql/core/yql_opt_aggregate.cpp @@ -1,15 +1,60 @@ #include "yql_opt_aggregate.h" #include "yql_opt_utils.h" #include "yql_opt_window.h" +#include "yql_expr_optimize.h" #include "yql_expr_type_annotation.h" namespace NYql { -TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, bool forceCompact) { +TExprNode::TPtr ExpandAggApply(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) { + auto name = node->Head().Content(); + auto exportsPtr = typesCtx.Modules->GetModule("/lib/yql/aggregate.yql"); + YQL_ENSURE(exportsPtr); + const auto& exports = exportsPtr->Symbols(); + const auto ex = exports.find(TString(name) + "_traits_factory"); + YQL_ENSURE(exports.cend() != ex); + TNodeOnNodeOwnedMap deepClones; + auto lambda = ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false); + + auto listTypeNode = ctx.NewCallable(node->Pos(), "ListType", { node->ChildPtr(1) }); + auto extractor = node->ChildPtr(2); + + auto traits = ctx.ReplaceNodes(lambda->TailPtr(), { + {lambda->Head().Child(0), listTypeNode}, + {lambda->Head().Child(1), extractor} + }); + + ctx.Step.Repeat(TExprStep::ExpandApplyForLambdas); + auto status = ExpandApply(traits, traits, ctx); + YQL_ENSURE(status != IGraphTransformer::TStatus::Error); + return traits; +} + +TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool forceCompact) { auto list = node->HeadPtr(); auto keyColumns = node->ChildPtr(1); auto aggregatedColumns = node->Child(2); auto settings = node->Child(3); + TExprNode::TListType traits; + bool needRebuild = false; + for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) { + auto trait = aggregatedColumns->Child(index)->ChildPtr(1); + if (trait->IsCallable("AggApply")) { + trait = ExpandAggApply(trait, ctx, typesCtx); + needRebuild = true; + } + + traits.push_back(trait); + } + + if (needRebuild) { + TExprNode::TListType newAggregatedColumnsItems = aggregatedColumns->ChildrenList(); + for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) { + newAggregatedColumnsItems[index] = ctx.ChangeChild(*(newAggregatedColumnsItems[index]), 1, std::move(traits[index])); + } + + return ctx.ChangeChild(*node, 2, ctx.NewList(node->Pos(), std::move(newAggregatedColumnsItems))); + } YQL_ENSURE(!HasSetting(*settings, "hopping"), "Aggregate with hopping unsupported here."); @@ -26,8 +71,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T TExprNode::TPtr sortOrder = voidNode; bool effectiveCompact = forceCompact || HasSetting(*settings, "compact"); - for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) { - auto trait = aggregatedColumns->Child(index)->Child(1); + for (const auto& trait : traits) { auto mergeLambda = trait->Child(5); if (mergeLambda->Tail().IsCallable("Void")) { effectiveCompact = true; @@ -206,7 +250,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T TExprNode::TListType nothingStates; for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) { - auto trait = aggregatedColumns->Child(index)->Child(1); + auto trait = traits[index]; auto saveLambda = trait->Child(3); auto saveLambdaType = saveLambda->GetTypeAnn(); @@ -232,7 +276,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { ui32 ndx = 0; for (ui32 i: nondistinctColumns) { - auto trait = aggregatedColumns->Child(i)->Child(1); + auto trait = traits[i]; auto initLambda = trait->Child(1); if (initLambda->Head().ChildrenSize() == 1) { parent.List(ndx++) @@ -280,7 +324,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { ui32 ndx = 0; for (ui32 i: nondistinctColumns) { - auto trait = aggregatedColumns->Child(i)->Child(1); + auto trait = traits[i]; auto updateLambda = trait->Child(2); if (updateLambda->Head().ChildrenSize() == 2) { parent.List(ndx++) @@ -345,7 +389,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Add(1, nothingStates[i]) .Seal(); } else { - auto trait = aggregatedColumns->Child(i)->Child(1); + auto trait = traits[i]; auto saveLambda = trait->Child(3); if (!distinctFields.empty()) { parent.List(i) @@ -579,7 +623,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { ui32 ndx = 0; for (ui32 i: indicies) { - auto trait = aggregatedColumns->Child(i)->Child(1); + auto trait = traits[i]; auto initLambda = trait->Child(1); if (initLambda->Head().ChildrenSize() == 1) { parent.List(ndx++) @@ -636,7 +680,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Callable(0, "AsStruct") .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { for (ui32 i: indicies) { - auto trait = aggregatedColumns->Child(i)->Child(1); + auto trait = traits[i]; auto saveLambda = trait->Child(3); parent.List(ndx++) .Add(0, initialColumnNames[i]) @@ -934,7 +978,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { for (ui32 i = 0; i < initialColumnNames.size(); ++i) { auto child = aggregatedColumns->Child(i); - auto trait = child->Child(1); + auto trait = traits[i]; if (!compact) { auto loadLambda = trait->Child(4); @@ -1093,7 +1137,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { for (ui32 i = 0; i < initialColumnNames.size(); ++i) { auto child = aggregatedColumns->Child(i); - auto trait = child->Child(1); + auto trait = traits[i]; if (!compact) { auto loadLambda = trait->Child(4); auto mergeLambda = trait->Child(5); @@ -1345,7 +1389,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { for (ui32 i = 0; i < initialColumnNames.size(); ++i) { auto child = aggregatedColumns->Child(i); - auto trait = child->Child(1); + auto trait = traits[i]; auto finishLambda = trait->Child(6); if (!compact && !distinctFields.empty()) { diff --git a/ydb/library/yql/core/yql_opt_aggregate.h b/ydb/library/yql/core/yql_opt_aggregate.h index 9707405544c..f1bb5d51345 100644 --- a/ydb/library/yql/core/yql_opt_aggregate.h +++ b/ydb/library/yql/core/yql_opt_aggregate.h @@ -1,11 +1,13 @@ #pragma once #include <ydb/library/yql/core/expr_nodes/yql_expr_nodes.h> +#include "yql_type_annotation.h" + namespace NYql { -TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, bool forceCompact = false); -inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExprContext& ctx) { - return ExpandAggregate(false, node, ctx, true); +TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool forceCompact = false); +inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) { + return ExpandAggregate(false, node, ctx, typesCtx, true); } } diff --git a/ydb/library/yql/dq/opt/dq_opt_log.cpp b/ydb/library/yql/dq/opt/dq_opt_log.cpp index 10feb101ed0..05bee5130b9 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_log.cpp @@ -13,12 +13,12 @@ using namespace NYql::NNodes; namespace NYql::NDq { -TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx) { +TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) { if (!node.Maybe<TCoAggregate>()) { return node; } - auto result = ExpandAggregate(true, node.Ptr(), ctx); + auto result = ExpandAggregate(true, node.Ptr(), ctx, typesCtx); YQL_ENSURE(result); return TExprBase(result); diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h index 8e9ed97b063..024ab5551e9 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.h +++ b/ydb/library/yql/dq/opt/dq_opt_log.h @@ -12,7 +12,7 @@ namespace NYql { namespace NYql::NDq { -NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx); +NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx); NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents); diff --git a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp index 91b394273e4..c6024d4ade7 100644 --- a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp @@ -81,7 +81,7 @@ protected: if (hopSetting) { return RewriteAsHoppingWindow(node, ctx, input.Cast()).Cast(); } else { - return DqRewriteAggregate(node, ctx); + return DqRewriteAggregate(node, ctx, *Types); } } return node; diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp index 64a9c657db9..72079f08d10 100644 --- a/ydb/library/yql/sql/v1/aggregation.cpp +++ b/ydb/library/yql/sql/v1/aggregation.cpp @@ -29,6 +29,10 @@ namespace { } } +static const THashSet<TString> AggApplyFuncs = { + "count_traits_factory" +}; + class TAggregationFactory : public IAggregation { public: TAggregationFactory(TPosition pos, const TString& name, const TString& func, EAggregateMode aggMode, @@ -37,6 +41,10 @@ public: BuildBind(Pos, aggMode == EAggregateMode::OverWindow ? "window_module" : "aggregate_module", func) : nullptr), Multi(multi), ValidateArgs(validateArgs), DynamicFactory(!Factory) { + if (!func.empty() && AggApplyFuncs.contains(func)) { + AggApplyName = func.substr(0, func.size() - 15); + } + if (!Factory) { FakeSource = BuildFakeSource(pos); } @@ -44,6 +52,10 @@ public: protected: bool InitAggr(TContext& ctx, bool isFactory, ISource* src, TAstListNode& node, const TVector<TNodePtr>& exprs) override { + if (!ctx.EmitAggApply) { + AggApplyName = ""; + } + if (ValidateArgs || isFactory) { ui32 expectedArgs = ValidateArgs && !Factory ? 2 : (isFactory ? 0 : 1); if (!Factory && ValidateArgs) { @@ -79,6 +91,9 @@ protected: Name = src->MakeLocalName(Name); } + if (Expr && Expr->IsAsterisk() && AggApplyName == "count") { + AggApplyName = "count_all"; + } if (!Init(ctx, src)) { return false; @@ -100,6 +115,10 @@ protected: TNodePtr GetApply(const TNodePtr& type) const override { if (!Multi) { + if (!DynamicFactory && !AggApplyName.empty()) { + return Y("AggApply", Q(AggApplyName), Y("ListItemType", type), BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr))); + } + return Y("Apply", Factory, (DynamicFactory ? Y("ListItemType", type) : type), BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr))); } @@ -183,6 +202,7 @@ protected: TNodePtr Expr; bool Multi; bool ValidateArgs; + TString AggApplyName; TVector<TNodePtr> Exprs; private: diff --git a/ydb/library/yql/sql/v1/context.cpp b/ydb/library/yql/sql/v1/context.cpp index 72bd5c5f1dd..71c8d02a461 100644 --- a/ydb/library/yql/sql/v1/context.cpp +++ b/ydb/library/yql/sql/v1/context.cpp @@ -56,6 +56,7 @@ THashMap<TStringBuf, TPragmaField> CTX_PRAGMA_FIELDS = { {"AnsiCurrentRow", &TContext::AnsiCurrentRow}, {"EmitStartsWith", &TContext::EmitStartsWith}, {"EnforceAnsiOrderByLimitInUnionAll", &TContext::EnforceAnsiOrderByLimitInUnionAll}, + {"EmitAggApply", &TContext::EmitAggApply}, }; typedef TMaybe<bool> TContext::*TPragmaMaybeField; diff --git a/ydb/library/yql/sql/v1/context.h b/ydb/library/yql/sql/v1/context.h index bc1ef027507..bfddf56c12d 100644 --- a/ydb/library/yql/sql/v1/context.h +++ b/ydb/library/yql/sql/v1/context.h @@ -272,6 +272,7 @@ namespace NSQLTranslationV1 { NYql::TWarningPolicy WarningPolicy; TString PqReadByRtmrCluster; bool EmitStartsWith = true; + bool EmitAggApply = false; }; class TColumnRefScope { diff --git a/ydb/library/yql/sql/v1/sql.cpp b/ydb/library/yql/sql/v1/sql.cpp index 1945b734f23..134b153359b 100644 --- a/ydb/library/yql/sql/v1/sql.cpp +++ b/ydb/library/yql/sql/v1/sql.cpp @@ -9826,6 +9826,9 @@ TNodePtr TSqlQuery::PragmaStatement(const TRule_pragma_stmt& stmt, bool& success } else if (normalizedPragma == "disableansicurrentrow") { Ctx.AnsiCurrentRow = false; Ctx.IncrementMonCounter("sql_pragma", "DisableAnsiCurrentRow"); + } else if (normalizedPragma == "emitaggapply") { + Ctx.EmitAggApply = true; + Ctx.IncrementMonCounter("sql_pragma", "EmitAggApply"); } else { Error() << "Unknown pragma: " << pragma; Ctx.IncrementMonCounter("sql_errors", "UnknownPragma"); |