aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-12-28 12:29:17 +0300
committervvvv <vvvv@ydb.tech>2022-12-28 12:29:17 +0300
commit2612e0181f336d66c1d20040283a2bc333a6bba2 (patch)
treeb25719426aeb4dc3630ce10804ab7b86f6d4d7b2
parent8a658f1fe52a945383c412b8d610d43b0536c0ba (diff)
downloadydb-2612e0181f336d66c1d20040283a2bc333a6bba2.tar.gz
initial support of generic distinct with GROUP BY
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.cpp31
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp15
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp71
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.h3
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.cpp91
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.h3
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.cpp9
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.h3
-rw-r--r--ydb/library/yql/minikql/arrow/arrow_util.cpp16
-rw-r--r--ydb/library/yql/minikql/arrow/arrow_util.h4
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp254
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp1
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp42
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h2
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp31
16 files changed, 520 insertions, 57 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
index 98025c96847..3ad8cdf39e9 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
@@ -1,4 +1,5 @@
#include "type_ann_blocks.h"
+#include "type_ann_list.h"
#include <ydb/library/yql/core/expr_nodes/yql_expr_nodes.h>
#include <ydb/library/yql/core/yql_expr_type_annotation.h>
@@ -315,6 +316,12 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType
return false;
}
+ if (overState) {
+ if (!EnsureTupleSize(*agg, 2, ctx)) {
+ return false;
+ }
+ }
+
auto expectedCallable = overState ? "AggBlockApplyState" : "AggBlockApply";
if (!agg->Head().IsCallable(expectedCallable)) {
ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected: " << expectedCallable));
@@ -445,12 +452,12 @@ IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& inpu
IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
Y_UNUSED(output);
const bool many = input->Content().EndsWith("ManyFinalizeHashed");
- if (!EnsureArgsCount(*input, 3U, ctx.Expr)) {
+ if (!EnsureArgsCount(*input, many ? 5U : 3U, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
TTypeAnnotationNode::TListType blockItemTypes;
- if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
+ if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr, false, !many)) {
return IGraphTransformer::TStatus::Error;
}
@@ -467,6 +474,26 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr
t = ctx.Expr.MakeType<TBlockExprType>(t);
}
+ if (many) {
+ if (!EnsureAtom(*input->Child(3), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ ui32 streamIndex;
+ if (!TryFromString(input->Child(3)->Content(), streamIndex) || streamIndex >= blockItemTypes.size()) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(3)->Pos()), "Bad stream index"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureSpecificDataType(input->Child(3)->Pos(), *blockItemTypes[streamIndex], EDataSlot::Uint32, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!ValidateAggManyStreams(*input->Child(4), input->Child(2)->ChildrenSize(), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)));
auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType);
input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType));
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 5aa6f048c0b..89d5b26259b 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11591,13 +11591,6 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["CastStruct"] = &CastStructWrapper;
Functions["AggregationTraits"] = &AggregationTraitsWrapper;
Functions["MultiAggregate"] = &MultiAggregateWrapper;
- Functions["Aggregate"] = &AggregateWrapper;
- Functions["AggregateCombine"] = &AggregateWrapper;
- Functions["AggregateCombineState"] = &AggregateWrapper;
- Functions["AggregateMergeState"] = &AggregateWrapper;
- Functions["AggregateFinalize"] = &AggregateWrapper;
- Functions["AggregateMergeFinalize"] = &AggregateWrapper;
- Functions["AggregateMergeManyFinalize"] = &AggregateWrapper;
Functions["AggOverState"] = &AggOverStateWrapper;
Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper;
Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper;
@@ -11885,6 +11878,14 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
ExtFunctions["SafeCast"] = &CastWrapper<false>;
ExtFunctions["StrictCast"] = &CastWrapper<true>;
+ ExtFunctions["Aggregate"] = &AggregateWrapper;
+ ExtFunctions["AggregateCombine"] = &AggregateWrapper;
+ ExtFunctions["AggregateCombineState"] = &AggregateWrapper;
+ ExtFunctions["AggregateMergeState"] = &AggregateWrapper;
+ ExtFunctions["AggregateFinalize"] = &AggregateWrapper;
+ ExtFunctions["AggregateMergeFinalize"] = &AggregateWrapper;
+ ExtFunctions["AggregateMergeManyFinalize"] = &AggregateWrapper;
+
ColumnOrderFunctions["PgSetItem"] = &OrderForPgSetItem;
ColumnOrderFunctions["AssumeColumnOrder"] = &OrderForAssumeColumnOrder;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index 210d5c95337..bcb66b0b740 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -4647,9 +4647,10 @@ namespace {
return output ? IGraphTransformer::TStatus::Repeat : IGraphTransformer::TStatus::Error;
}
- IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
TStringBuf suffix = input->Content();
YQL_ENSURE(suffix.SkipPrefix("Aggregate"));
+ const bool isMany = suffix == "MergeManyFinalize";
if (!EnsureMinArgsCount(*input, 3, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -4683,11 +4684,21 @@ namespace {
return IGraphTransformer::TStatus::Repeat;
}
+ if (isMany && ctx.Types.UseBlocks && !inputStructType->FindItem("_yql_group_stream_index")) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Missing service column: _yql_group_stream_index"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
auto status = NormalizeTupleOfAtoms(input, 1, output, ctx.Expr);
if (status != IGraphTransformer::TStatus::Ok) {
return status;
}
+ if (!EnsureTuple(*input->Child(2), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
if (input->ChildrenSize() < 4U) {
auto children = input->ChildrenList();
children.push_back(ctx.Expr.NewList(input->Pos(), {}));
@@ -4707,6 +4718,7 @@ namespace {
}
bool isHopping = false;
+ bool hasManyStreams = false;
const auto settings = input->Child(3);
if (!EnsureTuple(*settings, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -4823,6 +4835,21 @@ namespace {
if (!EnsureTupleSize(*setting, 1, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
+ } else if (settingName == "many_streams" && isMany) {
+ if (!EnsureTupleSize(*setting, 2, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto value = setting->ChildPtr(1);
+ if (!EnsureTuple(*value, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!ValidateAggManyStreams(*value, input->Child(2)->ChildrenSize(), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ hasManyStreams = true;
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(setting->Head().Pos()),
TStringBuilder() << "Unexpected setting: " << settingName));
@@ -4830,6 +4857,12 @@ namespace {
}
}
+ if (isMany && !hasManyStreams && ctx.Types.UseBlocks) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(settings->Head().Pos()),
+ "Missing setting: many_streams"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
for (auto& child : input->Child(1)->Children()) {
if (!EnsureAtom(*child, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -7113,5 +7146,41 @@ namespace {
.Seal()
.Build();
}
+
+ bool ValidateAggManyStreams(const TExprNode& value, ui32 aggCount, TExprContext& ctx) {
+ THashSet<ui32> usedIdxs;
+ for (const auto& child : value.Children()) {
+ if (!EnsureTuple(*child, ctx)) {
+ return false;
+ }
+
+ for (const auto& atom : child->Children()) {
+ if (!EnsureAtom(*atom, ctx)) {
+ return false;
+ }
+
+ ui32 idx;
+ if (!TryFromString(atom->Content(), idx) || idx >= aggCount) {
+ ctx.AddError(TIssue(ctx.GetPosition(atom->Pos()),
+ TStringBuilder() << "Invalid aggregation index: " << atom->Content()));
+ return false;
+ }
+
+ if (!usedIdxs.insert(idx).second) {
+ ctx.AddError(TIssue(ctx.GetPosition(atom->Pos()),
+ TStringBuilder() << "Duplication of aggregation index: " << atom->Content()));
+ return false;
+ }
+ }
+ }
+
+ if (usedIdxs.size() != aggCount) {
+ ctx.AddError(TIssue(ctx.GetPosition(value.Pos()),
+ TStringBuilder() << "Mismatch of total aggregations count in streams, expected: " << aggCount << ", got: " << usedIdxs.size()));
+ return false;
+ }
+
+ return true;
+ }
} // namespace NTypeAnnImpl
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h
index 6e04740df37..10df53318ca 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.h
+++ b/ydb/library/yql/core/type_ann/type_ann_list.h
@@ -10,6 +10,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus InferPositionalUnionType(TPositionHandle pos, const TExprNode::TListType& children,
TColumnOrder& resultColumnOrder, const TStructExprType*& resultStructType, TExtContext& ctx);
TExprNode::TPtr ExpandToWindowTraits(const TExprNode& input, TExprContext& ctx);
+ bool ValidateAggManyStreams(const TExprNode& value, ui32 aggCount, TExprContext& ctx);
IGraphTransformer::TStatus FilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
template <bool InverseCondition>
@@ -93,7 +94,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus AssumeColumnOrderWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
IGraphTransformer::TStatus AggregationTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
- IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+ IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
IGraphTransformer::TStatus AggOverStateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus CountedAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp
index 730811858d5..ab0fb4bef94 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.cpp
+++ b/ydb/library/yql/core/yql_aggregate_expander.cpp
@@ -500,7 +500,7 @@ TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) {
}
TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs,
- TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many) {
+ TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many, ui32* streamIdxColumn) {
auto flow = Ctx.NewCallable(Node->Pos(), "ToFlow", { streamArg });
TVector<TString> inputColumns;
for (ui32 i = 0; i < RowType->GetSize(); ++i) {
@@ -528,6 +528,16 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea
outputColumns.push_back(TString(keyName));
}
+ if (many) {
+ auto rowIndex = RowType->FindItem("_yql_group_stream_index");
+ YQL_ENSURE(rowIndex, "Unknown column: _yql_group_stream_index");
+ if (streamIdxColumn) {
+ *streamIdxColumn = extractorRoots.size();
+ }
+
+ extractorRoots.push_back(extractorArgs[*rowIndex]);
+ }
+
bool supported = false;
YQL_ENSURE(TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, supported, Ctx));
if (!supported) {
@@ -1989,6 +1999,21 @@ TExprNode::TPtr TAggregateExpander::GenerateJustOverStates(const TExprNode::TPtr
.Build();
}
+TExprNode::TPtr TAggregateExpander::SerializeIdxSet(const TIdxSet& indicies) {
+ return Ctx.Builder(Node->Pos())
+ .List()
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
+ ui32 pos = 0;
+ for (ui32 i : indicies) {
+ parent.Atom(pos++, ToString(i));
+ }
+
+ return parent;
+ })
+ .Seal()
+ .Build();
+}
+
TExprNode::TPtr TAggregateExpander::GeneratePhases() {
const bool many = HaveDistinct;
YQL_CLOG(DEBUG, Core) << "Aggregate: generate " << (many ? "phases with distinct" : "simple phases");
@@ -2125,6 +2150,8 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() {
// UnionAll
// MergeManyFinalize
TExprNode::TListType unionAllInputs;
+ TExprNode::TListType streams;
+
if (!NonDistinctColumns.empty()) {
TExprNode::TListType combineColumns;
for (ui32 i : NonDistinctColumns) {
@@ -2146,6 +2173,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() {
.Build();
unionAllInputs.push_back(GenerateJustOverStates(combine, NonDistinctColumns));
+ streams.push_back(SerializeIdxSet(NonDistinctColumns));
}
for (ui32 index = 0; index < DistinctFields.size(); ++index) {
@@ -2299,6 +2327,32 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() {
.Build();
unionAllInputs.push_back(GenerateJustOverStates(combine, indicies));
+ streams.push_back(SerializeIdxSet(indicies));
+ }
+
+ if (TypesCtx.UseBlocks) {
+ for (ui32 i = 0; i < unionAllInputs.size(); ++i) {
+ unionAllInputs[i] = Ctx.Builder(Node->Pos())
+ .Callable("Map")
+ .Add(0, unionAllInputs[i])
+ .Lambda(1)
+ .Param("row")
+ .Callable("AddMember")
+ .Arg(0, "row")
+ .Atom(1, "_yql_group_stream_index")
+ .Callable(2, "Uint32")
+ .Atom(0, ToString(i))
+ .Seal()
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+ }
+ }
+
+ auto settings = Node->ChildPtr(3);
+ if (TypesCtx.UseBlocks) {
+ settings = AddSetting(*settings, Node->Pos(), "many_streams", Ctx.NewList(Node->Pos(), std::move(streams)), Ctx);
}
auto unionAll = Ctx.NewCallable(Node->Pos(), "UnionAll", std::move(unionAllInputs));
@@ -2307,7 +2361,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() {
.Add(0, unionAll)
.Add(1, KeyColumns)
.Add(2, Ctx.NewList(Node->Pos(), std::move(finalizeColumns)))
- .Add(3, Node->ChildPtr(3))
+ .Add(3, settings)
.Seal()
.Build();
@@ -2357,25 +2411,42 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalizeHashed() {
return nullptr;
}
+ bool isMany = Suffix == "MergeManyFinalize";
auto streamArg = Ctx.NewArgument(Node->Pos(), "stream");
TExprNode::TListType keyIdxs;
TVector<TString> outputColumns;
TExprNode::TListType aggs;
- auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, true, Suffix == "MergeManyFinalize");
+ ui32 streamIdxColumn;
+ auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, true, isMany, &streamIdxColumn);
if (!blocks) {
return nullptr;
}
- auto aggWideFlow = Ctx.Builder(Node->Pos())
- .Callable("WideFromBlocks")
- .Callable(0, TStringBuilder() << "Block" << Suffix << "Hashed")
- .Add(0, blocks)
- .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs)))
- .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs)))
- .Seal()
+ TExprNode::TPtr aggBlocks;
+ if (!isMany) {
+ aggBlocks = Ctx.Builder(Node->Pos())
+ .Callable("BlockMergeFinalizeHashed")
+ .Add(0, blocks)
+ .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs)))
+ .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs)))
+ .Seal()
+ .Build();
+ } else {
+ auto manyStreamsSetting = GetSetting(*Node->Child(3), "many_streams");
+ YQL_ENSURE(manyStreamsSetting, "Missing many_streams setting");
+
+ aggBlocks = Ctx.Builder(Node->Pos())
+ .Callable("BlockMergeManyFinalizeHashed")
+ .Add(0, blocks)
+ .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs)))
+ .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs)))
+ .Atom(3, ToString(streamIdxColumn))
+ .Add(4, manyStreamsSetting->TailPtr())
.Seal()
.Build();
+ }
+ auto aggWideFlow = Ctx.NewCallable(Node->Pos(), "WideFromBlocks", { aggBlocks });
auto finalFlow = MakeNarrowMap(Node->Pos(), outputColumns, aggWideFlow, Ctx);
auto root = Ctx.NewCallable(Node->Pos(), "FromFlow", { finalFlow });
auto lambdaStream = Ctx.NewLambda(Node->Pos(), Ctx.NewArguments(Node->Pos(), { streamArg }), std::move(root));
diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h
index 213abe4ad48..74106737b1e 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.h
+++ b/ydb/library/yql/core/yql_aggregate_expander.h
@@ -76,12 +76,13 @@ private:
TExprNode::TPtr GeneratePhases();
void GenerateInitForDistinct(TExprNodeBuilder& parent, ui32& ndx, const TIdxSet& indicies, const TExprNode::TPtr& distinctField);
TExprNode::TPtr GenerateJustOverStates(const TExprNode::TPtr& input, const TIdxSet& indicies);
+ TExprNode::TPtr SerializeIdxSet(const TIdxSet& indicies);
TExprNode::TPtr TryGenerateBlockCombineAllOrHashed();
TExprNode::TPtr TryGenerateBlockMergeFinalizeHashed();
TExprNode::TPtr TryGenerateBlockCombine();
TExprNode::TPtr TryGenerateBlockMergeFinalize();
TExprNode::TPtr MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs,
- TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many);
+ TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState, bool many, ui32* streamIdxColumn = nullptr);
private:
static constexpr TStringBuf SessionStartMemberName = "_yql_group_session_start";
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp
index 07688adebf0..607615739f0 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.cpp
+++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp
@@ -2697,7 +2697,7 @@ bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& typ
return true;
}
-bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowChunked) {
+bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowChunked, bool allowScalar) {
if (!EnsureWideFlowType(node, ctx)) {
return false;
}
@@ -2709,12 +2709,17 @@ bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListTy
}
bool isScalar;
- for (const auto& type : items) {
+ for (ui32 i = 0; i < items.size(); ++i) {
+ const auto& type = items[i];
if (!EnsureBlockOrScalarType(node.Pos(), *type, ctx, allowChunked)) {
return false;
}
blockItemTypes.push_back(GetBlockItemType(*type, isScalar));
+ if (!allowScalar && isScalar && (i + 1 != items.size())) {
+ ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), "Scalars are not allowed"));
+ return false;
+ }
}
if (!isScalar) {
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h
index 9a0fa776425..4ac91b5ed26 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.h
+++ b/ydb/library/yql/core/yql_expr_type_annotation.h
@@ -118,7 +118,8 @@ bool EnsureFlowType(const TExprNode& node, TExprContext& ctx);
bool EnsureFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx);
bool EnsureWideFlowType(const TExprNode& node, TExprContext& ctx);
bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx);
-bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowChunked = false);
+bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx,
+ bool allowChunked = false, bool allowScalar = true);
bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx);
bool EnsureOptionalType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx);
bool EnsureType(const TExprNode& node, TExprContext& ctx);
diff --git a/ydb/library/yql/minikql/arrow/arrow_util.cpp b/ydb/library/yql/minikql/arrow/arrow_util.cpp
index d74d81fc740..2acd9654dc2 100644
--- a/ydb/library/yql/minikql/arrow/arrow_util.cpp
+++ b/ydb/library/yql/minikql/arrow/arrow_util.cpp
@@ -1,4 +1,5 @@
#include "arrow_util.h"
+#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <util/system/yassert.h>
@@ -36,5 +37,20 @@ std::shared_ptr<arrow::ArrayData> Chop(std::shared_ptr<arrow::ArrayData>& data,
return first;
}
+std::shared_ptr<arrow::ArrayData> Unwrap(const arrow::ArrayData& data, TType* itemType) {
+ bool isOptional;
+ auto unpacked = UnpackOptional(itemType, isOptional);
+ MKQL_ENSURE(isOptional, "Expected optional");
+ if (unpacked->IsOptional() || unpacked->IsVariant()) {
+ MKQL_ENSURE(data.child_data.size() == 1, "Expected struct with one element");
+ return data.child_data[0];
+ } else {
+ auto buffers = data.buffers;
+ MKQL_ENSURE(buffers.size() >= 1, "Missing nullable bitmap");
+ buffers[0] = nullptr;
+ return arrow::ArrayData::Make(data.type, data.length, buffers, data.child_data, data.dictionary, 0, data.offset);
+ }
+}
+
}
diff --git a/ydb/library/yql/minikql/arrow/arrow_util.h b/ydb/library/yql/minikql/arrow/arrow_util.h
index b30dc6f3072..76fb9a9aef6 100644
--- a/ydb/library/yql/minikql/arrow/arrow_util.h
+++ b/ydb/library/yql/minikql/arrow/arrow_util.h
@@ -1,6 +1,7 @@
#pragma once
#include <arrow/array/data.h>
+#include <ydb/library/yql/minikql/mkql_node.h>
namespace NKikimr::NMiniKQL {
@@ -10,4 +11,7 @@ std::shared_ptr<arrow::ArrayData> DeepSlice(const std::shared_ptr<arrow::ArrayDa
/// \brief Chops first len items of `data` as new ArrayData object
std::shared_ptr<arrow::ArrayData> Chop(std::shared_ptr<arrow::ArrayData>& data, size_t len);
+/// \brief Remove optional from `data` as new ArrayData object
+std::shared_ptr<arrow::ArrayData> Unwrap(const arrow::ArrayData& data, TType* itemType);
+
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 60531a0dbcd..6d4c8ab6fef 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -10,6 +10,7 @@
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <ydb/library/yql/minikql/arrow/arrow_defs.h>
+#include <ydb/library/yql/minikql/arrow/arrow_util.h>
#include <arrow/scalar.h>
#include <arrow/array/array_primitive.h>
@@ -302,6 +303,8 @@ namespace {
template <typename T>
struct TAggParams {
std::unique_ptr<IPreparedBlockAggregator<T>> Prepared_;
+ ui32 Column_ = 0;
+ TType* StateType_ = nullptr;
};
struct TKeyParams {
@@ -673,10 +676,10 @@ TStringBuf GetKeyView(const TSSOKey& key) {
return key.AsView();
}
-template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, typename TDerived>
+template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived>
class THashedWrapperBase : public TStatefulWideFlowComputationNode<TDerived> {
public:
- using TSelf = THashedWrapperBase<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, TDerived>;
+ using TSelf = THashedWrapperBase<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, Many, TDerived>;
using TBase = TStatefulWideFlowComputationNode<TDerived>;
static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value;
@@ -687,7 +690,9 @@ public:
size_t width,
const std::vector<TKeyParams>& keys,
std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
- TVector<TAggParams<TAggregator>>&& aggsParams)
+ TVector<TAggParams<TAggregator>>&& aggsParams,
+ ui32 streamIndex,
+ TVector<TVector<ui32>>&& streams)
: TBase(mutables, flow, EValueRepresentation::Any)
, Flow_(flow)
, FilterColumn_(filterColumn)
@@ -696,6 +701,8 @@ public:
, Keys_(keys)
, KeySerializers_(std::move(keySerializers))
, AggsParams_(std::move(aggsParams))
+ , StreamIndex_(streamIndex)
+ , Streams_(std::move(streams))
{
MKQL_ENSURE(Width_ > 0, "Missing block length column");
if constexpr (UseFilter) {
@@ -742,6 +749,19 @@ public:
}
}
+ const ui32* streamIndexData = nullptr;
+ if constexpr (Many) {
+ auto streamIndexDatum = TArrowBlock::From(s.Values_[StreamIndex_]).GetDatum();
+ MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array");
+ streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1);
+ s.UnwrappedValues_ = s.Values_;
+ for (const auto& p : AggsParams_) {
+ const auto& columnDatum = TArrowBlock::From(s.UnwrappedValues_[p.Column_]).GetDatum();
+ MKQL_ENSURE(columnDatum.is_array(), "Expected array");
+ s.UnwrappedValues_[p.Column_] = ctx.HolderFactory.CreateArrowBlock(Unwrap(*columnDatum.array(), p.StateType_));
+ }
+ }
+
s.HasValues_ = true;
TVector<arrow::Datum> keysDatum;
keysDatum.reserve(Keys_.size());
@@ -778,9 +798,9 @@ public:
}
} else {
if (!InlineAggState) {
- Insert(*s.HashFixedMap_, key, row, output, s);
+ Insert(*s.HashFixedMap_, key, row, streamIndexData, output, s);
} else {
- Insert(*s.HashMap_, key, row, output, s);
+ Insert(*s.HashMap_, key, row, streamIndexData, output, s);
}
}
}
@@ -862,6 +882,8 @@ private:
TVector<NUdf::TUnboxedValue> Values_;
TVector<NUdf::TUnboxedValue*> ValuePointers_;
TVector<std::unique_ptr<TAggregator>> Aggs_;
+ TVector<ui32> AggStateOffsets_;
+ TVector<NUdf::TUnboxedValue> UnwrappedValues_;
bool IsFinished_ = false;
bool HasValues_ = false;
ui32 TotalStateSize_ = 0;
@@ -870,19 +892,26 @@ private:
std::unique_ptr<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>>> HashFixedMap_;
TPagedArena Arena_;
- TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams<TAggregator>>& params, TComputationContext& ctx)
+ TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams<TAggregator>>& params,
+ const TVector<TVector<ui32>>& streams, TComputationContext& ctx)
: TBase(memInfo)
, Values_(width)
, ValuePointers_(width)
+ , UnwrappedValues_(width)
, Arena_(TlsAllocState)
{
for (size_t i = 0; i < width; ++i) {
ValuePointers_[i] = &Values_[i];
}
+ if constexpr (Many) {
+ TotalStateSize_ += streams.size();
+ }
+
for (const auto& p : params) {
Aggs_.emplace_back(p.Prepared_->Make(ctx));
MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch");
+ AggStateOffsets_.emplace_back(TotalStateSize_);
TotalStateSize_ += Aggs_.back()->StateSize;
}
@@ -906,13 +935,13 @@ private:
TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
if (!state.HasValue()) {
- state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
+ state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, Streams_, ctx);
}
return *static_cast<TState*>(state.AsBoxed().Get());
}
template <typename THash>
- void Insert(THash& hash, const TKey& key, ui64 row, NUdf::TUnboxedValue*const* output, TState& s) const {
+ void Insert(THash& hash, const TKey& key, ui64 row, const ui32* streamIndexData, NUdf::TUnboxedValue*const* output, TState& s) const {
bool isNew;
auto iter = hash.Insert(key, isNew);
char* payload = (char*)hash.GetMutablePayload(iter);
@@ -926,16 +955,30 @@ private:
ptr = payload;
}
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- if constexpr (Finalize) {
- s.Aggs_[i]->LoadState(ptr, s.Values_.data(), row);
- } else {
- s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row);
+ if constexpr (Many) {
+ static_assert(Finalize);
+ ui32 currentStreamIndex = streamIndexData[row];
+ MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index");
+ memset(ptr, 0, Streams_.size());
+ ptr[currentStreamIndex] = 1;
+
+ for (auto i : Streams_[currentStreamIndex]) {
+ if (output[Keys_.size() + i]) {
+ s.Aggs_[i]->LoadState(ptr + s.AggStateOffsets_[i], s.UnwrappedValues_.data(), row);
}
}
+ } else {
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ if constexpr (Finalize) {
+ s.Aggs_[i]->LoadState(ptr, s.Values_.data(), row);
+ } else {
+ s.Aggs_[i]->InitKey(ptr, s.Values_.data(), row);
+ }
+ }
- ptr += s.Aggs_[i]->StateSize;
+ ptr += s.Aggs_[i]->StateSize;
+ }
}
if constexpr (std::is_same<TKey, TSSOKey>::value) {
@@ -950,16 +993,35 @@ private:
ptr = payload;
}
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[Keys_.size() + i]) {
- if constexpr (Finalize) {
- s.Aggs_[i]->UpdateState(ptr, s.Values_.data(), row);
- } else {
- s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row);
+ if constexpr (Many) {
+ static_assert(Finalize);
+ ui32 currentStreamIndex = streamIndexData[row];
+ MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index");
+
+ bool isNewStream = !ptr[currentStreamIndex];
+ ptr[currentStreamIndex] = 1;
+
+ for (auto i : Streams_[currentStreamIndex]) {
+ if (output[Keys_.size() + i]) {
+ if (isNewStream) {
+ s.Aggs_[i]->LoadState(ptr + s.AggStateOffsets_[i], s.UnwrappedValues_.data(), row);
+ } else {
+ s.Aggs_[i]->UpdateState(ptr + s.AggStateOffsets_[i], s.UnwrappedValues_.data(), row);
+ }
}
}
+ } else {
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
+ if (output[Keys_.size() + i]) {
+ if constexpr (Finalize) {
+ s.Aggs_[i]->UpdateState(ptr, s.Values_.data(), row);
+ } else {
+ s.Aggs_[i]->UpdateKey(ptr, s.Values_.data(), row);
+ }
+ }
- ptr += s.Aggs_[i]->StateSize;
+ ptr += s.Aggs_[i]->StateSize;
+ }
}
}
}
@@ -987,6 +1049,14 @@ private:
kb->Add(in);
}
+ if constexpr (Many) {
+ for (ui32 i = 0; i < Streams_.size(); ++i) {
+ MKQL_ENSURE(ptr[i], "Missing partial aggregation state");
+ }
+
+ ptr += Streams_.size();
+ }
+
for (size_t i = 0; i < s.Aggs_.size(); ++i) {
if (output[Keys_.size() + i]) {
aggBuilders[i]->Add(ptr);
@@ -1009,13 +1079,15 @@ private:
const std::vector<TKeyParams> Keys_;
const TVector<TAggParams<TAggregator>> AggsParams_;
std::vector<std::unique_ptr<IKeySerializer>> KeySerializers_;
+ const ui32 StreamIndex_;
+ const TVector<TVector<ui32>> Streams_;
};
template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter>
-class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> {
+class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> {
public:
using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>;
- using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, TSelf>;
+ using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>;
TBlockCombineHashedWrapper(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
@@ -1024,15 +1096,15 @@ public:
const std::vector<TKeyParams>& keys,
std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
TVector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams)
- : TBase(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams))
+ : TBase(mutables, flow, filterColumn, width, keys, std::move(keySerializers), std::move(aggsParams), 0, {})
{}
};
template <typename TKey, typename TFixedAggState, bool UseSet>
-class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> {
+class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> {
public:
using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>;
- using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, TSelf>;
+ using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>;
TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
@@ -1040,7 +1112,24 @@ public:
const std::vector<TKeyParams>& keys,
std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams)
- : TBase(mutables, flow, {}, width, keys, std::move(keySerializers), std::move(aggsParams))
+ : TBase(mutables, flow, {}, width, keys, std::move(keySerializers), std::move(aggsParams), 0, {})
+ {}
+};
+
+template <typename TKey, typename TFixedAggState>
+class TBlockMergeManyFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>> {
+public:
+ using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>;
+ using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>;
+
+ TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
+ TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
+ ui32 streamIndex, TVector<TVector<ui32>>&& streams)
+ : TBase(mutables, flow, {}, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams))
{}
};
@@ -1080,7 +1169,7 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareB
}
template <typename TAggregator>
-ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams<TAggregator>>& aggsParams, const TTypeEnvironment& env) {
+ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams<TAggregator>>& aggsParams, const TTypeEnvironment& env, bool overState) {
ui32 totalStateSize = 0;
for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) {
auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i));
@@ -1093,6 +1182,12 @@ ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<
TAggParams<TAggregator> p;
p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), tupleType, filterColumn, argColumns, env);
+ if (overState) {
+ MKQL_ENSURE(argColumns.size() == 1, "Expected exactly one column");
+ p.Column_ = argColumns[0];
+ p.StateType_ = AS_TYPE(TBlockType, tupleType->GetElementType(p.Column_))->GetItemType();
+ }
+
totalStateSize += p.Prepared_->StateSize;
aggsParams.emplace_back(std::move(p));
}
@@ -1185,6 +1280,51 @@ IComputationNode* MakeBlockMergeFinalizeHashedWrapper(
return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams));
}
+template <typename TKey>
+IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper(
+ ui32 totalStateSize,
+ TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
+ TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
+ ui32 streamIndex,
+ TVector<TVector<ui32>>&& streams) {
+
+ if (totalStateSize <= sizeof(TState8)) {
+ return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState8>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+ }
+
+ if (totalStateSize <= sizeof(TState16)) {
+ return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState16>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+ }
+
+ return new TBlockMergeManyFinalizeHashedWrapper<TKey, TStateArena>(mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+}
+
+IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper(
+ ui32 totalKeysSize,
+ ui32 totalStateSize,
+ TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ std::vector<std::unique_ptr<IKeySerializer>>&& keySerializers,
+ TVector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
+ ui32 streamIndex,
+ TVector<TVector<ui32>>&& streams) {
+ if (totalKeysSize <= sizeof(ui32)) {
+ return MakeBlockMergeManyFinalizeHashedWrapper<ui32>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+ }
+
+ if (totalKeysSize <= sizeof(ui64)) {
+ return MakeBlockMergeManyFinalizeHashedWrapper<ui64>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+ }
+
+ return MakeBlockMergeManyFinalizeHashedWrapper<TSSOKey>(totalStateSize, mutables, flow, width, keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+}
+
void PrepareKeys(const std::vector<TKeyParams>& keys, ui32& totalKeysSize, std::vector<std::unique_ptr<IKeySerializer>>& keySerializers) {
totalKeysSize = 0;
keySerializers.clear();
@@ -1280,6 +1420,19 @@ void PrepareKeys(const std::vector<TKeyParams>& keys, ui32& totalKeysSize, std::
}
}
+void FillAggStreams(TRuntimeNode streamsNode, TVector<TVector<ui32>>& streams) {
+ auto streamsVal = AS_VALUE(TTupleLiteral, streamsNode);
+ for (ui32 i = 0; i < streamsVal->GetValuesCount(); ++i) {
+ streams.emplace_back();
+ auto& stream = streams.back();
+ auto streamVal = AS_VALUE(TTupleLiteral, streamsVal->GetValue(i));
+ for (ui32 j = 0; j < streamVal->GetValuesCount(); ++j) {
+ ui32 index = AS_VALUE(TDataLiteral, streamVal->GetValue(j))->AsValue().Get<ui32>();
+ stream.emplace_back(index);
+ }
+ }
+}
+
}
IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -1298,7 +1451,7 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
TVector<TAggParams<IBlockAggregatorCombineAll>> aggsParams;
- ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false);
return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams));
}
@@ -1325,7 +1478,7 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3));
TVector<TAggParams<IBlockAggregatorCombineKeys>> aggsParams;
- ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false);
ui32 totalKeysSize = 0;
std::vector<std::unique_ptr<IKeySerializer>> keySerializers;
@@ -1363,7 +1516,7 @@ IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TCompu
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
TVector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams;
- ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env);
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true);
ui32 totalKeysSize = 0;
std::vector<std::unique_ptr<IKeySerializer>> keySerializers;
@@ -1376,5 +1529,42 @@ IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TCompu
}
}
+IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args");
+ const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
+ const auto tupleType = AS_TYPE(TTupleType, flowType->GetItemType());
+
+ auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
+ MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
+
+ auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1));
+ std::vector<TKeyParams> keys;
+ for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) {
+ ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>();
+ keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) });
+ }
+
+ auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
+ TVector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams;
+ ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true);
+
+ ui32 totalKeysSize = 0;
+ std::vector<std::unique_ptr<IKeySerializer>> keySerializers;
+ PrepareKeys(keys, totalKeysSize, keySerializers);
+
+ ui32 streamIndex = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>();
+ TVector<TVector<ui32>> streams;
+ FillAggStreams(callable.GetInput(4), streams);
+ totalStateSize += streams.size();
+
+ if (aggsParams.size() == 0) {
+ return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(),
+ keys, std::move(keySerializers), std::move(aggsParams));
+ } else {
+ return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(),
+ keys, std::move(keySerializers), std::move(aggsParams), streamIndex, std::move(streams));
+ }
+}
+
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h
index 872788babec..6bd98eace24 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.h
@@ -8,6 +8,7 @@ namespace NMiniKQL {
IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx);
IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx);
IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
index fc79af03498..e29c051c971 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
@@ -285,6 +285,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"BlockCombineAll", &WrapBlockCombineAll},
{"BlockCombineHashed", &WrapBlockCombineHashed},
{"BlockMergeFinalizeHashed", &WrapBlockMergeFinalizeHashed},
+ {"BlockMergeManyFinalizeHashed", &WrapBlockMergeManyFinalizeHashed},
{"MakeHeap", &WrapMakeHeap},
{"PushHeap", &WrapPushHeap},
{"PopHeap", &WrapPopHeap},
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 6403fc25ec8..bb9bf360732 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -5409,6 +5409,48 @@ TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode flow, const
return TRuntimeNode(builder.Build(), false);
}
+TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys,
+ const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
+ if constexpr (RuntimeVersion < 31U) {
+ THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
+ }
+
+ TCallableBuilder builder(Env, __func__, returnType);
+ builder.Add(flow);
+
+ TVector<TRuntimeNode> keyNodes;
+ for (const auto& key : keys) {
+ keyNodes.push_back(NewDataLiteral<ui32>(key));
+ }
+
+ builder.Add(NewTuple(keyNodes));
+ TVector<TRuntimeNode> aggsNodes;
+ for (const auto& agg : aggs) {
+ TVector<TRuntimeNode> params;
+ params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
+ for (const auto& col : agg.ArgsColumns) {
+ params.push_back(NewDataLiteral<ui32>(col));
+ }
+
+ aggsNodes.push_back(NewTuple(params));
+ }
+
+ builder.Add(NewTuple(aggsNodes));
+ builder.Add(NewDataLiteral<ui32>(streamIndex));
+ TVector<TRuntimeNode> streamsNodes;
+ for (const auto& s : streams) {
+ TVector<TRuntimeNode> streamNodes;
+ for (const auto& i : s) {
+ streamNodes.push_back(NewDataLiteral<ui32>(i));
+ }
+
+ streamsNodes.push_back(NewTuple(streamNodes));
+ }
+
+ builder.Add(NewTuple(streamsNodes));
+ return TRuntimeNode(builder.Build(), false);
+}
+
bool CanExportType(TType* type, const TTypeEnvironment& env) {
if (type->GetKind() == TType::EKind::Type) {
return false; // Type of Type
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index 407db818dae..d037a3cfd47 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -264,6 +264,8 @@ public:
const TArrayRef<const TAggInfo>& aggs, TType* returnType);
TRuntimeNode BlockMergeFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys,
const TArrayRef<const TAggInfo>& aggs, TType* returnType);
+ TRuntimeNode BlockMergeManyFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys,
+ const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType);
// udfs
TRuntimeNode Udf(
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index 84421c87b9f..cec8f7e8633 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -2475,6 +2475,37 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
return ctx.ProgramBuilder.BlockMergeFinalizeHashed(arg, keys, aggs, returnType);
});
+ AddCallable("BlockMergeManyFinalizeHashed", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+ auto arg = MkqlBuildExpr(*node.Child(0), ctx);
+ TVector<ui32> keys;
+ for (const auto& key : node.Child(1)->Children()) {
+ keys.push_back(FromString<ui32>(key->Content()));
+ }
+
+ TVector<TAggInfo> aggs;
+ for (const auto& agg : node.Child(2)->Children()) {
+ TAggInfo info;
+ info.Name = TString(agg->Head().Head().Content());
+ for (ui32 i = 1; i < agg->ChildrenSize(); ++i) {
+ info.ArgsColumns.push_back(FromString<ui32>(agg->Child(i)->Content()));
+ }
+
+ aggs.push_back(info);
+ }
+
+ ui32 streamIndex = FromString<ui32>(node.Child(3)->Content());
+ TVector<TVector<ui32>> streams;
+ for (const auto& child : node.Child(4)->Children()) {
+ auto& stream = streams.emplace_back();
+ for (const auto& atom : child->Children()) {
+ stream.emplace_back(FromString<ui32>(atom->Content()));
+ }
+ }
+
+ auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder);
+ return ctx.ProgramBuilder.BlockMergeManyFinalizeHashed(arg, keys, aggs, streamIndex, streams, returnType);
+ });
+
AddCallable("BlockCompress", [](const TExprNode& node, TMkqlBuildContext& ctx) {
const auto flow = MkqlBuildExpr(node.Head(), ctx);
const auto index = FromString<ui32>(node.Child(1)->Content());