aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_block_func.cpp
blob: 91008c7f38f962f62b1cf152428bc1b6257da92d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "mkql_block_func.h"

#include <yql/essentials/minikql/computation/mkql_block_impl.h>

#include <yql/essentials/minikql/arrow/arrow_defs.h>
#include <yql/essentials/minikql/mkql_node_builder.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_type_builder.h>

#include <arrow/compute/cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) {
    std::vector<NUdf::TDataTypeId> argTypes;
    for (const auto& t : inputTypes) {
        auto asBlockType = AS_TYPE(TBlockType, t);
        bool isOptional;
        auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
        argTypes.push_back(dataType->GetSchemeType());
    }

    NUdf::TDataTypeId returnTypeId;
    {
        auto asBlockType = AS_TYPE(TBlockType, returnType);
        bool isOptional;
        auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
        returnTypeId = dataType->GetSchemeType();
    }

    auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size(), returnTypeId);
    MKQL_ENSURE(kernel, "Can't find kernel for " << funcName);
    return *kernel;
}

class TBlockBitCastWrapper : public TBlockFuncNode {
public:
    TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to)
        : TBlockFuncNode(mutables, "BitCast", { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions)
        , CastOptions(false)
    {
    }
private:
    static const arrow::compute::ScalarKernel& ResolveKernel(TType* from, TType* to) {
        std::shared_ptr<arrow::DataType> type;
        MKQL_ENSURE(ConvertArrowType(to, type), "can't get arrow type");

        auto function = ARROW_RESULT(arrow::compute::GetCastFunction(type));
        MKQL_ENSURE(function != nullptr, "missing function");
        MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function");

        std::vector<arrow::ValueDescr> args = { ToValueDescr(from) };
        const auto kernel = ARROW_RESULT(function->DispatchExact(args));
        return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
    }

    const arrow::compute::CastOptions CastOptions;
};

} // namespace

IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg");
    const auto funcNameData = AS_VALUE(TDataLiteral, callable.GetInput(0));
    const auto funcName = TString(funcNameData->AsValue().AsStringRef());
    TComputationNodePtrVector argsNodes;
    TVector<TType*> argsTypes;
    const auto callableType = callable.GetType();
    for (ui32 i = 1; i < callable.GetInputsCount(); ++i) {
        argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i));
        argsTypes.push_back(callableType->GetArgumentType(i));
    }

    const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType());
    if (kernel.IsPolymorphic()) {
        auto arrowKernel = kernel.MakeArrowKernel(argsTypes, callableType->GetReturnType());
        return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, *arrowKernel, arrowKernel, kernel.Family.FunctionOptions);
    } else {
        return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions);
    }
}

IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
    auto argNode = LocateNode(ctx.NodeLocator, callable, 0);
    MKQL_ENSURE(callable.GetInput(1).GetStaticType()->IsType(), "Expected type");
    return new TBlockBitCastWrapper(ctx.Mutables,
        argNode,
        callable.GetType()->GetArgumentType(0),
        static_cast<TType*>(callable.GetInput(1).GetNode())
    );
}

}
}