summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <[email protected]>2022-07-29 13:31:19 +0300
committervvvv <[email protected]>2022-07-29 13:31:19 +0300
commitc47ea27d22ee334dd949daf9bd7a39078639efdd (patch)
tree1e1481509741a54e19c3b327606e64c42ca65a10
parentc42419d12ba88adc5e1dcaac8577df094fa7d92e (diff)
initial implementation of AggApply for count. Only peephole is working atm.
-rw-r--r--ydb/core/kqp/opt/logical/kqp_opt_log.cpp2
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_datasink.cpp2
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_opt.cpp12
-rw-r--r--ydb/core/kqp/provider/yql_kikimr_provider_impl.h2
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_flow1.cpp4
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_flow2.cpp4
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp31
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp1
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp79
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.h1
-rw-r--r--ydb/library/yql/core/yql_opt_aggregate.cpp68
-rw-r--r--ydb/library/yql/core/yql_opt_aggregate.h8
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_log.cpp4
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_log.h2
-rw-r--r--ydb/library/yql/providers/dq/opt/logical_optimize.cpp2
-rw-r--r--ydb/library/yql/sql/v1/aggregation.cpp20
-rw-r--r--ydb/library/yql/sql/v1/context.cpp1
-rw-r--r--ydb/library/yql/sql/v1/context.h1
-rw-r--r--ydb/library/yql/sql/v1/sql.cpp3
19 files changed, 202 insertions, 45 deletions
diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
index 03a101d0e00..8c4405e6892 100644
--- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
+++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp
@@ -74,7 +74,7 @@ protected:
}
TMaybeNode<TExprBase> RewriteAggregate(TExprBase node, TExprContext& ctx) {
- TExprBase output = DqRewriteAggregate(node, ctx);
+ TExprBase output = DqRewriteAggregate(node, ctx, TypesCtx);
DumpAppliedRule("RewriteAggregate", node.Ptr(), output.Ptr(), ctx);
return output;
}
diff --git a/ydb/core/kqp/provider/yql_kikimr_datasink.cpp b/ydb/core/kqp/provider/yql_kikimr_datasink.cpp
index 0d1f6af7694..37f443e1bb8 100644
--- a/ydb/core/kqp/provider/yql_kikimr_datasink.cpp
+++ b/ydb/core/kqp/provider/yql_kikimr_datasink.cpp
@@ -313,7 +313,7 @@ public:
, SessionCtx(sessionCtx)
, IntentDeterminationTransformer(CreateKiSinkIntentDeterminationTransformer(sessionCtx))
, TypeAnnotationTransformer(CreateKiSinkTypeAnnotationTransformer(gateway, sessionCtx))
- , LogicalOptProposalTransformer(CreateKiLogicalOptProposalTransformer(sessionCtx))
+ , LogicalOptProposalTransformer(CreateKiLogicalOptProposalTransformer(sessionCtx, types))
, PhysicalOptProposalTransformer(CreateKiPhysicalOptProposalTransformer(sessionCtx))
, CallableExecutionTransformer(CreateKiSinkCallableExecutionTransformer(gateway, sessionCtx, queryExecutor))
, PlanInfoTransformer(CreateKiSinkPlanInfoTransformer(queryExecutor))
diff --git a/ydb/core/kqp/provider/yql_kikimr_opt.cpp b/ydb/core/kqp/provider/yql_kikimr_opt.cpp
index e04c518f2f7..a3b8a36848a 100644
--- a/ydb/core/kqp/provider/yql_kikimr_opt.cpp
+++ b/ydb/core/kqp/provider/yql_kikimr_opt.cpp
@@ -100,7 +100,7 @@ TExprNode::TPtr KiEraseOverSelectRow(TExprBase node, TExprContext& ctx) {
return node.Ptr();
}
-TExprNode::TPtr KiRewriteAggregate(TExprBase node, TExprContext& ctx) {
+TExprNode::TPtr KiRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) {
if (!node.Maybe<TCoAggregate>()) {
return node.Ptr();
}
@@ -120,7 +120,7 @@ TExprNode::TPtr KiRewriteAggregate(TExprBase node, TExprContext& ctx) {
}
YQL_CLOG(INFO, ProviderKikimr) << "KiRewriteAggregate";
- return ExpandAggregate(true, node.Ptr(), ctx);
+ return ExpandAggregate(true, node.Ptr(), ctx, typesCtx);
}
TExprNode::TPtr KiRedundantSortByPk(TExprBase node, TExprContext& ctx,
@@ -655,8 +655,8 @@ TExprNode::TPtr KiApplyExtractMembersToSelectRange(TExprBase node, TExprContext&
} // namespace
-TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx) {
- return CreateFunctorTransformer([sessionCtx](const TExprNode::TPtr& input, TExprNode::TPtr& output,
+TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx, TTypeAnnotationContext& types) {
+ return CreateFunctorTransformer([sessionCtx, &types](const TExprNode::TPtr& input, TExprNode::TPtr& output,
TExprContext& ctx)
{
typedef IGraphTransformer::TStatus TStatus;
@@ -666,7 +666,7 @@ TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<
GatherParents(*input, parentsMap);
optCtx.ParentsMap = &parentsMap;
- TStatus status = OptimizeExpr(input, output, [sessionCtx, &optCtx](const TExprNode::TPtr& inputNode, TExprContext& ctx) {
+ TStatus status = OptimizeExpr(input, output, [sessionCtx, &optCtx, &types](const TExprNode::TPtr& inputNode, TExprContext& ctx) {
auto ret = inputNode;
TExprBase node(inputNode);
@@ -710,7 +710,7 @@ TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<
return ret;
}
- ret = KiRewriteAggregate(node, ctx);
+ ret = KiRewriteAggregate(node, ctx, types);
if (ret != inputNode) {
return ret;
}
diff --git a/ydb/core/kqp/provider/yql_kikimr_provider_impl.h b/ydb/core/kqp/provider/yql_kikimr_provider_impl.h
index 2149d95309f..43e710002e5 100644
--- a/ydb/core/kqp/provider/yql_kikimr_provider_impl.h
+++ b/ydb/core/kqp/provider/yql_kikimr_provider_impl.h
@@ -181,7 +181,7 @@ TAutoPtr<IGraphTransformer> CreateKiSourceTypeAnnotationTransformer(TIntrusivePt
TTypeAnnotationContext& types);
TAutoPtr<IGraphTransformer> CreateKiSinkTypeAnnotationTransformer(TIntrusivePtr<IKikimrGateway> gateway,
TIntrusivePtr<TKikimrSessionContext> sessionCtx);
-TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx);
+TAutoPtr<IGraphTransformer> CreateKiLogicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx, TTypeAnnotationContext& types);
TAutoPtr<IGraphTransformer> CreateKiPhysicalOptProposalTransformer(TIntrusivePtr<TKikimrSessionContext> sessionCtx);
TAutoPtr<IGraphTransformer> CreateKiSourceLoadTableMetadataTransformer(TIntrusivePtr<IKikimrGateway> gateway,
TIntrusivePtr<TKikimrSessionContext> sessionCtx);
diff --git a/ydb/library/yql/core/common_opt/yql_co_flow1.cpp b/ydb/library/yql/core/common_opt/yql_co_flow1.cpp
index d15804a3051..4b61c4ebd85 100644
--- a/ydb/library/yql/core/common_opt/yql_co_flow1.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_flow1.cpp
@@ -1269,6 +1269,10 @@ TExprNode::TPtr CountAggregateRewrite(const TCoAggregate& node, TExprContext& ct
const bool isDistinct = (aggregatedColumn.Ref().ChildrenSize() == 3);
auto traits = aggregatedColumn.Ref().Child(1);
+ if (!traits->IsCallable("AggregationTraits")) {
+ return node.Ptr();
+ }
+
auto outputColumn = aggregatedColumn.Ref().HeadPtr();
// validation of traits
auto inputItemType = traits->Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType();
diff --git a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
index 7a67e5bc029..b4c1b618b9e 100644
--- a/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_flow2.cpp
@@ -50,6 +50,10 @@ TExprNode::TPtr AggregateSubsetFieldsAnalyzer(const TCoAggregate& node, TExprCon
}
else {
auto traits = x.Ref().Child(1);
+ if (!traits->IsCallable("AggregationTraits")) {
+ return node.Ptr();
+ }
+
auto structType = traits->Child(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>();
for (const auto& item : structType->GetItems()) {
usedFields.insert(item->GetName());
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 3b577aa0655..4f67f210c39 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -31,8 +31,8 @@ using namespace NNodes;
using TPeepHoleOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&);
using TPeepHoleOptimizerMap = std::unordered_map<std::string_view, TPeepHoleOptimizerPtr>;
-using TNonDeterministicOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&, TTypeAnnotationContext& types);
-using TNonDeterministicOptimizerMap = std::unordered_map<std::string_view, TNonDeterministicOptimizerPtr>;
+using TExtPeepHoleOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&, TTypeAnnotationContext& types);
+using TExtPeepHoleOptimizerMap = std::unordered_map<std::string_view, TExtPeepHoleOptimizerPtr>;
TExprNode::TPtr MakeNothing(TPositionHandle pos, const TTypeAnnotationNode& type, TExprContext& ctx) {
return ctx.NewCallable(pos, "Nothing", {ExpandType(pos, *ctx.MakeType<TOptionalExprType>(&type), ctx)});
@@ -1884,22 +1884,26 @@ TExprNode::TPtr ExpandFilter(const TExprNode::TPtr& input, TExprContext& ctx) {
}
IGraphTransformer::TStatus PeepHoleCommonStage(const TExprNode::TPtr& input, TExprNode::TPtr& output,
- TExprContext& ctx, TTypeAnnotationContext& types, const TPeepHoleOptimizerMap& optimizers)
+ TExprContext& ctx, TTypeAnnotationContext& types,
+ const TPeepHoleOptimizerMap& optimizers, const TExtPeepHoleOptimizerMap& extOptimizers)
{
TOptimizeExprSettings settings(&types);
settings.CustomInstantTypeTransformer = types.CustomInstantTypeTransformer.Get();
- return OptimizeExpr(input, output, [&optimizers](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr {
+ return OptimizeExpr(input, output, [&optimizers, &extOptimizers, &types](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr {
if (const auto rule = optimizers.find(node->Content()); optimizers.cend() != rule)
return (rule->second)(node, ctx);
+ if (const auto rule = extOptimizers.find(node->Content()); extOptimizers.cend() != rule)
+ return (rule->second)(node, ctx, types);
+
return node;
}, ctx, settings);
}
IGraphTransformer::TStatus PeepHoleFinalStage(const TExprNode::TPtr& input, TExprNode::TPtr& output,
TExprContext& ctx, TTypeAnnotationContext& types, bool* hasNonDeterministicFunctions,
- const TPeepHoleOptimizerMap& optimizers, const TNonDeterministicOptimizerMap& nonDetOptimizers)
+ const TPeepHoleOptimizerMap& optimizers, const TExtPeepHoleOptimizerMap& nonDetOptimizers)
{
TOptimizeExprSettings settings(&types);
settings.CustomInstantTypeTransformer = types.CustomInstantTypeTransformer.Get();
@@ -5823,7 +5827,6 @@ struct TPeepHoleRules {
{"OptionalReduce", &ExpandOptionalReduce},
{"AggrMin", &ExpandAggrMinMax<true>},
{"AggrMax", &ExpandAggrMinMax<false>},
- {"Aggregate", &ExpandAggregatePeephole},
{"And", &OptimizeLogicalDups<true>},
{"Or", &OptimizeLogicalDups<false>},
{"CombineByKey", &ExpandCombineByKey},
@@ -5850,6 +5853,10 @@ struct TPeepHoleRules {
{"ToFlow", &DropToFlowDeps},
};
+ static constexpr std::initializer_list<TExtPeepHoleOptimizerMap::value_type> CommonStageExtRulesInit = {
+ {"Aggregate", &ExpandAggregatePeephole},
+ };
+
static constexpr std::initializer_list<TPeepHoleOptimizerMap::value_type> SimplifyStageRulesInit = {
{"Map", &OptimizeMap<false, EnableNewOptimizers>},
{"OrderedMap", &OptimizeMap<true, EnableNewOptimizers>},
@@ -5909,7 +5916,7 @@ struct TPeepHoleRules {
{"SqueezeToDict", &OptimizeSqueezeToDict}
};
- static constexpr std::initializer_list<TNonDeterministicOptimizerMap::value_type> FinalStageNonDetRulesInit = {
+ static constexpr std::initializer_list<TExtPeepHoleOptimizerMap::value_type> FinalStageNonDetRulesInit = {
{"Random", &Random0Arg<double>},
{"RandomNumber", &Random0Arg<ui64>},
{"RandomUuid", &Random0Arg<TGUID>},
@@ -5921,6 +5928,7 @@ struct TPeepHoleRules {
TPeepHoleRules()
: CommonStageRules(CommonStageRulesInit)
+ , CommonStageExtRules(CommonStageExtRulesInit)
, FinalStageRules(FinalStageRulesInit)
, SimplifyStageRules(SimplifyStageRulesInit)
, FinalStageNonDetRules(FinalStageNonDetRulesInit)
@@ -5931,9 +5939,10 @@ struct TPeepHoleRules {
}
const TPeepHoleOptimizerMap CommonStageRules;
+ const TExtPeepHoleOptimizerMap CommonStageExtRules;
const TPeepHoleOptimizerMap FinalStageRules;
const TPeepHoleOptimizerMap SimplifyStageRules;
- const TNonDeterministicOptimizerMap FinalStageNonDetRules;
+ const TExtPeepHoleOptimizerMap FinalStageNonDetRules;
};
template <bool EnableNewOptimizers>
@@ -5955,7 +5964,9 @@ THolder<IGraphTransformer> CreatePeepHoleCommonStageTransformer(TTypeAnnotationC
pipeline.AddCommonOptimization(issueCode);
pipeline.Add(CreateFunctorTransformer(
[&types](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) {
- return PeepHoleCommonStage(input, output, ctx, types, TPeepHoleRules<EnableNewOptimizers>::Instance().CommonStageRules);
+ return PeepHoleCommonStage(input, output, ctx, types,
+ TPeepHoleRules<EnableNewOptimizers>::Instance().CommonStageRules,
+ TPeepHoleRules<EnableNewOptimizers>::Instance().CommonStageExtRules);
}
),
"PeepHoleCommon",
@@ -5997,7 +6008,7 @@ THolder<IGraphTransformer> CreatePeepHoleFinalStageTransformer(TTypeAnnotationCo
}
const auto& nonDetStageRules = withNonDeterministicRules ?
- TPeepHoleRules<EnableNewOptimizers>::Instance().FinalStageNonDetRules : TNonDeterministicOptimizerMap{};
+ TPeepHoleRules<EnableNewOptimizers>::Instance().FinalStageNonDetRules : TExtPeepHoleOptimizerMap{};
return PeepHoleFinalStage(input, output, ctx, types, hasNonDeterministicFunctions, stageRules, nonDetStageRules);
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index b2dcbdc39fe..7b51ab5fac6 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11244,6 +11244,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["MultiAggregate"] = &MultiAggregateWrapper;
Functions["Aggregate"] = &AggregateWrapper;
Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper;
+ Functions["AggApply"] = &AggApplyWrapper;
Functions["WinOnRows"] = &WinOnWrapper;
Functions["WinOnGroups"] = &WinOnWrapper;
Functions["WinOnRange"] = &WinOnWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index 743a7a872db..e9b27e10c91 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -4726,7 +4726,9 @@ namespace {
}
}
- if (!child->Child(1)->IsCallable("AggregationTraits")) {
+ const bool isAggApply = child->Child(1)->IsCallable("AggApply");
+ const bool isTraits = child->Child(1)->IsCallable("AggregationTraits");
+ if (!isAggApply && !isTraits) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Child(1)->Pos()), "Expected aggregation traits"));
return IGraphTransformer::TStatus::Error;
}
@@ -4752,7 +4754,7 @@ namespace {
}
}
- auto finishType = child->Child(1)->Child(6)->GetTypeAnn();
+ auto finishType = isAggApply ? child->Child(1)->GetTypeAnn() : child->Child(1)->Child(6)->GetTypeAnn();
bool isOptional = finishType->GetKind() == ETypeAnnotationKind::Optional;
if (child->Head().IsList()) {
if (isOptional) {
@@ -4774,17 +4776,33 @@ namespace {
item->GetKind() != ETypeAnnotationKind::Optional ? ctx.Expr.MakeType<TOptionalExprType>(item) : item));
}
} else {
- auto defVal = child->Child(1)->Child(7);
- if (defVal->IsCallable("Null") && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) {
+ const TTypeAnnotationNode* defValType;
+ bool isDefNull;
+ if (isTraits) {
+ auto defVal = child->Child(1)->Child(7);
+ isDefNull = defVal->IsCallable("Null");
+ defValType = defVal->GetTypeAnn();
+ } else {
+ auto name = child->Child(1)->Child(0)->Content();
+ if (name == "count" || name == "count_all") {
+ isDefNull = false;
+ defValType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
+ } else {
+ isDefNull = true;
+ defValType = ctx.Expr.MakeType<TNullExprType>();
+ }
+ }
+
+ if (isDefNull && !isOptional && !isHopping && input->Child(1)->ChildrenSize() == 0) {
if (finishType->GetKind() != ETypeAnnotationKind::Null &&
finishType->GetKind() != ETypeAnnotationKind::Pg) {
finishType = ctx.Expr.MakeType<TOptionalExprType>(finishType);
}
- } else if (!defVal->IsCallable("Null") && defVal->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional
+ } else if (!isDefNull && defValType->GetKind() != ETypeAnnotationKind::Optional
&& finishType->GetKind() == ETypeAnnotationKind::Optional) {
finishType = finishType->Cast<TOptionalExprType>()->GetItemType();
- } else if (!defVal->IsCallable("Null") && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) {
- finishType = defVal->GetTypeAnn();
+ } else if (!isDefNull && finishType->GetKind() == ETypeAnnotationKind::Null && input->Child(1)->ChildrenSize() == 0) {
+ finishType = defValType;
}
rowColumns.push_back(ctx.Expr.MakeType<TItemExprType>(child->Head().Content(), finishType));
@@ -4840,6 +4858,53 @@ namespace {
return IGraphTransformer::TStatus::Ok;
}
+ IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 3, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureAtom(*input->Child(0), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto name = input->Child(0)->Content();
+ if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ if (itemType->GetKind() != ETypeAnnotationKind::Struct) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Expected struct type, but got: " << *itemType));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto& lambda = input->ChildRef(2);
+ const auto status = ConvertToLambda(lambda, ctx.Expr, 1);
+ if (status.Level != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ if (!UpdateLambdaAllArgumentsTypes(lambda, { itemType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!lambda->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ if (name == "count" || name == "count_all") {
+ input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64));
+ } else {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Unsupported agg name: " << name));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ return IGraphTransformer::TStatus::Ok;
+ }
+
IGraphTransformer::TStatus FilterNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
if (!EnsureMinMaxArgsCount(*input, 1, 2, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h
index 3fd48517da4..840f6477742 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.h
+++ b/ydb/library/yql/core/type_ann/type_ann_list.h
@@ -94,6 +94,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+ IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus FilterNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus SkipNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus FilterNullElementsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
diff --git a/ydb/library/yql/core/yql_opt_aggregate.cpp b/ydb/library/yql/core/yql_opt_aggregate.cpp
index e977a47feae..edd7b4d26b2 100644
--- a/ydb/library/yql/core/yql_opt_aggregate.cpp
+++ b/ydb/library/yql/core/yql_opt_aggregate.cpp
@@ -1,15 +1,60 @@
#include "yql_opt_aggregate.h"
#include "yql_opt_utils.h"
#include "yql_opt_window.h"
+#include "yql_expr_optimize.h"
#include "yql_expr_type_annotation.h"
namespace NYql {
-TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, bool forceCompact) {
+TExprNode::TPtr ExpandAggApply(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) {
+ auto name = node->Head().Content();
+ auto exportsPtr = typesCtx.Modules->GetModule("/lib/yql/aggregate.yql");
+ YQL_ENSURE(exportsPtr);
+ const auto& exports = exportsPtr->Symbols();
+ const auto ex = exports.find(TString(name) + "_traits_factory");
+ YQL_ENSURE(exports.cend() != ex);
+ TNodeOnNodeOwnedMap deepClones;
+ auto lambda = ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false);
+
+ auto listTypeNode = ctx.NewCallable(node->Pos(), "ListType", { node->ChildPtr(1) });
+ auto extractor = node->ChildPtr(2);
+
+ auto traits = ctx.ReplaceNodes(lambda->TailPtr(), {
+ {lambda->Head().Child(0), listTypeNode},
+ {lambda->Head().Child(1), extractor}
+ });
+
+ ctx.Step.Repeat(TExprStep::ExpandApplyForLambdas);
+ auto status = ExpandApply(traits, traits, ctx);
+ YQL_ENSURE(status != IGraphTransformer::TStatus::Error);
+ return traits;
+}
+
+TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool forceCompact) {
auto list = node->HeadPtr();
auto keyColumns = node->ChildPtr(1);
auto aggregatedColumns = node->Child(2);
auto settings = node->Child(3);
+ TExprNode::TListType traits;
+ bool needRebuild = false;
+ for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) {
+ auto trait = aggregatedColumns->Child(index)->ChildPtr(1);
+ if (trait->IsCallable("AggApply")) {
+ trait = ExpandAggApply(trait, ctx, typesCtx);
+ needRebuild = true;
+ }
+
+ traits.push_back(trait);
+ }
+
+ if (needRebuild) {
+ TExprNode::TListType newAggregatedColumnsItems = aggregatedColumns->ChildrenList();
+ for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) {
+ newAggregatedColumnsItems[index] = ctx.ChangeChild(*(newAggregatedColumnsItems[index]), 1, std::move(traits[index]));
+ }
+
+ return ctx.ChangeChild(*node, 2, ctx.NewList(node->Pos(), std::move(newAggregatedColumnsItems)));
+ }
YQL_ENSURE(!HasSetting(*settings, "hopping"), "Aggregate with hopping unsupported here.");
@@ -26,8 +71,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
TExprNode::TPtr sortOrder = voidNode;
bool effectiveCompact = forceCompact || HasSetting(*settings, "compact");
- for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) {
- auto trait = aggregatedColumns->Child(index)->Child(1);
+ for (const auto& trait : traits) {
auto mergeLambda = trait->Child(5);
if (mergeLambda->Tail().IsCallable("Void")) {
effectiveCompact = true;
@@ -206,7 +250,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
TExprNode::TListType nothingStates;
for (ui32 index = 0; index < aggregatedColumns->ChildrenSize(); ++index) {
- auto trait = aggregatedColumns->Child(index)->Child(1);
+ auto trait = traits[index];
auto saveLambda = trait->Child(3);
auto saveLambdaType = saveLambda->GetTypeAnn();
@@ -232,7 +276,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
ui32 ndx = 0;
for (ui32 i: nondistinctColumns) {
- auto trait = aggregatedColumns->Child(i)->Child(1);
+ auto trait = traits[i];
auto initLambda = trait->Child(1);
if (initLambda->Head().ChildrenSize() == 1) {
parent.List(ndx++)
@@ -280,7 +324,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
ui32 ndx = 0;
for (ui32 i: nondistinctColumns) {
- auto trait = aggregatedColumns->Child(i)->Child(1);
+ auto trait = traits[i];
auto updateLambda = trait->Child(2);
if (updateLambda->Head().ChildrenSize() == 2) {
parent.List(ndx++)
@@ -345,7 +389,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Add(1, nothingStates[i])
.Seal();
} else {
- auto trait = aggregatedColumns->Child(i)->Child(1);
+ auto trait = traits[i];
auto saveLambda = trait->Child(3);
if (!distinctFields.empty()) {
parent.List(i)
@@ -579,7 +623,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
ui32 ndx = 0;
for (ui32 i: indicies) {
- auto trait = aggregatedColumns->Child(i)->Child(1);
+ auto trait = traits[i];
auto initLambda = trait->Child(1);
if (initLambda->Head().ChildrenSize() == 1) {
parent.List(ndx++)
@@ -636,7 +680,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Callable(0, "AsStruct")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (ui32 i: indicies) {
- auto trait = aggregatedColumns->Child(i)->Child(1);
+ auto trait = traits[i];
auto saveLambda = trait->Child(3);
parent.List(ndx++)
.Add(0, initialColumnNames[i])
@@ -934,7 +978,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (ui32 i = 0; i < initialColumnNames.size(); ++i) {
auto child = aggregatedColumns->Child(i);
- auto trait = child->Child(1);
+ auto trait = traits[i];
if (!compact) {
auto loadLambda = trait->Child(4);
@@ -1093,7 +1137,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (ui32 i = 0; i < initialColumnNames.size(); ++i) {
auto child = aggregatedColumns->Child(i);
- auto trait = child->Child(1);
+ auto trait = traits[i];
if (!compact) {
auto loadLambda = trait->Child(4);
auto mergeLambda = trait->Child(5);
@@ -1345,7 +1389,7 @@ TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, T
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (ui32 i = 0; i < initialColumnNames.size(); ++i) {
auto child = aggregatedColumns->Child(i);
- auto trait = child->Child(1);
+ auto trait = traits[i];
auto finishLambda = trait->Child(6);
if (!compact && !distinctFields.empty()) {
diff --git a/ydb/library/yql/core/yql_opt_aggregate.h b/ydb/library/yql/core/yql_opt_aggregate.h
index 9707405544c..f1bb5d51345 100644
--- a/ydb/library/yql/core/yql_opt_aggregate.h
+++ b/ydb/library/yql/core/yql_opt_aggregate.h
@@ -1,11 +1,13 @@
#pragma once
#include <ydb/library/yql/core/expr_nodes/yql_expr_nodes.h>
+#include "yql_type_annotation.h"
+
namespace NYql {
-TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, bool forceCompact = false);
-inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExprContext& ctx) {
- return ExpandAggregate(false, node, ctx, true);
+TExprNode::TPtr ExpandAggregate(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool forceCompact = false);
+inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) {
+ return ExpandAggregate(false, node, ctx, typesCtx, true);
}
}
diff --git a/ydb/library/yql/dq/opt/dq_opt_log.cpp b/ydb/library/yql/dq/opt/dq_opt_log.cpp
index 10feb101ed0..05bee5130b9 100644
--- a/ydb/library/yql/dq/opt/dq_opt_log.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_log.cpp
@@ -13,12 +13,12 @@ using namespace NYql::NNodes;
namespace NYql::NDq {
-TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx) {
+TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) {
if (!node.Maybe<TCoAggregate>()) {
return node;
}
- auto result = ExpandAggregate(true, node.Ptr(), ctx);
+ auto result = ExpandAggregate(true, node.Ptr(), ctx, typesCtx);
YQL_ENSURE(result);
return TExprBase(result);
diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h
index 8e9ed97b063..024ab5551e9 100644
--- a/ydb/library/yql/dq/opt/dq_opt_log.h
+++ b/ydb/library/yql/dq/opt/dq_opt_log.h
@@ -12,7 +12,7 @@ namespace NYql {
namespace NYql::NDq {
-NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx);
+NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx);
NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents);
diff --git a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp
index 91b394273e4..c6024d4ade7 100644
--- a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp
+++ b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp
@@ -81,7 +81,7 @@ protected:
if (hopSetting) {
return RewriteAsHoppingWindow(node, ctx, input.Cast()).Cast();
} else {
- return DqRewriteAggregate(node, ctx);
+ return DqRewriteAggregate(node, ctx, *Types);
}
}
return node;
diff --git a/ydb/library/yql/sql/v1/aggregation.cpp b/ydb/library/yql/sql/v1/aggregation.cpp
index 64a9c657db9..72079f08d10 100644
--- a/ydb/library/yql/sql/v1/aggregation.cpp
+++ b/ydb/library/yql/sql/v1/aggregation.cpp
@@ -29,6 +29,10 @@ namespace {
}
}
+static const THashSet<TString> AggApplyFuncs = {
+ "count_traits_factory"
+};
+
class TAggregationFactory : public IAggregation {
public:
TAggregationFactory(TPosition pos, const TString& name, const TString& func, EAggregateMode aggMode,
@@ -37,6 +41,10 @@ public:
BuildBind(Pos, aggMode == EAggregateMode::OverWindow ? "window_module" : "aggregate_module", func) : nullptr),
Multi(multi), ValidateArgs(validateArgs), DynamicFactory(!Factory)
{
+ if (!func.empty() && AggApplyFuncs.contains(func)) {
+ AggApplyName = func.substr(0, func.size() - 15);
+ }
+
if (!Factory) {
FakeSource = BuildFakeSource(pos);
}
@@ -44,6 +52,10 @@ public:
protected:
bool InitAggr(TContext& ctx, bool isFactory, ISource* src, TAstListNode& node, const TVector<TNodePtr>& exprs) override {
+ if (!ctx.EmitAggApply) {
+ AggApplyName = "";
+ }
+
if (ValidateArgs || isFactory) {
ui32 expectedArgs = ValidateArgs && !Factory ? 2 : (isFactory ? 0 : 1);
if (!Factory && ValidateArgs) {
@@ -79,6 +91,9 @@ protected:
Name = src->MakeLocalName(Name);
}
+ if (Expr && Expr->IsAsterisk() && AggApplyName == "count") {
+ AggApplyName = "count_all";
+ }
if (!Init(ctx, src)) {
return false;
@@ -100,6 +115,10 @@ protected:
TNodePtr GetApply(const TNodePtr& type) const override {
if (!Multi) {
+ if (!DynamicFactory && !AggApplyName.empty()) {
+ return Y("AggApply", Q(AggApplyName), Y("ListItemType", type), BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)));
+ }
+
return Y("Apply", Factory, (DynamicFactory ? Y("ListItemType", type) : type),
BuildLambda(Pos, Y("row"), Y("PersistableRepr", Expr)));
}
@@ -183,6 +202,7 @@ protected:
TNodePtr Expr;
bool Multi;
bool ValidateArgs;
+ TString AggApplyName;
TVector<TNodePtr> Exprs;
private:
diff --git a/ydb/library/yql/sql/v1/context.cpp b/ydb/library/yql/sql/v1/context.cpp
index 72bd5c5f1dd..71c8d02a461 100644
--- a/ydb/library/yql/sql/v1/context.cpp
+++ b/ydb/library/yql/sql/v1/context.cpp
@@ -56,6 +56,7 @@ THashMap<TStringBuf, TPragmaField> CTX_PRAGMA_FIELDS = {
{"AnsiCurrentRow", &TContext::AnsiCurrentRow},
{"EmitStartsWith", &TContext::EmitStartsWith},
{"EnforceAnsiOrderByLimitInUnionAll", &TContext::EnforceAnsiOrderByLimitInUnionAll},
+ {"EmitAggApply", &TContext::EmitAggApply},
};
typedef TMaybe<bool> TContext::*TPragmaMaybeField;
diff --git a/ydb/library/yql/sql/v1/context.h b/ydb/library/yql/sql/v1/context.h
index bc1ef027507..bfddf56c12d 100644
--- a/ydb/library/yql/sql/v1/context.h
+++ b/ydb/library/yql/sql/v1/context.h
@@ -272,6 +272,7 @@ namespace NSQLTranslationV1 {
NYql::TWarningPolicy WarningPolicy;
TString PqReadByRtmrCluster;
bool EmitStartsWith = true;
+ bool EmitAggApply = false;
};
class TColumnRefScope {
diff --git a/ydb/library/yql/sql/v1/sql.cpp b/ydb/library/yql/sql/v1/sql.cpp
index 1945b734f23..134b153359b 100644
--- a/ydb/library/yql/sql/v1/sql.cpp
+++ b/ydb/library/yql/sql/v1/sql.cpp
@@ -9826,6 +9826,9 @@ TNodePtr TSqlQuery::PragmaStatement(const TRule_pragma_stmt& stmt, bool& success
} else if (normalizedPragma == "disableansicurrentrow") {
Ctx.AnsiCurrentRow = false;
Ctx.IncrementMonCounter("sql_pragma", "DisableAnsiCurrentRow");
+ } else if (normalizedPragma == "emitaggapply") {
+ Ctx.EmitAggApply = true;
+ Ctx.IncrementMonCounter("sql_pragma", "EmitAggApply");
} else {
Error() << "Unknown pragma: " << pragma;
Ctx.IncrementMonCounter("sql_errors", "UnknownPragma");