1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
#include "yql_simple_arrow_resolver.h"
#include <yql/essentials/minikql/arrow/mkql_functions.h>
#include <yql/essentials/minikql/mkql_program_builder.h>
#include <yql/essentials/minikql/mkql_type_builder.h>
#include <yql/essentials/minikql/mkql_function_registry.h>
#include <yql/essentials/providers/common/mkql/yql_type_mkql.h>
#include <util/stream/null.h>
namespace NYql {
using namespace NKikimr::NMiniKQL;
class TSimpleArrowResolver: public IArrowResolver {
public:
TSimpleArrowResolver(const IFunctionRegistry& functionRegistry)
: FunctionRegistry_(functionRegistry)
{}
private:
EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
const TTypeAnnotationNode* returnType, TExprContext& ctx) const override
{
try {
TScopedAlloc alloc(__LOCATION__);
TTypeEnvironment env(alloc);
TTypeBuilder typeBuilder(env);
TNullOutput null;
TVector<TType*> mkqlInputTypes;
for (const auto& type : argTypes) {
auto mkqlType = NCommon::BuildType(*type, typeBuilder, null);
YQL_ENSURE(mkqlType, "Failed to convert type " << *type << " to MKQL type");
mkqlInputTypes.emplace_back(mkqlType);
}
TType* mkqlOutputType = NCommon::BuildType(*returnType, typeBuilder, null);
bool found = FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, *FunctionRegistry_.GetBuiltins());
return found ? EStatus::OK : EStatus::NOT_FOUND;
} catch (const std::exception& e) {
ctx.AddError(TIssue(pos, e.what()));
return EStatus::ERROR;
}
}
EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const override {
try {
TScopedAlloc alloc(__LOCATION__);
TTypeEnvironment env(alloc);
TTypeBuilder typeBuilder(env);
TNullOutput null;
auto mkqlFromType = NCommon::BuildType(*from, typeBuilder, null);
auto mkqlToType = NCommon::BuildType(*to, typeBuilder, null);
return HasArrowCast(mkqlFromType, mkqlToType) ? EStatus::OK : EStatus::NOT_FOUND;
} catch (const std::exception& e) {
ctx.AddError(TIssue(pos, e.what()));
return EStatus::ERROR;
}
}
EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx,
const TUnsupportedTypeCallback& onUnsupported = {}) const override
{
try {
TScopedAlloc alloc(__LOCATION__);
TTypeEnvironment env(alloc);
TTypeBuilder typeBuilder(env);
bool allOk = true;
TArrowConvertFailedCallback cb;
if (onUnsupported) {
cb = [&](TType* failed) {
if (failed->IsData()) {
auto slot = static_cast<TDataType*>(failed)->GetDataSlot();
YQL_ENSURE(slot);
onUnsupported(*slot);
} else {
onUnsupported(NYql::NCommon::ConvertMiniKQLTypeKind(failed));
}
};
}
for (const auto& type : types) {
YQL_ENSURE(type);
TNullOutput null;
auto mkqlType = NCommon::BuildType(*type, typeBuilder, null);
std::shared_ptr<arrow::DataType> arrowType;
if (!ConvertArrowType(mkqlType, arrowType, cb)) {
allOk = false;
if (!cb) {
break;
}
}
}
return allOk ? EStatus::OK : EStatus::NOT_FOUND;
} catch (const std::exception& e) {
ctx.AddError(TIssue(pos, e.what()));
return EStatus::ERROR;
}
}
private:
const IFunctionRegistry& FunctionRegistry_;
};
IArrowResolver::TPtr MakeSimpleArrowResolver(const IFunctionRegistry& functionRegistry) {
return new TSimpleArrowResolver(functionRegistry);
}
} // namespace NYql
|