summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorwhcrc <[email protected]>2022-06-23 16:44:34 +0300
committerwhcrc <[email protected]>2022-06-23 16:44:34 +0300
commitdc91553267a208ae6cb2eba08557578bf75d764f (patch)
tree0abb741752458e929c69eea60ba8e6e2c66b8886
parentfb769e1561a6939b3033079c9f3a634bc0f390eb (diff)
YQL-14404: support multiple outputs in PROCESS table using udf in DQ
ref:b321ab3bb5e991fe5db33d2a557935112ade8497
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_build.cpp31
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy.cpp95
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy.h2
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp117
4 files changed, 205 insertions, 40 deletions
diff --git a/ydb/library/yql/dq/opt/dq_opt_build.cpp b/ydb/library/yql/dq/opt/dq_opt_build.cpp
index 600dd21b7d9..ffc27fa84f6 100644
--- a/ydb/library/yql/dq/opt/dq_opt_build.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_build.cpp
@@ -109,6 +109,37 @@ void MakeConsumerReplaces(
TNodeOnNodeOwnedMap& replaces,
TExprContext& ctx)
{
+ if (!dqStage.Program().Body().Maybe<TDqReplicate>()) {
+ for (ui32 i = 0; i < consumers.size(); ++i) {
+ TVector<TCoArgument> newArgs;
+ newArgs.reserve(dqStage.Inputs().Size());
+ TNodeOnNodeOwnedMap argsMap;
+ CollectArgsReplaces(dqStage, newArgs, argsMap, ctx);
+ auto newStage = Build<TDqStage>(ctx, dqStage.Pos())
+ .InitFrom(dqStage)
+ .Program()
+ .Args(newArgs)
+ .Body<TCoFlatMap>()
+ .Input(ctx.ReplaceNodes(dqStage.Program().Body().Ptr(), argsMap))
+ .Lambda()
+ .Args({"arg"})
+ .template Body<TCoGuess>()
+ .Variant("arg")
+ .Index(consumers[i].Index())
+ .Build()
+ .Build()
+ .Build()
+ .Build()
+ .Settings(TDqStageSettings().BuildNode(ctx, dqStage.Pos()))
+ .Done().Ptr();
+ auto newOutput = Build<TDqOutput>(ctx, dqStage.Pos())
+ .Stage(newStage)
+ .Index().Build(0)
+ .Done().Ptr();
+ replaces.emplace(consumers[i].Raw(), newOutput);
+ }
+ return;
+ }
auto replicate = dqStage.Program().Body().Cast<TDqReplicate>();
TVector<TExprBase> stageResults;
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp
index ecfe4392a13..ee825d8fa47 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp
@@ -170,6 +170,10 @@ TExprBase DqPushMembersFilterToStage(TExprBase node, TExprContext& ctx, IOptimiz
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TMembersFilter::idx_Input, std::move(connToPushableStage)));
+ }
+
auto lambda = Build<TCoLambda>(ctx, filter.Pos())
.Args({"stream"})
.template Body<TMembersFilter>()
@@ -272,6 +276,32 @@ TMaybeNode<TDqStage> DqPushLambdaToStage(const TDqStage& stage, const TCoAtom& o
return newStage;
}
+TExprNode::TPtr DqBuildPushableStage(const NNodes::TDqConnection& connection, TExprContext& ctx) {
+ auto stage = connection.Output().Stage().Cast<TDqStage>();
+ auto program = stage.Program();
+ if (GetStageOutputsCount(stage) < 2 || program.Body().Maybe<TDqReplicate>()) {
+ return {};
+ }
+
+ auto newStage = Build<TDqStage>(ctx, stage.Pos())
+ .Inputs()
+ .Add(connection)
+ .Build()
+ .Program()
+ .Args({"arg"})
+ .Body("arg")
+ .Build()
+ .Settings(TDqStageSettings().BuildNode(ctx, stage.Pos()))
+ .Done();
+
+ auto output = Build<TDqOutput>(ctx, connection.Pos())
+ .Stage(newStage)
+ .Index().Build(BuildAtom("0", connection.Output().Index().Pos(), ctx))
+ .Done();
+
+ return ctx.ChangeChild(connection.Ref(), TDqConnection::idx_Output, output.Ptr());
+}
+
TMaybeNode<TDqConnection> DqPushLambdaToStageUnionAll(const TDqConnection& connection, TCoLambda& lambda,
const TVector<TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx)
{
@@ -283,7 +313,7 @@ TMaybeNode<TDqConnection> DqPushLambdaToStageUnionAll(const TDqConnection& conne
auto output = Build<TDqOutput>(ctx, connection.Pos())
.Stage(newStage.Cast())
- .Index(connection.Output().Index())
+ .Index().Build(connection.Output().Index())
.Done();
return TDqConnection(ctx.ChangeChild(connection.Ref(), TDqConnection::idx_Output, output.Ptr()));
@@ -318,6 +348,10 @@ TExprBase DqBuildFlatmapStage(TExprBase node, TExprContext& ctx, IOptimizationCo
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoFlatMapBase::idx_Input, std::move(connToPushableStage)));
+ }
+
auto lambda = TCoLambda(ctx.Builder(flatmap.Lambda().Pos())
.Lambda()
.Param("stream")
@@ -368,6 +402,10 @@ TExprBase DqPushBaseLMapToStage(TExprBase node, TExprContext& ctx, IOptimization
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), BaseLMap::idx_Input, std::move(connToPushableStage)));
+ }
+
auto lambda = Build<TCoLambda>(ctx, lmap.Lambda().Pos())
.Args({"stream"})
.template Body<TCoToStream>()
@@ -383,6 +421,28 @@ TExprBase DqPushBaseLMapToStage(TExprBase node, TExprContext& ctx, IOptimization
return node;
}
+ const TTypeAnnotationNode* lmapItemTy = GetSeqItemType(lmap.Ref().GetTypeAnn());
+ if (lmapItemTy->GetKind() == ETypeAnnotationKind::Variant) {
+ // preserve typing by Mux'ing several stage outputs into one
+ const auto variantItemTy = lmapItemTy->template Cast<TVariantExprType>();
+ const auto stageOutputNum = variantItemTy->GetUnderlyingType()->template Cast<TTupleExprType>()->GetSize();
+ TVector<TExprBase> muxParts;
+ muxParts.reserve(stageOutputNum);
+ for (auto i = 0U; i < stageOutputNum; i++) {
+ const auto muxPart = Build<TDqCnUnionAll>(ctx, lmap.Lambda().Pos())
+ .Output()
+ .Stage(result.Output().Stage().Cast())
+ .Index().Build(i)
+ .Build()
+ .Done();
+ muxParts.emplace_back(muxPart);
+ }
+ return Build<TCoMux>(ctx, result.Cast().Pos())
+ .template Input<TExprList>()
+ .Add(muxParts)
+ .Build()
+ .Done();
+ }
return result.Cast();
}
@@ -420,6 +480,10 @@ TExprBase DqPushCombineToStage(TExprBase node, TExprContext& ctx, IOptimizationC
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoCombineByKey::idx_Input, std::move(connToPushableStage)));
+ }
+
auto lambda = Build<TCoLambda>(ctx, combine.Pos())
.Args({"stream"})
.Body<TCoCombineByKey>()
@@ -732,6 +796,10 @@ TExprBase DqBuildTopSortStage(TExprBase node, TExprContext& ctx, IOptimizationCo
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoTopSort::idx_Input, std::move(connToPushableStage)));
+ }
+
auto result = dqUnion.Output().Stage().Program().Body();
auto sortKeySelector = topSort.KeySelectorLambda();
@@ -878,6 +946,9 @@ TExprBase DqBuildSortStage(TExprBase node, TExprContext& ctx, IOptimizationConte
TMaybeNode<TDqStage> outerStage;
if (canMerge && IsMergeConnectionApplicable(sortKeyTypes)) {
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoSortBase::idx_Input, std::move(connToPushableStage)));
+ }
auto lambda = Build<TCoLambda>(ctx, sort.Pos())
.Args({"stream"})
.Body<TCoSort>()
@@ -998,6 +1069,10 @@ TExprBase DqBuildTakeStage(TExprBase node, TExprContext& ctx, IOptimizationConte
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoTake::idx_Input, std::move(connToPushableStage)));
+ }
+
auto result = dqUnion.Output().Stage().Program().Body();
auto stage = dqUnion.Output().Stage();
@@ -1060,6 +1135,10 @@ TExprBase DqBuildTakeSkipStage(TExprBase node, TExprContext& ctx, IOptimizationC
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoTake::idx_Input, std::move(connToPushableStage)));
+ }
+
auto lambda = Build<TCoLambda>(ctx, node.Pos())
.Args({"stream"})
.Body<TCoTake>()
@@ -1113,6 +1192,10 @@ TExprBase DqRewriteLengthOfStageOutput(TExprBase node, TExprContext& ctx, IOptim
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoLength::idx_List, std::move(connToPushableStage)));
+ }
+
auto zero = Build<TCoUint64>(ctx, node.Pos())
.Literal().Build("0")
.Done();
@@ -1475,6 +1558,10 @@ TExprBase DqBuildHasItems(TExprBase node, TExprContext& ctx, IOptimizationContex
auto unionAll = hasItems.List().Cast<TDqCnUnionAll>();
+ if (auto connToPushableStage = DqBuildPushableStage(unionAll, ctx)) {
+ return TExprBase(ctx.ChangeChild(*node.Raw(), TCoHasItems::idx_List, std::move(connToPushableStage)));
+ }
+
// Add LIMIT 1 via Take
auto takeProgram = Build<TCoLambda>(ctx, node.Pos())
.Args({"take_arg"})
@@ -1566,6 +1653,12 @@ TExprBase DqBuildScalarPrecompute(TExprBase node, TExprContext& ctx, IOptimizati
if (!output.Stage().Maybe<TDqStage>()) {
return node;
}
+ if (auto connToPushableStage = DqBuildPushableStage(unionAll, ctx)) {
+ return TExprBase(ctx.ChangeChild(
+ *node.Raw(),
+ node.Maybe<TCoToOptional>() ? TCoToOptional::idx_List : TCoHead::idx_Input,
+ std::move(connToPushableStage)));
+ }
auto stage = output.Stage().Cast<TDqStage>();
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h
index de152bdf1b4..0349e358653 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy.h
+++ b/ydb/library/yql/dq/opt/dq_opt_phy.h
@@ -14,6 +14,8 @@ NNodes::TMaybeNode<NNodes::TDqStage> DqPushLambdaToStage(const NNodes::TDqStage
const NNodes::TCoAtom& outputIndex, NNodes::TCoLambda& lambda,
const TVector<NNodes::TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx);
+TExprNode::TPtr DqBuildPushableStage(const NNodes::TDqConnection& connection, TExprContext& ctx);
+
NNodes::TMaybeNode<NNodes::TDqConnection> DqPushLambdaToStageUnionAll(const NNodes::TDqConnection& connection, NNodes::TCoLambda& lambda,
const TVector<NNodes::TDqConnection>& lambdaInputs, TExprContext& ctx, IOptimizationContext& optCtx);
diff --git a/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp b/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
index a4204becfe9..17e5f0678fb 100644
--- a/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_phy_finalizing.cpp
@@ -1,4 +1,5 @@
#include "dq_opt_phy_finalizing.h"
+#include "ydb/library/yql/core/yql_opt_utils.h"
#include <ydb/library/yql/dq/type_ann/dq_type_ann.h>
#include <ydb/library/yql/utils/log/log.h>
@@ -12,10 +13,7 @@ using namespace NNodes;
namespace {
-// returns new DqStage and list of added output indexes
-std::pair<TDqStage, TVector<TCoAtom>> ReplicateStageOutput(const TDqStage& stage, const TCoAtom& indexAtom,
- const TVector<TCoLambda>& lambdas, TExprContext& ctx)
-{
+ui32 GetStageOutputsCount(const TDqStage& stage, const TCoAtom& indexAtom, TExprContext& ctx) {
auto result = stage.Program().Body();
auto resultType = result.Ref().GetTypeAnn();
@@ -35,6 +33,20 @@ std::pair<TDqStage, TVector<TCoAtom>> ReplicateStageOutput(const TDqStage& stage
} else {
outputsCount = 1;
}
+ return outputsCount;
+}
+
+// returns new DqStage and list of added output indexes
+std::pair<TDqStage, TVector<TCoAtom>> ReplicateStageOutput(const TDqStage& stage, const TCoAtom& indexAtom,
+ const TVector<TCoLambda>& lambdas, TExprContext& ctx)
+{
+ auto result = stage.Program().Body();
+ auto resultType = result.Ref().GetTypeAnn();
+
+ const TTypeAnnotationNode* resultItemType = GetSeqItemType(resultType);
+
+ ui32 index = FromString<ui32>(indexAtom.Value());
+ ui32 outputsCount = GetStageOutputsCount(stage, indexAtom, ctx);
YQL_CLOG(TRACE, CoreDq) << "replicate stage (#" << stage.Ref().UniqueId() << ", " << index << "), outputs: "
<< outputsCount << ", about to add " << lambdas.size() << " copies." << Endl << PrintDqStageOnly(stage, ctx);
@@ -259,6 +271,45 @@ TExprNode::TPtr ReplicateDqOutput(TExprNode::TPtr&& input, const TMultiUsedConne
return ctx.ReplaceNodes(std::move(input), replaces);
}
+TExprNode::TPtr ReplaceStageForConsumer(TDqStage newStage, const TExprNode* consumer, TExprNode::TPtr&& input,
+ TExprContext& ctx, bool skipFirstUsage, const TExprNode* dqConnection, const TVector<TCoAtom>& outputlIndices = {}) {
+ bool isStageConsumer = TMaybeNode<TDqStage>(consumer).IsValid();
+ auto consumerNode = isStageConsumer
+ ? TDqStage(consumer).Inputs().Raw()
+ : consumer;
+
+ ui32 usageIdx = 0;
+ TExprNode::TPtr newConsumer = ctx.ShallowCopy(*consumerNode);
+ for (size_t childIndex = 0; childIndex < newConsumer->ChildrenSize(); ++childIndex) {
+ TExprBase child(newConsumer->Child(childIndex));
+
+ if (child.Raw() == dqConnection) {
+ if (skipFirstUsage && usageIdx == 0) {
+ // Keep first (any of) usage as is.
+ skipFirstUsage = false;
+ continue;
+ }
+
+ const auto newIdx = outputlIndices.empty() ? BuildAtom("0", dqConnection->Pos(), ctx) : outputlIndices[usageIdx];
+ auto newOutput = Build<TDqOutput>(ctx, child.Pos())
+ .Stage(newStage)
+ .Index(newIdx)
+ .Done();
+
+ auto newConnection = ctx.ChangeChild(child.Ref(), TDqConnection::idx_Output, newOutput.Ptr());
+
+ newConsumer = ctx.ChangeChild(*newConsumer, childIndex, std::move(newConnection));
+ ++usageIdx;
+ }
+ }
+
+ if (isStageConsumer) {
+ newConsumer = ctx.ChangeChild(*consumer, TDqStage::idx_Inputs, std::move(newConsumer));
+ }
+
+ return ctx.ReplaceNode(std::move(input), *consumer, newConsumer);
+}
+
TExprNode::TPtr ReplicateDqConnection(TExprNode::TPtr&& input, const TMultiUsedConnection& muConnection,
TExprContext& ctx)
{
@@ -271,6 +322,28 @@ TExprNode::TPtr ReplicateDqConnection(TExprNode::TPtr&& input, const TMultiUsedC
auto& consumers = muConnection.Consumers;
YQL_ENSURE(consumers.size() > 1);
+ if (GetStageOutputsCount(dqStage, outputIndex, ctx) > 1 && !dqStage.Program().Body().Maybe<TDqReplicate>()) {
+ // create a stage with single output, which is used by multiple consumers
+ auto newStage = Build<TDqStage>(ctx, dqStage.Pos())
+ .Inputs()
+ .Add(muConnection.Connection)
+ .Build()
+ .Program()
+ .Args({"arg"})
+ .Body("arg")
+ .Build()
+ .Settings(TDqStageSettings().BuildNode(ctx, dqStage.Pos()))
+ .Done();
+ TNodeSet processed;
+ for (const auto& consumer : consumers) {
+ if (processed.contains(consumer)) {
+ continue;
+ }
+ processed.insert(consumer);
+ input = ReplaceStageForConsumer(newStage, consumer, std::move(input), ctx, /* skipFirstUsage = */ false, muConnection.Connection.Raw());
+ }
+ return input;
+ }
// NOTE: Only handle one consumer at a time, as there might be dependencies between them.
// Ensure stable order by processing connection with minimal ID
@@ -297,41 +370,7 @@ TExprNode::TPtr ReplicateDqConnection(TExprNode::TPtr&& input, const TMultiUsedC
auto [newStage, newAdditionalIndexes] = ReplicateStageOutput(dqStage, outputIndex, lambdas, ctx);
- bool isStageConsumer = TMaybeNode<TDqStage>(consumer).IsValid();
- auto consumerNode = isStageConsumer
- ? TDqStage(consumer).Inputs().Raw()
- : consumer;
-
- ui32 usageIdx = 0;
- bool skipUsage = isLastConsumer;
- TExprNode::TPtr newConsumer = ctx.ShallowCopy(*consumerNode);
- for (size_t childIndex = 0; childIndex < newConsumer->ChildrenSize(); ++childIndex) {
- TExprBase child(newConsumer->Child(childIndex));
-
- if (child.Raw() == muConnection.Connection.Raw()) {
- if (skipUsage && usageIdx == 0) {
- // Keep first (any of) usage as is.
- skipUsage = false;
- continue;
- }
-
- auto newOutput = Build<TDqOutput>(ctx, child.Pos())
- .Stage(newStage)
- .Index(newAdditionalIndexes[usageIdx])
- .Done();
-
- auto newConnection = ctx.ChangeChild(child.Ref(), TDqConnection::idx_Output, newOutput.Ptr());
-
- newConsumer = ctx.ChangeChild(*newConsumer, childIndex, std::move(newConnection));
- ++usageIdx;
- }
- }
-
- if (isStageConsumer) {
- newConsumer = ctx.ChangeChild(*consumer, TDqStage::idx_Inputs, std::move(newConsumer));
- }
-
- auto result = ctx.ReplaceNode(std::move(input), *consumer, newConsumer);
+ auto result = ReplaceStageForConsumer(newStage, consumer, std::move(input), ctx, /* skipFirstUsage = */ isLastConsumer, muConnection.Connection.Raw(), newAdditionalIndexes);
return ctx.ReplaceNode(std::move(result), dqStage.Ref(), newStage.Ptr());
}