aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/mkql_function_metadata.h
blob: 45c677a0998271fce4f4859ad36b76be252c414a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#pragma once

#include <yql/essentials/minikql/defs.h>
#include <yql/essentials/minikql/mkql_node.h>
#include <yql/essentials/public/udf/udf_value.h>
#include <util/digest/numeric.h>
#include <util/generic/vector.h>

#include <arrow/compute/kernel.h>

namespace NKikimr {

namespace NMiniKQL {

using TFunctionPtr = NUdf::TUnboxedValuePod (*)(const NUdf::TUnboxedValuePod* args);

struct TFunctionParamMetadata {
    enum EFlags : ui16 {
        FlagIsNullable = 0x01,
    };

    TFunctionParamMetadata() = default;

    TFunctionParamMetadata(NUdf::TDataTypeId schemeType, ui32 flags)
        : SchemeType(schemeType)
        , Flags(flags)
    {}

    bool IsNullable() const {
        return Flags & FlagIsNullable;
    }

    NUdf::TDataTypeId SchemeType = 0;
    ui16 Flags = 0;
};

struct TFunctionDescriptor {
    TFunctionDescriptor() = default;

    TFunctionDescriptor(const TFunctionParamMetadata* resultAndArgs, TFunctionPtr function, void* generator = nullptr)
        : ResultAndArgs(resultAndArgs)
        , Function(function)
        , Generator(generator)
    {}

    const TFunctionParamMetadata* ResultAndArgs = nullptr; // ends with SchemeType zero
    TFunctionPtr Function = nullptr;
    void *Generator = nullptr;
};

using TFunctionParamMetadataList = std::vector<TFunctionParamMetadata>;
using TArgType = std::pair<NUdf::TDataTypeId, bool>; // type with optional flag
using TDescriptionList = std::vector<TFunctionDescriptor>;
using TFunctionsMap = std::unordered_map<TString, TDescriptionList>;

class TKernel;

class TKernelFamily {
public:
    const arrow::compute::FunctionOptions* FunctionOptions;

    TKernelFamily(const arrow::compute::FunctionOptions* functionOptions = nullptr)
        : FunctionOptions(functionOptions)
    {}

    virtual ~TKernelFamily() = default;
    virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0;
    virtual TVector<const TKernel*> GetAllKernels() const = 0;
};

class TKernel {
public:
    enum class ENullMode {
        Default,
        AlwaysNull,
        AlwaysNotNull
    };

    const TKernelFamily& Family;
    const std::vector<NUdf::TDataTypeId> ArgTypes;
    const NUdf::TDataTypeId ReturnType;
    const ENullMode NullMode;

    TKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, ENullMode nullMode)
        : Family(family)
        , ArgTypes(argTypes)
        , ReturnType(returnType)
        , NullMode(nullMode)
    {
    }

    virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0;
    virtual std::shared_ptr<arrow::compute::ScalarKernel> MakeArrowKernel(const TVector<TType*>& argTypes, TType* resultType) const = 0;
    virtual bool IsPolymorphic() const = 0;

    virtual ~TKernel() = default;
};

using TKernelMapKey = std::pair<std::vector<NUdf::TDataTypeId>, NUdf::TDataTypeId>;
struct TTypeHasher {
    std::size_t operator()(const TKernelMapKey& s) const noexcept {
        size_t r = 0;
        for (const auto& x : s.first) {
            r = CombineHashes<size_t>(r, x);
        }
        r = CombineHashes<size_t>(r, s.second);

        return r;
    }
};

using TKernelMap = std::unordered_map<TKernelMapKey, std::unique_ptr<TKernel>, TTypeHasher>;

using TKernelFamilyMap = std::unordered_map<TString, std::unique_ptr<TKernelFamily>>;

class TKernelFamilyBase : public TKernelFamily
{
public:
    TKernelFamilyBase(const arrow::compute::FunctionOptions* functionOptions = nullptr);

    const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const final;
    TVector<const TKernel*> GetAllKernels() const final;

    void Adopt(const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, std::unique_ptr<TKernel>&& kernel);
private:
    TKernelMap KernelMap;
};

class IBuiltinFunctionRegistry: public TThrRefBase, private TNonCopyable
{
public:
    typedef TIntrusivePtr<IBuiltinFunctionRegistry> TPtr;

    virtual ui64 GetMetadataEtag() const = 0;

    virtual void PrintInfoTo(IOutputStream& out) const = 0;

    virtual void Register(const std::string_view& name, const TFunctionDescriptor& description) = 0;

    virtual bool HasBuiltin(const std::string_view& name) const = 0;

    virtual void RegisterAll(TFunctionsMap&& functions, TFunctionParamMetadataList&& arguments) = 0;

    virtual const TFunctionsMap& GetFunctions() const = 0;

    virtual TFunctionDescriptor GetBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const = 0;

    virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0;

    virtual void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) = 0;

    virtual TVector<std::pair<TString, const TKernelFamily*>> GetAllKernelFamilies() const = 0;
};

}
}