summaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp
diff options
context:
space:
mode:
authoratarasov5 <[email protected]>2025-08-01 12:50:59 +0300
committeratarasov5 <[email protected]>2025-08-01 13:19:39 +0300
commitdd74f77fb65e154f13376538c07dc908ac55cc3b (patch)
treedff76c3080e7b59b0008b13729bfcafedfec9d84 /yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp
parent0cb0942d9ea385bc978073f3b4ea866052f2b73e (diff)
YQL-20229: Add WideMap stream overload
commit_hash:297647045a9ca9c90137f0ec6488181f81fe2447
Diffstat (limited to 'yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp')
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp152
1 files changed, 136 insertions, 16 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp b/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp
index e970cd3f05a..f4a7e6ce34f 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp
@@ -10,10 +10,10 @@ using NYql::EnsureDynamicCast;
namespace {
-class TWideMapWrapper : public TStatelessWideFlowCodegeneratorNode<TWideMapWrapper> {
-using TBaseComputation = TStatelessWideFlowCodegeneratorNode<TWideMapWrapper>;
+class TWideMapFlowWrapper : public TStatelessWideFlowCodegeneratorNode<TWideMapFlowWrapper> {
+using TBaseComputation = TStatelessWideFlowCodegeneratorNode<TWideMapFlowWrapper>;
public:
- TWideMapWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, TComputationNodePtrVector&& newItems)
+ TWideMapFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, TComputationNodePtrVector&& newItems)
: TBaseComputation(flow)
, Flow(flow)
, Items(std::move(items))
@@ -89,8 +89,8 @@ public:
private:
void RegisterDependencies() const final {
if (const auto flow = FlowDependsOn(Flow)) {
- std::for_each(Items.cbegin(), Items.cend(), std::bind(&TWideMapWrapper::Own, flow, std::placeholders::_1));
- std::for_each(NewItems.cbegin(), NewItems.cend(), std::bind(&TWideMapWrapper::DependsOn, flow, std::placeholders::_1));
+ std::for_each(Items.cbegin(), Items.cend(), std::bind(&TWideMapFlowWrapper::Own, flow, std::placeholders::_1));
+ std::for_each(NewItems.cbegin(), NewItems.cend(), std::bind(&TWideMapFlowWrapper::DependsOn, flow, std::placeholders::_1));
}
}
@@ -102,25 +102,145 @@ private:
const ui32 WideFieldsIndex;
};
+class TWideMapStreamWrapper: public TMutableComputationNode<TWideMapStreamWrapper> {
+ using TBaseComputation = TMutableComputationNode<TWideMapStreamWrapper>;
+
+public:
+ TWideMapStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, TComputationExternalNodePtrVector&& items, TComputationNodePtrVector&& newItems)
+ : TBaseComputation(mutables)
+ , Stream(stream)
+ , Items(std::move(items))
+ , NewItems(std::move(newItems))
+ , PasstroughtMap(GetPasstroughtMapOneToOne(Items, NewItems))
+ , ReversePasstroughtMap(GetPasstroughtMapOneToOne(NewItems, Items))
+ , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Items.size()))
+ {
+ }
+
+ NYql::NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
+ return ctx.HolderFactory.Create<TStreamValue>(
+ ctx,
+ ctx.HolderFactory,
+ Stream->GetValue(ctx),
+ Items,
+ NewItems,
+ PasstroughtMap,
+ ReversePasstroughtMap);
+ }
+
+private:
+ class TStreamValue: public TComputationValue<TStreamValue> {
+ using TBase = TComputationValue<TStreamValue>;
+
+ public:
+ TStreamValue(TMemoryUsageInfo* memInfo,
+ TComputationContext& compCtx,
+ const THolderFactory& holderFactory,
+ NYql::NUdf::TUnboxedValue&& stream,
+ const TComputationExternalNodePtrVector& items,
+ const TComputationNodePtrVector& newItems,
+ TPassthroughSpan passtroughtMap,
+ TPassthroughSpan reversePasstroughtMap)
+ : TBase(memInfo)
+ , CompCtx(compCtx)
+ , HolderFactory(holderFactory)
+ , Stream(std::move(stream))
+ , Items(items)
+ , NewItems(newItems)
+ , PasstroughtMap(std::move(passtroughtMap))
+ , ReversePasstroughtMap(std::move(reversePasstroughtMap))
+ {
+ State.resize(Items.size());
+ Y_UNUSED(HolderFactory);
+ }
+
+ NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) final {
+ Y_UNUSED(width);
+ if (const auto result = Stream.WideFetch(State.data(), State.size()); NUdf::EFetchStatus::Ok != result) {
+ return result;
+ }
+
+ for (auto i = 0U; i < Items.size(); ++i) {
+ if (const auto& map = PasstroughtMap[i]; map && !Items[i]->GetDependencesCount()) {
+ output[*map] = State[i];
+ } else {
+ Items[i]->RefValue(CompCtx) = State[i];
+ }
+ }
+
+ for (auto i = 0U; i < NewItems.size(); ++i) {
+ if (const auto& map = ReversePasstroughtMap[i]) {
+ if (const auto from = *map; !Items[from]->GetDependencesCount()) {
+ if (const auto first = *PasstroughtMap[from]; first != i) {
+ output[i] = output[first];
+ }
+ continue;
+ }
+ }
+
+ output[i] = NewItems[i]->GetValue(CompCtx);
+ }
+ return NUdf::EFetchStatus::Ok;
+ }
+
+ private:
+ TComputationContext& CompCtx;
+ const THolderFactory& HolderFactory;
+ NUdf::TUnboxedValue Stream;
+ const TComputationExternalNodePtrVector& Items;
+ const TComputationNodePtrVector& NewItems;
+
+ const TPassthroughSpan PasstroughtMap;
+ const TPassthroughSpan ReversePasstroughtMap;
+ TUnboxedValueVector State;
+ };
+
+ void RegisterDependencies() const final {
+ Stream->AddDependence(this);
+ std::for_each(Items.cbegin(), Items.cend(), std::bind(&TWideMapStreamWrapper::Own, this, std::placeholders::_1));
+ std::for_each(NewItems.cbegin(), NewItems.cend(), std::bind(&TWideMapStreamWrapper::DependsOn, this, std::placeholders::_1));
+ }
+
+ IComputationNode* const Stream;
+ const TComputationExternalNodePtrVector Items;
+ const TComputationNodePtrVector NewItems;
+ const TPasstroughtMap PasstroughtMap;
+ const TPasstroughtMap ReversePasstroughtMap;
+
+ const ui32 WideFieldsIndex;
+};
}
IComputationNode* WrapWideMap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected argument.");
- const auto inputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType()));
- const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
+ MKQL_ENSURE(callable.GetInput(0U).GetStaticType()->IsFlow() || callable.GetInput(0U).GetStaticType()->IsStream(),
+ "Expected stream or flow for input.");
+
+ const auto inputWidth = GetWideComponentsCount(callable.GetInput(0U).GetStaticType());
+ const auto outputWidth = GetWideComponentsCount(callable.GetType()->GetReturnType());
+
+ if (callable.GetInput(0U).GetStaticType()->IsFlow()) {
+ MKQL_ENSURE(callable.GetType()->GetReturnType()->IsFlow(), "Expected flow return type.");
+ } else {
+ MKQL_ENSURE(callable.GetType()->GetReturnType()->IsStream(), "Expected stream return type.");
+ }
+
MKQL_ENSURE(callable.GetInputsCount() == inputWidth + outputWidth + 1U, "Wrong signature.");
- const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
- if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
- TComputationNodePtrVector newItems(outputWidth, nullptr);
- ui32 index = inputWidth;
- std::generate(newItems.begin(), newItems.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); });
+ const auto flowOrStream = LocateNode(ctx.NodeLocator, callable, 0U);
+ TComputationNodePtrVector newItems(outputWidth, nullptr);
+ ui32 index = inputWidth;
+ std::generate(newItems.begin(), newItems.end(), [&]() { return LocateNode(ctx.NodeLocator, callable, ++index); });
- TComputationExternalNodePtrVector args(inputWidth, nullptr);
- index = 0U;
- std::generate(args.begin(), args.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
+ TComputationExternalNodePtrVector args(inputWidth, nullptr);
+ index = 0U;
+ std::generate(args.begin(), args.end(), [&]() { return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
- return new TWideMapWrapper(ctx.Mutables, wide, std::move(args), std::move(newItems));
+ if (const auto flow = dynamic_cast<IComputationWideFlowNode*>(flowOrStream)) {
+ return new TWideMapFlowWrapper(ctx.Mutables, flow, std::move(args), std::move(newItems));
+ } else {
+ auto* stream = flowOrStream;
+ return new TWideMapStreamWrapper(ctx.Mutables, stream, std::move(args), std::move(newItems));
}
THROW yexception() << "Expected wide flow.";