summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--ydb/core/kqp/expr_nodes/kqp_expr_nodes.json3
-rw-r--r--ydb/core/kqp/opt/rbo/kqp_operator.cpp9
-rw-r--r--ydb/core/kqp/opt/rbo/kqp_operator.h10
-rw-r--r--ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp7
-rw-r--r--ydb/core/kqp/opt/rbo/kqp_rbo_rules.h12
-rw-r--r--ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp5
-rw-r--r--ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp98
-rw-r--r--ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp50
-rw-r--r--ydb/core/kqp/opt/rbo/rules/ya.make1
9 files changed, 126 insertions, 69 deletions
diff --git a/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json b/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json
index 033f34a7078..5c9c663e8ce 100644
--- a/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json
+++ b/ydb/core/kqp/expr_nodes/kqp_expr_nodes.json
@@ -1026,7 +1026,8 @@
"Children": [
{"Index": 0, "Name": "OriginalColName", "Type": "TCoAtom"},
{"Index": 1, "Name": "AggregationFunction", "Type": "TCoAtom"},
- {"Index": 2, "Name": "ResultColName", "Type": "TCoAtom"}
+ {"Index": 2, "Name": "ResultColName", "Type": "TCoAtom"},
+ {"Index": 3, "Name": "Distinct", "Type": "TCoAtom", "Optional": true}
]
},
{
diff --git a/ydb/core/kqp/opt/rbo/kqp_operator.cpp b/ydb/core/kqp/opt/rbo/kqp_operator.cpp
index 102d81f943c..6e2261e4fc9 100644
--- a/ydb/core/kqp/opt/rbo/kqp_operator.cpp
+++ b/ydb/core/kqp/opt/rbo/kqp_operator.cpp
@@ -1038,8 +1038,13 @@ TString TOpAggregate::ToString(TExprContext& ctx) {
TStringBuilder strBuilder;
strBuilder << "Aggregate [";
for (ui32 i = 0; i < AggregationTraitsList.size(); ++i) {
- strBuilder << AggregationTraitsList[i].ResultColName.GetFullName() << ": " << AggregationTraitsList[i].AggFunction << "("
- << AggregationTraitsList[i].OriginalColName.GetFullName() << ")";
+ strBuilder << AggregationTraitsList[i].ResultColName.GetFullName() << ": " << AggregationTraitsList[i].AggFunction << "(";
+ if (AggregationTraitsList[i].Distinct) {
+ strBuilder << "distinct ";
+ }
+
+ strBuilder << AggregationTraitsList[i].OriginalColName.GetFullName();
+ strBuilder << ")";
if (i + 1 != AggregationTraitsList.size()) {
strBuilder << ", ";
}
diff --git a/ydb/core/kqp/opt/rbo/kqp_operator.h b/ydb/core/kqp/opt/rbo/kqp_operator.h
index 6f58626fc56..66d8283e743 100644
--- a/ydb/core/kqp/opt/rbo/kqp_operator.h
+++ b/ydb/core/kqp/opt/rbo/kqp_operator.h
@@ -431,22 +431,24 @@ public:
struct TOpAggregationTraits {
TOpAggregationTraits() = default;
- TOpAggregationTraits(const TInfoUnit& originalColName, const TString& aggFunction, const TInfoUnit& resultColName)
+ TOpAggregationTraits(const TInfoUnit& originalColName, const TString& aggFunction, const TInfoUnit& resultColName, bool distinct = false)
: OriginalColName(originalColName)
, AggFunction(aggFunction)
- , ResultColName(resultColName) {
+ , ResultColName(resultColName)
+ , Distinct(distinct) {
}
TInfoUnit OriginalColName;
TString AggFunction;
TInfoUnit ResultColName;
+ bool Distinct;
};
class TOpAggregate: public IUnaryOperator {
public:
- TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggFunctions, const TVector<TInfoUnit>& keyColumns,
+ TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggTraitsList, const TVector<TInfoUnit>& keyColumns,
const EOpPhase aggPhase, bool distinctAll, TPositionHandle pos);
- TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggFunctions, const TVector<TInfoUnit>& keyColumns,
+ TOpAggregate(TIntrusivePtr<IOperator> input, const TVector<TOpAggregationTraits>& aggTraitsList, const TVector<TInfoUnit>& keyColumns,
const EOpPhase aggPhase, bool distinctAll, const TPhysicalOpProps& props, TPositionHandle pos);
virtual TVector<TInfoUnit> GetOutputIUs() override;
diff --git a/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp b/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp
index 0ced70a321e..a90f42224cb 100644
--- a/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp
+++ b/ydb/core/kqp/opt/rbo/kqp_plan_conversion_utils.cpp
@@ -361,6 +361,11 @@ bool GetOrdered(const TKqpOpMap& map) {
return maybeOrdered && maybeOrdered.Cast().StringValue() == "True";
}
+bool GetDistinct(const TKqpOpAggregationTraits& aggTraits) {
+ auto maybeDistinct = aggTraits.Distinct();
+ return maybeDistinct && maybeDistinct.Cast().StringValue() == "distinct";
+}
+
} // anonymous namespace
void RepairPlanOutputIUs(TOpRoot& root, TExprContext& ctx) {
@@ -807,7 +812,7 @@ TIntrusivePtr<IOperator> PlanConverter::ConvertTKqpOpAggregate(TExprNode::TPtr n
const auto originalColName = TInfoUnit(TString(traits.OriginalColName()));
const auto aggFuncName = TString(traits.AggregationFunction());
const auto resultColName = TInfoUnit(TString(traits.ResultColName()));
- TOpAggregationTraits opAggTraits(originalColName, aggFuncName, resultColName);
+ TOpAggregationTraits opAggTraits(originalColName, aggFuncName, resultColName, GetDistinct(traits));
opAggTraitsList.push_back(opAggTraits);
}
diff --git a/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h b/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h
index a8faa5573e5..30e402c91bf 100644
--- a/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h
+++ b/ydb/core/kqp/opt/rbo/kqp_rbo_rules.h
@@ -91,6 +91,18 @@ class TEliminateLeftJoinRule : public ISimplifiedRule {
virtual TIntrusivePtr<IOperator> SimpleMatchAndApply(const TIntrusivePtr<IOperator> &input, TRBOContext &ctx, TPlanProps &props) override;
};
+/**
+ * Expand distinct aggregation.
+ */
+class TExpandDistinctAggregationRule: public ISimplifiedRule {
+public:
+ TExpandDistinctAggregationRule()
+ : ISimplifiedRule("Expand distinct aggregation rule", ERuleProperties::RequireParents | ERuleProperties::RequireTypes) {
+ }
+
+ virtual TIntrusivePtr<IOperator> SimpleMatchAndApply(const TIntrusivePtr<IOperator>& input, TRBOContext& ctx, TPlanProps& props) override;
+};
+
/***
* Fuse two consequtive filters
*/
diff --git a/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp b/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp
index cf16dbd31d0..85ed74595c6 100644
--- a/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp
+++ b/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp
@@ -355,6 +355,11 @@ void TKqpNewRBOTransformer::InitializeRBOOptimizationStages() {
};
// Initial stages.
+ // Expand aggregation.
+ TVector<std::unique_ptr<IRule>> expandAggregationRules;
+ expandAggregationRules.emplace_back(std::make_unique<TExpandDistinctAggregationRule>());
+ RBO.AddStage(std::make_unique<TRuleBasedStage>("Expand aggregation", std::move(expandAggregationRules)));
+
// Inline join filters. FIXME: Move after inlining when adding support for more advanced decorelation
TVector<std::unique_ptr<IRule>> joinFiltersInlineRules;
joinFiltersInlineRules.emplace_back(std::make_unique<TInlineJoinFiltersRule>());
diff --git a/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp b/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp
index 56143eedfed..21c8b171c4f 100644
--- a/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp
+++ b/ydb/core/kqp/opt/rbo/kqp_rewrite_select.cpp
@@ -126,9 +126,9 @@ TExprNode::TPtr BuildJoinKeys(const TVector<TInfoUnit> &joinKeys, const TVector<
}
TExprNode::TPtr BuildAggregationTraits(const TString& originalColName, const TString& aggFunction, const TString& resultColName, TExprContext& ctx,
- TPositionHandle pos) {
+ TPositionHandle pos, bool distinct = false) {
// clang-format off
- return Build<TKqpOpAggregationTraits>(ctx, pos)
+ auto aggTraitsBuilder = Build<TKqpOpAggregationTraits>(ctx, pos)
.OriginalColName<TCoAtom>()
.Value(originalColName)
.Build()
@@ -137,9 +137,19 @@ TExprNode::TPtr BuildAggregationTraits(const TString& originalColName, const TSt
.Build()
.ResultColName<TCoAtom>()
.Value(resultColName)
- .Build()
- .Done().Ptr();
+ .Build();
// clang-format on
+
+ if (distinct) {
+ // clang-format off
+ aggTraitsBuilder
+ .Distinct<TCoAtom>()
+ .Value("distinct")
+ .Build();
+ // clang-format on
+ }
+
+ return aggTraitsBuilder.Done().Ptr();
}
TExprNode::TPtr BuildAggregate(TExprNode::TPtr resultExpr, const TVector<TExprNode::TPtr>& aggTraitsList, const TVector<TInfoUnit> &keys,
@@ -672,26 +682,20 @@ void EliminateDuplicateAggregations(TVector<std::tuple<TInfoUnit, TExprNode::TPt
}
TExprNode::TPtr BuildAggregationPipeline(TExprNode::TPtr resultExpr, TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>&& expressionsMapPreAgg,
- TVector<std::pair<TInfoUnit, TExprNode::TPtr>>&& groupByKeysExpressionsMap,
- TAggregationTraits&& distinctAggregationTraitsPreAggregate, TAggregationTraits&& aggTraits,
+ TVector<std::pair<TInfoUnit, TExprNode::TPtr>>&& groupByKeysExpressionsMap, TAggregationTraits&& aggTraits,
TAggregationTraits&& distinctAggregationTraitsPostAggregate, TExprNode::TPtr& havingFilterLambda,
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>&& expressionsMapPostAgg, TExprContext& ctx,
TPositionHandle pos) {
// While processing aggregations and having we could have the same aggregations functions on the same column, here we want to eliminate them.
// TODO: Make a special rule in optimizer for that and support more cases, currently we support only simple one aka:
// select f(a) ... having f(a) > val ...;
- if (distinctAggregationTraitsPreAggregate.AggTraitsList.empty() && distinctAggregationTraitsPostAggregate.AggTraitsList.empty()) {
+ if (distinctAggregationTraitsPostAggregate.AggTraitsList.empty()) {
EliminateDuplicateAggregations(expressionsMapPreAgg, aggTraits, expressionsMapPostAgg, havingFilterLambda, ctx, pos);
}
// In case we have an expression for aggregation - f(a + b ...) or group by.
if (!expressionsMapPreAgg.empty() || !groupByKeysExpressionsMap.empty()) {
resultExpr = BuildAggregateExpressionMap(resultExpr, expressionsMapPreAgg, groupByKeysExpressionsMap, ctx, pos);
}
- // Build distinct aggregate pre aggregate.
- if (!distinctAggregationTraitsPreAggregate.AggTraitsList.empty()) {
- resultExpr = BuildAggregate(resultExpr, distinctAggregationTraitsPreAggregate.AggTraitsList, distinctAggregationTraitsPreAggregate.KeyColumns,
- /*distinct=*/true, ctx, pos);
- }
// Build Aggreegate.
if (!aggTraits.AggTraitsList.empty()) {
resultExpr = BuildAggregate(resultExpr, aggTraits.AggTraitsList, aggTraits.KeyColumns, /*distinct=*/false, ctx, pos);
@@ -720,10 +724,10 @@ TExprNode::TPtr BuildAggregationPipeline(TExprNode::TPtr resultExpr, TVector<std
void ProcessAggregations(TExprNode::TPtr lambdaToProcess, TString&& resultColName, THashSet<TString>& aggregationUniqueColNames,
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPreAgg,
- TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& distinctAggregationTraitsPreAggregate,
- TAggregationTraits& aggTraits, TAggregationTraits& distinctAggregationTraitsPostAggregate,
- TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPostAgg, ui64& uniqueAggColumnId, bool& distinctPreAggregate,
- const bool distinctAll, TExprContext& ctx, TPositionHandle pos) {
+ TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& aggTraits,
+ TAggregationTraits& distinctAggregationTraitsPostAggregate,
+ TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPostAgg, ui64& uniqueAggColumnId, const bool distinctAll,
+ TExprContext& ctx, TPositionHandle pos) {
// Here we want to process given lambda to find all aggregations and expressions.
auto lambda = TCoLambda(ctx.DeepCopyLambda(*lambdaToProcess));
THashMap<TExprNode::TPtr, TString> aggregationsForReplacement;
@@ -799,19 +803,12 @@ void ProcessAggregations(TExprNode::TPtr lambdaToProcess, TString&& resultColNam
aggregationUniqueColNames.insert(aggColName.GetFullName());
// Distinct for column or expression f(distinct a) => (distinct a) as b -> f(b).
- if (!!GetSetting(*aggregation->Child(1), "distinct")) {
- const auto colName = aggColName.GetFullName();
- auto distinctAggTraits = BuildAggregationTraits(colName, "distinct", colName, ctx, pos);
- distinctAggregationTraitsPreAggregate.AggTraitsList.push_back(distinctAggTraits);
- distinctAggregationTraitsPreAggregate.KeyColumns.push_back(aggColName);
- distinctPreAggregate = true;
- }
-
+ const bool distinct = !!GetSetting(*aggregation->Child(1), "distinct");
// Rename for aggregation result.
const auto aggResultColName = TInfoUnit(GenerateUniqueColumnName(uniqueAggColumnId, "agg_result", "agg_col"));
// Build an aggregation traits.
- auto aggregationTraits = BuildAggregationTraits(aggColName.GetFullName(), aggFuncName, aggResultColName.GetFullName(), ctx, pos);
- aggTraits.AggTraitsList.push_back(aggregationTraits);
+ const auto aggregationTraits = BuildAggregationTraits(aggColName.GetFullName(), aggFuncName, aggResultColName.GetFullName(), ctx, pos, distinct);
+ aggTraits.AggTraitsList.emplace_back(aggregationTraits);
aggregationsForReplacement[aggregation] = aggResultColName.GetFullName();
}
@@ -873,12 +870,10 @@ void ProcessAggregations(TExprNode::TPtr lambdaToProcess, TString&& resultColNam
void ProcessAggregationsInHaving(TExprNode::TPtr having, THashSet<TString>& aggregationUniqueColNames,
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPreAgg,
- TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap,
- TAggregationTraits& distinctAggregationTraitsPreAggregate, TAggregationTraits& aggTraits,
+ TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& aggTraits,
TAggregationTraits& distinctAggregationTraitsPostAggregate, TExprNode::TPtr& havingFilterLambda, ui64& uniqueAggColumnId,
const bool distinctAll, TExprContext& ctx, TPositionHandle pos) {
Y_ENSURE(!distinctAll, "Distinct all is not supported for HAVING.");
- bool distinctPreAggregate = false;
// For each result item, we want to process result lambda to extract aggregations and pre/post expressions.
auto yqlWhere = having->ChildPtr(1);
Y_ENSURE(yqlWhere->IsCallable("YqlWhere"));
@@ -886,11 +881,9 @@ void ProcessAggregationsInHaving(TExprNode::TPtr having, THashSet<TString>& aggr
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>> havingFilterHolder;
TString resultColName = GenerateUniqueColumnName(uniqueAggColumnId, "having", "col");
- ProcessAggregations(yqlWhere->ChildPtr(1), std::move(resultColName), aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap,
- distinctAggregationTraitsPreAggregate, aggTraits, distinctAggregationTraitsPostAggregate, havingFilterHolder, uniqueAggColumnId,
- distinctPreAggregate, distinctAll, ctx, pos);
+ ProcessAggregations(yqlWhere->ChildPtr(1), std::move(resultColName), aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, aggTraits,
+ distinctAggregationTraitsPostAggregate, havingFilterHolder, uniqueAggColumnId, distinctAll, ctx, pos);
- Y_ENSURE(!distinctPreAggregate, "Distinct is not supported for HAVING.");
Y_ENSURE(havingFilterHolder.size() == 1, "Invalid number of filters for HAVING.");
havingFilterLambda = std::get<1>(havingFilterHolder.front());
Y_ENSURE(havingFilterLambda, "Fitler for HAVING is nullptr");
@@ -898,36 +891,23 @@ void ProcessAggregationsInHaving(TExprNode::TPtr having, THashSet<TString>& aggr
void ProcessAggregationsInResultItems(TExprNode::TPtr result, THashSet<TString>& aggregationUniqueColNames,
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPreAgg,
- TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap,
- TAggregationTraits& distinctAggregationTraitsPreAggregate, TAggregationTraits& aggTraits,
+ TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap, TAggregationTraits& aggTraits,
TAggregationTraits& distinctAggregationTraitsPostAggregate,
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>>& expressionsMapPostAgg, ui64& uniqueAggColumnId,
const bool distinctAll, TExprContext& ctx, TPositionHandle pos) {
- bool distinctPreAggregate = false;
// For each result item, we want to process result lambda to extract aggregations and pre/post expressions.
for (ui32 i = 0, e = result->Child(1)->ChildrenSize(); i < e; ++i) {
auto resultItem = result->Child(1)->ChildPtr(i);
ProcessAggregations(resultItem->ChildPtr(2), TString(resultItem->Child(0)->Content()), aggregationUniqueColNames, expressionsMapPreAgg,
- groupByKeysExpressionsMap, distinctAggregationTraitsPreAggregate, aggTraits, distinctAggregationTraitsPostAggregate,
- expressionsMapPostAgg, uniqueAggColumnId, distinctPreAggregate, distinctAll, ctx, pos);
- }
-
- // Distinct pre aggregate fro group by keys.
- if (distinctPreAggregate) {
- Y_ENSURE(distinctAggregationTraitsPreAggregate.AggTraitsList.size() == 1 && aggTraits.AggTraitsList.size() == 1, "Multiple distinct is not supported");
- for (const auto& key : aggTraits.KeyColumns) {
- const auto colName = key.GetFullName();
- const auto distinctAggTraits = BuildAggregationTraits(colName, "distinct", colName, ctx, pos);
- distinctAggregationTraitsPreAggregate.AggTraitsList.push_back(distinctAggTraits);
- distinctAggregationTraitsPreAggregate.KeyColumns.push_back(colName);
- }
+ groupByKeysExpressionsMap, aggTraits, distinctAggregationTraitsPostAggregate, expressionsMapPostAgg, uniqueAggColumnId, distinctAll,
+ ctx, pos);
}
// Distinct post aggregate for group by keys.
if (distinctAll) {
// distinct f(a), b group by b => f(a) as f, b group by b -> select f, b group by f, b.
THashSet<TString> distinctSet;
- for (const auto& key: distinctAggregationTraitsPostAggregate.KeyColumns) {
+ for (const auto& key : distinctAggregationTraitsPostAggregate.KeyColumns) {
distinctSet.insert(key.GetFullName());
}
@@ -1422,16 +1402,14 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c
auto having = GetSetting(setItem->Tail(), "having");
if (having) {
- ProcessAggregationsInHaving(having, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap,
- distinctAggregationTraitsPreAggregate, aggregationTraits, distinctAggregationTraitsPostAggregate, havingFilterLambda,
- uniqueAggColumnId, distinctAll, ctx, node->Pos());
+ ProcessAggregationsInHaving(having, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, aggregationTraits,
+ distinctAggregationTraitsPostAggregate, havingFilterLambda, uniqueAggColumnId, distinctAll, ctx, node->Pos());
}
auto result = GetSetting(setItem->Tail(), "result");
// Process all aggregations in result item.
- ProcessAggregationsInResultItems(result, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap,
- distinctAggregationTraitsPreAggregate, aggregationTraits, distinctAggregationTraitsPostAggregate,
- expressionsMapPostAgg, uniqueAggColumnId, distinctAll, ctx, node->Pos());
+ ProcessAggregationsInResultItems(result, aggregationUniqueColNames, expressionsMapPreAgg, groupByKeysExpressionsMap, aggregationTraits,
+ distinctAggregationTraitsPostAggregate, expressionsMapPostAgg, uniqueAggColumnId, distinctAll, ctx, node->Pos());
if (hasRollup) {
Y_ENSURE(groupBySets.size() == 1, "Invalid group sets size for rollup.");
@@ -1444,7 +1422,6 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c
// We have to use keys based on group set.
aggregationTraitsForSet.KeyColumns.clear();
TVector<std::pair<TInfoUnit, TExprNode::TPtr>> groupByKeysExpressionsMapForSet;
- TAggregationTraits distinctAggregationTraitsPreAggregateForSet = distinctAggregationTraitsPreAggregate;
TAggregationTraits distinctAggregationTraitsPostAggregateForSet = distinctAggregationTraitsPostAggregate;
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>> expressionsMapPostAggForSet = expressionsMapPostAgg;
TVector<std::tuple<TInfoUnit, TExprNode::TPtr, bool>> expressionsMapPreAggForSet = expressionsMapPreAgg;
@@ -1490,7 +1467,7 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c
auto aggregationForGroupSetResultExpr = BuildAggregationPipeline(
resultExpr, std::move(expressionsMapPreAggForSet), std::move(groupByKeysExpressionsMapForSet),
- std::move(distinctAggregationTraitsPreAggregateForSet), std::move(aggregationTraitsForSet),
+ std::move(aggregationTraitsForSet),
std::move(distinctAggregationTraitsPostAggregateForSet), havingFilterLambda, std::move(expressionsMapPostAggForSet), ctx, node->Pos());
if (rollupResultExpr) {
@@ -1511,9 +1488,8 @@ TExprNode::TPtr RewriteSelect(const TExprNode::TPtr& input, TExprContext& ctx, c
} else {
// Build an aggregation pipeline.
resultExpr = BuildAggregationPipeline(resultExpr, std::move(expressionsMapPreAgg), std::move(groupByKeysExpressionsMap),
- std::move(distinctAggregationTraitsPreAggregate), std::move(aggregationTraits),
- std::move(distinctAggregationTraitsPostAggregate), havingFilterLambda, std::move(expressionsMapPostAgg), ctx,
- node->Pos());
+ std::move(aggregationTraits), std::move(distinctAggregationTraitsPostAggregate), havingFilterLambda,
+ std::move(expressionsMapPostAgg), ctx, node->Pos());
}
finalColumnOrder.clear();
diff --git a/ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp b/ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp
new file mode 100644
index 00000000000..58b59500b2e
--- /dev/null
+++ b/ydb/core/kqp/opt/rbo/rules/expand_distinct_aggregation.cpp
@@ -0,0 +1,50 @@
+#include <ydb/core/kqp/opt/rbo/kqp_rbo_rules.h>
+
+namespace NKikimr::NKqp {
+
+namespace {
+
+bool IsSuitableToExpandDistinctAggregation(const TIntrusivePtr<IOperator>& input) {
+ if (input->GetKind() != EOperator::Aggregate) {
+ return false;
+ }
+
+ const auto& aggTraitsList = CastOperator<TOpAggregate>(input)->GetAggregationTraits();
+ return std::any_of(aggTraitsList.begin(), aggTraitsList.end(), [](const TOpAggregationTraits& aggTraits) { return aggTraits.Distinct; });
+}
+
+} // anonymous namespace
+
+TIntrusivePtr<IOperator> TExpandDistinctAggregationRule::SimpleMatchAndApply(const TIntrusivePtr<IOperator>& input, TRBOContext& rboCtx, TPlanProps& props) {
+ Y_UNUSED(props);
+ Y_UNUSED(rboCtx);
+
+ if (!IsSuitableToExpandDistinctAggregation(input)) {
+ return input;
+ }
+
+ const auto aggregate = CastOperator<TOpAggregate>(input);
+ const auto& aggTraitsList = aggregate->GetAggregationTraits();
+ Y_ENSURE(aggTraitsList.size() == 1, "Multiple distinct is not supported.");
+ const auto& aggTraits = aggTraitsList.front();
+ TVector<TInfoUnit> distinctKeys = aggregate->GetKeyColumns();
+
+ // Split into distinct and original aggregation.
+ TVector<TOpAggregationTraits> distinctTraitsList;
+ for (const auto& key: distinctKeys) {
+ distinctTraitsList.emplace_back(TOpAggregationTraits(key, "distinct", key));
+ }
+ distinctTraitsList.emplace_back(TOpAggregationTraits(aggTraits.OriginalColName, "distinct", aggTraits.OriginalColName));
+ distinctKeys.emplace_back(aggTraits.OriginalColName);
+
+ const TIntrusivePtr<IOperator> distinctAggregation =
+ MakeIntrusive<TOpAggregate>(aggregate->GetInput(), distinctTraitsList, distinctKeys, EOpPhase::Undefined,
+ /*distinctAll=*/true, input->Pos);
+ TOpAggregationTraits aggregationTraits = aggTraits;
+ aggregationTraits.Distinct = false;
+ const TVector<TOpAggregationTraits> newAggTraitsList{aggregationTraits};
+ return MakeIntrusive<TOpAggregate>(distinctAggregation, newAggTraitsList, aggregate->GetKeyColumns(), EOpPhase::Undefined, /*distinctAll=*/false,
+ input->Pos);
+}
+
+} // namespace NKikimr::NKqp
diff --git a/ydb/core/kqp/opt/rbo/rules/ya.make b/ydb/core/kqp/opt/rbo/rules/ya.make
index fae225c2ec9..1620f403d10 100644
--- a/ydb/core/kqp/opt/rbo/rules/ya.make
+++ b/ydb/core/kqp/opt/rbo/rules/ya.make
@@ -7,6 +7,7 @@ SRCS(
constant_folding_stage.cpp
correlated_filter_pullup.cpp
expand_cbo_tree.cpp
+ expand_distinct_aggregation.cpp
extract_join_expressions.cpp
eliminate_left_join.cpp
fuse_filters.cpp