aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/core/arrow_kernels/registry/registry.cpp
blob: 4a4e51fa57a2d30e629d027b516d2726554975f1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include "registry.h"
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_node_serialization.h>

#include <memory>

namespace NYql {

namespace {
    class TLoader : std::enable_shared_from_this<TLoader> {
    public:
        TLoader()
            : Alloc_(__LOCATION__)
            , Env_(Alloc_)
        {
            Alloc_.Release();
        }

        void Init(const TString& serialized,
            const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry,
            const NKikimr::NMiniKQL::TComputationNodeFactory& nodeFactory) {
            TGuard<NKikimr::NMiniKQL::TScopedAlloc> allocGuard(Alloc_);
            Pgm_ = NKikimr::NMiniKQL::DeserializeRuntimeNode(serialized, Env_);
            auto pgmTop = AS_CALLABLE("BlockAsTuple", Pgm_);
            MKQL_ENSURE(pgmTop->GetInputsCount() == 2, "Expected tuple of 2 items");
            auto argsNode = pgmTop->GetInput(0);
            MKQL_ENSURE(!argsNode.IsImmediate() && argsNode.GetNode()->GetType()->IsCallable(), "Expected callable");
            auto argsCallable = static_cast<NKikimr::NMiniKQL::TCallable*>(argsNode.GetNode());

            Explorer_.Walk(Pgm_.GetNode(), Env_);
            NKikimr::NMiniKQL::TComputationPatternOpts opts(Alloc_.Ref(), Env_, nodeFactory,
                &functionRegistry, NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, "OFF", NKikimr::NMiniKQL::EGraphPerProcess::Multi);
            std::vector<NKikimr::NMiniKQL::TNode*> entryPoints;
            if (argsCallable->GetType()->GetName() == "BlockAsTuple") {
                for (ui32 i = 0; i < argsCallable->GetInputsCount(); ++i) {
                    entryPoints.emplace_back(argsCallable->GetInput(i).GetNode());
                }
            }

            Alloc_.Ref().UseRefLocking = true;
            Pattern_ = NKikimr::NMiniKQL::MakeComputationPattern(Explorer_, Pgm_, entryPoints, opts);
            RandomProvider_ = CreateDefaultRandomProvider();
            TimeProvider_ = CreateDefaultTimeProvider();

            Graph_ = Pattern_->Clone(opts.ToComputationOptions(*RandomProvider_, *TimeProvider_));
            NKikimr::NMiniKQL::TBindTerminator terminator(Graph_->GetTerminator());
            Topology_ = Graph_->GetKernelsTopology();
            MKQL_ENSURE(Topology_->Items.size() >= 3, "Expected at least 3 kernels");
        }

        ~TLoader() {
            Alloc_.Acquire();
        }

        ui32 GetKernelsCount() const {
            return Topology_->Items.size() - 3;
        }

        const arrow::compute::ScalarKernel* GetKernel(ui32 index) const {
            MKQL_ENSURE(index < Topology_->Items.size() - 3, "Bad kernel index");
            return &Topology_->Items[index].Node->GetArrowKernel();
        }

    private:
        NKikimr::NMiniKQL::TScopedAlloc Alloc_;
        NKikimr::NMiniKQL::TTypeEnvironment Env_;
        NKikimr::NMiniKQL::TRuntimeNode Pgm_;
        NKikimr::NMiniKQL::TExploringNodeVisitor Explorer_;
        NKikimr::NMiniKQL::IComputationPattern::TPtr Pattern_;
        TIntrusivePtr<IRandomProvider> RandomProvider_;
        TIntrusivePtr<ITimeProvider> TimeProvider_;
        THolder<NKikimr::NMiniKQL::IComputationGraph> Graph_;
        const NKikimr::NMiniKQL::TArrowKernelsTopology* Topology_;
    };
}

std::vector<std::shared_ptr<const arrow::compute::ScalarKernel>> LoadKernels(const TString& serialized,
    const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry,
    const NKikimr::NMiniKQL::TComputationNodeFactory& nodeFactory) {
    auto loader = std::make_shared<TLoader>();
    loader->Init(serialized, functionRegistry, nodeFactory);
    std::vector<std::shared_ptr<const arrow::compute::ScalarKernel>> ret(loader->GetKernelsCount());
    auto deleter = [loader](const arrow::compute::ScalarKernel*) {};
    for (ui32 i = 0; i < ret.size(); ++i) {
        ret[i] = std::shared_ptr<const arrow::compute::ScalarKernel>(loader->GetKernel(ret.size() - 1 - i), deleter);
    }

    return ret;
}

}