aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-12-19 20:34:44 +0300
committervvvv <vvvv@ydb.tech>2022-12-19 20:34:44 +0300
commit0b60d14b135627201f472fceca644d6a7e01019f (patch)
treeaf0c5be2fcaf38dd2a17e3fcd134dcd233ae2e70
parent81a58118787d9507ae39aeecaf003285bc3da5f6 (diff)
downloadydb-0b60d14b135627201f472fceca644d6a7e01019f.tar.gz
support of final aggregation by keys (DQ/YT)
-rw-r--r--ydb/library/yql/core/expr_nodes/yql_expr_nodes.json9
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp21
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.cpp79
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.h1
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp3
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp166
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.h1
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.cpp107
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.h10
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.cpp70
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.h7
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_log.cpp2
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy.cpp94
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy.h2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h12
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp22
-rw-r--r--ydb/library/yql/providers/dq/opt/physical_optimize.cpp5
17 files changed, 513 insertions, 98 deletions
diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
index 0c31c72546..8698b9d604 100644
--- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
+++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
@@ -2260,6 +2260,15 @@
{"Index": 2, "Name": "Keys", "Type": "TCoAtomList"},
{"Index": 3, "Name": "Aggregations", "Type": "TExprList"}
]
+ },
+ {
+ "Name": "TCoShuffleByKeys",
+ "Base": "TCoInputBase",
+ "Match": {"Type": "Callable", "Name": "ShuffleByKeys"},
+ "Children": [
+ {"Index": 1, "Name": "KeySelectorLambda", "Type": "TCoLambda"},
+ {"Index": 2, "Name": "ListHandlerLambda", "Type": "TCoLambda"}
+ ]
}
]
}
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 20d6fae8d5..1dd2d05f8e 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -68,13 +68,23 @@ TExprNode::TPtr Now0Arg(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnn
}
bool IsArgumentsOnlyLambda(const TExprNode& lambda, TVector<ui32>& argIndices) {
+ TNodeMap<ui32> args;
+ for (ui32 i = 0; i < lambda.Head().ChildrenSize(); ++i) {
+ args.insert(std::make_pair(lambda.Head().Child(i), i));
+ }
+
for (ui32 i = 1; i < lambda.ChildrenSize(); ++i) {
auto root = lambda.Child(i);
- if (!root->IsArgument() || root->GetLambdaLevel() > 0) {
+ if (!root->IsArgument()) {
+ return false;
+ }
+
+ auto it = args.find(root);
+ if (it == args.end()) {
return false;
}
- argIndices.push_back(root->GetArgIndex());
+ argIndices.push_back(it->second);
}
return true;
@@ -2456,7 +2466,7 @@ TExprNode::TPtr ExpandMux(const TExprNode::TPtr& node, TExprContext& ctx) {
return node;
}
-TExprNode::TPtr ExpandLMap(const TExprNode::TPtr& node, TExprContext& ctx) {
+TExprNode::TPtr ExpandLMapOrShuffleByKeys(const TExprNode::TPtr& node, TExprContext& ctx) {
YQL_CLOG(DEBUG, CorePeepHole) << "Expand " << node->Content();
return ctx.Builder(node->Pos())
.Callable("Collect")
@@ -6678,8 +6688,9 @@ struct TPeepHoleRules {
{"OrderedFilter", &ExpandFilter},
{"TakeWhile", &ExpandFilter<false>},
{"SkipWhile", &ExpandFilter<true>},
- {"LMap", &ExpandLMap},
- {"OrderedLMap", &ExpandLMap},
+ {"LMap", &ExpandLMapOrShuffleByKeys},
+ {"OrderedLMap", &ExpandLMapOrShuffleByKeys},
+ {"ShuffleByKeys", &ExpandLMapOrShuffleByKeys},
{"ExpandMap", &OptimizeExpandMap},
{"MultiMap", &OptimizeMultiMap<false>},
{"OrderedMultiMap", &OptimizeMultiMap<true>},
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
index bcf423eac7..e95cd3cd33 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
@@ -251,8 +251,31 @@ IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TEx
return IGraphTransformer::TStatus::Ok;
}
-bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType inputItems, const TExprNode& aggs,
- TTypeAnnotationNode::TListType& retMultiType, TExprContext& ctx) {
+bool ValidateBlockKeys(TPositionHandle pos, const TTypeAnnotationNode::TListType& inputItems,
+ const TExprNode& keys, TTypeAnnotationNode::TListType& retMultiType, TExprContext& ctx) {
+ if (!EnsureTupleMinSize(keys, 1, ctx)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ for (auto child : keys.Children()) {
+ if (!EnsureAtom(*child, ctx)) {
+ return false;
+ }
+
+ ui32 keyColumnIndex;
+ if (!TryFromString(child->Content(), keyColumnIndex) || keyColumnIndex >= inputItems.size()) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos), "Bad key column index"));
+ return false;
+ }
+
+ retMultiType.push_back(inputItems[keyColumnIndex]);
+ }
+
+ return true;
+}
+
+bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType& inputItems, const TExprNode& aggs,
+ TTypeAnnotationNode::TListType& retMultiType, TExprContext& ctx, bool overState) {
if (!EnsureTuple(aggs, ctx)) {
return false;
}
@@ -262,8 +285,9 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType
return false;
}
- if (!agg->Head().IsCallable("AggBlockApply")) {
- ctx.AddError(TIssue(ctx.GetPosition(pos), "Expected AggBlockApply"));
+ auto expectedCallable = overState ? "AggBlockApplyState" : "AggBlockApply";
+ if (!agg->Head().IsCallable(expectedCallable)) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected: " << expectedCallable));
return false;
}
@@ -287,7 +311,8 @@ bool ValidateBlockAggs(TPositionHandle pos, const TTypeAnnotationNode::TListType
}
}
- retMultiType.push_back(AggApplySerializedStateType(agg->HeadPtr(), ctx));
+ auto retAggType = overState ? agg->HeadPtr()->GetTypeAnn() : AggApplySerializedStateType(agg->HeadPtr(), ctx);
+ retMultiType.push_back(retAggType);
}
return true;
@@ -321,7 +346,7 @@ IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input,
}
TTypeAnnotationNode::TListType retMultiType;
- if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr)) {
+ if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr, false)) {
return IGraphTransformer::TStatus::Error;
}
@@ -362,21 +387,41 @@ IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& inpu
}
TTypeAnnotationNode::TListType retMultiType;
- for (auto child : input->Child(2)->Children()) {
- if (!EnsureAtom(*child, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
+ if (!ValidateBlockKeys(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
- ui32 keyColumnIndex;
- if (!TryFromString(child->Content(), keyColumnIndex) || keyColumnIndex >= blockItemTypes.size()) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Bad key column index"));
- return IGraphTransformer::TStatus::Error;
- }
+ if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(3), retMultiType, ctx.Expr, false)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ for (auto& t : retMultiType) {
+ t = ctx.Expr.MakeType<TBlockExprType>(t);
+ }
+
+ retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)));
+ auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType);
+ input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType));
+ return IGraphTransformer::TStatus::Ok;
+}
- retMultiType.push_back(blockItemTypes[keyColumnIndex]);
+IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 3U, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ TTypeAnnotationNode::TListType blockItemTypes;
+ if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ TTypeAnnotationNode::TListType retMultiType;
+ if (!ValidateBlockKeys(input->Pos(), blockItemTypes, *input->Child(1), retMultiType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
}
- if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(3), retMultiType, ctx.Expr)) {
+ if (!ValidateBlockAggs(input->Pos(), blockItemTypes, *input->Child(2), retMultiType, ctx.Expr, true)) {
return IGraphTransformer::TStatus::Error;
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h
index b8fd6fb9b8..e461142643 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.h
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h
@@ -16,6 +16,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
+ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
} // namespace NTypeAnnImpl
} // namespace NYql
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 54931a0cfa..361f97dec5 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11364,6 +11364,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["Chain1Map"] = &Chain1MapWrapper;
Functions["LMap"] = &LMapWrapper;
Functions["OrderedLMap"] = &LMapWrapper;
+ Functions["ShuffleByKeys"] = &ShuffleByKeysWrapper;
Functions["Struct"] = &StructWrapper;
Functions["AddMember"] = &AddMemberWrapper;
Functions["RemoveMember"] = &RemoveMemberWrapper<false>;
@@ -11584,6 +11585,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["AggApply"] = &AggApplyWrapper;
Functions["AggApplyState"] = &AggApplyWrapper;
Functions["AggBlockApply"] = &AggBlockApplyWrapper;
+ Functions["AggBlockApplyState"] = &AggBlockApplyWrapper;
Functions["WinOnRows"] = &WinOnWrapper;
Functions["WinOnGroups"] = &WinOnWrapper;
Functions["WinOnRange"] = &WinOnWrapper;
@@ -11792,6 +11794,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper;
ExtFunctions["BlockCombineAll"] = &BlockCombineAllWrapper;
ExtFunctions["BlockCombineHashed"] = &BlockCombineHashedWrapper;
+ ExtFunctions["BlockMergeFinalizeHashed"] = &BlockMergeFinalizeHashedWrapper;
Functions["AsRange"] = &AsRangeWrapper;
Functions["RangeCreate"] = &RangeCreateWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index 040dbb9e04..210d5c9533 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -871,6 +871,74 @@ namespace {
return IGraphTransformer::TStatus::Ok;
}
+ IGraphTransformer::TStatus ShuffleByKeysWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 3, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureListType(input->Head(), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto itemType = input->Head().GetTypeAnn()->Cast<TListExprType>()->GetItemType();
+ auto& lambdaKeySelector = input->ChildRef(1);
+ auto status = ConvertToLambda(lambdaKeySelector, ctx.Expr, 1);
+ if (status.Level != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ if (!UpdateLambdaAllArgumentsTypes(lambdaKeySelector, {itemType}, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!lambdaKeySelector->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ auto keyType = lambdaKeySelector->GetTypeAnn();
+ if (!EnsureHashableKey(lambdaKeySelector->Pos(), keyType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureEquatableKey(lambdaKeySelector->Pos(), keyType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto& lambdaHandler = input->ChildRef(2);
+ status = ConvertToLambda(lambdaHandler, ctx.Expr, 1);
+ if (status.Level != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ auto handlerStreamType = ctx.Expr.MakeType<TStreamExprType>(itemType);
+
+ if (!UpdateLambdaAllArgumentsTypes(lambdaHandler, { handlerStreamType }, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!lambdaHandler->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ if (!EnsureSeqOrOptionalType(lambdaHandler->Pos(), *lambdaHandler->GetTypeAnn(), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto retKind = lambdaHandler->GetTypeAnn()->GetKind();
+ const TTypeAnnotationNode* retItemType;
+ if (retKind == ETypeAnnotationKind::List) {
+ retItemType = lambdaHandler->GetTypeAnn()->Cast<TListExprType>()->GetItemType();
+ } else if (retKind == ETypeAnnotationKind::Optional) {
+ retItemType = lambdaHandler->GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType();
+ } else {
+ retItemType = lambdaHandler->GetTypeAnn()->Cast<TStreamExprType>()->GetItemType();
+ }
+
+ input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(retItemType));
+ return IGraphTransformer::TStatus::Ok;
+ }
+
IGraphTransformer::TStatus FoldMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
if (!EnsureArgsCount(*input, 3, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -5155,7 +5223,16 @@ namespace {
}
if (name == "count" || name == "count_all") {
- input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64));
+ const TTypeAnnotationNode* retType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
+ if (overState) {
+ if (!IsSameAnnotation(*lambda->GetTypeAnn(), *retType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Mismatch count type, expected: " << *lambda->GetTypeAnn() << ", but got: " << *retType));
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
+ input->SetTypeAnn(retType);
} else if (name == "sum") {
const TTypeAnnotationNode* retType;
if (!GetSumResultType(input->Pos(), *lambda->GetTypeAnn(), retType, ctx.Expr)) {
@@ -5178,42 +5255,8 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
} else {
- auto itemType = lambda->GetTypeAnn();
- if (IsNull(*itemType)) {
- retType = itemType;
- } else {
- bool isOptional = false;
- if (itemType->GetKind() == ETypeAnnotationKind::Optional) {
- isOptional = true;
- itemType = itemType->Cast<TOptionalExprType>()->GetItemType();
- }
-
- if (!EnsureTupleTypeSize(lambda->Pos(), itemType, 2, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
-
- auto tupleType = itemType->Cast<TTupleExprType>();
- auto sumType = tupleType->GetItems()[0];
- const TTypeAnnotationNode* sumTypeOut;
- if (!GetSumResultType(input->Pos(), *sumType, sumTypeOut, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
-
- if (!IsSameAnnotation(*sumType, *sumTypeOut)) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
- TStringBuilder() << "Mismatch sum type, expected: " << *sumType << ", but got: " << *sumTypeOut));
- return IGraphTransformer::TStatus::Error;
- }
-
- auto countType = tupleType->GetItems()[1];
- if (!EnsureSpecificDataType(lambda->Pos(), *countType, EDataSlot::Uint64, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
-
- retType = sumType;
- if (isOptional) {
- retType = ctx.Expr.MakeType<TOptionalExprType>(retType);
- }
+ if (!GetAvgResultTypeOverState(input->Pos(), *lambda->GetTypeAnn(), retType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
}
}
@@ -5240,6 +5283,14 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
+ if (overState) {
+ if (!IsSameAnnotation(*lambda->GetTypeAnn(), *retType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Mismatch min/max type, expected: " << *lambda->GetTypeAnn() << ", but got: " << *retType));
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
input->SetTypeAnn(retType);
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
@@ -5252,6 +5303,7 @@ namespace {
IGraphTransformer::TStatus AggBlockApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
+ const bool overState = input->Content().EndsWith("State");
if (!EnsureMinArgsCount(*input, 1, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -5263,7 +5315,7 @@ namespace {
auto name = input->Child(0)->Content();
ui32 expectedArgs;
if (name == "count_all") {
- expectedArgs = 1;
+ expectedArgs = overState ? 2 : 1;
} else if (name == "count" || name == "sum" || name == "avg" || name == "min" || name == "max") {
expectedArgs = 2;
} else {
@@ -5283,7 +5335,17 @@ namespace {
}
if (name == "count_all" || name == "count") {
- input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64));
+ const TTypeAnnotationNode* retType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
+ if (overState) {
+ auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ if (!IsSameAnnotation(*itemType, *retType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Mismatch count type, expected: " << *itemType << ", but got: " << *retType));
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
+ input->SetTypeAnn(retType);
} else if (name == "sum") {
auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
const TTypeAnnotationNode* retType;
@@ -5291,12 +5353,26 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
+ if (overState) {
+ if (!IsSameAnnotation(*itemType, *retType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Mismatch sum type, expected: " << *itemType << ", but got: " << *retType));
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
input->SetTypeAnn(retType);
} else if (name == "avg") {
auto itemType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
const TTypeAnnotationNode* retType;
- if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
+ if (!overState) {
+ if (!GetAvgResultType(input->Pos(), *itemType, retType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+ } else {
+ if (!GetAvgResultTypeOverState(input->Pos(), *itemType, retType, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
}
input->SetTypeAnn(retType);
@@ -5307,6 +5383,14 @@ namespace {
return IGraphTransformer::TStatus::Error;
}
+ if (overState) {
+ if (!IsSameAnnotation(*itemType, *retType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Mismatch min/max type, expected: " << *itemType << ", but got: " << *retType));
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
input->SetTypeAnn(retType);
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h
index afeaa54f7b..6e04740df3 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.h
+++ b/ydb/library/yql/core/type_ann/type_ann_list.h
@@ -17,6 +17,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus MapNextWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus LMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+ IGraphTransformer::TStatus ShuffleByKeysWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
template <bool Warn>
IGraphTransformer::TStatus FlatMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
template <bool Ordered>
diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp
index 6ffce1fd3d..79f1dc8cc2 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.cpp
+++ b/ydb/library/yql/core/yql_aggregate_expander.cpp
@@ -63,6 +63,13 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregate()
return ret;
}
}
+
+ if (Suffix == "MergeFinalize") {
+ auto ret = TryGenerateBlockMergeFinalize();
+ if (ret) {
+ return ret;
+ }
+ }
}
if (!allTraitsCollected) {
@@ -492,14 +499,8 @@ TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) {
.Build();
}
-TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
- if (!TypesCtx.ArrowResolver) {
- return nullptr;
- }
-
- const bool hashed = (KeyColumns->ChildrenSize() > 0);
-
- auto streamArg = Ctx.NewArgument(Node->Pos(), "stream");
+TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs,
+ TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState) {
auto flow = Ctx.NewCallable(Node->Pos(), "ToFlow", { streamArg });
TVector<TString> inputColumns;
for (ui32 i = 0; i < RowType->GetSize(); ++i) {
@@ -514,9 +515,6 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
}
TExprNode::TListType extractorRoots;
- TExprNode::TListType aggs;
- TVector<TString> outputColumns;
- TExprNode::TListType keyIdxs;
TVector<const TTypeAnnotationNode*> allKeyTypes;
for (ui32 index = 0; index < KeyColumns->ChildrenSize(); ++index) {
auto keyName = KeyColumns->Child(index)->Content();
@@ -538,7 +536,7 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
for (ui32 index = 0; index < AggregatedColumns->ChildrenSize(); ++index) {
auto trait = AggregatedColumns->Child(index)->ChildPtr(1);
- if (trait->Child(0)->Content() == "count_all") {
+ if (!overState && trait->Child(0)->Content() == "count_all") {
// 0 columns
aggs.push_back(Ctx.Builder(Node->Pos())
.List()
@@ -547,7 +545,8 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
.Seal()
.Seal()
.Build());
- } else {
+ }
+ else {
// 1 column
auto root = trait->Child(2)->TailPtr();
auto rowArg = &trait->Child(2)->Head().Head();
@@ -575,7 +574,7 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
aggs.push_back(Ctx.Builder(Node->Pos())
.List()
- .Callable(0, "AggBlockApply")
+ .Callable(0, TString("AggBlockApply") + (overState ? "State" : ""))
.Atom(0, trait->Child(0)->Content())
.Add(1, ExpandType(Node->Pos(), *trait->Child(2)->GetTypeAnn(), Ctx))
.Seal()
@@ -592,6 +591,25 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
auto extractorLambda = Ctx.NewLambda(Node->Pos(), Ctx.NewArguments(Node->Pos(), std::move(extractorArgs)), std::move(extractorRoots));
auto mappedWideFlow = Ctx.NewCallable(Node->Pos(), "WideMap", { wideFlow, extractorLambda });
auto blocks = Ctx.NewCallable(Node->Pos(), "WideToBlocks", { mappedWideFlow });
+ return blocks;
+}
+
+TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() {
+ if (!TypesCtx.ArrowResolver) {
+ return nullptr;
+ }
+
+ const bool hashed = (KeyColumns->ChildrenSize() > 0);
+
+ auto streamArg = Ctx.NewArgument(Node->Pos(), "stream");
+ TExprNode::TListType keyIdxs;
+ TVector<TString> outputColumns;
+ TExprNode::TListType aggs;
+ auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, false);
+ if (!blocks) {
+ return nullptr;
+ }
+
TExprNode::TPtr aggWideFlow;
if (hashed) {
aggWideFlow = Ctx.Builder(Node->Pos())
@@ -2234,4 +2252,65 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombine() {
return TryGenerateBlockCombineAllOrHashed();
}
+TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalize() {
+ if (UsePartitionsByKeys) {
+ return nullptr;
+ }
+
+ if (HaveSessionSetting || HaveDistinct) {
+ return nullptr;
+ }
+
+ for (const auto& x : AggregatedColumns->Children()) {
+ auto trait = x->ChildPtr(1);
+ if (!trait->IsCallable("AggApplyState")) {
+ return nullptr;
+ }
+ }
+
+ return TryGenerateBlockMergeFinalizeHashed();
+}
+
+TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalizeHashed() {
+ if (!TypesCtx.ArrowResolver) {
+ return nullptr;
+ }
+
+ if (KeyColumns->ChildrenSize() == 0) {
+ return nullptr;
+ }
+
+ auto streamArg = Ctx.NewArgument(Node->Pos(), "stream");
+ TExprNode::TListType keyIdxs;
+ TVector<TString> outputColumns;
+ TExprNode::TListType aggs;
+ auto blocks = MakeInputBlocks(streamArg, keyIdxs, outputColumns, aggs, true);
+ if (!blocks) {
+ return nullptr;
+ }
+
+ auto aggWideFlow = Ctx.Builder(Node->Pos())
+ .Callable("WideFromBlocks")
+ .Callable(0, "BlockMergeFinalizeHashed")
+ .Add(0, blocks)
+ .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs)))
+ .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs)))
+ .Seal()
+ .Seal()
+ .Build();
+
+ auto finalFlow = MakeNarrowMap(Node->Pos(), outputColumns, aggWideFlow, Ctx);
+ auto root = Ctx.NewCallable(Node->Pos(), "FromFlow", { finalFlow });
+ auto lambdaStream = Ctx.NewLambda(Node->Pos(), Ctx.NewArguments(Node->Pos(), { streamArg }), std::move(root));
+
+ auto keySelector = BuildKeySelector(Node->Pos(), *OriginalRowType, KeyColumns, Ctx);
+ return Ctx.Builder(Node->Pos())
+ .Callable("ShuffleByKeys")
+ .Add(0, AggList)
+ .Add(1, keySelector)
+ .Add(2, lambdaStream)
+ .Seal()
+ .Build();
+}
+
} // namespace NYql
diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h
index 63f1cc9dfc..7695c7cf6b 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.h
+++ b/ydb/library/yql/core/yql_aggregate_expander.h
@@ -8,12 +8,13 @@ namespace NYql {
class TAggregateExpander {
public:
- TAggregateExpander(bool allowPickle, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx,
+ TAggregateExpander(bool allowPickle, bool usePartitionsByKeys, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx,
bool forceCompact = false, bool compactForDistinct = false, bool usePhases = false)
: Node(node)
, Ctx(ctx)
, TypesCtx(typesCtx)
, AllowPickle(allowPickle)
+ , UsePartitionsByKeys(usePartitionsByKeys)
, ForceCompact(forceCompact)
, CompactForDistinct(compactForDistinct)
, UsePhases(usePhases)
@@ -76,7 +77,11 @@ private:
void GenerateInitForDistinct(TExprNodeBuilder& parent, ui32& ndx, const TIdxSet& indicies, const TExprNode::TPtr& distinctField);
TExprNode::TPtr GenerateJustOverStates(const TExprNode::TPtr& input, const TIdxSet& indicies);
TExprNode::TPtr TryGenerateBlockCombineAllOrHashed();
+ TExprNode::TPtr TryGenerateBlockMergeFinalizeHashed();
TExprNode::TPtr TryGenerateBlockCombine();
+ TExprNode::TPtr TryGenerateBlockMergeFinalize();
+ TExprNode::TPtr MakeInputBlocks(const TExprNode::TPtr& streamArg, TExprNode::TListType& keyIdxs,
+ TVector<TString>& outputColumns, TExprNode::TListType& aggs, bool overState);
private:
static constexpr TStringBuf SessionStartMemberName = "_yql_group_session_start";
@@ -85,6 +90,7 @@ private:
TExprContext& Ctx;
TTypeAnnotationContext& TypesCtx;
bool AllowPickle;
+ bool UsePartitionsByKeys;
bool ForceCompact;
bool CompactForDistinct;
bool UsePhases;
@@ -121,7 +127,7 @@ private:
};
inline TExprNode::TPtr ExpandAggregatePeephole(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx) {
- TAggregateExpander aggExpander(false, node, ctx, typesCtx, true);
+ TAggregateExpander aggExpander(false, true, node, ctx, typesCtx, true);
return aggExpander.ExpandAggregate();
}
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp
index 040ff5dd49..008dcc700e 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.cpp
+++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp
@@ -5405,10 +5405,10 @@ const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& in
}
}
-bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
+bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
bool isOptional;
const TDataExprType* lambdaType;
- if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) {
+ if(IsDataOrOptionalOfData(&inputType, isOptional, lambdaType)) {
auto lambdaTypeSlot = lambdaType->GetSlot();
const TTypeAnnotationNode *sumResultType = nullptr;
if (IsDataTypeSigned(lambdaTypeSlot)) {
@@ -5432,28 +5432,28 @@ bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& ite
retType = sumResultType;
return true;
- } else if (IsNull(itemType)) {
+ } else if (IsNull(inputType)) {
retType = ctx.MakeType<TNullExprType>();
return true;
} else {
ctx.AddError(TIssue(ctx.GetPosition(pos),
- TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data."));
+ TStringBuilder() << "Unsupported type: " << FormatType(&inputType) << ". Expected Data or Optional of Data."));
return false;
}
}
-bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
+bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
bool isOptional;
const TDataExprType* lambdaType;
- if(IsDataOrOptionalOfData(&itemType, isOptional, lambdaType)) {
+ if(IsDataOrOptionalOfData(&inputType, isOptional, lambdaType)) {
auto lambdaTypeSlot = lambdaType->GetSlot();
const TTypeAnnotationNode *avgResultType = nullptr;
if (IsDataTypeNumeric(lambdaTypeSlot)) {
avgResultType = ctx.MakeType<TDataExprType>(EDataSlot::Double);
} else if (IsDataTypeDecimal(lambdaTypeSlot)) {
- avgResultType = &itemType;
+ avgResultType = &inputType;
} else if (IsDataTypeInterval(lambdaTypeSlot)) {
- avgResultType = &itemType;
+ avgResultType = &inputType;
} else {
ctx.AddError(TIssue(ctx.GetPosition(pos),
TStringBuilder() << "Unsupported column type: " << lambdaTypeSlot));
@@ -5466,23 +5466,65 @@ bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& ite
retType = avgResultType;
return true;
- } else if (IsNull(itemType)) {
+ } else if (IsNull(inputType)) {
retType = ctx.MakeType<TNullExprType>();
return true;
} else {
ctx.AddError(TIssue(ctx.GetPosition(pos),
- TStringBuilder() << "Unsupported type: " << FormatType(&itemType) << ". Expected Data or Optional of Data."));
+ TStringBuilder() << "Unsupported type: " << FormatType(&inputType) << ". Expected Data or Optional of Data."));
return false;
}
}
-bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
- if (!itemType.IsComparable()) {
- ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected comparable type, but got: " << itemType));
+bool GetAvgResultTypeOverState(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
+ if (IsNull(inputType)) {
+ retType = &inputType;
+ } else {
+ auto itemType = &inputType;
+ bool isOptional = false;
+ if (itemType->GetKind() == ETypeAnnotationKind::Optional) {
+ isOptional = true;
+ itemType = itemType->Cast<TOptionalExprType>()->GetItemType();
+ }
+
+ if (!EnsureTupleTypeSize(pos, itemType, 2, ctx)) {
+ return false;
+ }
+
+ auto tupleType = itemType->Cast<TTupleExprType>();
+ auto sumType = tupleType->GetItems()[0];
+ const TTypeAnnotationNode* sumTypeOut;
+ if (!GetSumResultType(pos, *sumType, sumTypeOut, ctx)) {
+ return false;
+ }
+
+ if (!IsSameAnnotation(*sumType, *sumTypeOut)) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos),
+ TStringBuilder() << "Mismatch sum type, expected: " << *sumType << ", but got: " << *sumTypeOut));
+ return false;
+ }
+
+ auto countType = tupleType->GetItems()[1];
+ if (!EnsureSpecificDataType(pos, *countType, EDataSlot::Uint64, ctx)) {
+ return false;
+ }
+
+ retType = sumType;
+ if (isOptional) {
+ retType = ctx.MakeType<TOptionalExprType>(retType);
+ }
+ }
+
+ return true;
+}
+
+bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx) {
+ if (!inputType.IsComparable()) {
+ ctx.AddError(TIssue(ctx.GetPosition(pos), TStringBuilder() << "Expected comparable type, but got: " << inputType));
return false;
}
- retType = &itemType;
+ retType = &inputType;
return true;
}
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h
index 59a2ecaa35..e6a19bb723 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.h
+++ b/ydb/library/yql/core/yql_expr_type_annotation.h
@@ -302,8 +302,9 @@ bool EnsureBlockOrScalarType(TPositionHandle position, const TTypeAnnotationNode
const TTypeAnnotationNode* GetBlockItemType(const TTypeAnnotationNode& type, bool& isScalar);
const TTypeAnnotationNode* AggApplySerializedStateType(const TExprNode::TPtr& input, TExprContext& ctx);
-bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
-bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
-bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& itemType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
+bool GetSumResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
+bool GetAvgResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
+bool GetAvgResultTypeOverState(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
+bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode& inputType, const TTypeAnnotationNode*& retType, TExprContext& ctx);
}
diff --git a/ydb/library/yql/dq/opt/dq_opt_log.cpp b/ydb/library/yql/dq/opt/dq_opt_log.cpp
index b8c97743e0..27803e6e54 100644
--- a/ydb/library/yql/dq/opt/dq_opt_log.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_log.cpp
@@ -18,7 +18,7 @@ TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationC
return node;
}
- TAggregateExpander aggExpander(true, node.Ptr(), ctx, typesCtx, false, compactForDistinct, usePhases);
+ TAggregateExpander aggExpander(true, false, node.Ptr(), ctx, typesCtx, false, compactForDistinct, usePhases);
auto result = aggExpander.ExpandAggregate();
YQL_ENSURE(result);
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp
index 276ca6b748..72d8dab524 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp
@@ -926,6 +926,100 @@ TExprBase DqBuildPartitionStage(TExprBase node, TExprContext& ctx, const TParent
return DqBuildPartitionsStageStub<TCoPartitionByKey>(std::move(node), ctx, parentsMap);
}
+TExprBase DqBuildShuffleStage(TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap) {
+ auto shuffleInput = node.Maybe<TCoShuffleByKeys>().Input();
+ if (!shuffleInput.Maybe<TDqCnUnionAll>()) {
+ return node;
+ }
+
+ auto shuffle = node.Cast<TCoShuffleByKeys>();
+ if (!IsDqPureExpr(shuffle.KeySelectorLambda()) ||
+ !IsDqPureExpr(shuffle.ListHandlerLambda()))
+ {
+ return node;
+ }
+
+ auto dqUnion = shuffle.Input().Cast<TDqCnUnionAll>();
+
+ if (!IsSingleConsumerConnection(dqUnion, parentsMap)) {
+ return node;
+ }
+
+ auto keyLambda = shuffle.KeySelectorLambda();
+ TVector<TExprBase> keyElements;
+ if (auto maybeTuple = keyLambda.Body().Maybe<TExprList>()) {
+ auto tuple = maybeTuple.Cast();
+ for (const auto& element : tuple) {
+ keyElements.push_back(element);
+ }
+ } else {
+ keyElements.push_back(keyLambda.Body());
+ }
+
+ TVector<TCoAtom> keyColumns;
+ keyColumns.reserve(keyElements.size());
+ for (auto& element : keyElements) {
+ if (!element.Maybe<TCoMember>()) {
+ return node;
+ }
+
+ auto member = element.Cast<TCoMember>();
+ if (member.Struct().Raw() != keyLambda.Args().Arg(0).Raw()) {
+ return node;
+ }
+
+ keyColumns.push_back(member.Name());
+ }
+
+ if (keyColumns.empty()) {
+ return node;
+ }
+
+ auto connection = Build<TDqCnHashShuffle>(ctx, node.Pos())
+ .Output()
+ .Stage(dqUnion.Output().Stage())
+ .Index(dqUnion.Output().Index())
+ .Build()
+ .KeyColumns()
+ .Add(keyColumns)
+ .Build()
+ .Done();
+
+ TCoArgument programArg = Build<TCoArgument>(ctx, node.Pos())
+ .Name("arg")
+ .Done();
+
+ TVector<TCoArgument> inputArgs;
+ TVector<TExprBase> inputConns;
+
+ inputConns.push_back(connection);
+ inputArgs.push_back(programArg);
+
+ auto handler = shuffle.ListHandlerLambda();
+
+ auto shuffleStage = Build<TDqStage>(ctx, node.Pos())
+ .Inputs()
+ .Add(inputConns)
+ .Build()
+ .Program()
+ .Args(inputArgs)
+ .Body<TCoToStream>()
+ .Input<TExprApplier>()
+ .Apply(handler)
+ .With(handler.Args().Arg(0), programArg)
+ .Build()
+ .Build()
+ .Build()
+ .Settings(TDqStageSettings().BuildNode(ctx, node.Pos()))
+ .Done();
+
+ return Build<TDqCnUnionAll>(ctx, node.Pos())
+ .Output()
+ .Stage(shuffleStage)
+ .Index().Build("0")
+ .Build()
+ .Done();
+}
/*
* Optimizer rule which handles a switch to scalar expression context for aggregation results.
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h
index e0b26e9429..482277bbe8 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy.h
+++ b/ydb/library/yql/dq/opt/dq_opt_phy.h
@@ -46,6 +46,8 @@ NNodes::TExprBase DqBuildPartitionsStage(NNodes::TExprBase node, TExprContext& c
NNodes::TExprBase DqBuildPartitionStage(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap);
+NNodes::TExprBase DqBuildShuffleStage(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap);
+
NNodes::TExprBase DqBuildAggregationResultStage(NNodes::TExprBase node, TExprContext& ctx,
IOptimizationContext& optCtx);
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
index 8f3744bdc3..a37b84c33b 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_rh_hash.h
@@ -4,6 +4,8 @@
#include <util/generic/yexception.h>
#include <vector>
+#include <util/digest/city.h>
+
namespace NKikimr {
namespace NMiniKQL {
@@ -17,6 +19,7 @@ protected:
explicit TRobinHoodHashBase(ui64 initialCapacity = 1u << 8)
: Capacity(initialCapacity)
+ , SelfHash(GetSelfHash(this))
{
Y_ENSURE((Capacity & (Capacity - 1)) == 0);
}
@@ -91,7 +94,7 @@ public:
private:
Y_FORCE_INLINE char* InsertImpl(TKey key, bool& isNew, ui64 capacity, TVec& data) {
isNew = false;
- ui64 bucket = THash()(key) & (capacity - 1);
+ ui64 bucket = (SelfHash ^ THash()(key)) & (capacity - 1);
char* ptr = data.data() + AsDeriv().GetCellSize() * bucket;
TPSLStorage distance = 0;
char* returnPtr;
@@ -168,6 +171,12 @@ private:
ptr = (ptr == data.data() + data.size()) ? data.data() : ptr;
}
+ static ui64 GetSelfHash(void* self) {
+ char buf[sizeof(void*)];
+ *(void**)buf = self;
+ return CityHash64(buf, sizeof(buf));
+ }
+
protected:
void Init() {
Allocate(Capacity, Data);
@@ -195,6 +204,7 @@ private:
ui64 Size = 0;
ui64 Capacity;
TVec Data;
+ const ui64 SelfHash;
};
template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>>
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index 321e345287..dff429c515 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -2453,6 +2453,28 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
return ctx.ProgramBuilder.BlockCombineHashed(arg, filterColumn, keys, aggs, returnType);
});
+ AddCallable("BlockMergeFinalizeHashed", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+ auto arg = MkqlBuildExpr(*node.Child(0), ctx);
+ TVector<ui32> keys;
+ for (const auto& key : node.Child(1)->Children()) {
+ keys.push_back(FromString<ui32>(key->Content()));
+ }
+
+ TVector<TAggInfo> aggs;
+ for (const auto& agg : node.Child(2)->Children()) {
+ TAggInfo info;
+ info.Name = TString(agg->Head().Head().Content());
+ for (ui32 i = 1; i < agg->ChildrenSize(); ++i) {
+ info.ArgsColumns.push_back(FromString<ui32>(agg->Child(i)->Content()));
+ }
+
+ aggs.push_back(info);
+ }
+
+ auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder);
+ return ctx.ProgramBuilder.BlockMergeFinalizeHashed(arg, keys, aggs, returnType);
+ });
+
AddCallable("BlockCompress", [](const TExprNode& node, TMkqlBuildContext& ctx) {
const auto flow = MkqlBuildExpr(node.Head(), ctx);
const auto index = FromString<ui32>(node.Child(1)->Content());
diff --git a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp
index e0c6f75cab..f89ed69a64 100644
--- a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp
+++ b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp
@@ -33,6 +33,7 @@ public:
AddHandler(0, &TCoFlatMapBase::Match, HNDL(BuildFlatmapStage<false>));
AddHandler(0, &TCoCombineByKey::Match, HNDL(PushCombineToStage<false>));
AddHandler(0, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage));
+ AddHandler(0, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage));
AddHandler(0, &TCoPartitionByKey::Match, HNDL(BuildPartitionStage));
AddHandler(0, &TCoAsList::Match, HNDL(BuildAggregationResultStage));
AddHandler(0, &TCoTopSort::Match, HNDL(BuildTopSortStage<false>));
@@ -272,6 +273,10 @@ protected:
return DqBuildPartitionStage(node, ctx, *getParents());
}
+ TMaybeNode<TExprBase> BuildShuffleStage(TExprBase node, TExprContext& ctx, const TGetParents& getParents) {
+ return DqBuildShuffleStage(node, ctx, *getParents());
+ }
+
TMaybeNode<TExprBase> BuildAggregationResultStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx) {
return DqBuildAggregationResultStage(node, ctx, optCtx);
}