diff options
| author | atarasov5 <[email protected]> | 2025-08-01 12:50:59 +0300 | 
|---|---|---|
| committer | atarasov5 <[email protected]> | 2025-08-01 13:19:39 +0300 | 
| commit | dd74f77fb65e154f13376538c07dc908ac55cc3b (patch) | |
| tree | dff76c3080e7b59b0008b13729bfcafedfec9d84 /yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp | |
| parent | 0cb0942d9ea385bc978073f3b4ea866052f2b73e (diff) | |
YQL-20229: Add WideMap stream overload
commit_hash:297647045a9ca9c90137f0ec6488181f81fe2447
Diffstat (limited to 'yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp')
| -rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp | 152 | 
1 files changed, 136 insertions, 16 deletions
| diff --git a/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp b/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp index e970cd3f05a..f4a7e6ce34f 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_wide_map.cpp @@ -10,10 +10,10 @@ using NYql::EnsureDynamicCast;  namespace { -class TWideMapWrapper : public TStatelessWideFlowCodegeneratorNode<TWideMapWrapper> { -using TBaseComputation = TStatelessWideFlowCodegeneratorNode<TWideMapWrapper>; +class TWideMapFlowWrapper : public TStatelessWideFlowCodegeneratorNode<TWideMapFlowWrapper> { +using TBaseComputation = TStatelessWideFlowCodegeneratorNode<TWideMapFlowWrapper>;  public: -    TWideMapWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, TComputationNodePtrVector&& newItems) +    TWideMapFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, TComputationNodePtrVector&& newItems)          : TBaseComputation(flow)          , Flow(flow)          , Items(std::move(items)) @@ -89,8 +89,8 @@ public:  private:      void RegisterDependencies() const final {          if (const auto flow = FlowDependsOn(Flow)) { -            std::for_each(Items.cbegin(), Items.cend(), std::bind(&TWideMapWrapper::Own, flow, std::placeholders::_1)); -            std::for_each(NewItems.cbegin(), NewItems.cend(), std::bind(&TWideMapWrapper::DependsOn, flow, std::placeholders::_1)); +            std::for_each(Items.cbegin(), Items.cend(), std::bind(&TWideMapFlowWrapper::Own, flow, std::placeholders::_1)); +            std::for_each(NewItems.cbegin(), NewItems.cend(), std::bind(&TWideMapFlowWrapper::DependsOn, flow, std::placeholders::_1));          }      } @@ -102,25 +102,145 @@ private:      const ui32 WideFieldsIndex;  }; +class TWideMapStreamWrapper: public TMutableComputationNode<TWideMapStreamWrapper> { +    using TBaseComputation = TMutableComputationNode<TWideMapStreamWrapper>; + +public: +    TWideMapStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, TComputationExternalNodePtrVector&& items, TComputationNodePtrVector&& newItems) +        : TBaseComputation(mutables) +        , Stream(stream) +        , Items(std::move(items)) +        , NewItems(std::move(newItems)) +        , PasstroughtMap(GetPasstroughtMapOneToOne(Items, NewItems)) +        , ReversePasstroughtMap(GetPasstroughtMapOneToOne(NewItems, Items)) +        , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Items.size())) +    { +    } + +    NYql::NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { +        return ctx.HolderFactory.Create<TStreamValue>( +                    ctx, +                     ctx.HolderFactory, +                     Stream->GetValue(ctx), +                     Items, +                     NewItems, +                     PasstroughtMap, +                     ReversePasstroughtMap); +    } + +private: +    class TStreamValue: public TComputationValue<TStreamValue> { +        using TBase = TComputationValue<TStreamValue>; + +    public: +        TStreamValue(TMemoryUsageInfo* memInfo, +                     TComputationContext& compCtx, +                     const THolderFactory& holderFactory, +                     NYql::NUdf::TUnboxedValue&& stream, +                     const TComputationExternalNodePtrVector& items, +                     const TComputationNodePtrVector& newItems, +                     TPassthroughSpan passtroughtMap, +                     TPassthroughSpan reversePasstroughtMap) +            : TBase(memInfo) +            , CompCtx(compCtx) +            , HolderFactory(holderFactory) +            , Stream(std::move(stream)) +            , Items(items) +            , NewItems(newItems) +            , PasstroughtMap(std::move(passtroughtMap)) +            , ReversePasstroughtMap(std::move(reversePasstroughtMap)) +        { +            State.resize(Items.size()); +            Y_UNUSED(HolderFactory); +        } + +        NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) final { +            Y_UNUSED(width); +            if (const auto result = Stream.WideFetch(State.data(), State.size()); NUdf::EFetchStatus::Ok != result) { +                return result; +            } + +            for (auto i = 0U; i < Items.size(); ++i) { +                if (const auto& map = PasstroughtMap[i]; map && !Items[i]->GetDependencesCount()) { +                    output[*map] = State[i]; +                } else { +                    Items[i]->RefValue(CompCtx) = State[i]; +                } +            } + +            for (auto i = 0U; i < NewItems.size(); ++i) { +                if (const auto& map = ReversePasstroughtMap[i]) { +                    if (const auto from = *map; !Items[from]->GetDependencesCount()) { +                        if (const auto first = *PasstroughtMap[from]; first != i) { +                            output[i] = output[first]; +                        } +                        continue; +                    } +                } + +                output[i] = NewItems[i]->GetValue(CompCtx); +            } +            return NUdf::EFetchStatus::Ok; +        } + +    private: +        TComputationContext& CompCtx; +        const THolderFactory& HolderFactory; +        NUdf::TUnboxedValue Stream; +        const TComputationExternalNodePtrVector& Items; +        const TComputationNodePtrVector& NewItems; + +        const TPassthroughSpan PasstroughtMap; +        const TPassthroughSpan ReversePasstroughtMap; +        TUnboxedValueVector State; +    }; + +    void RegisterDependencies() const final { +        Stream->AddDependence(this); +        std::for_each(Items.cbegin(), Items.cend(), std::bind(&TWideMapStreamWrapper::Own, this, std::placeholders::_1)); +        std::for_each(NewItems.cbegin(), NewItems.cend(), std::bind(&TWideMapStreamWrapper::DependsOn, this, std::placeholders::_1)); +    } + +    IComputationNode* const Stream; +    const TComputationExternalNodePtrVector Items; +    const TComputationNodePtrVector NewItems; +    const TPasstroughtMap PasstroughtMap; +    const TPasstroughtMap ReversePasstroughtMap; + +    const ui32 WideFieldsIndex; +};  }  IComputationNode* WrapWideMap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {      MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected argument."); -    const auto inputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType())); -    const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType())); +    MKQL_ENSURE(callable.GetInput(0U).GetStaticType()->IsFlow() || callable.GetInput(0U).GetStaticType()->IsStream(), +                "Expected stream or flow for input."); + +    const auto inputWidth = GetWideComponentsCount(callable.GetInput(0U).GetStaticType()); +    const auto outputWidth = GetWideComponentsCount(callable.GetType()->GetReturnType()); + +    if (callable.GetInput(0U).GetStaticType()->IsFlow()) { +        MKQL_ENSURE(callable.GetType()->GetReturnType()->IsFlow(), "Expected flow return type."); +    } else { +        MKQL_ENSURE(callable.GetType()->GetReturnType()->IsStream(), "Expected stream return type."); +    } +      MKQL_ENSURE(callable.GetInputsCount() == inputWidth + outputWidth + 1U, "Wrong signature."); -    const auto flow = LocateNode(ctx.NodeLocator, callable, 0U); -    if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { -        TComputationNodePtrVector newItems(outputWidth, nullptr); -        ui32 index = inputWidth; -        std::generate(newItems.begin(), newItems.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); }); +    const auto flowOrStream = LocateNode(ctx.NodeLocator, callable, 0U); +    TComputationNodePtrVector newItems(outputWidth, nullptr); +    ui32 index = inputWidth; +    std::generate(newItems.begin(), newItems.end(), [&]() { return LocateNode(ctx.NodeLocator, callable, ++index); }); -        TComputationExternalNodePtrVector args(inputWidth, nullptr); -        index = 0U; -        std::generate(args.begin(), args.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); }); +    TComputationExternalNodePtrVector args(inputWidth, nullptr); +    index = 0U; +    std::generate(args.begin(), args.end(), [&]() { return LocateExternalNode(ctx.NodeLocator, callable, ++index); }); -        return new TWideMapWrapper(ctx.Mutables, wide, std::move(args), std::move(newItems)); +    if (const auto flow = dynamic_cast<IComputationWideFlowNode*>(flowOrStream)) { +        return new TWideMapFlowWrapper(ctx.Mutables, flow, std::move(args), std::move(newItems)); +    } else { +        auto* stream = flowOrStream; +        return new TWideMapStreamWrapper(ctx.Mutables, stream, std::move(args), std::move(newItems));      }      THROW yexception() << "Expected wide flow."; | 
