1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
|
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/cast.h"
#include <mutex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/cast_internal.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/exec.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/function_internal.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernel.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernels/codegen_internal.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/registry.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/util/logging.h"
#include "contrib/libs/apache/arrow_next/cpp/src/arrow/util/reflection_internal.h"
namespace arrow20 {
using internal::ToTypeName;
namespace compute {
namespace internal {
// ----------------------------------------------------------------------
// Function options
namespace {
std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table;
std::once_flag cast_table_initialized;
void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) {
for (const auto& func : funcs) {
g_cast_table[static_cast<int>(func->out_type_id())] = func;
}
}
void InitCastTable() {
AddCastFunctions(GetBooleanCasts());
AddCastFunctions(GetBinaryLikeCasts());
AddCastFunctions(GetNestedCasts());
AddCastFunctions(GetNumericCasts());
AddCastFunctions(GetTemporalCasts());
AddCastFunctions(GetDictionaryCasts());
AddCastFunctions(GetExtensionCasts());
}
void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); }
const FunctionDoc cast_doc{"Cast values to another data type",
("Behavior when values wouldn't fit in the target type\n"
"can be controlled through CastOptions."),
{"input"},
"CastOptions"};
// Metafunction for dispatching to appropriate CastFunction. This corresponds
// to the standard SQL CAST(expr AS target_type)
class CastMetaFunction : public MetaFunction {
public:
CastMetaFunction() : MetaFunction("cast", Arity::Unary(), cast_doc) {}
Result<const CastOptions*> ValidateOptions(const FunctionOptions* options) const {
auto cast_options = static_cast<const CastOptions*>(options);
if (cast_options == nullptr || cast_options->to_type == nullptr) {
return Status::Invalid(
"Cast requires that options be passed with "
"the to_type populated");
}
return cast_options;
}
Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx) const override {
ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options));
// args[0].type() could be a nullptr so check for that before
// we do anything with it.
if (args[0].type() && args[0].type()->Equals(*cast_options->to_type)) {
// Nested types might differ in field names but still be considered equal,
// so we can only return non-nested types as-is.
if (!is_nested(args[0].type()->id())) {
return args[0];
} else if (args[0].is_array()) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> array,
::arrow20::internal::GetArrayView(
args[0].array(), cast_options->to_type.owned_type));
return Datum(array);
} else if (args[0].is_chunked_array()) {
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<ChunkedArray> array,
args[0].chunked_array()->View(cast_options->to_type.owned_type));
return Datum(array);
}
}
Result<std::shared_ptr<CastFunction>> result =
GetCastFunction(*cast_options->to_type);
if (!result.ok()) {
Status s = result.status();
return s.WithMessage(s.message(), " from ", *args[0].type());
}
return (*result)->Execute(args, options, ctx);
}
};
static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>(
arrow20::internal::DataMember("to_type", &CastOptions::to_type),
arrow20::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow),
arrow20::internal::DataMember("allow_time_truncate", &CastOptions::allow_time_truncate),
arrow20::internal::DataMember("allow_time_overflow", &CastOptions::allow_time_overflow),
arrow20::internal::DataMember("allow_decimal_truncate",
&CastOptions::allow_decimal_truncate),
arrow20::internal::DataMember("allow_float_truncate",
&CastOptions::allow_float_truncate),
arrow20::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8));
} // namespace
void RegisterScalarCast(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>()));
DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType));
}
CastFunction::CastFunction(std::string name, Type::type out_type_id)
: ScalarFunction(std::move(name), Arity::Unary(), FunctionDoc::Empty()),
out_type_id_(out_type_id) {}
Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) {
// We use the same KernelInit for every cast
kernel.init = internal::CastState::Init;
RETURN_NOT_OK(ScalarFunction::AddKernel(kernel));
in_type_ids_.push_back(in_type_id);
return Status::OK();
}
Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_types,
OutputType out_type, ArrayKernelExec exec,
NullHandling::type null_handling,
MemAllocation::type mem_allocation) {
ScalarKernel kernel;
kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type));
kernel.exec = exec;
kernel.null_handling = null_handling;
kernel.mem_allocation = mem_allocation;
return AddKernel(in_type_id, std::move(kernel));
}
Result<const Kernel*> CastFunction::DispatchExact(
const std::vector<TypeHolder>& types) const {
RETURN_NOT_OK(CheckArity(types.size()));
std::vector<const ScalarKernel*> candidate_kernels;
for (const auto& kernel : kernels_) {
if (kernel.signature->MatchesInputs(types)) {
candidate_kernels.push_back(&kernel);
}
}
if (candidate_kernels.size() == 0) {
return Status::NotImplemented("Unsupported cast from ", types[0].type->ToString(),
" to ", ToTypeName(out_type_id_), " using function ",
this->name());
}
if (candidate_kernels.size() == 1) {
// One match, return it
return candidate_kernels[0];
}
// Now we are in a casting scenario where we may have both a EXACT_TYPE and
// a SAME_TYPE_ID. So we will see if there is an exact match among the
// candidate kernels and if not we will just return the first one
for (auto kernel : candidate_kernels) {
const InputType& arg0 = kernel->signature->in_types()[0];
if (arg0.kind() == InputType::EXACT_TYPE) {
// Bingo. Return it
return kernel;
}
}
// We didn't find an exact match. So just return some kernel that matches
return candidate_kernels[0];
}
Result<std::shared_ptr<CastFunction>> GetCastFunction(const DataType& to_type) {
internal::EnsureInitCastTable();
auto it = internal::g_cast_table.find(static_cast<int>(to_type.id()));
if (it == internal::g_cast_table.end()) {
return Status::NotImplemented("Unsupported cast to ", to_type);
}
return it->second;
}
} // namespace internal
CastOptions::CastOptions(bool safe)
: FunctionOptions(internal::kCastOptionsType),
allow_int_overflow(!safe),
allow_time_truncate(!safe),
allow_time_overflow(!safe),
allow_decimal_truncate(!safe),
allow_float_truncate(!safe),
allow_invalid_utf8(!safe) {}
bool CastOptions::is_safe() const {
return !allow_int_overflow && !allow_time_truncate && !allow_time_overflow &&
!allow_decimal_truncate && !allow_float_truncate && !allow_invalid_utf8;
}
bool CastOptions::is_unsafe() const {
return allow_int_overflow && allow_time_truncate && allow_time_overflow &&
allow_decimal_truncate && allow_float_truncate && allow_invalid_utf8;
}
constexpr char CastOptions::kTypeName[];
Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
return CallFunction("cast", {value}, &options, ctx);
}
Result<Datum> Cast(const Datum& value, const TypeHolder& to_type,
const CastOptions& options, ExecContext* ctx) {
CastOptions options_with_to_type = options;
options_with_to_type.to_type = to_type;
return Cast(value, options_with_to_type, ctx);
}
Result<std::shared_ptr<Array>> Cast(const Array& value, const TypeHolder& to_type,
const CastOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx));
return result.make_array();
}
bool CanCast(const DataType& from_type, const DataType& to_type) {
internal::EnsureInitCastTable();
auto it = internal::g_cast_table.find(static_cast<int>(to_type.id()));
if (it == internal::g_cast_table.end()) {
return false;
}
const internal::CastFunction* function = it->second.get();
DCHECK_EQ(function->out_type_id(), to_type.id());
for (auto from_id : function->in_type_ids()) {
// XXX should probably check the output type as well
if (from_type.id() == from_id) return true;
}
return false;
}
} // namespace compute
} // namespace arrow20
|