diff options
author | vvvv <vvvv@ydb.tech> | 2022-12-28 12:29:17 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-12-28 12:29:17 +0300 |
commit | 2612e0181f336d66c1d20040283a2bc333a6bba2 (patch) | |
tree | b25719426aeb4dc3630ce10804ab7b86f6d4d7b2 | |
parent | 8a658f1fe52a945383c412b8d610d43b0536c0ba (diff) | |
download | ydb-2612e0181f336d66c1d20040283a2bc333a6bba2.tar.gz |
initial support of generic distinct with GROUP BY
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_blocks.cpp | 31 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 15 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 71 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.h | 3 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.cpp | 91 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_aggregate_expander.h | 3 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.cpp | 9 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.h | 3 | ||||
-rw-r--r-- | ydb/library/yql/minikql/arrow/arrow_util.cpp | 16 | ||||
-rw-r--r-- | ydb/library/yql/minikql/arrow/arrow_util.h | 4 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp | 254 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h | 1 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp | 1 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.cpp | 42 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 31 |
16 files changed, 520 insertions, 57 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index 98025c96847..3ad8cdf39e9 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -1,4 +1,5 @@ #include "type_ann_blocks.h" +#include "type_ann_list.h" #include <ydb/library/yql/core/expr_nodes/yql_expr_nodes.h> #include <ydb/library/yql/core/yql_expr_type_annotation.h> @@ -315,6 +316,12 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType return false; } + if (overState) { + if (!EnsureTupleSize(*agg, 2, ctx)) { + return false; + } + } + auto expectedCallable = overState ? "AggBlockApplyState" : "AggBlockApply"; if (!agg->Head().IsCallable(expectedCallable)) { ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected: " << expectedCallable)); @@ -445,12 +452,12 @@ IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& inpu IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { Y_UNUSED(output); const bool many = input->Content().EndsWith("ManyFinalizeHashed"); - if (!EnsureArgsCount(*input, 3U, ctx.Expr)) { + if (!EnsureArgsCount(*input, many ? 5U : 3U, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr, false, !many)) { return IGraphTransformer::TStatus::Error; } @@ -467,6 +474,26 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr t = ctx.Expr.MakeType<TBlockExprType>(t); } + if (many) { + if (!EnsureAtom(*input->Child(3), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + ui32 streamIndex; + if (!TryFromString(input->Child(3)->Content(), streamIndex) || streamIndex >= blockItemTypes.size()) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(3)->Pos()), "Bad stream index")); + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureSpecificDataType(input->Child(3)->Pos(), *blockItemTypes[streamIndex], EDataSlot::Uint32, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!ValidateAggManyStreams(*input->Child(4), input->Child(2)->ChildrenSize(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } + retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64))); auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 5aa6f048c0b..89d5b26259b 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11591,13 +11591,6 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["CastStruct"] = &CastStructWrapper; Functions["AggregationTraits"] = &AggregationTraitsWrapper; Functions["MultiAggregate"] = &MultiAggregateWrapper; - Functions["Aggregate"] = &AggregateWrapper; - Functions["AggregateCombine"] = &AggregateWrapper; - Functions["AggregateCombineState"] = &AggregateWrapper; - Functions["AggregateMergeState"] = &AggregateWrapper; - Functions["AggregateFinalize"] = &AggregateWrapper; - Functions["AggregateMergeFinalize"] = &AggregateWrapper; - Functions["AggregateMergeManyFinalize"] = &AggregateWrapper; Functions["AggOverState"] = &AggOverStateWrapper; Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper; Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper; @@ -11885,6 +11878,14 @@ template <NKikimr::NUdf::EDataSlot DataSlot> ExtFunctions["SafeCast"] = &CastWrapper<false>; ExtFunctions["StrictCast"] = &CastWrapper<true>; + ExtFunctions["Aggregate"] = &AggregateWrapper; + ExtFunctions["AggregateCombine"] = &AggregateWrapper; + ExtFunctions["AggregateCombineState"] = &AggregateWrapper; + ExtFunctions["AggregateMergeState"] = &AggregateWrapper; + ExtFunctions["AggregateFinalize"] = &AggregateWrapper; + ExtFunctions["AggregateMergeFinalize"] = &AggregateWrapper; + ExtFunctions["AggregateMergeManyFinalize"] = &AggregateWrapper; + ColumnOrderFunctions["PgSetItem"] = &OrderForPgSetItem; ColumnOrderFunctions["AssumeColumnOrder"] = &OrderForAssumeColumnOrder; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index 210d5c95337..bcb66b0b740 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -4647,9 +4647,10 @@ namespace { return output ? IGraphTransformer::TStatus::Repeat : IGraphTransformer::TStatus::Error; } - IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { TStringBuf suffix = input->Content(); YQL_ENSURE(suffix.SkipPrefix("Aggregate")); + const bool isMany = suffix == "MergeManyFinalize"; if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -4683,11 +4684,21 @@ namespace { return IGraphTransformer::TStatus::Repeat; } + if (isMany && ctx.Types.UseBlocks && !inputStructType->FindItem("_yql_group_stream_index")) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Missing service column: _yql_group_stream_index")); + return IGraphTransformer::TStatus::Error; + } + auto status = NormalizeTupleOfAtoms(input, 1, output, ctx.Expr); if (status != IGraphTransformer::TStatus::Ok) { return status; } + if (!EnsureTuple(*input->Child(2), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (input->ChildrenSize() < 4U) { auto children = input->ChildrenList(); children.push_back(ctx.Expr.NewList(input->Pos(), {})); @@ -4707,6 +4718,7 @@ namespace { } bool isHopping = false; + bool hasManyStreams = false; const auto settings = input->Child(3); if (!EnsureTuple(*settings, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -4823,6 +4835,21 @@ namespace { if (!EnsureTupleSize(*setting, 1, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } + } else if (settingName == "many_streams" && isMany) { + if (!EnsureTupleSize(*setting, 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto value = setting->ChildPtr(1); + if (!EnsureTuple(*value, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!ValidateAggManyStreams(*value, input->Child(2)->ChildrenSize(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + hasManyStreams = true; } else { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(setting->Head().Pos()), TStringBuilder() << "Unexpected setting: " << settingName)); @@ -4830,6 +4857,12 @@ namespace { } } + if (isMany && !hasManyStreams && ctx.Types.UseBlocks) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(settings->Head().Pos()), + "Missing setting: many_streams")); + return IGraphTransformer::TStatus::Error; + } + for (auto& child : input->Child(1)->Children()) { if (!EnsureAtom(*child, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -7113,5 +7146,41 @@ namespace { .Seal() .Build(); } + + bool ValidateAggManyStreams(const TExprNode& value, ui32 aggCount, TExprContext& ctx) { + THashSet<ui32> usedIdxs; + for (const auto& child : value.Children()) { + if (!EnsureTuple(*child, ctx)) { + return false; + } + + for (const auto& atom : child->Children()) { + if (!EnsureAtom(*atom, ctx)) { + return false; + } + + ui32 idx; + if (!TryFromString(atom->Content(), idx) || idx >= aggCount) { + ctx.AddError(TIssue(ctx.GetPosition(atom->Pos()), + TStringBuilder() << "Invalid aggregation index: " << atom->Content())); + return false; + } + + if (!usedIdxs.insert(idx).second) { + ctx.AddError(TIssue(ctx.GetPosition(atom->Pos()), + TStringBuilder() << "Duplication of aggregation index: " << atom->Content())); + return false; + } + } + } + + if (usedIdxs.size() != aggCount) { + ctx.AddError(TIssue(ctx.GetPosition(value.Pos()), + TStringBuilder() << "Mismatch of total aggregations count in streams, expected: " << aggCount << ", got: " << usedIdxs.size())); + return false; + } + + return true; + } } // namespace NTypeAnnImpl } diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h index 6e04740df37..10df53318ca 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.h +++ b/ydb/library/yql/core/type_ann/type_ann_list.h @@ -10,6 +10,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus InferPositionalUnionType(TPositionHandle pos, const TExprNode::TListType& children, TColumnOrder& resultColumnOrder, const TStructExprType*& resultStructType, TExtContext& ctx); TExprNode::TPtr ExpandToWindowTraits(const TExprNode& input, TExprContext& ctx); + bool ValidateAggManyStreams(const TExprNode& value, ui32 aggCount, TExprContext& ctx); IGraphTransformer::TStatus FilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); template <bool InverseCondition> @@ -93,7 +94,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus AssumeColumnOrderWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus AggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); - IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus AggOverStateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus CountedAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index 730811858d5..ab0fb4bef94 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -500,7 +500,7 @@ TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) { } TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs, - TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many) { + TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many, ui32* streamIdxColumn) { auto flow = Ctx.NewCallable(Node->Pos(), "ToFlow", { streamArg }); TVector<TString> inputColumns; for (ui32 i = 0; i < RowType->GetSize(); ++i) { @@ -528,6 +528,16 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea outputColumns.push_back(TString(keyName)); } + if (many) { + auto rowIndex = RowType->FindItem("_yql_group_stream_index"); + YQL_ENSURE(rowIndex, "Unknown column: _yql_group_stream_index"); + if (streamIdxColumn) { + *streamIdxColumn = extractorRoots.size(); + } + + extractorRoots.push_back(extractorArgs[*rowIndex]); + } + bool supported = false; YQL_ENSURE(TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, supported, Ctx)); if (!supported) { @@ -1989,6 +1999,21 @@ TExprNode::TPtr TAggregateExpander::GenerateJustOverStates(const TExprNode::TPtr .Build(); } +TExprNode::TPtr TAggregateExpander::SerializeIdxSet(const TIdxSet& indicies) { + return Ctx.Builder(Node->Pos()) + .List() + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + ui32 pos = 0; + for (ui32 i : indicies) { + parent.Atom(pos++, ToString(i)); + } + + return parent; + }) + .Seal() + .Build(); +} + TExprNode::TPtr TAggregateExpander::GeneratePhases() { const bool many = HaveDistinct; YQL_CLOG(DEBUG, Core) << "Aggregate: generate " << (many ? "phases with distinct" : "simple phases"); @@ -2125,6 +2150,8 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() { // UnionAll // MergeManyFinalize TExprNode::TListType unionAllInputs; + TExprNode::TListType streams; + if (!NonDistinctColumns.empty()) { TExprNode::TListType combineColumns; for (ui32 i : NonDistinctColumns) { @@ -2146,6 +2173,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() { .Build(); unionAllInputs.push_back(GenerateJustOverStates(combine, NonDistinctColumns)); + streams.push_back(SerializeIdxSet(NonDistinctColumns)); } for (ui32 index = 0; index < DistinctFields.size(); ++index) { @@ -2299,6 +2327,32 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() { .Build(); unionAllInputs.push_back(GenerateJustOverStates(combine, indicies)); + streams.push_back(SerializeIdxSet(indicies)); + } + + if (TypesCtx.UseBlocks) { + for (ui32 i = 0; i < unionAllInputs.size(); ++i) { + unionAllInputs[i] = Ctx.Builder(Node->Pos()) + .Callable("Map") + .Add(0, unionAllInputs[i]) + .Lambda(1) + .Param("row") + .Callable("AddMember") + .Arg(0, "row") + .Atom(1, "_yql_group_stream_index") + .Callable(2, "Uint32") + .Atom(0, ToString(i)) + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + } + } + + auto settings = Node->ChildPtr(3); + if (TypesCtx.UseBlocks) { + settings = AddSetting(*settings, Node->Pos(), "many_streams", Ctx.NewList(Node->Pos(), std::move(streams)), Ctx); } auto unionAll = Ctx.NewCallable(Node->Pos(), "UnionAll", std::move(unionAllInputs)); @@ -2307,7 +2361,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() { .Add(0, unionAll) .Add(1, KeyColumns) .Add(2, Ctx.NewList(Node->Pos(), std::move(finalizeColumns))) - .Add(3, Node->ChildPtr(3)) + .Add(3, settings) .Seal() .Build(); @@ -2357,25 +2411,42 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalizeHashed() { return nullptr; } + bool isMany = Suffix == "MergeManyFinalize"; auto streamArg = Ctx.NewArgument(Node->Pos(), "stream"); TExprNode::TListType keyIdxs; TVector<TString> outputColumns; TExprNode::TListType aggs; - auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, true, Suffix == "MergeManyFinalize"); + ui32 streamIdxColumn; + auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, true, isMany, &streamIdxColumn); if (!blocks) { return nullptr; } - auto aggWideFlow = Ctx.Builder(Node->Pos()) - .Callable("WideFromBlocks") - .Callable(0, TStringBuilder() << "Block" << Suffix << "Hashed") - .Add(0, blocks) - .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) - .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) - .Seal() + TExprNode::TPtr aggBlocks; + if (!isMany) { + aggBlocks = Ctx.Builder(Node->Pos()) + .Callable("BlockMergeFinalizeHashed") + .Add(0, blocks) + .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) + .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) + .Seal() + .Build(); + } else { + auto manyStreamsSetting = GetSetting(*Node->Child(3), "many_streams"); + YQL_ENSURE(manyStreamsSetting, "Missing many_streams setting"); + + aggBlocks = Ctx.Builder(Node->Pos()) + .Callable("BlockMergeManyFinalizeHashed") + .Add(0, blocks) + .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) + .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) + .Atom(3, ToString(streamIdxColumn)) + .Add(4, manyStreamsSetting->TailPtr()) .Seal() .Build(); + } + auto aggWideFlow = Ctx.NewCallable(Node->Pos(), "WideFromBlocks", { aggBlocks }); auto finalFlow = MakeNarrowMap(Node->Pos(), outputColumns, aggWideFlow, Ctx); auto root = Ctx.NewCallable(Node->Pos(), "FromFlow", { finalFlow }); auto lambdaStream = Ctx.NewLambda(Node->Pos(), Ctx.NewArguments(Node->Pos(), { streamArg }), std::move(root)); diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h index 213abe4ad48..74106737b1e 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.h +++ b/ydb/library/yql/core/yql_aggregate_expander.h @@ -76,12 +76,13 @@ private: TExprNode::TPtr GeneratePhases(); void GenerateInitForDistinct(TExprNodeBuilder& parent, ui32& ndx, const TIdxSet& indicies, const TExprNode::TPtr& distinctField); TExprNode::TPtr GenerateJustOverStates(const TExprNode::TPtr& input, const TIdxSet& indicies); + TExprNode::TPtr SerializeIdxSet(const TIdxSet& indicies); TExprNode::TPtr TryGenerateBlockCombineAllOrHashed(); TExprNode::TPtr TryGenerateBlockMergeFinalizeHashed(); TExprNode::TPtr TryGenerateBlockCombine(); TExprNode::TPtr TryGenerateBlockMergeFinalize(); TExprNode::TPtr MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs, - TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many); + TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many, ui32* streamIdxColumn = nullptr); private: static constexpr TStringBuf SessionStartMemberName = "_yql_group_session_start"; diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 07688adebf0..607615739f0 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -2697,7 +2697,7 @@ bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& typ return true; } -bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowChunked) { +bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowChunked, bool allowScalar) { if (!EnsureWideFlowType(node, ctx)) { return false; } @@ -2709,12 +2709,17 @@ bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListTy } bool isScalar; - for (const auto& type : items) { + for (ui32 i = 0; i < items.size(); ++i) { + const auto& type = items[i]; if (!EnsureBlockOrScalarType(node.Pos(), *type, ctx, allowChunked)) { return false; } blockItemTypes.push_back(GetBlockItemType(*type, isScalar)); + if (!allowScalar && isScalar && (i + 1 != items.size())) { + ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), "Scalars are not allowed")); + return false; + } } if (!isScalar) { diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h index 9a0fa776425..4ac91b5ed26 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.h +++ b/ydb/library/yql/core/yql_expr_type_annotation.h @@ -118,7 +118,8 @@ bool EnsureFlowType(const TExprNode& node, TExprContext& ctx); bool EnsureFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); bool EnsureWideFlowType(const TExprNode& node, TExprContext& ctx); bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); -bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowChunked = false); +bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, + bool allowChunked = false, bool allowScalar = true); bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx); bool EnsureOptionalType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); bool EnsureType(const TExprNode& node, TExprContext& ctx); diff --git a/ydb/library/yql/minikql/arrow/arrow_util.cpp b/ydb/library/yql/minikql/arrow/arrow_util.cpp index d74d81fc740..2acd9654dc2 100644 --- a/ydb/library/yql/minikql/arrow/arrow_util.cpp +++ b/ydb/library/yql/minikql/arrow/arrow_util.cpp @@ -1,4 +1,5 @@ #include "arrow_util.h" +#include <ydb/library/yql/minikql/mkql_node_builder.h> #include <util/system/yassert.h> @@ -36,5 +37,20 @@ std::shared_ptr<arrow::ArrayData> Chop(std::shared_ptr<arrow::ArrayData>& data, return first; } +std::shared_ptr<arrow::ArrayData> Unwrap(const arrow::ArrayData& data, TType* itemType) { + bool isOptional; + auto unpacked = UnpackOptional(itemType, isOptional); + MKQL_ENSURE(isOptional, "Expected optional"); + if (unpacked->IsOptional() || unpacked->IsVariant()) { + MKQL_ENSURE(data.child_data.size() == 1, "Expected struct with one element"); + return data.child_data[0]; + } else { + auto buffers = data.buffers; + MKQL_ENSURE(buffers.size() >= 1, "Missing nullable bitmap"); + buffers[0] = nullptr; + return arrow::ArrayData::Make(data.type, data.length, buffers, data.child_data, data.dictionary, 0, data.offset); + } +} + } diff --git a/ydb/library/yql/minikql/arrow/arrow_util.h b/ydb/library/yql/minikql/arrow/arrow_util.h index b30dc6f3072..76fb9a9aef6 100644 --- a/ydb/library/yql/minikql/arrow/arrow_util.h +++ b/ydb/library/yql/minikql/arrow/arrow_util.h @@ -1,6 +1,7 @@ #pragma once #include <arrow/array/data.h> +#include <ydb/library/yql/minikql/mkql_node.h> namespace NKikimr::NMiniKQL { @@ -10,4 +11,7 @@ std::shared_ptr<arrow::ArrayData> DeepSlice(const std::shared_ptr<arrow::ArrayDa /// \brief Chops first len items of `data` as new ArrayData object std::shared_ptr<arrow::ArrayData> Chop(std::shared_ptr<arrow::ArrayData>& data, size_t len); +/// \brief Remove optional from `data` as new ArrayData object +std::shared_ptr<arrow::ArrayData> Unwrap(const arrow::ArrayData& data, TType* itemType); + } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index 60531a0dbcd..6d4c8ab6fef 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -10,6 +10,7 @@ #include <ydb/library/yql/minikql/mkql_node_builder.h> #include <ydb/library/yql/minikql/arrow/arrow_defs.h> +#include <ydb/library/yql/minikql/arrow/arrow_util.h> #include <arrow/scalar.h> #include <arrow/array/array_primitive.h> @@ -302,6 +303,8 @@ namespace { template <typename T> struct TAggParams { std::unique_ptr<IPreparedBlockAggregator<T>> Prepared_; + ui32 Column_ = 0; + TType* StateType_ = nullptr; }; struct TKeyParams { @@ -673,10 +676,10 @@ TStringBuf GetKeyView(const TSSOKey& key) { return key.AsView(); } -template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, typename TDerived> +template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> class THashedWrapperBase : public TStatefulWideFlowComputationNode<TDerived> { public: - using TSelf = THashedWrapperBase<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, TDerived>; + using TSelf = THashedWrapperBase<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, Many, TDerived>; using TBase = TStatefulWideFlowComputationNode<TDerived>; static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value; @@ -687,7 +690,9 @@ public: size_t width, const std::vector<TKeyParams>& keys, std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, - TVector<TAggParams<TAggregator>>&& aggsParams) + TVector<TAggParams<TAggregator>>&& aggsParams, + ui32 streamIndex, + TVector<TVector<ui32>>&& streams) : TBase(mutables, flow, EValueRepresentation::Any) , Flow_(flow) , FilterColumn_(filterColumn) @@ -696,6 +701,8 @@ public: , Keys_(keys) , KeySerializers_(std::move(keySerializers)) , AggsParams_(std::move(aggsParams)) + , StreamIndex_(streamIndex) + , Streams_(std::move(streams)) { MKQL_ENSURE(Width_ > 0, "Missing block length column"); if constexpr (UseFilter) { @@ -742,6 +749,19 @@ public: } } + const ui32* streamIndexData = nullptr; + if constexpr (Many) { + auto streamIndexDatum = TArrowBlock::From(s.Values_[StreamIndex_]).GetDatum(); + MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array"); + streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1); + s.UnwrappedValues_ = s.Values_; + for (const auto& p : AggsParams_) { + const auto& columnDatum = TArrowBlock::From(s.UnwrappedValues_[p.Column_]).GetDatum(); + MKQL_ENSURE(columnDatum.is_array(), "Expected array"); + s.UnwrappedValues_[p.Column_] = ctx.HolderFactory.CreateArrowBlock(Unwrap(*columnDatum.array(), p.StateType_)); + } + } + s.HasValues_ = true; TVector<arrow::Datum> keysDatum; keysDatum.reserve(Keys_.size()); @@ -778,9 +798,9 @@ public: } } else { if (!InlineAggState) { - Insert(*s.HashFixedMap_, key, row, output, s); + Insert(*s.HashFixedMap_, key, row, streamIndexData, output, s); } else { - Insert(*s.HashMap_, key, row, output, s); + Insert(*s.HashMap_, key, row, streamIndexData, output, s); } } } @@ -862,6 +882,8 @@ private: TVector<NUdf::TUnboxedValue> Values_; TVector<NUdf::TUnboxedValue*> ValuePointers_; TVector<std::unique_ptr<TAggregator>> Aggs_; + TVector<ui32> AggStateOffsets_; + TVector<NUdf::TUnboxedValue> UnwrappedValues_; bool IsFinished_ = false; bool HasValues_ = false; ui32 TotalStateSize_ = 0; @@ -870,19 +892,26 @@ private: std::unique_ptr<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashFixedMap_; TPagedArena Arena_; - TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams<TAggregator>>& params, TComputationContext& ctx) + TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams<TAggregator>>& params, + const TVector<TVector<ui32>>& streams, TComputationContext& ctx) : TBase(memInfo) , Values_(width) , ValuePointers_(width) + , UnwrappedValues_(width) , Arena_(TlsAllocState) { for (size_t i = 0; i < width; ++i) { ValuePointers_[i] = &Values_[i]; } + if constexpr (Many) { + TotalStateSize_ += streams.size(); + } + for (const auto& p : params) { Aggs_.emplace_back(p.Prepared_->Make(ctx)); MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); + AggStateOffsets_.emplace_back(TotalStateSize_); TotalStateSize_ += Aggs_.back()->StateSize; } @@ -906,13 +935,13 @@ private: TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { if (!state.HasValue()) { - state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx); + state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, Streams_, ctx); } return *static_cast<TState*>(state.AsBoxed().Get()); } template <typename THash> - void Insert(THash& hash, const TKey& key, ui64 row, NUdf::TUnboxedValue*const* output, TState& s) const { + void Insert(THash& hash, const TKey& key, ui64 row, const ui32* streamIndexData, NUdf::TUnboxedValue*const* output, TState& s) const { bool isNew; auto iter = hash.Insert(key, isNew); char* payload = (char*)hash.GetMutablePayload(iter); @@ -926,16 +955,30 @@ private: ptr = payload; } - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - if constexpr (Finalize) { - s.Aggs_[i]->LoadState(ptr, s.Values_.data(), row); - } else { - s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row); + if constexpr (Many) { + static_assert(Finalize); + ui32 currentStreamIndex = streamIndexData[row]; + MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); + memset(ptr, 0, Streams_.size()); + ptr[currentStreamIndex] = 1; + + for (auto i : Streams_[currentStreamIndex]) { + if (output[Keys_.size() + i]) { + s.Aggs_[i]->LoadState(ptr + s.AggStateOffsets_[i], s.UnwrappedValues_.data(), row); } } + } else { + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + if constexpr (Finalize) { + s.Aggs_[i]->LoadState(ptr, s.Values_.data(), row); + } else { + s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row); + } + } - ptr += s.Aggs_[i]->StateSize; + ptr += s.Aggs_[i]->StateSize; + } } if constexpr (std::is_same<TKey, TSSOKey>::value) { @@ -950,16 +993,35 @@ private: ptr = payload; } - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[Keys_.size() + i]) { - if constexpr (Finalize) { - s.Aggs_[i]->UpdateState(ptr, s.Values_.data(), row); - } else { - s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row); + if constexpr (Many) { + static_assert(Finalize); + ui32 currentStreamIndex = streamIndexData[row]; + MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); + + bool isNewStream = !ptr[currentStreamIndex]; + ptr[currentStreamIndex] = 1; + + for (auto i : Streams_[currentStreamIndex]) { + if (output[Keys_.size() + i]) { + if (isNewStream) { + s.Aggs_[i]->LoadState(ptr + s.AggStateOffsets_[i], s.UnwrappedValues_.data(), row); + } else { + s.Aggs_[i]->UpdateState(ptr + s.AggStateOffsets_[i], s.UnwrappedValues_.data(), row); + } } } + } else { + for (size_t i = 0; i < s.Aggs_.size(); ++i) { + if (output[Keys_.size() + i]) { + if constexpr (Finalize) { + s.Aggs_[i]->UpdateState(ptr, s.Values_.data(), row); + } else { + s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row); + } + } - ptr += s.Aggs_[i]->StateSize; + ptr += s.Aggs_[i]->StateSize; + } } } } @@ -987,6 +1049,14 @@ private: kb->Add(in); } + if constexpr (Many) { + for (ui32 i = 0; i < Streams_.size(); ++i) { + MKQL_ENSURE(ptr[i], "Missing partial aggregation state"); + } + + ptr += Streams_.size(); + } + for (size_t i = 0; i < s.Aggs_.size(); ++i) { if (output[Keys_.size() + i]) { aggBuilders[i]->Add(ptr); @@ -1009,13 +1079,15 @@ private: const std::vector<TKeyParams> Keys_; const TVector<TAggParams<TAggregator>> AggsParams_; std::vector<std::unique_ptr<IKeySerializer>> KeySerializers_; + const ui32 StreamIndex_; + const TVector<TVector<ui32>> Streams_; }; template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter> -class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> { +class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> { public: using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>; - using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, TSelf>; + using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>; TBlockCombineHashedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, @@ -1024,15 +1096,15 @@ public: const std::vector<TKeyParams>& keys, std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, TVector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) - : TBase(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams)) + : TBase(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams), 0, {}) {} }; template <typename TKey, typename TFixedAggState, bool UseSet> -class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> { +class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> { public: using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>; - using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, TSelf>; + using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>; TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, @@ -1040,7 +1112,24 @@ public: const std::vector<TKeyParams>& keys, std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) - : TBase(mutables, flow, {}, width, keys, std::move(keySerializers), std::move(aggsParams)) + : TBase(mutables, flow, {}, width, keys, std::move(keySerializers), std::move(aggsParams), 0, {}) + {} +}; + +template <typename TKey, typename TFixedAggState> +class TBlockMergeManyFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>> { +public: + using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>; + using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>; + + TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, + TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, TVector<TVector<ui32>>&& streams) + : TBase(mutables, flow, {}, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)) {} }; @@ -1080,7 +1169,7 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareB } template <typename TAggregator> -ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams<TAggregator>>& aggsParams, const TTypeEnvironment& env) { +ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams<TAggregator>>& aggsParams, const TTypeEnvironment& env, bool overState) { ui32 totalStateSize = 0; for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) { auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i)); @@ -1093,6 +1182,12 @@ ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional< TAggParams<TAggregator> p; p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), tupleType, filterColumn, argColumns, env); + if (overState) { + MKQL_ENSURE(argColumns.size() == 1, "Expected exactly one column"); + p.Column_ = argColumns[0]; + p.StateType_ = AS_TYPE(TBlockType, tupleType->GetElementType(p.Column_))->GetItemType(); + } + totalStateSize += p.Prepared_->StateSize; aggsParams.emplace_back(std::move(p)); } @@ -1185,6 +1280,51 @@ IComputationNode* MakeBlockMergeFinalizeHashedWrapper( return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams)); } +template <typename TKey> +IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, + TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, + TVector<TVector<ui32>>&& streams) { + + if (totalStateSize <= sizeof(TState8)) { + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState8>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); + } + + if (totalStateSize <= sizeof(TState16)) { + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState16>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); + } + + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TStateArena>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); +} + +IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( + ui32 totalKeysSize, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers, + TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, + TVector<TVector<ui32>>&& streams) { + if (totalKeysSize <= sizeof(ui32)) { + return MakeBlockMergeManyFinalizeHashedWrapper<ui32>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); + } + + if (totalKeysSize <= sizeof(ui64)) { + return MakeBlockMergeManyFinalizeHashedWrapper<ui64>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); + } + + return MakeBlockMergeManyFinalizeHashedWrapper<TSSOKey>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); +} + void PrepareKeys(const std::vector<TKeyParams>& keys, ui32& totalKeysSize, std::vector<std::unique_ptr<IKeySerializer>>& keySerializers) { totalKeysSize = 0; keySerializers.clear(); @@ -1280,6 +1420,19 @@ void PrepareKeys(const std::vector<TKeyParams>& keys, ui32& totalKeysSize, std:: } } +void FillAggStreams(TRuntimeNode streamsNode, TVector<TVector<ui32>>& streams) { + auto streamsVal = AS_VALUE(TTupleLiteral, streamsNode); + for (ui32 i = 0; i < streamsVal->GetValuesCount(); ++i) { + streams.emplace_back(); + auto& stream = streams.back(); + auto streamVal = AS_VALUE(TTupleLiteral, streamsVal->GetValue(i)); + for (ui32 j = 0; j < streamVal->GetValuesCount(); ++j) { + ui32 index = AS_VALUE(TDataLiteral, streamVal->GetValue(j))->AsValue().Get<ui32>(); + stream.emplace_back(index); + } + } +} + } IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) { @@ -1298,7 +1451,7 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); TVector<TAggParams<IBlockAggregatorCombineAll>> aggsParams; - ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env); + ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false); return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams)); } @@ -1325,7 +1478,7 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3)); TVector<TAggParams<IBlockAggregatorCombineKeys>> aggsParams; - ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env); + ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false); ui32 totalKeysSize = 0; std::vector<std::unique_ptr<IKeySerializer>> keySerializers; @@ -1363,7 +1516,7 @@ IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TCompu auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); TVector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams; - ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env); + ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true); ui32 totalKeysSize = 0; std::vector<std::unique_ptr<IKeySerializer>> keySerializers; @@ -1376,5 +1529,42 @@ IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TCompu } } +IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args"); + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto tupleType = AS_TYPE(TTupleType, flowType->GetItemType()); + + auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + + auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1)); + std::vector<TKeyParams> keys; + for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) { + ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>(); + keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) }); + } + + auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); + TVector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams; + ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true); + + ui32 totalKeysSize = 0; + std::vector<std::unique_ptr<IKeySerializer>> keySerializers; + PrepareKeys(keys, totalKeysSize, keySerializers); + + ui32 streamIndex = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>(); + TVector<TVector<ui32>> streams; + FillAggStreams(callable.GetInput(4), streams); + totalStateSize += streams.size(); + + if (aggsParams.size() == 0) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), + keys, std::move(keySerializers), std::move(aggsParams)); + } else { + return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), + keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams)); + } +} + } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h index 872788babec..6bd98eace24 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h @@ -8,6 +8,7 @@ namespace NMiniKQL { IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx); IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx); IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index fc79af03498..e29c051c971 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -285,6 +285,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"BlockCombineAll", &WrapBlockCombineAll}, {"BlockCombineHashed", &WrapBlockCombineHashed}, {"BlockMergeFinalizeHashed", &WrapBlockMergeFinalizeHashed}, + {"BlockMergeManyFinalizeHashed", &WrapBlockMergeManyFinalizeHashed}, {"MakeHeap", &WrapMakeHeap}, {"PushHeap", &WrapPushHeap}, {"PopHeap", &WrapPopHeap}, diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 6403fc25ec8..bb9bf360732 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -5409,6 +5409,48 @@ TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode flow, const return TRuntimeNode(builder.Build(), false); } +TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) { + if constexpr (RuntimeVersion < 31U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; + } + + TCallableBuilder builder(Env, __func__, returnType); + builder.Add(flow); + + TVector<TRuntimeNode> keyNodes; + for (const auto& key : keys) { + keyNodes.push_back(NewDataLiteral<ui32>(key)); + } + + builder.Add(NewTuple(keyNodes)); + TVector<TRuntimeNode> aggsNodes; + for (const auto& agg : aggs) { + TVector<TRuntimeNode> params; + params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name)); + for (const auto& col : agg.ArgsColumns) { + params.push_back(NewDataLiteral<ui32>(col)); + } + + aggsNodes.push_back(NewTuple(params)); + } + + builder.Add(NewTuple(aggsNodes)); + builder.Add(NewDataLiteral<ui32>(streamIndex)); + TVector<TRuntimeNode> streamsNodes; + for (const auto& s : streams) { + TVector<TRuntimeNode> streamNodes; + for (const auto& i : s) { + streamNodes.push_back(NewDataLiteral<ui32>(i)); + } + + streamsNodes.push_back(NewTuple(streamNodes)); + } + + builder.Add(NewTuple(streamsNodes)); + return TRuntimeNode(builder.Build(), false); +} + bool CanExportType(TType* type, const TTypeEnvironment& env) { if (type->GetKind() == TType::EKind::Type) { return false; // Type of Type diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 407db818dae..d037a3cfd47 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -264,6 +264,8 @@ public: const TArrayRef<const TAggInfo>& aggs, TType* returnType); TRuntimeNode BlockMergeFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType); + TRuntimeNode BlockMergeManyFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType); // udfs TRuntimeNode Udf( diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index 84421c87b9f..cec8f7e8633 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2475,6 +2475,37 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.BlockMergeFinalizeHashed(arg, keys, aggs, returnType); }); + AddCallable("BlockMergeManyFinalizeHashed", [](const TExprNode& node, TMkqlBuildContext& ctx) { + auto arg = MkqlBuildExpr(*node.Child(0), ctx); + TVector<ui32> keys; + for (const auto& key : node.Child(1)->Children()) { + keys.push_back(FromString<ui32>(key->Content())); + } + + TVector<TAggInfo> aggs; + for (const auto& agg : node.Child(2)->Children()) { + TAggInfo info; + info.Name = TString(agg->Head().Head().Content()); + for (ui32 i = 1; i < agg->ChildrenSize(); ++i) { + info.ArgsColumns.push_back(FromString<ui32>(agg->Child(i)->Content())); + } + + aggs.push_back(info); + } + + ui32 streamIndex = FromString<ui32>(node.Child(3)->Content()); + TVector<TVector<ui32>> streams; + for (const auto& child : node.Child(4)->Children()) { + auto& stream = streams.emplace_back(); + for (const auto& atom : child->Children()) { + stream.emplace_back(FromString<ui32>(atom->Content())); + } + } + + auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + return ctx.ProgramBuilder.BlockMergeManyFinalizeHashed(arg, keys, aggs, streamIndex, streams, returnType); + }); + AddCallable("BlockCompress", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto flow = MkqlBuildExpr(node.Head(), ctx); const auto index = FromString<ui32>(node.Child(1)->Content()); |