aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_callable.cpp
blob: 4f49d4e4fef5ee32cbbeac445e0bfdcb430ac45f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#include "mkql_callable.h"
#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h>  // Y_IGNORE
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/mkql_node_cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

class TCallableWrapper : public TCustomValueCodegeneratorNode<TCallableWrapper> {
    typedef TCustomValueCodegeneratorNode<TCallableWrapper> TBaseComputation;
private:
    class TValue : public TComputationValue<TValue> {
    public:
        TValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, IComputationNode* resultNode,
            const TComputationExternalNodePtrVector& argNodes)
            : TComputationValue(memInfo)
            , CompCtx(compCtx)
            , ResultNode(resultNode)
            , ArgNodes(argNodes)
        {}

    private:
        NUdf::TUnboxedValue Run(const NUdf::IValueBuilder*, const NUdf::TUnboxedValuePod* args) const override
        {
            for (const auto node : ArgNodes) {
                node->SetValue(CompCtx, NUdf::TUnboxedValuePod(*args++));
            }

            return ResultNode->GetValue(CompCtx);
        }

        TComputationContext& CompCtx;
        IComputationNode *const ResultNode;
        const TComputationExternalNodePtrVector ArgNodes;
    };

    class TCodegenValue : public TComputationValue<TCodegenValue> {
    public:
        using TBase = TComputationValue<TCodegenValue>;

        using TRunPtr = NUdf::TUnboxedValuePod (*)(TComputationContext*, const NUdf::TUnboxedValuePod*);

        TCodegenValue(TMemoryUsageInfo* memInfo, TRunPtr run, TComputationContext* ctx)
            : TBase(memInfo)
            , RunFunc(run)
            , Ctx(ctx)
        {}

    private:
        NUdf::TUnboxedValue Run(const NUdf::IValueBuilder*, const NUdf::TUnboxedValuePod* args) const override {
            return RunFunc(Ctx, args);
        }

        const TRunPtr RunFunc;
        TComputationContext* const Ctx;
    };
public:
    TCallableWrapper(TComputationMutables& mutables, IComputationNode* resultNode, TComputationExternalNodePtrVector&& argNodes)
        : TBaseComputation(mutables)
        , ResultNode(resultNode)
        , ArgNodes(std::move(argNodes))
    {
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
#ifndef MKQL_DISABLE_CODEGEN
        if (ctx.ExecuteLLVM && Run)
            return ctx.HolderFactory.Create<TCodegenValue>(Run, &ctx);
#endif
        return ctx.HolderFactory.Create<TValue>(ctx, ResultNode, ArgNodes);
    }

private:
    void RegisterDependencies() const final {
        for (const auto& arg : ArgNodes) {
            Own(arg);
        }

        DependsOn(ResultNode);
    }

#ifndef MKQL_DISABLE_CODEGEN
    void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
        RunFunc = GenerateRun(codegen);
        codegen.ExportSymbol(RunFunc);
    }

    void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
        if (RunFunc)
            Run = reinterpret_cast<TRunPtr>(codegen.GetPointerToFunction(RunFunc));
    }

    Function* GenerateRun(NYql::NCodegen::ICodegen& codegen) const {
        auto& module = codegen.GetModule();
        auto& context = codegen.GetContext();

        const auto& name = TBaseComputation::MakeName("Run");
        if (const auto f = module.getFunction(name.c_str()))
            return f;

        const auto valueType = Type::getInt128Ty(context);
        const auto argsType = ArrayType::get(valueType, ArgNodes.size());
        const auto contextType = GetCompContextType(context);

        const auto funcType = codegen.GetEffectiveTarget() != NYql::NCodegen::ETarget::Windows ?
            FunctionType::get(valueType, {PointerType::getUnqual(contextType), PointerType::getUnqual(argsType)}, false):
            FunctionType::get(Type::getVoidTy(context), {PointerType::getUnqual(valueType), PointerType::getUnqual(contextType), PointerType::getUnqual(argsType)}, false);

        TCodegenContext ctx(codegen);
        ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());

        DISubprogramAnnotator annotator(ctx, ctx.Func);
        

        auto args = ctx.Func->arg_begin();

        const auto resultArg = codegen.GetEffectiveTarget() == NYql::NCodegen::ETarget::Windows ? &*args++ : nullptr;
        if (resultArg) {
            resultArg->addAttr(Attribute::StructRet);
            resultArg->addAttr(Attribute::NoAlias);
        }

        ctx.Ctx = &*args;
        const auto argsPtr = &*++args;

        const auto main = BasicBlock::Create(context, "main", ctx.Func);
        auto block = main;

        const auto arguments = new LoadInst(argsType, argsPtr, "arguments", block);

        unsigned i = 0U;
        for (const auto node : ArgNodes) {
            const auto arg = ExtractValueInst::Create(arguments, {i++}, "arg", block);
            const auto codegenArgNode = dynamic_cast<ICodegeneratorExternalNode*>(node);
            MKQL_ENSURE(codegenArgNode, "Argument must be codegenerator node.");
            codegenArgNode->CreateSetValue(ctx, block, arg);
        }

        const auto result = GetNodeValue(ResultNode, ctx, block);

        if (resultArg) {
            new StoreInst(result, resultArg, block);
            ReturnInst::Create(context, block);
        } else {
            ReturnInst::Create(context, result, block);
        }
        return ctx.Func;
    }

    using TRunPtr = TCodegenValue::TRunPtr;

    Function* RunFunc = nullptr;

    TRunPtr Run = nullptr;
#endif

    IComputationNode *const ResultNode;
    const TComputationExternalNodePtrVector ArgNodes;
};

}

IComputationNode* WrapCallable(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected at least one argument");

    const auto argsCount = callable.GetInputsCount() - 1U;
    const auto resultNode = LocateNode(ctx.NodeLocator, callable, argsCount);

    TComputationExternalNodePtrVector argNodes(argsCount);
    for (ui32 i = 0U; i < argsCount; ++i) {
        const auto listItem = AS_CALLABLE("Arg", callable.GetInput(i));
        MKQL_ENSURE(listItem->GetType()->GetName() == "Arg", "Wrong Callable arguments");
        MKQL_ENSURE(listItem->GetInputsCount() == 0, "Wrong Callable arguments");
        MKQL_ENSURE(listItem->GetType()->IsMergeDisabled(), "Merge mode is not disabled");

        argNodes[i] = LocateExternalNode(ctx.NodeLocator, callable, i);
    }
    return new TCallableWrapper(ctx.Mutables, resultNode, std::move(argNodes));
}

}
}