diff options
| -rw-r--r-- | ydb/core/kqp/expr_nodes/kqp_expr_nodes.json | 3 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/kqp_operator.cpp | 9 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/kqp_operator.h | 10 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp | 7 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/kqp_rbo_rules.h | 12 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp | 5 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp | 98 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp | 50 | ||||
| -rw-r--r-- | ydb/core/kqp/opt/rbo/rules/ya.make | 1 |
9 files changed, 126 insertions, 69 deletions
diff --git a/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json b/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json index 033f34a7078..5c9c663e8ce 100644 --- a/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json +++ b/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json @@ -1026,7 +1026,8 @@ "Children": [ {"Index": 0, "Name": "OriginalColName", "Type": "TCoAtom"}, {"Index": 1, "Name": "AggregationFunction", "Type": "TCoAtom"}, - {"Index": 2, "Name": "ResultColName", "Type": "TCoAtom"} + {"Index": 2, "Name": "ResultColName", "Type": "TCoAtom"}, + {"Index": 3, "Name": "Distinct", "Type": "TCoAtom", "Optional": true} ] }, { diff --git a/ydb/core/kqp/opt/rbo/kqp_operator.cpp b/ydb/core/kqp/opt/rbo/kqp_operator.cpp index 102d81f943c..6e2261e4fc9 100644 --- a/ydb/core/kqp/opt/rbo/kqp_operator.cpp +++ b/ydb/core/kqp/opt/rbo/kqp_operator.cpp @@ -1038,8 +1038,13 @@ TString TOpAggregate::ToString(TExprContext& ctx) { TStringBuilder strBuilder; strBuilder << "Aggregate ["; for (ui32 i = 0; i < AggregationTraitsList.size(); ++i) { - strBuilder << AggregationTraitsList[i].ResultColName.GetFullName() << ": " << AggregationTraitsList[i].AggFunction << "(" - << AggregationTraitsList[i].OriginalColName.GetFullName() << ")"; + strBuilder << AggregationTraitsList[i].ResultColName.GetFullName() << ": " << AggregationTraitsList[i].AggFunction << "("; + if (AggregationTraitsList[i].Distinct) { + strBuilder << "distinct "; + } + + strBuilder << AggregationTraitsList[i].OriginalColName.GetFullName(); + strBuilder << ")"; if (i + 1 != AggregationTraitsList.size()) { strBuilder << ", "; } diff --git a/ydb/core/kqp/opt/rbo/kqp_operator.h b/ydb/core/kqp/opt/rbo/kqp_operator.h index 6f58626fc56..66d8283e743 100644 --- a/ydb/core/kqp/opt/rbo/kqp_operator.h +++ b/ydb/core/kqp/opt/rbo/kqp_operator.h @@ -431,22 +431,24 @@ public: struct TOpAggregationTraits { TOpAggregationTraits() = default; - TOpAggregationTraits(const TInfoUnit& originalColName, const TString& aggFunction, const TInfoUnit& resultColName) + TOpAggregationTraits(const TInfoUnit& originalColName, const TString& aggFunction, const TInfoUnit& resultColName, bool distinct = false) : OriginalColName(originalColName) , AggFunction(aggFunction) - , ResultColName(resultColName) { + , ResultColName(resultColName) + , Distinct(distinct) { } TInfoUnit OriginalColName; TString AggFunction; TInfoUnit ResultColName; + bool Distinct; }; class TOpAggregate: public IUnaryOperator { public: - TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggFunctions, const TVector<TInfoUnit>& keyColumns, + TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggTraitsList, const TVector<TInfoUnit>& keyColumns, const EOpPhase aggPhase, bool distinctAll, TPositionHandle pos); - TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggFunctions, const TVector<TInfoUnit>& keyColumns, + TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggTraitsList, const TVector<TInfoUnit>& keyColumns, const EOpPhase aggPhase, bool distinctAll, const TPhysicalOpProps& props, TPositionHandle pos); virtual TVector<TInfoUnit> GetOutputIUs() override; diff --git a/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp b/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp index 0ced70a321e..a90f42224cb 100644 --- a/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp +++ b/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp @@ -361,6 +361,11 @@ bool GetOrdered(const TKqpOpMap& map) { return maybeOrdered && maybeOrdered.Cast().StringValue() == "True"; } +bool GetDistinct(const TKqpOpAggregationTraits& aggTraits) { + auto maybeDistinct = aggTraits.Distinct(); + return maybeDistinct && maybeDistinct.Cast().StringValue() == "distinct"; +} + } // anonymous namespace void RepairPlanOutputIUs(TOpRoot& root, TExprContext& ctx) { @@ -807,7 +812,7 @@ TIntrusivePtr<IOperator> PlanConverter::ConvertTKqpOpAggregate(TExprNode::TPtr n const auto originalColName = TInfoUnit(TString(traits.OriginalColName())); const auto aggFuncName = TString(traits.AggregationFunction()); const auto resultColName = TInfoUnit(TString(traits.ResultColName())); - TOpAggregationTraits opAggTraits(originalColName, aggFuncName, resultColName); + TOpAggregationTraits opAggTraits(originalColName, aggFuncName, resultColName, GetDistinct(traits)); opAggTraitsList.push_back(opAggTraits); } diff --git a/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h b/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h index a8faa5573e5..30e402c91bf 100644 --- a/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h +++ b/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h @@ -91,6 +91,18 @@ class TEliminateLeftJoinRule : public ISimplifiedRule { virtual TIntrusivePtr<IOperator> SimpleMatchAndApply(const TIntrusivePtr<IOperator> &input, TRBOContext &ctx, TPlanProps &props) override; }; +/** + * Expand distinct aggregation. + */ +class TExpandDistinctAggregationRule: public ISimplifiedRule { +public: + TExpandDistinctAggregationRule() + : ISimplifiedRule("Expand distinct aggregation rule", ERuleProperties::RequireParents | ERuleProperties::RequireTypes) { + } + + virtual TIntrusivePtr<IOperator> SimpleMatchAndApply(const TIntrusivePtr<IOperator>& input, TRBOContext& ctx, TPlanProps& props) override; +}; + /*** * Fuse two consequtive filters */ diff --git a/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp b/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp index cf16dbd31d0..85ed74595c6 100644 --- a/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp +++ b/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp @@ -355,6 +355,11 @@ void TKqpNewRBOTransformer::InitializeRBOOptimizationStages() { }; // Initial stages. + // Expand aggregation. + TVector<std::unique_ptr<IRule>> expandAggregationRules; + expandAggregationRules.emplace_back(std::make_unique<TExpandDistinctAggregationRule>()); + RBO.AddStage(std::make_unique<TRuleBasedStage>("Expand aggregation", std::move(expandAggregationRules))); + // Inline join filters. FIXME: Move after inlining when adding support for more advanced decorelation TVector<std::unique_ptr<IRule>> joinFiltersInlineRules; joinFiltersInlineRules.emplace_back(std::make_unique<TInlineJoinFiltersRule>()); diff --git a/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp b/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp index 56143eedfed..21c8b171c4f 100644 --- a/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp +++ b/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp @@ -126,9 +126,9 @@ TExprNode::TPtr BuildJoinKeys(const TVector<TInfoUnit> &joinKeys, const TVector< } TExprNode::TPtr BuildAggregationTraits(const TString& originalColName, const TString& aggFunction, const TString& resultColName, TExprContext& ctx, - TPositionHandle pos) { + TPositionHandle pos, bool distinct = false) { // clang-format off - return Build<TKqpOpAggregationTraits>(ctx, pos) + auto aggTraitsBuilder = Build<TKqpOpAggregationTraits>(ctx, pos) .OriginalColName<TCoAtom>() .Value(originalColName) .Build() @@ -137,9 +137,19 @@ TExprNode::TPtr BuildAggregationTraits(const TString& originalColName, const TSt .Build() .ResultColName<TCoAtom>() .Value(resultColName) - .Build() - .Done().Ptr(); + .Build(); // clang-format on + + if (distinct) { + // clang-format off + aggTraitsBuilder + .Distinct<TCoAtom>() + .Value("distinct") + .Build(); + // clang-format on + } + + return aggTraitsBuilder.Done().Ptr(); } TExprNode::TPtr BuildAggregate(TExprNode::TPtr resultExpr, const TVector<TExprNode::TPtr>& aggTraitsList, const TVector<TInfoUnit> &keys, @@ -672,26 +682,20 @@ void EliminateDuplicateAggregations(TVector<std::tuple<TInfoUnit, TExprNode::TPt } TExprNode::TPtr BuildAggregationPipeline(TExprNode::TPtr resultExpr, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>&& expressionsMapPreAgg, - TVector<std::pair<TInfoUnit, TExprNode::TPtr>>&& groupByKeysExpressionsMap, - TAggregationTraits&& distinctAggregationTraitsPreAggregate, TAggregationTraits&& aggTraits, + TVector<std::pair<TInfoUnit, TExprNode::TPtr>>&& groupByKeysExpressionsMap, TAggregationTraits&& aggTraits, TAggregationTraits&& distinctAggregationTraitsPostAggregate, TExprNode::TPtr& havingFilterLambda, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>&& expressionsMapPostAgg, TExprContext& ctx, TPositionHandle pos) { // While processing aggregations and having we could have the same aggregations functions on the same column, here we want to eliminate them. // TODO: Make a special rule in optimizer for that and support more cases, currently we support only simple one aka: // select f(a) ... having f(a) > val ...; - if (distinctAggregationTraitsPreAggregate.AggTraitsList.empty() && distinctAggregationTraitsPostAggregate.AggTraitsList.empty()) { + if (distinctAggregationTraitsPostAggregate.AggTraitsList.empty()) { EliminateDuplicateAggregations(expressionsMapPreAgg, aggTraits, expressionsMapPostAgg, havingFilterLambda, ctx, pos); } // In case we have an expression for aggregation - f(a + b ...) or group by. if (!expressionsMapPreAgg.empty() || !groupByKeysExpressionsMap.empty()) { resultExpr = BuildAggregateExpressionMap(resultExpr, expressionsMapPreAgg, groupByKeysExpressionsMap, ctx, pos); } - // Build distinct aggregate pre aggregate. - if (!distinctAggregationTraitsPreAggregate.AggTraitsList.empty()) { - resultExpr = BuildAggregate(resultExpr, distinctAggregationTraitsPreAggregate.AggTraitsList, distinctAggregationTraitsPreAggregate.KeyColumns, - /*distinct=*/true, ctx, pos); - } // Build Aggreegate. if (!aggTraits.AggTraitsList.empty()) { resultExpr = BuildAggregate(resultExpr, aggTraits.AggTraitsList, aggTraits.KeyColumns, /*distinct=*/false, ctx, pos); @@ -720,10 +724,10 @@ TExprNode::TPtr BuildAggregationPipeline(TExprNode::TPtr resultExpr, TVector<std void ProcessAggregations(TExprNode::TPtr lambdaToProcess, TString&& resultColName, THashSet<TString>& aggregationUniqueColNames, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPreAgg, - TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& distinctAggregationTraitsPreAggregate, - TAggregationTraits& aggTraits, TAggregationTraits& distinctAggregationTraitsPostAggregate, - TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPostAgg, ui64& uniqueAggColumnId, bool& distinctPreAggregate, - const bool distinctAll, TExprContext& ctx, TPositionHandle pos) { + TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& aggTraits, + TAggregationTraits& distinctAggregationTraitsPostAggregate, + TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPostAgg, ui64& uniqueAggColumnId, const bool distinctAll, + TExprContext& ctx, TPositionHandle pos) { // Here we want to process given lambda to find all aggregations and expressions. auto lambda = TCoLambda(ctx.DeepCopyLambda(*lambdaToProcess)); THashMap<TExprNode::TPtr, TString> aggregationsForReplacement; @@ -799,19 +803,12 @@ void ProcessAggregations(TExprNode::TPtr lambdaToProcess, TString&& resultColNam aggregationUniqueColNames.insert(aggColName.GetFullName()); // Distinct for column or expression f(distinct a) => (distinct a) as b -> f(b). - if (!!GetSetting(*aggregation->Child(1), "distinct")) { - const auto colName = aggColName.GetFullName(); - auto distinctAggTraits = BuildAggregationTraits(colName, "distinct", colName, ctx, pos); - distinctAggregationTraitsPreAggregate.AggTraitsList.push_back(distinctAggTraits); - distinctAggregationTraitsPreAggregate.KeyColumns.push_back(aggColName); - distinctPreAggregate = true; - } - + const bool distinct = !!GetSetting(*aggregation->Child(1), "distinct"); // Rename for aggregation result. const auto aggResultColName = TInfoUnit(GenerateUniqueColumnName(uniqueAggColumnId, "agg_result", "agg_col")); // Build an aggregation traits. - auto aggregationTraits = BuildAggregationTraits(aggColName.GetFullName(), aggFuncName, aggResultColName.GetFullName(), ctx, pos); - aggTraits.AggTraitsList.push_back(aggregationTraits); + const auto aggregationTraits = BuildAggregationTraits(aggColName.GetFullName(), aggFuncName, aggResultColName.GetFullName(), ctx, pos, distinct); + aggTraits.AggTraitsList.emplace_back(aggregationTraits); aggregationsForReplacement[aggregation] = aggResultColName.GetFullName(); } @@ -873,12 +870,10 @@ void ProcessAggregations(TExprNode::TPtr lambdaToProcess, TString&& resultColNam void ProcessAggregationsInHaving(TExprNode::TPtr having, THashSet<TString>& aggregationUniqueColNames, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPreAgg, - TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, - TAggregationTraits& distinctAggregationTraitsPreAggregate, TAggregationTraits& aggTraits, + TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& aggTraits, TAggregationTraits& distinctAggregationTraitsPostAggregate, TExprNode::TPtr& havingFilterLambda, ui64& uniqueAggColumnId, const bool distinctAll, TExprContext& ctx, TPositionHandle pos) { Y_ENSURE(!distinctAll, "Distinct all is not supported for HAVING."); - bool distinctPreAggregate = false; // For each result item, we want to process result lambda to extract aggregations and pre/post expressions. auto yqlWhere = having->ChildPtr(1); Y_ENSURE(yqlWhere->IsCallable("YqlWhere")); @@ -886,11 +881,9 @@ void ProcessAggregationsInHaving(TExprNode::TPtr having, THashSet<TString>& aggr TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>> havingFilterHolder; TString resultColName = GenerateUniqueColumnName(uniqueAggColumnId, "having", "col"); - ProcessAggregations(yqlWhere->ChildPtr(1), std::move(resultColName), aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, - distinctAggregationTraitsPreAggregate, aggTraits, distinctAggregationTraitsPostAggregate, havingFilterHolder, uniqueAggColumnId, - distinctPreAggregate, distinctAll, ctx, pos); + ProcessAggregations(yqlWhere->ChildPtr(1), std::move(resultColName), aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, aggTraits, + distinctAggregationTraitsPostAggregate, havingFilterHolder, uniqueAggColumnId, distinctAll, ctx, pos); - Y_ENSURE(!distinctPreAggregate, "Distinct is not supported for HAVING."); Y_ENSURE(havingFilterHolder.size() == 1, "Invalid number of filters for HAVING."); havingFilterLambda = std::get<1>(havingFilterHolder.front()); Y_ENSURE(havingFilterLambda, "Fitler for HAVING is nullptr"); @@ -898,36 +891,23 @@ void ProcessAggregationsInHaving(TExprNode::TPtr having, THashSet<TString>& aggr void ProcessAggregationsInResultItems(TExprNode::TPtr result, THashSet<TString>& aggregationUniqueColNames, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPreAgg, - TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, - TAggregationTraits& distinctAggregationTraitsPreAggregate, TAggregationTraits& aggTraits, + TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& aggTraits, TAggregationTraits& distinctAggregationTraitsPostAggregate, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPostAgg, ui64& uniqueAggColumnId, const bool distinctAll, TExprContext& ctx, TPositionHandle pos) { - bool distinctPreAggregate = false; // For each result item, we want to process result lambda to extract aggregations and pre/post expressions. for (ui32 i = 0, e = result->Child(1)->ChildrenSize(); i < e; ++i) { auto resultItem = result->Child(1)->ChildPtr(i); ProcessAggregations(resultItem->ChildPtr(2), TString(resultItem->Child(0)->Content()), aggregationUniqueColNames, expressionsMapPreAgg, - groupByKeysExpressionsMap, distinctAggregationTraitsPreAggregate, aggTraits, distinctAggregationTraitsPostAggregate, - expressionsMapPostAgg, uniqueAggColumnId, distinctPreAggregate, distinctAll, ctx, pos); - } - - // Distinct pre aggregate fro group by keys. - if (distinctPreAggregate) { - Y_ENSURE(distinctAggregationTraitsPreAggregate.AggTraitsList.size() == 1 && aggTraits.AggTraitsList.size() == 1, "Multiple distinct is not supported"); - for (const auto& key : aggTraits.KeyColumns) { - const auto colName = key.GetFullName(); - const auto distinctAggTraits = BuildAggregationTraits(colName, "distinct", colName, ctx, pos); - distinctAggregationTraitsPreAggregate.AggTraitsList.push_back(distinctAggTraits); - distinctAggregationTraitsPreAggregate.KeyColumns.push_back(colName); - } + groupByKeysExpressionsMap, aggTraits, distinctAggregationTraitsPostAggregate, expressionsMapPostAgg, uniqueAggColumnId, distinctAll, + ctx, pos); } // Distinct post aggregate for group by keys. if (distinctAll) { // distinct f(a), b group by b => f(a) as f, b group by b -> select f, b group by f, b. THashSet<TString> distinctSet; - for (const auto& key: distinctAggregationTraitsPostAggregate.KeyColumns) { + for (const auto& key : distinctAggregationTraitsPostAggregate.KeyColumns) { distinctSet.insert(key.GetFullName()); } @@ -1422,16 +1402,14 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c auto having = GetSetting(setItem->Tail(), "having"); if (having) { - ProcessAggregationsInHaving(having, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, - distinctAggregationTraitsPreAggregate, aggregationTraits, distinctAggregationTraitsPostAggregate, havingFilterLambda, - uniqueAggColumnId, distinctAll, ctx, node->Pos()); + ProcessAggregationsInHaving(having, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, aggregationTraits, + distinctAggregationTraitsPostAggregate, havingFilterLambda, uniqueAggColumnId, distinctAll, ctx, node->Pos()); } auto result = GetSetting(setItem->Tail(), "result"); // Process all aggregations in result item. - ProcessAggregationsInResultItems(result, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, - distinctAggregationTraitsPreAggregate, aggregationTraits, distinctAggregationTraitsPostAggregate, - expressionsMapPostAgg, uniqueAggColumnId, distinctAll, ctx, node->Pos()); + ProcessAggregationsInResultItems(result, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, aggregationTraits, + distinctAggregationTraitsPostAggregate, expressionsMapPostAgg, uniqueAggColumnId, distinctAll, ctx, node->Pos()); if (hasRollup) { Y_ENSURE(groupBySets.size() == 1, "Invalid group sets size for rollup."); @@ -1444,7 +1422,6 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c // We have to use keys based on group set. aggregationTraitsForSet.KeyColumns.clear(); TVector<std::pair<TInfoUnit, TExprNode::TPtr>> groupByKeysExpressionsMapForSet; - TAggregationTraits distinctAggregationTraitsPreAggregateForSet = distinctAggregationTraitsPreAggregate; TAggregationTraits distinctAggregationTraitsPostAggregateForSet = distinctAggregationTraitsPostAggregate; TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>> expressionsMapPostAggForSet = expressionsMapPostAgg; TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>> expressionsMapPreAggForSet = expressionsMapPreAgg; @@ -1490,7 +1467,7 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c auto aggregationForGroupSetResultExpr = BuildAggregationPipeline( resultExpr, std::move(expressionsMapPreAggForSet), std::move(groupByKeysExpressionsMapForSet), - std::move(distinctAggregationTraitsPreAggregateForSet), std::move(aggregationTraitsForSet), + std::move(aggregationTraitsForSet), std::move(distinctAggregationTraitsPostAggregateForSet), havingFilterLambda, std::move(expressionsMapPostAggForSet), ctx, node->Pos()); if (rollupResultExpr) { @@ -1511,9 +1488,8 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c } else { // Build an aggregation pipeline. resultExpr = BuildAggregationPipeline(resultExpr, std::move(expressionsMapPreAgg), std::move(groupByKeysExpressionsMap), - std::move(distinctAggregationTraitsPreAggregate), std::move(aggregationTraits), - std::move(distinctAggregationTraitsPostAggregate), havingFilterLambda, std::move(expressionsMapPostAgg), ctx, - node->Pos()); + std::move(aggregationTraits), std::move(distinctAggregationTraitsPostAggregate), havingFilterLambda, + std::move(expressionsMapPostAgg), ctx, node->Pos()); } finalColumnOrder.clear(); diff --git a/ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp b/ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp new file mode 100644 index 00000000000..58b59500b2e --- /dev/null +++ b/ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp @@ -0,0 +1,50 @@ +#include <ydb/core/kqp/opt/rbo/kqp_rbo_rules.h> + +namespace NKikimr::NKqp { + +namespace { + +bool IsSuitableToExpandDistinctAggregation(const TIntrusivePtr<IOperator>& input) { + if (input->GetKind() != EOperator::Aggregate) { + return false; + } + + const auto& aggTraitsList = CastOperator<TOpAggregate>(input)->GetAggregationTraits(); + return std::any_of(aggTraitsList.begin(), aggTraitsList.end(), [](const TOpAggregationTraits& aggTraits) { return aggTraits.Distinct; }); +} + +} // anonymous namespace + +TIntrusivePtr<IOperator> TExpandDistinctAggregationRule::SimpleMatchAndApply(const TIntrusivePtr<IOperator>& input, TRBOContext& rboCtx, TPlanProps& props) { + Y_UNUSED(props); + Y_UNUSED(rboCtx); + + if (!IsSuitableToExpandDistinctAggregation(input)) { + return input; + } + + const auto aggregate = CastOperator<TOpAggregate>(input); + const auto& aggTraitsList = aggregate->GetAggregationTraits(); + Y_ENSURE(aggTraitsList.size() == 1, "Multiple distinct is not supported."); + const auto& aggTraits = aggTraitsList.front(); + TVector<TInfoUnit> distinctKeys = aggregate->GetKeyColumns(); + + // Split into distinct and original aggregation. + TVector<TOpAggregationTraits> distinctTraitsList; + for (const auto& key: distinctKeys) { + distinctTraitsList.emplace_back(TOpAggregationTraits(key, "distinct", key)); + } + distinctTraitsList.emplace_back(TOpAggregationTraits(aggTraits.OriginalColName, "distinct", aggTraits.OriginalColName)); + distinctKeys.emplace_back(aggTraits.OriginalColName); + + const TIntrusivePtr<IOperator> distinctAggregation = + MakeIntrusive<TOpAggregate>(aggregate->GetInput(), distinctTraitsList, distinctKeys, EOpPhase::Undefined, + /*distinctAll=*/true, input->Pos); + TOpAggregationTraits aggregationTraits = aggTraits; + aggregationTraits.Distinct = false; + const TVector<TOpAggregationTraits> newAggTraitsList{aggregationTraits}; + return MakeIntrusive<TOpAggregate>(distinctAggregation, newAggTraitsList, aggregate->GetKeyColumns(), EOpPhase::Undefined, /*distinctAll=*/false, + input->Pos); +} + +} // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/opt/rbo/rules/ya.make b/ydb/core/kqp/opt/rbo/rules/ya.make index fae225c2ec9..1620f403d10 100644 --- a/ydb/core/kqp/opt/rbo/rules/ya.make +++ b/ydb/core/kqp/opt/rbo/rules/ya.make @@ -7,6 +7,7 @@ SRCS( constant_folding_stage.cpp correlated_filter_pullup.cpp expand_cbo_tree.cpp + expand_distinct_aggregation.cpp extract_join_expressions.cpp eliminate_left_join.cpp fuse_filters.cpp |
