aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
blob: 5aacd3a6699453178ebd5415cf45d334c664134e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "yql_simple_arrow_resolver.h"

#include <yql/essentials/minikql/arrow/mkql_functions.h>
#include <yql/essentials/minikql/mkql_program_builder.h>
#include <yql/essentials/minikql/mkql_type_builder.h>
#include <yql/essentials/minikql/mkql_function_registry.h>
#include <yql/essentials/providers/common/mkql/yql_type_mkql.h>

#include <util/stream/null.h>

namespace NYql {

using namespace NKikimr::NMiniKQL;

class TSimpleArrowResolver: public IArrowResolver {
public:
    TSimpleArrowResolver(const IFunctionRegistry& functionRegistry)
        : FunctionRegistry_(functionRegistry)
    {}

private:
    EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
        const TTypeAnnotationNode* returnType, TExprContext& ctx) const override
    {
        try {
            TScopedAlloc alloc(__LOCATION__);
            TTypeEnvironment env(alloc);
            TTypeBuilder typeBuilder(env);
            TNullOutput null;
            TVector<TType*> mkqlInputTypes;
            for (const auto& type : argTypes) {
                auto mkqlType = NCommon::BuildType(*type, typeBuilder, null);
                YQL_ENSURE(mkqlType, "Failed to convert type " << *type << " to MKQL type");
                mkqlInputTypes.emplace_back(mkqlType);
            }
            TType* mkqlOutputType = NCommon::BuildType(*returnType, typeBuilder, null);
            bool found = FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, *FunctionRegistry_.GetBuiltins());
            return found ? EStatus::OK : EStatus::NOT_FOUND;
        } catch (const std::exception& e) {
            ctx.AddError(TIssue(pos, e.what()));
            return EStatus::ERROR;
        }
    }

    EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const override {
        try {
            TScopedAlloc alloc(__LOCATION__);
            TTypeEnvironment env(alloc);
            TTypeBuilder typeBuilder(env);
            TNullOutput null;
            auto mkqlFromType = NCommon::BuildType(*from, typeBuilder, null);
            auto mkqlToType = NCommon::BuildType(*to, typeBuilder, null);
            return HasArrowCast(mkqlFromType, mkqlToType) ? EStatus::OK : EStatus::NOT_FOUND;
        } catch (const std::exception& e) {
            ctx.AddError(TIssue(pos, e.what()));
            return EStatus::ERROR;
        }
    }

    EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx,
        const TUnsupportedTypeCallback& onUnsupported = {}) const override
    {
        try {
            TScopedAlloc alloc(__LOCATION__);
            TTypeEnvironment env(alloc);
            TTypeBuilder typeBuilder(env);

	    bool allOk = true;
            TArrowConvertFailedCallback cb;
            if (onUnsupported) {
                cb = [&](TType* failed) {
                    if (failed->IsData()) {
                        auto slot = static_cast<TDataType*>(failed)->GetDataSlot();
                        YQL_ENSURE(slot);
                        onUnsupported(*slot);
                    } else {
                        onUnsupported(NYql::NCommon::ConvertMiniKQLTypeKind(failed));
                    }
                };
            }

            for (const auto& type : types) {
                YQL_ENSURE(type);
                TNullOutput null;
                auto mkqlType = NCommon::BuildType(*type, typeBuilder, null);
                std::shared_ptr<arrow::DataType> arrowType;
                if (!ConvertArrowType(mkqlType, arrowType, cb)) {
                    allOk = false;
                    if (!cb) {
                        break;
                    }
                }
            }
            return allOk ? EStatus::OK : EStatus::NOT_FOUND;
        } catch (const std::exception& e) {
            ctx.AddError(TIssue(pos, e.what()));
            return EStatus::ERROR;
        }
    }

private:
    const IFunctionRegistry& FunctionRegistry_;
};

IArrowResolver::TPtr MakeSimpleArrowResolver(const IFunctionRegistry& functionRegistry) {
    return new TSimpleArrowResolver(functionRegistry);
}

} // namespace NYql