aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_block_func.cpp
blob: d94f0ff400a4bc3b4b4f1eeec3be9e063312075c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include "mkql_block_func.h"

#include <yql/essentials/minikql/computation/mkql_block_impl.h>

#include <yql/essentials/minikql/arrow/arrow_defs.h>
#include <yql/essentials/minikql/mkql_node_builder.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_type_builder.h>

#include <arrow/compute/cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) {
    std::vector<NUdf::TDataTypeId> argTypes;
    for (const auto& t : inputTypes) {
        auto asBlockType = AS_TYPE(TBlockType, t);
        bool isOptional;
        auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
        argTypes.push_back(dataType->GetSchemeType());
    }

    NUdf::TDataTypeId returnTypeId;
    {
        auto asBlockType = AS_TYPE(TBlockType, returnType);
        bool isOptional;
        auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
        returnTypeId = dataType->GetSchemeType();
    }

    auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size(), returnTypeId);
    MKQL_ENSURE(kernel, "Can't find kernel for " << funcName);
    return *kernel;
}

} // namespace

IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg");
    const auto funcNameData = AS_VALUE(TDataLiteral, callable.GetInput(0));
    const auto funcName = TString(funcNameData->AsValue().AsStringRef());
    TComputationNodePtrVector argsNodes;
    TVector<TType*> argsTypes;
    const auto callableType = callable.GetType();
    for (ui32 i = 1; i < callable.GetInputsCount(); ++i) {
        argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i));
        argsTypes.push_back(callableType->GetArgumentType(i));
    }

    const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType());
    if (kernel.IsPolymorphic()) {
        auto arrowKernel = kernel.MakeArrowKernel(argsTypes, callableType->GetReturnType());
        return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, *arrowKernel, arrowKernel, kernel.Family.FunctionOptions);
    } else {
        return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions);
    }
}

}
}