diff options
author | Artem Zuikov <chertus@gmail.com> | 2022-02-10 16:46:27 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:46:27 +0300 |
commit | d23c9e2333524245de2f721e04136f51c31411ef (patch) | |
tree | 841eba1726e3db80e41401053939b3f5ef86ab2d /contrib/libs/apache | |
parent | 6f0f263753da4db2175d8b66d02619f6a476d319 (diff) | |
download | ydb-d23c9e2333524245de2f721e04136f51c31411ef.tar.gz |
Restoring authorship annotation for Artem Zuikov <chertus@gmail.com>. Commit 1 of 2.
Diffstat (limited to 'contrib/libs/apache')
83 files changed, 18246 insertions, 18246 deletions
diff --git a/contrib/libs/apache/arrow/.yandex_meta/devtools.licenses.report b/contrib/libs/apache/arrow/.yandex_meta/devtools.licenses.report index 1f1bbe7849..f3f9e90568 100644 --- a/contrib/libs/apache/arrow/.yandex_meta/devtools.licenses.report +++ b/contrib/libs/apache/arrow/.yandex_meta/devtools.licenses.report @@ -1308,25 +1308,25 @@ BELONGS ya.make cpp/src/arrow/compute/util_internal.h [1:16] cpp/src/arrow/config.cc [1:16] cpp/src/arrow/config.h [1:16] - cpp/src/arrow/csv/api.h [1:16] - cpp/src/arrow/csv/chunker.cc [1:16] - cpp/src/arrow/csv/chunker.h [1:16] - cpp/src/arrow/csv/column_builder.cc [1:16] - cpp/src/arrow/csv/column_builder.h [1:16] - cpp/src/arrow/csv/column_decoder.cc [1:16] - cpp/src/arrow/csv/column_decoder.h [1:16] - cpp/src/arrow/csv/converter.cc [1:16] - cpp/src/arrow/csv/converter.h [1:16] - cpp/src/arrow/csv/inference_internal.h [1:16] - cpp/src/arrow/csv/options.cc [1:16] - cpp/src/arrow/csv/options.h [1:16] - cpp/src/arrow/csv/parser.cc [1:16] - cpp/src/arrow/csv/parser.h [1:16] - cpp/src/arrow/csv/reader.cc [1:16] - cpp/src/arrow/csv/reader.h [1:16] - cpp/src/arrow/csv/type_fwd.h [1:16] - cpp/src/arrow/csv/writer.cc [1:16] - cpp/src/arrow/csv/writer.h [1:16] + cpp/src/arrow/csv/api.h [1:16] + cpp/src/arrow/csv/chunker.cc [1:16] + cpp/src/arrow/csv/chunker.h [1:16] + cpp/src/arrow/csv/column_builder.cc [1:16] + cpp/src/arrow/csv/column_builder.h [1:16] + cpp/src/arrow/csv/column_decoder.cc [1:16] + cpp/src/arrow/csv/column_decoder.h [1:16] + cpp/src/arrow/csv/converter.cc [1:16] + cpp/src/arrow/csv/converter.h [1:16] + cpp/src/arrow/csv/inference_internal.h [1:16] + cpp/src/arrow/csv/options.cc [1:16] + cpp/src/arrow/csv/options.h [1:16] + cpp/src/arrow/csv/parser.cc [1:16] + cpp/src/arrow/csv/parser.h [1:16] + cpp/src/arrow/csv/reader.cc [1:16] + cpp/src/arrow/csv/reader.h [1:16] + cpp/src/arrow/csv/type_fwd.h [1:16] + cpp/src/arrow/csv/writer.cc [1:16] + cpp/src/arrow/csv/writer.h [1:16] cpp/src/arrow/datum.cc [1:16] cpp/src/arrow/datum.h [1:16] cpp/src/arrow/device.cc [1:16] diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/api.h b/contrib/libs/apache/arrow/cpp/src/arrow/api.h index 8958eaf1c9..1ac5b20893 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/api.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/api.h @@ -1,44 +1,44 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Coarse public API while the library is in development - -#pragma once - -#include "arrow/array.h" // IYWU pragma: export -#include "arrow/array/concatenate.h" // IYWU pragma: export -#include "arrow/buffer.h" // IYWU pragma: export -#include "arrow/builder.h" // IYWU pragma: export -#include "arrow/chunked_array.h" // IYWU pragma: export -#include "arrow/compare.h" // IYWU pragma: export -#include "arrow/config.h" // IYWU pragma: export -#include "arrow/datum.h" // IYWU pragma: export -#include "arrow/extension_type.h" // IYWU pragma: export -#include "arrow/memory_pool.h" // IYWU pragma: export -#include "arrow/pretty_print.h" // IYWU pragma: export -#include "arrow/record_batch.h" // IYWU pragma: export -#include "arrow/result.h" // IYWU pragma: export -#include "arrow/status.h" // IYWU pragma: export -#include "arrow/table.h" // IYWU pragma: export -#include "arrow/table_builder.h" // IYWU pragma: export -#include "arrow/tensor.h" // IYWU pragma: export -#include "arrow/type.h" // IYWU pragma: export -#include "arrow/util/key_value_metadata.h" // IWYU pragma: export -#include "arrow/visitor.h" // IYWU pragma: export - -/// \brief Top-level namespace for Apache Arrow C++ API -namespace arrow {} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Coarse public API while the library is in development + +#pragma once + +#include "arrow/array.h" // IYWU pragma: export +#include "arrow/array/concatenate.h" // IYWU pragma: export +#include "arrow/buffer.h" // IYWU pragma: export +#include "arrow/builder.h" // IYWU pragma: export +#include "arrow/chunked_array.h" // IYWU pragma: export +#include "arrow/compare.h" // IYWU pragma: export +#include "arrow/config.h" // IYWU pragma: export +#include "arrow/datum.h" // IYWU pragma: export +#include "arrow/extension_type.h" // IYWU pragma: export +#include "arrow/memory_pool.h" // IYWU pragma: export +#include "arrow/pretty_print.h" // IYWU pragma: export +#include "arrow/record_batch.h" // IYWU pragma: export +#include "arrow/result.h" // IYWU pragma: export +#include "arrow/status.h" // IYWU pragma: export +#include "arrow/table.h" // IYWU pragma: export +#include "arrow/table_builder.h" // IYWU pragma: export +#include "arrow/tensor.h" // IYWU pragma: export +#include "arrow/type.h" // IYWU pragma: export +#include "arrow/util/key_value_metadata.h" // IWYU pragma: export +#include "arrow/visitor.h" // IYWU pragma: export + +/// \brief Top-level namespace for Apache Arrow C++ API +namespace arrow {} diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h b/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h index 62edc69fb8..c895240e23 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h @@ -53,7 +53,7 @@ class BaseBinaryBuilder : public ArrayBuilder { explicit BaseBinaryBuilder(MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool), offsets_builder_(pool), value_data_builder_(pool) {} - BaseBinaryBuilder(const std::shared_ptr<DataType>& /*type*/, MemoryPool* pool) + BaseBinaryBuilder(const std::shared_ptr<DataType>& /*type*/, MemoryPool* pool) : BaseBinaryBuilder(pool) {} Status Append(const uint8_t* value, offset_type length) { diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_dict.h b/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_dict.h index eb96482dbf..bb43658868 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_dict.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_dict.h @@ -421,11 +421,11 @@ class DictionaryBuilderBase<BuilderType, NullType> : public ArrayBuilder { DictionaryBuilderBase( enable_if_t<std::is_base_of<AdaptiveIntBuilderBase, B>::value, uint8_t> start_int_size, - const std::shared_ptr<DataType>& /*value_type*/, + const std::shared_ptr<DataType>& /*value_type*/, MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool), indices_builder_(start_int_size, pool) {} - explicit DictionaryBuilderBase(const std::shared_ptr<DataType>& /*value_type*/, + explicit DictionaryBuilderBase(const std::shared_ptr<DataType>& /*value_type*/, MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool), indices_builder_(pool) {} @@ -439,7 +439,7 @@ class DictionaryBuilderBase<BuilderType, NullType> : public ArrayBuilder { explicit DictionaryBuilderBase(MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool), indices_builder_(pool) {} - explicit DictionaryBuilderBase(const std::shared_ptr<Array>& /*dictionary*/, + explicit DictionaryBuilderBase(const std::shared_ptr<Array>& /*dictionary*/, MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool), indices_builder_(pool) {} diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h b/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h index 80cfc4061b..9bd7a52c34 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h @@ -32,7 +32,7 @@ namespace arrow { class ARROW_EXPORT NullBuilder : public ArrayBuilder { public: explicit NullBuilder(MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool) {} - explicit NullBuilder(const std::shared_ptr<DataType>& /*type*/, + explicit NullBuilder(const std::shared_ptr<DataType>& /*type*/, MemoryPool* pool = default_memory_pool()) : NullBuilder(pool) {} diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h index a890cd362f..13f1ea762a 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h @@ -1,35 +1,35 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle - -#pragma once - -/// \defgroup compute-concrete-options Concrete option classes for compute functions -/// @{ -/// @} - -#include "arrow/compute/api_aggregate.h" // IWYU pragma: export -#include "arrow/compute/api_scalar.h" // IWYU pragma: export -#include "arrow/compute/api_vector.h" // IWYU pragma: export -#include "arrow/compute/cast.h" // IWYU pragma: export -#include "arrow/compute/exec.h" // IWYU pragma: export -#include "arrow/compute/function.h" // IWYU pragma: export -#include "arrow/compute/kernel.h" // IWYU pragma: export -#include "arrow/compute/registry.h" // IWYU pragma: export -#include "arrow/datum.h" // IWYU pragma: export +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +/// \defgroup compute-concrete-options Concrete option classes for compute functions +/// @{ +/// @} + +#include "arrow/compute/api_aggregate.h" // IWYU pragma: export +#include "arrow/compute/api_scalar.h" // IWYU pragma: export +#include "arrow/compute/api_vector.h" // IWYU pragma: export +#include "arrow/compute/cast.h" // IWYU pragma: export +#include "arrow/compute/exec.h" // IWYU pragma: export +#include "arrow/compute/function.h" // IWYU pragma: export +#include "arrow/compute/kernel.h" // IWYU pragma: export +#include "arrow/compute/registry.h" // IWYU pragma: export +#include "arrow/datum.h" // IWYU pragma: export diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.cc index 1b00c366bf..6e9d9de0c5 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.cc @@ -1,30 +1,30 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/api_aggregate.h" - -#include "arrow/compute/exec.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_aggregate.h" + +#include "arrow/compute/exec.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" #include "arrow/compute/util_internal.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" - -namespace arrow { + +namespace arrow { namespace internal { template <> @@ -52,9 +52,9 @@ struct EnumTraits<compute::QuantileOptions::Interpolation> }; } // namespace internal -namespace compute { - -// ---------------------------------------------------------------------- +namespace compute { + +// ---------------------------------------------------------------------- // Function options using ::arrow::internal::checked_cast; @@ -133,33 +133,33 @@ void RegisterAggregateOptions(FunctionRegistry* registry) { } // namespace internal // ---------------------------------------------------------------------- -// Scalar aggregates - +// Scalar aggregates + Result<Datum> Count(const Datum& value, const ScalarAggregateOptions& options, ExecContext* ctx) { - return CallFunction("count", {value}, &options, ctx); -} - + return CallFunction("count", {value}, &options, ctx); +} + Result<Datum> Mean(const Datum& value, const ScalarAggregateOptions& options, ExecContext* ctx) { return CallFunction("mean", {value}, &options, ctx); -} - +} + Result<Datum> Sum(const Datum& value, const ScalarAggregateOptions& options, ExecContext* ctx) { return CallFunction("sum", {value}, &options, ctx); -} - +} + Result<Datum> MinMax(const Datum& value, const ScalarAggregateOptions& options, ExecContext* ctx) { - return CallFunction("min_max", {value}, &options, ctx); -} - + return CallFunction("min_max", {value}, &options, ctx); +} + Result<Datum> Any(const Datum& value, const ScalarAggregateOptions& options, ExecContext* ctx) { return CallFunction("any", {value}, &options, ctx); -} - +} + Result<Datum> All(const Datum& value, const ScalarAggregateOptions& options, ExecContext* ctx) { return CallFunction("all", {value}, &options, ctx); @@ -169,16 +169,16 @@ Result<Datum> Mode(const Datum& value, const ModeOptions& options, ExecContext* return CallFunction("mode", {value}, &options, ctx); } -Result<Datum> Stddev(const Datum& value, const VarianceOptions& options, - ExecContext* ctx) { - return CallFunction("stddev", {value}, &options, ctx); -} - -Result<Datum> Variance(const Datum& value, const VarianceOptions& options, - ExecContext* ctx) { - return CallFunction("variance", {value}, &options, ctx); -} - +Result<Datum> Stddev(const Datum& value, const VarianceOptions& options, + ExecContext* ctx) { + return CallFunction("stddev", {value}, &options, ctx); +} + +Result<Datum> Variance(const Datum& value, const VarianceOptions& options, + ExecContext* ctx) { + return CallFunction("variance", {value}, &options, ctx); +} + Result<Datum> Quantile(const Datum& value, const QuantileOptions& options, ExecContext* ctx) { return CallFunction("quantile", {value}, &options, ctx); @@ -193,5 +193,5 @@ Result<Datum> Index(const Datum& value, const IndexOptions& options, ExecContext return CallFunction("index", {value}, &options, ctx); } -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.h index 7a6c44bd92..99ea33f7bf 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_aggregate.h @@ -1,58 +1,58 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Eager evaluation convenience APIs for invoking common functions, including -// necessary memory allocations - -#pragma once - -#include "arrow/compute/function.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class Array; - -namespace compute { - -class ExecContext; - -// ---------------------------------------------------------------------- -// Aggregate functions - -/// \addtogroup compute-concrete-options -/// @{ - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include "arrow/compute/function.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class ExecContext; + +// ---------------------------------------------------------------------- +// Aggregate functions + +/// \addtogroup compute-concrete-options +/// @{ + /// \brief Control general scalar aggregate kernel behavior -/// +/// /// By default, null values are ignored class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { public: explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1); constexpr static char const kTypeName[] = "ScalarAggregateOptions"; static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; } - + bool skip_nulls; uint32_t min_count; }; - + /// \brief Control Mode kernel behavior /// /// Returns top-n common values and counts. @@ -62,12 +62,12 @@ class ARROW_EXPORT ModeOptions : public FunctionOptions { explicit ModeOptions(int64_t n = 1); constexpr static char const kTypeName[] = "ModeOptions"; static ModeOptions Defaults() { return ModeOptions{}; } - + int64_t n = 1; -}; - +}; + /// \brief Control Delta Degrees of Freedom (ddof) of Variance and Stddev kernel -/// +/// /// The divisor used in calculations is N - ddof, where N is the number of elements. /// By default, ddof is zero, and population variance or stddev is returned. class ARROW_EXPORT VarianceOptions : public FunctionOptions { @@ -91,23 +91,23 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions { HIGHER, NEAREST, MIDPOINT, - }; - + }; + explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR); - + explicit QuantileOptions(std::vector<double> q, enum Interpolation interpolation = LINEAR); - + constexpr static char const kTypeName[] = "QuantileOptions"; static QuantileOptions Defaults() { return QuantileOptions{}; } /// quantile must be between 0 and 1 inclusive std::vector<double> q; enum Interpolation interpolation; -}; - +}; + /// \brief Control TDigest approximate quantile kernel behavior -/// +/// /// By default, returns the median value. class ARROW_EXPORT TDigestOptions : public FunctionOptions { public: @@ -117,7 +117,7 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions { uint32_t buffer_size = 500); constexpr static char const kTypeName[] = "TDigestOptions"; static TDigestOptions Defaults() { return TDigestOptions{}; } - + /// quantile must be between 0 and 1 inclusive std::vector<double> q; /// compression parameter, default 100 @@ -125,7 +125,7 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions { /// input buffer size, default 500 uint32_t buffer_size; }; - + /// \brief Control Index kernel behavior class ARROW_EXPORT IndexOptions : public FunctionOptions { public: @@ -135,73 +135,73 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions { constexpr static char const kTypeName[] = "IndexOptions"; std::shared_ptr<Scalar> value; -}; - -/// @} - -/// \brief Count non-null (or null) values in an array. -/// +}; + +/// @} + +/// \brief Count non-null (or null) values in an array. +/// /// \param[in] options counting options, see ScalarAggregateOptions for more information -/// \param[in] datum to count -/// \param[in] ctx the function execution context, optional -/// \return out resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \param[in] datum to count +/// \param[in] ctx the function execution context, optional +/// \return out resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> Count( const Datum& datum, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); - -/// \brief Compute the mean of a numeric array. -/// -/// \param[in] value datum to compute the mean, expecting Array + +/// \brief Compute the mean of a numeric array. +/// +/// \param[in] value datum to compute the mean, expecting Array /// \param[in] options see ScalarAggregateOptions for more information -/// \param[in] ctx the function execution context, optional -/// \return datum of the computed mean as a DoubleScalar -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed mean as a DoubleScalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> Mean( const Datum& value, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); - -/// \brief Sum values of a numeric array. -/// -/// \param[in] value datum to sum, expecting Array or ChunkedArray + +/// \brief Sum values of a numeric array. +/// +/// \param[in] value datum to sum, expecting Array or ChunkedArray /// \param[in] options see ScalarAggregateOptions for more information -/// \param[in] ctx the function execution context, optional -/// \return datum of the computed sum as a Scalar -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed sum as a Scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> Sum( const Datum& value, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); - -/// \brief Calculate the min / max of a numeric array -/// -/// This function returns both the min and max as a struct scalar, with type -/// struct<min: T, max: T>, where T is the input type -/// -/// \param[in] value input datum, expecting Array or ChunkedArray + +/// \brief Calculate the min / max of a numeric array +/// +/// This function returns both the min and max as a struct scalar, with type +/// struct<min: T, max: T>, where T is the input type +/// +/// \param[in] value input datum, expecting Array or ChunkedArray /// \param[in] options see ScalarAggregateOptions for more information -/// \param[in] ctx the function execution context, optional -/// \return resulting datum as a struct<min: T, max: T> scalar -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as a struct<min: T, max: T> scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> MinMax( const Datum& value, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); - + /// \brief Test whether any element in a boolean array evaluates to true. /// /// This function returns true if any of the elements in the array evaluates @@ -244,53 +244,53 @@ Result<Datum> All( const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); -/// \brief Calculate the modal (most common) value of a numeric array -/// +/// \brief Calculate the modal (most common) value of a numeric array +/// /// This function returns top-n most common values and number of times they occur as /// an array of `struct<mode: T, count: int64>`, where T is the input type. /// Values with larger counts are returned before smaller ones. /// If there are more than one values with same count, smaller value is returned first. -/// -/// \param[in] value input datum, expecting Array or ChunkedArray +/// +/// \param[in] value input datum, expecting Array or ChunkedArray /// \param[in] options see ModeOptions for more information -/// \param[in] ctx the function execution context, optional +/// \param[in] ctx the function execution context, optional /// \return resulting datum as an array of struct<mode: T, count: int64> -/// -/// \since 2.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// +/// \since 2.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> Mode(const Datum& value, const ModeOptions& options = ModeOptions::Defaults(), ExecContext* ctx = NULLPTR); - -/// \brief Calculate the standard deviation of a numeric array -/// -/// \param[in] value input datum, expecting Array or ChunkedArray -/// \param[in] options see VarianceOptions for more information -/// \param[in] ctx the function execution context, optional -/// \return datum of the computed standard deviation as a DoubleScalar -/// -/// \since 2.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Stddev(const Datum& value, - const VarianceOptions& options = VarianceOptions::Defaults(), - ExecContext* ctx = NULLPTR); - -/// \brief Calculate the variance of a numeric array -/// -/// \param[in] value input datum, expecting Array or ChunkedArray -/// \param[in] options see VarianceOptions for more information -/// \param[in] ctx the function execution context, optional -/// \return datum of the computed variance as a DoubleScalar -/// -/// \since 2.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Variance(const Datum& value, - const VarianceOptions& options = VarianceOptions::Defaults(), - ExecContext* ctx = NULLPTR); - + +/// \brief Calculate the standard deviation of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see VarianceOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed standard deviation as a DoubleScalar +/// +/// \since 2.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Stddev(const Datum& value, + const VarianceOptions& options = VarianceOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the variance of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see VarianceOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed variance as a DoubleScalar +/// +/// \since 2.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Variance(const Datum& value, + const VarianceOptions& options = VarianceOptions::Defaults(), + ExecContext* ctx = NULLPTR); + /// \brief Calculate the quantiles of a numeric array /// /// \param[in] value input datum, expecting Array or ChunkedArray @@ -429,5 +429,5 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat ExecContext* ctx = default_exec_context()); } // namespace internal -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.cc index 1feb4e7eee..1d374cb915 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.cc @@ -1,37 +1,37 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/api_scalar.h" - -#include <memory> -#include <sstream> -#include <string> - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_scalar.h" + +#include <memory> +#include <sstream> +#include <string> + #include "arrow/array/array_base.h" -#include "arrow/compute/exec.h" +#include "arrow/compute/exec.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" #include "arrow/compute/util_internal.h" -#include "arrow/status.h" -#include "arrow/type.h" +#include "arrow/status.h" +#include "arrow/type.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" - -namespace arrow { + +namespace arrow { namespace internal { template <> @@ -100,8 +100,8 @@ struct EnumTraits<compute::CompareOperator> }; } // namespace internal -namespace compute { - +namespace compute { + // ---------------------------------------------------------------------- // Function options @@ -302,19 +302,19 @@ void RegisterScalarOptions(FunctionRegistry* registry) { } } // namespace internal -#define SCALAR_EAGER_UNARY(NAME, REGISTRY_NAME) \ - Result<Datum> NAME(const Datum& value, ExecContext* ctx) { \ - return CallFunction(REGISTRY_NAME, {value}, ctx); \ - } - -#define SCALAR_EAGER_BINARY(NAME, REGISTRY_NAME) \ - Result<Datum> NAME(const Datum& left, const Datum& right, ExecContext* ctx) { \ - return CallFunction(REGISTRY_NAME, {left, right}, ctx); \ - } - -// ---------------------------------------------------------------------- -// Arithmetic - +#define SCALAR_EAGER_UNARY(NAME, REGISTRY_NAME) \ + Result<Datum> NAME(const Datum& value, ExecContext* ctx) { \ + return CallFunction(REGISTRY_NAME, {value}, ctx); \ + } + +#define SCALAR_EAGER_BINARY(NAME, REGISTRY_NAME) \ + Result<Datum> NAME(const Datum& left, const Datum& right, ExecContext* ctx) { \ + return CallFunction(REGISTRY_NAME, {left, right}, ctx); \ + } + +// ---------------------------------------------------------------------- +// Arithmetic + #define SCALAR_ARITHMETIC_UNARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ Result<Datum> NAME(const Datum& arg, ArithmeticOptions options, ExecContext* ctx) { \ auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \ @@ -335,17 +335,17 @@ SCALAR_ARITHMETIC_UNARY(Log10, "log10", "log10_checked") SCALAR_ARITHMETIC_UNARY(Log2, "log2", "log2_checked") SCALAR_ARITHMETIC_UNARY(Log1p, "log1p", "log1p_checked") -#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ - Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \ - ExecContext* ctx) { \ - auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \ - return CallFunction(func_name, {left, right}, ctx); \ - } - -SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked") -SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked") -SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked") -SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked") +#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ + Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \ + ExecContext* ctx) { \ + auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \ + return CallFunction(func_name, {left, right}, ctx); \ + } + +SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked") +SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked") +SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked") +SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked") SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked") SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked") SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked") @@ -353,7 +353,7 @@ SCALAR_EAGER_BINARY(Atan2, "atan2") SCALAR_EAGER_UNARY(Floor, "floor") SCALAR_EAGER_UNARY(Ceil, "ceil") SCALAR_EAGER_UNARY(Trunc, "trunc") - + Result<Datum> MaxElementWise(const std::vector<Datum>& args, ElementWiseAggregateOptions options, ExecContext* ctx) { return CallFunction("max_element_wise", args, &options, ctx); @@ -364,14 +364,14 @@ Result<Datum> MinElementWise(const std::vector<Datum>& args, return CallFunction("min_element_wise", args, &options, ctx); } -// ---------------------------------------------------------------------- -// Set-related operations - -static Result<Datum> ExecSetLookup(const std::string& func_name, const Datum& data, +// ---------------------------------------------------------------------- +// Set-related operations + +static Result<Datum> ExecSetLookup(const std::string& func_name, const Datum& data, const SetLookupOptions& options, ExecContext* ctx) { if (!options.value_set.is_arraylike()) { - return Status::Invalid("Set lookup value set must be Array or ChunkedArray"); - } + return Status::Invalid("Set lookup value set must be Array or ChunkedArray"); + } std::shared_ptr<DataType> data_type; if (data.type()->id() == Type::DICTIONARY) { data_type = @@ -379,85 +379,85 @@ static Result<Datum> ExecSetLookup(const std::string& func_name, const Datum& da } else { data_type = data.type(); } - + if (options.value_set.length() > 0 && !data_type->Equals(options.value_set.type())) { - std::stringstream ss; + std::stringstream ss; ss << "Array type didn't match type of values set: " << data_type->ToString() << " vs " << options.value_set.type()->ToString(); - return Status::Invalid(ss.str()); - } - return CallFunction(func_name, {data}, &options, ctx); -} - + return Status::Invalid(ss.str()); + } + return CallFunction(func_name, {data}, &options, ctx); +} + Result<Datum> IsIn(const Datum& values, const SetLookupOptions& options, ExecContext* ctx) { return ExecSetLookup("is_in", values, options, ctx); } -Result<Datum> IsIn(const Datum& values, const Datum& value_set, ExecContext* ctx) { +Result<Datum> IsIn(const Datum& values, const Datum& value_set, ExecContext* ctx) { return ExecSetLookup("is_in", values, SetLookupOptions{value_set}, ctx); -} - +} + Result<Datum> IndexIn(const Datum& values, const SetLookupOptions& options, ExecContext* ctx) { return ExecSetLookup("index_in", values, options, ctx); } -Result<Datum> IndexIn(const Datum& values, const Datum& value_set, ExecContext* ctx) { +Result<Datum> IndexIn(const Datum& values, const Datum& value_set, ExecContext* ctx) { return ExecSetLookup("index_in", values, SetLookupOptions{value_set}, ctx); -} - -// ---------------------------------------------------------------------- -// Boolean functions - -SCALAR_EAGER_UNARY(Invert, "invert") -SCALAR_EAGER_BINARY(And, "and") -SCALAR_EAGER_BINARY(KleeneAnd, "and_kleene") -SCALAR_EAGER_BINARY(Or, "or") -SCALAR_EAGER_BINARY(KleeneOr, "or_kleene") -SCALAR_EAGER_BINARY(Xor, "xor") +} + +// ---------------------------------------------------------------------- +// Boolean functions + +SCALAR_EAGER_UNARY(Invert, "invert") +SCALAR_EAGER_BINARY(And, "and") +SCALAR_EAGER_BINARY(KleeneAnd, "and_kleene") +SCALAR_EAGER_BINARY(Or, "or") +SCALAR_EAGER_BINARY(KleeneOr, "or_kleene") +SCALAR_EAGER_BINARY(Xor, "xor") SCALAR_EAGER_BINARY(AndNot, "and_not") SCALAR_EAGER_BINARY(KleeneAndNot, "and_not_kleene") - -// ---------------------------------------------------------------------- - -Result<Datum> Compare(const Datum& left, const Datum& right, CompareOptions options, - ExecContext* ctx) { - std::string func_name; - switch (options.op) { - case CompareOperator::EQUAL: - func_name = "equal"; - break; - case CompareOperator::NOT_EQUAL: - func_name = "not_equal"; - break; - case CompareOperator::GREATER: - func_name = "greater"; - break; - case CompareOperator::GREATER_EQUAL: - func_name = "greater_equal"; - break; - case CompareOperator::LESS: - func_name = "less"; - break; - case CompareOperator::LESS_EQUAL: - func_name = "less_equal"; - break; - } + +// ---------------------------------------------------------------------- + +Result<Datum> Compare(const Datum& left, const Datum& right, CompareOptions options, + ExecContext* ctx) { + std::string func_name; + switch (options.op) { + case CompareOperator::EQUAL: + func_name = "equal"; + break; + case CompareOperator::NOT_EQUAL: + func_name = "not_equal"; + break; + case CompareOperator::GREATER: + func_name = "greater"; + break; + case CompareOperator::GREATER_EQUAL: + func_name = "greater_equal"; + break; + case CompareOperator::LESS: + func_name = "less"; + break; + case CompareOperator::LESS_EQUAL: + func_name = "less_equal"; + break; + } return CallFunction(func_name, {left, right}, nullptr, ctx); -} - -// ---------------------------------------------------------------------- -// Validity functions - -SCALAR_EAGER_UNARY(IsValid, "is_valid") -SCALAR_EAGER_UNARY(IsNull, "is_null") +} + +// ---------------------------------------------------------------------- +// Validity functions + +SCALAR_EAGER_UNARY(IsValid, "is_valid") +SCALAR_EAGER_UNARY(IsNull, "is_null") SCALAR_EAGER_UNARY(IsNan, "is_nan") - -Result<Datum> FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx) { - return CallFunction("fill_null", {values, fill_value}, ctx); -} - + +Result<Datum> FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx) { + return CallFunction("fill_null", {values, fill_value}, ctx); +} + Result<Datum> IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false, ExecContext* ctx) { return CallFunction("if_else", {cond, if_true, if_false}, ctx); @@ -494,5 +494,5 @@ Result<Datum> DayOfWeek(const Datum& arg, DayOfWeekOptions options, ExecContext* return CallFunction("day_of_week", {arg}, &options, ctx); } -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.h index e07e41569a..edad35f53d 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_scalar.h @@ -1,55 +1,55 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Eager evaluation convenience APIs for invoking common functions, including -// necessary memory allocations - -#pragma once - -#include <string> -#include <utility> - -#include "arrow/compute/exec.h" // IWYU pragma: keep -#include "arrow/compute/function.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace compute { - -/// \addtogroup compute-concrete-options -/// -/// @{ - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include <string> +#include <utility> + +#include "arrow/compute/exec.h" // IWYU pragma: keep +#include "arrow/compute/function.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \addtogroup compute-concrete-options +/// +/// @{ + class ARROW_EXPORT ArithmeticOptions : public FunctionOptions { public: explicit ArithmeticOptions(bool check_overflow = false); constexpr static char const kTypeName[] = "ArithmeticOptions"; - bool check_overflow; -}; - + bool check_overflow; +}; + class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { public: explicit ElementWiseAggregateOptions(bool skip_nulls = true); constexpr static char const kTypeName[] = "ElementWiseAggregateOptions"; static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; } - + bool skip_nulls; }; @@ -80,11 +80,11 @@ class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { constexpr static char const kTypeName[] = "MatchSubstringOptions"; /// The exact substring (or regex, depending on kernel) to look for inside input values. - std::string pattern; + std::string pattern; /// Whether to perform a case-insensitive match. bool ignore_case = false; -}; - +}; + class ARROW_EXPORT SplitOptions : public FunctionOptions { public: explicit SplitOptions(int64_t max_splits = -1, bool reverse = false); @@ -150,34 +150,34 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { std::string pattern; }; -/// Options for IsIn and IndexIn functions +/// Options for IsIn and IndexIn functions class ARROW_EXPORT SetLookupOptions : public FunctionOptions { public: explicit SetLookupOptions(Datum value_set, bool skip_nulls = false); SetLookupOptions(); constexpr static char const kTypeName[] = "SetLookupOptions"; - - /// The set of values to look up input values into. - Datum value_set; - /// Whether nulls in `value_set` count for lookup. - /// - /// If true, any null in `value_set` is ignored and nulls in the input - /// produce null (IndexIn) or false (IsIn) values in the output. - /// If false, any null in `value_set` is successfully matched in - /// the input. - bool skip_nulls; -}; - + + /// The set of values to look up input values into. + Datum value_set; + /// Whether nulls in `value_set` count for lookup. + /// + /// If true, any null in `value_set` is ignored and nulls in the input + /// produce null (IndexIn) or false (IsIn) values in the output. + /// If false, any null in `value_set` is successfully matched in + /// the input. + bool skip_nulls; +}; + class ARROW_EXPORT StrptimeOptions : public FunctionOptions { public: explicit StrptimeOptions(std::string format, TimeUnit::type unit); StrptimeOptions(); constexpr static char const kTypeName[] = "StrptimeOptions"; - - std::string format; - TimeUnit::type unit; -}; - + + std::string format; + TimeUnit::type unit; +}; + class ARROW_EXPORT PadOptions : public FunctionOptions { public: explicit PadOptions(int64_t width, std::string padding = " "); @@ -209,21 +209,21 @@ class ARROW_EXPORT SliceOptions : public FunctionOptions { int64_t start, stop, step; }; -enum CompareOperator : int8_t { - EQUAL, - NOT_EQUAL, - GREATER, - GREATER_EQUAL, - LESS, - LESS_EQUAL, -}; - +enum CompareOperator : int8_t { + EQUAL, + NOT_EQUAL, + GREATER, + GREATER_EQUAL, + LESS, + LESS_EQUAL, +}; + struct ARROW_EXPORT CompareOptions { - explicit CompareOptions(CompareOperator op) : op(op) {} + explicit CompareOptions(CompareOperator op) : op(op) {} CompareOptions() : CompareOptions(CompareOperator::EQUAL) {} - enum CompareOperator op; -}; - + enum CompareOperator op; +}; + class ARROW_EXPORT MakeStructOptions : public FunctionOptions { public: MakeStructOptions(std::vector<std::string> n, std::vector<bool> r, @@ -254,8 +254,8 @@ struct ARROW_EXPORT DayOfWeekOptions : public FunctionOptions { uint32_t week_start; }; -/// @} - +/// @} + /// \brief Get the absolute value of a value. /// /// If argument is null the result will be null. @@ -269,59 +269,59 @@ Result<Datum> AbsoluteValue(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), ExecContext* ctx = NULLPTR); -/// \brief Add two values together. Array values must be the same length. If -/// either addend is null the result will be null. -/// -/// \param[in] left the first addend -/// \param[in] right the second addend -/// \param[in] options arithmetic options (overflow handling), optional -/// \param[in] ctx the function execution context, optional -/// \return the elementwise sum -ARROW_EXPORT -Result<Datum> Add(const Datum& left, const Datum& right, - ArithmeticOptions options = ArithmeticOptions(), - ExecContext* ctx = NULLPTR); - -/// \brief Subtract two values. Array values must be the same length. If the -/// minuend or subtrahend is null the result will be null. -/// -/// \param[in] left the value subtracted from (minuend) -/// \param[in] right the value by which the minuend is reduced (subtrahend) -/// \param[in] options arithmetic options (overflow handling), optional -/// \param[in] ctx the function execution context, optional -/// \return the elementwise difference -ARROW_EXPORT -Result<Datum> Subtract(const Datum& left, const Datum& right, - ArithmeticOptions options = ArithmeticOptions(), - ExecContext* ctx = NULLPTR); - -/// \brief Multiply two values. Array values must be the same length. If either -/// factor is null the result will be null. -/// -/// \param[in] left the first factor -/// \param[in] right the second factor -/// \param[in] options arithmetic options (overflow handling), optional -/// \param[in] ctx the function execution context, optional -/// \return the elementwise product -ARROW_EXPORT -Result<Datum> Multiply(const Datum& left, const Datum& right, - ArithmeticOptions options = ArithmeticOptions(), - ExecContext* ctx = NULLPTR); - -/// \brief Divide two values. Array values must be the same length. If either -/// argument is null the result will be null. For integer types, if there is -/// a zero divisor, an error will be raised. -/// -/// \param[in] left the dividend -/// \param[in] right the divisor -/// \param[in] options arithmetic options (enable/disable overflow checking), optional -/// \param[in] ctx the function execution context, optional -/// \return the elementwise quotient -ARROW_EXPORT -Result<Datum> Divide(const Datum& left, const Datum& right, - ArithmeticOptions options = ArithmeticOptions(), - ExecContext* ctx = NULLPTR); - +/// \brief Add two values together. Array values must be the same length. If +/// either addend is null the result will be null. +/// +/// \param[in] left the first addend +/// \param[in] right the second addend +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise sum +ARROW_EXPORT +Result<Datum> Add(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Subtract two values. Array values must be the same length. If the +/// minuend or subtrahend is null the result will be null. +/// +/// \param[in] left the value subtracted from (minuend) +/// \param[in] right the value by which the minuend is reduced (subtrahend) +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise difference +ARROW_EXPORT +Result<Datum> Subtract(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Multiply two values. Array values must be the same length. If either +/// factor is null the result will be null. +/// +/// \param[in] left the first factor +/// \param[in] right the second factor +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise product +ARROW_EXPORT +Result<Datum> Multiply(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Divide two values. Array values must be the same length. If either +/// argument is null the result will be null. For integer types, if there is +/// a zero divisor, an error will be raised. +/// +/// \param[in] left the dividend +/// \param[in] right the divisor +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise quotient +ARROW_EXPORT +Result<Datum> Divide(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + /// \brief Negate values. /// /// If argument is null the result will be null. @@ -549,98 +549,98 @@ Result<Datum> MinElementWise( ARROW_EXPORT Result<Datum> Sign(const Datum& arg, ExecContext* ctx = NULLPTR); -/// \brief Compare a numeric array with a scalar. -/// -/// \param[in] left datum to compare, must be an Array -/// \param[in] right datum to compare, must be a Scalar of the same type than -/// left Datum. -/// \param[in] options compare options -/// \param[in] ctx the function execution context, optional -/// \return resulting datum -/// -/// Note on floating point arrays, this uses ieee-754 compare semantics. -/// -/// \since 1.0.0 -/// \note API not yet finalized +/// \brief Compare a numeric array with a scalar. +/// +/// \param[in] left datum to compare, must be an Array +/// \param[in] right datum to compare, must be a Scalar of the same type than +/// left Datum. +/// \param[in] options compare options +/// \param[in] ctx the function execution context, optional +/// \return resulting datum +/// +/// Note on floating point arrays, this uses ieee-754 compare semantics. +/// +/// \since 1.0.0 +/// \note API not yet finalized ARROW_DEPRECATED("Deprecated in 5.0.0. Use each compare function directly") -ARROW_EXPORT +ARROW_EXPORT Result<Datum> Compare(const Datum& left, const Datum& right, CompareOptions options, ExecContext* ctx = NULLPTR); - -/// \brief Invert the values of a boolean datum -/// \param[in] value datum to invert -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Invert(const Datum& value, ExecContext* ctx = NULLPTR); - -/// \brief Element-wise AND of two boolean datums which always propagates nulls -/// (null and false is null). -/// + +/// \brief Invert the values of a boolean datum +/// \param[in] value datum to invert +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Invert(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums which always propagates nulls +/// (null and false is null). +/// /// \param[in] left left operand /// \param[in] right right operand -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> And(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); - -/// \brief Element-wise AND of two boolean datums with a Kleene truth table -/// (null and false is false). -/// +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> And(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums with a Kleene truth table +/// (null and false is false). +/// /// \param[in] left left operand /// \param[in] right right operand -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> KleeneAnd(const Datum& left, const Datum& right, - ExecContext* ctx = NULLPTR); - -/// \brief Element-wise OR of two boolean datums which always propagates nulls -/// (null and true is null). -/// +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> KleeneAnd(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums which always propagates nulls +/// (null and true is null). +/// /// \param[in] left left operand /// \param[in] right right operand -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Or(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); - -/// \brief Element-wise OR of two boolean datums with a Kleene truth table -/// (null or true is true). -/// +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Or(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums with a Kleene truth table +/// (null or true is true). +/// /// \param[in] left left operand /// \param[in] right right operand -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> KleeneOr(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); - -/// \brief Element-wise XOR of two boolean datums +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> KleeneOr(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise XOR of two boolean datums /// \param[in] left left operand /// \param[in] right right operand -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Xor(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); - +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Xor(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + /// \brief Element-wise AND NOT of two boolean datums which always propagates nulls /// (null and not true is null). /// @@ -668,73 +668,73 @@ ARROW_EXPORT Result<Datum> KleeneAndNot(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); -/// \brief IsIn returns true for each element of `values` that is contained in -/// `value_set` -/// +/// \brief IsIn returns true for each element of `values` that is contained in +/// `value_set` +/// /// Behaviour of nulls is governed by SetLookupOptions::skip_nulls. -/// -/// \param[in] values array-like input to look up in value_set +/// +/// \param[in] values array-like input to look up in value_set /// \param[in] options SetLookupOptions -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> IsIn(const Datum& values, const SetLookupOptions& options, ExecContext* ctx = NULLPTR); ARROW_EXPORT -Result<Datum> IsIn(const Datum& values, const Datum& value_set, - ExecContext* ctx = NULLPTR); - -/// \brief IndexIn examines each slot in the values against a value_set array. -/// If the value is not found in value_set, null will be output. -/// If found, the index of occurrence within value_set (ignoring duplicates) -/// will be output. -/// -/// For example given values = [99, 42, 3, null] and -/// value_set = [3, 3, 99], the output will be = [1, null, 0, null] -/// +Result<Datum> IsIn(const Datum& values, const Datum& value_set, + ExecContext* ctx = NULLPTR); + +/// \brief IndexIn examines each slot in the values against a value_set array. +/// If the value is not found in value_set, null will be output. +/// If found, the index of occurrence within value_set (ignoring duplicates) +/// will be output. +/// +/// For example given values = [99, 42, 3, null] and +/// value_set = [3, 3, 99], the output will be = [1, null, 0, null] +/// /// Behaviour of nulls is governed by SetLookupOptions::skip_nulls. -/// -/// \param[in] values array-like input +/// +/// \param[in] values array-like input /// \param[in] options SetLookupOptions -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> IndexIn(const Datum& values, const SetLookupOptions& options, ExecContext* ctx = NULLPTR); ARROW_EXPORT -Result<Datum> IndexIn(const Datum& values, const Datum& value_set, - ExecContext* ctx = NULLPTR); - -/// \brief IsValid returns true for each element of `values` that is not null, -/// false otherwise -/// -/// \param[in] values input to examine for validity -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> IsValid(const Datum& values, ExecContext* ctx = NULLPTR); - -/// \brief IsNull returns true for each element of `values` that is null, -/// false otherwise -/// -/// \param[in] values input to examine for nullity -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> IsNull(const Datum& values, ExecContext* ctx = NULLPTR); - +Result<Datum> IndexIn(const Datum& values, const Datum& value_set, + ExecContext* ctx = NULLPTR); + +/// \brief IsValid returns true for each element of `values` that is not null, +/// false otherwise +/// +/// \param[in] values input to examine for validity +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> IsValid(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief IsNull returns true for each element of `values` that is null, +/// false otherwise +/// +/// \param[in] values input to examine for nullity +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> IsNull(const Datum& values, ExecContext* ctx = NULLPTR); + /// \brief IsNan returns true for each element of `values` that is NaN, /// false otherwise /// @@ -747,21 +747,21 @@ Result<Datum> IsNull(const Datum& values, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result<Datum> IsNan(const Datum& values, ExecContext* ctx = NULLPTR); -/// \brief FillNull replaces each null element in `values` -/// with `fill_value` -/// -/// \param[in] values input to examine for nullity -/// \param[in] fill_value scalar -/// \param[in] ctx the function execution context, optional -/// -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> FillNull(const Datum& values, const Datum& fill_value, - ExecContext* ctx = NULLPTR); - +/// \brief FillNull replaces each null element in `values` +/// with `fill_value` +/// +/// \param[in] values input to examine for nullity +/// \param[in] fill_value scalar +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> FillNull(const Datum& values, const Datum& fill_value, + ExecContext* ctx = NULLPTR); + /// \brief IfElse returns elements chosen from `left` or `right` /// depending on `cond`. `null` values in `cond` will be promoted to the result /// @@ -985,5 +985,5 @@ Result<Datum> Nanosecond(const Datum& values, ExecContext* ctx = NULLPTR); /// \note API not yet finalized ARROW_EXPORT Result<Datum> Subsecond(const Datum& values, ExecContext* ctx = NULLPTR); -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.cc index a68969b2ee..967829f425 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.cc @@ -1,43 +1,43 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/api_vector.h" - -#include <memory> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_vector.h" + +#include <memory> #include <sstream> -#include <utility> -#include <vector> - -#include "arrow/array/array_nested.h" -#include "arrow/array/builder_primitive.h" -#include "arrow/compute/exec.h" +#include <utility> +#include <vector> + +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/exec.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" -#include "arrow/datum.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/util/checked_cast.h" +#include "arrow/datum.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" - -namespace arrow { - + +namespace arrow { + using internal::checked_cast; -using internal::checked_pointer_cast; - +using internal::checked_pointer_cast; + namespace internal { using compute::DictionaryEncodeOptions; using compute::FilterOptions; @@ -73,9 +73,9 @@ struct EnumTraits<DictionaryEncodeOptions::NullEncodingBehavior> }; } // namespace internal -namespace compute { - -// ---------------------------------------------------------------------- +namespace compute { + +// ---------------------------------------------------------------------- // Function options bool SortKey::Equals(const SortKey& other) const { @@ -152,16 +152,16 @@ void RegisterVectorOptions(FunctionRegistry* registry) { } // namespace internal // ---------------------------------------------------------------------- -// Direct exec interface to kernels - -Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n, - ExecContext* ctx) { - PartitionNthOptions options(/*pivot=*/n); - ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("partition_nth_indices", - {Datum(values)}, &options, ctx)); - return result.make_array(); -} - +// Direct exec interface to kernels + +Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n, + ExecContext* ctx) { + PartitionNthOptions options(/*pivot=*/n); + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("partition_nth_indices", + {Datum(values)}, &options, ctx)); + return result.make_array(); +} + Result<Datum> ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx) { return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); @@ -172,9 +172,9 @@ Result<std::shared_ptr<Array>> SortIndices(const Array& values, SortOrder order, ArraySortOptions options(order); ARROW_ASSIGN_OR_RAISE( Datum result, CallFunction("array_sort_indices", {Datum(values)}, &options, ctx)); - return result.make_array(); -} - + return result.make_array(); +} + Result<std::shared_ptr<Array>> SortIndices(const ChunkedArray& chunked_array, SortOrder order, ExecContext* ctx) { SortOptions options({SortKey("not-used", order)}); @@ -190,94 +190,94 @@ Result<std::shared_ptr<Array>> SortIndices(const Datum& datum, const SortOptions return result.make_array(); } -Result<std::shared_ptr<Array>> Unique(const Datum& value, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("unique", {value}, ctx)); - return result.make_array(); -} - +Result<std::shared_ptr<Array>> Unique(const Datum& value, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("unique", {value}, ctx)); + return result.make_array(); +} + Result<Datum> DictionaryEncode(const Datum& value, const DictionaryEncodeOptions& options, ExecContext* ctx) { return CallFunction("dictionary_encode", {value}, &options, ctx); -} - -const char kValuesFieldName[] = "values"; -const char kCountsFieldName[] = "counts"; -const int32_t kValuesFieldIndex = 0; -const int32_t kCountsFieldIndex = 1; - -Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("value_counts", {value}, ctx)); - return checked_pointer_cast<StructArray>(result.make_array()); -} - -// ---------------------------------------------------------------------- -// Filter- and take-related selection functions - -Result<Datum> Filter(const Datum& values, const Datum& filter, - const FilterOptions& options, ExecContext* ctx) { - // Invoke metafunction which deals with Datum kinds other than just Array, - // ChunkedArray. - return CallFunction("filter", {values, filter}, &options, ctx); -} - -Result<Datum> Take(const Datum& values, const Datum& filter, const TakeOptions& options, - ExecContext* ctx) { - // Invoke metafunction which deals with Datum kinds other than just Array, - // ChunkedArray. - return CallFunction("take", {values, filter}, &options, ctx); -} - -Result<std::shared_ptr<Array>> Take(const Array& values, const Array& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum out, Take(Datum(values), Datum(indices), options, ctx)); - return out.make_array(); -} - -// ---------------------------------------------------------------------- -// Deprecated functions - -Result<std::shared_ptr<ChunkedArray>> Take(const ChunkedArray& values, - const Array& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(values), Datum(indices), options, ctx)); - return result.chunked_array(); -} - -Result<std::shared_ptr<ChunkedArray>> Take(const ChunkedArray& values, - const ChunkedArray& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(values), Datum(indices), options, ctx)); - return result.chunked_array(); -} - -Result<std::shared_ptr<ChunkedArray>> Take(const Array& values, - const ChunkedArray& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(values), Datum(indices), options, ctx)); - return result.chunked_array(); -} - -Result<std::shared_ptr<RecordBatch>> Take(const RecordBatch& batch, const Array& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(batch), Datum(indices), options, ctx)); - return result.record_batch(); -} - -Result<std::shared_ptr<Table>> Take(const Table& table, const Array& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(table), Datum(indices), options, ctx)); - return result.table(); -} - -Result<std::shared_ptr<Table>> Take(const Table& table, const ChunkedArray& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(table), Datum(indices), options, ctx)); - return result.table(); -} - +} + +const char kValuesFieldName[] = "values"; +const char kCountsFieldName[] = "counts"; +const int32_t kValuesFieldIndex = 0; +const int32_t kCountsFieldIndex = 1; + +Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("value_counts", {value}, ctx)); + return checked_pointer_cast<StructArray>(result.make_array()); +} + +// ---------------------------------------------------------------------- +// Filter- and take-related selection functions + +Result<Datum> Filter(const Datum& values, const Datum& filter, + const FilterOptions& options, ExecContext* ctx) { + // Invoke metafunction which deals with Datum kinds other than just Array, + // ChunkedArray. + return CallFunction("filter", {values, filter}, &options, ctx); +} + +Result<Datum> Take(const Datum& values, const Datum& filter, const TakeOptions& options, + ExecContext* ctx) { + // Invoke metafunction which deals with Datum kinds other than just Array, + // ChunkedArray. + return CallFunction("take", {values, filter}, &options, ctx); +} + +Result<std::shared_ptr<Array>> Take(const Array& values, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum out, Take(Datum(values), Datum(indices), options, ctx)); + return out.make_array(); +} + +// ---------------------------------------------------------------------- +// Deprecated functions + +Result<std::shared_ptr<ChunkedArray>> Take(const ChunkedArray& values, + const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(values), Datum(indices), options, ctx)); + return result.chunked_array(); +} + +Result<std::shared_ptr<ChunkedArray>> Take(const ChunkedArray& values, + const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(values), Datum(indices), options, ctx)); + return result.chunked_array(); +} + +Result<std::shared_ptr<ChunkedArray>> Take(const Array& values, + const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(values), Datum(indices), options, ctx)); + return result.chunked_array(); +} + +Result<std::shared_ptr<RecordBatch>> Take(const RecordBatch& batch, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(batch), Datum(indices), options, ctx)); + return result.record_batch(); +} + +Result<std::shared_ptr<Table>> Take(const Table& table, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(table), Datum(indices), options, ctx)); + return result.table(); +} + +Result<std::shared_ptr<Table>> Take(const Table& table, const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(table), Datum(indices), options, ctx)); + return result.table(); +} + Result<std::shared_ptr<Array>> SortToIndices(const Array& values, ExecContext* ctx) { return SortIndices(values, SortOrder::Ascending, ctx); } -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.h index 9d8d4271db..4a804e4a57 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.h @@ -1,65 +1,65 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <memory> - -#include "arrow/compute/function.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/type_fwd.h" - -namespace arrow { -namespace compute { - -class ExecContext; - -/// \addtogroup compute-concrete-options -/// @{ - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/compute/function.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace compute { + +class ExecContext; + +/// \addtogroup compute-concrete-options +/// @{ + class ARROW_EXPORT FilterOptions : public FunctionOptions { public: - /// Configure the action taken when a slot of the selection mask is null - enum NullSelectionBehavior { - /// the corresponding filtered value will be removed in the output - DROP, - /// the corresponding filtered value will be null in the output - EMIT_NULL, - }; - + /// Configure the action taken when a slot of the selection mask is null + enum NullSelectionBehavior { + /// the corresponding filtered value will be removed in the output + DROP, + /// the corresponding filtered value will be null in the output + EMIT_NULL, + }; + explicit FilterOptions(NullSelectionBehavior null_selection = DROP); constexpr static char const kTypeName[] = "FilterOptions"; - static FilterOptions Defaults() { return FilterOptions(); } - - NullSelectionBehavior null_selection_behavior = DROP; -}; - + static FilterOptions Defaults() { return FilterOptions(); } + + NullSelectionBehavior null_selection_behavior = DROP; +}; + class ARROW_EXPORT TakeOptions : public FunctionOptions { public: explicit TakeOptions(bool boundscheck = true); constexpr static char const kTypeName[] = "TakeOptions"; - static TakeOptions BoundsCheck() { return TakeOptions(true); } - static TakeOptions NoBoundsCheck() { return TakeOptions(false); } - static TakeOptions Defaults() { return BoundsCheck(); } + static TakeOptions BoundsCheck() { return TakeOptions(true); } + static TakeOptions NoBoundsCheck() { return TakeOptions(false); } + static TakeOptions Defaults() { return BoundsCheck(); } bool boundscheck = true; -}; - +}; + /// \brief Options for the dictionary encode function class ARROW_EXPORT DictionaryEncodeOptions : public FunctionOptions { public: @@ -119,58 +119,58 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { std::vector<SortKey> sort_keys; }; -/// \brief Partitioning options for NthToIndices +/// \brief Partitioning options for NthToIndices class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { public: explicit PartitionNthOptions(int64_t pivot); PartitionNthOptions() : PartitionNthOptions(0) {} constexpr static char const kTypeName[] = "PartitionNthOptions"; - - /// The index into the equivalent sorted array of the partition pivot element. - int64_t pivot; -}; - -/// @} - -/// \brief Filter with a boolean selection filter -/// -/// The output will be populated with values from the input at positions -/// where the selection filter is not 0. Nulls in the filter will be handled -/// based on options.null_selection_behavior. -/// -/// For example given values = ["a", "b", "c", null, "e", "f"] and -/// filter = [0, 1, 1, 0, null, 1], the output will be -/// (null_selection_behavior == DROP) = ["b", "c", "f"] -/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"] -/// -/// \param[in] values array to filter -/// \param[in] filter indicates which values should be filtered out -/// \param[in] options configures null_selection_behavior -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -ARROW_EXPORT -Result<Datum> Filter(const Datum& values, const Datum& filter, - const FilterOptions& options = FilterOptions::Defaults(), - ExecContext* ctx = NULLPTR); - -namespace internal { - -// These internal functions are implemented in kernels/vector_selection.cc - -/// \brief Return the number of selected indices in the boolean filter -ARROW_EXPORT -int64_t GetFilterOutputSize(const ArrayData& filter, - FilterOptions::NullSelectionBehavior null_selection); - -/// \brief Compute uint64 selection indices for use with Take given a boolean -/// filter -ARROW_EXPORT -Result<std::shared_ptr<ArrayData>> GetTakeIndices( - const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, - MemoryPool* memory_pool = default_memory_pool()); - -} // namespace internal - + + /// The index into the equivalent sorted array of the partition pivot element. + int64_t pivot; +}; + +/// @} + +/// \brief Filter with a boolean selection filter +/// +/// The output will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will be handled +/// based on options.null_selection_behavior. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// (null_selection_behavior == DROP) = ["b", "c", "f"] +/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"] +/// +/// \param[in] values array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[in] options configures null_selection_behavior +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result<Datum> Filter(const Datum& values, const Datum& filter, + const FilterOptions& options = FilterOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +namespace internal { + +// These internal functions are implemented in kernels/vector_selection.cc + +/// \brief Return the number of selected indices in the boolean filter +ARROW_EXPORT +int64_t GetFilterOutputSize(const ArrayData& filter, + FilterOptions::NullSelectionBehavior null_selection); + +/// \brief Compute uint64 selection indices for use with Take given a boolean +/// filter +ARROW_EXPORT +Result<std::shared_ptr<ArrayData>> GetTakeIndices( + const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool = default_memory_pool()); + +} // namespace internal + /// \brief ReplaceWithMask replaces each value in the array corresponding /// to a true value in the mask with the next element from `replacements`. /// @@ -188,72 +188,72 @@ ARROW_EXPORT Result<Datum> ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx = NULLPTR); -/// \brief Take from an array of values at indices in another array -/// -/// The output array will be of the same type as the input values -/// array, with elements taken from the values array at the given -/// indices. If an index is null then the taken element will be null. -/// -/// For example given values = ["a", "b", "c", null, "e", "f"] and -/// indices = [2, 1, null, 3], the output will be -/// = [values[2], values[1], null, values[3]] -/// = ["c", "b", null, null] -/// -/// \param[in] values datum from which to take -/// \param[in] indices which values to take -/// \param[in] options options -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -ARROW_EXPORT -Result<Datum> Take(const Datum& values, const Datum& indices, - const TakeOptions& options = TakeOptions::Defaults(), - ExecContext* ctx = NULLPTR); - -/// \brief Take with Array inputs and output -ARROW_EXPORT -Result<std::shared_ptr<Array>> Take(const Array& values, const Array& indices, - const TakeOptions& options = TakeOptions::Defaults(), - ExecContext* ctx = NULLPTR); - -/// \brief Returns indices that partition an array around n-th -/// sorted element. -/// -/// Find index of n-th(0 based) smallest value and perform indirect -/// partition of an array around that element. Output indices[0 ~ n-1] -/// holds values no greater than n-th element, and indices[n+1 ~ end] -/// holds values no less than n-th element. Elements in each partition -/// is not sorted. Nulls will be partitioned to the end of the output. -/// Output is not guaranteed to be stable. -/// -/// \param[in] values array to be partitioned -/// \param[in] n pivot array around sorted n-th element -/// \param[in] ctx the function execution context, optional -/// \return offsets indices that would partition an array -ARROW_EXPORT -Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n, - ExecContext* ctx = NULLPTR); - +/// \brief Take from an array of values at indices in another array +/// +/// The output array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] values datum from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result<Datum> Take(const Datum& values, const Datum& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Take with Array inputs and output +ARROW_EXPORT +Result<std::shared_ptr<Array>> Take(const Array& values, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Returns indices that partition an array around n-th +/// sorted element. +/// +/// Find index of n-th(0 based) smallest value and perform indirect +/// partition of an array around that element. Output indices[0 ~ n-1] +/// holds values no greater than n-th element, and indices[n+1 ~ end] +/// holds values no less than n-th element. Elements in each partition +/// is not sorted. Nulls will be partitioned to the end of the output. +/// Output is not guaranteed to be stable. +/// +/// \param[in] values array to be partitioned +/// \param[in] n pivot array around sorted n-th element +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would partition an array +ARROW_EXPORT +Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n, + ExecContext* ctx = NULLPTR); + /// \brief Returns the indices that would sort an array in the /// specified order. -/// -/// Perform an indirect sort of array. The output array will contain -/// indices that would sort an array, which would be the same length +/// +/// Perform an indirect sort of array. The output array will contain +/// indices that would sort an array, which would be the same length /// as input. Nulls will be stably partitioned to the end of the output /// regardless of order. -/// +/// /// For example given array = [null, 1, 3.3, null, 2, 5.3] and order /// = SortOrder::DESCENDING, the output will be [5, 2, 4, 1, 0, /// 3]. -/// +/// /// \param[in] array array to sort /// \param[in] order ascending or descending -/// \param[in] ctx the function execution context, optional -/// \return offsets indices that would sort an array -ARROW_EXPORT +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT Result<std::shared_ptr<Array>> SortIndices(const Array& array, SortOrder order = SortOrder::Ascending, ExecContext* ctx = NULLPTR); - + /// \brief Returns the indices that would sort a chunked array in the /// specified order. /// @@ -300,44 +300,44 @@ ARROW_EXPORT Result<std::shared_ptr<Array>> SortIndices(const Datum& datum, const SortOptions& options, ExecContext* ctx = NULLPTR); -/// \brief Compute unique elements from an array-like object -/// -/// Note if a null occurs in the input it will NOT be included in the output. -/// -/// \param[in] datum array-like input -/// \param[in] ctx the function execution context, optional -/// \return result as Array -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<std::shared_ptr<Array>> Unique(const Datum& datum, ExecContext* ctx = NULLPTR); - -// Constants for accessing the output of ValueCounts -ARROW_EXPORT extern const char kValuesFieldName[]; -ARROW_EXPORT extern const char kCountsFieldName[]; -ARROW_EXPORT extern const int32_t kValuesFieldIndex; -ARROW_EXPORT extern const int32_t kCountsFieldIndex; - -/// \brief Return counts of unique elements from an array-like object. -/// -/// Note that the counts do not include counts for nulls in the array. These can be -/// obtained separately from metadata. -/// -/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values -/// which can lead to unexpected results if the input Array has these values. -/// -/// \param[in] value array-like input -/// \param[in] ctx the function execution context, optional -/// \return counts An array of <input type "Values", int64_t "Counts"> structs. -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, - ExecContext* ctx = NULLPTR); - -/// \brief Dictionary-encode values in an array-like object +/// \brief Compute unique elements from an array-like object +/// +/// Note if a null occurs in the input it will NOT be included in the output. +/// +/// \param[in] datum array-like input +/// \param[in] ctx the function execution context, optional +/// \return result as Array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<std::shared_ptr<Array>> Unique(const Datum& datum, ExecContext* ctx = NULLPTR); + +// Constants for accessing the output of ValueCounts +ARROW_EXPORT extern const char kValuesFieldName[]; +ARROW_EXPORT extern const char kCountsFieldName[]; +ARROW_EXPORT extern const int32_t kValuesFieldIndex; +ARROW_EXPORT extern const int32_t kCountsFieldIndex; + +/// \brief Return counts of unique elements from an array-like object. +/// +/// Note that the counts do not include counts for nulls in the array. These can be +/// obtained separately from metadata. +/// +/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values +/// which can lead to unexpected results if the input Array has these values. +/// +/// \param[in] value array-like input +/// \param[in] ctx the function execution context, optional +/// \return counts An array of <input type "Values", int64_t "Counts"> structs. +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, + ExecContext* ctx = NULLPTR); + +/// \brief Dictionary-encode values in an array-like object /// /// Any nulls encountered in the dictionary will be handled according to the /// specified null encoding behavior. @@ -349,62 +349,62 @@ Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, /// If the input is already dictionary encoded this function is a no-op unless /// it needs to modify the null_encoding (TODO) /// -/// \param[in] data array-like input -/// \param[in] ctx the function execution context, optional +/// \param[in] data array-like input +/// \param[in] ctx the function execution context, optional /// \param[in] options configures null encoding behavior -/// \return result with same shape and type as input -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT +/// \return result with same shape and type as input +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result<Datum> DictionaryEncode( const Datum& data, const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(), ExecContext* ctx = NULLPTR); - -// ---------------------------------------------------------------------- -// Deprecated functions - -ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") -ARROW_EXPORT -Result<std::shared_ptr<ChunkedArray>> Take( - const ChunkedArray& values, const Array& indices, - const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); - -ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") -ARROW_EXPORT -Result<std::shared_ptr<ChunkedArray>> Take( - const ChunkedArray& values, const ChunkedArray& indices, - const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); - -ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") -ARROW_EXPORT -Result<std::shared_ptr<ChunkedArray>> Take( - const Array& values, const ChunkedArray& indices, - const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); - -ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") -ARROW_EXPORT -Result<std::shared_ptr<RecordBatch>> Take( - const RecordBatch& batch, const Array& indices, - const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); - -ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") -ARROW_EXPORT -Result<std::shared_ptr<Table>> Take(const Table& table, const Array& indices, - const TakeOptions& options = TakeOptions::Defaults(), - ExecContext* context = NULLPTR); - -ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") -ARROW_EXPORT -Result<std::shared_ptr<Table>> Take(const Table& table, const ChunkedArray& indices, - const TakeOptions& options = TakeOptions::Defaults(), - ExecContext* context = NULLPTR); - + +// ---------------------------------------------------------------------- +// Deprecated functions + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") +ARROW_EXPORT +Result<std::shared_ptr<ChunkedArray>> Take( + const ChunkedArray& values, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") +ARROW_EXPORT +Result<std::shared_ptr<ChunkedArray>> Take( + const ChunkedArray& values, const ChunkedArray& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") +ARROW_EXPORT +Result<std::shared_ptr<ChunkedArray>> Take( + const Array& values, const ChunkedArray& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") +ARROW_EXPORT +Result<std::shared_ptr<RecordBatch>> Take( + const RecordBatch& batch, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") +ARROW_EXPORT +Result<std::shared_ptr<Table>> Take(const Table& table, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* context = NULLPTR); + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use Datum-based version") +ARROW_EXPORT +Result<std::shared_ptr<Table>> Take(const Table& table, const ChunkedArray& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* context = NULLPTR); + ARROW_DEPRECATED("Deprecated in 3.0.0. Use SortIndices()") ARROW_EXPORT Result<std::shared_ptr<Array>> SortToIndices(const Array& values, ExecContext* ctx = NULLPTR); -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc index 4de68ba8d9..d92079cde8 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc @@ -1,128 +1,128 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/cast.h" - -#include <mutex> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/cast.h" + +#include <mutex> #include <sstream> -#include <string> -#include <unordered_map> -#include <unordered_set> -#include <utility> -#include <vector> - -#include "arrow/compute/cast_internal.h" -#include "arrow/compute/exec.h" +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "arrow/compute/cast_internal.h" +#include "arrow/compute/exec.h" #include "arrow/compute/function_internal.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/codegen_internal.h" -#include "arrow/compute/registry.h" -#include "arrow/util/logging.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/util/logging.h" #include "arrow/util/reflection_internal.h" - -namespace arrow { - -using internal::ToTypeName; - -namespace compute { -namespace internal { - + +namespace arrow { + +using internal::ToTypeName; + +namespace compute { +namespace internal { + // ---------------------------------------------------------------------- // Function options namespace { -std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table; +std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table; std::once_flag cast_table_initialized; - -void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) { - for (const auto& func : funcs) { - g_cast_table[static_cast<int>(func->out_type_id())] = func; - } -} - -void InitCastTable() { - AddCastFunctions(GetBooleanCasts()); - AddCastFunctions(GetBinaryLikeCasts()); - AddCastFunctions(GetNestedCasts()); - AddCastFunctions(GetNumericCasts()); - AddCastFunctions(GetTemporalCasts()); + +void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) { + for (const auto& func : funcs) { + g_cast_table[static_cast<int>(func->out_type_id())] = func; + } +} + +void InitCastTable() { + AddCastFunctions(GetBooleanCasts()); + AddCastFunctions(GetBinaryLikeCasts()); + AddCastFunctions(GetNestedCasts()); + AddCastFunctions(GetNumericCasts()); + AddCastFunctions(GetTemporalCasts()); AddCastFunctions(GetDictionaryCasts()); -} - -void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } - -// Private version of GetCastFunction with better error reporting -// if the input type is known. -Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal( - const std::shared_ptr<DataType>& to_type, const DataType* from_type = nullptr) { - internal::EnsureInitCastTable(); - auto it = internal::g_cast_table.find(static_cast<int>(to_type->id())); - if (it == internal::g_cast_table.end()) { - if (from_type != nullptr) { - return Status::NotImplemented("Unsupported cast from ", *from_type, " to ", - *to_type, - " (no available cast function for target type)"); - } else { - return Status::NotImplemented("Unsupported cast to ", *to_type, - " (no available cast function for target type)"); - } - } - return it->second; -} - +} + +void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } + +// Private version of GetCastFunction with better error reporting +// if the input type is known. +Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal( + const std::shared_ptr<DataType>& to_type, const DataType* from_type = nullptr) { + internal::EnsureInitCastTable(); + auto it = internal::g_cast_table.find(static_cast<int>(to_type->id())); + if (it == internal::g_cast_table.end()) { + if (from_type != nullptr) { + return Status::NotImplemented("Unsupported cast from ", *from_type, " to ", + *to_type, + " (no available cast function for target type)"); + } else { + return Status::NotImplemented("Unsupported cast to ", *to_type, + " (no available cast function for target type)"); + } + } + return it->second; +} + const FunctionDoc cast_doc{"Cast values to another data type", ("Behavior when values wouldn't fit in the target type\n" "can be controlled through CastOptions."), {"input"}, "CastOptions"}; - + // Metafunction for dispatching to appropriate CastFunction. This corresponds -// to the standard SQL CAST(expr AS target_type) -class CastMetaFunction : public MetaFunction { - public: +// to the standard SQL CAST(expr AS target_type) +class CastMetaFunction : public MetaFunction { + public: CastMetaFunction() : MetaFunction("cast", Arity::Unary(), &cast_doc) {} - - Result<const CastOptions*> ValidateOptions(const FunctionOptions* options) const { - auto cast_options = static_cast<const CastOptions*>(options); - - if (cast_options == nullptr || cast_options->to_type == nullptr) { - return Status::Invalid( - "Cast requires that options be passed with " - "the to_type populated"); - } - - return cast_options; - } - - Result<Datum> ExecuteImpl(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options)); - if (args[0].type()->Equals(*cast_options->to_type)) { - return args[0]; - } - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr<CastFunction> cast_func, - GetCastFunctionInternal(cast_options->to_type, args[0].type().get())); - return cast_func->Execute(args, options, ctx); - } -}; - + + Result<const CastOptions*> ValidateOptions(const FunctionOptions* options) const { + auto cast_options = static_cast<const CastOptions*>(options); + + if (cast_options == nullptr || cast_options->to_type == nullptr) { + return Status::Invalid( + "Cast requires that options be passed with " + "the to_type populated"); + } + + return cast_options; + } + + Result<Datum> ExecuteImpl(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options)); + if (args[0].type()->Equals(*cast_options->to_type)) { + return args[0]; + } + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr<CastFunction> cast_func, + GetCastFunctionInternal(cast_options->to_type, args[0].type().get())); + return cast_func->Execute(args, options, ctx); + } +}; + static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>( arrow::internal::DataMember("to_type", &CastOptions::to_type), arrow::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow), @@ -135,12 +135,12 @@ static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>( arrow::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8)); } // namespace -void RegisterScalarCast(FunctionRegistry* registry) { - DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>())); +void RegisterScalarCast(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>())); DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType)); -} -} // namespace internal - +} +} // namespace internal + CastOptions::CastOptions(bool safe) : FunctionOptions(internal::kCastOptionsType), allow_int_overflow(!safe), @@ -149,53 +149,53 @@ CastOptions::CastOptions(bool safe) allow_decimal_truncate(!safe), allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} - + constexpr char CastOptions::kTypeName[]; - + CastFunction::CastFunction(std::string name, Type::type out_type_id) : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr), out_type_id_(out_type_id) {} - -Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) { - // We use the same KernelInit for every cast - kernel.init = internal::CastState::Init; - RETURN_NOT_OK(ScalarFunction::AddKernel(kernel)); + +Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) { + // We use the same KernelInit for every cast + kernel.init = internal::CastState::Init; + RETURN_NOT_OK(ScalarFunction::AddKernel(kernel)); in_type_ids_.push_back(in_type_id); - return Status::OK(); -} - -Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_types, - OutputType out_type, ArrayKernelExec exec, - NullHandling::type null_handling, - MemAllocation::type mem_allocation) { - ScalarKernel kernel; - kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type)); - kernel.exec = exec; - kernel.null_handling = null_handling; - kernel.mem_allocation = mem_allocation; - return AddKernel(in_type_id, std::move(kernel)); -} - + return Status::OK(); +} + +Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_types, + OutputType out_type, ArrayKernelExec exec, + NullHandling::type null_handling, + MemAllocation::type mem_allocation) { + ScalarKernel kernel; + kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type)); + kernel.exec = exec; + kernel.null_handling = null_handling; + kernel.mem_allocation = mem_allocation; + return AddKernel(in_type_id, std::move(kernel)); +} + Result<const Kernel*> CastFunction::DispatchExact( - const std::vector<ValueDescr>& values) const { + const std::vector<ValueDescr>& values) const { RETURN_NOT_OK(CheckArity(values)); - - std::vector<const ScalarKernel*> candidate_kernels; - for (const auto& kernel : kernels_) { - if (kernel.signature->MatchesInputs(values)) { - candidate_kernels.push_back(&kernel); - } - } - - if (candidate_kernels.size() == 0) { - return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), + + std::vector<const ScalarKernel*> candidate_kernels; + for (const auto& kernel : kernels_) { + if (kernel.signature->MatchesInputs(values)) { + candidate_kernels.push_back(&kernel); + } + } + + if (candidate_kernels.size() == 0) { + return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), " to ", ToTypeName(out_type_id_), " using function ", - this->name()); + this->name()); } if (candidate_kernels.size() == 1) { - // One match, return it - return candidate_kernels[0]; + // One match, return it + return candidate_kernels[0]; } // Now we are in a casting scenario where we may have both a EXACT_TYPE and @@ -206,41 +206,41 @@ Result<const Kernel*> CastFunction::DispatchExact( if (arg0.kind() == InputType::EXACT_TYPE) { // Bingo. Return it return kernel; - } - } + } + } // We didn't find an exact match. So just return some kernel that matches return candidate_kernels[0]; -} - -Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { - return CallFunction("cast", {value}, &options, ctx); -} - -Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type, - const CastOptions& options, ExecContext* ctx) { - CastOptions options_with_to_type = options; - options_with_to_type.to_type = to_type; - return Cast(value, options_with_to_type, ctx); -} - -Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type, - const CastOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx)); - return result.make_array(); -} - -Result<std::shared_ptr<CastFunction>> GetCastFunction( - const std::shared_ptr<DataType>& to_type) { - return internal::GetCastFunctionInternal(to_type); -} - -bool CanCast(const DataType& from_type, const DataType& to_type) { - internal::EnsureInitCastTable(); +} + +Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { + return CallFunction("cast", {value}, &options, ctx); +} + +Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type, + const CastOptions& options, ExecContext* ctx) { + CastOptions options_with_to_type = options; + options_with_to_type.to_type = to_type; + return Cast(value, options_with_to_type, ctx); +} + +Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type, + const CastOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx)); + return result.make_array(); +} + +Result<std::shared_ptr<CastFunction>> GetCastFunction( + const std::shared_ptr<DataType>& to_type) { + return internal::GetCastFunctionInternal(to_type); +} + +bool CanCast(const DataType& from_type, const DataType& to_type) { + internal::EnsureInitCastTable(); auto it = internal::g_cast_table.find(static_cast<int>(to_type.id())); - if (it == internal::g_cast_table.end()) { - return false; - } + if (it == internal::g_cast_table.end()) { + return false; + } const CastFunction* function = it->second.get(); DCHECK_EQ(function->out_type_id(), to_type.id()); @@ -251,8 +251,8 @@ bool CanCast(const DataType& from_type, const DataType& to_type) { } return false; -} - +} + Result<std::vector<Datum>> Cast(std::vector<Datum> datums, std::vector<ValueDescr> descrs, ExecContext* ctx) { for (size_t i = 0; i != datums.size(); ++i) { @@ -269,5 +269,5 @@ Result<std::vector<Datum>> Cast(std::vector<Datum> datums, std::vector<ValueDesc return datums; } -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h index 131f57f892..a0944ac721 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h @@ -1,156 +1,156 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <memory> -#include <string> -#include <vector> - -#include "arrow/compute/function.h" -#include "arrow/compute/kernel.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class Array; - -namespace compute { - -class ExecContext; - -/// \addtogroup compute-concrete-options -/// @{ - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <vector> + +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class ExecContext; + +/// \addtogroup compute-concrete-options +/// @{ + class ARROW_EXPORT CastOptions : public FunctionOptions { public: explicit CastOptions(bool safe = true); - + constexpr static char const kTypeName[] = "CastOptions"; static CastOptions Safe(std::shared_ptr<DataType> to_type = NULLPTR) { CastOptions safe(true); safe.to_type = std::move(to_type); return safe; } - + static CastOptions Unsafe(std::shared_ptr<DataType> to_type = NULLPTR) { CastOptions unsafe(false); unsafe.to_type = std::move(to_type); return unsafe; } - - // Type being casted to. May be passed separate to eager function - // compute::Cast - std::shared_ptr<DataType> to_type; - - bool allow_int_overflow; - bool allow_time_truncate; - bool allow_time_overflow; - bool allow_decimal_truncate; - bool allow_float_truncate; - // Indicate if conversions from Binary/FixedSizeBinary to string must - // validate the utf8 payload. - bool allow_invalid_utf8; -}; - -/// @} - -// Cast functions are _not_ registered in the FunctionRegistry, though they use -// the same execution machinery -class CastFunction : public ScalarFunction { - public: + + // Type being casted to. May be passed separate to eager function + // compute::Cast + std::shared_ptr<DataType> to_type; + + bool allow_int_overflow; + bool allow_time_truncate; + bool allow_time_overflow; + bool allow_decimal_truncate; + bool allow_float_truncate; + // Indicate if conversions from Binary/FixedSizeBinary to string must + // validate the utf8 payload. + bool allow_invalid_utf8; +}; + +/// @} + +// Cast functions are _not_ registered in the FunctionRegistry, though they use +// the same execution machinery +class CastFunction : public ScalarFunction { + public: CastFunction(std::string name, Type::type out_type_id); - + Type::type out_type_id() const { return out_type_id_; } const std::vector<Type::type>& in_type_ids() const { return in_type_ids_; } - - Status AddKernel(Type::type in_type_id, std::vector<InputType> in_types, - OutputType out_type, ArrayKernelExec exec, - NullHandling::type = NullHandling::INTERSECTION, - MemAllocation::type = MemAllocation::PREALLOCATE); - - // Note, this function toggles off memory allocation and sets the init - // function to CastInit - Status AddKernel(Type::type in_type_id, ScalarKernel kernel); - + + Status AddKernel(Type::type in_type_id, std::vector<InputType> in_types, + OutputType out_type, ArrayKernelExec exec, + NullHandling::type = NullHandling::INTERSECTION, + MemAllocation::type = MemAllocation::PREALLOCATE); + + // Note, this function toggles off memory allocation and sets the init + // function to CastInit + Status AddKernel(Type::type in_type_id, ScalarKernel kernel); + Result<const Kernel*> DispatchExact( - const std::vector<ValueDescr>& values) const override; - - private: + const std::vector<ValueDescr>& values) const override; + + private: std::vector<Type::type> in_type_ids_; const Type::type out_type_id_; -}; - -ARROW_EXPORT -Result<std::shared_ptr<CastFunction>> GetCastFunction( - const std::shared_ptr<DataType>& to_type); - -/// \brief Return true if a cast function is defined -ARROW_EXPORT -bool CanCast(const DataType& from_type, const DataType& to_type); - -// ---------------------------------------------------------------------- -// Convenience invocation APIs for a number of kernels - -/// \brief Cast from one array type to another -/// \param[in] value array to cast -/// \param[in] to_type type to cast to -/// \param[in] options casting options -/// \param[in] ctx the function execution context, optional -/// \return the resulting array -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type, - const CastOptions& options = CastOptions::Safe(), - ExecContext* ctx = NULLPTR); - -/// \brief Cast from one array type to another -/// \param[in] value array to cast -/// \param[in] options casting options. The "to_type" field must be populated -/// \param[in] ctx the function execution context, optional -/// \return the resulting array -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Cast(const Datum& value, const CastOptions& options, - ExecContext* ctx = NULLPTR); - -/// \brief Cast from one value to another -/// \param[in] value datum to cast -/// \param[in] to_type type to cast to -/// \param[in] options casting options -/// \param[in] ctx the function execution context, optional -/// \return the resulting datum -/// -/// \since 1.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type, - const CastOptions& options = CastOptions::Safe(), - ExecContext* ctx = NULLPTR); - +}; + +ARROW_EXPORT +Result<std::shared_ptr<CastFunction>> GetCastFunction( + const std::shared_ptr<DataType>& to_type); + +/// \brief Return true if a cast function is defined +ARROW_EXPORT +bool CanCast(const DataType& from_type, const DataType& to_type); + +// ---------------------------------------------------------------------- +// Convenience invocation APIs for a number of kernels + +/// \brief Cast from one array type to another +/// \param[in] value array to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] ctx the function execution context, optional +/// \return the resulting array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* ctx = NULLPTR); + +/// \brief Cast from one array type to another +/// \param[in] value array to cast +/// \param[in] options casting options. The "to_type" field must be populated +/// \param[in] ctx the function execution context, optional +/// \return the resulting array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Cast(const Datum& value, const CastOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Cast from one value to another +/// \param[in] value datum to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* ctx = NULLPTR); + /// \brief Cast several values simultaneously. Safe cast options are used. /// \param[in] values datums to cast /// \param[in] descrs ValueDescrs to cast to @@ -163,5 +163,5 @@ ARROW_EXPORT Result<std::vector<Datum>> Cast(std::vector<Datum> values, std::vector<ValueDescr> descrs, ExecContext* ctx = NULLPTR); -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast_internal.h index 0105d08a57..97975b8006 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/cast_internal.h @@ -1,43 +1,43 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <memory> -#include <vector> - -#include "arrow/compute/cast.h" // IWYU pragma: keep -#include "arrow/compute/kernel.h" // IWYU pragma: keep -#include "arrow/compute/kernels/codegen_internal.h" // IWYU pragma: keep - -namespace arrow { -namespace compute { -namespace internal { - -using CastState = OptionsWrapper<CastOptions>; - -// See kernels/scalar_cast_*.cc for these -std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts(); -std::vector<std::shared_ptr<CastFunction>> GetNumericCasts(); -std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts(); -std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts(); -std::vector<std::shared_ptr<CastFunction>> GetNestedCasts(); +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <vector> + +#include "arrow/compute/cast.h" // IWYU pragma: keep +#include "arrow/compute/kernel.h" // IWYU pragma: keep +#include "arrow/compute/kernels/codegen_internal.h" // IWYU pragma: keep + +namespace arrow { +namespace compute { +namespace internal { + +using CastState = OptionsWrapper<CastOptions>; + +// See kernels/scalar_cast_*.cc for these +std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts(); +std::vector<std::shared_ptr<CastFunction>> GetNumericCasts(); +std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts(); +std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts(); +std::vector<std::shared_ptr<CastFunction>> GetNestedCasts(); std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts(); - -} // namespace internal -} // namespace compute -} // namespace arrow + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.cc index 63f8d39f55..55fb256a36 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.cc @@ -1,64 +1,64 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/exec.h" - -#include <algorithm> -#include <cstddef> -#include <cstdint> -#include <memory> -#include <utility> -#include <vector> - -#include "arrow/array/array_base.h" -#include "arrow/array/array_primitive.h" -#include "arrow/array/data.h" -#include "arrow/array/util.h" -#include "arrow/buffer.h" -#include "arrow/chunked_array.h" -#include "arrow/compute/exec_internal.h" -#include "arrow/compute/function.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/registry.h" -#include "arrow/compute/util_internal.h" -#include "arrow/datum.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include "arrow/array/array_base.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/data.h" +#include "arrow/array/util.h" +#include "arrow/buffer.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/util_internal.h" +#include "arrow/datum.h" #include "arrow/pretty_print.h" #include "arrow/record_batch.h" -#include "arrow/scalar.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_ops.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/cpu_info.h" -#include "arrow/util/logging.h" +#include "arrow/scalar.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" #include "arrow/util/make_unique.h" #include "arrow/util/vector.h" - -namespace arrow { - -using internal::BitmapAnd; -using internal::checked_cast; -using internal::CopyBitmap; -using internal::CpuInfo; - -namespace compute { - + +namespace arrow { + +using internal::BitmapAnd; +using internal::checked_cast; +using internal::CopyBitmap; +using internal::CpuInfo; + +namespace compute { + ExecContext* default_exec_context() { static ExecContext default_ctx; return &default_ctx; @@ -157,22 +157,22 @@ Result<std::shared_ptr<RecordBatch>> ExecBatch::ToRecordBatch( return RecordBatch::Make(std::move(schema), length, std::move(columns)); } -namespace { - -Result<std::shared_ptr<Buffer>> AllocateDataBuffer(KernelContext* ctx, int64_t length, - int bit_width) { - if (bit_width == 1) { - return ctx->AllocateBitmap(length); - } else { +namespace { + +Result<std::shared_ptr<Buffer>> AllocateDataBuffer(KernelContext* ctx, int64_t length, + int bit_width) { + if (bit_width == 1) { + return ctx->AllocateBitmap(length); + } else { int64_t buffer_size = BitUtil::BytesForBits(length * bit_width); - return ctx->Allocate(buffer_size); - } -} - + return ctx->Allocate(buffer_size); + } +} + struct BufferPreallocation { explicit BufferPreallocation(int bit_width = -1, int added_length = 0) : bit_width(bit_width), added_length(added_length) {} - + int bit_width; int added_length; }; @@ -182,7 +182,7 @@ void ComputeDataPreallocate(const DataType& type, if (is_fixed_width(type.id()) && type.id() != Type::NA) { widths->emplace_back(checked_cast<const FixedWidthType&>(type).bit_width()); return; - } + } // Preallocate binary and list offsets switch (type.id()) { case Type::BINARY: @@ -199,12 +199,12 @@ void ComputeDataPreallocate(const DataType& type, default: break; } -} - -} // namespace - -namespace detail { - +} + +} // namespace + +namespace detail { + Status CheckAllValues(const std::vector<Datum>& values) { for (const auto& value : values) { if (!value.is_value()) { @@ -215,102 +215,102 @@ Status CheckAllValues(const std::vector<Datum>& values) { return Status::OK(); } -ExecBatchIterator::ExecBatchIterator(std::vector<Datum> args, int64_t length, - int64_t max_chunksize) - : args_(std::move(args)), - position_(0), - length_(length), - max_chunksize_(max_chunksize) { - chunk_indexes_.resize(args_.size(), 0); - chunk_positions_.resize(args_.size(), 0); -} - -Result<std::unique_ptr<ExecBatchIterator>> ExecBatchIterator::Make( - std::vector<Datum> args, int64_t max_chunksize) { - for (const auto& arg : args) { - if (!(arg.is_arraylike() || arg.is_scalar())) { - return Status::Invalid( - "ExecBatchIterator only works with Scalar, Array, and " - "ChunkedArray arguments"); - } - } - - // If the arguments are all scalars, then the length is 1 - int64_t length = 1; - - bool length_set = false; - for (auto& arg : args) { - if (arg.is_scalar()) { - continue; - } - if (!length_set) { - length = arg.length(); - length_set = true; - } else { - if (arg.length() != length) { - return Status::Invalid("Array arguments must all be the same length"); - } - } - } - - max_chunksize = std::min(length, max_chunksize); - - return std::unique_ptr<ExecBatchIterator>( - new ExecBatchIterator(std::move(args), length, max_chunksize)); -} - -bool ExecBatchIterator::Next(ExecBatch* batch) { - if (position_ == length_) { - return false; - } - - // Determine how large the common contiguous "slice" of all the arguments is - int64_t iteration_size = std::min(length_ - position_, max_chunksize_); - - // If length_ is 0, then this loop will never execute - for (size_t i = 0; i < args_.size() && iteration_size > 0; ++i) { - // If the argument is not a chunked array, it's either a Scalar or Array, - // in which case it doesn't influence the size of this batch. Note that if - // the args are all scalars the batch length is 1 - if (args_[i].kind() != Datum::CHUNKED_ARRAY) { - continue; - } - const ChunkedArray& arg = *args_[i].chunked_array(); - std::shared_ptr<Array> current_chunk; - while (true) { - current_chunk = arg.chunk(chunk_indexes_[i]); - if (chunk_positions_[i] == current_chunk->length()) { - // Chunk is zero-length, or was exhausted in the previous iteration - chunk_positions_[i] = 0; - ++chunk_indexes_[i]; - continue; - } - break; - } - iteration_size = - std::min(current_chunk->length() - chunk_positions_[i], iteration_size); - } - - // Now, fill the batch - batch->values.resize(args_.size()); - batch->length = iteration_size; - for (size_t i = 0; i < args_.size(); ++i) { - if (args_[i].is_scalar()) { - batch->values[i] = args_[i].scalar(); - } else if (args_[i].is_array()) { - batch->values[i] = args_[i].array()->Slice(position_, iteration_size); - } else { - const ChunkedArray& carr = *args_[i].chunked_array(); - const auto& chunk = carr.chunk(chunk_indexes_[i]); - batch->values[i] = chunk->data()->Slice(chunk_positions_[i], iteration_size); - chunk_positions_[i] += iteration_size; - } - } - position_ += iteration_size; - DCHECK_LE(position_, length_); - return true; -} - +ExecBatchIterator::ExecBatchIterator(std::vector<Datum> args, int64_t length, + int64_t max_chunksize) + : args_(std::move(args)), + position_(0), + length_(length), + max_chunksize_(max_chunksize) { + chunk_indexes_.resize(args_.size(), 0); + chunk_positions_.resize(args_.size(), 0); +} + +Result<std::unique_ptr<ExecBatchIterator>> ExecBatchIterator::Make( + std::vector<Datum> args, int64_t max_chunksize) { + for (const auto& arg : args) { + if (!(arg.is_arraylike() || arg.is_scalar())) { + return Status::Invalid( + "ExecBatchIterator only works with Scalar, Array, and " + "ChunkedArray arguments"); + } + } + + // If the arguments are all scalars, then the length is 1 + int64_t length = 1; + + bool length_set = false; + for (auto& arg : args) { + if (arg.is_scalar()) { + continue; + } + if (!length_set) { + length = arg.length(); + length_set = true; + } else { + if (arg.length() != length) { + return Status::Invalid("Array arguments must all be the same length"); + } + } + } + + max_chunksize = std::min(length, max_chunksize); + + return std::unique_ptr<ExecBatchIterator>( + new ExecBatchIterator(std::move(args), length, max_chunksize)); +} + +bool ExecBatchIterator::Next(ExecBatch* batch) { + if (position_ == length_) { + return false; + } + + // Determine how large the common contiguous "slice" of all the arguments is + int64_t iteration_size = std::min(length_ - position_, max_chunksize_); + + // If length_ is 0, then this loop will never execute + for (size_t i = 0; i < args_.size() && iteration_size > 0; ++i) { + // If the argument is not a chunked array, it's either a Scalar or Array, + // in which case it doesn't influence the size of this batch. Note that if + // the args are all scalars the batch length is 1 + if (args_[i].kind() != Datum::CHUNKED_ARRAY) { + continue; + } + const ChunkedArray& arg = *args_[i].chunked_array(); + std::shared_ptr<Array> current_chunk; + while (true) { + current_chunk = arg.chunk(chunk_indexes_[i]); + if (chunk_positions_[i] == current_chunk->length()) { + // Chunk is zero-length, or was exhausted in the previous iteration + chunk_positions_[i] = 0; + ++chunk_indexes_[i]; + continue; + } + break; + } + iteration_size = + std::min(current_chunk->length() - chunk_positions_[i], iteration_size); + } + + // Now, fill the batch + batch->values.resize(args_.size()); + batch->length = iteration_size; + for (size_t i = 0; i < args_.size(); ++i) { + if (args_[i].is_scalar()) { + batch->values[i] = args_[i].scalar(); + } else if (args_[i].is_array()) { + batch->values[i] = args_[i].array()->Slice(position_, iteration_size); + } else { + const ChunkedArray& carr = *args_[i].chunked_array(); + const auto& chunk = carr.chunk(chunk_indexes_[i]); + batch->values[i] = chunk->data()->Slice(chunk_positions_[i], iteration_size); + chunk_positions_[i] += iteration_size; + } + } + position_ += iteration_size; + DCHECK_LE(position_, length_); + return true; +} + namespace { struct NullGeneralization { @@ -327,7 +327,7 @@ struct NullGeneralization { const auto& arr = *datum.array(); - // Do not count the bits if they haven't been counted already + // Do not count the bits if they haven't been counted already const int64_t known_null_count = arr.null_count.load(); if ((known_null_count == 0) || (arr.buffers[0] == NULLPTR)) { return ALL_VALID; @@ -338,88 +338,88 @@ struct NullGeneralization { } return PERHAPS_NULL; - } + } }; - -// Null propagation implementation that deals both with preallocated bitmaps -// and maybe-to-be allocated bitmaps -// -// If the bitmap is preallocated, it MUST be populated (since it might be a -// view of a much larger bitmap). If it isn't preallocated, then we have -// more flexibility. -// -// * If the batch has no nulls, then we do nothing -// * If only a single array has nulls, and its offset is a multiple of 8, -// then we can zero-copy the bitmap into the output -// * Otherwise, we allocate the bitmap and populate it -class NullPropagator { - public: - NullPropagator(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) - : ctx_(ctx), batch_(batch), output_(output) { + +// Null propagation implementation that deals both with preallocated bitmaps +// and maybe-to-be allocated bitmaps +// +// If the bitmap is preallocated, it MUST be populated (since it might be a +// view of a much larger bitmap). If it isn't preallocated, then we have +// more flexibility. +// +// * If the batch has no nulls, then we do nothing +// * If only a single array has nulls, and its offset is a multiple of 8, +// then we can zero-copy the bitmap into the output +// * Otherwise, we allocate the bitmap and populate it +class NullPropagator { + public: + NullPropagator(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) + : ctx_(ctx), batch_(batch), output_(output) { for (const Datum& datum : batch_.values) { auto null_generalization = NullGeneralization::Get(datum); if (null_generalization == NullGeneralization::ALL_NULL) { is_all_null_ = true; - } + } if (null_generalization != NullGeneralization::ALL_VALID && datum.kind() == Datum::ARRAY) { arrays_with_nulls_.push_back(datum.array().get()); } - } - - if (output->buffers[0] != nullptr) { - bitmap_preallocated_ = true; - SetBitmap(output_->buffers[0].get()); - } - } - - void SetBitmap(Buffer* bitmap) { bitmap_ = bitmap->mutable_data(); } - - Status EnsureAllocated() { - if (bitmap_preallocated_) { - return Status::OK(); - } - ARROW_ASSIGN_OR_RAISE(output_->buffers[0], ctx_->AllocateBitmap(output_->length)); - SetBitmap(output_->buffers[0].get()); - return Status::OK(); - } - + } + + if (output->buffers[0] != nullptr) { + bitmap_preallocated_ = true; + SetBitmap(output_->buffers[0].get()); + } + } + + void SetBitmap(Buffer* bitmap) { bitmap_ = bitmap->mutable_data(); } + + Status EnsureAllocated() { + if (bitmap_preallocated_) { + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE(output_->buffers[0], ctx_->AllocateBitmap(output_->length)); + SetBitmap(output_->buffers[0].get()); + return Status::OK(); + } + Status AllNullShortCircuit() { // OK, the output should be all null output_->null_count = output_->length; - + if (bitmap_preallocated_) { BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, false); return Status::OK(); } - // Walk all the values with nulls instead of breaking on the first in case - // we find a bitmap that can be reused in the non-preallocated case + // Walk all the values with nulls instead of breaking on the first in case + // we find a bitmap that can be reused in the non-preallocated case for (const ArrayData* arr : arrays_with_nulls_) { if (arr->null_count.load() == arr->length && arr->buffers[0] != nullptr) { // Reuse this all null bitmap output_->buffers[0] = arr->buffers[0]; return Status::OK(); - } - } - + } + } + RETURN_NOT_OK(EnsureAllocated()); BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, false); return Status::OK(); - } - - Status PropagateSingle() { - // One array + } + + Status PropagateSingle() { + // One array const ArrayData& arr = *arrays_with_nulls_[0]; - const std::shared_ptr<Buffer>& arr_bitmap = arr.buffers[0]; - - // Reuse the null count if it's known - output_->null_count = arr.null_count.load(); - - if (bitmap_preallocated_) { - CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, output_->offset); + const std::shared_ptr<Buffer>& arr_bitmap = arr.buffers[0]; + + // Reuse the null count if it's known + output_->null_count = arr.null_count.load(); + + if (bitmap_preallocated_) { + CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, output_->offset); return Status::OK(); } @@ -437,144 +437,144 @@ class NullPropagator { } else if (arr.offset % 8 == 0) { output_->buffers[0] = SliceBuffer(arr_bitmap, arr.offset / 8, BitUtil::BytesForBits(arr.length)); - } else { + } else { RETURN_NOT_OK(EnsureAllocated()); CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, /*dst_offset=*/0); - } - return Status::OK(); - } - - Status PropagateMultiple() { - // More than one array. We use BitmapAnd to intersect their bitmaps - - // Do not compute the intersection null count until it's needed - RETURN_NOT_OK(EnsureAllocated()); - - auto Accumulate = [&](const ArrayData& left, const ArrayData& right) { - DCHECK(left.buffers[0]); - DCHECK(right.buffers[0]); - BitmapAnd(left.buffers[0]->data(), left.offset, right.buffers[0]->data(), - right.offset, output_->length, output_->offset, - output_->buffers[0]->mutable_data()); - }; - + } + return Status::OK(); + } + + Status PropagateMultiple() { + // More than one array. We use BitmapAnd to intersect their bitmaps + + // Do not compute the intersection null count until it's needed + RETURN_NOT_OK(EnsureAllocated()); + + auto Accumulate = [&](const ArrayData& left, const ArrayData& right) { + DCHECK(left.buffers[0]); + DCHECK(right.buffers[0]); + BitmapAnd(left.buffers[0]->data(), left.offset, right.buffers[0]->data(), + right.offset, output_->length, output_->offset, + output_->buffers[0]->mutable_data()); + }; + DCHECK_GT(arrays_with_nulls_.size(), 1); - - // Seed the output bitmap with the & of the first two bitmaps + + // Seed the output bitmap with the & of the first two bitmaps Accumulate(*arrays_with_nulls_[0], *arrays_with_nulls_[1]); - - // Accumulate the rest + + // Accumulate the rest for (size_t i = 2; i < arrays_with_nulls_.size(); ++i) { Accumulate(*output_, *arrays_with_nulls_[i]); - } - return Status::OK(); - } - - Status Execute() { + } + return Status::OK(); + } + + Status Execute() { if (is_all_null_) { // An all-null value (scalar null or all-null array) gives us a short // circuit opportunity return AllNullShortCircuit(); - } - - // At this point, by construction we know that all of the values in + } + + // At this point, by construction we know that all of the values in // arrays_with_nulls_ are arrays that are not all null. So there are a - // few cases: - // - // * No arrays. This is a no-op w/o preallocation but when the bitmap is - // pre-allocated we have to fill it with 1's - // * One array, whose bitmap can be zero-copied (w/o preallocation, and - // when no byte is split) or copied (split byte or w/ preallocation) - // * More than one array, we must compute the intersection of all the - // bitmaps - // - // BUT, if the output offset is nonzero for some reason, we copy into the - // output unconditionally - - output_->null_count = kUnknownNullCount; - + // few cases: + // + // * No arrays. This is a no-op w/o preallocation but when the bitmap is + // pre-allocated we have to fill it with 1's + // * One array, whose bitmap can be zero-copied (w/o preallocation, and + // when no byte is split) or copied (split byte or w/ preallocation) + // * More than one array, we must compute the intersection of all the + // bitmaps + // + // BUT, if the output offset is nonzero for some reason, we copy into the + // output unconditionally + + output_->null_count = kUnknownNullCount; + if (arrays_with_nulls_.empty()) { - // No arrays with nulls case - output_->null_count = 0; - if (bitmap_preallocated_) { - BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, true); - } - return Status::OK(); + // No arrays with nulls case + output_->null_count = 0; + if (bitmap_preallocated_) { + BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, true); + } + return Status::OK(); } if (arrays_with_nulls_.size() == 1) { - return PropagateSingle(); - } + return PropagateSingle(); + } return PropagateMultiple(); - } - - private: - KernelContext* ctx_; - const ExecBatch& batch_; + } + + private: + KernelContext* ctx_; + const ExecBatch& batch_; std::vector<const ArrayData*> arrays_with_nulls_; bool is_all_null_ = false; - ArrayData* output_; - uint8_t* bitmap_; - bool bitmap_preallocated_ = false; -}; - -std::shared_ptr<ChunkedArray> ToChunkedArray(const std::vector<Datum>& values, - const std::shared_ptr<DataType>& type) { - std::vector<std::shared_ptr<Array>> arrays; + ArrayData* output_; + uint8_t* bitmap_; + bool bitmap_preallocated_ = false; +}; + +std::shared_ptr<ChunkedArray> ToChunkedArray(const std::vector<Datum>& values, + const std::shared_ptr<DataType>& type) { + std::vector<std::shared_ptr<Array>> arrays; arrays.reserve(values.size()); for (const Datum& val : values) { if (val.length() == 0) { - // Skip empty chunks - continue; - } + // Skip empty chunks + continue; + } arrays.emplace_back(val.make_array()); - } + } return std::make_shared<ChunkedArray>(std::move(arrays), type); -} - -bool HaveChunkedArray(const std::vector<Datum>& values) { - for (const auto& value : values) { - if (value.kind() == Datum::CHUNKED_ARRAY) { - return true; - } - } - return false; -} - +} + +bool HaveChunkedArray(const std::vector<Datum>& values) { + for (const auto& value : values) { + if (value.kind() == Datum::CHUNKED_ARRAY) { + return true; + } + } + return false; +} + template <typename KernelType> class KernelExecutorImpl : public KernelExecutor { - public: + public: Status Init(KernelContext* kernel_ctx, KernelInitArgs args) override { kernel_ctx_ = kernel_ctx; kernel_ = static_cast<const KernelType*>(args.kernel); - + // Resolve the output descriptor for this kernel ARROW_ASSIGN_OR_RAISE( output_descr_, kernel_->signature->out_type().Resolve(kernel_ctx_, args.inputs)); - - return Status::OK(); - } - + + return Status::OK(); + } + protected: - // This is overridden by the VectorExecutor - virtual Status SetupArgIteration(const std::vector<Datum>& args) { + // This is overridden by the VectorExecutor + virtual Status SetupArgIteration(const std::vector<Datum>& args) { ARROW_ASSIGN_OR_RAISE( batch_iterator_, ExecBatchIterator::Make(args, exec_context()->exec_chunksize())); - return Status::OK(); - } - - Result<std::shared_ptr<ArrayData>> PrepareOutput(int64_t length) { - auto out = std::make_shared<ArrayData>(output_descr_.type, length); - out->buffers.resize(output_num_buffers_); - - if (validity_preallocated_) { + return Status::OK(); + } + + Result<std::shared_ptr<ArrayData>> PrepareOutput(int64_t length) { + auto out = std::make_shared<ArrayData>(output_descr_.type, length); + out->buffers.resize(output_num_buffers_); + + if (validity_preallocated_) { ARROW_ASSIGN_OR_RAISE(out->buffers[0], kernel_ctx_->AllocateBitmap(length)); - } + } if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { out->null_count = 0; - } + } for (size_t i = 0; i < data_preallocated_.size(); ++i) { const auto& prealloc = data_preallocated_[i]; if (prealloc.bit_width >= 0) { @@ -584,183 +584,183 @@ class KernelExecutorImpl : public KernelExecutor { prealloc.bit_width)); } } - return out; - } - + return out; + } + ExecContext* exec_context() { return kernel_ctx_->exec_context(); } KernelState* state() { return kernel_ctx_->state(); } - - // Not all of these members are used for every executor type - + + // Not all of these members are used for every executor type + KernelContext* kernel_ctx_; - const KernelType* kernel_; - std::unique_ptr<ExecBatchIterator> batch_iterator_; - ValueDescr output_descr_; - - int output_num_buffers_; - - // If true, then memory is preallocated for the validity bitmap with the same - // strategy as the data buffer(s). - bool validity_preallocated_ = false; + const KernelType* kernel_; + std::unique_ptr<ExecBatchIterator> batch_iterator_; + ValueDescr output_descr_; + + int output_num_buffers_; + + // If true, then memory is preallocated for the validity bitmap with the same + // strategy as the data buffer(s). + bool validity_preallocated_ = false; // The kernel writes into data buffers preallocated for these bit widths // (0 indicates no preallocation); std::vector<BufferPreallocation> data_preallocated_; -}; - +}; + class ScalarExecutor : public KernelExecutorImpl<ScalarKernel> { - public: - Status Execute(const std::vector<Datum>& args, ExecListener* listener) override { - RETURN_NOT_OK(PrepareExecute(args)); - ExecBatch batch; - while (batch_iterator_->Next(&batch)) { - RETURN_NOT_OK(ExecuteBatch(batch, listener)); - } - if (preallocate_contiguous_) { - // If we preallocated one big chunk, since the kernel execution is - // completed, we can now emit it - RETURN_NOT_OK(listener->OnResult(std::move(preallocated_))); - } - return Status::OK(); - } - - Datum WrapResults(const std::vector<Datum>& inputs, - const std::vector<Datum>& outputs) override { - if (output_descr_.shape == ValueDescr::SCALAR) { - DCHECK_GT(outputs.size(), 0); - if (outputs.size() == 1) { - // Return as SCALAR - return outputs[0]; - } else { - // Return as COLLECTION - return outputs; - } - } else { - // If execution yielded multiple chunks (because large arrays were split - // based on the ExecContext parameters, then the result is a ChunkedArray - if (HaveChunkedArray(inputs) || outputs.size() > 1) { - return ToChunkedArray(outputs, output_descr_.type); - } else if (outputs.size() == 1) { - // Outputs have just one element - return outputs[0]; - } else { - // XXX: In the case where no outputs are omitted, is returning a 0-length - // array always the correct move? + public: + Status Execute(const std::vector<Datum>& args, ExecListener* listener) override { + RETURN_NOT_OK(PrepareExecute(args)); + ExecBatch batch; + while (batch_iterator_->Next(&batch)) { + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + if (preallocate_contiguous_) { + // If we preallocated one big chunk, since the kernel execution is + // completed, we can now emit it + RETURN_NOT_OK(listener->OnResult(std::move(preallocated_))); + } + return Status::OK(); + } + + Datum WrapResults(const std::vector<Datum>& inputs, + const std::vector<Datum>& outputs) override { + if (output_descr_.shape == ValueDescr::SCALAR) { + DCHECK_GT(outputs.size(), 0); + if (outputs.size() == 1) { + // Return as SCALAR + return outputs[0]; + } else { + // Return as COLLECTION + return outputs; + } + } else { + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray + if (HaveChunkedArray(inputs) || outputs.size() > 1) { + return ToChunkedArray(outputs, output_descr_.type); + } else if (outputs.size() == 1) { + // Outputs have just one element + return outputs[0]; + } else { + // XXX: In the case where no outputs are omitted, is returning a 0-length + // array always the correct move? return MakeArrayOfNull(output_descr_.type, /*length=*/0, exec_context()->memory_pool()) .ValueOrDie(); - } - } - } - - protected: - Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { - Datum out; - RETURN_NOT_OK(PrepareNextOutput(batch, &out)); - - if (output_descr_.shape == ValueDescr::ARRAY) { - ArrayData* out_arr = out.mutable_array(); - if (kernel_->null_handling == NullHandling::INTERSECTION) { + } + } + } + + protected: + Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { + Datum out; + RETURN_NOT_OK(PrepareNextOutput(batch, &out)); + + if (output_descr_.shape == ValueDescr::ARRAY) { + ArrayData* out_arr = out.mutable_array(); + if (kernel_->null_handling == NullHandling::INTERSECTION) { RETURN_NOT_OK(PropagateNulls(kernel_ctx_, batch, out_arr)); - } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { - out_arr->null_count = 0; - } - } else { - if (kernel_->null_handling == NullHandling::INTERSECTION) { - // set scalar validity - out.scalar()->is_valid = - std::all_of(batch.values.begin(), batch.values.end(), - [](const Datum& input) { return input.scalar()->is_valid; }); - } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { - out.scalar()->is_valid = true; - } - } - + } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { + out_arr->null_count = 0; + } + } else { + if (kernel_->null_handling == NullHandling::INTERSECTION) { + // set scalar validity + out.scalar()->is_valid = + std::all_of(batch.values.begin(), batch.values.end(), + [](const Datum& input) { return input.scalar()->is_valid; }); + } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { + out.scalar()->is_valid = true; + } + } + RETURN_NOT_OK(kernel_->exec(kernel_ctx_, batch, &out)); - if (!preallocate_contiguous_) { - // If we are producing chunked output rather than one big array, then - // emit each chunk as soon as it's available - RETURN_NOT_OK(listener->OnResult(std::move(out))); - } - return Status::OK(); - } - - Status PrepareExecute(const std::vector<Datum>& args) { + if (!preallocate_contiguous_) { + // If we are producing chunked output rather than one big array, then + // emit each chunk as soon as it's available + RETURN_NOT_OK(listener->OnResult(std::move(out))); + } + return Status::OK(); + } + + Status PrepareExecute(const std::vector<Datum>& args) { RETURN_NOT_OK(this->SetupArgIteration(args)); - - if (output_descr_.shape == ValueDescr::ARRAY) { - // If the executor is configured to produce a single large Array output for - // kernels supporting preallocation, then we do so up front and then - // iterate over slices of that large array. Otherwise, we preallocate prior - // to processing each batch emitted from the ExecBatchIterator - RETURN_NOT_OK(SetupPreallocation(batch_iterator_->length())); - } - return Status::OK(); - } - - // We must accommodate two different modes of execution for preallocated - // execution - // - // * A single large ("contiguous") allocation that we populate with results - // on a chunkwise basis according to the ExecBatchIterator. This permits - // parallelization even if the objective is to obtain a single Array or - // ChunkedArray at the end - // * A standalone buffer preallocation for each chunk emitted from the - // ExecBatchIterator - // - // When data buffer preallocation is not possible (e.g. with BINARY / STRING - // outputs), then contiguous results are only possible if the input is - // contiguous. - - Status PrepareNextOutput(const ExecBatch& batch, Datum* out) { - if (output_descr_.shape == ValueDescr::ARRAY) { - if (preallocate_contiguous_) { - // The output is already fully preallocated - const int64_t batch_start_position = batch_iterator_->position() - batch.length; - - if (batch.length < batch_iterator_->length()) { - // If this is a partial execution, then we write into a slice of - // preallocated_ - out->value = preallocated_->Slice(batch_start_position, batch.length); - } else { - // Otherwise write directly into preallocated_. The main difference - // computationally (versus the Slice approach) is that the null_count - // may not need to be recomputed in the result - out->value = preallocated_; - } - } else { - // We preallocate (maybe) only for the output of processing the current - // batch - ARROW_ASSIGN_OR_RAISE(out->value, PrepareOutput(batch.length)); - } - } else { - // For scalar outputs, we set a null scalar of the correct type to - // communicate the output type to the kernel if needed - // - // XXX: Is there some way to avoid this step? - out->value = MakeNullScalar(output_descr_.type); - } - return Status::OK(); - } - - Status SetupPreallocation(int64_t total_length) { - output_num_buffers_ = static_cast<int>(output_descr_.type->layout().buffers.size()); - - // Decide if we need to preallocate memory for this kernel - validity_preallocated_ = - (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && + + if (output_descr_.shape == ValueDescr::ARRAY) { + // If the executor is configured to produce a single large Array output for + // kernels supporting preallocation, then we do so up front and then + // iterate over slices of that large array. Otherwise, we preallocate prior + // to processing each batch emitted from the ExecBatchIterator + RETURN_NOT_OK(SetupPreallocation(batch_iterator_->length())); + } + return Status::OK(); + } + + // We must accommodate two different modes of execution for preallocated + // execution + // + // * A single large ("contiguous") allocation that we populate with results + // on a chunkwise basis according to the ExecBatchIterator. This permits + // parallelization even if the objective is to obtain a single Array or + // ChunkedArray at the end + // * A standalone buffer preallocation for each chunk emitted from the + // ExecBatchIterator + // + // When data buffer preallocation is not possible (e.g. with BINARY / STRING + // outputs), then contiguous results are only possible if the input is + // contiguous. + + Status PrepareNextOutput(const ExecBatch& batch, Datum* out) { + if (output_descr_.shape == ValueDescr::ARRAY) { + if (preallocate_contiguous_) { + // The output is already fully preallocated + const int64_t batch_start_position = batch_iterator_->position() - batch.length; + + if (batch.length < batch_iterator_->length()) { + // If this is a partial execution, then we write into a slice of + // preallocated_ + out->value = preallocated_->Slice(batch_start_position, batch.length); + } else { + // Otherwise write directly into preallocated_. The main difference + // computationally (versus the Slice approach) is that the null_count + // may not need to be recomputed in the result + out->value = preallocated_; + } + } else { + // We preallocate (maybe) only for the output of processing the current + // batch + ARROW_ASSIGN_OR_RAISE(out->value, PrepareOutput(batch.length)); + } + } else { + // For scalar outputs, we set a null scalar of the correct type to + // communicate the output type to the kernel if needed + // + // XXX: Is there some way to avoid this step? + out->value = MakeNullScalar(output_descr_.type); + } + return Status::OK(); + } + + Status SetupPreallocation(int64_t total_length) { + output_num_buffers_ = static_cast<int>(output_descr_.type->layout().buffers.size()); + + // Decide if we need to preallocate memory for this kernel + validity_preallocated_ = + (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL && output_descr_.type->id() != Type::NA); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); } - + // Contiguous preallocation only possible on non-nested types if all // buffers are preallocated. Otherwise, we must go chunk-by-chunk. - // + // // Some kernels are also unable to write into sliced outputs, so we respect the // kernel's attributes. - preallocate_contiguous_ = + preallocate_contiguous_ = (exec_context()->preallocate_contiguous() && kernel_->can_write_into_slices && validity_preallocated_ && !is_nested(output_descr_.type->id()) && !is_dictionary(output_descr_.type->id()) && @@ -769,202 +769,202 @@ class ScalarExecutor : public KernelExecutorImpl<ScalarKernel> { [](const BufferPreallocation& prealloc) { return prealloc.bit_width >= 0; })); - if (preallocate_contiguous_) { - ARROW_ASSIGN_OR_RAISE(preallocated_, PrepareOutput(total_length)); - } - return Status::OK(); - } - - // If true, and the kernel and output type supports preallocation (for both - // the validity and data buffers), then we allocate one big array and then - // iterate through it while executing the kernel in chunks - bool preallocate_contiguous_ = false; - - // For storing a contiguous preallocation per above. Unused otherwise - std::shared_ptr<ArrayData> preallocated_; -}; - -Status PackBatchNoChunks(const std::vector<Datum>& args, ExecBatch* out) { - int64_t length = 0; - for (const auto& arg : args) { - switch (arg.kind()) { - case Datum::SCALAR: - case Datum::ARRAY: + if (preallocate_contiguous_) { + ARROW_ASSIGN_OR_RAISE(preallocated_, PrepareOutput(total_length)); + } + return Status::OK(); + } + + // If true, and the kernel and output type supports preallocation (for both + // the validity and data buffers), then we allocate one big array and then + // iterate through it while executing the kernel in chunks + bool preallocate_contiguous_ = false; + + // For storing a contiguous preallocation per above. Unused otherwise + std::shared_ptr<ArrayData> preallocated_; +}; + +Status PackBatchNoChunks(const std::vector<Datum>& args, ExecBatch* out) { + int64_t length = 0; + for (const auto& arg : args) { + switch (arg.kind()) { + case Datum::SCALAR: + case Datum::ARRAY: case Datum::CHUNKED_ARRAY: - length = std::max(arg.length(), length); - break; - default: - DCHECK(false); - break; - } - } - out->length = length; - out->values = args; - return Status::OK(); -} - + length = std::max(arg.length(), length); + break; + default: + DCHECK(false); + break; + } + } + out->length = length; + out->values = args; + return Status::OK(); +} + class VectorExecutor : public KernelExecutorImpl<VectorKernel> { - public: - Status Execute(const std::vector<Datum>& args, ExecListener* listener) override { - RETURN_NOT_OK(PrepareExecute(args)); - ExecBatch batch; - if (kernel_->can_execute_chunkwise) { - while (batch_iterator_->Next(&batch)) { - RETURN_NOT_OK(ExecuteBatch(batch, listener)); - } - } else { - RETURN_NOT_OK(PackBatchNoChunks(args, &batch)); - RETURN_NOT_OK(ExecuteBatch(batch, listener)); - } - return Finalize(listener); - } - - Datum WrapResults(const std::vector<Datum>& inputs, - const std::vector<Datum>& outputs) override { - // If execution yielded multiple chunks (because large arrays were split - // based on the ExecContext parameters, then the result is a ChunkedArray + public: + Status Execute(const std::vector<Datum>& args, ExecListener* listener) override { + RETURN_NOT_OK(PrepareExecute(args)); + ExecBatch batch; + if (kernel_->can_execute_chunkwise) { + while (batch_iterator_->Next(&batch)) { + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + } else { + RETURN_NOT_OK(PackBatchNoChunks(args, &batch)); + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + return Finalize(listener); + } + + Datum WrapResults(const std::vector<Datum>& inputs, + const std::vector<Datum>& outputs) override { + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray if (kernel_->output_chunked && (HaveChunkedArray(inputs) || outputs.size() > 1)) { return ToChunkedArray(outputs, output_descr_.type); } else if (outputs.size() == 1) { // Outputs have just one element return outputs[0]; - } else { + } else { // XXX: In the case where no outputs are omitted, is returning a 0-length // array always the correct move? return MakeArrayOfNull(output_descr_.type, /*length=*/0).ValueOrDie(); - } - } - - protected: - Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { - if (batch.length == 0) { - // Skip empty batches. This may only happen when not using - // ExecBatchIterator - return Status::OK(); - } - Datum out; - if (output_descr_.shape == ValueDescr::ARRAY) { - // We preallocate (maybe) only for the output of processing the current - // batch - ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(batch.length)); - } - - if (kernel_->null_handling == NullHandling::INTERSECTION && - output_descr_.shape == ValueDescr::ARRAY) { + } + } + + protected: + Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { + if (batch.length == 0) { + // Skip empty batches. This may only happen when not using + // ExecBatchIterator + return Status::OK(); + } + Datum out; + if (output_descr_.shape == ValueDescr::ARRAY) { + // We preallocate (maybe) only for the output of processing the current + // batch + ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(batch.length)); + } + + if (kernel_->null_handling == NullHandling::INTERSECTION && + output_descr_.shape == ValueDescr::ARRAY) { RETURN_NOT_OK(PropagateNulls(kernel_ctx_, batch, out.mutable_array())); - } + } RETURN_NOT_OK(kernel_->exec(kernel_ctx_, batch, &out)); - if (!kernel_->finalize) { - // If there is no result finalizer (e.g. for hash-based functions, we can - // emit the processed batch right away rather than waiting - RETURN_NOT_OK(listener->OnResult(std::move(out))); - } else { - results_.emplace_back(std::move(out)); - } - return Status::OK(); - } - - Status Finalize(ExecListener* listener) { - if (kernel_->finalize) { - // Intermediate results require post-processing after the execution is - // completed (possibly involving some accumulated state) + if (!kernel_->finalize) { + // If there is no result finalizer (e.g. for hash-based functions, we can + // emit the processed batch right away rather than waiting + RETURN_NOT_OK(listener->OnResult(std::move(out))); + } else { + results_.emplace_back(std::move(out)); + } + return Status::OK(); + } + + Status Finalize(ExecListener* listener) { + if (kernel_->finalize) { + // Intermediate results require post-processing after the execution is + // completed (possibly involving some accumulated state) RETURN_NOT_OK(kernel_->finalize(kernel_ctx_, &results_)); - for (const auto& result : results_) { - RETURN_NOT_OK(listener->OnResult(result)); - } - } - return Status::OK(); - } - - Status SetupArgIteration(const std::vector<Datum>& args) override { - if (kernel_->can_execute_chunkwise) { + for (const auto& result : results_) { + RETURN_NOT_OK(listener->OnResult(result)); + } + } + return Status::OK(); + } + + Status SetupArgIteration(const std::vector<Datum>& args) override { + if (kernel_->can_execute_chunkwise) { ARROW_ASSIGN_OR_RAISE(batch_iterator_, ExecBatchIterator::Make( args, exec_context()->exec_chunksize())); - } - return Status::OK(); - } - - Status PrepareExecute(const std::vector<Datum>& args) { + } + return Status::OK(); + } + + Status PrepareExecute(const std::vector<Datum>& args) { RETURN_NOT_OK(this->SetupArgIteration(args)); - output_num_buffers_ = static_cast<int>(output_descr_.type->layout().buffers.size()); - - // Decide if we need to preallocate memory for this kernel - validity_preallocated_ = - (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && - kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + output_num_buffers_ = static_cast<int>(output_descr_.type->layout().buffers.size()); + + // Decide if we need to preallocate memory for this kernel + validity_preallocated_ = + (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); } - return Status::OK(); - } - - std::vector<Datum> results_; -}; - + return Status::OK(); + } + + std::vector<Datum> results_; +}; + class ScalarAggExecutor : public KernelExecutorImpl<ScalarAggregateKernel> { - public: + public: Status Init(KernelContext* ctx, KernelInitArgs args) override { input_descrs_ = &args.inputs; options_ = args.options; return KernelExecutorImpl<ScalarAggregateKernel>::Init(ctx, args); } - - Status Execute(const std::vector<Datum>& args, ExecListener* listener) override { + + Status Execute(const std::vector<Datum>& args, ExecListener* listener) override { RETURN_NOT_OK(this->SetupArgIteration(args)); - - ExecBatch batch; - while (batch_iterator_->Next(&batch)) { - // TODO: implement parallelism - if (batch.length > 0) { - RETURN_NOT_OK(Consume(batch)); - } - } - - Datum out; + + ExecBatch batch; + while (batch_iterator_->Next(&batch)) { + // TODO: implement parallelism + if (batch.length > 0) { + RETURN_NOT_OK(Consume(batch)); + } + } + + Datum out; RETURN_NOT_OK(kernel_->finalize(kernel_ctx_, &out)); - RETURN_NOT_OK(listener->OnResult(std::move(out))); - return Status::OK(); - } - - Datum WrapResults(const std::vector<Datum>&, - const std::vector<Datum>& outputs) override { - DCHECK_EQ(1, outputs.size()); - return outputs[0]; - } - - private: - Status Consume(const ExecBatch& batch) { + RETURN_NOT_OK(listener->OnResult(std::move(out))); + return Status::OK(); + } + + Datum WrapResults(const std::vector<Datum>&, + const std::vector<Datum>& outputs) override { + DCHECK_EQ(1, outputs.size()); + return outputs[0]; + } + + private: + Status Consume(const ExecBatch& batch) { // FIXME(ARROW-11840) don't merge *any* aggegates for every batch ARROW_ASSIGN_OR_RAISE( auto batch_state, kernel_->init(kernel_ctx_, {kernel_, *input_descrs_, options_})); - - if (batch_state == nullptr) { + + if (batch_state == nullptr) { return Status::Invalid("ScalarAggregation requires non-null kernel state"); - } - + } + KernelContext batch_ctx(exec_context()); - batch_ctx.SetState(batch_state.get()); - + batch_ctx.SetState(batch_state.get()); + RETURN_NOT_OK(kernel_->consume(&batch_ctx, batch)); RETURN_NOT_OK(kernel_->merge(kernel_ctx_, std::move(*batch_state), state())); - return Status::OK(); - } + return Status::OK(); + } const std::vector<ValueDescr>* input_descrs_; const FunctionOptions* options_; -}; - -template <typename ExecutorType, - typename FunctionType = typename ExecutorType::FunctionType> +}; + +template <typename ExecutorType, + typename FunctionType = typename ExecutorType::FunctionType> Result<std::unique_ptr<KernelExecutor>> MakeExecutor(ExecContext* ctx, const Function* func, const FunctionOptions* options) { - DCHECK_EQ(ExecutorType::function_kind, func->kind()); - auto typed_func = checked_cast<const FunctionType*>(func); + DCHECK_EQ(ExecutorType::function_kind, func->kind()); + auto typed_func = checked_cast<const FunctionType*>(func); return std::unique_ptr<KernelExecutor>(new ExecutorType(ctx, typed_func, options)); -} - +} + } // namespace Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) { @@ -975,7 +975,7 @@ Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out // Null output type is a no-op (rare when this would happen but we at least // will test for it) return Status::OK(); - } + } // This function is ONLY able to write into output with non-zero offset // when the bitmap is preallocated. This could be a DCHECK but returning @@ -987,8 +987,8 @@ Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out } NullPropagator propagator(ctx, batch, output); return propagator.Execute(); -} - +} + std::unique_ptr<KernelExecutor> KernelExecutor::MakeScalar() { return ::arrow::internal::make_unique<detail::ScalarExecutor>(); } @@ -1001,50 +1001,50 @@ std::unique_ptr<KernelExecutor> KernelExecutor::MakeScalarAggregate() { return ::arrow::internal::make_unique<detail::ScalarAggExecutor>(); } -} // namespace detail - +} // namespace detail + ExecContext::ExecContext(MemoryPool* pool, ::arrow::internal::Executor* executor, FunctionRegistry* func_registry) : pool_(pool), executor_(executor) { - this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; -} - -CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); } - -// ---------------------------------------------------------------------- -// SelectionVector - -SelectionVector::SelectionVector(std::shared_ptr<ArrayData> data) - : data_(std::move(data)) { - DCHECK_EQ(Type::INT32, data_->type->id()); - DCHECK_EQ(0, data_->GetNullCount()); - indices_ = data_->GetValues<int32_t>(1); -} - -SelectionVector::SelectionVector(const Array& arr) : SelectionVector(arr.data()) {} - -int32_t SelectionVector::length() const { return static_cast<int32_t>(data_->length); } - -Result<std::shared_ptr<SelectionVector>> SelectionVector::FromMask( - const BooleanArray& arr) { - return Status::NotImplemented("FromMask"); -} - -Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, - const FunctionOptions* options, ExecContext* ctx) { - if (ctx == nullptr) { - ExecContext default_ctx; - return CallFunction(func_name, args, options, &default_ctx); - } - ARROW_ASSIGN_OR_RAISE(std::shared_ptr<const Function> func, - ctx->func_registry()->GetFunction(func_name)); - return func->Execute(args, options, ctx); -} - -Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, - ExecContext* ctx) { - return CallFunction(func_name, args, /*options=*/nullptr, ctx); -} - -} // namespace compute -} // namespace arrow + this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; +} + +CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); } + +// ---------------------------------------------------------------------- +// SelectionVector + +SelectionVector::SelectionVector(std::shared_ptr<ArrayData> data) + : data_(std::move(data)) { + DCHECK_EQ(Type::INT32, data_->type->id()); + DCHECK_EQ(0, data_->GetNullCount()); + indices_ = data_->GetValues<int32_t>(1); +} + +SelectionVector::SelectionVector(const Array& arr) : SelectionVector(arr.data()) {} + +int32_t SelectionVector::length() const { return static_cast<int32_t>(data_->length); } + +Result<std::shared_ptr<SelectionVector>> SelectionVector::FromMask( + const BooleanArray& arr) { + return Status::NotImplemented("FromMask"); +} + +Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, + const FunctionOptions* options, ExecContext* ctx) { + if (ctx == nullptr) { + ExecContext default_ctx; + return CallFunction(func_name, args, options, &default_ctx); + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<const Function> func, + ctx->func_registry()->GetFunction(func_name)); + return func->Execute(args, options, ctx); +} + +Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, + ExecContext* ctx) { + return CallFunction(func_name, args, /*options=*/nullptr, ctx); +} + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h index de1b695de4..227d0c76ad 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h @@ -1,183 +1,183 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle - -#pragma once - -#include <cstdint> -#include <limits> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "arrow/array/data.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/array/data.h" #include "arrow/compute/exec/expression.h" -#include "arrow/datum.h" -#include "arrow/memory_pool.h" -#include "arrow/result.h" -#include "arrow/type_fwd.h" -#include "arrow/util/macros.h" +#include "arrow/datum.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" #include "arrow/util/type_fwd.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace internal { - -class CpuInfo; - -} // namespace internal - -namespace compute { - +#include "arrow/util/visibility.h" + +namespace arrow { +namespace internal { + +class CpuInfo; + +} // namespace internal + +namespace compute { + class FunctionOptions; -class FunctionRegistry; - -// It seems like 64K might be a good default chunksize to use for execution -// based on the experience of other query processing systems. The current -// default is not to chunk contiguous arrays, though, but this may change in -// the future once parallel execution is implemented -static constexpr int64_t kDefaultExecChunksize = UINT16_MAX; - -/// \brief Context for expression-global variables and options used by -/// function evaluation -class ARROW_EXPORT ExecContext { - public: - // If no function registry passed, the default is used. - explicit ExecContext(MemoryPool* pool = default_memory_pool(), +class FunctionRegistry; + +// It seems like 64K might be a good default chunksize to use for execution +// based on the experience of other query processing systems. The current +// default is not to chunk contiguous arrays, though, but this may change in +// the future once parallel execution is implemented +static constexpr int64_t kDefaultExecChunksize = UINT16_MAX; + +/// \brief Context for expression-global variables and options used by +/// function evaluation +class ARROW_EXPORT ExecContext { + public: + // If no function registry passed, the default is used. + explicit ExecContext(MemoryPool* pool = default_memory_pool(), ::arrow::internal::Executor* executor = NULLPTR, - FunctionRegistry* func_registry = NULLPTR); - - /// \brief The MemoryPool used for allocations, default is - /// default_memory_pool(). - MemoryPool* memory_pool() const { return pool_; } - - ::arrow::internal::CpuInfo* cpu_info() const; - + FunctionRegistry* func_registry = NULLPTR); + + /// \brief The MemoryPool used for allocations, default is + /// default_memory_pool(). + MemoryPool* memory_pool() const { return pool_; } + + ::arrow::internal::CpuInfo* cpu_info() const; + /// \brief An Executor which may be used to parallelize execution. ::arrow::internal::Executor* executor() const { return executor_; } - /// \brief The FunctionRegistry for looking up functions by name and - /// selecting kernels for execution. Defaults to the library-global function - /// registry provided by GetFunctionRegistry. - FunctionRegistry* func_registry() const { return func_registry_; } - - // \brief Set maximum length unit of work for kernel execution. Larger - // contiguous array inputs will be split into smaller chunks, and, if - // possible and enabled, processed in parallel. The default chunksize is - // INT64_MAX, so contiguous arrays are not split. - void set_exec_chunksize(int64_t chunksize) { exec_chunksize_ = chunksize; } - - // \brief Maximum length for ExecBatch data chunks processed by - // kernels. Contiguous array inputs with longer length will be split into - // smaller chunks. - int64_t exec_chunksize() const { return exec_chunksize_; } - - /// \brief Set whether to use multiple threads for function execution. This - /// is not yet used. - void set_use_threads(bool use_threads = true) { use_threads_ = use_threads; } - - /// \brief If true, then utilize multiple threads where relevant for function - /// execution. This is not yet used. - bool use_threads() const { return use_threads_; } - - // Set the preallocation strategy for kernel execution as it relates to - // chunked execution. For chunked execution, whether via ChunkedArray inputs - // or splitting larger Array arguments into smaller pieces, contiguous - // allocation (if permitted by the kernel) will allocate one large array to - // write output into yielding it to the caller at the end. If this option is - // set to off, then preallocations will be performed independently for each - // chunk of execution - // - // TODO: At some point we might want the limit the size of contiguous - // preallocations. For example, even if the exec_chunksize is 64K or less, we - // might limit contiguous allocations to 1M records, say. - void set_preallocate_contiguous(bool preallocate) { - preallocate_contiguous_ = preallocate; - } - - /// \brief If contiguous preallocations should be used when doing chunked - /// execution as specified by exec_chunksize(). See - /// set_preallocate_contiguous() for more information. - bool preallocate_contiguous() const { return preallocate_contiguous_; } - - private: - MemoryPool* pool_; + /// \brief The FunctionRegistry for looking up functions by name and + /// selecting kernels for execution. Defaults to the library-global function + /// registry provided by GetFunctionRegistry. + FunctionRegistry* func_registry() const { return func_registry_; } + + // \brief Set maximum length unit of work for kernel execution. Larger + // contiguous array inputs will be split into smaller chunks, and, if + // possible and enabled, processed in parallel. The default chunksize is + // INT64_MAX, so contiguous arrays are not split. + void set_exec_chunksize(int64_t chunksize) { exec_chunksize_ = chunksize; } + + // \brief Maximum length for ExecBatch data chunks processed by + // kernels. Contiguous array inputs with longer length will be split into + // smaller chunks. + int64_t exec_chunksize() const { return exec_chunksize_; } + + /// \brief Set whether to use multiple threads for function execution. This + /// is not yet used. + void set_use_threads(bool use_threads = true) { use_threads_ = use_threads; } + + /// \brief If true, then utilize multiple threads where relevant for function + /// execution. This is not yet used. + bool use_threads() const { return use_threads_; } + + // Set the preallocation strategy for kernel execution as it relates to + // chunked execution. For chunked execution, whether via ChunkedArray inputs + // or splitting larger Array arguments into smaller pieces, contiguous + // allocation (if permitted by the kernel) will allocate one large array to + // write output into yielding it to the caller at the end. If this option is + // set to off, then preallocations will be performed independently for each + // chunk of execution + // + // TODO: At some point we might want the limit the size of contiguous + // preallocations. For example, even if the exec_chunksize is 64K or less, we + // might limit contiguous allocations to 1M records, say. + void set_preallocate_contiguous(bool preallocate) { + preallocate_contiguous_ = preallocate; + } + + /// \brief If contiguous preallocations should be used when doing chunked + /// execution as specified by exec_chunksize(). See + /// set_preallocate_contiguous() for more information. + bool preallocate_contiguous() const { return preallocate_contiguous_; } + + private: + MemoryPool* pool_; ::arrow::internal::Executor* executor_; - FunctionRegistry* func_registry_; - int64_t exec_chunksize_ = std::numeric_limits<int64_t>::max(); - bool preallocate_contiguous_ = true; - bool use_threads_ = true; -}; - + FunctionRegistry* func_registry_; + int64_t exec_chunksize_ = std::numeric_limits<int64_t>::max(); + bool preallocate_contiguous_ = true; + bool use_threads_ = true; +}; + ARROW_EXPORT ExecContext* default_exec_context(); -// TODO: Consider standardizing on uint16 selection vectors and only use them -// when we can ensure that each value is 64K length or smaller - -/// \brief Container for an array of value selection indices that were -/// materialized from a filter. -/// -/// Columnar query engines (see e.g. [1]) have found that rather than -/// materializing filtered data, the filter can instead be converted to an -/// array of the "on" indices and then "fusing" these indices in operator -/// implementations. This is especially relevant for aggregations but also -/// applies to scalar operations. -/// -/// We are not yet using this so this is mostly a placeholder for now. -/// -/// [1]: http://cidrdb.org/cidr2005/papers/P19.pdf -class ARROW_EXPORT SelectionVector { - public: - explicit SelectionVector(std::shared_ptr<ArrayData> data); - - explicit SelectionVector(const Array& arr); - - /// \brief Create SelectionVector from boolean mask - static Result<std::shared_ptr<SelectionVector>> FromMask(const BooleanArray& arr); - - const int32_t* indices() const { return indices_; } - int32_t length() const; - - private: - std::shared_ptr<ArrayData> data_; - const int32_t* indices_; -}; - -/// \brief A unit of work for kernel execution. It contains a collection of -/// Array and Scalar values and an optional SelectionVector indicating that -/// there is an unmaterialized filter that either must be materialized, or (if -/// the kernel supports it) pushed down into the kernel implementation. -/// -/// ExecBatch is semantically similar to RecordBatch in that in a SQL context -/// it represents a collection of records, but constant "columns" are -/// represented by Scalar values rather than having to be converted into arrays -/// with repeated values. -/// -/// TODO: Datum uses arrow/util/variant.h which may be a bit heavier-weight -/// than is desirable for this class. Microbenchmarks would help determine for -/// sure. See ARROW-8928. +// TODO: Consider standardizing on uint16 selection vectors and only use them +// when we can ensure that each value is 64K length or smaller + +/// \brief Container for an array of value selection indices that were +/// materialized from a filter. +/// +/// Columnar query engines (see e.g. [1]) have found that rather than +/// materializing filtered data, the filter can instead be converted to an +/// array of the "on" indices and then "fusing" these indices in operator +/// implementations. This is especially relevant for aggregations but also +/// applies to scalar operations. +/// +/// We are not yet using this so this is mostly a placeholder for now. +/// +/// [1]: http://cidrdb.org/cidr2005/papers/P19.pdf +class ARROW_EXPORT SelectionVector { + public: + explicit SelectionVector(std::shared_ptr<ArrayData> data); + + explicit SelectionVector(const Array& arr); + + /// \brief Create SelectionVector from boolean mask + static Result<std::shared_ptr<SelectionVector>> FromMask(const BooleanArray& arr); + + const int32_t* indices() const { return indices_; } + int32_t length() const; + + private: + std::shared_ptr<ArrayData> data_; + const int32_t* indices_; +}; + +/// \brief A unit of work for kernel execution. It contains a collection of +/// Array and Scalar values and an optional SelectionVector indicating that +/// there is an unmaterialized filter that either must be materialized, or (if +/// the kernel supports it) pushed down into the kernel implementation. +/// +/// ExecBatch is semantically similar to RecordBatch in that in a SQL context +/// it represents a collection of records, but constant "columns" are +/// represented by Scalar values rather than having to be converted into arrays +/// with repeated values. +/// +/// TODO: Datum uses arrow/util/variant.h which may be a bit heavier-weight +/// than is desirable for this class. Microbenchmarks would help determine for +/// sure. See ARROW-8928. struct ARROW_EXPORT ExecBatch { ExecBatch() = default; - ExecBatch(std::vector<Datum> values, int64_t length) - : values(std::move(values)), length(length) {} - + ExecBatch(std::vector<Datum> values, int64_t length) + : values(std::move(values)), length(length) {} + explicit ExecBatch(const RecordBatch& batch); static Result<ExecBatch> Make(std::vector<Datum> values); @@ -185,80 +185,80 @@ struct ARROW_EXPORT ExecBatch { Result<std::shared_ptr<RecordBatch>> ToRecordBatch( std::shared_ptr<Schema> schema, MemoryPool* pool = default_memory_pool()) const; - /// The values representing positional arguments to be passed to a kernel's - /// exec function for processing. - std::vector<Datum> values; - - /// A deferred filter represented as an array of indices into the values. - /// - /// For example, the filter [true, true, false, true] would be represented as - /// the selection vector [0, 1, 3]. When the selection vector is set, - /// ExecBatch::length is equal to the length of this array. - std::shared_ptr<SelectionVector> selection_vector; - + /// The values representing positional arguments to be passed to a kernel's + /// exec function for processing. + std::vector<Datum> values; + + /// A deferred filter represented as an array of indices into the values. + /// + /// For example, the filter [true, true, false, true] would be represented as + /// the selection vector [0, 1, 3]. When the selection vector is set, + /// ExecBatch::length is equal to the length of this array. + std::shared_ptr<SelectionVector> selection_vector; + /// A predicate Expression guaranteed to evaluate to true for all rows in this batch. Expression guarantee = literal(true); - /// The semantic length of the ExecBatch. When the values are all scalars, - /// the length should be set to 1, otherwise the length is taken from the - /// array values, except when there is a selection vector. When there is a - /// selection vector set, the length of the batch is the length of the - /// selection. - /// - /// If the array values are of length 0 then the length is 0 regardless of - /// whether any values are Scalar. In general ExecBatch objects are produced - /// by ExecBatchIterator which by design does not yield length-0 batches. - int64_t length; - - /// \brief Return the value at the i-th index - template <typename index_type> - inline const Datum& operator[](index_type i) const { - return values[i]; - } - + /// The semantic length of the ExecBatch. When the values are all scalars, + /// the length should be set to 1, otherwise the length is taken from the + /// array values, except when there is a selection vector. When there is a + /// selection vector set, the length of the batch is the length of the + /// selection. + /// + /// If the array values are of length 0 then the length is 0 regardless of + /// whether any values are Scalar. In general ExecBatch objects are produced + /// by ExecBatchIterator which by design does not yield length-0 batches. + int64_t length; + + /// \brief Return the value at the i-th index + template <typename index_type> + inline const Datum& operator[](index_type i) const { + return values[i]; + } + bool Equals(const ExecBatch& other) const; - /// \brief A convenience for the number of values / arguments. - int num_values() const { return static_cast<int>(values.size()); } - + /// \brief A convenience for the number of values / arguments. + int num_values() const { return static_cast<int>(values.size()); } + ExecBatch Slice(int64_t offset, int64_t length) const; - /// \brief A convenience for returning the ValueDescr objects (types and - /// shapes) from the batch. - std::vector<ValueDescr> GetDescriptors() const { - std::vector<ValueDescr> result; - for (const auto& value : this->values) { - result.emplace_back(value.descr()); - } - return result; - } + /// \brief A convenience for returning the ValueDescr objects (types and + /// shapes) from the batch. + std::vector<ValueDescr> GetDescriptors() const { + std::vector<ValueDescr> result; + for (const auto& value : this->values) { + result.emplace_back(value.descr()); + } + return result; + } ARROW_EXPORT friend void PrintTo(const ExecBatch&, std::ostream*); -}; - +}; + inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals(r); } inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); } -/// \defgroup compute-call-function One-shot calls to compute functions -/// -/// @{ - -/// \brief One-shot invoker for all types of functions. -/// -/// Does kernel dispatch, argument checking, iteration of ChunkedArray inputs, -/// and wrapping of outputs. -ARROW_EXPORT -Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, - const FunctionOptions* options, ExecContext* ctx = NULLPTR); - -/// \brief Variant of CallFunction which uses a function's default options. -/// -/// NB: Some functions require FunctionOptions be provided. -ARROW_EXPORT -Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, - ExecContext* ctx = NULLPTR); - -/// @} - -} // namespace compute -} // namespace arrow +/// \defgroup compute-call-function One-shot calls to compute functions +/// +/// @{ + +/// \brief One-shot invoker for all types of functions. +/// +/// Does kernel dispatch, argument checking, iteration of ChunkedArray inputs, +/// and wrapping of outputs. +ARROW_EXPORT +Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, + const FunctionOptions* options, ExecContext* ctx = NULLPTR); + +/// \brief Variant of CallFunction which uses a function's default options. +/// +/// NB: Some functions require FunctionOptions be provided. +ARROW_EXPORT +Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args, + ExecContext* ctx = NULLPTR); + +/// @} + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec_internal.h index 55daa243cd..e2872e6141 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/exec_internal.h @@ -1,111 +1,111 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> -#include <limits> -#include <memory> -#include <string> -#include <vector> - -#include "arrow/array.h" -#include "arrow/buffer.h" -#include "arrow/compute/exec.h" -#include "arrow/compute/kernel.h" -#include "arrow/status.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace compute { - -class Function; - -static constexpr int64_t kDefaultMaxChunksize = std::numeric_limits<int64_t>::max(); - -namespace detail { - -/// \brief Break std::vector<Datum> into a sequence of ExecBatch for kernel -/// execution -class ARROW_EXPORT ExecBatchIterator { - public: - /// \brief Construct iterator and do basic argument validation - /// - /// \param[in] args the Datum argument, must be all array-like or scalar - /// \param[in] max_chunksize the maximum length of each ExecBatch. Depending - /// on the chunk layout of ChunkedArray. - static Result<std::unique_ptr<ExecBatchIterator>> Make( - std::vector<Datum> args, int64_t max_chunksize = kDefaultMaxChunksize); - - /// \brief Compute the next batch. Always returns at least one batch. Return - /// false if the iterator is exhausted - bool Next(ExecBatch* batch); - - int64_t length() const { return length_; } - - int64_t position() const { return position_; } - - int64_t max_chunksize() const { return max_chunksize_; } - - private: - ExecBatchIterator(std::vector<Datum> args, int64_t length, int64_t max_chunksize); - - std::vector<Datum> args_; - std::vector<int> chunk_indexes_; - std::vector<int64_t> chunk_positions_; - int64_t position_; - int64_t length_; - int64_t max_chunksize_; -}; - -// "Push" / listener API like IPC reader so that consumers can receive -// processed chunks as soon as they're available. - -class ARROW_EXPORT ExecListener { - public: - virtual ~ExecListener() = default; - - virtual Status OnResult(Datum) { return Status::NotImplemented("OnResult"); } -}; - -class DatumAccumulator : public ExecListener { - public: +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <vector> + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +class Function; + +static constexpr int64_t kDefaultMaxChunksize = std::numeric_limits<int64_t>::max(); + +namespace detail { + +/// \brief Break std::vector<Datum> into a sequence of ExecBatch for kernel +/// execution +class ARROW_EXPORT ExecBatchIterator { + public: + /// \brief Construct iterator and do basic argument validation + /// + /// \param[in] args the Datum argument, must be all array-like or scalar + /// \param[in] max_chunksize the maximum length of each ExecBatch. Depending + /// on the chunk layout of ChunkedArray. + static Result<std::unique_ptr<ExecBatchIterator>> Make( + std::vector<Datum> args, int64_t max_chunksize = kDefaultMaxChunksize); + + /// \brief Compute the next batch. Always returns at least one batch. Return + /// false if the iterator is exhausted + bool Next(ExecBatch* batch); + + int64_t length() const { return length_; } + + int64_t position() const { return position_; } + + int64_t max_chunksize() const { return max_chunksize_; } + + private: + ExecBatchIterator(std::vector<Datum> args, int64_t length, int64_t max_chunksize); + + std::vector<Datum> args_; + std::vector<int> chunk_indexes_; + std::vector<int64_t> chunk_positions_; + int64_t position_; + int64_t length_; + int64_t max_chunksize_; +}; + +// "Push" / listener API like IPC reader so that consumers can receive +// processed chunks as soon as they're available. + +class ARROW_EXPORT ExecListener { + public: + virtual ~ExecListener() = default; + + virtual Status OnResult(Datum) { return Status::NotImplemented("OnResult"); } +}; + +class DatumAccumulator : public ExecListener { + public: DatumAccumulator() = default; - - Status OnResult(Datum value) override { - values_.emplace_back(value); - return Status::OK(); - } - + + Status OnResult(Datum value) override { + values_.emplace_back(value); + return Status::OK(); + } + std::vector<Datum> values() { return std::move(values_); } - - private: - std::vector<Datum> values_; -}; - -/// \brief Check that each Datum is of a "value" type, which means either -/// SCALAR, ARRAY, or CHUNKED_ARRAY. If there are chunked inputs, then these -/// inputs will be split into non-chunked ExecBatch values for execution -Status CheckAllValues(const std::vector<Datum>& values); - + + private: + std::vector<Datum> values_; +}; + +/// \brief Check that each Datum is of a "value" type, which means either +/// SCALAR, ARRAY, or CHUNKED_ARRAY. If there are chunked inputs, then these +/// inputs will be split into non-chunked ExecBatch values for execution +Status CheckAllValues(const std::vector<Datum>& values); + class ARROW_EXPORT KernelExecutor { - public: + public: virtual ~KernelExecutor() = default; - + /// The Kernel's `init` method must be called and any KernelState set in the /// KernelContext *before* KernelExecutor::Init is called. This is to facilitate /// the case where init may be expensive and does not need to be called again for @@ -113,30 +113,30 @@ class ARROW_EXPORT KernelExecutor { /// for all scanned batches in a dataset filter. virtual Status Init(KernelContext*, KernelInitArgs) = 0; - /// XXX: Better configurability for listener - /// Not thread-safe - virtual Status Execute(const std::vector<Datum>& args, ExecListener* listener) = 0; - - virtual Datum WrapResults(const std::vector<Datum>& args, - const std::vector<Datum>& outputs) = 0; - + /// XXX: Better configurability for listener + /// Not thread-safe + virtual Status Execute(const std::vector<Datum>& args, ExecListener* listener) = 0; + + virtual Datum WrapResults(const std::vector<Datum>& args, + const std::vector<Datum>& outputs) = 0; + static std::unique_ptr<KernelExecutor> MakeScalar(); static std::unique_ptr<KernelExecutor> MakeVector(); static std::unique_ptr<KernelExecutor> MakeScalarAggregate(); -}; - -/// \brief Populate validity bitmap with the intersection of the nullity of the -/// arguments. If a preallocated bitmap is not provided, then one will be -/// allocated if needed (in some cases a bitmap can be zero-copied from the -/// arguments). If any Scalar value is null, then the entire validity bitmap -/// will be set to null. -/// -/// \param[in] ctx kernel execution context, for memory allocation etc. -/// \param[in] batch the data batch -/// \param[in] out the output ArrayData, must not be null -ARROW_EXPORT -Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out); - -} // namespace detail -} // namespace compute -} // namespace arrow +}; + +/// \brief Populate validity bitmap with the intersection of the nullity of the +/// arguments. If a preallocated bitmap is not provided, then one will be +/// allocated if needed (in some cases a bitmap can be zero-copied from the +/// arguments). If any Scalar value is null, then the entire validity bitmap +/// will be set to null. +/// +/// \param[in] ctx kernel execution context, for memory allocation etc. +/// \param[in] batch the data batch +/// \param[in] out the output ArrayData, must not be null +ARROW_EXPORT +Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out); + +} // namespace detail +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.cc index 05d14d03b1..3eefb327c1 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.cc @@ -1,46 +1,46 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/function.h" - -#include <cstddef> -#include <memory> -#include <sstream> - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/function.h" + +#include <cstddef> +#include <memory> +#include <sstream> + #include "arrow/compute/api_scalar.h" #include "arrow/compute/cast.h" -#include "arrow/compute/exec.h" -#include "arrow/compute/exec_internal.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/exec_internal.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/registry.h" -#include "arrow/datum.h" -#include "arrow/util/cpu_info.h" - -namespace arrow { +#include "arrow/datum.h" +#include "arrow/util/cpu_info.h" + +namespace arrow { using internal::checked_cast; -namespace compute { +namespace compute { Result<std::shared_ptr<Buffer>> FunctionOptionsType::Serialize( const FunctionOptions&) const { return Status::NotImplemented("Serialize for ", type_name()); } - + Result<std::unique_ptr<FunctionOptions>> FunctionOptionsType::Deserialize( const Buffer& buffer) const { return Status::NotImplemented("Deserialize for ", type_name()); @@ -79,7 +79,7 @@ static Status CheckArityImpl(const Function* function, int passed_num_args, return Status::Invalid("VarArgs function ", function->name(), " needs at least ", function->arity().num_args, " arguments but ", passed_num_args_label, " only ", passed_num_args); - } + } if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) { return Status::Invalid("Function ", function->name(), " accepts ", @@ -87,18 +87,18 @@ static Status CheckArityImpl(const Function* function, int passed_num_args, passed_num_args_label, " ", passed_num_args); } - return Status::OK(); -} - + return Status::OK(); +} + Status Function::CheckArity(const std::vector<InputType>& in_types) const { return CheckArityImpl(this, static_cast<int>(in_types.size()), "kernel accepts"); -} - +} + Status Function::CheckArity(const std::vector<ValueDescr>& descrs) const { return CheckArityImpl(this, static_cast<int>(descrs.size()), "attempted to look up kernel(s) with"); } - + namespace detail { Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>& descrs) { @@ -112,38 +112,38 @@ const KernelType* DispatchExactImpl(const std::vector<KernelType*>& kernels, const std::vector<ValueDescr>& values) { const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr}; - // Validate arity - for (const auto& kernel : kernels) { + // Validate arity + for (const auto& kernel : kernels) { if (kernel->signature->MatchesInputs(values)) { kernel_matches[kernel->simd_level] = kernel; - } - } - - // Dispatch as the CPU feature + } + } + + // Dispatch as the CPU feature #if defined(ARROW_HAVE_RUNTIME_AVX512) || defined(ARROW_HAVE_RUNTIME_AVX2) - auto cpu_info = arrow::internal::CpuInfo::GetInstance(); -#endif -#if defined(ARROW_HAVE_RUNTIME_AVX512) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - if (kernel_matches[SimdLevel::AVX512]) { - return kernel_matches[SimdLevel::AVX512]; - } - } + auto cpu_info = arrow::internal::CpuInfo::GetInstance(); #endif -#if defined(ARROW_HAVE_RUNTIME_AVX2) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - if (kernel_matches[SimdLevel::AVX2]) { - return kernel_matches[SimdLevel::AVX2]; - } - } -#endif - if (kernel_matches[SimdLevel::NONE]) { - return kernel_matches[SimdLevel::NONE]; - } - +#if defined(ARROW_HAVE_RUNTIME_AVX512) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { + if (kernel_matches[SimdLevel::AVX512]) { + return kernel_matches[SimdLevel::AVX512]; + } + } +#endif +#if defined(ARROW_HAVE_RUNTIME_AVX2) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { + if (kernel_matches[SimdLevel::AVX2]) { + return kernel_matches[SimdLevel::AVX2]; + } + } +#endif + if (kernel_matches[SimdLevel::NONE]) { + return kernel_matches[SimdLevel::NONE]; + } + return nullptr; -} - +} + const Kernel* DispatchExactImpl(const Function* func, const std::vector<ValueDescr>& values) { if (func->kind() == Function::SCALAR) { @@ -189,19 +189,19 @@ Result<const Kernel*> Function::DispatchBest(std::vector<ValueDescr>* values) co return DispatchExact(*values); } -Result<Datum> Function::Execute(const std::vector<Datum>& args, - const FunctionOptions* options, ExecContext* ctx) const { - if (options == nullptr) { - options = default_options(); - } - if (ctx == nullptr) { - ExecContext default_ctx; - return Execute(args, options, &default_ctx); - } - - // type-check Datum arguments here. Really we'd like to avoid this as much as - // possible - RETURN_NOT_OK(detail::CheckAllValues(args)); +Result<Datum> Function::Execute(const std::vector<Datum>& args, + const FunctionOptions* options, ExecContext* ctx) const { + if (options == nullptr) { + options = default_options(); + } + if (ctx == nullptr) { + ExecContext default_ctx; + return Execute(args, options, &default_ctx); + } + + // type-check Datum arguments here. Really we'd like to avoid this as much as + // possible + RETURN_NOT_OK(detail::CheckAllValues(args)); std::vector<ValueDescr> inputs(args.size()); for (size_t i = 0; i != args.size(); ++i) { inputs[i] = args[i].descr(); @@ -230,11 +230,11 @@ Result<Datum> Function::Execute(const std::vector<Datum>& args, } RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); - auto listener = std::make_shared<detail::DatumAccumulator>(); + auto listener = std::make_shared<detail::DatumAccumulator>(); RETURN_NOT_OK(executor->Execute(implicitly_cast_args, listener.get())); return executor->WrapResults(implicitly_cast_args, listener->values()); -} - +} + Status Function::Validate() const { if (!doc_->summary.empty()) { // Documentation given, check its contents @@ -252,59 +252,59 @@ Status Function::Validate() const { return Status::OK(); } -Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type, - ArrayKernelExec exec, KernelInit init) { +Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init) { RETURN_NOT_OK(CheckArity(in_types)); - - if (arity_.is_varargs && in_types.size() != 1) { - return Status::Invalid("VarArgs signatures must have exactly one input type"); - } - auto sig = - KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); - kernels_.emplace_back(std::move(sig), exec, init); - return Status::OK(); -} - -Status ScalarFunction::AddKernel(ScalarKernel kernel) { + + if (arity_.is_varargs && in_types.size() != 1) { + return Status::Invalid("VarArgs signatures must have exactly one input type"); + } + auto sig = + KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); + kernels_.emplace_back(std::move(sig), exec, init); + return Status::OK(); +} + +Status ScalarFunction::AddKernel(ScalarKernel kernel) { RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); - if (arity_.is_varargs && !kernel.signature->is_varargs()) { - return Status::Invalid("Function accepts varargs but kernel signature does not"); - } - kernels_.emplace_back(std::move(kernel)); - return Status::OK(); -} - -Status VectorFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type, - ArrayKernelExec exec, KernelInit init) { + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Status VectorFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init) { RETURN_NOT_OK(CheckArity(in_types)); - - if (arity_.is_varargs && in_types.size() != 1) { - return Status::Invalid("VarArgs signatures must have exactly one input type"); - } - auto sig = - KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); - kernels_.emplace_back(std::move(sig), exec, init); - return Status::OK(); -} - -Status VectorFunction::AddKernel(VectorKernel kernel) { + + if (arity_.is_varargs && in_types.size() != 1) { + return Status::Invalid("VarArgs signatures must have exactly one input type"); + } + auto sig = + KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); + kernels_.emplace_back(std::move(sig), exec, init); + return Status::OK(); +} + +Status VectorFunction::AddKernel(VectorKernel kernel) { RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); - if (arity_.is_varargs && !kernel.signature->is_varargs()) { - return Status::Invalid("Function accepts varargs but kernel signature does not"); - } - kernels_.emplace_back(std::move(kernel)); - return Status::OK(); -} - -Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); - if (arity_.is_varargs && !kernel.signature->is_varargs()) { - return Status::Invalid("Function accepts varargs but kernel signature does not"); - } - kernels_.emplace_back(std::move(kernel)); - return Status::OK(); -} - + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { @@ -312,19 +312,19 @@ Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { } kernels_.emplace_back(std::move(kernel)); return Status::OK(); -} - -Result<Datum> MetaFunction::Execute(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const { +} + +Result<Datum> MetaFunction::Execute(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const { RETURN_NOT_OK( CheckArityImpl(this, static_cast<int>(args.size()), "attempted to Execute with")); - if (options == nullptr) { - options = default_options(); - } - return ExecuteImpl(args, options, ctx); -} - -} // namespace compute -} // namespace arrow + if (options == nullptr) { + options = default_options(); + } + return ExecuteImpl(args, options, ctx); +} + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h index bd854bbb28..69c55cd998 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h @@ -1,45 +1,45 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle. - -#pragma once - -#include <string> -#include <utility> -#include <vector> - -#include "arrow/compute/kernel.h" -#include "arrow/compute/type_fwd.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/status.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle. + +#pragma once + +#include <string> +#include <utility> +#include <vector> + +#include "arrow/compute/kernel.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/status.h" #include "arrow/util/compare.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace compute { - -/// \defgroup compute-functions Abstract compute function API -/// -/// @{ - +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \defgroup compute-functions Abstract compute function API +/// +/// @{ + /// \brief Extension point for defining options outside libarrow (but /// still within this project). class ARROW_EXPORT FunctionOptionsType { @@ -54,12 +54,12 @@ class ARROW_EXPORT FunctionOptionsType { const Buffer& buffer) const; }; -/// \brief Base class for specifying options configuring a function's behavior, -/// such as error handling. +/// \brief Base class for specifying options configuring a function's behavior, +/// such as error handling. class ARROW_EXPORT FunctionOptions : public util::EqualityComparable<FunctionOptions> { public: virtual ~FunctionOptions() = default; - + const FunctionOptionsType* options_type() const { return options_type_; } const char* type_name() const { return options_type()->type_name(); } @@ -84,40 +84,40 @@ class ARROW_EXPORT FunctionOptions : public util::EqualityComparable<FunctionOpt ARROW_EXPORT void PrintTo(const FunctionOptions&, std::ostream*); -/// \brief Contains the number of required arguments for the function. -/// -/// Naming conventions taken from https://en.wikipedia.org/wiki/Arity. -struct ARROW_EXPORT Arity { - /// \brief A function taking no arguments - static Arity Nullary() { return Arity(0, false); } - - /// \brief A function taking 1 argument - static Arity Unary() { return Arity(1, false); } - - /// \brief A function taking 2 arguments - static Arity Binary() { return Arity(2, false); } - - /// \brief A function taking 3 arguments - static Arity Ternary() { return Arity(3, false); } - - /// \brief A function taking a variable number of arguments - /// - /// \param[in] min_args the minimum number of arguments required when - /// invoking the function - static Arity VarArgs(int min_args = 0) { return Arity(min_args, true); } - - // NOTE: the 0-argument form (default constructor) is required for Cython - explicit Arity(int num_args = 0, bool is_varargs = false) - : num_args(num_args), is_varargs(is_varargs) {} - - /// The number of required arguments (or the minimum number for varargs - /// functions). - int num_args; - - /// If true, then the num_args is the minimum number of required arguments. - bool is_varargs = false; -}; - +/// \brief Contains the number of required arguments for the function. +/// +/// Naming conventions taken from https://en.wikipedia.org/wiki/Arity. +struct ARROW_EXPORT Arity { + /// \brief A function taking no arguments + static Arity Nullary() { return Arity(0, false); } + + /// \brief A function taking 1 argument + static Arity Unary() { return Arity(1, false); } + + /// \brief A function taking 2 arguments + static Arity Binary() { return Arity(2, false); } + + /// \brief A function taking 3 arguments + static Arity Ternary() { return Arity(3, false); } + + /// \brief A function taking a variable number of arguments + /// + /// \param[in] min_args the minimum number of arguments required when + /// invoking the function + static Arity VarArgs(int min_args = 0) { return Arity(min_args, true); } + + // NOTE: the 0-argument form (default constructor) is required for Cython + explicit Arity(int num_args = 0, bool is_varargs = false) + : num_args(num_args), is_varargs(is_varargs) {} + + /// The number of required arguments (or the minimum number for varargs + /// functions). + int num_args; + + /// If true, then the num_args is the minimum number of required arguments. + bool is_varargs = false; +}; + struct ARROW_EXPORT FunctionDoc { /// \brief A one-line summary of the function, using a verb. /// @@ -149,57 +149,57 @@ struct ARROW_EXPORT FunctionDoc { static const FunctionDoc& Empty(); }; -/// \brief Base class for compute functions. Function implementations contain a -/// collection of "kernels" which are implementations of the function for -/// specific argument types. Selecting a viable kernel for executing a function -/// is referred to as "dispatching". -class ARROW_EXPORT Function { - public: - /// \brief The kind of function, which indicates in what contexts it is - /// valid for use. - enum Kind { - /// A function that performs scalar data operations on whole arrays of - /// data. Can generally process Array or Scalar values. The size of the - /// output will be the same as the size (or broadcasted size, in the case - /// of mixing Array and Scalar inputs) of the input. - SCALAR, - - /// A function with array input and output whose behavior depends on the - /// values of the entire arrays passed, rather than the value of each scalar - /// value. - VECTOR, - - /// A function that computes scalar summary statistics from array input. - SCALAR_AGGREGATE, - +/// \brief Base class for compute functions. Function implementations contain a +/// collection of "kernels" which are implementations of the function for +/// specific argument types. Selecting a viable kernel for executing a function +/// is referred to as "dispatching". +class ARROW_EXPORT Function { + public: + /// \brief The kind of function, which indicates in what contexts it is + /// valid for use. + enum Kind { + /// A function that performs scalar data operations on whole arrays of + /// data. Can generally process Array or Scalar values. The size of the + /// output will be the same as the size (or broadcasted size, in the case + /// of mixing Array and Scalar inputs) of the input. + SCALAR, + + /// A function with array input and output whose behavior depends on the + /// values of the entire arrays passed, rather than the value of each scalar + /// value. + VECTOR, + + /// A function that computes scalar summary statistics from array input. + SCALAR_AGGREGATE, + /// A function that computes grouped summary statistics from array input /// and an array of group identifiers. HASH_AGGREGATE, - /// A function that dispatches to other functions and does not contain its - /// own kernels. - META - }; - - virtual ~Function() = default; - - /// \brief The name of the kernel. The registry enforces uniqueness of names. - const std::string& name() const { return name_; } - - /// \brief The kind of kernel, which indicates in what contexts it is valid - /// for use. - Function::Kind kind() const { return kind_; } - - /// \brief Contains the number of arguments the function requires, or if the - /// function accepts variable numbers of arguments. - const Arity& arity() const { return arity_; } - + /// A function that dispatches to other functions and does not contain its + /// own kernels. + META + }; + + virtual ~Function() = default; + + /// \brief The name of the kernel. The registry enforces uniqueness of names. + const std::string& name() const { return name_; } + + /// \brief The kind of kernel, which indicates in what contexts it is valid + /// for use. + Function::Kind kind() const { return kind_; } + + /// \brief Contains the number of arguments the function requires, or if the + /// function accepts variable numbers of arguments. + const Arity& arity() const { return arity_; } + /// \brief Return the function documentation const FunctionDoc& doc() const { return *doc_; } - /// \brief Returns the number of registered kernels for this function. - virtual int num_kernels() const = 0; - + /// \brief Returns the number of registered kernels for this function. + virtual int num_kernels() const = 0; + /// \brief Return a kernel that can execute the function given the exact /// argument types (without implicit type casts or scalar->array promotions). /// @@ -215,67 +215,67 @@ class ARROW_EXPORT Function { /// are responsible for casting inputs to the type and shape required by the kernel. virtual Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const; - /// \brief Execute the function eagerly with the passed input arguments with - /// kernel dispatch, batch iteration, and memory allocation details taken - /// care of. - /// - /// If the `options` pointer is null, then `default_options()` will be used. - /// - /// This function can be overridden in subclasses. - virtual Result<Datum> Execute(const std::vector<Datum>& args, - const FunctionOptions* options, ExecContext* ctx) const; - - /// \brief Returns a the default options for this function. - /// - /// Whatever option semantics a Function has, implementations must guarantee - /// that default_options() is valid to pass to Execute as options. - const FunctionOptions* default_options() const { return default_options_; } - + /// \brief Execute the function eagerly with the passed input arguments with + /// kernel dispatch, batch iteration, and memory allocation details taken + /// care of. + /// + /// If the `options` pointer is null, then `default_options()` will be used. + /// + /// This function can be overridden in subclasses. + virtual Result<Datum> Execute(const std::vector<Datum>& args, + const FunctionOptions* options, ExecContext* ctx) const; + + /// \brief Returns a the default options for this function. + /// + /// Whatever option semantics a Function has, implementations must guarantee + /// that default_options() is valid to pass to Execute as options. + const FunctionOptions* default_options() const { return default_options_; } + virtual Status Validate() const; - protected: - Function(std::string name, Function::Kind kind, const Arity& arity, + protected: + Function(std::string name, Function::Kind kind, const Arity& arity, const FunctionDoc* doc, const FunctionOptions* default_options) - : name_(std::move(name)), - kind_(kind), - arity_(arity), + : name_(std::move(name)), + kind_(kind), + arity_(arity), doc_(doc ? doc : &FunctionDoc::Empty()), - default_options_(default_options) {} - + default_options_(default_options) {} + Status CheckArity(const std::vector<InputType>&) const; Status CheckArity(const std::vector<ValueDescr>&) const; - - std::string name_; - Function::Kind kind_; - Arity arity_; + + std::string name_; + Function::Kind kind_; + Arity arity_; const FunctionDoc* doc_; - const FunctionOptions* default_options_ = NULLPTR; -}; - -namespace detail { - -template <typename KernelType> -class FunctionImpl : public Function { - public: - /// \brief Return pointers to current-available kernels for inspection - std::vector<const KernelType*> kernels() const { - std::vector<const KernelType*> result; - for (const auto& kernel : kernels_) { - result.push_back(&kernel); - } - return result; - } - - int num_kernels() const override { return static_cast<int>(kernels_.size()); } - - protected: - FunctionImpl(std::string name, Function::Kind kind, const Arity& arity, + const FunctionOptions* default_options_ = NULLPTR; +}; + +namespace detail { + +template <typename KernelType> +class FunctionImpl : public Function { + public: + /// \brief Return pointers to current-available kernels for inspection + std::vector<const KernelType*> kernels() const { + std::vector<const KernelType*> result; + for (const auto& kernel : kernels_) { + result.push_back(&kernel); + } + return result; + } + + int num_kernels() const override { return static_cast<int>(kernels_.size()); } + + protected: + FunctionImpl(std::string name, Function::Kind kind, const Arity& arity, const FunctionDoc* doc, const FunctionOptions* default_options) : Function(std::move(name), kind, arity, doc, default_options) {} - - std::vector<KernelType> kernels_; -}; - + + std::vector<KernelType> kernels_; +}; + /// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. ARROW_EXPORT const Kernel* DispatchExactImpl(const Function* func, const std::vector<ValueDescr>&); @@ -284,72 +284,72 @@ const Kernel* DispatchExactImpl(const Function* func, const std::vector<ValueDes ARROW_EXPORT Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>&); -} // namespace detail - -/// \brief A function that executes elementwise operations on arrays or -/// scalars, and therefore whose results generally do not depend on the order -/// of the values in the arguments. Accepts and returns arrays that are all of -/// the same size. These functions roughly correspond to the functions used in -/// SQL expressions. -class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl<ScalarKernel> { - public: - using KernelType = ScalarKernel; - +} // namespace detail + +/// \brief A function that executes elementwise operations on arrays or +/// scalars, and therefore whose results generally do not depend on the order +/// of the values in the arguments. Accepts and returns arrays that are all of +/// the same size. These functions roughly correspond to the functions used in +/// SQL expressions. +class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl<ScalarKernel> { + public: + using KernelType = ScalarKernel; + ScalarFunction(std::string name, const Arity& arity, const FunctionDoc* doc, - const FunctionOptions* default_options = NULLPTR) + const FunctionOptions* default_options = NULLPTR) : detail::FunctionImpl<ScalarKernel>(std::move(name), Function::SCALAR, arity, doc, - default_options) {} - - /// \brief Add a kernel with given input/output types, no required state - /// initialization, preallocation for fixed-width types, and default null - /// handling (intersect validity bitmaps of inputs). - Status AddKernel(std::vector<InputType> in_types, OutputType out_type, - ArrayKernelExec exec, KernelInit init = NULLPTR); - - /// \brief Add a kernel (function implementation). Returns error if the - /// kernel's signature does not match the function's arity. - Status AddKernel(ScalarKernel kernel); -}; - -/// \brief A function that executes general array operations that may yield -/// outputs of different sizes or have results that depend on the whole array -/// contents. These functions roughly correspond to the functions found in -/// non-SQL array languages like APL and its derivatives. -class ARROW_EXPORT VectorFunction : public detail::FunctionImpl<VectorKernel> { - public: - using KernelType = VectorKernel; - + default_options) {} + + /// \brief Add a kernel with given input/output types, no required state + /// initialization, preallocation for fixed-width types, and default null + /// handling (intersect validity bitmaps of inputs). + Status AddKernel(std::vector<InputType> in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(ScalarKernel kernel); +}; + +/// \brief A function that executes general array operations that may yield +/// outputs of different sizes or have results that depend on the whole array +/// contents. These functions roughly correspond to the functions found in +/// non-SQL array languages like APL and its derivatives. +class ARROW_EXPORT VectorFunction : public detail::FunctionImpl<VectorKernel> { + public: + using KernelType = VectorKernel; + VectorFunction(std::string name, const Arity& arity, const FunctionDoc* doc, - const FunctionOptions* default_options = NULLPTR) + const FunctionOptions* default_options = NULLPTR) : detail::FunctionImpl<VectorKernel>(std::move(name), Function::VECTOR, arity, doc, - default_options) {} - - /// \brief Add a simple kernel with given input/output types, no required - /// state initialization, no data preallocation, and no preallocation of the - /// validity bitmap. - Status AddKernel(std::vector<InputType> in_types, OutputType out_type, - ArrayKernelExec exec, KernelInit init = NULLPTR); - - /// \brief Add a kernel (function implementation). Returns error if the - /// kernel's signature does not match the function's arity. - Status AddKernel(VectorKernel kernel); -}; - -class ARROW_EXPORT ScalarAggregateFunction - : public detail::FunctionImpl<ScalarAggregateKernel> { - public: - using KernelType = ScalarAggregateKernel; - + default_options) {} + + /// \brief Add a simple kernel with given input/output types, no required + /// state initialization, no data preallocation, and no preallocation of the + /// validity bitmap. + Status AddKernel(std::vector<InputType> in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(VectorKernel kernel); +}; + +class ARROW_EXPORT ScalarAggregateFunction + : public detail::FunctionImpl<ScalarAggregateKernel> { + public: + using KernelType = ScalarAggregateKernel; + ScalarAggregateFunction(std::string name, const Arity& arity, const FunctionDoc* doc, - const FunctionOptions* default_options = NULLPTR) - : detail::FunctionImpl<ScalarAggregateKernel>( + const FunctionOptions* default_options = NULLPTR) + : detail::FunctionImpl<ScalarAggregateKernel>( std::move(name), Function::SCALAR_AGGREGATE, arity, doc, default_options) {} - - /// \brief Add a kernel (function implementation). Returns error if the - /// kernel's signature does not match the function's arity. - Status AddKernel(ScalarAggregateKernel kernel); + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(ScalarAggregateKernel kernel); }; - + class ARROW_EXPORT HashAggregateFunction : public detail::FunctionImpl<HashAggregateKernel> { public: @@ -363,31 +363,31 @@ class ARROW_EXPORT HashAggregateFunction /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(HashAggregateKernel kernel); -}; - -/// \brief A function that dispatches to other functions. Must implement -/// MetaFunction::ExecuteImpl. -/// -/// For Array, ChunkedArray, and Scalar Datum kinds, may rely on the execution -/// of concrete Function types, but must handle other Datum kinds on its own. -class ARROW_EXPORT MetaFunction : public Function { - public: - int num_kernels() const override { return 0; } - - Result<Datum> Execute(const std::vector<Datum>& args, const FunctionOptions* options, - ExecContext* ctx) const override; - - protected: - virtual Result<Datum> ExecuteImpl(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const = 0; - +}; + +/// \brief A function that dispatches to other functions. Must implement +/// MetaFunction::ExecuteImpl. +/// +/// For Array, ChunkedArray, and Scalar Datum kinds, may rely on the execution +/// of concrete Function types, but must handle other Datum kinds on its own. +class ARROW_EXPORT MetaFunction : public Function { + public: + int num_kernels() const override { return 0; } + + Result<Datum> Execute(const std::vector<Datum>& args, const FunctionOptions* options, + ExecContext* ctx) const override; + + protected: + virtual Result<Datum> ExecuteImpl(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const = 0; + MetaFunction(std::string name, const Arity& arity, const FunctionDoc* doc, - const FunctionOptions* default_options = NULLPTR) + const FunctionOptions* default_options = NULLPTR) : Function(std::move(name), Function::META, arity, doc, default_options) {} -}; - -/// @} - -} // namespace compute -} // namespace arrow +}; + +/// @} + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.cc index f131f524d2..6d6dbb5ee5 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.cc @@ -1,72 +1,72 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/kernel.h" - -#include <cstddef> -#include <memory> -#include <sstream> -#include <string> - -#include "arrow/buffer.h" -#include "arrow/compute/exec.h" -#include "arrow/compute/util_internal.h" -#include "arrow/result.h" -#include "arrow/type_traits.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/hash_util.h" -#include "arrow/util/logging.h" -#include "arrow/util/macros.h" - -namespace arrow { - -using internal::checked_cast; -using internal::hash_combine; - -static constexpr size_t kHashSeed = 0; - -namespace compute { - -// ---------------------------------------------------------------------- -// KernelContext - -Result<std::shared_ptr<ResizableBuffer>> KernelContext::Allocate(int64_t nbytes) { - return AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool()); -} - -Result<std::shared_ptr<ResizableBuffer>> KernelContext::AllocateBitmap(int64_t num_bits) { - const int64_t nbytes = BitUtil::BytesForBits(num_bits); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> result, - AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool())); - // Since bitmaps are typically written bit by bit, we could leak uninitialized bits. - // Make sure all memory is initialized (this also appeases Valgrind). - internal::ZeroMemory(result.get()); - return result; -} - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernel.h" + +#include <cstddef> +#include <memory> +#include <sstream> +#include <string> + +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/util_internal.h" +#include "arrow/result.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/hash_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" + +namespace arrow { + +using internal::checked_cast; +using internal::hash_combine; + +static constexpr size_t kHashSeed = 0; + +namespace compute { + +// ---------------------------------------------------------------------- +// KernelContext + +Result<std::shared_ptr<ResizableBuffer>> KernelContext::Allocate(int64_t nbytes) { + return AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool()); +} + +Result<std::shared_ptr<ResizableBuffer>> KernelContext::AllocateBitmap(int64_t num_bits) { + const int64_t nbytes = BitUtil::BytesForBits(num_bits); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> result, + AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool())); + // Since bitmaps are typically written bit by bit, we could leak uninitialized bits. + // Make sure all memory is initialized (this also appeases Valgrind). + internal::ZeroMemory(result.get()); + return result; +} + Status Kernel::InitAll(KernelContext* ctx, const KernelInitArgs& args, std::vector<std::unique_ptr<KernelState>>* states) { for (auto& state : *states) { ARROW_ASSIGN_OR_RAISE(state, args.kernel->init(ctx, args)); - } + } return Status::OK(); -} - +} + Result<std::unique_ptr<KernelState>> ScalarAggregateKernel::MergeAll( const ScalarAggregateKernel* kernel, KernelContext* ctx, std::vector<std::unique_ptr<KernelState>> states) { @@ -78,409 +78,409 @@ Result<std::unique_ptr<KernelState>> ScalarAggregateKernel::MergeAll( } return std::move(out); } - -// ---------------------------------------------------------------------- -// Some basic TypeMatcher implementations - -namespace match { - -class SameTypeIdMatcher : public TypeMatcher { - public: - explicit SameTypeIdMatcher(Type::type accepted_id) : accepted_id_(accepted_id) {} - - bool Matches(const DataType& type) const override { return type.id() == accepted_id_; } - - std::string ToString() const override { - std::stringstream ss; - ss << "Type::" << ::arrow::internal::ToString(accepted_id_); - return ss.str(); - } - - bool Equals(const TypeMatcher& other) const override { - if (this == &other) { - return true; - } - auto casted = dynamic_cast<const SameTypeIdMatcher*>(&other); - if (casted == nullptr) { - return false; - } - return this->accepted_id_ == casted->accepted_id_; - } - - private: - Type::type accepted_id_; -}; - -std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) { - return std::make_shared<SameTypeIdMatcher>(type_id); -} - -template <typename ArrowType> -class TimeUnitMatcher : public TypeMatcher { - using ThisType = TimeUnitMatcher<ArrowType>; - - public: - explicit TimeUnitMatcher(TimeUnit::type accepted_unit) - : accepted_unit_(accepted_unit) {} - - bool Matches(const DataType& type) const override { - if (type.id() != ArrowType::type_id) { - return false; - } - const auto& time_type = checked_cast<const ArrowType&>(type); - return time_type.unit() == accepted_unit_; - } - - bool Equals(const TypeMatcher& other) const override { - if (this == &other) { - return true; - } - auto casted = dynamic_cast<const ThisType*>(&other); - if (casted == nullptr) { - return false; - } - return this->accepted_unit_ == casted->accepted_unit_; - } - - std::string ToString() const override { - std::stringstream ss; - ss << ArrowType::type_name() << "(" << ::arrow::internal::ToString(accepted_unit_) - << ")"; - return ss.str(); - } - - private: - TimeUnit::type accepted_unit_; -}; - -using DurationTypeUnitMatcher = TimeUnitMatcher<DurationType>; -using Time32TypeUnitMatcher = TimeUnitMatcher<Time32Type>; -using Time64TypeUnitMatcher = TimeUnitMatcher<Time64Type>; -using TimestampTypeUnitMatcher = TimeUnitMatcher<TimestampType>; - -std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit) { - return std::make_shared<TimestampTypeUnitMatcher>(unit); -} - -std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit) { - return std::make_shared<Time32TypeUnitMatcher>(unit); -} - -std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit) { - return std::make_shared<Time64TypeUnitMatcher>(unit); -} - -std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit) { - return std::make_shared<DurationTypeUnitMatcher>(unit); -} - -class IntegerMatcher : public TypeMatcher { - public: - IntegerMatcher() {} - - bool Matches(const DataType& type) const override { return is_integer(type.id()); } - - bool Equals(const TypeMatcher& other) const override { - if (this == &other) { - return true; - } - auto casted = dynamic_cast<const IntegerMatcher*>(&other); - return casted != nullptr; - } - - std::string ToString() const override { return "integer"; } -}; - -std::shared_ptr<TypeMatcher> Integer() { return std::make_shared<IntegerMatcher>(); } - -class PrimitiveMatcher : public TypeMatcher { - public: - PrimitiveMatcher() {} - - bool Matches(const DataType& type) const override { return is_primitive(type.id()); } - - bool Equals(const TypeMatcher& other) const override { - if (this == &other) { - return true; - } - auto casted = dynamic_cast<const PrimitiveMatcher*>(&other); - return casted != nullptr; - } - - std::string ToString() const override { return "primitive"; } -}; - -std::shared_ptr<TypeMatcher> Primitive() { return std::make_shared<PrimitiveMatcher>(); } - -class BinaryLikeMatcher : public TypeMatcher { - public: - BinaryLikeMatcher() {} - - bool Matches(const DataType& type) const override { return is_binary_like(type.id()); } - - bool Equals(const TypeMatcher& other) const override { - if (this == &other) { - return true; - } - auto casted = dynamic_cast<const BinaryLikeMatcher*>(&other); - return casted != nullptr; - } - std::string ToString() const override { return "binary-like"; } -}; - -std::shared_ptr<TypeMatcher> BinaryLike() { - return std::make_shared<BinaryLikeMatcher>(); -} - -class LargeBinaryLikeMatcher : public TypeMatcher { - public: - LargeBinaryLikeMatcher() {} - - bool Matches(const DataType& type) const override { - return is_large_binary_like(type.id()); - } - - bool Equals(const TypeMatcher& other) const override { - if (this == &other) { - return true; - } - auto casted = dynamic_cast<const LargeBinaryLikeMatcher*>(&other); - return casted != nullptr; - } - std::string ToString() const override { return "large-binary-like"; } -}; - -std::shared_ptr<TypeMatcher> LargeBinaryLike() { - return std::make_shared<LargeBinaryLikeMatcher>(); -} - -} // namespace match - -// ---------------------------------------------------------------------- -// InputType - -size_t InputType::Hash() const { - size_t result = kHashSeed; - hash_combine(result, static_cast<int>(shape_)); - hash_combine(result, static_cast<int>(kind_)); - switch (kind_) { - case InputType::EXACT_TYPE: - hash_combine(result, type_->Hash()); - break; - default: - break; - } - return result; -} - -std::string InputType::ToString() const { - std::stringstream ss; - switch (shape_) { - case ValueDescr::ANY: - ss << "any"; - break; - case ValueDescr::ARRAY: - ss << "array"; - break; - case ValueDescr::SCALAR: - ss << "scalar"; - break; - default: - DCHECK(false); - break; - } - ss << "["; - switch (kind_) { - case InputType::ANY_TYPE: - ss << "any"; - break; - case InputType::EXACT_TYPE: - ss << type_->ToString(); - break; - case InputType::USE_TYPE_MATCHER: { - ss << type_matcher_->ToString(); - } break; - default: - DCHECK(false); - break; - } - ss << "]"; - return ss.str(); -} - -bool InputType::Equals(const InputType& other) const { - if (this == &other) { - return true; - } - if (kind_ != other.kind_ || shape_ != other.shape_) { - return false; - } - switch (kind_) { - case InputType::ANY_TYPE: - return true; - case InputType::EXACT_TYPE: - return type_->Equals(*other.type_); - case InputType::USE_TYPE_MATCHER: - return type_matcher_->Equals(*other.type_matcher_); - default: - return false; - } -} - -bool InputType::Matches(const ValueDescr& descr) const { - if (shape_ != ValueDescr::ANY && descr.shape != shape_) { - return false; - } - switch (kind_) { - case InputType::EXACT_TYPE: - return type_->Equals(*descr.type); - case InputType::USE_TYPE_MATCHER: - return type_matcher_->Matches(*descr.type); - default: - // ANY_TYPE - return true; - } -} - -bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); } - -const std::shared_ptr<DataType>& InputType::type() const { - DCHECK_EQ(InputType::EXACT_TYPE, kind_); - return type_; -} - -const TypeMatcher& InputType::type_matcher() const { - DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_); - return *type_matcher_; -} - -// ---------------------------------------------------------------------- -// OutputType - -OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) { - shape_ = descr.shape; -} - -Result<ValueDescr> OutputType::Resolve(KernelContext* ctx, - const std::vector<ValueDescr>& args) const { - ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args); - if (kind_ == OutputType::FIXED) { - return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_); - } else { - ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args)); - if (resolved_descr.shape == ValueDescr::ANY) { - resolved_descr.shape = broadcasted_shape; - } - return resolved_descr; - } -} - -const std::shared_ptr<DataType>& OutputType::type() const { - DCHECK_EQ(FIXED, kind_); - return type_; -} - -const OutputType::Resolver& OutputType::resolver() const { - DCHECK_EQ(COMPUTED, kind_); - return resolver_; -} - -std::string OutputType::ToString() const { - if (kind_ == OutputType::FIXED) { - return type_->ToString(); - } else { - return "computed"; - } -} - -// ---------------------------------------------------------------------- -// KernelSignature - -KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType out_type, - bool is_varargs) - : in_types_(std::move(in_types)), - out_type_(std::move(out_type)), - is_varargs_(is_varargs), - hash_code_(0) { + +// ---------------------------------------------------------------------- +// Some basic TypeMatcher implementations + +namespace match { + +class SameTypeIdMatcher : public TypeMatcher { + public: + explicit SameTypeIdMatcher(Type::type accepted_id) : accepted_id_(accepted_id) {} + + bool Matches(const DataType& type) const override { return type.id() == accepted_id_; } + + std::string ToString() const override { + std::stringstream ss; + ss << "Type::" << ::arrow::internal::ToString(accepted_id_); + return ss.str(); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast<const SameTypeIdMatcher*>(&other); + if (casted == nullptr) { + return false; + } + return this->accepted_id_ == casted->accepted_id_; + } + + private: + Type::type accepted_id_; +}; + +std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) { + return std::make_shared<SameTypeIdMatcher>(type_id); +} + +template <typename ArrowType> +class TimeUnitMatcher : public TypeMatcher { + using ThisType = TimeUnitMatcher<ArrowType>; + + public: + explicit TimeUnitMatcher(TimeUnit::type accepted_unit) + : accepted_unit_(accepted_unit) {} + + bool Matches(const DataType& type) const override { + if (type.id() != ArrowType::type_id) { + return false; + } + const auto& time_type = checked_cast<const ArrowType&>(type); + return time_type.unit() == accepted_unit_; + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast<const ThisType*>(&other); + if (casted == nullptr) { + return false; + } + return this->accepted_unit_ == casted->accepted_unit_; + } + + std::string ToString() const override { + std::stringstream ss; + ss << ArrowType::type_name() << "(" << ::arrow::internal::ToString(accepted_unit_) + << ")"; + return ss.str(); + } + + private: + TimeUnit::type accepted_unit_; +}; + +using DurationTypeUnitMatcher = TimeUnitMatcher<DurationType>; +using Time32TypeUnitMatcher = TimeUnitMatcher<Time32Type>; +using Time64TypeUnitMatcher = TimeUnitMatcher<Time64Type>; +using TimestampTypeUnitMatcher = TimeUnitMatcher<TimestampType>; + +std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit) { + return std::make_shared<TimestampTypeUnitMatcher>(unit); +} + +std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit) { + return std::make_shared<Time32TypeUnitMatcher>(unit); +} + +std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit) { + return std::make_shared<Time64TypeUnitMatcher>(unit); +} + +std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit) { + return std::make_shared<DurationTypeUnitMatcher>(unit); +} + +class IntegerMatcher : public TypeMatcher { + public: + IntegerMatcher() {} + + bool Matches(const DataType& type) const override { return is_integer(type.id()); } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast<const IntegerMatcher*>(&other); + return casted != nullptr; + } + + std::string ToString() const override { return "integer"; } +}; + +std::shared_ptr<TypeMatcher> Integer() { return std::make_shared<IntegerMatcher>(); } + +class PrimitiveMatcher : public TypeMatcher { + public: + PrimitiveMatcher() {} + + bool Matches(const DataType& type) const override { return is_primitive(type.id()); } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast<const PrimitiveMatcher*>(&other); + return casted != nullptr; + } + + std::string ToString() const override { return "primitive"; } +}; + +std::shared_ptr<TypeMatcher> Primitive() { return std::make_shared<PrimitiveMatcher>(); } + +class BinaryLikeMatcher : public TypeMatcher { + public: + BinaryLikeMatcher() {} + + bool Matches(const DataType& type) const override { return is_binary_like(type.id()); } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast<const BinaryLikeMatcher*>(&other); + return casted != nullptr; + } + std::string ToString() const override { return "binary-like"; } +}; + +std::shared_ptr<TypeMatcher> BinaryLike() { + return std::make_shared<BinaryLikeMatcher>(); +} + +class LargeBinaryLikeMatcher : public TypeMatcher { + public: + LargeBinaryLikeMatcher() {} + + bool Matches(const DataType& type) const override { + return is_large_binary_like(type.id()); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast<const LargeBinaryLikeMatcher*>(&other); + return casted != nullptr; + } + std::string ToString() const override { return "large-binary-like"; } +}; + +std::shared_ptr<TypeMatcher> LargeBinaryLike() { + return std::make_shared<LargeBinaryLikeMatcher>(); +} + +} // namespace match + +// ---------------------------------------------------------------------- +// InputType + +size_t InputType::Hash() const { + size_t result = kHashSeed; + hash_combine(result, static_cast<int>(shape_)); + hash_combine(result, static_cast<int>(kind_)); + switch (kind_) { + case InputType::EXACT_TYPE: + hash_combine(result, type_->Hash()); + break; + default: + break; + } + return result; +} + +std::string InputType::ToString() const { + std::stringstream ss; + switch (shape_) { + case ValueDescr::ANY: + ss << "any"; + break; + case ValueDescr::ARRAY: + ss << "array"; + break; + case ValueDescr::SCALAR: + ss << "scalar"; + break; + default: + DCHECK(false); + break; + } + ss << "["; + switch (kind_) { + case InputType::ANY_TYPE: + ss << "any"; + break; + case InputType::EXACT_TYPE: + ss << type_->ToString(); + break; + case InputType::USE_TYPE_MATCHER: { + ss << type_matcher_->ToString(); + } break; + default: + DCHECK(false); + break; + } + ss << "]"; + return ss.str(); +} + +bool InputType::Equals(const InputType& other) const { + if (this == &other) { + return true; + } + if (kind_ != other.kind_ || shape_ != other.shape_) { + return false; + } + switch (kind_) { + case InputType::ANY_TYPE: + return true; + case InputType::EXACT_TYPE: + return type_->Equals(*other.type_); + case InputType::USE_TYPE_MATCHER: + return type_matcher_->Equals(*other.type_matcher_); + default: + return false; + } +} + +bool InputType::Matches(const ValueDescr& descr) const { + if (shape_ != ValueDescr::ANY && descr.shape != shape_) { + return false; + } + switch (kind_) { + case InputType::EXACT_TYPE: + return type_->Equals(*descr.type); + case InputType::USE_TYPE_MATCHER: + return type_matcher_->Matches(*descr.type); + default: + // ANY_TYPE + return true; + } +} + +bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); } + +const std::shared_ptr<DataType>& InputType::type() const { + DCHECK_EQ(InputType::EXACT_TYPE, kind_); + return type_; +} + +const TypeMatcher& InputType::type_matcher() const { + DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_); + return *type_matcher_; +} + +// ---------------------------------------------------------------------- +// OutputType + +OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) { + shape_ = descr.shape; +} + +Result<ValueDescr> OutputType::Resolve(KernelContext* ctx, + const std::vector<ValueDescr>& args) const { + ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args); + if (kind_ == OutputType::FIXED) { + return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_); + } else { + ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args)); + if (resolved_descr.shape == ValueDescr::ANY) { + resolved_descr.shape = broadcasted_shape; + } + return resolved_descr; + } +} + +const std::shared_ptr<DataType>& OutputType::type() const { + DCHECK_EQ(FIXED, kind_); + return type_; +} + +const OutputType::Resolver& OutputType::resolver() const { + DCHECK_EQ(COMPUTED, kind_); + return resolver_; +} + +std::string OutputType::ToString() const { + if (kind_ == OutputType::FIXED) { + return type_->ToString(); + } else { + return "computed"; + } +} + +// ---------------------------------------------------------------------- +// KernelSignature + +KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType out_type, + bool is_varargs) + : in_types_(std::move(in_types)), + out_type_(std::move(out_type)), + is_varargs_(is_varargs), + hash_code_(0) { DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1))); -} - -std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> in_types, - OutputType out_type, - bool is_varargs) { - return std::make_shared<KernelSignature>(std::move(in_types), std::move(out_type), - is_varargs); -} - -bool KernelSignature::Equals(const KernelSignature& other) const { - if (is_varargs_ != other.is_varargs_) { - return false; - } - if (in_types_.size() != other.in_types_.size()) { - return false; - } - for (size_t i = 0; i < in_types_.size(); ++i) { - if (!in_types_[i].Equals(other.in_types_[i])) { - return false; - } - } - return true; -} - -bool KernelSignature::MatchesInputs(const std::vector<ValueDescr>& args) const { - if (is_varargs_) { +} + +std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> in_types, + OutputType out_type, + bool is_varargs) { + return std::make_shared<KernelSignature>(std::move(in_types), std::move(out_type), + is_varargs); +} + +bool KernelSignature::Equals(const KernelSignature& other) const { + if (is_varargs_ != other.is_varargs_) { + return false; + } + if (in_types_.size() != other.in_types_.size()) { + return false; + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (!in_types_[i].Equals(other.in_types_[i])) { + return false; + } + } + return true; +} + +bool KernelSignature::MatchesInputs(const std::vector<ValueDescr>& args) const { + if (is_varargs_) { for (size_t i = 0; i < args.size(); ++i) { if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) { - return false; - } - } - } else { - if (args.size() != in_types_.size()) { - return false; - } - for (size_t i = 0; i < in_types_.size(); ++i) { - if (!in_types_[i].Matches(args[i])) { - return false; - } - } - } - return true; -} - -size_t KernelSignature::Hash() const { - if (hash_code_ != 0) { - return hash_code_; - } - size_t result = kHashSeed; - for (const auto& in_type : in_types_) { - hash_combine(result, in_type.Hash()); - } - hash_code_ = result; - return result; -} - -std::string KernelSignature::ToString() const { - std::stringstream ss; - - if (is_varargs_) { + return false; + } + } + } else { + if (args.size() != in_types_.size()) { + return false; + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (!in_types_[i].Matches(args[i])) { + return false; + } + } + } + return true; +} + +size_t KernelSignature::Hash() const { + if (hash_code_ != 0) { + return hash_code_; + } + size_t result = kHashSeed; + for (const auto& in_type : in_types_) { + hash_combine(result, in_type.Hash()); + } + hash_code_ = result; + return result; +} + +std::string KernelSignature::ToString() const { + std::stringstream ss; + + if (is_varargs_) { ss << "varargs["; - } else { - ss << "("; + } else { + ss << "("; } for (size_t i = 0; i < in_types_.size(); ++i) { if (i > 0) { ss << ", "; - } + } ss << in_types_[i].ToString(); } if (is_varargs_) { ss << "]"; } else { - ss << ")"; - } - ss << " -> " << out_type_.ToString(); - return ss.str(); -} - -} // namespace compute -} // namespace arrow + ss << ")"; + } + ss << " -> " << out_type_.ToString(); + return ss.str(); +} + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h index 36d20c7289..6cea5558e9 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h @@ -1,695 +1,695 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle - -#pragma once - -#include <cstddef> -#include <cstdint> -#include <functional> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "arrow/buffer.h" -#include "arrow/compute/exec.h" -#include "arrow/datum.h" -#include "arrow/memory_pool.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace compute { - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include <cstddef> +#include <cstdint> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/datum.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + class FunctionOptions; - -/// \brief Base class for opaque kernel-specific state. For example, if there -/// is some kind of initialization required. -struct ARROW_EXPORT KernelState { - virtual ~KernelState() = default; -}; - -/// \brief Context/state for the execution of a particular kernel. -class ARROW_EXPORT KernelContext { - public: + +/// \brief Base class for opaque kernel-specific state. For example, if there +/// is some kind of initialization required. +struct ARROW_EXPORT KernelState { + virtual ~KernelState() = default; +}; + +/// \brief Context/state for the execution of a particular kernel. +class ARROW_EXPORT KernelContext { + public: explicit KernelContext(ExecContext* exec_ctx) : exec_ctx_(exec_ctx), state_() {} - - /// \brief Allocate buffer from the context's memory pool. The contents are - /// not initialized. - Result<std::shared_ptr<ResizableBuffer>> Allocate(int64_t nbytes); - - /// \brief Allocate buffer for bitmap from the context's memory pool. Like - /// Allocate, the contents of the buffer are not initialized but the last - /// byte is preemptively zeroed to help avoid ASAN or valgrind issues. - Result<std::shared_ptr<ResizableBuffer>> AllocateBitmap(int64_t num_bits); - - /// \brief Assign the active KernelState to be utilized for each stage of - /// kernel execution. Ownership and memory lifetime of the KernelState must - /// be minded separately. - void SetState(KernelState* state) { state_ = state; } - - KernelState* state() { return state_; } - - /// \brief Configuration related to function execution that is to be shared - /// across multiple kernels. - ExecContext* exec_context() { return exec_ctx_; } - - /// \brief The memory pool to use for allocations. For now, it uses the - /// MemoryPool contained in the ExecContext used to create the KernelContext. - MemoryPool* memory_pool() { return exec_ctx_->memory_pool(); } - - private: - ExecContext* exec_ctx_; - KernelState* state_; -}; - -/// \brief The standard kernel execution API that must be implemented for -/// SCALAR and VECTOR kernel types. This includes both stateless and stateful -/// kernels. Kernels depending on some execution state access that state via -/// subclasses of KernelState set on the KernelContext object. May be used for -/// SCALAR and VECTOR kernel kinds. Implementations should endeavor to write -/// into pre-allocated memory if they are able, though for some kernels -/// (e.g. in cases when a builder like StringBuilder) must be employed this may -/// not be possible. + + /// \brief Allocate buffer from the context's memory pool. The contents are + /// not initialized. + Result<std::shared_ptr<ResizableBuffer>> Allocate(int64_t nbytes); + + /// \brief Allocate buffer for bitmap from the context's memory pool. Like + /// Allocate, the contents of the buffer are not initialized but the last + /// byte is preemptively zeroed to help avoid ASAN or valgrind issues. + Result<std::shared_ptr<ResizableBuffer>> AllocateBitmap(int64_t num_bits); + + /// \brief Assign the active KernelState to be utilized for each stage of + /// kernel execution. Ownership and memory lifetime of the KernelState must + /// be minded separately. + void SetState(KernelState* state) { state_ = state; } + + KernelState* state() { return state_; } + + /// \brief Configuration related to function execution that is to be shared + /// across multiple kernels. + ExecContext* exec_context() { return exec_ctx_; } + + /// \brief The memory pool to use for allocations. For now, it uses the + /// MemoryPool contained in the ExecContext used to create the KernelContext. + MemoryPool* memory_pool() { return exec_ctx_->memory_pool(); } + + private: + ExecContext* exec_ctx_; + KernelState* state_; +}; + +/// \brief The standard kernel execution API that must be implemented for +/// SCALAR and VECTOR kernel types. This includes both stateless and stateful +/// kernels. Kernels depending on some execution state access that state via +/// subclasses of KernelState set on the KernelContext object. May be used for +/// SCALAR and VECTOR kernel kinds. Implementations should endeavor to write +/// into pre-allocated memory if they are able, though for some kernels +/// (e.g. in cases when a builder like StringBuilder) must be employed this may +/// not be possible. using ArrayKernelExec = std::function<Status(KernelContext*, const ExecBatch&, Datum*)>; - -/// \brief An type-checking interface to permit customizable validation rules -/// for use with InputType and KernelSignature. This is for scenarios where the -/// acceptance is not an exact type instance, such as a TIMESTAMP type for a -/// specific TimeUnit, but permitting any time zone. -struct ARROW_EXPORT TypeMatcher { - virtual ~TypeMatcher() = default; - - /// \brief Return true if this matcher accepts the data type. - virtual bool Matches(const DataType& type) const = 0; - - /// \brief A human-interpretable string representation of what the type - /// matcher checks for, usable when printing KernelSignature or formatting - /// error messages. - virtual std::string ToString() const = 0; - - /// \brief Return true if this TypeMatcher contains the same matching rule as - /// the other. Currently depends on RTTI. - virtual bool Equals(const TypeMatcher& other) const = 0; -}; - -namespace match { - -/// \brief Match any DataType instance having the same DataType::id. -ARROW_EXPORT std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id); - -/// \brief Match any TimestampType instance having the same unit, but the time -/// zones can be different. -ARROW_EXPORT std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit); -ARROW_EXPORT std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit); -ARROW_EXPORT std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit); -ARROW_EXPORT std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit); - -// \brief Match any integer type -ARROW_EXPORT std::shared_ptr<TypeMatcher> Integer(); - -// Match types using 32-bit varbinary representation -ARROW_EXPORT std::shared_ptr<TypeMatcher> BinaryLike(); - -// Match types using 64-bit varbinary representation -ARROW_EXPORT std::shared_ptr<TypeMatcher> LargeBinaryLike(); - -// \brief Match any primitive type (boolean or any type representable as a C -// Type) -ARROW_EXPORT std::shared_ptr<TypeMatcher> Primitive(); - -} // namespace match - -/// \brief An object used for type- and shape-checking arguments to be passed -/// to a kernel and stored in a KernelSignature. Distinguishes between ARRAY -/// and SCALAR arguments using ValueDescr::Shape. The type-checking rule can be -/// supplied either with an exact DataType instance or a custom TypeMatcher. -class ARROW_EXPORT InputType { - public: - /// \brief The kind of type-checking rule that the InputType contains. - enum Kind { - /// \brief Accept any value type. - ANY_TYPE, - - /// \brief A fixed arrow::DataType and will only exact match having this - /// exact type (e.g. same TimestampType unit, same decimal scale and - /// precision, or same nested child types). - EXACT_TYPE, - - /// \brief Uses a TypeMatcher implementation to check the type. - USE_TYPE_MATCHER - }; - - /// \brief Accept any value type but with a specific shape (e.g. any Array or - /// any Scalar). - InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction - : kind_(ANY_TYPE), shape_(shape) {} - - /// \brief Accept an exact value type. - InputType(std::shared_ptr<DataType> type, // NOLINT implicit construction - ValueDescr::Shape shape = ValueDescr::ANY) - : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {} - - /// \brief Accept an exact value type and shape provided by a ValueDescr. - InputType(const ValueDescr& descr) // NOLINT implicit construction - : InputType(descr.type, descr.shape) {} - - /// \brief Use the passed TypeMatcher to type check. - InputType(std::shared_ptr<TypeMatcher> type_matcher, // NOLINT implicit construction - ValueDescr::Shape shape = ValueDescr::ANY) - : kind_(USE_TYPE_MATCHER), shape_(shape), type_matcher_(std::move(type_matcher)) {} - - /// \brief Match any type with the given Type::type. Uses a TypeMatcher for - /// its implementation. - explicit InputType(Type::type type_id, ValueDescr::Shape shape = ValueDescr::ANY) - : InputType(match::SameTypeId(type_id), shape) {} - - InputType(const InputType& other) { CopyInto(other); } - - void operator=(const InputType& other) { CopyInto(other); } - - InputType(InputType&& other) { MoveInto(std::forward<InputType>(other)); } - - void operator=(InputType&& other) { MoveInto(std::forward<InputType>(other)); } - - // \brief Match an array with the given exact type. Convenience constructor. - static InputType Array(std::shared_ptr<DataType> type) { - return InputType(std::move(type), ValueDescr::ARRAY); - } - - // \brief Match a scalar with the given exact type. Convenience constructor. - static InputType Scalar(std::shared_ptr<DataType> type) { - return InputType(std::move(type), ValueDescr::SCALAR); - } - - // \brief Match an array with the given Type::type id. Convenience - // constructor. - static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); } - - // \brief Match a scalar with the given Type::type id. Convenience - // constructor. - static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); } - - /// \brief Return true if this input type matches the same type cases as the - /// other. - bool Equals(const InputType& other) const; - - bool operator==(const InputType& other) const { return this->Equals(other); } - - bool operator!=(const InputType& other) const { return !(*this == other); } - - /// \brief Return hash code. - size_t Hash() const; - - /// \brief Render a human-readable string representation. - std::string ToString() const; - - /// \brief Return true if the value matches this argument kind in type - /// and shape. - bool Matches(const Datum& value) const; - - /// \brief Return true if the value descriptor matches this argument kind in - /// type and shape. - bool Matches(const ValueDescr& value) const; - - /// \brief The type matching rule that this InputType uses. - Kind kind() const { return kind_; } - - /// \brief Indicates whether this InputType matches Array (ValueDescr::ARRAY), - /// Scalar (ValueDescr::SCALAR) values, or both (ValueDescr::ANY). - ValueDescr::Shape shape() const { return shape_; } - - /// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType - /// must match. Otherwise this function should not be used and will assert in - /// debug builds. - const std::shared_ptr<DataType>& type() const; - - /// \brief For InputType::USE_TYPE_MATCHER, the TypeMatcher to be used for - /// checking the type of a value. Otherwise this function should not be used - /// and will assert in debug builds. - const TypeMatcher& type_matcher() const; - - private: - void CopyInto(const InputType& other) { - this->kind_ = other.kind_; - this->shape_ = other.shape_; - this->type_ = other.type_; - this->type_matcher_ = other.type_matcher_; - } - - void MoveInto(InputType&& other) { - this->kind_ = other.kind_; - this->shape_ = other.shape_; - this->type_ = std::move(other.type_); - this->type_matcher_ = std::move(other.type_matcher_); - } - - Kind kind_; - - ValueDescr::Shape shape_ = ValueDescr::ANY; - - // For EXACT_TYPE Kind - std::shared_ptr<DataType> type_; - - // For USE_TYPE_MATCHER Kind - std::shared_ptr<TypeMatcher> type_matcher_; -}; - -/// \brief Container to capture both exact and input-dependent output types. -/// -/// The value shape returned by Resolve will be determined by broadcasting the -/// shapes of the input arguments, otherwise this is handled by the -/// user-defined resolver function: -/// -/// * Any ARRAY shape -> output shape is ARRAY -/// * All SCALAR shapes -> output shape is SCALAR -class ARROW_EXPORT OutputType { - public: - /// \brief An enum indicating whether the value type is an invariant fixed - /// value or one that's computed by a kernel-defined resolver function. - enum ResolveKind { FIXED, COMPUTED }; - - /// Type resolution function. Given input types and shapes, return output - /// type and shape. This function SHOULD _not_ be used to check for arity, - /// that is to be performed one or more layers above. May make use of kernel - /// state to know what type to output in some cases. - using Resolver = - std::function<Result<ValueDescr>(KernelContext*, const std::vector<ValueDescr>&)>; - - /// \brief Output an exact type, but with shape determined by promoting the - /// shapes of the inputs (any ARRAY argument yields ARRAY). - OutputType(std::shared_ptr<DataType> type) // NOLINT implicit construction - : kind_(FIXED), type_(std::move(type)) {} - - /// \brief Output the exact type and shape provided by a ValueDescr - OutputType(ValueDescr descr); // NOLINT implicit construction - - explicit OutputType(Resolver resolver) - : kind_(COMPUTED), resolver_(std::move(resolver)) {} - - OutputType(const OutputType& other) { - this->kind_ = other.kind_; - this->shape_ = other.shape_; - this->type_ = other.type_; - this->resolver_ = other.resolver_; - } - - OutputType(OutputType&& other) { - this->kind_ = other.kind_; - this->type_ = std::move(other.type_); - this->shape_ = other.shape_; - this->resolver_ = other.resolver_; - } - + +/// \brief An type-checking interface to permit customizable validation rules +/// for use with InputType and KernelSignature. This is for scenarios where the +/// acceptance is not an exact type instance, such as a TIMESTAMP type for a +/// specific TimeUnit, but permitting any time zone. +struct ARROW_EXPORT TypeMatcher { + virtual ~TypeMatcher() = default; + + /// \brief Return true if this matcher accepts the data type. + virtual bool Matches(const DataType& type) const = 0; + + /// \brief A human-interpretable string representation of what the type + /// matcher checks for, usable when printing KernelSignature or formatting + /// error messages. + virtual std::string ToString() const = 0; + + /// \brief Return true if this TypeMatcher contains the same matching rule as + /// the other. Currently depends on RTTI. + virtual bool Equals(const TypeMatcher& other) const = 0; +}; + +namespace match { + +/// \brief Match any DataType instance having the same DataType::id. +ARROW_EXPORT std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id); + +/// \brief Match any TimestampType instance having the same unit, but the time +/// zones can be different. +ARROW_EXPORT std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit); + +// \brief Match any integer type +ARROW_EXPORT std::shared_ptr<TypeMatcher> Integer(); + +// Match types using 32-bit varbinary representation +ARROW_EXPORT std::shared_ptr<TypeMatcher> BinaryLike(); + +// Match types using 64-bit varbinary representation +ARROW_EXPORT std::shared_ptr<TypeMatcher> LargeBinaryLike(); + +// \brief Match any primitive type (boolean or any type representable as a C +// Type) +ARROW_EXPORT std::shared_ptr<TypeMatcher> Primitive(); + +} // namespace match + +/// \brief An object used for type- and shape-checking arguments to be passed +/// to a kernel and stored in a KernelSignature. Distinguishes between ARRAY +/// and SCALAR arguments using ValueDescr::Shape. The type-checking rule can be +/// supplied either with an exact DataType instance or a custom TypeMatcher. +class ARROW_EXPORT InputType { + public: + /// \brief The kind of type-checking rule that the InputType contains. + enum Kind { + /// \brief Accept any value type. + ANY_TYPE, + + /// \brief A fixed arrow::DataType and will only exact match having this + /// exact type (e.g. same TimestampType unit, same decimal scale and + /// precision, or same nested child types). + EXACT_TYPE, + + /// \brief Uses a TypeMatcher implementation to check the type. + USE_TYPE_MATCHER + }; + + /// \brief Accept any value type but with a specific shape (e.g. any Array or + /// any Scalar). + InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction + : kind_(ANY_TYPE), shape_(shape) {} + + /// \brief Accept an exact value type. + InputType(std::shared_ptr<DataType> type, // NOLINT implicit construction + ValueDescr::Shape shape = ValueDescr::ANY) + : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {} + + /// \brief Accept an exact value type and shape provided by a ValueDescr. + InputType(const ValueDescr& descr) // NOLINT implicit construction + : InputType(descr.type, descr.shape) {} + + /// \brief Use the passed TypeMatcher to type check. + InputType(std::shared_ptr<TypeMatcher> type_matcher, // NOLINT implicit construction + ValueDescr::Shape shape = ValueDescr::ANY) + : kind_(USE_TYPE_MATCHER), shape_(shape), type_matcher_(std::move(type_matcher)) {} + + /// \brief Match any type with the given Type::type. Uses a TypeMatcher for + /// its implementation. + explicit InputType(Type::type type_id, ValueDescr::Shape shape = ValueDescr::ANY) + : InputType(match::SameTypeId(type_id), shape) {} + + InputType(const InputType& other) { CopyInto(other); } + + void operator=(const InputType& other) { CopyInto(other); } + + InputType(InputType&& other) { MoveInto(std::forward<InputType>(other)); } + + void operator=(InputType&& other) { MoveInto(std::forward<InputType>(other)); } + + // \brief Match an array with the given exact type. Convenience constructor. + static InputType Array(std::shared_ptr<DataType> type) { + return InputType(std::move(type), ValueDescr::ARRAY); + } + + // \brief Match a scalar with the given exact type. Convenience constructor. + static InputType Scalar(std::shared_ptr<DataType> type) { + return InputType(std::move(type), ValueDescr::SCALAR); + } + + // \brief Match an array with the given Type::type id. Convenience + // constructor. + static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); } + + // \brief Match a scalar with the given Type::type id. Convenience + // constructor. + static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); } + + /// \brief Return true if this input type matches the same type cases as the + /// other. + bool Equals(const InputType& other) const; + + bool operator==(const InputType& other) const { return this->Equals(other); } + + bool operator!=(const InputType& other) const { return !(*this == other); } + + /// \brief Return hash code. + size_t Hash() const; + + /// \brief Render a human-readable string representation. + std::string ToString() const; + + /// \brief Return true if the value matches this argument kind in type + /// and shape. + bool Matches(const Datum& value) const; + + /// \brief Return true if the value descriptor matches this argument kind in + /// type and shape. + bool Matches(const ValueDescr& value) const; + + /// \brief The type matching rule that this InputType uses. + Kind kind() const { return kind_; } + + /// \brief Indicates whether this InputType matches Array (ValueDescr::ARRAY), + /// Scalar (ValueDescr::SCALAR) values, or both (ValueDescr::ANY). + ValueDescr::Shape shape() const { return shape_; } + + /// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType + /// must match. Otherwise this function should not be used and will assert in + /// debug builds. + const std::shared_ptr<DataType>& type() const; + + /// \brief For InputType::USE_TYPE_MATCHER, the TypeMatcher to be used for + /// checking the type of a value. Otherwise this function should not be used + /// and will assert in debug builds. + const TypeMatcher& type_matcher() const; + + private: + void CopyInto(const InputType& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = other.type_; + this->type_matcher_ = other.type_matcher_; + } + + void MoveInto(InputType&& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = std::move(other.type_); + this->type_matcher_ = std::move(other.type_matcher_); + } + + Kind kind_; + + ValueDescr::Shape shape_ = ValueDescr::ANY; + + // For EXACT_TYPE Kind + std::shared_ptr<DataType> type_; + + // For USE_TYPE_MATCHER Kind + std::shared_ptr<TypeMatcher> type_matcher_; +}; + +/// \brief Container to capture both exact and input-dependent output types. +/// +/// The value shape returned by Resolve will be determined by broadcasting the +/// shapes of the input arguments, otherwise this is handled by the +/// user-defined resolver function: +/// +/// * Any ARRAY shape -> output shape is ARRAY +/// * All SCALAR shapes -> output shape is SCALAR +class ARROW_EXPORT OutputType { + public: + /// \brief An enum indicating whether the value type is an invariant fixed + /// value or one that's computed by a kernel-defined resolver function. + enum ResolveKind { FIXED, COMPUTED }; + + /// Type resolution function. Given input types and shapes, return output + /// type and shape. This function SHOULD _not_ be used to check for arity, + /// that is to be performed one or more layers above. May make use of kernel + /// state to know what type to output in some cases. + using Resolver = + std::function<Result<ValueDescr>(KernelContext*, const std::vector<ValueDescr>&)>; + + /// \brief Output an exact type, but with shape determined by promoting the + /// shapes of the inputs (any ARRAY argument yields ARRAY). + OutputType(std::shared_ptr<DataType> type) // NOLINT implicit construction + : kind_(FIXED), type_(std::move(type)) {} + + /// \brief Output the exact type and shape provided by a ValueDescr + OutputType(ValueDescr descr); // NOLINT implicit construction + + explicit OutputType(Resolver resolver) + : kind_(COMPUTED), resolver_(std::move(resolver)) {} + + OutputType(const OutputType& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = other.type_; + this->resolver_ = other.resolver_; + } + + OutputType(OutputType&& other) { + this->kind_ = other.kind_; + this->type_ = std::move(other.type_); + this->shape_ = other.shape_; + this->resolver_ = other.resolver_; + } + OutputType& operator=(const OutputType&) = default; OutputType& operator=(OutputType&&) = default; - /// \brief Return the shape and type of the expected output value of the - /// kernel given the value descriptors (shapes and types) of the input - /// arguments. The resolver may make use of state information kept in the - /// KernelContext. - Result<ValueDescr> Resolve(KernelContext* ctx, - const std::vector<ValueDescr>& args) const; - - /// \brief The exact output value type for the FIXED kind. - const std::shared_ptr<DataType>& type() const; - - /// \brief For use with COMPUTED resolution strategy. It may be more - /// convenient to invoke this with OutputType::Resolve returned from this - /// method. - const Resolver& resolver() const; - - /// \brief Render a human-readable string representation. - std::string ToString() const; - - /// \brief Return the kind of type resolution of this output type, whether - /// fixed/invariant or computed by a resolver. - ResolveKind kind() const { return kind_; } - - /// \brief If the shape is ANY, then Resolve will compute the shape based on - /// the input arguments. - ValueDescr::Shape shape() const { return shape_; } - - private: - ResolveKind kind_; - - // For FIXED resolution - std::shared_ptr<DataType> type_; - - /// \brief The shape of the output type to return when using Resolve. If ANY - /// will promote the input shapes. - ValueDescr::Shape shape_ = ValueDescr::ANY; - - // For COMPUTED resolution - Resolver resolver_; -}; - -/// \brief Holds the input types and output type of the kernel. -/// + /// \brief Return the shape and type of the expected output value of the + /// kernel given the value descriptors (shapes and types) of the input + /// arguments. The resolver may make use of state information kept in the + /// KernelContext. + Result<ValueDescr> Resolve(KernelContext* ctx, + const std::vector<ValueDescr>& args) const; + + /// \brief The exact output value type for the FIXED kind. + const std::shared_ptr<DataType>& type() const; + + /// \brief For use with COMPUTED resolution strategy. It may be more + /// convenient to invoke this with OutputType::Resolve returned from this + /// method. + const Resolver& resolver() const; + + /// \brief Render a human-readable string representation. + std::string ToString() const; + + /// \brief Return the kind of type resolution of this output type, whether + /// fixed/invariant or computed by a resolver. + ResolveKind kind() const { return kind_; } + + /// \brief If the shape is ANY, then Resolve will compute the shape based on + /// the input arguments. + ValueDescr::Shape shape() const { return shape_; } + + private: + ResolveKind kind_; + + // For FIXED resolution + std::shared_ptr<DataType> type_; + + /// \brief The shape of the output type to return when using Resolve. If ANY + /// will promote the input shapes. + ValueDescr::Shape shape_ = ValueDescr::ANY; + + // For COMPUTED resolution + Resolver resolver_; +}; + +/// \brief Holds the input types and output type of the kernel. +/// /// VarArgs functions with minimum N arguments should pass up to N input types to be /// used to validate the input types of a function invocation. The first N-1 types /// will be matched against the first N-1 arguments, and the last type will be /// matched against the remaining arguments. -class ARROW_EXPORT KernelSignature { - public: - KernelSignature(std::vector<InputType> in_types, OutputType out_type, - bool is_varargs = false); - - /// \brief Convenience ctor since make_shared can be awkward - static std::shared_ptr<KernelSignature> Make(std::vector<InputType> in_types, - OutputType out_type, - bool is_varargs = false); - - /// \brief Return true if the signature if compatible with the list of input - /// value descriptors. - bool MatchesInputs(const std::vector<ValueDescr>& descriptors) const; - - /// \brief Returns true if the input types of each signature are - /// equal. Well-formed functions should have a deterministic output type - /// given input types, but currently it is the responsibility of the - /// developer to ensure this. - bool Equals(const KernelSignature& other) const; - - bool operator==(const KernelSignature& other) const { return this->Equals(other); } - - bool operator!=(const KernelSignature& other) const { return !(*this == other); } - - /// \brief Compute a hash code for the signature - size_t Hash() const; - - /// \brief The input types for the kernel. For VarArgs functions, this should - /// generally contain a single validator to use for validating all of the - /// function arguments. - const std::vector<InputType>& in_types() const { return in_types_; } - - /// \brief The output type for the kernel. Use Resolve to return the exact - /// output given input argument ValueDescrs, since many kernels' output types - /// depend on their input types (or their type metadata). - const OutputType& out_type() const { return out_type_; } - - /// \brief Render a human-readable string representation - std::string ToString() const; - - bool is_varargs() const { return is_varargs_; } - - private: - std::vector<InputType> in_types_; - OutputType out_type_; - bool is_varargs_; - - // For caching the hash code after it's computed the first time - mutable uint64_t hash_code_; -}; - -/// \brief A function may contain multiple variants of a kernel for a given -/// type combination for different SIMD levels. Based on the active system's -/// CPU info or the user's preferences, we can elect to use one over the other. -struct SimdLevel { - enum type { NONE = 0, SSE4_2, AVX, AVX2, AVX512, NEON, MAX }; -}; - -/// \brief The strategy to use for propagating or otherwise populating the -/// validity bitmap of a kernel output. -struct NullHandling { - enum type { - /// Compute the output validity bitmap by intersecting the validity bitmaps - /// of the arguments using bitwise-and operations. This means that values - /// in the output are valid/non-null only if the corresponding values in - /// all input arguments were valid/non-null. Kernel generally need not - /// touch the bitmap thereafter, but a kernel's exec function is permitted - /// to alter the bitmap after the null intersection is computed if it needs - /// to. - INTERSECTION, - - /// Kernel expects a pre-allocated buffer to write the result bitmap - /// into. The preallocated memory is not zeroed (except for the last byte), - /// so the kernel should ensure to completely populate the bitmap. - COMPUTED_PREALLOCATE, - - /// Kernel allocates and sets the validity bitmap of the output. - COMPUTED_NO_PREALLOCATE, - - /// Kernel output is never null and a validity bitmap does not need to be - /// allocated. - OUTPUT_NOT_NULL - }; -}; - -/// \brief The preference for memory preallocation of fixed-width type outputs -/// in kernel execution. -struct MemAllocation { - enum type { - // For data types that support pre-allocation (i.e. fixed-width), the - // kernel expects to be provided a pre-allocated data buffer to write - // into. Non-fixed-width types must always allocate their own data - // buffers. The allocation made for the same length as the execution batch, - // so vector kernels yielding differently sized output should not use this. - // - // It is valid for the data to not be preallocated but the validity bitmap - // is (or is computed using the intersection/bitwise-and method). - // - // For variable-size output types like BinaryType or StringType, or for - // nested types, this option has no effect. - PREALLOCATE, - - // The kernel is responsible for allocating its own data buffer for - // fixed-width type outputs. - NO_PREALLOCATE - }; -}; - -struct Kernel; - -/// \brief Arguments to pass to a KernelInit function. A struct is used to help -/// avoid API breakage should the arguments passed need to be expanded. -struct KernelInitArgs { - /// \brief A pointer to the kernel being initialized. The init function may - /// depend on the kernel's KernelSignature or other data contained there. - const Kernel* kernel; - - /// \brief The types and shapes of the input arguments that the kernel is - /// about to be executed against. - /// - /// TODO: should this be const std::vector<ValueDescr>*? const-ref is being - /// used to avoid the cost of copying the struct into the args struct. - const std::vector<ValueDescr>& inputs; - - /// \brief Opaque options specific to this kernel. May be nullptr for functions - /// that do not require options. - const FunctionOptions* options; -}; - -/// \brief Common initializer function for all kernel types. +class ARROW_EXPORT KernelSignature { + public: + KernelSignature(std::vector<InputType> in_types, OutputType out_type, + bool is_varargs = false); + + /// \brief Convenience ctor since make_shared can be awkward + static std::shared_ptr<KernelSignature> Make(std::vector<InputType> in_types, + OutputType out_type, + bool is_varargs = false); + + /// \brief Return true if the signature if compatible with the list of input + /// value descriptors. + bool MatchesInputs(const std::vector<ValueDescr>& descriptors) const; + + /// \brief Returns true if the input types of each signature are + /// equal. Well-formed functions should have a deterministic output type + /// given input types, but currently it is the responsibility of the + /// developer to ensure this. + bool Equals(const KernelSignature& other) const; + + bool operator==(const KernelSignature& other) const { return this->Equals(other); } + + bool operator!=(const KernelSignature& other) const { return !(*this == other); } + + /// \brief Compute a hash code for the signature + size_t Hash() const; + + /// \brief The input types for the kernel. For VarArgs functions, this should + /// generally contain a single validator to use for validating all of the + /// function arguments. + const std::vector<InputType>& in_types() const { return in_types_; } + + /// \brief The output type for the kernel. Use Resolve to return the exact + /// output given input argument ValueDescrs, since many kernels' output types + /// depend on their input types (or their type metadata). + const OutputType& out_type() const { return out_type_; } + + /// \brief Render a human-readable string representation + std::string ToString() const; + + bool is_varargs() const { return is_varargs_; } + + private: + std::vector<InputType> in_types_; + OutputType out_type_; + bool is_varargs_; + + // For caching the hash code after it's computed the first time + mutable uint64_t hash_code_; +}; + +/// \brief A function may contain multiple variants of a kernel for a given +/// type combination for different SIMD levels. Based on the active system's +/// CPU info or the user's preferences, we can elect to use one over the other. +struct SimdLevel { + enum type { NONE = 0, SSE4_2, AVX, AVX2, AVX512, NEON, MAX }; +}; + +/// \brief The strategy to use for propagating or otherwise populating the +/// validity bitmap of a kernel output. +struct NullHandling { + enum type { + /// Compute the output validity bitmap by intersecting the validity bitmaps + /// of the arguments using bitwise-and operations. This means that values + /// in the output are valid/non-null only if the corresponding values in + /// all input arguments were valid/non-null. Kernel generally need not + /// touch the bitmap thereafter, but a kernel's exec function is permitted + /// to alter the bitmap after the null intersection is computed if it needs + /// to. + INTERSECTION, + + /// Kernel expects a pre-allocated buffer to write the result bitmap + /// into. The preallocated memory is not zeroed (except for the last byte), + /// so the kernel should ensure to completely populate the bitmap. + COMPUTED_PREALLOCATE, + + /// Kernel allocates and sets the validity bitmap of the output. + COMPUTED_NO_PREALLOCATE, + + /// Kernel output is never null and a validity bitmap does not need to be + /// allocated. + OUTPUT_NOT_NULL + }; +}; + +/// \brief The preference for memory preallocation of fixed-width type outputs +/// in kernel execution. +struct MemAllocation { + enum type { + // For data types that support pre-allocation (i.e. fixed-width), the + // kernel expects to be provided a pre-allocated data buffer to write + // into. Non-fixed-width types must always allocate their own data + // buffers. The allocation made for the same length as the execution batch, + // so vector kernels yielding differently sized output should not use this. + // + // It is valid for the data to not be preallocated but the validity bitmap + // is (or is computed using the intersection/bitwise-and method). + // + // For variable-size output types like BinaryType or StringType, or for + // nested types, this option has no effect. + PREALLOCATE, + + // The kernel is responsible for allocating its own data buffer for + // fixed-width type outputs. + NO_PREALLOCATE + }; +}; + +struct Kernel; + +/// \brief Arguments to pass to a KernelInit function. A struct is used to help +/// avoid API breakage should the arguments passed need to be expanded. +struct KernelInitArgs { + /// \brief A pointer to the kernel being initialized. The init function may + /// depend on the kernel's KernelSignature or other data contained there. + const Kernel* kernel; + + /// \brief The types and shapes of the input arguments that the kernel is + /// about to be executed against. + /// + /// TODO: should this be const std::vector<ValueDescr>*? const-ref is being + /// used to avoid the cost of copying the struct into the args struct. + const std::vector<ValueDescr>& inputs; + + /// \brief Opaque options specific to this kernel. May be nullptr for functions + /// that do not require options. + const FunctionOptions* options; +}; + +/// \brief Common initializer function for all kernel types. using KernelInit = std::function<Result<std::unique_ptr<KernelState>>( KernelContext*, const KernelInitArgs&)>; - -/// \brief Base type for kernels. Contains the function signature and -/// optionally the state initialization function, along with some common -/// attributes -struct Kernel { - Kernel() = default; - - Kernel(std::shared_ptr<KernelSignature> sig, KernelInit init) - : signature(std::move(sig)), init(std::move(init)) {} - - Kernel(std::vector<InputType> in_types, OutputType out_type, KernelInit init) + +/// \brief Base type for kernels. Contains the function signature and +/// optionally the state initialization function, along with some common +/// attributes +struct Kernel { + Kernel() = default; + + Kernel(std::shared_ptr<KernelSignature> sig, KernelInit init) + : signature(std::move(sig)), init(std::move(init)) {} + + Kernel(std::vector<InputType> in_types, OutputType out_type, KernelInit init) : Kernel(KernelSignature::Make(std::move(in_types), std::move(out_type)), std::move(init)) {} - - /// \brief The "signature" of the kernel containing the InputType input - /// argument validators and OutputType output type and shape resolver. - std::shared_ptr<KernelSignature> signature; - - /// \brief Create a new KernelState for invocations of this kernel, e.g. to - /// set up any options or state relevant for execution. - KernelInit init; - + + /// \brief The "signature" of the kernel containing the InputType input + /// argument validators and OutputType output type and shape resolver. + std::shared_ptr<KernelSignature> signature; + + /// \brief Create a new KernelState for invocations of this kernel, e.g. to + /// set up any options or state relevant for execution. + KernelInit init; + /// \brief Create a vector of new KernelState for invocations of this kernel. static Status InitAll(KernelContext*, const KernelInitArgs&, std::vector<std::unique_ptr<KernelState>>*); - /// \brief Indicates whether execution can benefit from parallelization - /// (splitting large chunks into smaller chunks and using multiple - /// threads). Some kernels may not support parallel execution at - /// all. Synchronization and concurrency-related issues are currently the - /// responsibility of the Kernel's implementation. - bool parallelizable = true; - - /// \brief Indicates the level of SIMD instruction support in the host CPU is - /// required to use the function. The intention is for functions to be able to - /// contain multiple kernels with the same signature but different levels of SIMD, - /// so that the most optimized kernel supported on a host's processor can be chosen. - SimdLevel::type simd_level = SimdLevel::NONE; -}; - -/// \brief Common kernel base data structure for ScalarKernel and -/// VectorKernel. It is called "ArrayKernel" in that the functions generally -/// output array values (as opposed to scalar values in the case of aggregate -/// functions). -struct ArrayKernel : public Kernel { + /// \brief Indicates whether execution can benefit from parallelization + /// (splitting large chunks into smaller chunks and using multiple + /// threads). Some kernels may not support parallel execution at + /// all. Synchronization and concurrency-related issues are currently the + /// responsibility of the Kernel's implementation. + bool parallelizable = true; + + /// \brief Indicates the level of SIMD instruction support in the host CPU is + /// required to use the function. The intention is for functions to be able to + /// contain multiple kernels with the same signature but different levels of SIMD, + /// so that the most optimized kernel supported on a host's processor can be chosen. + SimdLevel::type simd_level = SimdLevel::NONE; +}; + +/// \brief Common kernel base data structure for ScalarKernel and +/// VectorKernel. It is called "ArrayKernel" in that the functions generally +/// output array values (as opposed to scalar values in the case of aggregate +/// functions). +struct ArrayKernel : public Kernel { ArrayKernel() = default; - - ArrayKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec, - KernelInit init = NULLPTR) - : Kernel(std::move(sig), init), exec(std::move(exec)) {} - - ArrayKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec, - KernelInit init = NULLPTR) + + ArrayKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(sig), init), exec(std::move(exec)) {} + + ArrayKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR) : Kernel(std::move(in_types), std::move(out_type), std::move(init)), exec(std::move(exec)) {} - - /// \brief Perform a single invocation of this kernel. Depending on the - /// implementation, it may only write into preallocated memory, while in some - /// cases it will allocate its own memory. Any required state is managed - /// through the KernelContext. - ArrayKernelExec exec; - - /// \brief Writing execution results into larger contiguous allocations - /// requires that the kernel be able to write into sliced output ArrayData*, - /// including sliced output validity bitmaps. Some kernel implementations may - /// not be able to do this, so setting this to false disables this - /// functionality. - bool can_write_into_slices = true; -}; - -/// \brief Kernel data structure for implementations of ScalarFunction. In -/// addition to the members found in ArrayKernel, contains the null handling -/// and memory pre-allocation preferences. -struct ScalarKernel : public ArrayKernel { - using ArrayKernel::ArrayKernel; - - // For scalar functions preallocated data and intersecting arg validity - // bitmaps is a reasonable default - NullHandling::type null_handling = NullHandling::INTERSECTION; - MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; -}; - -// ---------------------------------------------------------------------- -// VectorKernel (for VectorFunction) - -/// \brief See VectorKernel::finalize member for usage + + /// \brief Perform a single invocation of this kernel. Depending on the + /// implementation, it may only write into preallocated memory, while in some + /// cases it will allocate its own memory. Any required state is managed + /// through the KernelContext. + ArrayKernelExec exec; + + /// \brief Writing execution results into larger contiguous allocations + /// requires that the kernel be able to write into sliced output ArrayData*, + /// including sliced output validity bitmaps. Some kernel implementations may + /// not be able to do this, so setting this to false disables this + /// functionality. + bool can_write_into_slices = true; +}; + +/// \brief Kernel data structure for implementations of ScalarFunction. In +/// addition to the members found in ArrayKernel, contains the null handling +/// and memory pre-allocation preferences. +struct ScalarKernel : public ArrayKernel { + using ArrayKernel::ArrayKernel; + + // For scalar functions preallocated data and intersecting arg validity + // bitmaps is a reasonable default + NullHandling::type null_handling = NullHandling::INTERSECTION; + MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; +}; + +// ---------------------------------------------------------------------- +// VectorKernel (for VectorFunction) + +/// \brief See VectorKernel::finalize member for usage using VectorFinalize = std::function<Status(KernelContext*, std::vector<Datum>*)>; - -/// \brief Kernel data structure for implementations of VectorFunction. In -/// addition to the members found in ArrayKernel, contains an optional -/// finalizer function, the null handling and memory pre-allocation preferences -/// (which have different defaults from ScalarKernel), and some other -/// execution-related options. -struct VectorKernel : public ArrayKernel { + +/// \brief Kernel data structure for implementations of VectorFunction. In +/// addition to the members found in ArrayKernel, contains an optional +/// finalizer function, the null handling and memory pre-allocation preferences +/// (which have different defaults from ScalarKernel), and some other +/// execution-related options. +struct VectorKernel : public ArrayKernel { VectorKernel() = default; - - VectorKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec) + + VectorKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec) : ArrayKernel(std::move(sig), std::move(exec)) {} - - VectorKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec, - KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) - : ArrayKernel(std::move(in_types), std::move(out_type), std::move(exec), - std::move(init)), - finalize(std::move(finalize)) {} - - VectorKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec, - KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) - : ArrayKernel(std::move(sig), std::move(exec), std::move(init)), - finalize(std::move(finalize)) {} - - /// \brief For VectorKernel, convert intermediate results into finalized - /// results. Mutates input argument. Some kernels may accumulate state - /// (example: hashing-related functions) through processing chunked inputs, and - /// then need to attach some accumulated state to each of the outputs of - /// processing each chunk of data. - VectorFinalize finalize; - - /// Since vector kernels generally are implemented rather differently from - /// scalar/elementwise kernels (and they may not even yield arrays of the same - /// size), so we make the developer opt-in to any memory preallocation rather - /// than having to turn it off. - NullHandling::type null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - MemAllocation::type mem_allocation = MemAllocation::NO_PREALLOCATE; - - /// Some vector kernels can do chunkwise execution using ExecBatchIterator, - /// in some cases accumulating some state. Other kernels (like Take) need to - /// be passed whole arrays and don't work on ChunkedArray inputs - bool can_execute_chunkwise = true; - - /// Some kernels (like unique and value_counts) yield non-chunked output from - /// chunked-array inputs. This option controls how the results are boxed when - /// returned from ExecVectorFunction - /// - /// true -> ChunkedArray - /// false -> Array - bool output_chunked = true; -}; - -// ---------------------------------------------------------------------- -// ScalarAggregateKernel (for ScalarAggregateFunction) - + + VectorKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) + : ArrayKernel(std::move(in_types), std::move(out_type), std::move(exec), + std::move(init)), + finalize(std::move(finalize)) {} + + VectorKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec, + KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) + : ArrayKernel(std::move(sig), std::move(exec), std::move(init)), + finalize(std::move(finalize)) {} + + /// \brief For VectorKernel, convert intermediate results into finalized + /// results. Mutates input argument. Some kernels may accumulate state + /// (example: hashing-related functions) through processing chunked inputs, and + /// then need to attach some accumulated state to each of the outputs of + /// processing each chunk of data. + VectorFinalize finalize; + + /// Since vector kernels generally are implemented rather differently from + /// scalar/elementwise kernels (and they may not even yield arrays of the same + /// size), so we make the developer opt-in to any memory preallocation rather + /// than having to turn it off. + NullHandling::type null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + MemAllocation::type mem_allocation = MemAllocation::NO_PREALLOCATE; + + /// Some vector kernels can do chunkwise execution using ExecBatchIterator, + /// in some cases accumulating some state. Other kernels (like Take) need to + /// be passed whole arrays and don't work on ChunkedArray inputs + bool can_execute_chunkwise = true; + + /// Some kernels (like unique and value_counts) yield non-chunked output from + /// chunked-array inputs. This option controls how the results are boxed when + /// returned from ExecVectorFunction + /// + /// true -> ChunkedArray + /// false -> Array + bool output_chunked = true; +}; + +// ---------------------------------------------------------------------- +// ScalarAggregateKernel (for ScalarAggregateFunction) + using ScalarAggregateConsume = std::function<Status(KernelContext*, const ExecBatch&)>; - -using ScalarAggregateMerge = + +using ScalarAggregateMerge = std::function<Status(KernelContext*, KernelState&&, KernelState*)>; - -// Finalize returns Datum to permit multiple return values + +// Finalize returns Datum to permit multiple return values using ScalarAggregateFinalize = std::function<Status(KernelContext*, Datum*)>; - -/// \brief Kernel data structure for implementations of -/// ScalarAggregateFunction. The four necessary components of an aggregation -/// kernel are the init, consume, merge, and finalize functions. -/// -/// * init: creates a new KernelState for a kernel. -/// * consume: processes an ExecBatch and updates the KernelState found in the -/// KernelContext. -/// * merge: combines one KernelState with another. -/// * finalize: produces the end result of the aggregation using the -/// KernelState in the KernelContext. -struct ScalarAggregateKernel : public Kernel { + +/// \brief Kernel data structure for implementations of +/// ScalarAggregateFunction. The four necessary components of an aggregation +/// kernel are the init, consume, merge, and finalize functions. +/// +/// * init: creates a new KernelState for a kernel. +/// * consume: processes an ExecBatch and updates the KernelState found in the +/// KernelContext. +/// * merge: combines one KernelState with another. +/// * finalize: produces the end result of the aggregation using the +/// KernelState in the KernelContext. +struct ScalarAggregateKernel : public Kernel { ScalarAggregateKernel() = default; - - ScalarAggregateKernel(std::shared_ptr<KernelSignature> sig, KernelInit init, - ScalarAggregateConsume consume, ScalarAggregateMerge merge, - ScalarAggregateFinalize finalize) + + ScalarAggregateKernel(std::shared_ptr<KernelSignature> sig, KernelInit init, + ScalarAggregateConsume consume, ScalarAggregateMerge merge, + ScalarAggregateFinalize finalize) : Kernel(std::move(sig), std::move(init)), - consume(std::move(consume)), - merge(std::move(merge)), - finalize(std::move(finalize)) {} - - ScalarAggregateKernel(std::vector<InputType> in_types, OutputType out_type, - KernelInit init, ScalarAggregateConsume consume, - ScalarAggregateMerge merge, ScalarAggregateFinalize finalize) + consume(std::move(consume)), + merge(std::move(merge)), + finalize(std::move(finalize)) {} + + ScalarAggregateKernel(std::vector<InputType> in_types, OutputType out_type, + KernelInit init, ScalarAggregateConsume consume, + ScalarAggregateMerge merge, ScalarAggregateFinalize finalize) : ScalarAggregateKernel( KernelSignature::Make(std::move(in_types), std::move(out_type)), std::move(init), std::move(consume), std::move(merge), std::move(finalize)) {} - + /// \brief Merge a vector of KernelStates into a single KernelState. /// The merged state will be returned and will be set on the KernelContext. static Result<std::unique_ptr<KernelState>> MergeAll( const ScalarAggregateKernel* kernel, KernelContext* ctx, std::vector<std::unique_ptr<KernelState>> states); - ScalarAggregateConsume consume; - ScalarAggregateMerge merge; - ScalarAggregateFinalize finalize; -}; - + ScalarAggregateConsume consume; + ScalarAggregateMerge merge; + ScalarAggregateFinalize finalize; +}; + // ---------------------------------------------------------------------- // HashAggregateKernel (for HashAggregateFunction) @@ -735,5 +735,5 @@ struct HashAggregateKernel : public Kernel { HashAggregateFinalize finalize; }; -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc index a7df66695b..88f3b87d9e 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -1,44 +1,44 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/api_aggregate.h" -#include "arrow/compute/kernels/aggregate_basic_internal.h" -#include "arrow/compute/kernels/aggregate_internal.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/util/cpu_info.h" -#include "arrow/util/make_unique.h" - -namespace arrow { -namespace compute { - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_basic_internal.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/make_unique.h" + +namespace arrow { +namespace compute { + namespace { Status AggregateConsume(KernelContext* ctx, const ExecBatch& batch) { return checked_cast<ScalarAggregator*>(ctx->state())->Consume(ctx, batch); -} - +} + Status AggregateMerge(KernelContext* ctx, KernelState&& src, KernelState* dst) { return checked_cast<ScalarAggregator*>(dst)->MergeFrom(ctx, std::move(src)); -} - +} + Status AggregateFinalize(KernelContext* ctx, Datum* out) { return checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, out); -} - +} + } // namespace void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init, @@ -52,12 +52,12 @@ void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init, namespace aggregate { -// ---------------------------------------------------------------------- -// Count implementation - -struct CountImpl : public ScalarAggregator { +// ---------------------------------------------------------------------- +// Count implementation + +struct CountImpl : public ScalarAggregator { explicit CountImpl(ScalarAggregateOptions options) : options(std::move(options)) {} - + Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { const ArrayData& input = *batch[0].array(); @@ -70,80 +70,80 @@ struct CountImpl : public ScalarAggregator { this->non_nulls += input.is_valid * batch.length; } return Status::OK(); - } - + } + Status MergeFrom(KernelContext*, KernelState&& src) override { - const auto& other_state = checked_cast<const CountImpl&>(src); - this->non_nulls += other_state.non_nulls; - this->nulls += other_state.nulls; + const auto& other_state = checked_cast<const CountImpl&>(src); + this->non_nulls += other_state.non_nulls; + this->nulls += other_state.nulls; return Status::OK(); - } - + } + Status Finalize(KernelContext* ctx, Datum* out) override { - const auto& state = checked_cast<const CountImpl&>(*ctx->state()); + const auto& state = checked_cast<const CountImpl&>(*ctx->state()); if (state.options.skip_nulls) { *out = Datum(state.non_nulls); } else { *out = Datum(state.nulls); - } + } return Status::OK(); - } - + } + ScalarAggregateOptions options; - int64_t non_nulls = 0; - int64_t nulls = 0; -}; - + int64_t non_nulls = 0; + int64_t nulls = 0; +}; + Result<std::unique_ptr<KernelState>> CountInit(KernelContext*, const KernelInitArgs& args) { - return ::arrow::internal::make_unique<CountImpl>( + return ::arrow::internal::make_unique<CountImpl>( static_cast<const ScalarAggregateOptions&>(*args.options)); -} - -// ---------------------------------------------------------------------- -// Sum implementation - +} + +// ---------------------------------------------------------------------- +// Sum implementation + template <typename ArrowType> struct SumImplDefault : public SumImpl<ArrowType, SimdLevel::NONE> { explicit SumImplDefault(const ScalarAggregateOptions& options_) { this->options = options_; } -}; - +}; + template <typename ArrowType> struct MeanImplDefault : public MeanImpl<ArrowType, SimdLevel::NONE> { explicit MeanImplDefault(const ScalarAggregateOptions& options_) { this->options = options_; } -}; - +}; + Result<std::unique_ptr<KernelState>> SumInit(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit<SumImplDefault> visitor( ctx, *args.inputs[0].type, static_cast<const ScalarAggregateOptions&>(*args.options)); - return visitor.Create(); -} - + return visitor.Create(); +} + Result<std::unique_ptr<KernelState>> MeanInit(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit<MeanImplDefault> visitor( ctx, *args.inputs[0].type, static_cast<const ScalarAggregateOptions&>(*args.options)); - return visitor.Create(); -} - -// ---------------------------------------------------------------------- -// MinMax implementation - + return visitor.Create(); +} + +// ---------------------------------------------------------------------- +// MinMax implementation + Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx, const KernelInitArgs& args) { - MinMaxInitState<SimdLevel::NONE> visitor( - ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(), + MinMaxInitState<SimdLevel::NONE> visitor( + ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(), static_cast<const ScalarAggregateOptions&>(*args.options)); - return visitor.Create(); -} - + return visitor.Create(); +} + // ---------------------------------------------------------------------- // Any implementation @@ -203,8 +203,8 @@ Result<std::unique_ptr<KernelState>> AnyInit(KernelContext*, const KernelInitArg static_cast<const ScalarAggregateOptions&>(*args.options); return ::arrow::internal::make_unique<BooleanAnyImpl>( static_cast<const ScalarAggregateOptions&>(*args.options)); -} - +} + // ---------------------------------------------------------------------- // All implementation @@ -394,17 +394,17 @@ struct IndexInit { } }; -void AddBasicAggKernels(KernelInit init, - const std::vector<std::shared_ptr<DataType>>& types, - std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func, - SimdLevel::type simd_level) { - for (const auto& ty : types) { - // array[InT] -> scalar[OutT] - auto sig = KernelSignature::Make({InputType::Array(ty)}, ValueDescr::Scalar(out_ty)); - AddAggKernel(std::move(sig), init, func, simd_level); - } -} - +void AddBasicAggKernels(KernelInit init, + const std::vector<std::shared_ptr<DataType>>& types, + std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func, + SimdLevel::type simd_level) { + for (const auto& ty : types) { + // array[InT] -> scalar[OutT] + auto sig = KernelSignature::Make({InputType::Array(ty)}, ValueDescr::Scalar(out_ty)); + AddAggKernel(std::move(sig), init, func, simd_level); + } +} + void AddScalarAggKernels(KernelInit init, const std::vector<std::shared_ptr<DataType>>& types, std::shared_ptr<DataType> out_ty, @@ -425,20 +425,20 @@ void AddArrayScalarAggKernels(KernelInit init, AddScalarAggKernels(init, types, out_ty, func); } -void AddMinMaxKernels(KernelInit init, - const std::vector<std::shared_ptr<DataType>>& types, - ScalarAggregateFunction* func, SimdLevel::type simd_level) { - for (const auto& ty : types) { +void AddMinMaxKernels(KernelInit init, + const std::vector<std::shared_ptr<DataType>>& types, + ScalarAggregateFunction* func, SimdLevel::type simd_level) { + for (const auto& ty : types) { // any[T] -> scalar[struct<min: T, max: T>] - auto out_ty = struct_({field("min", ty), field("max", ty)}); + auto out_ty = struct_({field("min", ty), field("max", ty)}); auto sig = KernelSignature::Make({InputType(ty)}, ValueDescr::Scalar(out_ty)); - AddAggKernel(std::move(sig), init, func, simd_level); - } -} - -} // namespace aggregate - -namespace internal { + AddAggKernel(std::move(sig), init, func, simd_level); + } +} + +} // namespace aggregate + +namespace internal { namespace { const FunctionDoc count_doc{"Count the number of null / non-null values", @@ -496,21 +496,21 @@ const FunctionDoc index_doc{"Find the index of the first occurrence of a given v } // namespace -void RegisterScalarAggregateBasic(FunctionRegistry* registry) { +void RegisterScalarAggregateBasic(FunctionRegistry* registry) { static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults(); - + auto func = std::make_shared<ScalarAggregateFunction>( "count", Arity::Unary(), &count_doc, &default_scalar_aggregate_options); - // Takes any array input, outputs int64 scalar - InputType any_array(ValueDescr::ARRAY); + // Takes any array input, outputs int64 scalar + InputType any_array(ValueDescr::ARRAY); AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())), aggregate::CountInit, func.get()); AddAggKernel( KernelSignature::Make({InputType(ValueDescr::SCALAR)}, ValueDescr::Scalar(int64())), aggregate::CountInit, func.get()); - DCHECK_OK(registry->AddFunction(std::move(func))); - + DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), &sum_doc, &default_scalar_aggregate_options); aggregate::AddArrayScalarAggKernels(aggregate::SumInit, {boolean()}, int64(), @@ -521,59 +521,59 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { func.get()); aggregate::AddArrayScalarAggKernels(aggregate::SumInit, FloatingPointTypes(), float64(), func.get()); - // Add the SIMD variants for sum + // Add the SIMD variants for sum #if defined(ARROW_HAVE_RUNTIME_AVX2) || defined(ARROW_HAVE_RUNTIME_AVX512) - auto cpu_info = arrow::internal::CpuInfo::GetInstance(); -#endif -#if defined(ARROW_HAVE_RUNTIME_AVX2) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - aggregate::AddSumAvx2AggKernels(func.get()); - } -#endif -#if defined(ARROW_HAVE_RUNTIME_AVX512) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - aggregate::AddSumAvx512AggKernels(func.get()); - } + auto cpu_info = arrow::internal::CpuInfo::GetInstance(); #endif - DCHECK_OK(registry->AddFunction(std::move(func))); - +#if defined(ARROW_HAVE_RUNTIME_AVX2) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { + aggregate::AddSumAvx2AggKernels(func.get()); + } +#endif +#if defined(ARROW_HAVE_RUNTIME_AVX512) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { + aggregate::AddSumAvx512AggKernels(func.get()); + } +#endif + DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared<ScalarAggregateFunction>("mean", Arity::Unary(), &mean_doc, &default_scalar_aggregate_options); aggregate::AddArrayScalarAggKernels(aggregate::MeanInit, {boolean()}, float64(), func.get()); aggregate::AddArrayScalarAggKernels(aggregate::MeanInit, NumericTypes(), float64(), func.get()); - // Add the SIMD variants for mean -#if defined(ARROW_HAVE_RUNTIME_AVX2) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - aggregate::AddMeanAvx2AggKernels(func.get()); - } -#endif -#if defined(ARROW_HAVE_RUNTIME_AVX512) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - aggregate::AddMeanAvx512AggKernels(func.get()); - } -#endif - DCHECK_OK(registry->AddFunction(std::move(func))); - + // Add the SIMD variants for mean +#if defined(ARROW_HAVE_RUNTIME_AVX2) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { + aggregate::AddMeanAvx2AggKernels(func.get()); + } +#endif +#if defined(ARROW_HAVE_RUNTIME_AVX512) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { + aggregate::AddMeanAvx512AggKernels(func.get()); + } +#endif + DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared<ScalarAggregateFunction>( "min_max", Arity::Unary(), &min_max_doc, &default_scalar_aggregate_options); - aggregate::AddMinMaxKernels(aggregate::MinMaxInit, {boolean()}, func.get()); - aggregate::AddMinMaxKernels(aggregate::MinMaxInit, NumericTypes(), func.get()); - // Add the SIMD variants for min max -#if defined(ARROW_HAVE_RUNTIME_AVX2) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - aggregate::AddMinMaxAvx2AggKernels(func.get()); - } -#endif -#if defined(ARROW_HAVE_RUNTIME_AVX512) - if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - aggregate::AddMinMaxAvx512AggKernels(func.get()); - } -#endif - - DCHECK_OK(registry->AddFunction(std::move(func))); - + aggregate::AddMinMaxKernels(aggregate::MinMaxInit, {boolean()}, func.get()); + aggregate::AddMinMaxKernels(aggregate::MinMaxInit, NumericTypes(), func.get()); + // Add the SIMD variants for min max +#if defined(ARROW_HAVE_RUNTIME_AVX2) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { + aggregate::AddMinMaxAvx2AggKernels(func.get()); + } +#endif +#if defined(ARROW_HAVE_RUNTIME_AVX512) + if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { + aggregate::AddMinMaxAvx512AggKernels(func.get()); + } +#endif + + DCHECK_OK(registry->AddFunction(std::move(func))); + // any func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc, &default_scalar_aggregate_options); @@ -597,8 +597,8 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { aggregate::AddBasicAggKernels(aggregate::IndexInit::Init, TemporalTypes(), int64(), func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index 5163d3fd03..60419356c5 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -1,63 +1,63 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cmath> - -#include "arrow/compute/api_aggregate.h" -#include "arrow/compute/kernels/aggregate_internal.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/util/align_util.h" -#include "arrow/util/bit_block_counter.h" - -namespace arrow { -namespace compute { -namespace aggregate { - -void AddBasicAggKernels(KernelInit init, - const std::vector<std::shared_ptr<DataType>>& types, - std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func, - SimdLevel::type simd_level = SimdLevel::NONE); - -void AddMinMaxKernels(KernelInit init, - const std::vector<std::shared_ptr<DataType>>& types, - ScalarAggregateFunction* func, - SimdLevel::type simd_level = SimdLevel::NONE); - -// SIMD variants for kernels -void AddSumAvx2AggKernels(ScalarAggregateFunction* func); -void AddMeanAvx2AggKernels(ScalarAggregateFunction* func); -void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func); - -void AddSumAvx512AggKernels(ScalarAggregateFunction* func); -void AddMeanAvx512AggKernels(ScalarAggregateFunction* func); -void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func); - -// ---------------------------------------------------------------------- -// Sum implementation - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/align_util.h" +#include "arrow/util/bit_block_counter.h" + +namespace arrow { +namespace compute { +namespace aggregate { + +void AddBasicAggKernels(KernelInit init, + const std::vector<std::shared_ptr<DataType>>& types, + std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func, + SimdLevel::type simd_level = SimdLevel::NONE); + +void AddMinMaxKernels(KernelInit init, + const std::vector<std::shared_ptr<DataType>>& types, + ScalarAggregateFunction* func, + SimdLevel::type simd_level = SimdLevel::NONE); + +// SIMD variants for kernels +void AddSumAvx2AggKernels(ScalarAggregateFunction* func); +void AddMeanAvx2AggKernels(ScalarAggregateFunction* func); +void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func); + +void AddSumAvx512AggKernels(ScalarAggregateFunction* func); +void AddMeanAvx512AggKernels(ScalarAggregateFunction* func); +void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func); + +// ---------------------------------------------------------------------- +// Sum implementation + template <typename ArrowType, SimdLevel::type SimdLevel> struct SumImpl : public ScalarAggregator { using ThisType = SumImpl<ArrowType, SimdLevel>; using CType = typename ArrowType::c_type; - using SumType = typename FindAccumulatorType<ArrowType>::Type; + using SumType = typename FindAccumulatorType<ArrowType>::Type; using OutputType = typename TypeTraits<SumType>::ScalarType; - + Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { const auto& data = batch[0].array(); @@ -70,173 +70,173 @@ struct SumImpl : public ScalarAggregator { arrow::compute::detail::SumArray<CType, typename SumType::c_type, SimdLevel>( *data); } - } else { + } else { const auto& data = *batch[0].scalar(); this->count += data.is_valid * batch.length; if (data.is_valid) { this->sum += internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length; - } - } + } + } return Status::OK(); - } - + } + Status MergeFrom(KernelContext*, KernelState&& src) override { const auto& other = checked_cast<const ThisType&>(src); this->count += other.count; this->sum += other.sum; return Status::OK(); - } - + } + Status Finalize(KernelContext*, Datum* out) override { if (this->count < options.min_count) { out->value = std::make_shared<OutputType>(); - } else { + } else { out->value = MakeScalar(this->sum); - } + } return Status::OK(); - } - - size_t count = 0; - typename SumType::c_type sum = 0; + } + + size_t count = 0; + typename SumType::c_type sum = 0; ScalarAggregateOptions options; -}; - +}; + template <typename ArrowType, SimdLevel::type SimdLevel> struct MeanImpl : public SumImpl<ArrowType, SimdLevel> { Status Finalize(KernelContext*, Datum* out) override { if (this->count < options.min_count) { out->value = std::make_shared<DoubleScalar>(); - } else { + } else { const double mean = static_cast<double>(this->sum) / this->count; out->value = std::make_shared<DoubleScalar>(mean); - } + } return Status::OK(); - } + } ScalarAggregateOptions options; -}; - -template <template <typename> class KernelClass> -struct SumLikeInit { - std::unique_ptr<KernelState> state; - KernelContext* ctx; - const DataType& type; +}; + +template <template <typename> class KernelClass> +struct SumLikeInit { + std::unique_ptr<KernelState> state; + KernelContext* ctx; + const DataType& type; const ScalarAggregateOptions& options; - + SumLikeInit(KernelContext* ctx, const DataType& type, const ScalarAggregateOptions& options) : ctx(ctx), type(type), options(options) {} - - Status Visit(const DataType&) { return Status::NotImplemented("No sum implemented"); } - - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("No sum implemented"); - } - - Status Visit(const BooleanType&) { + + Status Visit(const DataType&) { return Status::NotImplemented("No sum implemented"); } + + Status Visit(const HalfFloatType&) { + return Status::NotImplemented("No sum implemented"); + } + + Status Visit(const BooleanType&) { state.reset(new KernelClass<BooleanType>(options)); - return Status::OK(); - } - - template <typename Type> - enable_if_number<Type, Status> Visit(const Type&) { + return Status::OK(); + } + + template <typename Type> + enable_if_number<Type, Status> Visit(const Type&) { state.reset(new KernelClass<Type>(options)); - return Status::OK(); - } - + return Status::OK(); + } + Result<std::unique_ptr<KernelState>> Create() { RETURN_NOT_OK(VisitTypeInline(type, this)); - return std::move(state); - } -}; - -// ---------------------------------------------------------------------- -// MinMax implementation - -template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void> -struct MinMaxState {}; - -template <typename ArrowType, SimdLevel::type SimdLevel> -struct MinMaxState<ArrowType, SimdLevel, enable_if_boolean<ArrowType>> { - using ThisType = MinMaxState<ArrowType, SimdLevel>; - using T = typename ArrowType::c_type; - - ThisType& operator+=(const ThisType& rhs) { - this->has_nulls |= rhs.has_nulls; - this->has_values |= rhs.has_values; - this->min = this->min && rhs.min; - this->max = this->max || rhs.max; - return *this; - } - - void MergeOne(T value) { - this->min = this->min && value; - this->max = this->max || value; - } - - T min = true; - T max = false; - bool has_nulls = false; - bool has_values = false; -}; - -template <typename ArrowType, SimdLevel::type SimdLevel> -struct MinMaxState<ArrowType, SimdLevel, enable_if_integer<ArrowType>> { - using ThisType = MinMaxState<ArrowType, SimdLevel>; - using T = typename ArrowType::c_type; - - ThisType& operator+=(const ThisType& rhs) { - this->has_nulls |= rhs.has_nulls; - this->has_values |= rhs.has_values; - this->min = std::min(this->min, rhs.min); - this->max = std::max(this->max, rhs.max); - return *this; - } - - void MergeOne(T value) { - this->min = std::min(this->min, value); - this->max = std::max(this->max, value); - } - - T min = std::numeric_limits<T>::max(); - T max = std::numeric_limits<T>::min(); - bool has_nulls = false; - bool has_values = false; -}; - -template <typename ArrowType, SimdLevel::type SimdLevel> -struct MinMaxState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> { - using ThisType = MinMaxState<ArrowType, SimdLevel>; - using T = typename ArrowType::c_type; - - ThisType& operator+=(const ThisType& rhs) { - this->has_nulls |= rhs.has_nulls; - this->has_values |= rhs.has_values; - this->min = std::fmin(this->min, rhs.min); - this->max = std::fmax(this->max, rhs.max); - return *this; - } - - void MergeOne(T value) { - this->min = std::fmin(this->min, value); - this->max = std::fmax(this->max, value); - } - - T min = std::numeric_limits<T>::infinity(); - T max = -std::numeric_limits<T>::infinity(); - bool has_nulls = false; - bool has_values = false; -}; - -template <typename ArrowType, SimdLevel::type SimdLevel> -struct MinMaxImpl : public ScalarAggregator { - using ArrayType = typename TypeTraits<ArrowType>::ArrayType; - using ThisType = MinMaxImpl<ArrowType, SimdLevel>; - using StateType = MinMaxState<ArrowType, SimdLevel>; - + return std::move(state); + } +}; + +// ---------------------------------------------------------------------- +// MinMax implementation + +template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void> +struct MinMaxState {}; + +template <typename ArrowType, SimdLevel::type SimdLevel> +struct MinMaxState<ArrowType, SimdLevel, enable_if_boolean<ArrowType>> { + using ThisType = MinMaxState<ArrowType, SimdLevel>; + using T = typename ArrowType::c_type; + + ThisType& operator+=(const ThisType& rhs) { + this->has_nulls |= rhs.has_nulls; + this->has_values |= rhs.has_values; + this->min = this->min && rhs.min; + this->max = this->max || rhs.max; + return *this; + } + + void MergeOne(T value) { + this->min = this->min && value; + this->max = this->max || value; + } + + T min = true; + T max = false; + bool has_nulls = false; + bool has_values = false; +}; + +template <typename ArrowType, SimdLevel::type SimdLevel> +struct MinMaxState<ArrowType, SimdLevel, enable_if_integer<ArrowType>> { + using ThisType = MinMaxState<ArrowType, SimdLevel>; + using T = typename ArrowType::c_type; + + ThisType& operator+=(const ThisType& rhs) { + this->has_nulls |= rhs.has_nulls; + this->has_values |= rhs.has_values; + this->min = std::min(this->min, rhs.min); + this->max = std::max(this->max, rhs.max); + return *this; + } + + void MergeOne(T value) { + this->min = std::min(this->min, value); + this->max = std::max(this->max, value); + } + + T min = std::numeric_limits<T>::max(); + T max = std::numeric_limits<T>::min(); + bool has_nulls = false; + bool has_values = false; +}; + +template <typename ArrowType, SimdLevel::type SimdLevel> +struct MinMaxState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> { + using ThisType = MinMaxState<ArrowType, SimdLevel>; + using T = typename ArrowType::c_type; + + ThisType& operator+=(const ThisType& rhs) { + this->has_nulls |= rhs.has_nulls; + this->has_values |= rhs.has_values; + this->min = std::fmin(this->min, rhs.min); + this->max = std::fmax(this->max, rhs.max); + return *this; + } + + void MergeOne(T value) { + this->min = std::fmin(this->min, value); + this->max = std::fmax(this->max, value); + } + + T min = std::numeric_limits<T>::infinity(); + T max = -std::numeric_limits<T>::infinity(); + bool has_nulls = false; + bool has_values = false; +}; + +template <typename ArrowType, SimdLevel::type SimdLevel> +struct MinMaxImpl : public ScalarAggregator { + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ThisType = MinMaxImpl<ArrowType, SimdLevel>; + using StateType = MinMaxState<ArrowType, SimdLevel>; + MinMaxImpl(const std::shared_ptr<DataType>& out_type, const ScalarAggregateOptions& options) - : out_type(out_type), options(options) {} - + : out_type(out_type), options(options) {} + Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { return ConsumeArray(ArrayType(batch[0].array())); @@ -245,15 +245,15 @@ struct MinMaxImpl : public ScalarAggregator { } Status ConsumeScalar(const Scalar& scalar) { - StateType local; + StateType local; local.has_nulls = !scalar.is_valid; local.has_values = scalar.is_valid; - + if (local.has_nulls && !options.skip_nulls) { this->state = local; return Status::OK(); } - + local.MergeOne(internal::UnboxScalar<ArrowType>::Unbox(scalar)); this->state = local; return Status::OK(); @@ -262,143 +262,143 @@ struct MinMaxImpl : public ScalarAggregator { Status ConsumeArray(const ArrayType& arr) { StateType local; - const auto null_count = arr.null_count(); - local.has_nulls = null_count > 0; - local.has_values = (arr.length() - null_count) > 0; - + const auto null_count = arr.null_count(); + local.has_nulls = null_count > 0; + local.has_values = (arr.length() - null_count) > 0; + if (local.has_nulls && !options.skip_nulls) { - this->state = local; + this->state = local; return Status::OK(); - } - - if (local.has_nulls) { - local += ConsumeWithNulls(arr); - } else { // All true values - for (int64_t i = 0; i < arr.length(); i++) { - local.MergeOne(arr.Value(i)); - } - } - this->state = local; + } + + if (local.has_nulls) { + local += ConsumeWithNulls(arr); + } else { // All true values + for (int64_t i = 0; i < arr.length(); i++) { + local.MergeOne(arr.Value(i)); + } + } + this->state = local; return Status::OK(); - } - + } + Status MergeFrom(KernelContext*, KernelState&& src) override { - const auto& other = checked_cast<const ThisType&>(src); - this->state += other.state; + const auto& other = checked_cast<const ThisType&>(src); + this->state += other.state; return Status::OK(); - } - + } + Status Finalize(KernelContext*, Datum* out) override { - using ScalarType = typename TypeTraits<ArrowType>::ScalarType; - - std::vector<std::shared_ptr<Scalar>> values; + using ScalarType = typename TypeTraits<ArrowType>::ScalarType; + + std::vector<std::shared_ptr<Scalar>> values; if (!state.has_values || (state.has_nulls && !options.skip_nulls)) { - // (null, null) - values = {std::make_shared<ScalarType>(), std::make_shared<ScalarType>()}; - } else { - values = {std::make_shared<ScalarType>(state.min), - std::make_shared<ScalarType>(state.max)}; - } + // (null, null) + values = {std::make_shared<ScalarType>(), std::make_shared<ScalarType>()}; + } else { + values = {std::make_shared<ScalarType>(state.min), + std::make_shared<ScalarType>(state.max)}; + } out->value = std::make_shared<StructScalar>(std::move(values), this->out_type); return Status::OK(); - } - - std::shared_ptr<DataType> out_type; + } + + std::shared_ptr<DataType> out_type; ScalarAggregateOptions options; - MinMaxState<ArrowType, SimdLevel> state; - - private: - StateType ConsumeWithNulls(const ArrayType& arr) const { - StateType local; - const int64_t length = arr.length(); - int64_t offset = arr.offset(); - const uint8_t* bitmap = arr.null_bitmap_data(); - int64_t idx = 0; - - const auto p = arrow::internal::BitmapWordAlign<1>(bitmap, offset, length); - // First handle the leading bits - const int64_t leading_bits = p.leading_bits; - while (idx < leading_bits) { - if (BitUtil::GetBit(bitmap, offset)) { - local.MergeOne(arr.Value(idx)); - } - idx++; - offset++; - } - - // The aligned parts scanned with BitBlockCounter - arrow::internal::BitBlockCounter data_counter(bitmap, offset, length - leading_bits); - auto current_block = data_counter.NextWord(); - while (idx < length) { - if (current_block.AllSet()) { // All true values - int run_length = 0; - // Scan forward until a block that has some false values (or the end) - while (current_block.length > 0 && current_block.AllSet()) { - run_length += current_block.length; - current_block = data_counter.NextWord(); - } - for (int64_t i = 0; i < run_length; i++) { - local.MergeOne(arr.Value(idx + i)); - } - idx += run_length; - offset += run_length; - // The current_block already computed, advance to next loop - continue; - } else if (!current_block.NoneSet()) { // Some values are null - BitmapReader reader(arr.null_bitmap_data(), offset, current_block.length); - for (int64_t i = 0; i < current_block.length; i++) { - if (reader.IsSet()) { - local.MergeOne(arr.Value(idx + i)); - } - reader.Next(); - } - - idx += current_block.length; - offset += current_block.length; - } else { // All null values - idx += current_block.length; - offset += current_block.length; - } - current_block = data_counter.NextWord(); - } - - return local; - } -}; - -template <SimdLevel::type SimdLevel> -struct BooleanMinMaxImpl : public MinMaxImpl<BooleanType, SimdLevel> { - using StateType = MinMaxState<BooleanType, SimdLevel>; - using ArrayType = typename TypeTraits<BooleanType>::ArrayType; - using MinMaxImpl<BooleanType, SimdLevel>::MinMaxImpl; - using MinMaxImpl<BooleanType, SimdLevel>::options; - + MinMaxState<ArrowType, SimdLevel> state; + + private: + StateType ConsumeWithNulls(const ArrayType& arr) const { + StateType local; + const int64_t length = arr.length(); + int64_t offset = arr.offset(); + const uint8_t* bitmap = arr.null_bitmap_data(); + int64_t idx = 0; + + const auto p = arrow::internal::BitmapWordAlign<1>(bitmap, offset, length); + // First handle the leading bits + const int64_t leading_bits = p.leading_bits; + while (idx < leading_bits) { + if (BitUtil::GetBit(bitmap, offset)) { + local.MergeOne(arr.Value(idx)); + } + idx++; + offset++; + } + + // The aligned parts scanned with BitBlockCounter + arrow::internal::BitBlockCounter data_counter(bitmap, offset, length - leading_bits); + auto current_block = data_counter.NextWord(); + while (idx < length) { + if (current_block.AllSet()) { // All true values + int run_length = 0; + // Scan forward until a block that has some false values (or the end) + while (current_block.length > 0 && current_block.AllSet()) { + run_length += current_block.length; + current_block = data_counter.NextWord(); + } + for (int64_t i = 0; i < run_length; i++) { + local.MergeOne(arr.Value(idx + i)); + } + idx += run_length; + offset += run_length; + // The current_block already computed, advance to next loop + continue; + } else if (!current_block.NoneSet()) { // Some values are null + BitmapReader reader(arr.null_bitmap_data(), offset, current_block.length); + for (int64_t i = 0; i < current_block.length; i++) { + if (reader.IsSet()) { + local.MergeOne(arr.Value(idx + i)); + } + reader.Next(); + } + + idx += current_block.length; + offset += current_block.length; + } else { // All null values + idx += current_block.length; + offset += current_block.length; + } + current_block = data_counter.NextWord(); + } + + return local; + } +}; + +template <SimdLevel::type SimdLevel> +struct BooleanMinMaxImpl : public MinMaxImpl<BooleanType, SimdLevel> { + using StateType = MinMaxState<BooleanType, SimdLevel>; + using ArrayType = typename TypeTraits<BooleanType>::ArrayType; + using MinMaxImpl<BooleanType, SimdLevel>::MinMaxImpl; + using MinMaxImpl<BooleanType, SimdLevel>::options; + Status Consume(KernelContext*, const ExecBatch& batch) override { if (ARROW_PREDICT_FALSE(batch[0].is_scalar())) { return ConsumeScalar(checked_cast<const BooleanScalar&>(*batch[0].scalar())); } - StateType local; - ArrayType arr(batch[0].array()); - - const auto arr_length = arr.length(); - const auto null_count = arr.null_count(); - const auto valid_count = arr_length - null_count; - - local.has_nulls = null_count > 0; - local.has_values = valid_count > 0; + StateType local; + ArrayType arr(batch[0].array()); + + const auto arr_length = arr.length(); + const auto null_count = arr.null_count(); + const auto valid_count = arr_length - null_count; + + local.has_nulls = null_count > 0; + local.has_values = valid_count > 0; if (local.has_nulls && !options.skip_nulls) { - this->state = local; + this->state = local; return Status::OK(); - } - - const auto true_count = arr.true_count(); - const auto false_count = valid_count - true_count; - local.max = true_count > 0; - local.min = false_count == 0; - - this->state = local; + } + + const auto true_count = arr.true_count(); + const auto false_count = valid_count - true_count; + local.max = true_count > 0; + local.min = false_count == 0; + + this->state = local; return Status::OK(); - } + } Status ConsumeScalar(const BooleanScalar& scalar) { StateType local; @@ -418,46 +418,46 @@ struct BooleanMinMaxImpl : public MinMaxImpl<BooleanType, SimdLevel> { this->state = local; return Status::OK(); } -}; - -template <SimdLevel::type SimdLevel> -struct MinMaxInitState { - std::unique_ptr<KernelState> state; - KernelContext* ctx; - const DataType& in_type; - const std::shared_ptr<DataType>& out_type; +}; + +template <SimdLevel::type SimdLevel> +struct MinMaxInitState { + std::unique_ptr<KernelState> state; + KernelContext* ctx; + const DataType& in_type; + const std::shared_ptr<DataType>& out_type; const ScalarAggregateOptions& options; - - MinMaxInitState(KernelContext* ctx, const DataType& in_type, + + MinMaxInitState(KernelContext* ctx, const DataType& in_type, const std::shared_ptr<DataType>& out_type, const ScalarAggregateOptions& options) - : ctx(ctx), in_type(in_type), out_type(out_type), options(options) {} - - Status Visit(const DataType&) { - return Status::NotImplemented("No min/max implemented"); - } - - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("No min/max implemented"); - } - - Status Visit(const BooleanType&) { - state.reset(new BooleanMinMaxImpl<SimdLevel>(out_type, options)); - return Status::OK(); - } - - template <typename Type> - enable_if_number<Type, Status> Visit(const Type&) { - state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options)); - return Status::OK(); - } - + : ctx(ctx), in_type(in_type), out_type(out_type), options(options) {} + + Status Visit(const DataType&) { + return Status::NotImplemented("No min/max implemented"); + } + + Status Visit(const HalfFloatType&) { + return Status::NotImplemented("No min/max implemented"); + } + + Status Visit(const BooleanType&) { + state.reset(new BooleanMinMaxImpl<SimdLevel>(out_type, options)); + return Status::OK(); + } + + template <typename Type> + enable_if_number<Type, Status> Visit(const Type&) { + state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options)); + return Status::OK(); + } + Result<std::unique_ptr<KernelState>> Create() { RETURN_NOT_OK(VisitTypeInline(in_type, this)); - return std::move(state); - } -}; - -} // namespace aggregate -} // namespace compute -} // namespace arrow + return std::move(state); + } +}; + +} // namespace aggregate +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h index ed29f26f2c..930242ac92 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h @@ -1,54 +1,54 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "arrow/type.h" -#include "arrow/type_traits.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/logging.h" - -namespace arrow { -namespace compute { - -// Find the largest compatible primitive type for a primitive type. -template <typename I, typename Enable = void> -struct FindAccumulatorType {}; - -template <typename I> -struct FindAccumulatorType<I, enable_if_boolean<I>> { - using Type = UInt64Type; -}; - -template <typename I> -struct FindAccumulatorType<I, enable_if_signed_integer<I>> { - using Type = Int64Type; -}; - -template <typename I> -struct FindAccumulatorType<I, enable_if_unsigned_integer<I>> { - using Type = UInt64Type; -}; - -template <typename I> -struct FindAccumulatorType<I, enable_if_floating_point<I>> { - using Type = DoubleType; -}; - + +namespace arrow { +namespace compute { + +// Find the largest compatible primitive type for a primitive type. +template <typename I, typename Enable = void> +struct FindAccumulatorType {}; + +template <typename I> +struct FindAccumulatorType<I, enable_if_boolean<I>> { + using Type = UInt64Type; +}; + +template <typename I> +struct FindAccumulatorType<I, enable_if_signed_integer<I>> { + using Type = Int64Type; +}; + +template <typename I> +struct FindAccumulatorType<I, enable_if_unsigned_integer<I>> { + using Type = UInt64Type; +}; + +template <typename I> +struct FindAccumulatorType<I, enable_if_floating_point<I>> { + using Type = DoubleType; +}; + struct ScalarAggregator : public KernelState { virtual Status Consume(KernelContext* ctx, const ExecBatch& batch) = 0; virtual Status MergeFrom(KernelContext* ctx, KernelState&& src) = 0; @@ -168,5 +168,5 @@ SumType SumArray(const ArrayData& data) { } // namespace detail -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc index 6ad0eeb645..4d8f0fc42d 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc @@ -1,24 +1,24 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <cmath> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cmath> #include <queue> #include <utility> - + #include "arrow/compute/api_aggregate.h" #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/common.h" @@ -26,31 +26,31 @@ #include "arrow/result.h" #include "arrow/stl_allocator.h" #include "arrow/type_traits.h" - -namespace arrow { -namespace compute { + +namespace arrow { +namespace compute { namespace internal { - -namespace { - + +namespace { + using ModeState = OptionsWrapper<ModeOptions>; - + constexpr char kModeFieldName[] = "mode"; constexpr char kCountFieldName[] = "count"; - + constexpr uint64_t kCountEOF = ~0ULL; - + template <typename InType, typename CType = typename InType::c_type> Result<std::pair<CType*, int64_t*>> PrepareOutput(int64_t n, KernelContext* ctx, Datum* out) { const auto& mode_type = TypeTraits<InType>::type_singleton(); const auto& count_type = int64(); - + auto mode_data = ArrayData::Make(mode_type, /*length=*/n, /*null_count=*/0); mode_data->buffers.resize(2, nullptr); auto count_data = ArrayData::Make(count_type, n, 0); count_data->buffers.resize(2, nullptr); - + CType* mode_buffer = nullptr; int64_t* count_buffer = nullptr; @@ -59,28 +59,28 @@ Result<std::pair<CType*, int64_t*>> PrepareOutput(int64_t n, KernelContext* ctx, ARROW_ASSIGN_OR_RAISE(count_data->buffers[1], ctx->Allocate(n * sizeof(int64_t))); mode_buffer = mode_data->template GetMutableValues<CType>(1); count_buffer = count_data->template GetMutableValues<int64_t>(1); - } - + } + const auto& out_type = struct_({field(kModeFieldName, mode_type), field(kCountFieldName, count_type)}); *out = Datum(ArrayData::Make(out_type, n, {nullptr}, {mode_data, count_data}, 0)); return std::make_pair(mode_buffer, count_buffer); -} - +} + // find top-n value:count pairs with minimal heap // suboptimal for tiny or large n, possibly okay as we're not in hot path template <typename InType, typename Generator> Status Finalize(KernelContext* ctx, Datum* out, Generator&& gen) { using CType = typename InType::c_type; - + using ValueCountPair = std::pair<CType, uint64_t>; auto gt = [](const ValueCountPair& lhs, const ValueCountPair& rhs) { const bool rhs_is_nan = rhs.first != rhs.first; // nan as largest value return lhs.second > rhs.second || (lhs.second == rhs.second && (lhs.first < rhs.first || rhs_is_nan)); }; - + std::priority_queue<ValueCountPair, std::vector<ValueCountPair>, decltype(gt)> min_heap( std::move(gt)); @@ -94,10 +94,10 @@ Status Finalize(KernelContext* ctx, Datum* out, Generator&& gen) { } else if (gt(value_count, min_heap.top())) { min_heap.pop(); min_heap.push(value_count); - } - } + } + } const int64_t n = min_heap.size(); - + CType* mode_buffer; int64_t* count_buffer; ARROW_ASSIGN_OR_RAISE(std::tie(mode_buffer, count_buffer), @@ -109,29 +109,29 @@ Status Finalize(KernelContext* ctx, Datum* out, Generator&& gen) { } return Status::OK(); -} - +} + // count value occurances for integers with narrow value range // O(1) space, O(n) time template <typename T> struct CountModer { using CType = typename T::c_type; - + CType min; std::vector<uint64_t> counts; - + CountModer(CType min, CType max) { uint32_t value_range = static_cast<uint32_t>(max - min) + 1; DCHECK_LT(value_range, 1 << 20); this->min = min; this->counts.resize(value_range, 0); - } - + } + Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // count values in all chunks, ignore nulls const Datum& datum = batch[0]; CountValues<CType>(this->counts.data(), datum, this->min); - + // generator to emit next value:count pair int index = 0; auto gen = [&]() { @@ -145,17 +145,17 @@ struct CountModer { } return std::pair<CType, uint64_t>(0, kCountEOF); }; - + return Finalize<T>(ctx, out, std::move(gen)); } }; - + // booleans can be handled more straightforward template <> struct CountModer<BooleanType> { Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { int64_t counts[2]{}; - + const Datum& datum = batch[0]; for (const auto& array : datum.chunks()) { if (array->length() > array->null_count()) { @@ -164,13 +164,13 @@ struct CountModer<BooleanType> { const int64_t false_count = array->length() - array->null_count() - true_count; counts[true] += true_count; counts[false] += false_count; - } - } - + } + } + const ModeOptions& options = ModeState::Get(ctx); const int64_t distinct_values = (counts[0] != 0) + (counts[1] != 0); const int64_t n = std::min(options.n, distinct_values); - + bool* mode_buffer; int64_t* count_buffer; ARROW_ASSIGN_OR_RAISE(std::tie(mode_buffer, count_buffer), @@ -183,31 +183,31 @@ struct CountModer<BooleanType> { if (n == 2) { mode_buffer[1] = !index; count_buffer[1] = counts[!index]; - } - } + } + } return Status::OK(); - } -}; - + } +}; + // copy and sort approach for floating points or integers with wide value range // O(n) space, O(nlogn) time template <typename T> struct SortModer { using CType = typename T::c_type; using Allocator = arrow::stl::allocator<CType>; - + Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // copy all chunks to a buffer, ignore nulls and nans std::vector<CType, Allocator> in_buffer(Allocator(ctx->memory_pool())); - + uint64_t nan_count = 0; const Datum& datum = batch[0]; const int64_t in_length = datum.length() - datum.null_count(); if (in_length > 0) { in_buffer.resize(in_length); CopyNonNullValues(datum, in_buffer.data()); - + // drop nan if (is_floating_type<T>::value) { const auto& it = std::remove_if(in_buffer.begin(), in_buffer.end(), @@ -243,14 +243,14 @@ struct SortModer { }; return Finalize<T>(ctx, out, std::move(gen)); - } + } }; - + // pick counting or sorting approach per integers value range template <typename T> struct CountOrSortModer { using CType = typename T::c_type; - + Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // cross point to benefit from counting approach // about 2x improvement for int32/64 from micro-benchmarking @@ -265,12 +265,12 @@ struct CountOrSortModer { if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <= kMaxValueRange) { return CountModer<T>(min, max).Exec(ctx, batch, out); } - } + } return SortModer<T>().Exec(ctx, batch, out); - } + } }; - + template <typename InType, typename Enable = void> struct Moder; @@ -278,30 +278,30 @@ template <> struct Moder<Int8Type> { CountModer<Int8Type> impl; Moder() : impl(-128, 127) {} -}; - +}; + template <> struct Moder<UInt8Type> { CountModer<UInt8Type> impl; Moder() : impl(0, 255) {} }; - + template <> struct Moder<BooleanType> { CountModer<BooleanType> impl; }; - + template <typename InType> struct Moder<InType, enable_if_t<(is_integer_type<InType>::value && (sizeof(typename InType::c_type) > 1))>> { CountOrSortModer<InType> impl; }; - + template <typename InType> struct Moder<InType, enable_if_t<is_floating_type<InType>::value>> { SortModer<InType> impl; }; - + template <typename T> Status ScalarMode(KernelContext* ctx, const Scalar& scalar, Datum* out) { using CType = typename T::c_type; @@ -314,12 +314,12 @@ Status ScalarMode(KernelContext* ctx, const Scalar& scalar, Datum* out) { } return std::pair<CType, uint64_t>(static_cast<CType>(0), kCountEOF); }); - } + } return Finalize<T>(ctx, out, []() { return std::pair<CType, uint64_t>(static_cast<CType>(0), kCountEOF); }); } - + template <typename _, typename InType> struct ModeExecutor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -336,9 +336,9 @@ struct ModeExecutor { } return Moder<InType>().impl.Exec(ctx, batch, out); - } -}; - + } +}; + VectorKernel NewModeKernel(const std::shared_ptr<DataType>& in_type) { VectorKernel kernel; kernel.init = ModeState::Init; @@ -349,8 +349,8 @@ VectorKernel NewModeKernel(const std::shared_ptr<DataType>& in_type) { kernel.signature = KernelSignature::Make({InputType(in_type)}, ValueDescr::Array(out_type)); return kernel; -} - +} + void AddBooleanModeKernel(VectorFunction* func) { VectorKernel kernel = NewModeKernel(boolean()); kernel.exec = ModeExecutor<StructType, BooleanType>::Exec; @@ -362,9 +362,9 @@ void AddNumericModeKernels(VectorFunction* func) { VectorKernel kernel = NewModeKernel(type); kernel.exec = GenerateNumeric<ModeExecutor, StructType>(*type); DCHECK_OK(func->AddKernel(kernel)); - } -} - + } +} + const FunctionDoc mode_doc{ "Calculate the modal (most common) values of a numeric array", ("Returns top-n most common values and number of times they occur in an array.\n" @@ -376,8 +376,8 @@ const FunctionDoc mode_doc{ {"array"}, "ModeOptions"}; -} // namespace - +} // namespace + void RegisterScalarAggregateMode(FunctionRegistry* registry) { static auto default_options = ModeOptions::Defaults(); auto func = std::make_shared<VectorFunction>("mode", Arity::Unary(), &mode_doc, @@ -385,8 +385,8 @@ void RegisterScalarAggregateMode(FunctionRegistry* registry) { AddBooleanModeKernel(func.get()); AddNumericModeKernels(func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); -} - +} + } // namespace internal -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index d6965fed4a..82fc3a2752 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -1,70 +1,70 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include <cmath> - + #include "arrow/compute/api_aggregate.h" #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/common.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/int128_internal.h" -namespace arrow { -namespace compute { +namespace arrow { +namespace compute { namespace internal { - -namespace { - + +namespace { + using arrow::internal::int128_t; using arrow::internal::VisitSetBitRunsVoid; -template <typename ArrowType> -struct VarStdState { - using ArrayType = typename TypeTraits<ArrowType>::ArrayType; +template <typename ArrowType> +struct VarStdState { + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; using CType = typename ArrowType::c_type; - using ThisType = VarStdState<ArrowType>; - + using ThisType = VarStdState<ArrowType>; + // float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm` - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm template <typename T = ArrowType> enable_if_t<is_floating_type<T>::value || (sizeof(CType) > 4)> Consume( const ArrayType& array) { - int64_t count = array.length() - array.null_count(); - if (count == 0) { - return; - } - + int64_t count = array.length() - array.null_count(); + if (count == 0) { + return; + } + using SumType = typename std::conditional<is_floating_type<T>::value, double, int128_t>::type; SumType sum = arrow::compute::detail::SumArray<CType, SumType, SimdLevel::NONE>(*array.data()); - + const double mean = static_cast<double>(sum) / count; const double m2 = arrow::compute::detail::SumArray<CType, double, SimdLevel::NONE>( *array.data(), [mean](CType value) { const double v = static_cast<double>(value); return (v - mean) * (v - mean); }); - - this->count = count; + + this->count = count; this->mean = mean; - this->m2 = m2; - } - + this->m2 = m2; + } + // int32/16/8: textbook one pass algorithm with integer arithmetic template <typename T = ArrowType> enable_if_t<is_integer_type<T>::value && (sizeof(CType) <= 4)> Consume( @@ -118,69 +118,69 @@ struct VarStdState { // Combine `m2` from two chunks (m2 = n*s2) // https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html - void MergeFrom(const ThisType& state) { - if (state.count == 0) { - return; - } - if (this->count == 0) { - this->count = state.count; + void MergeFrom(const ThisType& state) { + if (state.count == 0) { + return; + } + if (this->count == 0) { + this->count = state.count; this->mean = state.mean; - this->m2 = state.m2; - return; - } + this->m2 = state.m2; + return; + } double mean = (this->mean * this->count + state.mean * state.count) / (this->count + state.count); this->m2 += state.m2 + this->count * (this->mean - mean) * (this->mean - mean) + state.count * (state.mean - mean) * (state.mean - mean); - this->count += state.count; + this->count += state.count; this->mean = mean; - } - - int64_t count = 0; + } + + int64_t count = 0; double mean = 0; double m2 = 0; // m2 = count*s2 = sum((X-mean)^2) -}; - -enum class VarOrStd : bool { Var, Std }; - -template <typename ArrowType> -struct VarStdImpl : public ScalarAggregator { - using ThisType = VarStdImpl<ArrowType>; - using ArrayType = typename TypeTraits<ArrowType>::ArrayType; - - explicit VarStdImpl(const std::shared_ptr<DataType>& out_type, - const VarianceOptions& options, VarOrStd return_type) - : out_type(out_type), options(options), return_type(return_type) {} - +}; + +enum class VarOrStd : bool { Var, Std }; + +template <typename ArrowType> +struct VarStdImpl : public ScalarAggregator { + using ThisType = VarStdImpl<ArrowType>; + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + + explicit VarStdImpl(const std::shared_ptr<DataType>& out_type, + const VarianceOptions& options, VarOrStd return_type) + : out_type(out_type), options(options), return_type(return_type) {} + Status Consume(KernelContext*, const ExecBatch& batch) override { - ArrayType array(batch[0].array()); - this->state.Consume(array); + ArrayType array(batch[0].array()); + this->state.Consume(array); return Status::OK(); - } - + } + Status MergeFrom(KernelContext*, KernelState&& src) override { - const auto& other = checked_cast<const ThisType&>(src); - this->state.MergeFrom(other.state); + const auto& other = checked_cast<const ThisType&>(src); + this->state.MergeFrom(other.state); return Status::OK(); - } - + } + Status Finalize(KernelContext*, Datum* out) override { - if (this->state.count <= options.ddof) { + if (this->state.count <= options.ddof) { out->value = std::make_shared<DoubleScalar>(); - } else { - double var = this->state.m2 / (this->state.count - options.ddof); - out->value = + } else { + double var = this->state.m2 / (this->state.count - options.ddof); + out->value = std::make_shared<DoubleScalar>(return_type == VarOrStd::Var ? var : sqrt(var)); - } + } return Status::OK(); - } - - std::shared_ptr<DataType> out_type; - VarStdState<ArrowType> state; - VarianceOptions options; - VarOrStd return_type; -}; - + } + + std::shared_ptr<DataType> out_type; + VarStdState<ArrowType> state; + VarianceOptions options; + VarOrStd return_type; +}; + struct ScalarVarStdImpl : public ScalarAggregator { explicit ScalarVarStdImpl(const VarianceOptions& options) : options(options), seen(false) {} @@ -209,77 +209,77 @@ struct ScalarVarStdImpl : public ScalarAggregator { bool seen; }; -struct VarStdInitState { - std::unique_ptr<KernelState> state; - KernelContext* ctx; - const DataType& in_type; - const std::shared_ptr<DataType>& out_type; - const VarianceOptions& options; - VarOrStd return_type; - - VarStdInitState(KernelContext* ctx, const DataType& in_type, - const std::shared_ptr<DataType>& out_type, - const VarianceOptions& options, VarOrStd return_type) - : ctx(ctx), - in_type(in_type), - out_type(out_type), - options(options), - return_type(return_type) {} - - Status Visit(const DataType&) { - return Status::NotImplemented("No variance/stddev implemented"); - } - - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("No variance/stddev implemented"); - } - - template <typename Type> - enable_if_t<is_number_type<Type>::value, Status> Visit(const Type&) { - state.reset(new VarStdImpl<Type>(out_type, options, return_type)); - return Status::OK(); - } - +struct VarStdInitState { + std::unique_ptr<KernelState> state; + KernelContext* ctx; + const DataType& in_type; + const std::shared_ptr<DataType>& out_type; + const VarianceOptions& options; + VarOrStd return_type; + + VarStdInitState(KernelContext* ctx, const DataType& in_type, + const std::shared_ptr<DataType>& out_type, + const VarianceOptions& options, VarOrStd return_type) + : ctx(ctx), + in_type(in_type), + out_type(out_type), + options(options), + return_type(return_type) {} + + Status Visit(const DataType&) { + return Status::NotImplemented("No variance/stddev implemented"); + } + + Status Visit(const HalfFloatType&) { + return Status::NotImplemented("No variance/stddev implemented"); + } + + template <typename Type> + enable_if_t<is_number_type<Type>::value, Status> Visit(const Type&) { + state.reset(new VarStdImpl<Type>(out_type, options, return_type)); + return Status::OK(); + } + Result<std::unique_ptr<KernelState>> Create() { RETURN_NOT_OK(VisitTypeInline(in_type, this)); - return std::move(state); - } -}; - + return std::move(state); + } +}; + Result<std::unique_ptr<KernelState>> StddevInit(KernelContext* ctx, const KernelInitArgs& args) { - VarStdInitState visitor( - ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(), - static_cast<const VarianceOptions&>(*args.options), VarOrStd::Std); - return visitor.Create(); -} - + VarStdInitState visitor( + ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(), + static_cast<const VarianceOptions&>(*args.options), VarOrStd::Std); + return visitor.Create(); +} + Result<std::unique_ptr<KernelState>> VarianceInit(KernelContext* ctx, const KernelInitArgs& args) { - VarStdInitState visitor( - ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(), - static_cast<const VarianceOptions&>(*args.options), VarOrStd::Var); - return visitor.Create(); -} - + VarStdInitState visitor( + ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(), + static_cast<const VarianceOptions&>(*args.options), VarOrStd::Var); + return visitor.Create(); +} + Result<std::unique_ptr<KernelState>> ScalarVarStdInit(KernelContext* ctx, const KernelInitArgs& args) { return arrow::internal::make_unique<ScalarVarStdImpl>( static_cast<const VarianceOptions&>(*args.options)); } -void AddVarStdKernels(KernelInit init, - const std::vector<std::shared_ptr<DataType>>& types, - ScalarAggregateFunction* func) { - for (const auto& ty : types) { - auto sig = KernelSignature::Make({InputType::Array(ty)}, float64()); - AddAggKernel(std::move(sig), init, func); +void AddVarStdKernels(KernelInit init, + const std::vector<std::shared_ptr<DataType>>& types, + ScalarAggregateFunction* func) { + for (const auto& ty : types) { + auto sig = KernelSignature::Make({InputType::Array(ty)}, float64()); + AddAggKernel(std::move(sig), init, func); sig = KernelSignature::Make({InputType::Scalar(ty)}, float64()); AddAggKernel(std::move(sig), ScalarVarStdInit, func); - } -} - + } +} + const FunctionDoc stddev_doc{ "Calculate the standard deviation of a numeric array", ("The number of degrees of freedom can be controlled using VarianceOptions.\n" @@ -288,7 +288,7 @@ const FunctionDoc stddev_doc{ "to satisfy `ddof`, null is returned."), {"array"}, "VarianceOptions"}; - + const FunctionDoc variance_doc{ "Calculate the variance of a numeric array", ("The number of degrees of freedom can be controlled using VarianceOptions.\n" @@ -298,22 +298,22 @@ const FunctionDoc variance_doc{ {"array"}, "VarianceOptions"}; -std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels() { - static auto default_std_options = VarianceOptions::Defaults(); +std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels() { + static auto default_std_options = VarianceOptions::Defaults(); auto func = std::make_shared<ScalarAggregateFunction>( "stddev", Arity::Unary(), &stddev_doc, &default_std_options); AddVarStdKernels(StddevInit, NumericTypes(), func.get()); - return func; -} - -std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels() { - static auto default_var_options = VarianceOptions::Defaults(); + return func; +} + +std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels() { + static auto default_var_options = VarianceOptions::Defaults(); auto func = std::make_shared<ScalarAggregateFunction>( "variance", Arity::Unary(), &variance_doc, &default_var_options); AddVarStdKernels(VarianceInit, NumericTypes(), func.get()); - return func; -} - + return func; +} + } // namespace void RegisterScalarAggregateVariance(FunctionRegistry* registry) { @@ -322,5 +322,5 @@ void RegisterScalarAggregateVariance(FunctionRegistry* registry) { } } // namespace internal -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc index bab8e7000c..aa342eec25 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -1,195 +1,195 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/kernels/codegen_internal.h" - -#include <functional> -#include <memory> -#include <mutex> -#include <vector> - -#include "arrow/type_fwd.h" - -namespace arrow { -namespace compute { -namespace internal { - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/codegen_internal.h" + +#include <functional> +#include <memory> +#include <mutex> +#include <vector> + +#include "arrow/type_fwd.h" + +namespace arrow { +namespace compute { +namespace internal { + Status ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return Status::NotImplemented("This kernel is malformed"); -} - -ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) { - return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecBatch flipped_batch = batch; - std::swap(flipped_batch.values[0], flipped_batch.values[1]); +} + +ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) { + return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ExecBatch flipped_batch = batch; + std::swap(flipped_batch.values[0], flipped_batch.values[1]); return exec(ctx, flipped_batch, out); - }; -} - -std::vector<std::shared_ptr<DataType>> g_signed_int_types; -std::vector<std::shared_ptr<DataType>> g_unsigned_int_types; -std::vector<std::shared_ptr<DataType>> g_int_types; -std::vector<std::shared_ptr<DataType>> g_floating_types; -std::vector<std::shared_ptr<DataType>> g_numeric_types; -std::vector<std::shared_ptr<DataType>> g_base_binary_types; -std::vector<std::shared_ptr<DataType>> g_temporal_types; -std::vector<std::shared_ptr<DataType>> g_primitive_types; + }; +} + +std::vector<std::shared_ptr<DataType>> g_signed_int_types; +std::vector<std::shared_ptr<DataType>> g_unsigned_int_types; +std::vector<std::shared_ptr<DataType>> g_int_types; +std::vector<std::shared_ptr<DataType>> g_floating_types; +std::vector<std::shared_ptr<DataType>> g_numeric_types; +std::vector<std::shared_ptr<DataType>> g_base_binary_types; +std::vector<std::shared_ptr<DataType>> g_temporal_types; +std::vector<std::shared_ptr<DataType>> g_primitive_types; std::vector<Type::type> g_decimal_type_ids; -static std::once_flag codegen_static_initialized; - -template <typename T> -void Extend(const std::vector<T>& values, std::vector<T>* out) { - for (const auto& t : values) { - out->push_back(t); - } -} - -static void InitStaticData() { - // Signed int types - g_signed_int_types = {int8(), int16(), int32(), int64()}; - - // Unsigned int types - g_unsigned_int_types = {uint8(), uint16(), uint32(), uint64()}; - - // All int types - Extend(g_unsigned_int_types, &g_int_types); - Extend(g_signed_int_types, &g_int_types); - - // Floating point types - g_floating_types = {float32(), float64()}; - +static std::once_flag codegen_static_initialized; + +template <typename T> +void Extend(const std::vector<T>& values, std::vector<T>* out) { + for (const auto& t : values) { + out->push_back(t); + } +} + +static void InitStaticData() { + // Signed int types + g_signed_int_types = {int8(), int16(), int32(), int64()}; + + // Unsigned int types + g_unsigned_int_types = {uint8(), uint16(), uint32(), uint64()}; + + // All int types + Extend(g_unsigned_int_types, &g_int_types); + Extend(g_signed_int_types, &g_int_types); + + // Floating point types + g_floating_types = {float32(), float64()}; + // Decimal types g_decimal_type_ids = {Type::DECIMAL128, Type::DECIMAL256}; - // Numeric types - Extend(g_int_types, &g_numeric_types); - Extend(g_floating_types, &g_numeric_types); - - // Temporal types - g_temporal_types = {date32(), - date64(), - time32(TimeUnit::SECOND), - time32(TimeUnit::MILLI), - time64(TimeUnit::MICRO), - time64(TimeUnit::NANO), - timestamp(TimeUnit::SECOND), - timestamp(TimeUnit::MILLI), - timestamp(TimeUnit::MICRO), - timestamp(TimeUnit::NANO)}; - - // Base binary types (without FixedSizeBinary) - g_base_binary_types = {binary(), utf8(), large_binary(), large_utf8()}; - - // Non-parametric, non-nested types. This also DOES NOT include - // - // * Decimal - // * Fixed Size Binary - // * Time32 - // * Time64 - // * Timestamp - g_primitive_types = {null(), boolean(), date32(), date64()}; - Extend(g_numeric_types, &g_primitive_types); - Extend(g_base_binary_types, &g_primitive_types); -} - -const std::vector<std::shared_ptr<DataType>>& BaseBinaryTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_base_binary_types; -} - -const std::vector<std::shared_ptr<DataType>>& StringTypes() { - static DataTypeVector types = {utf8(), large_utf8()}; - return types; -} - -const std::vector<std::shared_ptr<DataType>>& SignedIntTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_signed_int_types; -} - -const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_unsigned_int_types; -} - -const std::vector<std::shared_ptr<DataType>>& IntTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_int_types; -} - -const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_floating_types; -} - + // Numeric types + Extend(g_int_types, &g_numeric_types); + Extend(g_floating_types, &g_numeric_types); + + // Temporal types + g_temporal_types = {date32(), + date64(), + time32(TimeUnit::SECOND), + time32(TimeUnit::MILLI), + time64(TimeUnit::MICRO), + time64(TimeUnit::NANO), + timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::MILLI), + timestamp(TimeUnit::MICRO), + timestamp(TimeUnit::NANO)}; + + // Base binary types (without FixedSizeBinary) + g_base_binary_types = {binary(), utf8(), large_binary(), large_utf8()}; + + // Non-parametric, non-nested types. This also DOES NOT include + // + // * Decimal + // * Fixed Size Binary + // * Time32 + // * Time64 + // * Timestamp + g_primitive_types = {null(), boolean(), date32(), date64()}; + Extend(g_numeric_types, &g_primitive_types); + Extend(g_base_binary_types, &g_primitive_types); +} + +const std::vector<std::shared_ptr<DataType>>& BaseBinaryTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_base_binary_types; +} + +const std::vector<std::shared_ptr<DataType>>& StringTypes() { + static DataTypeVector types = {utf8(), large_utf8()}; + return types; +} + +const std::vector<std::shared_ptr<DataType>>& SignedIntTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_signed_int_types; +} + +const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_unsigned_int_types; +} + +const std::vector<std::shared_ptr<DataType>>& IntTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_int_types; +} + +const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_floating_types; +} + const std::vector<Type::type>& DecimalTypeIds() { std::call_once(codegen_static_initialized, InitStaticData); return g_decimal_type_ids; } -const std::vector<TimeUnit::type>& AllTimeUnits() { - static std::vector<TimeUnit::type> units = {TimeUnit::SECOND, TimeUnit::MILLI, - TimeUnit::MICRO, TimeUnit::NANO}; - return units; -} - -const std::vector<std::shared_ptr<DataType>>& NumericTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_numeric_types; -} - -const std::vector<std::shared_ptr<DataType>>& TemporalTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_temporal_types; -} - -const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes() { - std::call_once(codegen_static_initialized, InitStaticData); - return g_primitive_types; -} - -const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() { - static DataTypeVector example_parametric_types = { +const std::vector<TimeUnit::type>& AllTimeUnits() { + static std::vector<TimeUnit::type> units = {TimeUnit::SECOND, TimeUnit::MILLI, + TimeUnit::MICRO, TimeUnit::NANO}; + return units; +} + +const std::vector<std::shared_ptr<DataType>>& NumericTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_numeric_types; +} + +const std::vector<std::shared_ptr<DataType>>& TemporalTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_temporal_types; +} + +const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_primitive_types; +} + +const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() { + static DataTypeVector example_parametric_types = { decimal128(12, 2), - duration(TimeUnit::SECOND), - timestamp(TimeUnit::SECOND), - time32(TimeUnit::SECOND), - time64(TimeUnit::MICRO), - fixed_size_binary(0), - list(null()), - large_list(null()), - fixed_size_list(field("dummy", null()), 0), - struct_({}), - sparse_union(FieldVector{}), - dense_union(FieldVector{}), - dictionary(int32(), null()), - map(null(), null())}; - return example_parametric_types; -} - -// Construct dummy parametric types so that we can get VisitTypeInline to -// work above - -Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) { + duration(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND), + time32(TimeUnit::SECOND), + time64(TimeUnit::MICRO), + fixed_size_binary(0), + list(null()), + large_list(null()), + fixed_size_list(field("dummy", null()), 0), + struct_({}), + sparse_union(FieldVector{}), + dense_union(FieldVector{}), + dictionary(int32(), null()), + map(null(), null())}; + return example_parametric_types; +} + +// Construct dummy parametric types so that we can get VisitTypeInline to +// work above + +Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) { ValueDescr result = descrs.front(); result.shape = GetBroadcastShape(descrs); return result; -} - +} + void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) { for (ValueDescr& descr : *descrs) { if (descr.type->id() == Type::DICTIONARY) { @@ -332,6 +332,6 @@ std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs) { return large_binary(); } -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h index cb9b13bb3d..11a08a6ea9 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1,95 +1,95 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> #include <cstring> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "arrow/array/builder_binary.h" -#include "arrow/array/data.h" -#include "arrow/buffer.h" -#include "arrow/buffer_builder.h" -#include "arrow/compute/exec.h" -#include "arrow/compute/kernel.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/bit_block_counter.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_generate.h" -#include "arrow/util/bitmap_reader.h" -#include "arrow/util/bitmap_writer.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/decimal.h" -#include "arrow/util/logging.h" -#include "arrow/util/macros.h" -#include "arrow/util/make_unique.h" -#include "arrow/util/optional.h" -#include "arrow/util/string_view.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using internal::BinaryBitBlockCounter; -using internal::BitBlockCount; -using internal::BitmapReader; -using internal::checked_cast; -using internal::FirstTimeBitmapWriter; -using internal::GenerateBitsUnrolled; -using internal::VisitBitBlocksVoid; -using internal::VisitTwoBitBlocksVoid; - -namespace compute { -namespace internal { - -/// KernelState adapter for the common case of kernels whose only -/// state is an instance of a subclass of FunctionOptions. -/// Default FunctionOptions are *not* handled here. -template <typename OptionsType> -struct OptionsWrapper : public KernelState { - explicit OptionsWrapper(OptionsType options) : options(std::move(options)) {} - +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/array/builder_binary.h" +#include "arrow/array/data.h" +#include "arrow/buffer.h" +#include "arrow/buffer_builder.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/kernel.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/scalar.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_generate.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/bitmap_writer.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/optional.h" +#include "arrow/util/string_view.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +using internal::BinaryBitBlockCounter; +using internal::BitBlockCount; +using internal::BitmapReader; +using internal::checked_cast; +using internal::FirstTimeBitmapWriter; +using internal::GenerateBitsUnrolled; +using internal::VisitBitBlocksVoid; +using internal::VisitTwoBitBlocksVoid; + +namespace compute { +namespace internal { + +/// KernelState adapter for the common case of kernels whose only +/// state is an instance of a subclass of FunctionOptions. +/// Default FunctionOptions are *not* handled here. +template <typename OptionsType> +struct OptionsWrapper : public KernelState { + explicit OptionsWrapper(OptionsType options) : options(std::move(options)) {} + static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx, const KernelInitArgs& args) { - if (auto options = static_cast<const OptionsType*>(args.options)) { - return ::arrow::internal::make_unique<OptionsWrapper>(*options); - } - + if (auto options = static_cast<const OptionsType*>(args.options)) { + return ::arrow::internal::make_unique<OptionsWrapper>(*options); + } + return Status::Invalid( "Attempted to initialize KernelState from null FunctionOptions"); - } - - static const OptionsType& Get(const KernelState& state) { - return ::arrow::internal::checked_cast<const OptionsWrapper&>(state).options; - } - - static const OptionsType& Get(KernelContext* ctx) { return Get(*ctx->state()); } - - OptionsType options; -}; - + } + + static const OptionsType& Get(const KernelState& state) { + return ::arrow::internal::checked_cast<const OptionsWrapper&>(state).options; + } + + static const OptionsType& Get(KernelContext* ctx) { return Get(*ctx->state()); } + + OptionsType options; +}; + /// KernelState adapter for when the state is an instance constructed with the /// KernelContext and the FunctionOptions as argument template <typename StateType, typename OptionsType> @@ -118,41 +118,41 @@ struct KernelStateFromFunctionOptions : public KernelState { StateType state; }; -// ---------------------------------------------------------------------- -// Input and output value type definitions - -template <typename Type, typename Enable = void> -struct GetViewType; - -template <typename Type> -struct GetViewType<Type, enable_if_has_c_type<Type>> { - using T = typename Type::c_type; - using PhysicalType = T; - - static T LogicalValue(PhysicalType value) { return value; } -}; - -template <typename Type> -struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value || - is_fixed_size_binary_type<Type>::value>> { - using T = util::string_view; - using PhysicalType = T; - - static T LogicalValue(PhysicalType value) { return value; } -}; - -template <> -struct GetViewType<Decimal128Type> { - using T = Decimal128; - using PhysicalType = util::string_view; - - static T LogicalValue(PhysicalType value) { - return Decimal128(reinterpret_cast<const uint8_t*>(value.data())); - } +// ---------------------------------------------------------------------- +// Input and output value type definitions + +template <typename Type, typename Enable = void> +struct GetViewType; + +template <typename Type> +struct GetViewType<Type, enable_if_has_c_type<Type>> { + using T = typename Type::c_type; + using PhysicalType = T; + + static T LogicalValue(PhysicalType value) { return value; } +}; + +template <typename Type> +struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value || + is_fixed_size_binary_type<Type>::value>> { + using T = util::string_view; + using PhysicalType = T; + + static T LogicalValue(PhysicalType value) { return value; } +}; + +template <> +struct GetViewType<Decimal128Type> { + using T = Decimal128; + using PhysicalType = util::string_view; + + static T LogicalValue(PhysicalType value) { + return Decimal128(reinterpret_cast<const uint8_t*>(value.data())); + } static T LogicalValue(T value) { return value; } -}; - +}; + template <> struct GetViewType<Decimal256Type> { using T = Decimal256; @@ -165,88 +165,88 @@ struct GetViewType<Decimal256Type> { static T LogicalValue(T value) { return value; } }; -template <typename Type, typename Enable = void> -struct GetOutputType; - -template <typename Type> -struct GetOutputType<Type, enable_if_has_c_type<Type>> { - using T = typename Type::c_type; -}; - -template <typename Type> -struct GetOutputType<Type, enable_if_t<is_string_like_type<Type>::value>> { - using T = std::string; -}; - -template <> -struct GetOutputType<Decimal128Type> { - using T = Decimal128; -}; - +template <typename Type, typename Enable = void> +struct GetOutputType; + +template <typename Type> +struct GetOutputType<Type, enable_if_has_c_type<Type>> { + using T = typename Type::c_type; +}; + +template <typename Type> +struct GetOutputType<Type, enable_if_t<is_string_like_type<Type>::value>> { + using T = std::string; +}; + +template <> +struct GetOutputType<Decimal128Type> { + using T = Decimal128; +}; + template <> struct GetOutputType<Decimal256Type> { using T = Decimal256; }; -// ---------------------------------------------------------------------- -// Iteration / value access utilities - -template <typename T, typename R = void> -using enable_if_has_c_type_not_boolean = - enable_if_t<has_c_type<T>::value && !is_boolean_type<T>::value, R>; - -// Iterator over various input array types, yielding a GetViewType<Type> - -template <typename Type, typename Enable = void> -struct ArrayIterator; - -template <typename Type> -struct ArrayIterator<Type, enable_if_has_c_type_not_boolean<Type>> { - using T = typename Type::c_type; - const T* values; - - explicit ArrayIterator(const ArrayData& data) : values(data.GetValues<T>(1)) {} - T operator()() { return *values++; } -}; - -template <typename Type> -struct ArrayIterator<Type, enable_if_boolean<Type>> { - BitmapReader reader; - - explicit ArrayIterator(const ArrayData& data) - : reader(data.buffers[1]->data(), data.offset, data.length) {} - bool operator()() { - bool out = reader.IsSet(); - reader.Next(); - return out; - } -}; - -template <typename Type> -struct ArrayIterator<Type, enable_if_base_binary<Type>> { - using offset_type = typename Type::offset_type; - const ArrayData& arr; - const offset_type* offsets; - offset_type cur_offset; - const char* data; - int64_t position; - - explicit ArrayIterator(const ArrayData& arr) - : arr(arr), - offsets(reinterpret_cast<const offset_type*>(arr.buffers[1]->data()) + - arr.offset), - cur_offset(offsets[0]), - data(reinterpret_cast<const char*>(arr.buffers[2]->data())), - position(0) {} - - util::string_view operator()() { - offset_type next_offset = offsets[++position]; - auto result = util::string_view(data + cur_offset, next_offset - cur_offset); - cur_offset = next_offset; - return result; - } -}; - +// ---------------------------------------------------------------------- +// Iteration / value access utilities + +template <typename T, typename R = void> +using enable_if_has_c_type_not_boolean = + enable_if_t<has_c_type<T>::value && !is_boolean_type<T>::value, R>; + +// Iterator over various input array types, yielding a GetViewType<Type> + +template <typename Type, typename Enable = void> +struct ArrayIterator; + +template <typename Type> +struct ArrayIterator<Type, enable_if_has_c_type_not_boolean<Type>> { + using T = typename Type::c_type; + const T* values; + + explicit ArrayIterator(const ArrayData& data) : values(data.GetValues<T>(1)) {} + T operator()() { return *values++; } +}; + +template <typename Type> +struct ArrayIterator<Type, enable_if_boolean<Type>> { + BitmapReader reader; + + explicit ArrayIterator(const ArrayData& data) + : reader(data.buffers[1]->data(), data.offset, data.length) {} + bool operator()() { + bool out = reader.IsSet(); + reader.Next(); + return out; + } +}; + +template <typename Type> +struct ArrayIterator<Type, enable_if_base_binary<Type>> { + using offset_type = typename Type::offset_type; + const ArrayData& arr; + const offset_type* offsets; + offset_type cur_offset; + const char* data; + int64_t position; + + explicit ArrayIterator(const ArrayData& arr) + : arr(arr), + offsets(reinterpret_cast<const offset_type*>(arr.buffers[1]->data()) + + arr.offset), + cur_offset(offsets[0]), + data(reinterpret_cast<const char*>(arr.buffers[2]->data())), + position(0) {} + + util::string_view operator()() { + offset_type next_offset = offsets[++position]; + auto result = util::string_view(data + cur_offset, next_offset - cur_offset); + cur_offset = next_offset; + return result; + } +}; + template <typename Type> struct ArrayIterator<Type, enable_if_decimal<Type>> { using T = typename TypeTraits<Type>::ScalarType::ValueType; @@ -259,27 +259,27 @@ struct ArrayIterator<Type, enable_if_decimal<Type>> { T operator()() { return T{values++->data()}; } }; -// Iterator over various output array types, taking a GetOutputType<Type> - -template <typename Type, typename Enable = void> -struct OutputArrayWriter; - -template <typename Type> -struct OutputArrayWriter<Type, enable_if_has_c_type_not_boolean<Type>> { - using T = typename Type::c_type; - T* values; - - explicit OutputArrayWriter(ArrayData* data) : values(data->GetMutableValues<T>(1)) {} - - void Write(T value) { *values++ = value; } - - // Note that this doesn't write the null bitmap, which should be consistent - // with Write / WriteNull calls - void WriteNull() { *values++ = T{}; } +// Iterator over various output array types, taking a GetOutputType<Type> + +template <typename Type, typename Enable = void> +struct OutputArrayWriter; + +template <typename Type> +struct OutputArrayWriter<Type, enable_if_has_c_type_not_boolean<Type>> { + using T = typename Type::c_type; + T* values; + + explicit OutputArrayWriter(ArrayData* data) : values(data->GetMutableValues<T>(1)) {} + + void Write(T value) { *values++ = value; } + + // Note that this doesn't write the null bitmap, which should be consistent + // with Write / WriteNull calls + void WriteNull() { *values++ = T{}; } void WriteAllNull(int64_t length) { std::memset(values, 0, sizeof(T) * length); } -}; - +}; + template <typename Type> struct OutputArrayWriter<Type, enable_if_decimal<Type>> { using T = typename TypeTraits<Type>::ScalarType::ValueType; @@ -296,35 +296,35 @@ struct OutputArrayWriter<Type, enable_if_decimal<Type>> { void WriteAllNull(int64_t length) { std::memset(values, 0, sizeof(T) * length); } }; -// (Un)box Scalar to / from C++ value - -template <typename Type, typename Enable = void> -struct UnboxScalar; - -template <typename Type> -struct UnboxScalar<Type, enable_if_has_c_type<Type>> { - using T = typename Type::c_type; - static T Unbox(const Scalar& val) { - return *reinterpret_cast<const T*>( - checked_cast<const ::arrow::internal::PrimitiveScalarBase&>(val).data()); - } -}; - -template <typename Type> +// (Un)box Scalar to / from C++ value + +template <typename Type, typename Enable = void> +struct UnboxScalar; + +template <typename Type> +struct UnboxScalar<Type, enable_if_has_c_type<Type>> { + using T = typename Type::c_type; + static T Unbox(const Scalar& val) { + return *reinterpret_cast<const T*>( + checked_cast<const ::arrow::internal::PrimitiveScalarBase&>(val).data()); + } +}; + +template <typename Type> struct UnboxScalar<Type, enable_if_has_string_view<Type>> { - static util::string_view Unbox(const Scalar& val) { + static util::string_view Unbox(const Scalar& val) { if (!val.is_valid) return util::string_view(); - return util::string_view(*checked_cast<const BaseBinaryScalar&>(val).value); - } -}; - -template <> -struct UnboxScalar<Decimal128Type> { - static Decimal128 Unbox(const Scalar& val) { - return checked_cast<const Decimal128Scalar&>(val).value; - } -}; - + return util::string_view(*checked_cast<const BaseBinaryScalar&>(val).value); + } +}; + +template <> +struct UnboxScalar<Decimal128Type> { + static Decimal128 Unbox(const Scalar& val) { + return checked_cast<const Decimal128Scalar&>(val).value; + } +}; + template <> struct UnboxScalar<Decimal256Type> { static Decimal256 Unbox(const Scalar& val) { @@ -332,36 +332,36 @@ struct UnboxScalar<Decimal256Type> { } }; -template <typename Type, typename Enable = void> -struct BoxScalar; - -template <typename Type> -struct BoxScalar<Type, enable_if_has_c_type<Type>> { - using T = typename GetOutputType<Type>::T; +template <typename Type, typename Enable = void> +struct BoxScalar; + +template <typename Type> +struct BoxScalar<Type, enable_if_has_c_type<Type>> { + using T = typename GetOutputType<Type>::T; static void Box(T val, Scalar* out) { // Enables BoxScalar<Int64Type> to work on a (for example) Time64Scalar T* mutable_data = reinterpret_cast<T*>( checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data()); *mutable_data = val; } -}; - -template <typename Type> -struct BoxScalar<Type, enable_if_base_binary<Type>> { - using T = typename GetOutputType<Type>::T; - using ScalarType = typename TypeTraits<Type>::ScalarType; - static void Box(T val, Scalar* out) { - checked_cast<ScalarType*>(out)->value = std::make_shared<Buffer>(val); - } -}; - -template <> -struct BoxScalar<Decimal128Type> { - using T = Decimal128; - using ScalarType = Decimal128Scalar; - static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value = val; } -}; - +}; + +template <typename Type> +struct BoxScalar<Type, enable_if_base_binary<Type>> { + using T = typename GetOutputType<Type>::T; + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void Box(T val, Scalar* out) { + checked_cast<ScalarType*>(out)->value = std::make_shared<Buffer>(val); + } +}; + +template <> +struct BoxScalar<Decimal128Type> { + using T = Decimal128; + using ScalarType = Decimal128Scalar; + static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value = val; } +}; + template <> struct BoxScalar<Decimal256Type> { using T = Decimal256; @@ -369,21 +369,21 @@ struct BoxScalar<Decimal256Type> { static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value = val; } }; -// A VisitArrayDataInline variant that calls its visitor function with logical -// values, such as Decimal128 rather than util::string_view. - -template <typename T, typename VisitFunc, typename NullFunc> +// A VisitArrayDataInline variant that calls its visitor function with logical +// values, such as Decimal128 rather than util::string_view. + +template <typename T, typename VisitFunc, typename NullFunc> static typename arrow::internal::call_traits::enable_if_return<VisitFunc, void>::type VisitArrayValuesInline(const ArrayData& arr, VisitFunc&& valid_func, NullFunc&& null_func) { - VisitArrayDataInline<T>( - arr, - [&](typename GetViewType<T>::PhysicalType v) { - valid_func(GetViewType<T>::LogicalValue(std::move(v))); - }, - std::forward<NullFunc>(null_func)); -} - + VisitArrayDataInline<T>( + arr, + [&](typename GetViewType<T>::PhysicalType v) { + valid_func(GetViewType<T>::LogicalValue(std::move(v))); + }, + std::forward<NullFunc>(null_func)); +} + template <typename T, typename VisitFunc, typename NullFunc> static typename arrow::internal::call_traits::enable_if_return<VisitFunc, Status>::type VisitArrayValuesInline(const ArrayData& arr, VisitFunc&& valid_func, @@ -396,110 +396,110 @@ VisitArrayValuesInline(const ArrayData& arr, VisitFunc&& valid_func, std::forward<NullFunc>(null_func)); } -// Like VisitArrayValuesInline, but for binary functions. - -template <typename Arg0Type, typename Arg1Type, typename VisitFunc, typename NullFunc> -static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& arr1, - VisitFunc&& valid_func, NullFunc&& null_func) { - ArrayIterator<Arg0Type> arr0_it(arr0); - ArrayIterator<Arg1Type> arr1_it(arr1); - - auto visit_valid = [&](int64_t i) { - valid_func(GetViewType<Arg0Type>::LogicalValue(arr0_it()), - GetViewType<Arg1Type>::LogicalValue(arr1_it())); - }; - auto visit_null = [&]() { - arr0_it(); - arr1_it(); - null_func(); - }; - VisitTwoBitBlocksVoid(arr0.buffers[0], arr0.offset, arr1.buffers[0], arr1.offset, - arr0.length, std::move(visit_valid), std::move(visit_null)); -} - -// ---------------------------------------------------------------------- -// Reusable type resolvers - -Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs); - -// ---------------------------------------------------------------------- -// Generate an array kernel given template classes - +// Like VisitArrayValuesInline, but for binary functions. + +template <typename Arg0Type, typename Arg1Type, typename VisitFunc, typename NullFunc> +static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& arr1, + VisitFunc&& valid_func, NullFunc&& null_func) { + ArrayIterator<Arg0Type> arr0_it(arr0); + ArrayIterator<Arg1Type> arr1_it(arr1); + + auto visit_valid = [&](int64_t i) { + valid_func(GetViewType<Arg0Type>::LogicalValue(arr0_it()), + GetViewType<Arg1Type>::LogicalValue(arr1_it())); + }; + auto visit_null = [&]() { + arr0_it(); + arr1_it(); + null_func(); + }; + VisitTwoBitBlocksVoid(arr0.buffers[0], arr0.offset, arr1.buffers[0], arr1.offset, + arr0.length, std::move(visit_valid), std::move(visit_null)); +} + +// ---------------------------------------------------------------------- +// Reusable type resolvers + +Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs); + +// ---------------------------------------------------------------------- +// Generate an array kernel given template classes + Status ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out); - -ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec); - -// ---------------------------------------------------------------------- -// Helpers for iterating over common DataType instances for adding kernels to -// functions - -const std::vector<std::shared_ptr<DataType>>& BaseBinaryTypes(); -const std::vector<std::shared_ptr<DataType>>& StringTypes(); -const std::vector<std::shared_ptr<DataType>>& SignedIntTypes(); -const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes(); -const std::vector<std::shared_ptr<DataType>>& IntTypes(); -const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes(); + +ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec); + +// ---------------------------------------------------------------------- +// Helpers for iterating over common DataType instances for adding kernels to +// functions + +const std::vector<std::shared_ptr<DataType>>& BaseBinaryTypes(); +const std::vector<std::shared_ptr<DataType>>& StringTypes(); +const std::vector<std::shared_ptr<DataType>>& SignedIntTypes(); +const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes(); +const std::vector<std::shared_ptr<DataType>>& IntTypes(); +const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes(); const std::vector<Type::type>& DecimalTypeIds(); - -ARROW_EXPORT -const std::vector<TimeUnit::type>& AllTimeUnits(); - -// Returns a vector of example instances of parametric types such as -// -// * Decimal -// * Timestamp (requiring unit) -// * Time32 (requiring unit) -// * Time64 (requiring unit) -// * Duration (requiring unit) -// * List, LargeList, FixedSizeList -// * Struct -// * Union -// * Dictionary -// * Map -// -// Generally kernels will use the "FirstType" OutputType::Resolver above for -// the OutputType of the kernel's signature and match::SameTypeId for the -// corresponding InputType -const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes(); - -// Number types without boolean -const std::vector<std::shared_ptr<DataType>>& NumericTypes(); - -// Temporal types including time and timestamps for each unit -const std::vector<std::shared_ptr<DataType>>& TemporalTypes(); - -// Integer, floating point, base binary, and temporal -const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes(); - -// ---------------------------------------------------------------------- -// "Applicators" take an operator definition (which may be scalar-valued or -// array-valued) and creates an ArrayKernelExec which can be used to add an -// ArrayKernel to a Function. - -namespace applicator { - -// Generate an ArrayKernelExec given a functor that handles all of its own -// iteration, etc. -// -// Operator must implement -// + +ARROW_EXPORT +const std::vector<TimeUnit::type>& AllTimeUnits(); + +// Returns a vector of example instances of parametric types such as +// +// * Decimal +// * Timestamp (requiring unit) +// * Time32 (requiring unit) +// * Time64 (requiring unit) +// * Duration (requiring unit) +// * List, LargeList, FixedSizeList +// * Struct +// * Union +// * Dictionary +// * Map +// +// Generally kernels will use the "FirstType" OutputType::Resolver above for +// the OutputType of the kernel's signature and match::SameTypeId for the +// corresponding InputType +const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes(); + +// Number types without boolean +const std::vector<std::shared_ptr<DataType>>& NumericTypes(); + +// Temporal types including time and timestamps for each unit +const std::vector<std::shared_ptr<DataType>>& TemporalTypes(); + +// Integer, floating point, base binary, and temporal +const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes(); + +// ---------------------------------------------------------------------- +// "Applicators" take an operator definition (which may be scalar-valued or +// array-valued) and creates an ArrayKernelExec which can be used to add an +// ArrayKernel to a Function. + +namespace applicator { + +// Generate an ArrayKernelExec given a functor that handles all of its own +// iteration, etc. +// +// Operator must implement +// // static Status Call(KernelContext*, const ArrayData& in, ArrayData* out) // static Status Call(KernelContext*, const Scalar& in, Scalar* out) -template <typename Operator> +template <typename Operator> static Status SimpleUnary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::SCALAR) { + if (batch[0].kind() == Datum::SCALAR) { return Operator::Call(ctx, *batch[0].scalar(), out->scalar().get()); - } else if (batch.length > 0) { + } else if (batch.length > 0) { return Operator::Call(ctx, *batch[0].array(), out->mutable_array()); - } + } return Status::OK(); -} - -// Generate an ArrayKernelExec given a functor that handles all of its own -// iteration, etc. -// -// Operator must implement -// +} + +// Generate an ArrayKernelExec given a functor that handles all of its own +// iteration, etc. +// +// Operator must implement +// // static Status Call(KernelContext*, const ArrayData& arg0, const ArrayData& arg1, // ArrayData* out) // static Status Call(KernelContext*, const ArrayData& arg0, const Scalar& arg1, @@ -508,7 +508,7 @@ static Status SimpleUnary(KernelContext* ctx, const ExecBatch& batch, Datum* out // ArrayData* out) // static Status Call(KernelContext*, const Scalar& arg0, const Scalar& arg1, // Scalar* out) -template <typename Operator> +template <typename Operator> static Status SimpleBinary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (batch.length == 0) return Status::OK(); @@ -528,50 +528,50 @@ static Status SimpleBinary(KernelContext* ctx, const ExecBatch& batch, Datum* ou return Operator::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), out->scalar().get()); } - } -} - -// OutputAdapter allows passing an inlineable lambda that provides a sequence -// of output values to write into output memory. Boolean and primitive outputs -// are currently implemented, and the validity bitmap is presumed to be handled -// at a higher level, so this writes into every output slot, null or not. -template <typename Type, typename Enable = void> -struct OutputAdapter; - -template <typename Type> -struct OutputAdapter<Type, enable_if_boolean<Type>> { - template <typename Generator> + } +} + +// OutputAdapter allows passing an inlineable lambda that provides a sequence +// of output values to write into output memory. Boolean and primitive outputs +// are currently implemented, and the validity bitmap is presumed to be handled +// at a higher level, so this writes into every output slot, null or not. +template <typename Type, typename Enable = void> +struct OutputAdapter; + +template <typename Type> +struct OutputAdapter<Type, enable_if_boolean<Type>> { + template <typename Generator> static Status Write(KernelContext*, Datum* out, Generator&& generator) { - ArrayData* out_arr = out->mutable_array(); - auto out_bitmap = out_arr->buffers[1]->mutable_data(); - GenerateBitsUnrolled(out_bitmap, out_arr->offset, out_arr->length, - std::forward<Generator>(generator)); + ArrayData* out_arr = out->mutable_array(); + auto out_bitmap = out_arr->buffers[1]->mutable_data(); + GenerateBitsUnrolled(out_bitmap, out_arr->offset, out_arr->length, + std::forward<Generator>(generator)); return Status::OK(); - } -}; - -template <typename Type> -struct OutputAdapter<Type, enable_if_has_c_type_not_boolean<Type>> { - template <typename Generator> + } +}; + +template <typename Type> +struct OutputAdapter<Type, enable_if_has_c_type_not_boolean<Type>> { + template <typename Generator> static Status Write(KernelContext*, Datum* out, Generator&& generator) { - ArrayData* out_arr = out->mutable_array(); - auto out_data = out_arr->GetMutableValues<typename Type::c_type>(1); - // TODO: Is this as fast as a more explicitly inlined function? - for (int64_t i = 0; i < out_arr->length; ++i) { - *out_data++ = generator(); - } + ArrayData* out_arr = out->mutable_array(); + auto out_data = out_arr->GetMutableValues<typename Type::c_type>(1); + // TODO: Is this as fast as a more explicitly inlined function? + for (int64_t i = 0; i < out_arr->length; ++i) { + *out_data++ = generator(); + } return Status::OK(); - } -}; - -template <typename Type> -struct OutputAdapter<Type, enable_if_base_binary<Type>> { - template <typename Generator> + } +}; + +template <typename Type> +struct OutputAdapter<Type, enable_if_base_binary<Type>> { + template <typename Generator> static Status Write(KernelContext* ctx, Datum* out, Generator&& generator) { return Status::NotImplemented("NYI"); - } -}; - + } +}; + template <typename Type> struct OutputAdapter<Type, enable_if_decimal<Type>> { using T = typename TypeTraits<Type>::ScalarType::ValueType; @@ -588,578 +588,578 @@ struct OutputAdapter<Type, enable_if_decimal<Type>> { } }; -// A kernel exec generator for unary functions that addresses both array and -// scalar inputs and dispatches input iteration and output writing to other -// templates -// -// This template executes the operator even on the data behind null values, -// therefore it is generally only suitable for operators that are safe to apply -// even on the null slot values. -// -// The "Op" functor should have the form -// -// struct Op { -// template <typename OutValue, typename Arg0Value> +// A kernel exec generator for unary functions that addresses both array and +// scalar inputs and dispatches input iteration and output writing to other +// templates +// +// This template executes the operator even on the data behind null values, +// therefore it is generally only suitable for operators that are safe to apply +// even on the null slot values. +// +// The "Op" functor should have the form +// +// struct Op { +// template <typename OutValue, typename Arg0Value> // static OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) { -// // implementation +// // implementation // // NOTE: "status" should only populated with errors, // // leave it unmodified to indicate Status::OK() -// } -// }; -template <typename OutType, typename Arg0Type, typename Op> -struct ScalarUnary { - using OutValue = typename GetOutputType<OutType>::T; - using Arg0Value = typename GetViewType<Arg0Type>::T; - +// } +// }; +template <typename OutType, typename Arg0Type, typename Op> +struct ScalarUnary { + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + static Status ExecArray(KernelContext* ctx, const ArrayData& arg0, Datum* out) { Status st = Status::OK(); - ArrayIterator<Arg0Type> arg0_it(arg0); + ArrayIterator<Arg0Type> arg0_it(arg0); RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { return Op::template Call<OutValue, Arg0Value>(ctx, arg0_it(), &st); })); return st; - } - + } + static Status ExecScalar(KernelContext* ctx, const Scalar& arg0, Datum* out) { Status st = Status::OK(); Scalar* out_scalar = out->scalar().get(); - if (arg0.is_valid) { - Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + if (arg0.is_valid) { + Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); out_scalar->is_valid = true; BoxScalar<OutType>::Box(Op::template Call<OutValue, Arg0Value>(ctx, arg0_val, &st), out_scalar); - } else { + } else { out_scalar->is_valid = false; - } + } return st; - } - + } + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { + if (batch[0].kind() == Datum::ARRAY) { return ExecArray(ctx, *batch[0].array(), out); - } else { + } else { return ExecScalar(ctx, *batch[0].scalar(), out); - } - } -}; - -// An alternative to ScalarUnary that Applies a scalar operation with state on -// only the not-null values of a single array -template <typename OutType, typename Arg0Type, typename Op> -struct ScalarUnaryNotNullStateful { - using ThisType = ScalarUnaryNotNullStateful<OutType, Arg0Type, Op>; - using OutValue = typename GetOutputType<OutType>::T; - using Arg0Value = typename GetViewType<Arg0Type>::T; - - Op op; - explicit ScalarUnaryNotNullStateful(Op op) : op(std::move(op)) {} - - // NOTE: In ArrayExec<Type>, Type is really OutputType - - template <typename Type, typename Enable = void> - struct ArrayExec { + } + } +}; + +// An alternative to ScalarUnary that Applies a scalar operation with state on +// only the not-null values of a single array +template <typename OutType, typename Arg0Type, typename Op> +struct ScalarUnaryNotNullStateful { + using ThisType = ScalarUnaryNotNullStateful<OutType, Arg0Type, Op>; + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + + Op op; + explicit ScalarUnaryNotNullStateful(Op op) : op(std::move(op)) {} + + // NOTE: In ArrayExec<Type>, Type is really OutputType + + template <typename Type, typename Enable = void> + struct ArrayExec { static Status Exec(const ThisType& functor, KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ARROW_LOG(FATAL) << "Missing ArrayExec specialization for output type " - << out->type(); + ARROW_LOG(FATAL) << "Missing ArrayExec specialization for output type " + << out->type(); return Status::NotImplemented("NYI"); - } - }; - - template <typename Type> - struct ArrayExec< - Type, enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value>> { + } + }; + + template <typename Type> + struct ArrayExec< + Type, enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value>> { static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0, Datum* out) { Status st = Status::OK(); - ArrayData* out_arr = out->mutable_array(); - auto out_data = out_arr->GetMutableValues<OutValue>(1); - VisitArrayValuesInline<Arg0Type>( - arg0, - [&](Arg0Value v) { + ArrayData* out_arr = out->mutable_array(); + auto out_data = out_arr->GetMutableValues<OutValue>(1); + VisitArrayValuesInline<Arg0Type>( + arg0, + [&](Arg0Value v) { *out_data++ = functor.op.template Call<OutValue, Arg0Value>(ctx, v, &st); - }, - [&]() { - // null + }, + [&]() { + // null *out_data++ = OutValue{}; - }); + }); return st; - } - }; - - template <typename Type> - struct ArrayExec<Type, enable_if_base_binary<Type>> { + } + }; + + template <typename Type> + struct ArrayExec<Type, enable_if_base_binary<Type>> { static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0, Datum* out) { - // NOTE: This code is not currently used by any kernels and has - // suboptimal performance because it's recomputing the validity bitmap - // that is already computed by the kernel execution layer. Consider - // writing a lower-level "output adapter" for base binary types. - typename TypeTraits<Type>::BuilderType builder; + // NOTE: This code is not currently used by any kernels and has + // suboptimal performance because it's recomputing the validity bitmap + // that is already computed by the kernel execution layer. Consider + // writing a lower-level "output adapter" for base binary types. + typename TypeTraits<Type>::BuilderType builder; Status st = Status::OK(); RETURN_NOT_OK(VisitArrayValuesInline<Arg0Type>( arg0, [&](Arg0Value v) { return builder.Append(functor.op.Call(ctx, v, &st)); }, [&]() { return builder.AppendNull(); })); if (st.ok()) { - std::shared_ptr<ArrayData> result; + std::shared_ptr<ArrayData> result; RETURN_NOT_OK(builder.FinishInternal(&result)); - out->value = std::move(result); - } + out->value = std::move(result); + } return st; - } - }; - - template <typename Type> - struct ArrayExec<Type, enable_if_t<is_boolean_type<Type>::value>> { + } + }; + + template <typename Type> + struct ArrayExec<Type, enable_if_t<is_boolean_type<Type>::value>> { static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0, Datum* out) { Status st = Status::OK(); - ArrayData* out_arr = out->mutable_array(); - FirstTimeBitmapWriter out_writer(out_arr->buffers[1]->mutable_data(), - out_arr->offset, out_arr->length); - VisitArrayValuesInline<Arg0Type>( - arg0, - [&](Arg0Value v) { + ArrayData* out_arr = out->mutable_array(); + FirstTimeBitmapWriter out_writer(out_arr->buffers[1]->mutable_data(), + out_arr->offset, out_arr->length); + VisitArrayValuesInline<Arg0Type>( + arg0, + [&](Arg0Value v) { if (functor.op.template Call<OutValue, Arg0Value>(ctx, v, &st)) { - out_writer.Set(); - } - out_writer.Next(); - }, - [&]() { - // null - out_writer.Clear(); - out_writer.Next(); - }); - out_writer.Finish(); + out_writer.Set(); + } + out_writer.Next(); + }, + [&]() { + // null + out_writer.Clear(); + out_writer.Next(); + }); + out_writer.Finish(); return st; - } - }; - - template <typename Type> + } + }; + + template <typename Type> struct ArrayExec<Type, enable_if_decimal<Type>> { static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0, Datum* out) { Status st = Status::OK(); - ArrayData* out_arr = out->mutable_array(); + ArrayData* out_arr = out->mutable_array(); // Decimal128 data buffers are not safely reinterpret_cast-able on big-endian using endian_agnostic = std::array<uint8_t, sizeof(typename TypeTraits<Type>::ScalarType::ValueType)>; auto out_data = out_arr->GetMutableValues<endian_agnostic>(1); - VisitArrayValuesInline<Arg0Type>( - arg0, - [&](Arg0Value v) { + VisitArrayValuesInline<Arg0Type>( + arg0, + [&](Arg0Value v) { functor.op.template Call<OutValue, Arg0Value>(ctx, v, &st) .ToBytes(out_data++->data()); - }, + }, [&]() { // null std::memset(out_data, 0, sizeof(*out_data)); ++out_data; }); return st; - } - }; - + } + }; + Status Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) { Status st = Status::OK(); - if (arg0.is_valid) { - Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + if (arg0.is_valid) { + Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); BoxScalar<OutType>::Box( this->op.template Call<OutValue, Arg0Value>(ctx, arg0_val, &st), out->scalar().get()); - } + } return st; - } - + } + Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { + if (batch[0].kind() == Datum::ARRAY) { return ArrayExec<OutType>::Exec(*this, ctx, *batch[0].array(), out); - } else { - return Scalar(ctx, *batch[0].scalar(), out); - } - } -}; - -// An alternative to ScalarUnary that Applies a scalar operation on only the -// not-null values of a single array. The operator is not stateful; if the -// operator requires some initialization use ScalarUnaryNotNullStateful -template <typename OutType, typename Arg0Type, typename Op> -struct ScalarUnaryNotNull { - using OutValue = typename GetOutputType<OutType>::T; - using Arg0Value = typename GetViewType<Arg0Type>::T; - + } else { + return Scalar(ctx, *batch[0].scalar(), out); + } + } +}; + +// An alternative to ScalarUnary that Applies a scalar operation on only the +// not-null values of a single array. The operator is not stateful; if the +// operator requires some initialization use ScalarUnaryNotNullStateful +template <typename OutType, typename Arg0Type, typename Op> +struct ScalarUnaryNotNull { + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Seed kernel with dummy state - ScalarUnaryNotNullStateful<OutType, Arg0Type, Op> kernel({}); - return kernel.Exec(ctx, batch, out); - } -}; - -// A kernel exec generator for binary functions that addresses both array and -// scalar inputs and dispatches input iteration and output writing to other -// templates -// -// This template executes the operator even on the data behind null values, -// therefore it is generally only suitable for operators that are safe to apply -// even on the null slot values. -// -// The "Op" functor should have the form -// -// struct Op { -// template <typename OutValue, typename Arg0Value, typename Arg1Value> + // Seed kernel with dummy state + ScalarUnaryNotNullStateful<OutType, Arg0Type, Op> kernel({}); + return kernel.Exec(ctx, batch, out); + } +}; + +// A kernel exec generator for binary functions that addresses both array and +// scalar inputs and dispatches input iteration and output writing to other +// templates +// +// This template executes the operator even on the data behind null values, +// therefore it is generally only suitable for operators that are safe to apply +// even on the null slot values. +// +// The "Op" functor should have the form +// +// struct Op { +// template <typename OutValue, typename Arg0Value, typename Arg1Value> // static OutValue Call(KernelContext* ctx, Arg0Value arg0, Arg1Value arg1, Status* st) // { -// // implementation +// // implementation // // NOTE: "status" should only populated with errors, // // leave it unmodified to indicate Status::OK() -// } -// }; -template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op> -struct ScalarBinary { - using OutValue = typename GetOutputType<OutType>::T; - using Arg0Value = typename GetViewType<Arg0Type>::T; - using Arg1Value = typename GetViewType<Arg1Type>::T; - +// } +// }; +template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op> +struct ScalarBinary { + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + using Arg1Value = typename GetViewType<Arg1Type>::T; + static Status ArrayArray(KernelContext* ctx, const ArrayData& arg0, const ArrayData& arg1, Datum* out) { Status st = Status::OK(); - ArrayIterator<Arg0Type> arg0_it(arg0); - ArrayIterator<Arg1Type> arg1_it(arg1); + ArrayIterator<Arg0Type> arg0_it(arg0); + ArrayIterator<Arg1Type> arg1_it(arg1); RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_it(), arg1_it(), &st); })); return st; - } - + } + static Status ArrayScalar(KernelContext* ctx, const ArrayData& arg0, const Scalar& arg1, Datum* out) { Status st = Status::OK(); - ArrayIterator<Arg0Type> arg0_it(arg0); - auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + ArrayIterator<Arg0Type> arg0_it(arg0); + auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_it(), arg1_val, &st); })); return st; - } - + } + static Status ScalarArray(KernelContext* ctx, const Scalar& arg0, const ArrayData& arg1, Datum* out) { Status st = Status::OK(); - auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); - ArrayIterator<Arg1Type> arg1_it(arg1); + auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + ArrayIterator<Arg1Type> arg1_it(arg1); RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_it(), &st); })); return st; - } - + } + static Status ScalarScalar(KernelContext* ctx, const Scalar& arg0, const Scalar& arg1, Datum* out) { Status st = Status::OK(); - if (out->scalar()->is_valid) { - auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); - auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + if (out->scalar()->is_valid) { + auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); BoxScalar<OutType>::Box( Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_val, &st), out->scalar().get()); - } + } return st; - } - + } + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { - if (batch[1].kind() == Datum::ARRAY) { - return ArrayArray(ctx, *batch[0].array(), *batch[1].array(), out); - } else { - return ArrayScalar(ctx, *batch[0].array(), *batch[1].scalar(), out); - } - } else { - if (batch[1].kind() == Datum::ARRAY) { - return ScalarArray(ctx, *batch[0].scalar(), *batch[1].array(), out); - } else { - return ScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out); - } - } - } -}; - -// An alternative to ScalarBinary that Applies a scalar operation with state on -// only the value pairs which are not-null in both arrays -template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op> -struct ScalarBinaryNotNullStateful { - using ThisType = ScalarBinaryNotNullStateful<OutType, Arg0Type, Arg1Type, Op>; - using OutValue = typename GetOutputType<OutType>::T; - using Arg0Value = typename GetViewType<Arg0Type>::T; - using Arg1Value = typename GetViewType<Arg1Type>::T; - - Op op; - explicit ScalarBinaryNotNullStateful(Op op) : op(std::move(op)) {} - - // NOTE: In ArrayExec<Type>, Type is really OutputType - + if (batch[0].kind() == Datum::ARRAY) { + if (batch[1].kind() == Datum::ARRAY) { + return ArrayArray(ctx, *batch[0].array(), *batch[1].array(), out); + } else { + return ArrayScalar(ctx, *batch[0].array(), *batch[1].scalar(), out); + } + } else { + if (batch[1].kind() == Datum::ARRAY) { + return ScalarArray(ctx, *batch[0].scalar(), *batch[1].array(), out); + } else { + return ScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out); + } + } + } +}; + +// An alternative to ScalarBinary that Applies a scalar operation with state on +// only the value pairs which are not-null in both arrays +template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op> +struct ScalarBinaryNotNullStateful { + using ThisType = ScalarBinaryNotNullStateful<OutType, Arg0Type, Arg1Type, Op>; + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + using Arg1Value = typename GetViewType<Arg1Type>::T; + + Op op; + explicit ScalarBinaryNotNullStateful(Op op) : op(std::move(op)) {} + + // NOTE: In ArrayExec<Type>, Type is really OutputType + Status ArrayArray(KernelContext* ctx, const ArrayData& arg0, const ArrayData& arg1, Datum* out) { Status st = Status::OK(); - OutputArrayWriter<OutType> writer(out->mutable_array()); - VisitTwoArrayValuesInline<Arg0Type, Arg1Type>( - arg0, arg1, - [&](Arg0Value u, Arg1Value v) { + OutputArrayWriter<OutType> writer(out->mutable_array()); + VisitTwoArrayValuesInline<Arg0Type, Arg1Type>( + arg0, arg1, + [&](Arg0Value u, Arg1Value v) { writer.Write(op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, u, v, &st)); - }, - [&]() { writer.WriteNull(); }); + }, + [&]() { writer.WriteNull(); }); return st; - } - + } + Status ArrayScalar(KernelContext* ctx, const ArrayData& arg0, const Scalar& arg1, Datum* out) { Status st = Status::OK(); - OutputArrayWriter<OutType> writer(out->mutable_array()); - if (arg1.is_valid) { - const auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); - VisitArrayValuesInline<Arg0Type>( - arg0, - [&](Arg0Value u) { - writer.Write( + OutputArrayWriter<OutType> writer(out->mutable_array()); + if (arg1.is_valid) { + const auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + VisitArrayValuesInline<Arg0Type>( + arg0, + [&](Arg0Value u) { + writer.Write( op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, u, arg1_val, &st)); - }, - [&]() { writer.WriteNull(); }); + }, + [&]() { writer.WriteNull(); }); } else { writer.WriteAllNull(out->mutable_array()->length); - } + } return st; - } - + } + Status ScalarArray(KernelContext* ctx, const Scalar& arg0, const ArrayData& arg1, Datum* out) { Status st = Status::OK(); - OutputArrayWriter<OutType> writer(out->mutable_array()); - if (arg0.is_valid) { - const auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); - VisitArrayValuesInline<Arg1Type>( - arg1, - [&](Arg1Value v) { - writer.Write( + OutputArrayWriter<OutType> writer(out->mutable_array()); + if (arg0.is_valid) { + const auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + VisitArrayValuesInline<Arg1Type>( + arg1, + [&](Arg1Value v) { + writer.Write( op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, v, &st)); - }, - [&]() { writer.WriteNull(); }); + }, + [&]() { writer.WriteNull(); }); } else { writer.WriteAllNull(out->mutable_array()->length); - } + } return st; - } - + } + Status ScalarScalar(KernelContext* ctx, const Scalar& arg0, const Scalar& arg1, Datum* out) { Status st = Status::OK(); - if (arg0.is_valid && arg1.is_valid) { - const auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); - const auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); - BoxScalar<OutType>::Box( + if (arg0.is_valid && arg1.is_valid) { + const auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + const auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + BoxScalar<OutType>::Box( op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_val, &st), - out->scalar().get()); - } + out->scalar().get()); + } return st; - } - + } + Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { - if (batch[1].kind() == Datum::ARRAY) { - return ArrayArray(ctx, *batch[0].array(), *batch[1].array(), out); - } else { - return ArrayScalar(ctx, *batch[0].array(), *batch[1].scalar(), out); - } - } else { - if (batch[1].kind() == Datum::ARRAY) { - return ScalarArray(ctx, *batch[0].scalar(), *batch[1].array(), out); - } else { - return ScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out); - } - } - } -}; - -// An alternative to ScalarBinary that Applies a scalar operation on only -// the value pairs which are not-null in both arrays. -// The operator is not stateful; if the operator requires some initialization -// use ScalarBinaryNotNullStateful. -template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op> -struct ScalarBinaryNotNull { - using OutValue = typename GetOutputType<OutType>::T; - using Arg0Value = typename GetViewType<Arg0Type>::T; - using Arg1Value = typename GetViewType<Arg1Type>::T; - + if (batch[0].kind() == Datum::ARRAY) { + if (batch[1].kind() == Datum::ARRAY) { + return ArrayArray(ctx, *batch[0].array(), *batch[1].array(), out); + } else { + return ArrayScalar(ctx, *batch[0].array(), *batch[1].scalar(), out); + } + } else { + if (batch[1].kind() == Datum::ARRAY) { + return ScalarArray(ctx, *batch[0].scalar(), *batch[1].array(), out); + } else { + return ScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out); + } + } + } +}; + +// An alternative to ScalarBinary that Applies a scalar operation on only +// the value pairs which are not-null in both arrays. +// The operator is not stateful; if the operator requires some initialization +// use ScalarBinaryNotNullStateful. +template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op> +struct ScalarBinaryNotNull { + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + using Arg1Value = typename GetViewType<Arg1Type>::T; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Seed kernel with dummy state - ScalarBinaryNotNullStateful<OutType, Arg0Type, Arg1Type, Op> kernel({}); - return kernel.Exec(ctx, batch, out); - } -}; - -// A kernel exec generator for binary kernels where both input types are the -// same -template <typename OutType, typename ArgType, typename Op> -using ScalarBinaryEqualTypes = ScalarBinary<OutType, ArgType, ArgType, Op>; - -// A kernel exec generator for non-null binary kernels where both input types are the -// same -template <typename OutType, typename ArgType, typename Op> -using ScalarBinaryNotNullEqualTypes = ScalarBinaryNotNull<OutType, ArgType, ArgType, Op>; - -} // namespace applicator - -// ---------------------------------------------------------------------- -// BEGIN of kernel generator-dispatchers ("GD") -// -// These GD functions instantiate kernel functor templates and select one of -// the instantiated kernels dynamically based on the data type or Type::type id -// that is passed. This enables functions to be populated with kernels by -// looping over vectors of data types rather than using macros or other -// approaches. -// -// The kernel functor must be of the form: -// -// template <typename Type0, typename Type1, Args...> -// struct FUNCTOR { -// static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { -// // IMPLEMENTATION -// } -// }; -// -// When you pass FUNCTOR to a GD function, you must pass at least one static -// type along with the functor -- this is often the fixed return type of the -// functor. This Type0 argument is passed as the first argument to the functor -// during instantiation. The 2nd type passed to the functor is the DataType -// subclass corresponding to the type passed as argument (not template type) to -// the function. -// -// For example, GenerateNumeric<FUNCTOR, Type0>(int32()) will select a kernel -// instantiated like FUNCTOR<Type0, Int32Type>. Any additional variadic -// template arguments will be passed as additional template arguments to the -// kernel template. - -namespace detail { - -// Convenience so we can pass DataType or Type::type for the GD's -struct GetTypeId { - Type::type id; - GetTypeId(const std::shared_ptr<DataType>& type) // NOLINT implicit construction - : id(type->id()) {} - GetTypeId(const DataType& type) // NOLINT implicit construction - : id(type.id()) {} - GetTypeId(Type::type id) // NOLINT implicit construction - : id(id) {} -}; - -} // namespace detail - -// GD for numeric types (integer and floating point) -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateNumeric(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return Generator<Type0, Int8Type, Args...>::Exec; - case Type::UINT8: - return Generator<Type0, UInt8Type, Args...>::Exec; - case Type::INT16: - return Generator<Type0, Int16Type, Args...>::Exec; - case Type::UINT16: - return Generator<Type0, UInt16Type, Args...>::Exec; - case Type::INT32: - return Generator<Type0, Int32Type, Args...>::Exec; - case Type::UINT32: - return Generator<Type0, UInt32Type, Args...>::Exec; - case Type::INT64: - return Generator<Type0, Int64Type, Args...>::Exec; - case Type::UINT64: - return Generator<Type0, UInt64Type, Args...>::Exec; - case Type::FLOAT: - return Generator<Type0, FloatType, Args...>::Exec; - case Type::DOUBLE: - return Generator<Type0, DoubleType, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - -// Generate a kernel given a templated functor for floating point types -// -// See "Numeric" above for description of the generator functor -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::FLOAT: - return Generator<Type0, FloatType, Args...>::Exec; - case Type::DOUBLE: - return Generator<Type0, DoubleType, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - -// Generate a kernel given a templated functor for integer types -// -// See "Numeric" above for description of the generator functor -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return Generator<Type0, Int8Type, Args...>::Exec; - case Type::INT16: - return Generator<Type0, Int16Type, Args...>::Exec; - case Type::INT32: - return Generator<Type0, Int32Type, Args...>::Exec; - case Type::INT64: - return Generator<Type0, Int64Type, Args...>::Exec; - case Type::UINT8: - return Generator<Type0, UInt8Type, Args...>::Exec; - case Type::UINT16: - return Generator<Type0, UInt16Type, Args...>::Exec; - case Type::UINT32: - return Generator<Type0, UInt32Type, Args...>::Exec; - case Type::UINT64: - return Generator<Type0, UInt64Type, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GeneratePhysicalInteger(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return Generator<Type0, Int8Type, Args...>::Exec; - case Type::INT16: - return Generator<Type0, Int16Type, Args...>::Exec; - case Type::INT32: - case Type::DATE32: - case Type::TIME32: - return Generator<Type0, Int32Type, Args...>::Exec; - case Type::INT64: - case Type::DATE64: - case Type::TIMESTAMP: - case Type::TIME64: - case Type::DURATION: - return Generator<Type0, Int64Type, Args...>::Exec; - case Type::UINT8: - return Generator<Type0, UInt8Type, Args...>::Exec; - case Type::UINT16: - return Generator<Type0, UInt16Type, Args...>::Exec; - case Type::UINT32: - return Generator<Type0, UInt32Type, Args...>::Exec; - case Type::UINT64: - return Generator<Type0, UInt64Type, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - + // Seed kernel with dummy state + ScalarBinaryNotNullStateful<OutType, Arg0Type, Arg1Type, Op> kernel({}); + return kernel.Exec(ctx, batch, out); + } +}; + +// A kernel exec generator for binary kernels where both input types are the +// same +template <typename OutType, typename ArgType, typename Op> +using ScalarBinaryEqualTypes = ScalarBinary<OutType, ArgType, ArgType, Op>; + +// A kernel exec generator for non-null binary kernels where both input types are the +// same +template <typename OutType, typename ArgType, typename Op> +using ScalarBinaryNotNullEqualTypes = ScalarBinaryNotNull<OutType, ArgType, ArgType, Op>; + +} // namespace applicator + +// ---------------------------------------------------------------------- +// BEGIN of kernel generator-dispatchers ("GD") +// +// These GD functions instantiate kernel functor templates and select one of +// the instantiated kernels dynamically based on the data type or Type::type id +// that is passed. This enables functions to be populated with kernels by +// looping over vectors of data types rather than using macros or other +// approaches. +// +// The kernel functor must be of the form: +// +// template <typename Type0, typename Type1, Args...> +// struct FUNCTOR { +// static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { +// // IMPLEMENTATION +// } +// }; +// +// When you pass FUNCTOR to a GD function, you must pass at least one static +// type along with the functor -- this is often the fixed return type of the +// functor. This Type0 argument is passed as the first argument to the functor +// during instantiation. The 2nd type passed to the functor is the DataType +// subclass corresponding to the type passed as argument (not template type) to +// the function. +// +// For example, GenerateNumeric<FUNCTOR, Type0>(int32()) will select a kernel +// instantiated like FUNCTOR<Type0, Int32Type>. Any additional variadic +// template arguments will be passed as additional template arguments to the +// kernel template. + +namespace detail { + +// Convenience so we can pass DataType or Type::type for the GD's +struct GetTypeId { + Type::type id; + GetTypeId(const std::shared_ptr<DataType>& type) // NOLINT implicit construction + : id(type->id()) {} + GetTypeId(const DataType& type) // NOLINT implicit construction + : id(type.id()) {} + GetTypeId(Type::type id) // NOLINT implicit construction + : id(id) {} +}; + +} // namespace detail + +// GD for numeric types (integer and floating point) +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateNumeric(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::INT8: + return Generator<Type0, Int8Type, Args...>::Exec; + case Type::UINT8: + return Generator<Type0, UInt8Type, Args...>::Exec; + case Type::INT16: + return Generator<Type0, Int16Type, Args...>::Exec; + case Type::UINT16: + return Generator<Type0, UInt16Type, Args...>::Exec; + case Type::INT32: + return Generator<Type0, Int32Type, Args...>::Exec; + case Type::UINT32: + return Generator<Type0, UInt32Type, Args...>::Exec; + case Type::INT64: + return Generator<Type0, Int64Type, Args...>::Exec; + case Type::UINT64: + return Generator<Type0, UInt64Type, Args...>::Exec; + case Type::FLOAT: + return Generator<Type0, FloatType, Args...>::Exec; + case Type::DOUBLE: + return Generator<Type0, DoubleType, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + +// Generate a kernel given a templated functor for floating point types +// +// See "Numeric" above for description of the generator functor +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::FLOAT: + return Generator<Type0, FloatType, Args...>::Exec; + case Type::DOUBLE: + return Generator<Type0, DoubleType, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + +// Generate a kernel given a templated functor for integer types +// +// See "Numeric" above for description of the generator functor +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::INT8: + return Generator<Type0, Int8Type, Args...>::Exec; + case Type::INT16: + return Generator<Type0, Int16Type, Args...>::Exec; + case Type::INT32: + return Generator<Type0, Int32Type, Args...>::Exec; + case Type::INT64: + return Generator<Type0, Int64Type, Args...>::Exec; + case Type::UINT8: + return Generator<Type0, UInt8Type, Args...>::Exec; + case Type::UINT16: + return Generator<Type0, UInt16Type, Args...>::Exec; + case Type::UINT32: + return Generator<Type0, UInt32Type, Args...>::Exec; + case Type::UINT64: + return Generator<Type0, UInt64Type, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GeneratePhysicalInteger(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::INT8: + return Generator<Type0, Int8Type, Args...>::Exec; + case Type::INT16: + return Generator<Type0, Int16Type, Args...>::Exec; + case Type::INT32: + case Type::DATE32: + case Type::TIME32: + return Generator<Type0, Int32Type, Args...>::Exec; + case Type::INT64: + case Type::DATE64: + case Type::TIMESTAMP: + case Type::TIME64: + case Type::DURATION: + return Generator<Type0, Int64Type, Args...>::Exec; + case Type::UINT8: + return Generator<Type0, UInt8Type, Args...>::Exec; + case Type::UINT16: + return Generator<Type0, UInt16Type, Args...>::Exec; + case Type::UINT32: + return Generator<Type0, UInt32Type, Args...>::Exec; + case Type::UINT64: + return Generator<Type0, UInt64Type, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + template <template <typename... Args> class Generator, typename... Args> ArrayKernelExec GeneratePhysicalNumeric(detail::GetTypeId get_id) { switch (get_id.id) { @@ -1195,68 +1195,68 @@ ArrayKernelExec GeneratePhysicalNumeric(detail::GetTypeId get_id) { } } -// Generate a kernel given a templated functor for integer types -// -// See "Numeric" above for description of the generator functor -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateSignedInteger(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return Generator<Type0, Int8Type, Args...>::Exec; - case Type::INT16: - return Generator<Type0, Int16Type, Args...>::Exec; - case Type::INT32: - return Generator<Type0, Int32Type, Args...>::Exec; - case Type::INT64: - return Generator<Type0, Int64Type, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - -// Generate a kernel given a templated functor. Only a single template is -// instantiated for each bit width, and the functor is expected to treat types -// of the same bit width the same without utilizing any type-specific behavior -// (e.g. int64 should be handled equivalent to uint64 or double -- all 64 -// bits). -// -// See "Numeric" above for description of the generator functor +// Generate a kernel given a templated functor for integer types +// +// See "Numeric" above for description of the generator functor +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateSignedInteger(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::INT8: + return Generator<Type0, Int8Type, Args...>::Exec; + case Type::INT16: + return Generator<Type0, Int16Type, Args...>::Exec; + case Type::INT32: + return Generator<Type0, Int32Type, Args...>::Exec; + case Type::INT64: + return Generator<Type0, Int64Type, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + +// Generate a kernel given a templated functor. Only a single template is +// instantiated for each bit width, and the functor is expected to treat types +// of the same bit width the same without utilizing any type-specific behavior +// (e.g. int64 should be handled equivalent to uint64 or double -- all 64 +// bits). +// +// See "Numeric" above for description of the generator functor template <template <typename...> class Generator, typename... Args> -ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::NA: +ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::NA: return Generator<NullType, Args...>::Exec; - case Type::BOOL: + case Type::BOOL: return Generator<BooleanType, Args...>::Exec; - case Type::UINT8: - case Type::INT8: + case Type::UINT8: + case Type::INT8: return Generator<UInt8Type, Args...>::Exec; - case Type::UINT16: - case Type::INT16: + case Type::UINT16: + case Type::INT16: return Generator<UInt16Type, Args...>::Exec; - case Type::UINT32: - case Type::INT32: - case Type::FLOAT: - case Type::DATE32: - case Type::TIME32: + case Type::UINT32: + case Type::INT32: + case Type::FLOAT: + case Type::DATE32: + case Type::TIME32: case Type::INTERVAL_MONTHS: return Generator<UInt32Type, Args...>::Exec; - case Type::UINT64: - case Type::INT64: - case Type::DOUBLE: - case Type::DATE64: - case Type::TIMESTAMP: - case Type::TIME64: - case Type::DURATION: + case Type::UINT64: + case Type::INT64: + case Type::DOUBLE: + case Type::DATE64: + case Type::TIMESTAMP: + case Type::TIME64: + case Type::DURATION: case Type::INTERVAL_DAY_TIME: return Generator<UInt64Type, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - + default: + DCHECK(false); + return ExecFail; + } +} + // similar to GenerateTypeAgnosticPrimitive, but for variable types template <template <typename...> class Generator, typename... Args> ArrayKernelExec GenerateTypeAgnosticVarBinaryBase(detail::GetTypeId get_id) { @@ -1273,69 +1273,69 @@ ArrayKernelExec GenerateTypeAgnosticVarBinaryBase(detail::GetTypeId get_id) { } } -// Generate a kernel given a templated functor for base binary types. Generates -// a single kernel for binary/string and large binary / large string. If your -// kernel implementation needs access to the specific type at compile time, -// please use BaseBinarySpecific. -// -// See "Numeric" above for description of the generator functor -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateVarBinaryBase(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::BINARY: - case Type::STRING: - return Generator<Type0, BinaryType, Args...>::Exec; - case Type::LARGE_BINARY: - case Type::LARGE_STRING: - return Generator<Type0, LargeBinaryType, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - -// See BaseBinary documentation -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateVarBinary(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::BINARY: - return Generator<Type0, BinaryType, Args...>::Exec; - case Type::STRING: - return Generator<Type0, StringType, Args...>::Exec; - case Type::LARGE_BINARY: - return Generator<Type0, LargeBinaryType, Args...>::Exec; - case Type::LARGE_STRING: - return Generator<Type0, LargeStringType, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - -// Generate a kernel given a templated functor for temporal types -// -// See "Numeric" above for description of the generator functor -template <template <typename...> class Generator, typename Type0, typename... Args> -ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::DATE32: - return Generator<Type0, Date32Type, Args...>::Exec; - case Type::DATE64: - return Generator<Type0, Date64Type, Args...>::Exec; - case Type::DURATION: - return Generator<Type0, DurationType, Args...>::Exec; - case Type::TIME32: - return Generator<Type0, Time32Type, Args...>::Exec; - case Type::TIME64: - return Generator<Type0, Time64Type, Args...>::Exec; - case Type::TIMESTAMP: - return Generator<Type0, TimestampType, Args...>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - +// Generate a kernel given a templated functor for base binary types. Generates +// a single kernel for binary/string and large binary / large string. If your +// kernel implementation needs access to the specific type at compile time, +// please use BaseBinarySpecific. +// +// See "Numeric" above for description of the generator functor +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateVarBinaryBase(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::BINARY: + case Type::STRING: + return Generator<Type0, BinaryType, Args...>::Exec; + case Type::LARGE_BINARY: + case Type::LARGE_STRING: + return Generator<Type0, LargeBinaryType, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + +// See BaseBinary documentation +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateVarBinary(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::BINARY: + return Generator<Type0, BinaryType, Args...>::Exec; + case Type::STRING: + return Generator<Type0, StringType, Args...>::Exec; + case Type::LARGE_BINARY: + return Generator<Type0, LargeBinaryType, Args...>::Exec; + case Type::LARGE_STRING: + return Generator<Type0, LargeStringType, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + +// Generate a kernel given a templated functor for temporal types +// +// See "Numeric" above for description of the generator functor +template <template <typename...> class Generator, typename Type0, typename... Args> +ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::DATE32: + return Generator<Type0, Date32Type, Args...>::Exec; + case Type::DATE64: + return Generator<Type0, Date64Type, Args...>::Exec; + case Type::DURATION: + return Generator<Type0, DurationType, Args...>::Exec; + case Type::TIME32: + return Generator<Type0, Time32Type, Args...>::Exec; + case Type::TIME64: + return Generator<Type0, Time64Type, Args...>::Exec; + case Type::TIMESTAMP: + return Generator<Type0, TimestampType, Args...>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + // Generate a kernel given a templated functor for decimal types // // See "Numeric" above for description of the generator functor @@ -1352,9 +1352,9 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) { } } -// END of kernel generator-dispatchers -// ---------------------------------------------------------------------- - +// END of kernel generator-dispatchers +// ---------------------------------------------------------------------- + ARROW_EXPORT void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs); @@ -1376,6 +1376,6 @@ std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& descrs) ARROW_EXPORT std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs); -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/common.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/common.h index 21244320f3..9ee2ec977a 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/common.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/common.h @@ -1,54 +1,54 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -// IWYU pragma: begin_exports - -#include <cstdint> -#include <memory> -#include <string> -#include <type_traits> -#include <utility> -#include <vector> - -#include "arrow/array/data.h" -#include "arrow/buffer.h" -#include "arrow/chunked_array.h" -#include "arrow/compute/exec.h" -#include "arrow/compute/function.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/codegen_internal.h" -#include "arrow/compute/registry.h" -#include "arrow/datum.h" -#include "arrow/memory_pool.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/logging.h" -#include "arrow/util/macros.h" -#include "arrow/util/string_view.h" - -// IWYU pragma: end_exports - -namespace arrow { - -using internal::checked_cast; -using internal::checked_pointer_cast; - -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +// IWYU pragma: begin_exports + +#include <cstdint> +#include <memory> +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +#include "arrow/array/data.h" +#include "arrow/buffer.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/datum.h" +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/string_view.h" + +// IWYU pragma: end_exports + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index a5d4a55774..f05cc0f3d3 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -1,83 +1,83 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include <algorithm> #include <cmath> #include <limits> #include <utility> #include "arrow/compute/kernels/codegen_internal.h" -#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/decimal.h" -#include "arrow/util/int_util_internal.h" -#include "arrow/util/macros.h" - -namespace arrow { - -using internal::AddWithOverflow; -using internal::DivideWithOverflow; -using internal::MultiplyWithOverflow; +#include "arrow/util/int_util_internal.h" +#include "arrow/util/macros.h" + +namespace arrow { + +using internal::AddWithOverflow; +using internal::DivideWithOverflow; +using internal::MultiplyWithOverflow; using internal::NegateWithOverflow; -using internal::SubtractWithOverflow; - -namespace compute { -namespace internal { - -using applicator::ScalarBinaryEqualTypes; -using applicator::ScalarBinaryNotNullEqualTypes; +using internal::SubtractWithOverflow; + +namespace compute { +namespace internal { + +using applicator::ScalarBinaryEqualTypes; +using applicator::ScalarBinaryNotNullEqualTypes; using applicator::ScalarUnary; using applicator::ScalarUnaryNotNull; - -namespace { - -template <typename T> -using is_unsigned_integer = std::integral_constant<bool, std::is_integral<T>::value && - std::is_unsigned<T>::value>; - -template <typename T> -using is_signed_integer = - std::integral_constant<bool, std::is_integral<T>::value && std::is_signed<T>::value>; - + +namespace { + +template <typename T> +using is_unsigned_integer = std::integral_constant<bool, std::is_integral<T>::value && + std::is_unsigned<T>::value>; + +template <typename T> +using is_signed_integer = + std::integral_constant<bool, std::is_integral<T>::value && std::is_signed<T>::value>; + template <typename T, typename R = T> using enable_if_signed_integer = enable_if_t<is_signed_integer<T>::value, R>; - + template <typename T, typename R = T> using enable_if_unsigned_integer = enable_if_t<is_unsigned_integer<T>::value, R>; - + template <typename T, typename R = T> -using enable_if_integer = +using enable_if_integer = enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, R>; - + template <typename T, typename R = T> using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, R>; -template <typename T> +template <typename T> using enable_if_decimal = enable_if_t<std::is_same<Decimal128, T>::value || std::is_same<Decimal256, T>::value, T>; - -template <typename T, typename Unsigned = typename std::make_unsigned<T>::type> -constexpr Unsigned to_unsigned(T signed_) { - return static_cast<Unsigned>(signed_); -} - + +template <typename T, typename Unsigned = typename std::make_unsigned<T>::type> +constexpr Unsigned to_unsigned(T signed_) { + return static_cast<Unsigned>(signed_); +} + struct AbsoluteValue { template <typename T, typename Arg> static constexpr enable_if_floating_point<T> Call(KernelContext*, T arg, Status*) { @@ -119,201 +119,201 @@ struct AbsoluteValueChecked { } }; -struct Add { +struct Add { template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - return left + right; - } - + return left + right; + } + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_unsigned_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - return left + right; - } - + return left + right; + } + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_signed_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - return arrow::internal::SafeSignedAdd(left, right); - } + return arrow::internal::SafeSignedAdd(left, right); + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left + right; } -}; - -struct AddChecked { - template <typename T, typename Arg0, typename Arg1> +}; + +struct AddChecked { + template <typename T, typename Arg0, typename Arg1> static enable_if_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - T result = 0; - if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) { + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + T result = 0; + if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); - } - return result; - } - - template <typename T, typename Arg0, typename Arg1> + } + return result; + } + + template <typename T, typename Arg0, typename Arg1> static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - return left + right; - } + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + return left + right; + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left + right; } -}; - -struct Subtract { +}; + +struct Subtract { template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - return left - right; - } - + return left - right; + } + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_unsigned_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - return left - right; - } - + return left - right; + } + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_signed_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - return arrow::internal::SafeSignedSubtract(left, right); - } + return arrow::internal::SafeSignedSubtract(left, right); + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left + (-right); } -}; - -struct SubtractChecked { - template <typename T, typename Arg0, typename Arg1> +}; + +struct SubtractChecked { + template <typename T, typename Arg0, typename Arg1> static enable_if_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - T result = 0; - if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); - } - return result; - } - - template <typename T, typename Arg0, typename Arg1> + } + return result; + } + + template <typename T, typename Arg0, typename Arg1> static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - return left - right; - } + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + return left - right; + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left + (-right); } -}; - -struct Multiply { - static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(uint16_t() * uint16_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(uint32_t() * uint32_t()), uint32_t>::value, ""); - static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value, ""); - static_assert(std::is_same<decltype(uint64_t() * uint64_t()), uint64_t>::value, ""); - +}; + +struct Multiply { + static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(uint16_t() * uint16_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(uint32_t() * uint32_t()), uint32_t>::value, ""); + static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value, ""); + static_assert(std::is_same<decltype(uint64_t() * uint64_t()), uint64_t>::value, ""); + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right, Status*) { - return left * right; - } - + return left * right; + } + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_t< is_unsigned_integer<T>::value && !std::is_same<T, uint16_t>::value, T> Call(KernelContext*, T left, T right, Status*) { - return left * right; - } - + return left * right; + } + template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_t< is_signed_integer<T>::value && !std::is_same<T, int16_t>::value, T> Call(KernelContext*, T left, T right, Status*) { - return to_unsigned(left) * to_unsigned(right); - } - - // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit - // integer. However, some inputs may nevertheless overflow (which triggers undefined - // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is - // well defined. + return to_unsigned(left) * to_unsigned(right); + } + + // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit + // integer. However, some inputs may nevertheless overflow (which triggers undefined + // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is + // well defined. template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_same<T, int16_t, T> Call(KernelContext*, int16_t left, int16_t right, Status*) { - return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); - } + return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); + } template <typename T, typename Arg0, typename Arg1> static constexpr enable_if_same<T, uint16_t, T> Call(KernelContext*, uint16_t left, uint16_t right, Status*) { - return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); - } + return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left * right; } -}; - -struct MultiplyChecked { - template <typename T, typename Arg0, typename Arg1> +}; + +struct MultiplyChecked { + template <typename T, typename Arg0, typename Arg1> static enable_if_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - T result = 0; - if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) { + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + T result = 0; + if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); - } - return result; - } - - template <typename T, typename Arg0, typename Arg1> + } + return result; + } + + template <typename T, typename Arg0, typename Arg1> static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - return left * right; - } + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + return left * right; + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left * right; } -}; - -struct Divide { - template <typename T, typename Arg0, typename Arg1> +}; + +struct Divide { + template <typename T, typename Arg0, typename Arg1> static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - return left / right; - } - - template <typename T, typename Arg0, typename Arg1> + return left / right; + } + + template <typename T, typename Arg0, typename Arg1> static enable_if_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - T result; - if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) { - if (right == 0) { + T result; + if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) { + if (right == 0) { *st = Status::Invalid("divide by zero"); - } else { - result = 0; - } - } - return result; - } + } else { + result = 0; + } + } + return result; + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { @@ -324,41 +324,41 @@ struct Divide { return left / right; } } -}; - -struct DivideChecked { - template <typename T, typename Arg0, typename Arg1> +}; + +struct DivideChecked { + template <typename T, typename Arg0, typename Arg1> static enable_if_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - T result; - if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) { - if (right == 0) { + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + T result; + if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) { + if (right == 0) { *st = Status::Invalid("divide by zero"); - } else { + } else { *st = Status::Invalid("overflow"); - } - } - return result; - } - - template <typename T, typename Arg0, typename Arg1> + } + } + return result; + } + + template <typename T, typename Arg0, typename Arg1> static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); - if (ARROW_PREDICT_FALSE(right == 0)) { + static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, ""); + if (ARROW_PREDICT_FALSE(right == 0)) { *st = Status::Invalid("divide by zero"); - return 0; - } - return left / right; - } + return 0; + } + return left / right; + } template <typename T, typename Arg0, typename Arg1> static enable_if_decimal<T> Call(KernelContext* ctx, Arg0 left, Arg1 right, Status* st) { return Divide::Call<T>(ctx, left, right, st); } -}; - +}; + struct Negate { template <typename T, typename Arg> static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg arg, Status*) { @@ -838,37 +838,37 @@ struct Trunc { } }; -// Generate a kernel given an arithmetic functor -template <template <typename... Args> class KernelGenerator, typename Op> +// Generate a kernel given an arithmetic functor +template <template <typename... Args> class KernelGenerator, typename Op> ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return KernelGenerator<Int8Type, Int8Type, Op>::Exec; - case Type::UINT8: - return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec; - case Type::INT16: - return KernelGenerator<Int16Type, Int16Type, Op>::Exec; - case Type::UINT16: - return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec; - case Type::INT32: - return KernelGenerator<Int32Type, Int32Type, Op>::Exec; - case Type::UINT32: - return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec; - case Type::INT64: - case Type::TIMESTAMP: - return KernelGenerator<Int64Type, Int64Type, Op>::Exec; - case Type::UINT64: - return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec; - case Type::FLOAT: - return KernelGenerator<FloatType, FloatType, Op>::Exec; - case Type::DOUBLE: - return KernelGenerator<DoubleType, DoubleType, Op>::Exec; - default: - DCHECK(false); - return ExecFail; - } -} - + switch (get_id.id) { + case Type::INT8: + return KernelGenerator<Int8Type, Int8Type, Op>::Exec; + case Type::UINT8: + return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec; + case Type::INT16: + return KernelGenerator<Int16Type, Int16Type, Op>::Exec; + case Type::UINT16: + return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec; + case Type::INT32: + return KernelGenerator<Int32Type, Int32Type, Op>::Exec; + case Type::UINT32: + return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec; + case Type::INT64: + case Type::TIMESTAMP: + return KernelGenerator<Int64Type, Int64Type, Op>::Exec; + case Type::UINT64: + return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec; + case Type::FLOAT: + return KernelGenerator<FloatType, FloatType, Op>::Exec; + case Type::DOUBLE: + return KernelGenerator<DoubleType, DoubleType, Op>::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + // Generate a kernel given a bitwise arithmetic functor. Assumes the // functor treats all integer types of equal width identically template <template <typename... Args> class KernelGenerator, typename Op> @@ -1050,7 +1050,7 @@ Result<ValueDescr> ResolveDecimalDivisionOutput(KernelContext*, }); } -template <typename Op> +template <typename Op> void AddDecimalBinaryKernels(const std::string& name, std::shared_ptr<ScalarFunction>* func) { OutputType out_type(null()); @@ -1182,26 +1182,26 @@ template <typename Op> std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name, const FunctionDoc* doc) { auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc); - for (const auto& ty : NumericTypes()) { + for (const auto& ty : NumericTypes()) { auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Op>(ty); - DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); - } - return func; -} - -// Like MakeArithmeticFunction, but for arithmetic ops that need to run -// only on non-null output. -template <typename Op> + DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); + } + return func; +} + +// Like MakeArithmeticFunction, but for arithmetic ops that need to run +// only on non-null output. +template <typename Op> std::shared_ptr<ScalarFunction> MakeArithmeticFunctionNotNull(std::string name, const FunctionDoc* doc) { auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc); - for (const auto& ty : NumericTypes()) { + for (const auto& ty : NumericTypes()) { auto exec = ArithmeticExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty); - DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); - } - return func; -} - + DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); + } + return func; +} + template <typename Op> std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunction(std::string name, const FunctionDoc* doc) { @@ -1606,10 +1606,10 @@ const FunctionDoc trunc_doc{ ("Calculate the nearest integer not greater in magnitude than to the " "argument element-wise."), {"x"}}; -} // namespace - -void RegisterScalarArithmetic(FunctionRegistry* registry) { - // ---------------------------------------------------------------------- +} // namespace + +void RegisterScalarArithmetic(FunctionRegistry* registry) { + // ---------------------------------------------------------------------- auto absolute_value = MakeUnaryArithmeticFunction<AbsoluteValue>("abs", &absolute_value_doc); DCHECK_OK(registry->AddFunction(std::move(absolute_value))); @@ -1622,54 +1622,54 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto add = MakeArithmeticFunction<Add>("add", &add_doc); AddDecimalBinaryKernels<Add>("add", &add); - DCHECK_OK(registry->AddFunction(std::move(add))); - - // ---------------------------------------------------------------------- + DCHECK_OK(registry->AddFunction(std::move(add))); + + // ---------------------------------------------------------------------- auto add_checked = MakeArithmeticFunctionNotNull<AddChecked>("add_checked", &add_checked_doc); AddDecimalBinaryKernels<AddChecked>("add_checked", &add_checked); - DCHECK_OK(registry->AddFunction(std::move(add_checked))); - - // ---------------------------------------------------------------------- + DCHECK_OK(registry->AddFunction(std::move(add_checked))); + + // ---------------------------------------------------------------------- auto subtract = MakeArithmeticFunction<Subtract>("subtract", &sub_doc); AddDecimalBinaryKernels<Subtract>("subtract", &subtract); - - // Add subtract(timestamp, timestamp) -> duration - for (auto unit : AllTimeUnits()) { - InputType in_type(match::TimestampTypeUnit(unit)); + + // Add subtract(timestamp, timestamp) -> duration + for (auto unit : AllTimeUnits()) { + InputType in_type(match::TimestampTypeUnit(unit)); auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP); - DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); - } - - DCHECK_OK(registry->AddFunction(std::move(subtract))); - - // ---------------------------------------------------------------------- + DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + } + + DCHECK_OK(registry->AddFunction(std::move(subtract))); + + // ---------------------------------------------------------------------- auto subtract_checked = MakeArithmeticFunctionNotNull<SubtractChecked>( "subtract_checked", &sub_checked_doc); AddDecimalBinaryKernels<SubtractChecked>("subtract_checked", &subtract_checked); - DCHECK_OK(registry->AddFunction(std::move(subtract_checked))); - - // ---------------------------------------------------------------------- + DCHECK_OK(registry->AddFunction(std::move(subtract_checked))); + + // ---------------------------------------------------------------------- auto multiply = MakeArithmeticFunction<Multiply>("multiply", &mul_doc); AddDecimalBinaryKernels<Multiply>("multiply", &multiply); - DCHECK_OK(registry->AddFunction(std::move(multiply))); - - // ---------------------------------------------------------------------- + DCHECK_OK(registry->AddFunction(std::move(multiply))); + + // ---------------------------------------------------------------------- auto multiply_checked = MakeArithmeticFunctionNotNull<MultiplyChecked>( "multiply_checked", &mul_checked_doc); AddDecimalBinaryKernels<MultiplyChecked>("multiply_checked", &multiply_checked); - DCHECK_OK(registry->AddFunction(std::move(multiply_checked))); - - // ---------------------------------------------------------------------- + DCHECK_OK(registry->AddFunction(std::move(multiply_checked))); + + // ---------------------------------------------------------------------- auto divide = MakeArithmeticFunctionNotNull<Divide>("divide", &div_doc); AddDecimalBinaryKernels<Divide>("divide", ÷); - DCHECK_OK(registry->AddFunction(std::move(divide))); - - // ---------------------------------------------------------------------- + DCHECK_OK(registry->AddFunction(std::move(divide))); + + // ---------------------------------------------------------------------- auto divide_checked = MakeArithmeticFunctionNotNull<DivideChecked>("divide_checked", &div_checked_doc); AddDecimalBinaryKernels<DivideChecked>("divide_checked", ÷_checked); - DCHECK_OK(registry->AddFunction(std::move(divide_checked))); + DCHECK_OK(registry->AddFunction(std::move(divide_checked))); // ---------------------------------------------------------------------- auto negate = MakeUnaryArithmeticFunction<Negate>("negate", &negate_doc); @@ -1816,8 +1816,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { auto trunc = MakeUnaryArithmeticFunctionFloatingPoint<Trunc>("trunc", &trunc_doc); DCHECK_OK(registry->AddFunction(std::move(trunc))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc index 7a0e3654ed..63fddcd1fe 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc @@ -1,61 +1,61 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <array> - -#include "arrow/compute/kernels/common.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap.h" -#include "arrow/util/bitmap_ops.h" - -namespace arrow { - -using internal::Bitmap; - -namespace compute { - -namespace { - -template <typename ComputeWord> -void ComputeKleene(ComputeWord&& compute_word, KernelContext* ctx, const ArrayData& left, - const ArrayData& right, ArrayData* out) { +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <array> + +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using internal::Bitmap; + +namespace compute { + +namespace { + +template <typename ComputeWord> +void ComputeKleene(ComputeWord&& compute_word, KernelContext* ctx, const ArrayData& left, + const ArrayData& right, ArrayData* out) { DCHECK(left.null_count != 0 || right.null_count != 0) << "ComputeKleene is unnecessarily expensive for the non-null case"; - + Bitmap left_valid_bm{left.buffers[0], left.offset, left.length}; Bitmap left_data_bm{left.buffers[1], left.offset, left.length}; - + Bitmap right_valid_bm{right.buffers[0], right.offset, right.length}; Bitmap right_data_bm{right.buffers[1], right.offset, right.length}; - + std::array<Bitmap, 2> out_bms{Bitmap(out->buffers[0], out->offset, out->length), Bitmap(out->buffers[1], out->offset, out->length)}; - - auto apply = [&](uint64_t left_valid, uint64_t left_data, uint64_t right_valid, + + auto apply = [&](uint64_t left_valid, uint64_t left_data, uint64_t right_valid, uint64_t right_data, uint64_t* out_validity, uint64_t* out_data) { - auto left_true = left_valid & left_data; - auto left_false = left_valid & ~left_data; - - auto right_true = right_valid & right_data; - auto right_false = right_valid & ~right_data; - + auto left_true = left_valid & left_data; + auto left_false = left_valid & ~left_data; + + auto right_true = right_valid & right_data; + auto right_false = right_valid & ~right_data; + compute_word(left_true, left_false, right_true, right_false, out_validity, out_data); - }; - + }; + if (right.null_count == 0) { std::array<Bitmap, 3> in_bms{left_valid_bm, left_data_bm, right_data_bm}; Bitmap::VisitWordsAndWrite( @@ -65,7 +65,7 @@ void ComputeKleene(ComputeWord&& compute_word, KernelContext* ctx, const ArrayDa }); return; } - + if (left.null_count == 0) { std::array<Bitmap, 3> in_bms{left_data_bm, right_valid_bm, right_data_bm}; Bitmap::VisitWordsAndWrite( @@ -74,7 +74,7 @@ void ComputeKleene(ComputeWord&& compute_word, KernelContext* ctx, const ArrayDa apply(~uint64_t(0), in[0], in[1], in[2], &(out->at(0)), &(out->at(1))); }); return; - } + } DCHECK(left.null_count != 0 && right.null_count != 0); std::array<Bitmap, 4> in_bms{left_valid_bm, left_data_bm, right_valid_bm, @@ -84,8 +84,8 @@ void ComputeKleene(ComputeWord&& compute_word, KernelContext* ctx, const ArrayDa [&](const std::array<uint64_t, 4>& in, std::array<uint64_t, 2>* out) { apply(in[0], in[1], in[2], in[3], &(out->at(0)), &(out->at(1))); }); -} - +} + inline BooleanScalar InvertScalar(const Scalar& in) { return in.is_valid ? BooleanScalar(!checked_cast<const BooleanScalar&>(in).value) : BooleanScalar(); @@ -121,13 +121,13 @@ struct AndOp : Commutative<AndOp> { static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right, Scalar* out) { if (left.is_valid && right.is_valid) { - checked_cast<BooleanScalar*>(out)->value = + checked_cast<BooleanScalar*>(out)->value = checked_cast<const BooleanScalar&>(left).value && checked_cast<const BooleanScalar&>(right).value; - } + } return Status::OK(); - } - + } + static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right, ArrayData* out) { if (right.is_valid) { @@ -136,17 +136,17 @@ struct AndOp : Commutative<AndOp> { : GetBitmap(*out, 1).SetBitsTo(false); } return Status::OK(); - } - + } + static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ::arrow::internal::BitmapAnd(left.buffers[1]->data(), left.offset, - right.buffers[1]->data(), right.offset, right.length, - out->offset, out->buffers[1]->mutable_data()); + ::arrow::internal::BitmapAnd(left.buffers[1]->data(), left.offset, + right.buffers[1]->data(), right.offset, right.length, + out->offset, out->buffers[1]->mutable_data()); return Status::OK(); - } -}; - + } +}; + struct KleeneAndOp : Commutative<KleeneAndOp> { using Commutative<KleeneAndOp>::Call; @@ -202,23 +202,23 @@ struct KleeneAndOp : Commutative<KleeneAndOp> { static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right, ArrayData* out) { - if (left.GetNullCount() == 0 && right.GetNullCount() == 0) { + if (left.GetNullCount() == 0 && right.GetNullCount() == 0) { out->null_count = 0; // Kleene kernels have validity bitmap pre-allocated. Therefore, set it to 1 BitUtil::SetBitmap(out->buffers[0]->mutable_data(), out->offset, out->length); return AndOp::Call(ctx, left, right, out); - } - auto compute_word = [](uint64_t left_true, uint64_t left_false, uint64_t right_true, - uint64_t right_false, uint64_t* out_valid, - uint64_t* out_data) { - *out_data = left_true & right_true; - *out_valid = left_false | right_false | (left_true & right_true); - }; - ComputeKleene(compute_word, ctx, left, right, out); + } + auto compute_word = [](uint64_t left_true, uint64_t left_false, uint64_t right_true, + uint64_t right_false, uint64_t* out_valid, + uint64_t* out_data) { + *out_data = left_true & right_true; + *out_valid = left_false | right_false | (left_true & right_true); + }; + ComputeKleene(compute_word, ctx, left, right, out); return Status::OK(); - } -}; - + } +}; + struct OrOp : Commutative<OrOp> { using Commutative<OrOp>::Call; @@ -244,13 +244,13 @@ struct OrOp : Commutative<OrOp> { static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ::arrow::internal::BitmapOr(left.buffers[1]->data(), left.offset, - right.buffers[1]->data(), right.offset, right.length, - out->offset, out->buffers[1]->mutable_data()); + ::arrow::internal::BitmapOr(left.buffers[1]->data(), left.offset, + right.buffers[1]->data(), right.offset, right.length, + out->offset, out->buffers[1]->mutable_data()); return Status::OK(); - } -}; - + } +}; + struct KleeneOrOp : Commutative<KleeneOrOp> { using Commutative<KleeneOrOp>::Call; @@ -306,25 +306,25 @@ struct KleeneOrOp : Commutative<KleeneOrOp> { static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right, ArrayData* out) { - if (left.GetNullCount() == 0 && right.GetNullCount() == 0) { + if (left.GetNullCount() == 0 && right.GetNullCount() == 0) { out->null_count = 0; // Kleene kernels have validity bitmap pre-allocated. Therefore, set it to 1 BitUtil::SetBitmap(out->buffers[0]->mutable_data(), out->offset, out->length); return OrOp::Call(ctx, left, right, out); - } - - static auto compute_word = [](uint64_t left_true, uint64_t left_false, - uint64_t right_true, uint64_t right_false, - uint64_t* out_valid, uint64_t* out_data) { - *out_data = left_true | right_true; - *out_valid = left_true | right_true | (left_false & right_false); - }; - + } + + static auto compute_word = [](uint64_t left_true, uint64_t left_false, + uint64_t right_true, uint64_t right_false, + uint64_t* out_valid, uint64_t* out_data) { + *out_data = left_true | right_true; + *out_valid = left_true | right_true | (left_false & right_false); + }; + ComputeKleene(compute_word, ctx, left, right, out); return Status::OK(); - } -}; - + } +}; + struct XorOp : Commutative<XorOp> { using Commutative<XorOp>::Call; @@ -350,13 +350,13 @@ struct XorOp : Commutative<XorOp> { static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ::arrow::internal::BitmapXor(left.buffers[1]->data(), left.offset, - right.buffers[1]->data(), right.offset, right.length, - out->offset, out->buffers[1]->mutable_data()); + ::arrow::internal::BitmapXor(left.buffers[1]->data(), left.offset, + right.buffers[1]->data(), right.offset, right.length, + out->offset, out->buffers[1]->mutable_data()); return Status::OK(); - } -}; - + } +}; + struct AndNotOp { static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right, Scalar* out) { @@ -458,18 +458,18 @@ struct KleeneAndNotOp { void MakeFunction(const std::string& name, int arity, ArrayKernelExec exec, const FunctionDoc* doc, FunctionRegistry* registry, - NullHandling::type null_handling = NullHandling::INTERSECTION) { + NullHandling::type null_handling = NullHandling::INTERSECTION) { auto func = std::make_shared<ScalarFunction>(name, Arity(arity), doc); - - // Scalar arguments not yet supported + + // Scalar arguments not yet supported std::vector<InputType> in_types(arity, InputType(boolean())); - ScalarKernel kernel(std::move(in_types), boolean(), exec); - kernel.null_handling = null_handling; - - DCHECK_OK(func->AddKernel(kernel)); - DCHECK_OK(registry->AddFunction(std::move(func))); -} - + ScalarKernel kernel(std::move(in_types), boolean(), exec); + kernel.null_handling = null_handling; + + DCHECK_OK(func->AddKernel(kernel)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + const FunctionDoc invert_doc{"Invert boolean values", "", {"values"}}; const FunctionDoc and_doc{ @@ -538,26 +538,26 @@ const FunctionDoc or_kleene_doc{ "For a different null behavior, see function \"and\"."), {"x", "y"}}; -} // namespace - -namespace internal { - -void RegisterScalarBoolean(FunctionRegistry* registry) { - // These functions can write into sliced output bitmaps +} // namespace + +namespace internal { + +void RegisterScalarBoolean(FunctionRegistry* registry) { + // These functions can write into sliced output bitmaps MakeFunction("invert", 1, applicator::SimpleUnary<InvertOp>, &invert_doc, registry); MakeFunction("and", 2, applicator::SimpleBinary<AndOp>, &and_doc, registry); MakeFunction("and_not", 2, applicator::SimpleBinary<AndNotOp>, &and_not_doc, registry); MakeFunction("or", 2, applicator::SimpleBinary<OrOp>, &or_doc, registry); MakeFunction("xor", 2, applicator::SimpleBinary<XorOp>, &xor_doc, registry); - + MakeFunction("and_kleene", 2, applicator::SimpleBinary<KleeneAndOp>, &and_kleene_doc, registry, NullHandling::COMPUTED_PREALLOCATE); MakeFunction("and_not_kleene", 2, applicator::SimpleBinary<KleeneAndNotOp>, &and_not_kleene_doc, registry, NullHandling::COMPUTED_PREALLOCATE); MakeFunction("or_kleene", 2, applicator::SimpleBinary<KleeneOrOp>, &or_kleene_doc, registry, NullHandling::COMPUTED_PREALLOCATE); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc index dad94c1ace..c5fd7b78b1 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc @@ -1,70 +1,70 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Cast types to boolean - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Cast types to boolean + #include "arrow/array/builder_primitive.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/scalar_cast_internal.h" -#include "arrow/util/value_parsing.h" - -namespace arrow { - -using internal::ParseValue; - -namespace compute { -namespace internal { - -struct IsNonZero { - template <typename OutValue, typename Arg0Value> +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/util/value_parsing.h" + +namespace arrow { + +using internal::ParseValue; + +namespace compute { +namespace internal { + +struct IsNonZero { + template <typename OutValue, typename Arg0Value> static OutValue Call(KernelContext*, Arg0Value val, Status*) { - return val != 0; - } -}; - -struct ParseBooleanString { - template <typename OutValue, typename Arg0Value> + return val != 0; + } +}; + +struct ParseBooleanString { + template <typename OutValue, typename Arg0Value> static OutValue Call(KernelContext*, Arg0Value val, Status* st) { - bool result = false; - if (ARROW_PREDICT_FALSE(!ParseValue<BooleanType>(val.data(), val.size(), &result))) { + bool result = false; + if (ARROW_PREDICT_FALSE(!ParseValue<BooleanType>(val.data(), val.size(), &result))) { *st = Status::Invalid("Failed to parse value: ", val); - } - return result; - } -}; - -std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts() { - auto func = std::make_shared<CastFunction>("cast_boolean", Type::BOOL); - AddCommonCasts(Type::BOOL, boolean(), func.get()); + } + return result; + } +}; + +std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts() { + auto func = std::make_shared<CastFunction>("cast_boolean", Type::BOOL); + AddCommonCasts(Type::BOOL, boolean(), func.get()); AddZeroCopyCast(Type::BOOL, boolean(), boolean(), func.get()); - - for (const auto& ty : NumericTypes()) { - ArrayKernelExec exec = - GenerateNumeric<applicator::ScalarUnary, BooleanType, IsNonZero>(*ty); - DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec)); - } - for (const auto& ty : BaseBinaryTypes()) { - ArrayKernelExec exec = GenerateVarBinaryBase<applicator::ScalarUnaryNotNull, - BooleanType, ParseBooleanString>(*ty); - DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec)); - } - return {func}; -} - -} // namespace internal -} // namespace compute -} // namespace arrow + + for (const auto& ty : NumericTypes()) { + ArrayKernelExec exec = + GenerateNumeric<applicator::ScalarUnary, BooleanType, IsNonZero>(*ty); + DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec)); + } + for (const auto& ty : BaseBinaryTypes()) { + ArrayKernelExec exec = GenerateVarBinaryBase<applicator::ScalarUnaryNotNull, + BooleanType, ParseBooleanString>(*ty); + DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec)); + } + return {func}; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 198c82bd97..e25523a3c1 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -1,175 +1,175 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/kernels/scalar_cast_internal.h" -#include "arrow/compute/cast_internal.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/extension_type.h" - -namespace arrow { - -using internal::PrimitiveScalarBase; - -namespace compute { -namespace internal { - -// ---------------------------------------------------------------------- - -template <typename OutT, typename InT> -ARROW_DISABLE_UBSAN("float-cast-overflow") -void DoStaticCast(const void* in_data, int64_t in_offset, int64_t length, - int64_t out_offset, void* out_data) { - auto in = reinterpret_cast<const InT*>(in_data) + in_offset; - auto out = reinterpret_cast<OutT*>(out_data) + out_offset; - for (int64_t i = 0; i < length; ++i) { - *out++ = static_cast<OutT>(*in++); - } -} - -using StaticCastFunc = std::function<void(const void*, int64_t, int64_t, int64_t, void*)>; - -template <typename OutType, typename InType, typename Enable = void> -struct CastPrimitive { - static void Exec(const Datum& input, Datum* out) { - using OutT = typename OutType::c_type; - using InT = typename InType::c_type; - - StaticCastFunc caster = DoStaticCast<OutT, InT>; - if (input.kind() == Datum::ARRAY) { - const ArrayData& arr = *input.array(); - ArrayData* out_arr = out->mutable_array(); - caster(arr.buffers[1]->data(), arr.offset, arr.length, out_arr->offset, - out_arr->buffers[1]->mutable_data()); - } else { - // Scalar path. Use the caster with length 1 to place the casted value into - // the output - const auto& in_scalar = input.scalar_as<PrimitiveScalarBase>(); - auto out_scalar = checked_cast<PrimitiveScalarBase*>(out->scalar().get()); - caster(in_scalar.data(), /*in_offset=*/0, /*length=*/1, /*out_offset=*/0, - out_scalar->mutable_data()); - } - } -}; - -template <typename OutType, typename InType> -struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>::value>> { - // memcpy output - static void Exec(const Datum& input, Datum* out) { - using T = typename InType::c_type; - - if (input.kind() == Datum::ARRAY) { - const ArrayData& arr = *input.array(); - ArrayData* out_arr = out->mutable_array(); - std::memcpy( - reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data()) + out_arr->offset, - reinterpret_cast<const T*>(arr.buffers[1]->data()) + arr.offset, - arr.length * sizeof(T)); - } else { - // Scalar path. Use the caster with length 1 to place the casted value into - // the output - const auto& in_scalar = input.scalar_as<PrimitiveScalarBase>(); - auto out_scalar = checked_cast<PrimitiveScalarBase*>(out->scalar().get()); - *reinterpret_cast<T*>(out_scalar->mutable_data()) = - *reinterpret_cast<const T*>(in_scalar.data()); - } - } -}; - -template <typename InType> -void CastNumberImpl(Type::type out_type, const Datum& input, Datum* out) { - switch (out_type) { - case Type::INT8: - return CastPrimitive<Int8Type, InType>::Exec(input, out); - case Type::INT16: - return CastPrimitive<Int16Type, InType>::Exec(input, out); - case Type::INT32: - return CastPrimitive<Int32Type, InType>::Exec(input, out); - case Type::INT64: - return CastPrimitive<Int64Type, InType>::Exec(input, out); - case Type::UINT8: - return CastPrimitive<UInt8Type, InType>::Exec(input, out); - case Type::UINT16: - return CastPrimitive<UInt16Type, InType>::Exec(input, out); - case Type::UINT32: - return CastPrimitive<UInt32Type, InType>::Exec(input, out); - case Type::UINT64: - return CastPrimitive<UInt64Type, InType>::Exec(input, out); - case Type::FLOAT: - return CastPrimitive<FloatType, InType>::Exec(input, out); - case Type::DOUBLE: - return CastPrimitive<DoubleType, InType>::Exec(input, out); - default: - break; - } -} - -void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input, - Datum* out) { - switch (in_type) { - case Type::INT8: - return CastNumberImpl<Int8Type>(out_type, input, out); - case Type::INT16: - return CastNumberImpl<Int16Type>(out_type, input, out); - case Type::INT32: - return CastNumberImpl<Int32Type>(out_type, input, out); - case Type::INT64: - return CastNumberImpl<Int64Type>(out_type, input, out); - case Type::UINT8: - return CastNumberImpl<UInt8Type>(out_type, input, out); - case Type::UINT16: - return CastNumberImpl<UInt16Type>(out_type, input, out); - case Type::UINT32: - return CastNumberImpl<UInt32Type>(out_type, input, out); - case Type::UINT64: - return CastNumberImpl<UInt64Type>(out_type, input, out); - case Type::FLOAT: - return CastNumberImpl<FloatType>(out_type, input, out); - case Type::DOUBLE: - return CastNumberImpl<DoubleType>(out_type, input, out); - default: - DCHECK(false); - break; - } -} - -// ---------------------------------------------------------------------- - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/compute/cast_internal.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/extension_type.h" + +namespace arrow { + +using internal::PrimitiveScalarBase; + +namespace compute { +namespace internal { + +// ---------------------------------------------------------------------- + +template <typename OutT, typename InT> +ARROW_DISABLE_UBSAN("float-cast-overflow") +void DoStaticCast(const void* in_data, int64_t in_offset, int64_t length, + int64_t out_offset, void* out_data) { + auto in = reinterpret_cast<const InT*>(in_data) + in_offset; + auto out = reinterpret_cast<OutT*>(out_data) + out_offset; + for (int64_t i = 0; i < length; ++i) { + *out++ = static_cast<OutT>(*in++); + } +} + +using StaticCastFunc = std::function<void(const void*, int64_t, int64_t, int64_t, void*)>; + +template <typename OutType, typename InType, typename Enable = void> +struct CastPrimitive { + static void Exec(const Datum& input, Datum* out) { + using OutT = typename OutType::c_type; + using InT = typename InType::c_type; + + StaticCastFunc caster = DoStaticCast<OutT, InT>; + if (input.kind() == Datum::ARRAY) { + const ArrayData& arr = *input.array(); + ArrayData* out_arr = out->mutable_array(); + caster(arr.buffers[1]->data(), arr.offset, arr.length, out_arr->offset, + out_arr->buffers[1]->mutable_data()); + } else { + // Scalar path. Use the caster with length 1 to place the casted value into + // the output + const auto& in_scalar = input.scalar_as<PrimitiveScalarBase>(); + auto out_scalar = checked_cast<PrimitiveScalarBase*>(out->scalar().get()); + caster(in_scalar.data(), /*in_offset=*/0, /*length=*/1, /*out_offset=*/0, + out_scalar->mutable_data()); + } + } +}; + +template <typename OutType, typename InType> +struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>::value>> { + // memcpy output + static void Exec(const Datum& input, Datum* out) { + using T = typename InType::c_type; + + if (input.kind() == Datum::ARRAY) { + const ArrayData& arr = *input.array(); + ArrayData* out_arr = out->mutable_array(); + std::memcpy( + reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data()) + out_arr->offset, + reinterpret_cast<const T*>(arr.buffers[1]->data()) + arr.offset, + arr.length * sizeof(T)); + } else { + // Scalar path. Use the caster with length 1 to place the casted value into + // the output + const auto& in_scalar = input.scalar_as<PrimitiveScalarBase>(); + auto out_scalar = checked_cast<PrimitiveScalarBase*>(out->scalar().get()); + *reinterpret_cast<T*>(out_scalar->mutable_data()) = + *reinterpret_cast<const T*>(in_scalar.data()); + } + } +}; + +template <typename InType> +void CastNumberImpl(Type::type out_type, const Datum& input, Datum* out) { + switch (out_type) { + case Type::INT8: + return CastPrimitive<Int8Type, InType>::Exec(input, out); + case Type::INT16: + return CastPrimitive<Int16Type, InType>::Exec(input, out); + case Type::INT32: + return CastPrimitive<Int32Type, InType>::Exec(input, out); + case Type::INT64: + return CastPrimitive<Int64Type, InType>::Exec(input, out); + case Type::UINT8: + return CastPrimitive<UInt8Type, InType>::Exec(input, out); + case Type::UINT16: + return CastPrimitive<UInt16Type, InType>::Exec(input, out); + case Type::UINT32: + return CastPrimitive<UInt32Type, InType>::Exec(input, out); + case Type::UINT64: + return CastPrimitive<UInt64Type, InType>::Exec(input, out); + case Type::FLOAT: + return CastPrimitive<FloatType, InType>::Exec(input, out); + case Type::DOUBLE: + return CastPrimitive<DoubleType, InType>::Exec(input, out); + default: + break; + } +} + +void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input, + Datum* out) { + switch (in_type) { + case Type::INT8: + return CastNumberImpl<Int8Type>(out_type, input, out); + case Type::INT16: + return CastNumberImpl<Int16Type>(out_type, input, out); + case Type::INT32: + return CastNumberImpl<Int32Type>(out_type, input, out); + case Type::INT64: + return CastNumberImpl<Int64Type>(out_type, input, out); + case Type::UINT8: + return CastNumberImpl<UInt8Type>(out_type, input, out); + case Type::UINT16: + return CastNumberImpl<UInt16Type>(out_type, input, out); + case Type::UINT32: + return CastNumberImpl<UInt32Type>(out_type, input, out); + case Type::UINT64: + return CastNumberImpl<UInt64Type>(out_type, input, out); + case Type::FLOAT: + return CastNumberImpl<FloatType>(out_type, input, out); + case Type::DOUBLE: + return CastNumberImpl<DoubleType>(out_type, input, out); + default: + DCHECK(false); + break; + } +} + +// ---------------------------------------------------------------------- + Status UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { DCHECK(out->is_array()); - DictionaryArray dict_arr(batch[0].array()); - const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; - - const auto& dict_type = *dict_arr.dictionary()->type(); + DictionaryArray dict_arr(batch[0].array()); + const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; + + const auto& dict_type = *dict_arr.dictionary()->type(); if (!dict_type.Equals(options.to_type) && !CanCast(dict_type, *options.to_type)) { return Status::Invalid("Cast type ", options.to_type->ToString(), " incompatible with dictionary type ", dict_type.ToString()); - } - + } + ARROW_ASSIGN_OR_RAISE(*out, Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), TakeOptions::Defaults(), ctx->exec_context())); if (!dict_type.Equals(options.to_type)) { ARROW_ASSIGN_OR_RAISE(*out, Cast(*out, options)); - } + } return Status::OK(); -} - +} + Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (out->is_scalar()) { out->scalar()->is_valid = false; @@ -179,23 +179,23 @@ Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { output->null_count = batch.length; } return Status::OK(); -} - +} + Status CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const CastOptions& options = checked_cast<const CastState*>(ctx->state())->options; - - const DataType& in_type = *batch[0].type(); - const auto storage_type = checked_cast<const ExtensionType&>(in_type).storage_type(); - - ExtensionArray extension(batch[0].array()); - - Datum casted_storage; + const CastOptions& options = checked_cast<const CastState*>(ctx->state())->options; + + const DataType& in_type = *batch[0].type(); + const auto storage_type = checked_cast<const ExtensionType&>(in_type).storage_type(); + + ExtensionArray extension(batch[0].array()); + + Datum casted_storage; RETURN_NOT_OK(Cast(*extension.storage(), out->type(), options, ctx->exec_context()) .Value(&casted_storage)); - out->value = casted_storage.array(); + out->value = casted_storage.array(); return Status::OK(); -} - +} + Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (!batch[0].is_scalar()) { ArrayData* output = out->mutable_array(); @@ -204,25 +204,25 @@ Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { out->value = nulls->data(); } return Status::OK(); -} - -Result<ValueDescr> ResolveOutputFromOptions(KernelContext* ctx, - const std::vector<ValueDescr>& args) { - const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; - return ValueDescr(options.to_type, args[0].shape); -} - -/// You will see some of kernels with -/// -/// kOutputTargetType -/// -/// for their output type resolution. This is somewhat of an eyesore but the -/// easiest initial way to get the requested cast type including the TimeUnit -/// to the kernel (which is needed to compute the output) was through -/// CastOptions - -OutputType kOutputTargetType(ResolveOutputFromOptions); - +} + +Result<ValueDescr> ResolveOutputFromOptions(KernelContext* ctx, + const std::vector<ValueDescr>& args) { + const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; + return ValueDescr(options.to_type, args[0].shape); +} + +/// You will see some of kernels with +/// +/// kOutputTargetType +/// +/// for their output type resolution. This is somewhat of an eyesore but the +/// easiest initial way to get the requested cast type including the TimeUnit +/// to the kernel (which is needed to compute the output) was through +/// CastOptions + +OutputType kOutputTargetType(ResolveOutputFromOptions); + Status ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { DCHECK_EQ(batch[0].kind(), Datum::ARRAY); // Make a copy of the buffers into a destination array without carrying @@ -235,51 +235,51 @@ Status ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) output->offset = input.offset; output->child_data = input.child_data; return Status::OK(); -} - -void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, - CastFunction* func) { - auto sig = KernelSignature::Make({in_type}, out_type); - ScalarKernel kernel; +} + +void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, + CastFunction* func) { + auto sig = KernelSignature::Make({in_type}, out_type); + ScalarKernel kernel; kernel.exec = TrivialScalarUnaryAsArraysExec(ZeroCopyCastExec); - kernel.signature = sig; - kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - DCHECK_OK(func->AddKernel(in_type_id, std::move(kernel))); -} - -static bool CanCastFromDictionary(Type::type type_id) { - return (is_primitive(type_id) || is_base_binary_like(type_id) || - is_fixed_size_binary(type_id)); -} - -void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) { - // From null to this type + kernel.signature = sig; + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(in_type_id, std::move(kernel))); +} + +static bool CanCastFromDictionary(Type::type type_id) { + return (is_primitive(type_id) || is_base_binary_like(type_id) || + is_fixed_size_binary(type_id)); +} + +void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) { + // From null to this type ScalarKernel kernel; kernel.exec = CastFromNull; kernel.signature = KernelSignature::Make({null()}, out_ty); kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; DCHECK_OK(func->AddKernel(Type::NA, std::move(kernel))); - - // From dictionary to this type - if (CanCastFromDictionary(out_type_id)) { - // Dictionary unpacking not implemented for boolean or nested types. - // - // XXX: Uses Take and does its own memory allocation for the moment. We can - // fix this later. + + // From dictionary to this type + if (CanCastFromDictionary(out_type_id)) { + // Dictionary unpacking not implemented for boolean or nested types. + // + // XXX: Uses Take and does its own memory allocation for the moment. We can + // fix this later. DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty, TrivialScalarUnaryAsArraysExec(UnpackDictionary), NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); - } - - // From extension type to this type - DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType::Array(Type::EXTENSION)}, out_ty, - CastFromExtension, NullHandling::COMPUTED_NO_PREALLOCATE, - MemAllocation::NO_PREALLOCATE)); -} - -} // namespace internal -} // namespace compute -} // namespace arrow + } + + // From extension type to this type + DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType::Array(Type::EXTENSION)}, out_ty, + CastFromExtension, NullHandling::COMPUTED_NO_PREALLOCATE, + MemAllocation::NO_PREALLOCATE)); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 2419d898a6..12e3605695 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -1,88 +1,88 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "arrow/compute/api_vector.h" -#include "arrow/compute/cast.h" // IWYU pragma: export -#include "arrow/compute/cast_internal.h" // IWYU pragma: export -#include "arrow/compute/kernels/common.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/api_vector.h" +#include "arrow/compute/cast.h" // IWYU pragma: export +#include "arrow/compute/cast_internal.h" // IWYU pragma: export +#include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" - -namespace arrow { - -using internal::checked_cast; - -namespace compute { -namespace internal { - -template <typename OutType, typename InType, typename Enable = void> -struct CastFunctor {}; - -// No-op functor for identity casts -template <typename O, typename I> -struct CastFunctor< - O, I, enable_if_t<std::is_same<O, I>::value && is_parameter_free_type<I>::value>> { + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace internal { + +template <typename OutType, typename InType, typename Enable = void> +struct CastFunctor {}; + +// No-op functor for identity casts +template <typename O, typename I> +struct CastFunctor< + O, I, enable_if_t<std::is_same<O, I>::value && is_parameter_free_type<I>::value>> { static Status Exec(KernelContext*, const ExecBatch&, Datum*) { return Status::OK(); } -}; - +}; + Status CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out); - -// Utility for numeric casts -void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input, - Datum* out); - -// ---------------------------------------------------------------------- -// Dictionary to other things - + +// Utility for numeric casts +void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input, + Datum* out); + +// ---------------------------------------------------------------------- +// Dictionary to other things + Status UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out); - + Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out); - + Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out); - + // Adds a cast function where CastFunctor is specialized and the input and output // types are parameter free (have a type_singleton). Scalar inputs are handled by // wrapping with TrivialScalarUnaryAsArraysExec. -template <typename InType, typename OutType> -void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) { +template <typename InType, typename OutType> +void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) { DCHECK_OK(func->AddKernel( InType::type_id, {in_ty}, out_ty, TrivialScalarUnaryAsArraysExec(CastFunctor<OutType, InType>::Exec))); -} - +} + Status ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out); - -void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, - CastFunction* func); - -// OutputType::Resolver that returns a descr with the shape of the input -// argument and the type from CastOptions -Result<ValueDescr> ResolveOutputFromOptions(KernelContext* ctx, - const std::vector<ValueDescr>& args); - -ARROW_EXPORT extern OutputType kOutputTargetType; - -// Add generic casts to out_ty from: -// - the null type -// - dictionary with out_ty as given value type -// - extension types with a compatible storage type -void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func); - -} // namespace internal -} // namespace compute -} // namespace arrow + +void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, + CastFunction* func); + +// OutputType::Resolver that returns a descr with the shape of the input +// argument and the type from CastOptions +Result<ValueDescr> ResolveOutputFromOptions(KernelContext* ctx, + const std::vector<ValueDescr>& args); + +ARROW_EXPORT extern OutputType kOutputTargetType; + +// Add generic casts to out_ty from: +// - the null type +// - dictionary with out_ty as given value type +// - extension types with a compatible storage type +void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func); + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc index ec92dbb5d6..8b8fdf094a 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc @@ -1,46 +1,46 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Implementation of casting to (or between) list types - -#include <utility> -#include <vector> - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to (or between) list types + +#include <utility> +#include <vector> + #include "arrow/array/builder_nested.h" #include "arrow/compute/api_scalar.h" -#include "arrow/compute/cast.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/compute/cast.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" #include "arrow/util/bitmap_ops.h" - -namespace arrow { + +namespace arrow { using internal::CopyBitmap; -namespace compute { -namespace internal { - -template <typename Type> +namespace compute { +namespace internal { + +template <typename Type> Status CastListExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { using offset_type = typename Type::offset_type; using ScalarType = typename TypeTraits<Type>::ScalarType; - + const CastOptions& options = CastState::Get(ctx); - + auto child_type = checked_cast<const Type&>(*out->type()).value_type(); if (out->kind() == Datum::SCALAR) { @@ -55,11 +55,11 @@ Status CastListExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { out_scalar->is_valid = true; } return Status::OK(); - } - + } + const ArrayData& in_array = *batch[0].array(); ArrayData* out_array = out->mutable_array(); - + // Copy from parent out_array->buffers = in_array.buffers; Datum values = in_array.child_data[0]; @@ -88,46 +88,46 @@ Status CastListExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { DCHECK_EQ(Datum::ARRAY, cast_values.kind()); out_array->child_data.push_back(cast_values.array()); return Status::OK(); -} - -template <typename Type> -void AddListCast(CastFunction* func) { - ScalarKernel kernel; - kernel.exec = CastListExec<Type>; - kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType); - kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel))); -} - -std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() { - // We use the list<T> from the CastOptions when resolving the output type - - auto cast_list = std::make_shared<CastFunction>("cast_list", Type::LIST); - AddCommonCasts(Type::LIST, kOutputTargetType, cast_list.get()); - AddListCast<ListType>(cast_list.get()); - - auto cast_large_list = - std::make_shared<CastFunction>("cast_large_list", Type::LARGE_LIST); - AddCommonCasts(Type::LARGE_LIST, kOutputTargetType, cast_large_list.get()); - AddListCast<LargeListType>(cast_large_list.get()); - - // FSL is a bit incomplete at the moment - auto cast_fsl = - std::make_shared<CastFunction>("cast_fixed_size_list", Type::FIXED_SIZE_LIST); - AddCommonCasts(Type::FIXED_SIZE_LIST, kOutputTargetType, cast_fsl.get()); - - // So is struct - auto cast_struct = std::make_shared<CastFunction>("cast_struct", Type::STRUCT); - AddCommonCasts(Type::STRUCT, kOutputTargetType, cast_struct.get()); - +} + +template <typename Type> +void AddListCast(CastFunction* func) { + ScalarKernel kernel; + kernel.exec = CastListExec<Type>; + kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel))); +} + +std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() { + // We use the list<T> from the CastOptions when resolving the output type + + auto cast_list = std::make_shared<CastFunction>("cast_list", Type::LIST); + AddCommonCasts(Type::LIST, kOutputTargetType, cast_list.get()); + AddListCast<ListType>(cast_list.get()); + + auto cast_large_list = + std::make_shared<CastFunction>("cast_large_list", Type::LARGE_LIST); + AddCommonCasts(Type::LARGE_LIST, kOutputTargetType, cast_large_list.get()); + AddListCast<LargeListType>(cast_large_list.get()); + + // FSL is a bit incomplete at the moment + auto cast_fsl = + std::make_shared<CastFunction>("cast_fixed_size_list", Type::FIXED_SIZE_LIST); + AddCommonCasts(Type::FIXED_SIZE_LIST, kOutputTargetType, cast_fsl.get()); + + // So is struct + auto cast_struct = std::make_shared<CastFunction>("cast_struct", Type::STRUCT); + AddCommonCasts(Type::STRUCT, kOutputTargetType, cast_struct.get()); + // So is dictionary auto cast_dictionary = std::make_shared<CastFunction>("cast_dictionary", Type::DICTIONARY); AddCommonCasts(Type::DICTIONARY, kOutputTargetType, cast_dictionary.get()); return {cast_list, cast_large_list, cast_fsl, cast_struct, cast_dictionary}; -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index cc7b533f26..ae9a04e8e9 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -1,399 +1,399 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Implementation of casting to integer, floating point, or decimal types - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to integer, floating point, or decimal types + #include "arrow/array/builder_primitive.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" #include "arrow/compute/kernels/util_internal.h" -#include "arrow/util/bit_block_counter.h" -#include "arrow/util/int_util.h" -#include "arrow/util/value_parsing.h" - -namespace arrow { - -using internal::BitBlockCount; -using internal::CheckIntegersInRange; -using internal::IntegersCanFit; -using internal::OptionalBitBlockCounter; -using internal::ParseValue; - -namespace compute { -namespace internal { - +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/int_util.h" +#include "arrow/util/value_parsing.h" + +namespace arrow { + +using internal::BitBlockCount; +using internal::CheckIntegersInRange; +using internal::IntegersCanFit; +using internal::OptionalBitBlockCounter; +using internal::ParseValue; + +namespace compute { +namespace internal { + Status CastIntegerToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast<const CastState*>(ctx->state())->options; - if (!options.allow_int_overflow) { + const auto& options = checked_cast<const CastState*>(ctx->state())->options; + if (!options.allow_int_overflow) { RETURN_NOT_OK(IntegersCanFit(batch[0], *out->type())); - } - CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out); + } + CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out); return Status::OK(); -} - +} + Status CastFloatingToFloating(KernelContext*, const ExecBatch& batch, Datum* out) { - CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out); + CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out); return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Implement fast safe floating point to integer cast - -// InType is a floating point type we are planning to cast to integer -template <typename InType, typename OutType, typename InT = typename InType::c_type, - typename OutT = typename OutType::c_type> -ARROW_DISABLE_UBSAN("float-cast-overflow") -Status CheckFloatTruncation(const Datum& input, const Datum& output) { - auto WasTruncated = [&](OutT out_val, InT in_val) -> bool { - return static_cast<InT>(out_val) != in_val; - }; - auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool { - return is_valid && static_cast<InT>(out_val) != in_val; - }; - auto GetErrorMessage = [&](InT val) { +} + +// ---------------------------------------------------------------------- +// Implement fast safe floating point to integer cast + +// InType is a floating point type we are planning to cast to integer +template <typename InType, typename OutType, typename InT = typename InType::c_type, + typename OutT = typename OutType::c_type> +ARROW_DISABLE_UBSAN("float-cast-overflow") +Status CheckFloatTruncation(const Datum& input, const Datum& output) { + auto WasTruncated = [&](OutT out_val, InT in_val) -> bool { + return static_cast<InT>(out_val) != in_val; + }; + auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool { + return is_valid && static_cast<InT>(out_val) != in_val; + }; + auto GetErrorMessage = [&](InT val) { return Status::Invalid("Float value ", val, " was truncated converting to ", - *output.type()); - }; - - if (input.kind() == Datum::SCALAR) { - DCHECK_EQ(output.kind(), Datum::SCALAR); - const auto& in_scalar = input.scalar_as<typename TypeTraits<InType>::ScalarType>(); - const auto& out_scalar = output.scalar_as<typename TypeTraits<OutType>::ScalarType>(); - if (WasTruncatedMaybeNull(out_scalar.value, in_scalar.value, out_scalar.is_valid)) { - return GetErrorMessage(in_scalar.value); - } - return Status::OK(); - } - - const ArrayData& in_array = *input.array(); - const ArrayData& out_array = *output.array(); - - const InT* in_data = in_array.GetValues<InT>(1); - const OutT* out_data = out_array.GetValues<OutT>(1); - - const uint8_t* bitmap = nullptr; - if (in_array.buffers[0]) { - bitmap = in_array.buffers[0]->data(); - } - OptionalBitBlockCounter bit_counter(bitmap, in_array.offset, in_array.length); - int64_t position = 0; - int64_t offset_position = in_array.offset; - while (position < in_array.length) { - BitBlockCount block = bit_counter.NextBlock(); - bool block_out_of_bounds = false; - if (block.popcount == block.length) { - // Fast path: branchless - for (int64_t i = 0; i < block.length; ++i) { - block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]); - } - } else if (block.popcount > 0) { - // Indices have nulls, must only boundscheck non-null values - for (int64_t i = 0; i < block.length; ++i) { - block_out_of_bounds |= WasTruncatedMaybeNull( - out_data[i], in_data[i], BitUtil::GetBit(bitmap, offset_position + i)); - } - } - if (ARROW_PREDICT_FALSE(block_out_of_bounds)) { - if (in_array.GetNullCount() > 0) { - for (int64_t i = 0; i < block.length; ++i) { - if (WasTruncatedMaybeNull(out_data[i], in_data[i], - BitUtil::GetBit(bitmap, offset_position + i))) { - return GetErrorMessage(in_data[i]); - } - } - } else { - for (int64_t i = 0; i < block.length; ++i) { - if (WasTruncated(out_data[i], in_data[i])) { - return GetErrorMessage(in_data[i]); - } - } - } - } - in_data += block.length; - out_data += block.length; - position += block.length; - offset_position += block.length; - } - return Status::OK(); -} - -template <typename InType> -Status CheckFloatToIntTruncationImpl(const Datum& input, const Datum& output) { - switch (output.type()->id()) { - case Type::INT8: - return CheckFloatTruncation<InType, Int8Type>(input, output); - case Type::INT16: - return CheckFloatTruncation<InType, Int16Type>(input, output); - case Type::INT32: - return CheckFloatTruncation<InType, Int32Type>(input, output); - case Type::INT64: - return CheckFloatTruncation<InType, Int64Type>(input, output); - case Type::UINT8: - return CheckFloatTruncation<InType, UInt8Type>(input, output); - case Type::UINT16: - return CheckFloatTruncation<InType, UInt16Type>(input, output); - case Type::UINT32: - return CheckFloatTruncation<InType, UInt32Type>(input, output); - case Type::UINT64: - return CheckFloatTruncation<InType, UInt64Type>(input, output); - default: - break; - } - DCHECK(false); - return Status::OK(); -} - -Status CheckFloatToIntTruncation(const Datum& input, const Datum& output) { - switch (input.type()->id()) { - case Type::FLOAT: - return CheckFloatToIntTruncationImpl<FloatType>(input, output); - case Type::DOUBLE: - return CheckFloatToIntTruncationImpl<DoubleType>(input, output); - default: - break; - } - DCHECK(false); - return Status::OK(); -} - + *output.type()); + }; + + if (input.kind() == Datum::SCALAR) { + DCHECK_EQ(output.kind(), Datum::SCALAR); + const auto& in_scalar = input.scalar_as<typename TypeTraits<InType>::ScalarType>(); + const auto& out_scalar = output.scalar_as<typename TypeTraits<OutType>::ScalarType>(); + if (WasTruncatedMaybeNull(out_scalar.value, in_scalar.value, out_scalar.is_valid)) { + return GetErrorMessage(in_scalar.value); + } + return Status::OK(); + } + + const ArrayData& in_array = *input.array(); + const ArrayData& out_array = *output.array(); + + const InT* in_data = in_array.GetValues<InT>(1); + const OutT* out_data = out_array.GetValues<OutT>(1); + + const uint8_t* bitmap = nullptr; + if (in_array.buffers[0]) { + bitmap = in_array.buffers[0]->data(); + } + OptionalBitBlockCounter bit_counter(bitmap, in_array.offset, in_array.length); + int64_t position = 0; + int64_t offset_position = in_array.offset; + while (position < in_array.length) { + BitBlockCount block = bit_counter.NextBlock(); + bool block_out_of_bounds = false; + if (block.popcount == block.length) { + // Fast path: branchless + for (int64_t i = 0; i < block.length; ++i) { + block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]); + } + } else if (block.popcount > 0) { + // Indices have nulls, must only boundscheck non-null values + for (int64_t i = 0; i < block.length; ++i) { + block_out_of_bounds |= WasTruncatedMaybeNull( + out_data[i], in_data[i], BitUtil::GetBit(bitmap, offset_position + i)); + } + } + if (ARROW_PREDICT_FALSE(block_out_of_bounds)) { + if (in_array.GetNullCount() > 0) { + for (int64_t i = 0; i < block.length; ++i) { + if (WasTruncatedMaybeNull(out_data[i], in_data[i], + BitUtil::GetBit(bitmap, offset_position + i))) { + return GetErrorMessage(in_data[i]); + } + } + } else { + for (int64_t i = 0; i < block.length; ++i) { + if (WasTruncated(out_data[i], in_data[i])) { + return GetErrorMessage(in_data[i]); + } + } + } + } + in_data += block.length; + out_data += block.length; + position += block.length; + offset_position += block.length; + } + return Status::OK(); +} + +template <typename InType> +Status CheckFloatToIntTruncationImpl(const Datum& input, const Datum& output) { + switch (output.type()->id()) { + case Type::INT8: + return CheckFloatTruncation<InType, Int8Type>(input, output); + case Type::INT16: + return CheckFloatTruncation<InType, Int16Type>(input, output); + case Type::INT32: + return CheckFloatTruncation<InType, Int32Type>(input, output); + case Type::INT64: + return CheckFloatTruncation<InType, Int64Type>(input, output); + case Type::UINT8: + return CheckFloatTruncation<InType, UInt8Type>(input, output); + case Type::UINT16: + return CheckFloatTruncation<InType, UInt16Type>(input, output); + case Type::UINT32: + return CheckFloatTruncation<InType, UInt32Type>(input, output); + case Type::UINT64: + return CheckFloatTruncation<InType, UInt64Type>(input, output); + default: + break; + } + DCHECK(false); + return Status::OK(); +} + +Status CheckFloatToIntTruncation(const Datum& input, const Datum& output) { + switch (input.type()->id()) { + case Type::FLOAT: + return CheckFloatToIntTruncationImpl<FloatType>(input, output); + case Type::DOUBLE: + return CheckFloatToIntTruncationImpl<DoubleType>(input, output); + default: + break; + } + DCHECK(false); + return Status::OK(); +} + Status CastFloatingToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast<const CastState*>(ctx->state())->options; - CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out); - if (!options.allow_float_truncate) { + const auto& options = checked_cast<const CastState*>(ctx->state())->options; + CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out); + if (!options.allow_float_truncate) { RETURN_NOT_OK(CheckFloatToIntTruncation(batch[0], *out)); - } - return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Implement fast integer to floating point cast - -// These are the limits for exact representation of whole numbers in floating -// point numbers -template <typename T> -struct FloatingIntegerBound {}; - -template <> -struct FloatingIntegerBound<float> { - static const int64_t value = 1LL << 24; -}; - -template <> -struct FloatingIntegerBound<double> { - static const int64_t value = 1LL << 53; -}; - -template <typename InType, typename OutType, typename InT = typename InType::c_type, - typename OutT = typename OutType::c_type, - bool IsSigned = is_signed_integer_type<InType>::value> -Status CheckIntegerFloatTruncateImpl(const Datum& input) { - using InScalarType = typename TypeTraits<InType>::ScalarType; - const int64_t limit = FloatingIntegerBound<OutT>::value; - InScalarType bound_lower(IsSigned ? -limit : 0); - InScalarType bound_upper(limit); - return CheckIntegersInRange(input, bound_lower, bound_upper); -} - -Status CheckForIntegerToFloatingTruncation(const Datum& input, Type::type out_type) { - switch (input.type()->id()) { - // Small integers are all exactly representable as whole numbers - case Type::INT8: - case Type::INT16: - case Type::UINT8: - case Type::UINT16: - return Status::OK(); - case Type::INT32: { - if (out_type == Type::DOUBLE) { - return Status::OK(); - } - return CheckIntegerFloatTruncateImpl<Int32Type, FloatType>(input); - } - case Type::UINT32: { - if (out_type == Type::DOUBLE) { - return Status::OK(); - } - return CheckIntegerFloatTruncateImpl<UInt32Type, FloatType>(input); - } - case Type::INT64: { - if (out_type == Type::FLOAT) { - return CheckIntegerFloatTruncateImpl<Int64Type, FloatType>(input); - } else { - return CheckIntegerFloatTruncateImpl<Int64Type, DoubleType>(input); - } - } - case Type::UINT64: { - if (out_type == Type::FLOAT) { - return CheckIntegerFloatTruncateImpl<UInt64Type, FloatType>(input); - } else { - return CheckIntegerFloatTruncateImpl<UInt64Type, DoubleType>(input); - } - } - default: - break; - } - DCHECK(false); + } return Status::OK(); -} - +} + +// ---------------------------------------------------------------------- +// Implement fast integer to floating point cast + +// These are the limits for exact representation of whole numbers in floating +// point numbers +template <typename T> +struct FloatingIntegerBound {}; + +template <> +struct FloatingIntegerBound<float> { + static const int64_t value = 1LL << 24; +}; + +template <> +struct FloatingIntegerBound<double> { + static const int64_t value = 1LL << 53; +}; + +template <typename InType, typename OutType, typename InT = typename InType::c_type, + typename OutT = typename OutType::c_type, + bool IsSigned = is_signed_integer_type<InType>::value> +Status CheckIntegerFloatTruncateImpl(const Datum& input) { + using InScalarType = typename TypeTraits<InType>::ScalarType; + const int64_t limit = FloatingIntegerBound<OutT>::value; + InScalarType bound_lower(IsSigned ? -limit : 0); + InScalarType bound_upper(limit); + return CheckIntegersInRange(input, bound_lower, bound_upper); +} + +Status CheckForIntegerToFloatingTruncation(const Datum& input, Type::type out_type) { + switch (input.type()->id()) { + // Small integers are all exactly representable as whole numbers + case Type::INT8: + case Type::INT16: + case Type::UINT8: + case Type::UINT16: + return Status::OK(); + case Type::INT32: { + if (out_type == Type::DOUBLE) { + return Status::OK(); + } + return CheckIntegerFloatTruncateImpl<Int32Type, FloatType>(input); + } + case Type::UINT32: { + if (out_type == Type::DOUBLE) { + return Status::OK(); + } + return CheckIntegerFloatTruncateImpl<UInt32Type, FloatType>(input); + } + case Type::INT64: { + if (out_type == Type::FLOAT) { + return CheckIntegerFloatTruncateImpl<Int64Type, FloatType>(input); + } else { + return CheckIntegerFloatTruncateImpl<Int64Type, DoubleType>(input); + } + } + case Type::UINT64: { + if (out_type == Type::FLOAT) { + return CheckIntegerFloatTruncateImpl<UInt64Type, FloatType>(input); + } else { + return CheckIntegerFloatTruncateImpl<UInt64Type, DoubleType>(input); + } + } + default: + break; + } + DCHECK(false); + return Status::OK(); +} + Status CastIntegerToFloating(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast<const CastState*>(ctx->state())->options; - Type::type out_type = out->type()->id(); - if (!options.allow_float_truncate) { + const auto& options = checked_cast<const CastState*>(ctx->state())->options; + Type::type out_type = out->type()->id(); + if (!options.allow_float_truncate) { RETURN_NOT_OK(CheckForIntegerToFloatingTruncation(batch[0], out_type)); - } - CastNumberToNumberUnsafe(batch[0].type()->id(), out_type, batch[0], out); + } + CastNumberToNumberUnsafe(batch[0].type()->id(), out_type, batch[0], out); return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Boolean to number - -struct BooleanToNumber { - template <typename OutValue, typename Arg0Value> +} + +// ---------------------------------------------------------------------- +// Boolean to number + +struct BooleanToNumber { + template <typename OutValue, typename Arg0Value> static OutValue Call(KernelContext*, Arg0Value val, Status*) { - constexpr auto kOne = static_cast<OutValue>(1); - constexpr auto kZero = static_cast<OutValue>(0); - return val ? kOne : kZero; - } -}; - -template <typename O> -struct CastFunctor<O, BooleanType, enable_if_number<O>> { + constexpr auto kOne = static_cast<OutValue>(1); + constexpr auto kZero = static_cast<OutValue>(0); + return val ? kOne : kZero; + } +}; + +template <typename O> +struct CastFunctor<O, BooleanType, enable_if_number<O>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return applicator::ScalarUnary<O, BooleanType, BooleanToNumber>::Exec(ctx, batch, out); - } -}; - -// ---------------------------------------------------------------------- -// String to number - -template <typename OutType> -struct ParseString { - template <typename OutValue, typename Arg0Value> + } +}; + +// ---------------------------------------------------------------------- +// String to number + +template <typename OutType> +struct ParseString { + template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const { - OutValue result = OutValue(0); - if (ARROW_PREDICT_FALSE(!ParseValue<OutType>(val.data(), val.size(), &result))) { + OutValue result = OutValue(0); + if (ARROW_PREDICT_FALSE(!ParseValue<OutType>(val.data(), val.size(), &result))) { *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ", TypeTraits<OutType>::type_singleton()->ToString()); - } - return result; - } -}; - -template <typename O, typename I> -struct CastFunctor<O, I, enable_if_base_binary<I>> { + } + return result; + } +}; + +template <typename O, typename I> +struct CastFunctor<O, I, enable_if_base_binary<I>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return applicator::ScalarUnaryNotNull<O, I, ParseString<O>>::Exec(ctx, batch, out); - } -}; - -// ---------------------------------------------------------------------- -// Decimal to integer - -struct DecimalToIntegerMixin { + } +}; + +// ---------------------------------------------------------------------- +// Decimal to integer + +struct DecimalToIntegerMixin { template <typename OutValue, typename Arg0Value> OutValue ToInteger(KernelContext* ctx, const Arg0Value& val, Status* st) const { - constexpr auto min_value = std::numeric_limits<OutValue>::min(); - constexpr auto max_value = std::numeric_limits<OutValue>::max(); - - if (!allow_int_overflow_ && ARROW_PREDICT_FALSE(val < min_value || val > max_value)) { + constexpr auto min_value = std::numeric_limits<OutValue>::min(); + constexpr auto max_value = std::numeric_limits<OutValue>::max(); + + if (!allow_int_overflow_ && ARROW_PREDICT_FALSE(val < min_value || val > max_value)) { *st = Status::Invalid("Integer value out of bounds"); - return OutValue{}; // Zero - } else { - return static_cast<OutValue>(val.low_bits()); - } - } - - DecimalToIntegerMixin(int32_t in_scale, bool allow_int_overflow) - : in_scale_(in_scale), allow_int_overflow_(allow_int_overflow) {} - - int32_t in_scale_; - bool allow_int_overflow_; -}; - -struct UnsafeUpscaleDecimalToInteger : public DecimalToIntegerMixin { - using DecimalToIntegerMixin::DecimalToIntegerMixin; - - template <typename OutValue, typename Arg0Value> + return OutValue{}; // Zero + } else { + return static_cast<OutValue>(val.low_bits()); + } + } + + DecimalToIntegerMixin(int32_t in_scale, bool allow_int_overflow) + : in_scale_(in_scale), allow_int_overflow_(allow_int_overflow) {} + + int32_t in_scale_; + bool allow_int_overflow_; +}; + +struct UnsafeUpscaleDecimalToInteger : public DecimalToIntegerMixin { + using DecimalToIntegerMixin::DecimalToIntegerMixin; + + template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const { return ToInteger<OutValue>(ctx, val.IncreaseScaleBy(-in_scale_), st); - } -}; - -struct UnsafeDownscaleDecimalToInteger : public DecimalToIntegerMixin { - using DecimalToIntegerMixin::DecimalToIntegerMixin; - - template <typename OutValue, typename Arg0Value> + } +}; + +struct UnsafeDownscaleDecimalToInteger : public DecimalToIntegerMixin { + using DecimalToIntegerMixin::DecimalToIntegerMixin; + + template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const { return ToInteger<OutValue>(ctx, val.ReduceScaleBy(in_scale_, false), st); - } -}; - -struct SafeRescaleDecimalToInteger : public DecimalToIntegerMixin { - using DecimalToIntegerMixin::DecimalToIntegerMixin; - - template <typename OutValue, typename Arg0Value> + } +}; + +struct SafeRescaleDecimalToInteger : public DecimalToIntegerMixin { + using DecimalToIntegerMixin::DecimalToIntegerMixin; + + template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const { - auto result = val.Rescale(in_scale_, 0); - if (ARROW_PREDICT_FALSE(!result.ok())) { + auto result = val.Rescale(in_scale_, 0); + if (ARROW_PREDICT_FALSE(!result.ok())) { *st = result.status(); - return OutValue{}; // Zero - } else { + return OutValue{}; // Zero + } else { return ToInteger<OutValue>(ctx, *result, st); - } - } -}; - + } + } +}; + template <typename O, typename I> struct CastFunctor<O, I, enable_if_t<is_integer_type<O>::value && is_decimal_type<I>::value>> { - using out_type = typename O::c_type; - + using out_type = typename O::c_type; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast<const CastState*>(ctx->state())->options; - + const auto& options = checked_cast<const CastState*>(ctx->state())->options; + const auto& in_type_inst = checked_cast<const I&>(*batch[0].type()); - const auto in_scale = in_type_inst.scale(); - - if (options.allow_decimal_truncate) { - if (in_scale < 0) { - // Unsafe upscale + const auto in_scale = in_type_inst.scale(); + + if (options.allow_decimal_truncate) { + if (in_scale < 0) { + // Unsafe upscale applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimalToInteger> - kernel(UnsafeUpscaleDecimalToInteger{in_scale, options.allow_int_overflow}); - return kernel.Exec(ctx, batch, out); - } else { - // Unsafe downscale + kernel(UnsafeUpscaleDecimalToInteger{in_scale, options.allow_int_overflow}); + return kernel.Exec(ctx, batch, out); + } else { + // Unsafe downscale applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimalToInteger> - kernel(UnsafeDownscaleDecimalToInteger{in_scale, options.allow_int_overflow}); - return kernel.Exec(ctx, batch, out); - } - } else { - // Safe rescale + kernel(UnsafeDownscaleDecimalToInteger{in_scale, options.allow_int_overflow}); + return kernel.Exec(ctx, batch, out); + } + } else { + // Safe rescale applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimalToInteger> kernel( SafeRescaleDecimalToInteger{in_scale, options.allow_int_overflow}); - return kernel.Exec(ctx, batch, out); - } - } -}; - -// ---------------------------------------------------------------------- -// Decimal to decimal - + return kernel.Exec(ctx, batch, out); + } + } +}; + +// ---------------------------------------------------------------------- +// Decimal to decimal + // Helper that converts the input and output decimals // For instance, Decimal128 -> Decimal256 requires converting, then scaling // Decimal256 -> Decimal128 requires scaling, then truncating @@ -413,15 +413,15 @@ struct DecimalConversions<Decimal128, Decimal256> { static Decimal256 ConvertInput(Decimal256&& val) { return val; } static Decimal128 ConvertOutput(Decimal256&& val) { return Decimal128(val.little_endian_array()[1], val.little_endian_array()[0]); - } + } }; - + template <> struct DecimalConversions<Decimal128, Decimal128> { static Decimal128 ConvertInput(Decimal128&& val) { return val; } static Decimal128 ConvertOutput(Decimal128&& val) { return val; } -}; - +}; + struct UnsafeUpscaleDecimal { template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext*, Arg0Value val, Status*) const { @@ -431,17 +431,17 @@ struct UnsafeUpscaleDecimal { int32_t by_; }; -struct UnsafeDownscaleDecimal { +struct UnsafeDownscaleDecimal { template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext*, Arg0Value val, Status*) const { using Conv = DecimalConversions<OutValue, Arg0Value>; return Conv::ConvertOutput( Conv::ConvertInput(std::move(val)).ReduceScaleBy(by_, false)); - } + } int32_t by_; -}; - -struct SafeRescaleDecimal { +}; + +struct SafeRescaleDecimal { template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext*, Arg0Value val, Status* st) const { using Conv = DecimalConversions<OutValue, Arg0Value>; @@ -450,7 +450,7 @@ struct SafeRescaleDecimal { if (ARROW_PREDICT_FALSE(!maybe_rescaled.ok())) { *st = maybe_rescaled.status(); return {}; // Zero - } + } if (ARROW_PREDICT_TRUE(maybe_rescaled->FitsInPrecision(out_precision_))) { return Conv::ConvertOutput(maybe_rescaled.MoveValueUnsafe()); @@ -458,199 +458,199 @@ struct SafeRescaleDecimal { *st = Status::Invalid("Decimal value does not fit in precision ", out_precision_); return {}; // Zero - } - - int32_t out_scale_, out_precision_, in_scale_; -}; - + } + + int32_t out_scale_, out_precision_, in_scale_; +}; + template <typename O, typename I> struct CastFunctor<O, I, enable_if_t<is_decimal_type<O>::value && is_decimal_type<I>::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast<const CastState*>(ctx->state())->options; - + const auto& options = checked_cast<const CastState*>(ctx->state())->options; + const auto& in_type = checked_cast<const I&>(*batch[0].type()); const auto& out_type = checked_cast<const O&>(*out->type()); const auto in_scale = in_type.scale(); const auto out_scale = out_type.scale(); - - if (options.allow_decimal_truncate) { - if (in_scale < out_scale) { - // Unsafe upscale + + if (options.allow_decimal_truncate) { + if (in_scale < out_scale) { + // Unsafe upscale applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimal> kernel( UnsafeUpscaleDecimal{out_scale - in_scale}); - return kernel.Exec(ctx, batch, out); - } else { - // Unsafe downscale + return kernel.Exec(ctx, batch, out); + } else { + // Unsafe downscale applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimal> kernel( UnsafeDownscaleDecimal{in_scale - out_scale}); - return kernel.Exec(ctx, batch, out); - } - } + return kernel.Exec(ctx, batch, out); + } + } // Safe rescale applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimal> kernel( SafeRescaleDecimal{out_scale, out_type.precision(), in_scale}); return kernel.Exec(ctx, batch, out); - } -}; - -// ---------------------------------------------------------------------- -// Real to decimal - -struct RealToDecimal { - template <typename OutValue, typename RealType> + } +}; + +// ---------------------------------------------------------------------- +// Real to decimal + +struct RealToDecimal { + template <typename OutValue, typename RealType> OutValue Call(KernelContext*, RealType val, Status* st) const { auto maybe_decimal = OutValue::FromReal(val, out_precision_, out_scale_); if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) { return maybe_decimal.MoveValueUnsafe(); - } + } if (!allow_truncate_) { *st = maybe_decimal.status(); } return {}; // Zero - } - - int32_t out_scale_, out_precision_; - bool allow_truncate_; -}; - + } + + int32_t out_scale_, out_precision_; + bool allow_truncate_; +}; + template <typename O, typename I> struct CastFunctor<O, I, enable_if_t<is_decimal_type<O>::value && is_floating_type<I>::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast<const CastState*>(ctx->state())->options; + const auto& options = checked_cast<const CastState*>(ctx->state())->options; const auto& out_type = checked_cast<const O&>(*out->type()); const auto out_scale = out_type.scale(); const auto out_precision = out_type.precision(); - + applicator::ScalarUnaryNotNullStateful<O, I, RealToDecimal> kernel( - RealToDecimal{out_scale, out_precision, options.allow_decimal_truncate}); - return kernel.Exec(ctx, batch, out); - } -}; - -// ---------------------------------------------------------------------- -// Decimal to real - -struct DecimalToReal { - template <typename RealType, typename Arg0Value> + RealToDecimal{out_scale, out_precision, options.allow_decimal_truncate}); + return kernel.Exec(ctx, batch, out); + } +}; + +// ---------------------------------------------------------------------- +// Decimal to real + +struct DecimalToReal { + template <typename RealType, typename Arg0Value> RealType Call(KernelContext*, const Arg0Value& val, Status*) const { return val.template ToReal<RealType>(in_scale_); - } - - int32_t in_scale_; -}; - + } + + int32_t in_scale_; +}; + template <typename O, typename I> struct CastFunctor<O, I, enable_if_t<is_floating_type<O>::value && is_decimal_type<I>::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& in_type = checked_cast<const I&>(*batch[0].type()); const auto in_scale = in_type.scale(); - + applicator::ScalarUnaryNotNullStateful<O, I, DecimalToReal> kernel( - DecimalToReal{in_scale}); - return kernel.Exec(ctx, batch, out); - } -}; - -// ---------------------------------------------------------------------- -// Top-level kernel instantiation - -namespace { - -template <typename OutType> -void AddCommonNumberCasts(const std::shared_ptr<DataType>& out_ty, CastFunction* func) { - AddCommonCasts(out_ty->id(), out_ty, func); - - // Cast from boolean to number - DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, - CastFunctor<OutType, BooleanType>::Exec)); - - // Cast from other strings - for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) { - auto exec = GenerateVarBinaryBase<CastFunctor, OutType>(*in_ty); - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, exec)); - } -} - -template <typename OutType> -std::shared_ptr<CastFunction> GetCastToInteger(std::string name) { - auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id); - auto out_ty = TypeTraits<OutType>::type_singleton(); - - for (const std::shared_ptr<DataType>& in_ty : IntTypes()) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToInteger)); - } - - // Cast from floating point - for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger)); - } - - // From other numbers to integer - AddCommonNumberCasts<OutType>(out_ty, func.get()); - - // From decimal to integer + DecimalToReal{in_scale}); + return kernel.Exec(ctx, batch, out); + } +}; + +// ---------------------------------------------------------------------- +// Top-level kernel instantiation + +namespace { + +template <typename OutType> +void AddCommonNumberCasts(const std::shared_ptr<DataType>& out_ty, CastFunction* func) { + AddCommonCasts(out_ty->id(), out_ty, func); + + // Cast from boolean to number + DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, + CastFunctor<OutType, BooleanType>::Exec)); + + // Cast from other strings + for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) { + auto exec = GenerateVarBinaryBase<CastFunctor, OutType>(*in_ty); + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, exec)); + } +} + +template <typename OutType> +std::shared_ptr<CastFunction> GetCastToInteger(std::string name) { + auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id); + auto out_ty = TypeTraits<OutType>::type_singleton(); + + for (const std::shared_ptr<DataType>& in_ty : IntTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToInteger)); + } + + // Cast from floating point + for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger)); + } + + // From other numbers to integer + AddCommonNumberCasts<OutType>(out_ty, func.get()); + + // From decimal to integer DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty, - CastFunctor<OutType, Decimal128Type>::Exec)); + CastFunctor<OutType, Decimal128Type>::Exec)); DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty, CastFunctor<OutType, Decimal256Type>::Exec)); - return func; -} - -template <typename OutType> -std::shared_ptr<CastFunction> GetCastToFloating(std::string name) { - auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id); - auto out_ty = TypeTraits<OutType>::type_singleton(); - - // Casts from integer to floating point - for (const std::shared_ptr<DataType>& in_ty : IntTypes()) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToFloating)); - } - - // Cast from floating point - for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating)); - } - - // From other numbers to floating point - AddCommonNumberCasts<OutType>(out_ty, func.get()); - - // From decimal to floating point + return func; +} + +template <typename OutType> +std::shared_ptr<CastFunction> GetCastToFloating(std::string name) { + auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id); + auto out_ty = TypeTraits<OutType>::type_singleton(); + + // Casts from integer to floating point + for (const std::shared_ptr<DataType>& in_ty : IntTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToFloating)); + } + + // Cast from floating point + for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating)); + } + + // From other numbers to floating point + AddCommonNumberCasts<OutType>(out_ty, func.get()); + + // From decimal to floating point DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty, - CastFunctor<OutType, Decimal128Type>::Exec)); + CastFunctor<OutType, Decimal128Type>::Exec)); DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty, CastFunctor<OutType, Decimal256Type>::Exec)); - return func; -} - + return func; +} + std::shared_ptr<CastFunction> GetCastToDecimal128() { - OutputType sig_out_ty(ResolveOutputFromOptions); - + OutputType sig_out_ty(ResolveOutputFromOptions); + auto func = std::make_shared<CastFunction>("cast_decimal", Type::DECIMAL128); AddCommonCasts(Type::DECIMAL128, sig_out_ty, func.get()); - - // Cast from floating point - DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty, - CastFunctor<Decimal128Type, FloatType>::Exec)); - DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty, - CastFunctor<Decimal128Type, DoubleType>::Exec)); - - // Cast from other decimal - auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec; - // We resolve the output type of this kernel from the CastOptions + + // Cast from floating point + DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty, + CastFunctor<Decimal128Type, FloatType>::Exec)); + DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty, + CastFunctor<Decimal128Type, DoubleType>::Exec)); + + // Cast from other decimal + auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec; + // We resolve the output type of this kernel from the CastOptions DCHECK_OK( func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec)); exec = CastFunctor<Decimal128Type, Decimal256Type>::Exec; DCHECK_OK( func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec)); - return func; -} - + return func; +} + std::shared_ptr<CastFunction> GetCastToDecimal256() { OutputType sig_out_ty(ResolveOutputFromOptions); @@ -673,55 +673,55 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() { return func; } -} // namespace - -std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() { - std::vector<std::shared_ptr<CastFunction>> functions; - - // Make a cast to null that does not do much. Not sure why we need to be able - // to cast from dict<null> -> null but there are unit tests for it - auto cast_null = std::make_shared<CastFunction>("cast_null", Type::NA); +} // namespace + +std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() { + std::vector<std::shared_ptr<CastFunction>> functions; + + // Make a cast to null that does not do much. Not sure why we need to be able + // to cast from dict<null> -> null but there are unit tests for it + auto cast_null = std::make_shared<CastFunction>("cast_null", Type::NA); DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), OutputAllNull)); - functions.push_back(cast_null); - - functions.push_back(GetCastToInteger<Int8Type>("cast_int8")); - functions.push_back(GetCastToInteger<Int16Type>("cast_int16")); - - auto cast_int32 = GetCastToInteger<Int32Type>("cast_int32"); - // Convert DATE32 or TIME32 to INT32 zero copy - AddZeroCopyCast(Type::DATE32, date32(), int32(), cast_int32.get()); - AddZeroCopyCast(Type::TIME32, InputType(Type::TIME32), int32(), cast_int32.get()); - functions.push_back(cast_int32); - - auto cast_int64 = GetCastToInteger<Int64Type>("cast_int64"); - // Convert DATE64, DURATION, TIMESTAMP, TIME64 to INT64 zero copy - AddZeroCopyCast(Type::DATE64, InputType(Type::DATE64), int64(), cast_int64.get()); - AddZeroCopyCast(Type::DURATION, InputType(Type::DURATION), int64(), cast_int64.get()); - AddZeroCopyCast(Type::TIMESTAMP, InputType(Type::TIMESTAMP), int64(), cast_int64.get()); - AddZeroCopyCast(Type::TIME64, InputType(Type::TIME64), int64(), cast_int64.get()); - functions.push_back(cast_int64); - - functions.push_back(GetCastToInteger<UInt8Type>("cast_uint8")); - functions.push_back(GetCastToInteger<UInt16Type>("cast_uint16")); - functions.push_back(GetCastToInteger<UInt32Type>("cast_uint32")); - functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64")); - - // HalfFloat is a bit brain-damaged for now - auto cast_half_float = - std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT); - AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get()); - functions.push_back(cast_half_float); - - functions.push_back(GetCastToFloating<FloatType>("cast_float")); - functions.push_back(GetCastToFloating<DoubleType>("cast_double")); - + functions.push_back(cast_null); + + functions.push_back(GetCastToInteger<Int8Type>("cast_int8")); + functions.push_back(GetCastToInteger<Int16Type>("cast_int16")); + + auto cast_int32 = GetCastToInteger<Int32Type>("cast_int32"); + // Convert DATE32 or TIME32 to INT32 zero copy + AddZeroCopyCast(Type::DATE32, date32(), int32(), cast_int32.get()); + AddZeroCopyCast(Type::TIME32, InputType(Type::TIME32), int32(), cast_int32.get()); + functions.push_back(cast_int32); + + auto cast_int64 = GetCastToInteger<Int64Type>("cast_int64"); + // Convert DATE64, DURATION, TIMESTAMP, TIME64 to INT64 zero copy + AddZeroCopyCast(Type::DATE64, InputType(Type::DATE64), int64(), cast_int64.get()); + AddZeroCopyCast(Type::DURATION, InputType(Type::DURATION), int64(), cast_int64.get()); + AddZeroCopyCast(Type::TIMESTAMP, InputType(Type::TIMESTAMP), int64(), cast_int64.get()); + AddZeroCopyCast(Type::TIME64, InputType(Type::TIME64), int64(), cast_int64.get()); + functions.push_back(cast_int64); + + functions.push_back(GetCastToInteger<UInt8Type>("cast_uint8")); + functions.push_back(GetCastToInteger<UInt16Type>("cast_uint16")); + functions.push_back(GetCastToInteger<UInt32Type>("cast_uint32")); + functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64")); + + // HalfFloat is a bit brain-damaged for now + auto cast_half_float = + std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT); + AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get()); + functions.push_back(cast_half_float); + + functions.push_back(GetCastToFloating<FloatType>("cast_float")); + functions.push_back(GetCastToFloating<DoubleType>("cast_double")); + functions.push_back(GetCastToDecimal128()); functions.push_back(GetCastToDecimal256()); - - return functions; -} - -} // namespace internal -} // namespace compute -} // namespace arrow + + return functions; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 3ce537b722..56a19a69a1 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -1,107 +1,107 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include <limits> - -#include "arrow/array/array_base.h" + +#include "arrow/array/array_base.h" #include "arrow/array/builder_binary.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/scalar_cast_internal.h" -#include "arrow/result.h" -#include "arrow/util/formatting.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/result.h" +#include "arrow/util/formatting.h" #include "arrow/util/int_util.h" -#include "arrow/util/optional.h" -#include "arrow/util/utf8.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using internal::StringFormatter; -using util::InitializeUTF8; -using util::ValidateUTF8; - -namespace compute { -namespace internal { - +#include "arrow/util/optional.h" +#include "arrow/util/utf8.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +using internal::StringFormatter; +using util::InitializeUTF8; +using util::ValidateUTF8; + +namespace compute { +namespace internal { + namespace { -// ---------------------------------------------------------------------- -// Number / Boolean to String - +// ---------------------------------------------------------------------- +// Number / Boolean to String + template <typename O, typename I> struct NumericToStringCastFunctor { - using value_type = typename TypeTraits<I>::CType; - using BuilderType = typename TypeTraits<O>::BuilderType; - using FormatterType = StringFormatter<I>; - + using value_type = typename TypeTraits<I>::CType; + using BuilderType = typename TypeTraits<O>::BuilderType; + using FormatterType = StringFormatter<I>; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { DCHECK(out->is_array()); - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); return Convert(ctx, input, output); - } - - static Status Convert(KernelContext* ctx, const ArrayData& input, ArrayData* output) { - FormatterType formatter(input.type); - BuilderType builder(input.type, ctx->memory_pool()); - RETURN_NOT_OK(VisitArrayDataInline<I>( - input, - [&](value_type v) { - return formatter(v, [&](util::string_view v) { return builder.Append(v); }); - }, - [&]() { return builder.AppendNull(); })); - - std::shared_ptr<Array> output_array; - RETURN_NOT_OK(builder.Finish(&output_array)); - *output = std::move(*output_array->data()); - return Status::OK(); - } -}; - -// ---------------------------------------------------------------------- + } + + static Status Convert(KernelContext* ctx, const ArrayData& input, ArrayData* output) { + FormatterType formatter(input.type); + BuilderType builder(input.type, ctx->memory_pool()); + RETURN_NOT_OK(VisitArrayDataInline<I>( + input, + [&](value_type v) { + return formatter(v, [&](util::string_view v) { return builder.Append(v); }); + }, + [&]() { return builder.AppendNull(); })); + + std::shared_ptr<Array> output_array; + RETURN_NOT_OK(builder.Finish(&output_array)); + *output = std::move(*output_array->data()); + return Status::OK(); + } +}; + +// ---------------------------------------------------------------------- // Binary-like to binary-like -// - -#if defined(_MSC_VER) -// Silence warning: """'visitor': unreferenced local variable""" -#pragma warning(push) -#pragma warning(disable : 4101) -#endif - -struct Utf8Validator { - Status VisitNull() { return Status::OK(); } - - Status VisitValue(util::string_view str) { - if (ARROW_PREDICT_FALSE(!ValidateUTF8(str))) { - return Status::Invalid("Invalid UTF8 payload"); - } - return Status::OK(); - } -}; - -template <typename I, typename O> +// + +#if defined(_MSC_VER) +// Silence warning: """'visitor': unreferenced local variable""" +#pragma warning(push) +#pragma warning(disable : 4101) +#endif + +struct Utf8Validator { + Status VisitNull() { return Status::OK(); } + + Status VisitValue(util::string_view str) { + if (ARROW_PREDICT_FALSE(!ValidateUTF8(str))) { + return Status::Invalid("Invalid UTF8 payload"); + } + return Status::OK(); + } +}; + +template <typename I, typename O> Status CastBinaryToBinaryOffsets(KernelContext* ctx, const ArrayData& input, ArrayData* output) { static_assert(std::is_same<I, O>::value, "Cast same-width offsets (no-op)"); return Status::OK(); } - + // Upcast offsets -template <> +template <> Status CastBinaryToBinaryOffsets<int32_t, int64_t>(KernelContext* ctx, const ArrayData& input, ArrayData* output) { @@ -117,15 +117,15 @@ Status CastBinaryToBinaryOffsets<int32_t, int64_t>(KernelContext* ctx, output->length + 1); return Status::OK(); } - + // Downcast offsets -template <> +template <> Status CastBinaryToBinaryOffsets<int64_t, int32_t>(KernelContext* ctx, const ArrayData& input, ArrayData* output) { using input_offset_type = int64_t; using output_offset_type = int32_t; - + constexpr input_offset_type kMaxOffset = std::numeric_limits<output_offset_type>::max(); auto input_offsets = input.GetValues<input_offset_type>(1); @@ -167,31 +167,31 @@ Status BinaryToBinaryCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* ctx, input, out->mutable_array()); } -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // ---------------------------------------------------------------------- // Cast functions registration - -template <typename OutType> + +template <typename OutType> void AddNumberToStringCasts(CastFunction* func) { auto out_ty = TypeTraits<OutType>::type_singleton(); - DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, + DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, TrivialScalarUnaryAsArraysExec( NumericToStringCastFunctor<OutType, BooleanType>::Exec), - NullHandling::COMPUTED_NO_PREALLOCATE)); - - for (const std::shared_ptr<DataType>& in_ty : NumericTypes()) { + NullHandling::COMPUTED_NO_PREALLOCATE)); + + for (const std::shared_ptr<DataType>& in_ty : NumericTypes()) { DCHECK_OK( func->AddKernel(in_ty->id(), {in_ty}, out_ty, TrivialScalarUnaryAsArraysExec( GenerateNumeric<NumericToStringCastFunctor, OutType>(*in_ty)), NullHandling::COMPUTED_NO_PREALLOCATE)); - } -} - + } +} + template <typename OutType, typename InType> void AddBinaryToBinaryCast(CastFunction* func) { auto in_ty = TypeTraits<InType>::type_singleton(); @@ -213,35 +213,35 @@ void AddBinaryToBinaryCast(CastFunction* func) { } // namespace -std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() { - auto cast_binary = std::make_shared<CastFunction>("cast_binary", Type::BINARY); - AddCommonCasts(Type::BINARY, binary(), cast_binary.get()); +std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() { + auto cast_binary = std::make_shared<CastFunction>("cast_binary", Type::BINARY); + AddCommonCasts(Type::BINARY, binary(), cast_binary.get()); AddBinaryToBinaryCast<BinaryType>(cast_binary.get()); - - auto cast_large_binary = - std::make_shared<CastFunction>("cast_large_binary", Type::LARGE_BINARY); - AddCommonCasts(Type::LARGE_BINARY, large_binary(), cast_large_binary.get()); + + auto cast_large_binary = + std::make_shared<CastFunction>("cast_large_binary", Type::LARGE_BINARY); + AddCommonCasts(Type::LARGE_BINARY, large_binary(), cast_large_binary.get()); AddBinaryToBinaryCast<LargeBinaryType>(cast_large_binary.get()); - - auto cast_string = std::make_shared<CastFunction>("cast_string", Type::STRING); - AddCommonCasts(Type::STRING, utf8(), cast_string.get()); + + auto cast_string = std::make_shared<CastFunction>("cast_string", Type::STRING); + AddCommonCasts(Type::STRING, utf8(), cast_string.get()); AddNumberToStringCasts<StringType>(cast_string.get()); AddBinaryToBinaryCast<StringType>(cast_string.get()); - - auto cast_large_string = - std::make_shared<CastFunction>("cast_large_string", Type::LARGE_STRING); - AddCommonCasts(Type::LARGE_STRING, large_utf8(), cast_large_string.get()); + + auto cast_large_string = + std::make_shared<CastFunction>("cast_large_string", Type::LARGE_STRING); + AddCommonCasts(Type::LARGE_STRING, large_utf8(), cast_large_string.get()); AddNumberToStringCasts<LargeStringType>(cast_large_string.get()); AddBinaryToBinaryCast<LargeStringType>(cast_large_string.get()); - + auto cast_fsb = std::make_shared<CastFunction>("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY); AddCommonCasts(Type::FIXED_SIZE_BINARY, OutputType(ResolveOutputFromOptions), cast_fsb.get()); return {cast_binary, cast_large_binary, cast_string, cast_large_string, cast_fsb}; -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 1a58fce7c7..a06d473329 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -1,260 +1,260 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Implementation of casting to (or between) temporal types - -#include <limits> - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to (or between) temporal types + +#include <limits> + #include "arrow/array/builder_time.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/scalar_cast_internal.h" -#include "arrow/util/bitmap_reader.h" -#include "arrow/util/time.h" -#include "arrow/util/value_parsing.h" - -namespace arrow { - -using internal::ParseValue; - -namespace compute { -namespace internal { - -constexpr int64_t kMillisecondsInDay = 86400000; - -// ---------------------------------------------------------------------- -// From one timestamp to another - -template <typename in_type, typename out_type> +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/time.h" +#include "arrow/util/value_parsing.h" + +namespace arrow { + +using internal::ParseValue; + +namespace compute { +namespace internal { + +constexpr int64_t kMillisecondsInDay = 86400000; + +// ---------------------------------------------------------------------- +// From one timestamp to another + +template <typename in_type, typename out_type> Status ShiftTime(KernelContext* ctx, const util::DivideOrMultiply factor_op, const int64_t factor, const ArrayData& input, ArrayData* output) { - const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; + const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; auto in_data = input.GetValues<in_type>(1); - auto out_data = output->GetMutableValues<out_type>(1); - - if (factor == 1) { - for (int64_t i = 0; i < input.length; i++) { - out_data[i] = static_cast<out_type>(in_data[i]); - } - } else if (factor_op == util::MULTIPLY) { - if (options.allow_time_overflow) { - for (int64_t i = 0; i < input.length; i++) { - out_data[i] = static_cast<out_type>(in_data[i] * factor); - } - } else { + auto out_data = output->GetMutableValues<out_type>(1); + + if (factor == 1) { + for (int64_t i = 0; i < input.length; i++) { + out_data[i] = static_cast<out_type>(in_data[i]); + } + } else if (factor_op == util::MULTIPLY) { + if (options.allow_time_overflow) { + for (int64_t i = 0; i < input.length; i++) { + out_data[i] = static_cast<out_type>(in_data[i] * factor); + } + } else { #define RAISE_OVERFLOW_CAST(VAL) \ return Status::Invalid("Casting from ", input.type->ToString(), " to ", \ output->type->ToString(), " would result in ", \ "out of bounds timestamp: ", VAL); - - int64_t max_val = std::numeric_limits<int64_t>::max() / factor; - int64_t min_val = std::numeric_limits<int64_t>::min() / factor; - if (input.null_count != 0) { - BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length); - for (int64_t i = 0; i < input.length; i++) { - if (bit_reader.IsSet() && (in_data[i] < min_val || in_data[i] > max_val)) { - RAISE_OVERFLOW_CAST(in_data[i]); - } - out_data[i] = static_cast<out_type>(in_data[i] * factor); - bit_reader.Next(); - } - } else { - for (int64_t i = 0; i < input.length; i++) { - if (in_data[i] < min_val || in_data[i] > max_val) { - RAISE_OVERFLOW_CAST(in_data[i]); - } - out_data[i] = static_cast<out_type>(in_data[i] * factor); - } - } - -#undef RAISE_OVERFLOW_CAST - } - } else { - if (options.allow_time_truncate) { - for (int64_t i = 0; i < input.length; i++) { - out_data[i] = static_cast<out_type>(in_data[i] / factor); - } - } else { + + int64_t max_val = std::numeric_limits<int64_t>::max() / factor; + int64_t min_val = std::numeric_limits<int64_t>::min() / factor; + if (input.null_count != 0) { + BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length); + for (int64_t i = 0; i < input.length; i++) { + if (bit_reader.IsSet() && (in_data[i] < min_val || in_data[i] > max_val)) { + RAISE_OVERFLOW_CAST(in_data[i]); + } + out_data[i] = static_cast<out_type>(in_data[i] * factor); + bit_reader.Next(); + } + } else { + for (int64_t i = 0; i < input.length; i++) { + if (in_data[i] < min_val || in_data[i] > max_val) { + RAISE_OVERFLOW_CAST(in_data[i]); + } + out_data[i] = static_cast<out_type>(in_data[i] * factor); + } + } + +#undef RAISE_OVERFLOW_CAST + } + } else { + if (options.allow_time_truncate) { + for (int64_t i = 0; i < input.length; i++) { + out_data[i] = static_cast<out_type>(in_data[i] / factor); + } + } else { #define RAISE_INVALID_CAST(VAL) \ return Status::Invalid("Casting from ", input.type->ToString(), " to ", \ output->type->ToString(), " would lose data: ", VAL); - - if (input.null_count != 0) { - BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length); - for (int64_t i = 0; i < input.length; i++) { - out_data[i] = static_cast<out_type>(in_data[i] / factor); - if (bit_reader.IsSet() && (out_data[i] * factor != in_data[i])) { - RAISE_INVALID_CAST(in_data[i]); - } - bit_reader.Next(); - } - } else { - for (int64_t i = 0; i < input.length; i++) { - out_data[i] = static_cast<out_type>(in_data[i] / factor); - if (out_data[i] * factor != in_data[i]) { - RAISE_INVALID_CAST(in_data[i]); - } - } - } - -#undef RAISE_INVALID_CAST - } - } + + if (input.null_count != 0) { + BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length); + for (int64_t i = 0; i < input.length; i++) { + out_data[i] = static_cast<out_type>(in_data[i] / factor); + if (bit_reader.IsSet() && (out_data[i] * factor != in_data[i])) { + RAISE_INVALID_CAST(in_data[i]); + } + bit_reader.Next(); + } + } else { + for (int64_t i = 0; i < input.length; i++) { + out_data[i] = static_cast<out_type>(in_data[i] / factor); + if (out_data[i] * factor != in_data[i]) { + RAISE_INVALID_CAST(in_data[i]); + } + } + } + +#undef RAISE_INVALID_CAST + } + } return Status::OK(); -} - -// <TimestampType, TimestampType> and <DurationType, DurationType> -template <typename O, typename I> -struct CastFunctor< - O, I, - enable_if_t<(is_timestamp_type<O>::value && is_timestamp_type<I>::value) || - (is_duration_type<O>::value && is_duration_type<I>::value)>> { +} + +// <TimestampType, TimestampType> and <DurationType, DurationType> +template <typename O, typename I> +struct CastFunctor< + O, I, + enable_if_t<(is_timestamp_type<O>::value && is_timestamp_type<I>::value) || + (is_duration_type<O>::value && is_duration_type<I>::value)>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DCHECK_EQ(batch[0].kind(), Datum::ARRAY); - - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); - - // If units are the same, zero copy, otherwise convert - const auto& in_type = checked_cast<const I&>(*batch[0].type()); - const auto& out_type = checked_cast<const O&>(*output->type); - - // The units may be equal if the time zones are different. We might go to - // lengths to make this zero copy in the future but we leave it for now - - auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit()); + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); + + // If units are the same, zero copy, otherwise convert + const auto& in_type = checked_cast<const I&>(*batch[0].type()); + const auto& out_type = checked_cast<const O&>(*output->type); + + // The units may be equal if the time zones are different. We might go to + // lengths to make this zero copy in the future but we leave it for now + + auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit()); return ShiftTime<int64_t, int64_t>(ctx, conversion.first, conversion.second, input, output); - } -}; - -template <> -struct CastFunctor<Date32Type, TimestampType> { + } +}; + +template <> +struct CastFunctor<Date32Type, TimestampType> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DCHECK_EQ(batch[0].kind(), Datum::ARRAY); - - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); - - const auto& in_type = checked_cast<const TimestampType&>(*input.type); - - static const int64_t kTimestampToDateFactors[4] = { - 86400LL, // SECOND - 86400LL * 1000LL, // MILLI - 86400LL * 1000LL * 1000LL, // MICRO - 86400LL * 1000LL * 1000LL * 1000LL, // NANO - }; - - const int64_t factor = kTimestampToDateFactors[static_cast<int>(in_type.unit())]; + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); + + const auto& in_type = checked_cast<const TimestampType&>(*input.type); + + static const int64_t kTimestampToDateFactors[4] = { + 86400LL, // SECOND + 86400LL * 1000LL, // MILLI + 86400LL * 1000LL * 1000LL, // MICRO + 86400LL * 1000LL * 1000LL * 1000LL, // NANO + }; + + const int64_t factor = kTimestampToDateFactors[static_cast<int>(in_type.unit())]; return ShiftTime<int64_t, int32_t>(ctx, util::DIVIDE, factor, input, output); - } -}; - -template <> -struct CastFunctor<Date64Type, TimestampType> { + } +}; + +template <> +struct CastFunctor<Date64Type, TimestampType> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DCHECK_EQ(batch[0].kind(), Datum::ARRAY); - - const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); - const auto& in_type = checked_cast<const TimestampType&>(*input.type); - - auto conversion = util::GetTimestampConversion(in_type.unit(), TimeUnit::MILLI); + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options; + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); + const auto& in_type = checked_cast<const TimestampType&>(*input.type); + + auto conversion = util::GetTimestampConversion(in_type.unit(), TimeUnit::MILLI); RETURN_NOT_OK((ShiftTime<int64_t, int64_t>(ctx, conversion.first, conversion.second, input, output))); - - // Ensure that intraday milliseconds have been zeroed out - auto out_data = output->GetMutableValues<int64_t>(1); - - if (input.null_count != 0) { - BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length); - - for (int64_t i = 0; i < input.length; ++i) { - const int64_t remainder = out_data[i] % kMillisecondsInDay; - if (ARROW_PREDICT_FALSE(!options.allow_time_truncate && bit_reader.IsSet() && - remainder > 0)) { + + // Ensure that intraday milliseconds have been zeroed out + auto out_data = output->GetMutableValues<int64_t>(1); + + if (input.null_count != 0) { + BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length); + + for (int64_t i = 0; i < input.length; ++i) { + const int64_t remainder = out_data[i] % kMillisecondsInDay; + if (ARROW_PREDICT_FALSE(!options.allow_time_truncate && bit_reader.IsSet() && + remainder > 0)) { return Status::Invalid("Timestamp value had non-zero intraday milliseconds"); - } - out_data[i] -= remainder; - bit_reader.Next(); - } - } else { - for (int64_t i = 0; i < input.length; ++i) { - const int64_t remainder = out_data[i] % kMillisecondsInDay; - if (ARROW_PREDICT_FALSE(!options.allow_time_truncate && remainder > 0)) { + } + out_data[i] -= remainder; + bit_reader.Next(); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + const int64_t remainder = out_data[i] % kMillisecondsInDay; + if (ARROW_PREDICT_FALSE(!options.allow_time_truncate && remainder > 0)) { return Status::Invalid("Timestamp value had non-zero intraday milliseconds"); - } - out_data[i] -= remainder; - } - } + } + out_data[i] -= remainder; + } + } return Status::OK(); - } -}; - -// ---------------------------------------------------------------------- -// From one time32 or time64 to another - -template <typename O, typename I> -struct CastFunctor<O, I, enable_if_t<is_time_type<I>::value && is_time_type<O>::value>> { - using in_t = typename I::c_type; - using out_t = typename O::c_type; - + } +}; + +// ---------------------------------------------------------------------- +// From one time32 or time64 to another + +template <typename O, typename I> +struct CastFunctor<O, I, enable_if_t<is_time_type<I>::value && is_time_type<O>::value>> { + using in_t = typename I::c_type; + using out_t = typename O::c_type; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DCHECK_EQ(batch[0].kind(), Datum::ARRAY); - - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); - - // If units are the same, zero copy, otherwise convert - const auto& in_type = checked_cast<const I&>(*input.type); - const auto& out_type = checked_cast<const O&>(*output->type); - DCHECK_NE(in_type.unit(), out_type.unit()) << "Do not cast equal types"; - auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit()); + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); + + // If units are the same, zero copy, otherwise convert + const auto& in_type = checked_cast<const I&>(*input.type); + const auto& out_type = checked_cast<const O&>(*output->type); + DCHECK_NE(in_type.unit(), out_type.unit()) << "Do not cast equal types"; + auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit()); return ShiftTime<in_t, out_t>(ctx, conversion.first, conversion.second, input, output); - } -}; - -// ---------------------------------------------------------------------- -// Between date32 and date64 - -template <> -struct CastFunctor<Date64Type, Date32Type> { + } +}; + +// ---------------------------------------------------------------------- +// Between date32 and date64 + +template <> +struct CastFunctor<Date64Type, Date32Type> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DCHECK_EQ(batch[0].kind(), Datum::ARRAY); - + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + return ShiftTime<int32_t, int64_t>(ctx, util::MULTIPLY, kMillisecondsInDay, *batch[0].array(), out->mutable_array()); - } -}; - -template <> -struct CastFunctor<Date32Type, Date64Type> { + } +}; + +template <> +struct CastFunctor<Date32Type, Date64Type> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DCHECK_EQ(batch[0].kind(), Datum::ARRAY); - + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + return ShiftTime<int64_t, int32_t>(ctx, util::DIVIDE, kMillisecondsInDay, *batch[0].array(), out->mutable_array()); - } -}; - -// ---------------------------------------------------------------------- + } +}; + +// ---------------------------------------------------------------------- // date32, date64 to timestamp template <> @@ -289,164 +289,164 @@ struct CastFunctor<TimestampType, Date64Type> { }; // ---------------------------------------------------------------------- -// String to Timestamp - -struct ParseTimestamp { - template <typename OutValue, typename Arg0Value> +// String to Timestamp + +struct ParseTimestamp { + template <typename OutValue, typename Arg0Value> OutValue Call(KernelContext*, Arg0Value val, Status* st) const { - OutValue result = 0; - if (ARROW_PREDICT_FALSE(!ParseValue(type, val.data(), val.size(), &result))) { + OutValue result = 0; + if (ARROW_PREDICT_FALSE(!ParseValue(type, val.data(), val.size(), &result))) { *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ", type.ToString()); - } - return result; - } - - const TimestampType& type; -}; - -template <typename I> -struct CastFunctor<TimestampType, I, enable_if_t<is_base_binary_type<I>::value>> { + } + return result; + } + + const TimestampType& type; +}; + +template <typename I> +struct CastFunctor<TimestampType, I, enable_if_t<is_base_binary_type<I>::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& out_type = checked_cast<const TimestampType&>(*out->type()); - applicator::ScalarUnaryNotNullStateful<TimestampType, I, ParseTimestamp> kernel( - ParseTimestamp{out_type}); - return kernel.Exec(ctx, batch, out); - } -}; - -template <typename Type> -void AddCrossUnitCast(CastFunction* func) { - ScalarKernel kernel; + const auto& out_type = checked_cast<const TimestampType&>(*out->type()); + applicator::ScalarUnaryNotNullStateful<TimestampType, I, ParseTimestamp> kernel( + ParseTimestamp{out_type}); + return kernel.Exec(ctx, batch, out); + } +}; + +template <typename Type> +void AddCrossUnitCast(CastFunction* func) { + ScalarKernel kernel; kernel.exec = TrivialScalarUnaryAsArraysExec(CastFunctor<Type, Type>::Exec); - kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType); - DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel))); -} - -std::shared_ptr<CastFunction> GetDate32Cast() { - auto func = std::make_shared<CastFunction>("cast_date32", Type::DATE32); - auto out_ty = date32(); - AddCommonCasts(Type::DATE32, out_ty, func.get()); - - // int32 -> date32 - AddZeroCopyCast(Type::INT32, int32(), date32(), func.get()); - - // date64 -> date32 - AddSimpleCast<Date64Type, Date32Type>(date64(), date32(), func.get()); - - // timestamp -> date32 - AddSimpleCast<TimestampType, Date32Type>(InputType(Type::TIMESTAMP), date32(), - func.get()); - return func; -} - -std::shared_ptr<CastFunction> GetDate64Cast() { - auto func = std::make_shared<CastFunction>("cast_date64", Type::DATE64); - auto out_ty = date64(); - AddCommonCasts(Type::DATE64, out_ty, func.get()); - - // int64 -> date64 - AddZeroCopyCast(Type::INT64, int64(), date64(), func.get()); - - // date32 -> date64 - AddSimpleCast<Date32Type, Date64Type>(date32(), date64(), func.get()); - - // timestamp -> date64 - AddSimpleCast<TimestampType, Date64Type>(InputType(Type::TIMESTAMP), date64(), - func.get()); - return func; -} - -std::shared_ptr<CastFunction> GetDurationCast() { - auto func = std::make_shared<CastFunction>("cast_duration", Type::DURATION); - AddCommonCasts(Type::DURATION, kOutputTargetType, func.get()); - - auto seconds = duration(TimeUnit::SECOND); - auto millis = duration(TimeUnit::MILLI); - auto micros = duration(TimeUnit::MICRO); - auto nanos = duration(TimeUnit::NANO); - - // Same integer representation - AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); - - // Between durations - AddCrossUnitCast<DurationType>(func.get()); - - return func; -} - -std::shared_ptr<CastFunction> GetTime32Cast() { - auto func = std::make_shared<CastFunction>("cast_time32", Type::TIME32); - AddCommonCasts(Type::TIME32, kOutputTargetType, func.get()); - - // Zero copy when the unit is the same or same integer representation - AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); - - // time64 -> time32 - AddSimpleCast<Time64Type, Time32Type>(InputType(Type::TIME64), kOutputTargetType, - func.get()); - - // time32 -> time32 - AddCrossUnitCast<Time32Type>(func.get()); - - return func; -} - -std::shared_ptr<CastFunction> GetTime64Cast() { - auto func = std::make_shared<CastFunction>("cast_time64", Type::TIME64); - AddCommonCasts(Type::TIME64, kOutputTargetType, func.get()); - - // Zero copy when the unit is the same or same integer representation - AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); - - // time32 -> time64 - AddSimpleCast<Time32Type, Time64Type>(InputType(Type::TIME32), kOutputTargetType, - func.get()); - - // Between durations - AddCrossUnitCast<Time64Type>(func.get()); - - return func; -} - -std::shared_ptr<CastFunction> GetTimestampCast() { - auto func = std::make_shared<CastFunction>("cast_timestamp", Type::TIMESTAMP); - AddCommonCasts(Type::TIMESTAMP, kOutputTargetType, func.get()); - - // Same integer representation - AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); - - // From date types + kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType); + DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel))); +} + +std::shared_ptr<CastFunction> GetDate32Cast() { + auto func = std::make_shared<CastFunction>("cast_date32", Type::DATE32); + auto out_ty = date32(); + AddCommonCasts(Type::DATE32, out_ty, func.get()); + + // int32 -> date32 + AddZeroCopyCast(Type::INT32, int32(), date32(), func.get()); + + // date64 -> date32 + AddSimpleCast<Date64Type, Date32Type>(date64(), date32(), func.get()); + + // timestamp -> date32 + AddSimpleCast<TimestampType, Date32Type>(InputType(Type::TIMESTAMP), date32(), + func.get()); + return func; +} + +std::shared_ptr<CastFunction> GetDate64Cast() { + auto func = std::make_shared<CastFunction>("cast_date64", Type::DATE64); + auto out_ty = date64(); + AddCommonCasts(Type::DATE64, out_ty, func.get()); + + // int64 -> date64 + AddZeroCopyCast(Type::INT64, int64(), date64(), func.get()); + + // date32 -> date64 + AddSimpleCast<Date32Type, Date64Type>(date32(), date64(), func.get()); + + // timestamp -> date64 + AddSimpleCast<TimestampType, Date64Type>(InputType(Type::TIMESTAMP), date64(), + func.get()); + return func; +} + +std::shared_ptr<CastFunction> GetDurationCast() { + auto func = std::make_shared<CastFunction>("cast_duration", Type::DURATION); + AddCommonCasts(Type::DURATION, kOutputTargetType, func.get()); + + auto seconds = duration(TimeUnit::SECOND); + auto millis = duration(TimeUnit::MILLI); + auto micros = duration(TimeUnit::MICRO); + auto nanos = duration(TimeUnit::NANO); + + // Same integer representation + AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); + + // Between durations + AddCrossUnitCast<DurationType>(func.get()); + + return func; +} + +std::shared_ptr<CastFunction> GetTime32Cast() { + auto func = std::make_shared<CastFunction>("cast_time32", Type::TIME32); + AddCommonCasts(Type::TIME32, kOutputTargetType, func.get()); + + // Zero copy when the unit is the same or same integer representation + AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); + + // time64 -> time32 + AddSimpleCast<Time64Type, Time32Type>(InputType(Type::TIME64), kOutputTargetType, + func.get()); + + // time32 -> time32 + AddCrossUnitCast<Time32Type>(func.get()); + + return func; +} + +std::shared_ptr<CastFunction> GetTime64Cast() { + auto func = std::make_shared<CastFunction>("cast_time64", Type::TIME64); + AddCommonCasts(Type::TIME64, kOutputTargetType, func.get()); + + // Zero copy when the unit is the same or same integer representation + AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); + + // time32 -> time64 + AddSimpleCast<Time32Type, Time64Type>(InputType(Type::TIME32), kOutputTargetType, + func.get()); + + // Between durations + AddCrossUnitCast<Time64Type>(func.get()); + + return func; +} + +std::shared_ptr<CastFunction> GetTimestampCast() { + auto func = std::make_shared<CastFunction>("cast_timestamp", Type::TIMESTAMP); + AddCommonCasts(Type::TIMESTAMP, kOutputTargetType, func.get()); + + // Same integer representation + AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); + + // From date types // TODO: ARROW-8876, these casts are not directly tested AddSimpleCast<Date32Type, TimestampType>(InputType(Type::DATE32), kOutputTargetType, func.get()); AddSimpleCast<Date64Type, TimestampType>(InputType(Type::DATE64), kOutputTargetType, func.get()); - - // string -> timestamp - AddSimpleCast<StringType, TimestampType>(utf8(), kOutputTargetType, func.get()); - // large_string -> timestamp - AddSimpleCast<LargeStringType, TimestampType>(large_utf8(), kOutputTargetType, - func.get()); - - // From one timestamp to another - AddCrossUnitCast<TimestampType>(func.get()); - - return func; -} - -std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts() { - std::vector<std::shared_ptr<CastFunction>> functions; - - functions.push_back(GetDate32Cast()); - functions.push_back(GetDate64Cast()); - functions.push_back(GetDurationCast()); - functions.push_back(GetTime32Cast()); - functions.push_back(GetTime64Cast()); - functions.push_back(GetTimestampCast()); - return functions; -} - -} // namespace internal -} // namespace compute -} // namespace arrow + + // string -> timestamp + AddSimpleCast<StringType, TimestampType>(utf8(), kOutputTargetType, func.get()); + // large_string -> timestamp + AddSimpleCast<LargeStringType, TimestampType>(large_utf8(), kOutputTargetType, + func.get()); + + // From one timestamp to another + AddCrossUnitCast<TimestampType>(func.get()); + + return func; +} + +std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts() { + std::vector<std::shared_ptr<CastFunction>> functions; + + functions.push_back(GetDate32Cast()); + functions.push_back(GetDate64Cast()); + functions.push_back(GetDurationCast()); + functions.push_back(GetTime32Cast()); + functions.push_back(GetTime64Cast()); + functions.push_back(GetTimestampCast()); + return functions; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc index 4342d776c3..713875937a 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -1,70 +1,70 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include <cmath> #include <limits> #include "arrow/compute/api_scalar.h" -#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/common.h" #include "arrow/util/bitmap_ops.h" - -namespace arrow { - -using internal::checked_cast; -using internal::checked_pointer_cast; -using util::string_view; - -namespace compute { -namespace internal { - -namespace { - -struct Equal { + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; +using util::string_view; + +namespace compute { +namespace internal { + +namespace { + +struct Equal { template <typename T, typename Arg0, typename Arg1> static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) { static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, ""); - return left == right; - } -}; - -struct NotEqual { + return left == right; + } +}; + +struct NotEqual { template <typename T, typename Arg0, typename Arg1> static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) { static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, ""); - return left != right; - } -}; - -struct Greater { + return left != right; + } +}; + +struct Greater { template <typename T, typename Arg0, typename Arg1> static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) { static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, ""); - return left > right; - } -}; - -struct GreaterEqual { + return left > right; + } +}; + +struct GreaterEqual { template <typename T, typename Arg0, typename Arg1> static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) { static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, ""); - return left >= right; - } -}; - + return left >= right; + } +}; + template <typename T> using is_unsigned_integer = std::integral_constant<bool, std::is_integral<T>::value && std::is_unsigned<T>::value>; @@ -138,22 +138,22 @@ struct Maximum { } }; -// Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual - -template <typename Op> -void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) { - auto exec = - GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty); - DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); -} - -template <typename InType, typename Op> -void AddGenericCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) { - DCHECK_OK( - func->AddKernel({ty, ty}, boolean(), - applicator::ScalarBinaryEqualTypes<BooleanType, InType, Op>::Exec)); -} - +// Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual + +template <typename Op> +void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) { + auto exec = + GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty); + DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); +} + +template <typename InType, typename Op> +void AddGenericCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) { + DCHECK_OK( + func->AddKernel({ty, ty}, boolean(), + applicator::ScalarBinaryEqualTypes<BooleanType, InType, Op>::Exec)); +} + struct CompareFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -201,79 +201,79 @@ struct VarArgsCompareFunction : ScalarFunction { } }; -template <typename Op> +template <typename Op> std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, const FunctionDoc* doc) { auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc); - - DCHECK_OK(func->AddKernel( - {boolean(), boolean()}, boolean(), - applicator::ScalarBinary<BooleanType, BooleanType, BooleanType, Op>::Exec)); - - for (const std::shared_ptr<DataType>& ty : IntTypes()) { - AddIntegerCompare<Op>(ty, func.get()); - } - AddIntegerCompare<Op>(date32(), func.get()); - AddIntegerCompare<Op>(date64(), func.get()); - - AddGenericCompare<FloatType, Op>(float32(), func.get()); - AddGenericCompare<DoubleType, Op>(float64(), func.get()); - - // Add timestamp kernels - for (auto unit : AllTimeUnits()) { - InputType in_type(match::TimestampTypeUnit(unit)); - auto exec = - GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( - int64()); - DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); - } - - // Duration - for (auto unit : AllTimeUnits()) { - InputType in_type(match::DurationTypeUnit(unit)); - auto exec = - GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( - int64()); - DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); - } - - // Time32 and Time64 - for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) { - InputType in_type(match::Time32TypeUnit(unit)); - auto exec = - GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( - int32()); - DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); - } - for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) { - InputType in_type(match::Time64TypeUnit(unit)); - auto exec = - GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( - int64()); - DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); - } - - for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) { - auto exec = - GenerateVarBinaryBase<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty); - DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); - } - - return func; -} - -std::shared_ptr<ScalarFunction> MakeFlippedFunction(std::string name, + + DCHECK_OK(func->AddKernel( + {boolean(), boolean()}, boolean(), + applicator::ScalarBinary<BooleanType, BooleanType, BooleanType, Op>::Exec)); + + for (const std::shared_ptr<DataType>& ty : IntTypes()) { + AddIntegerCompare<Op>(ty, func.get()); + } + AddIntegerCompare<Op>(date32(), func.get()); + AddIntegerCompare<Op>(date64(), func.get()); + + AddGenericCompare<FloatType, Op>(float32(), func.get()); + AddGenericCompare<DoubleType, Op>(float64(), func.get()); + + // Add timestamp kernels + for (auto unit : AllTimeUnits()) { + InputType in_type(match::TimestampTypeUnit(unit)); + auto exec = + GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( + int64()); + DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); + } + + // Duration + for (auto unit : AllTimeUnits()) { + InputType in_type(match::DurationTypeUnit(unit)); + auto exec = + GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( + int64()); + DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); + } + + // Time32 and Time64 + for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) { + InputType in_type(match::Time32TypeUnit(unit)); + auto exec = + GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( + int32()); + DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); + } + for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) { + InputType in_type(match::Time64TypeUnit(unit)); + auto exec = + GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>( + int64()); + DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); + } + + for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) { + auto exec = + GenerateVarBinaryBase<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty); + DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); + } + + return func; +} + +std::shared_ptr<ScalarFunction> MakeFlippedFunction(std::string name, const ScalarFunction& func, const FunctionDoc* doc) { auto flipped_func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc); - for (const ScalarKernel* kernel : func.kernels()) { - ScalarKernel flipped_kernel = *kernel; - flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec); - DCHECK_OK(flipped_func->AddKernel(std::move(flipped_kernel))); - } - return flipped_func; -} - + for (const ScalarKernel* kernel : func.kernels()) { + ScalarKernel flipped_kernel = *kernel; + flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec); + DCHECK_OK(flipped_func->AddKernel(std::move(flipped_kernel))); + } + return flipped_func; +} + using MinMaxState = OptionsWrapper<ElementWiseAggregateOptions>; // Implement a variadic scalar min/max kernel. @@ -489,23 +489,23 @@ const FunctionDoc max_element_wise_doc{ "NaN will be taken over null, but not over any valid float."), {"*args"}, "ElementWiseAggregateOptions"}; -} // namespace - -void RegisterScalarComparison(FunctionRegistry* registry) { +} // namespace + +void RegisterScalarComparison(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(MakeCompareFunction<Equal>("equal", &equal_doc))); DCHECK_OK( registry->AddFunction(MakeCompareFunction<NotEqual>("not_equal", ¬_equal_doc))); - + auto greater = MakeCompareFunction<Greater>("greater", &greater_doc); auto greater_equal = MakeCompareFunction<GreaterEqual>("greater_equal", &greater_equal_doc); - + auto less = MakeFlippedFunction("less", *greater, &less_doc); auto less_equal = MakeFlippedFunction("less_equal", *greater_equal, &less_equal_doc); - DCHECK_OK(registry->AddFunction(std::move(less))); - DCHECK_OK(registry->AddFunction(std::move(less_equal))); - DCHECK_OK(registry->AddFunction(std::move(greater))); - DCHECK_OK(registry->AddFunction(std::move(greater_equal))); + DCHECK_OK(registry->AddFunction(std::move(less))); + DCHECK_OK(registry->AddFunction(std::move(less_equal))); + DCHECK_OK(registry->AddFunction(std::move(greater))); + DCHECK_OK(registry->AddFunction(std::move(greater_equal))); // ---------------------------------------------------------------------- // Variadic element-wise functions @@ -517,8 +517,8 @@ void RegisterScalarComparison(FunctionRegistry* registry) { auto max_element_wise = MakeScalarMinMax<Maximum>("max_element_wise", &max_element_wise_doc); DCHECK_OK(registry->AddFunction(std::move(max_element_wise))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_fill_null.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_fill_null.cc index cf22b0de3d..e189c294be 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_fill_null.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_fill_null.cc @@ -1,157 +1,157 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <algorithm> -#include <cstring> - -#include "arrow/compute/kernels/common.h" -#include "arrow/scalar.h" -#include "arrow/util/bit_block_counter.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_ops.h" - -namespace arrow { - -using internal::BitBlockCount; -using internal::BitBlockCounter; - -namespace compute { -namespace internal { - -namespace { - -template <typename Type, typename Enable = void> -struct FillNullFunctor {}; - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <cstring> + +#include "arrow/compute/kernels/common.h" +#include "arrow/scalar.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using internal::BitBlockCount; +using internal::BitBlockCounter; + +namespace compute { +namespace internal { + +namespace { + +template <typename Type, typename Enable = void> +struct FillNullFunctor {}; + // Numeric inputs -template <typename Type> -struct FillNullFunctor<Type, enable_if_t<is_number_type<Type>::value>> { - using T = typename TypeTraits<Type>::CType; - +template <typename Type> +struct FillNullFunctor<Type, enable_if_t<is_number_type<Type>::value>> { + using T = typename TypeTraits<Type>::CType; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const ArrayData& data = *batch[0].array(); - const Scalar& fill_value = *batch[1].scalar(); - ArrayData* output = out->mutable_array(); - - // Ensure the kernel is configured properly to have no validity bitmap / - // null count 0 unless we explicitly propagate it below. - DCHECK(output->buffers[0] == nullptr); - - T value = UnboxScalar<Type>::Unbox(fill_value); - if (data.MayHaveNulls() != 0 && fill_value.is_valid) { + const ArrayData& data = *batch[0].array(); + const Scalar& fill_value = *batch[1].scalar(); + ArrayData* output = out->mutable_array(); + + // Ensure the kernel is configured properly to have no validity bitmap / + // null count 0 unless we explicitly propagate it below. + DCHECK(output->buffers[0] == nullptr); + + T value = UnboxScalar<Type>::Unbox(fill_value); + if (data.MayHaveNulls() != 0 && fill_value.is_valid) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> out_buf, ctx->Allocate(data.length * sizeof(T))); - - const uint8_t* is_valid = data.buffers[0]->data(); - const T* in_values = data.GetValues<T>(1); - T* out_values = reinterpret_cast<T*>(out_buf->mutable_data()); - int64_t offset = data.offset; - BitBlockCounter bit_counter(is_valid, data.offset, data.length); - while (offset < data.offset + data.length) { - BitBlockCount block = bit_counter.NextWord(); - if (block.AllSet()) { - // Block all not null - std::memcpy(out_values, in_values, block.length * sizeof(T)); - } else if (block.NoneSet()) { - // Block all null - std::fill(out_values, out_values + block.length, value); - } else { - for (int64_t i = 0; i < block.length; ++i) { - out_values[i] = BitUtil::GetBit(is_valid, offset + i) ? in_values[i] : value; - } - } - offset += block.length; - out_values += block.length; - in_values += block.length; - } - output->buffers[1] = out_buf; + + const uint8_t* is_valid = data.buffers[0]->data(); + const T* in_values = data.GetValues<T>(1); + T* out_values = reinterpret_cast<T*>(out_buf->mutable_data()); + int64_t offset = data.offset; + BitBlockCounter bit_counter(is_valid, data.offset, data.length); + while (offset < data.offset + data.length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.AllSet()) { + // Block all not null + std::memcpy(out_values, in_values, block.length * sizeof(T)); + } else if (block.NoneSet()) { + // Block all null + std::fill(out_values, out_values + block.length, value); + } else { + for (int64_t i = 0; i < block.length; ++i) { + out_values[i] = BitUtil::GetBit(is_valid, offset + i) ? in_values[i] : value; + } + } + offset += block.length; + out_values += block.length; + in_values += block.length; + } + output->buffers[1] = out_buf; output->null_count = 0; - } else { - *output = data; - } + } else { + *output = data; + } return Status::OK(); - } -}; - + } +}; + // Boolean input -template <typename Type> -struct FillNullFunctor<Type, enable_if_t<is_boolean_type<Type>::value>> { +template <typename Type> +struct FillNullFunctor<Type, enable_if_t<is_boolean_type<Type>::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const ArrayData& data = *batch[0].array(); - const Scalar& fill_value = *batch[1].scalar(); - ArrayData* output = out->mutable_array(); - - bool value = UnboxScalar<BooleanType>::Unbox(fill_value); - if (data.MayHaveNulls() != 0 && fill_value.is_valid) { + const ArrayData& data = *batch[0].array(); + const Scalar& fill_value = *batch[1].scalar(); + ArrayData* output = out->mutable_array(); + + bool value = UnboxScalar<BooleanType>::Unbox(fill_value); + if (data.MayHaveNulls() != 0 && fill_value.is_valid) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> out_buf, ctx->AllocateBitmap(data.length)); - - const uint8_t* is_valid = data.buffers[0]->data(); - const uint8_t* data_bitmap = data.buffers[1]->data(); - uint8_t* out_bitmap = out_buf->mutable_data(); - - int64_t data_offset = data.offset; - BitBlockCounter bit_counter(is_valid, data.offset, data.length); - - int64_t out_offset = 0; - while (out_offset < data.length) { - BitBlockCount block = bit_counter.NextWord(); - if (block.AllSet()) { - // Block all not null - ::arrow::internal::CopyBitmap(data_bitmap, data_offset, block.length, - out_bitmap, out_offset); - } else if (block.NoneSet()) { - // Block all null - BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, value); - } else { - for (int64_t i = 0; i < block.length; ++i) { - BitUtil::SetBitTo(out_bitmap, out_offset + i, - BitUtil::GetBit(is_valid, data_offset + i) - ? BitUtil::GetBit(data_bitmap, data_offset + i) - : value); - } - } - data_offset += block.length; - out_offset += block.length; - } - output->buffers[1] = out_buf; + + const uint8_t* is_valid = data.buffers[0]->data(); + const uint8_t* data_bitmap = data.buffers[1]->data(); + uint8_t* out_bitmap = out_buf->mutable_data(); + + int64_t data_offset = data.offset; + BitBlockCounter bit_counter(is_valid, data.offset, data.length); + + int64_t out_offset = 0; + while (out_offset < data.length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.AllSet()) { + // Block all not null + ::arrow::internal::CopyBitmap(data_bitmap, data_offset, block.length, + out_bitmap, out_offset); + } else if (block.NoneSet()) { + // Block all null + BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, value); + } else { + for (int64_t i = 0; i < block.length; ++i) { + BitUtil::SetBitTo(out_bitmap, out_offset + i, + BitUtil::GetBit(is_valid, data_offset + i) + ? BitUtil::GetBit(data_bitmap, data_offset + i) + : value); + } + } + data_offset += block.length; + out_offset += block.length; + } + output->buffers[1] = out_buf; output->null_count = 0; - } else { - *output = data; - } + } else { + *output = data; + } return Status::OK(); - } -}; - + } +}; + // Null input -template <typename Type> -struct FillNullFunctor<Type, enable_if_t<is_null_type<Type>::value>> { +template <typename Type> +struct FillNullFunctor<Type, enable_if_t<is_null_type<Type>::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Nothing preallocated, so we assign into the output - *out->mutable_array() = *batch[0].array(); + // Nothing preallocated, so we assign into the output + *out->mutable_array() = *batch[0].array(); return Status::OK(); - } -}; - + } +}; + // Binary-like input template <typename Type> @@ -194,20 +194,20 @@ struct FillNullFunctor<Type, enable_if_t<is_base_binary_type<Type>::value>> { } }; -void AddBasicFillNullKernels(ScalarKernel kernel, ScalarFunction* func) { - auto AddKernels = [&](const std::vector<std::shared_ptr<DataType>>& types) { - for (const std::shared_ptr<DataType>& ty : types) { - kernel.signature = - KernelSignature::Make({InputType::Array(ty), InputType::Scalar(ty)}, ty); - kernel.exec = GenerateTypeAgnosticPrimitive<FillNullFunctor>(*ty); - DCHECK_OK(func->AddKernel(kernel)); - } - }; - AddKernels(NumericTypes()); - AddKernels(TemporalTypes()); - AddKernels({boolean(), null()}); -} - +void AddBasicFillNullKernels(ScalarKernel kernel, ScalarFunction* func) { + auto AddKernels = [&](const std::vector<std::shared_ptr<DataType>>& types) { + for (const std::shared_ptr<DataType>& ty : types) { + kernel.signature = + KernelSignature::Make({InputType::Array(ty), InputType::Scalar(ty)}, ty); + kernel.exec = GenerateTypeAgnosticPrimitive<FillNullFunctor>(*ty); + DCHECK_OK(func->AddKernel(kernel)); + } + }; + AddKernels(NumericTypes()); + AddKernels(TemporalTypes()); + AddKernels({boolean(), null()}); +} + void AddBinaryFillNullKernels(ScalarKernel kernel, ScalarFunction* func) { for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) { kernel.signature = @@ -224,21 +224,21 @@ const FunctionDoc fill_null_doc{ "Each null value in `values` is replaced with `fill_value`."), {"values", "fill_value"}}; -} // namespace - -void RegisterScalarFillNull(FunctionRegistry* registry) { - { - ScalarKernel fill_null_base; - fill_null_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - fill_null_base.mem_allocation = MemAllocation::NO_PREALLOCATE; +} // namespace + +void RegisterScalarFillNull(FunctionRegistry* registry) { + { + ScalarKernel fill_null_base; + fill_null_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + fill_null_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto fill_null = std::make_shared<ScalarFunction>("fill_null", Arity::Binary(), &fill_null_doc); - AddBasicFillNullKernels(fill_null_base, fill_null.get()); + AddBasicFillNullKernels(fill_null_base, fill_null.get()); AddBinaryFillNullKernels(fill_null_base, fill_null.get()); - DCHECK_OK(registry->AddFunction(fill_null)); - } -} - -} // namespace internal -} // namespace compute -} // namespace arrow + DCHECK_OK(registry->AddFunction(fill_null)); + } +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc index e9f0696c8f..cae2df4a09 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -1,60 +1,60 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Vector kernels involving nested types - -#include "arrow/array/array_base.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Vector kernels involving nested types + +#include "arrow/array/array_base.h" #include "arrow/compute/api_scalar.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/result.h" -#include "arrow/util/bit_block_counter.h" - -namespace arrow { -namespace compute { -namespace internal { -namespace { - -template <typename Type, typename offset_type = typename Type::offset_type> +#include "arrow/compute/kernels/common.h" +#include "arrow/result.h" +#include "arrow/util/bit_block_counter.h" + +namespace arrow { +namespace compute { +namespace internal { +namespace { + +template <typename Type, typename offset_type = typename Type::offset_type> Status ListValueLength(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - using ScalarType = typename TypeTraits<Type>::ScalarType; - using OffsetScalarType = typename TypeTraits<Type>::OffsetScalarType; - - if (batch[0].kind() == Datum::ARRAY) { - typename TypeTraits<Type>::ArrayType list(batch[0].array()); - ArrayData* out_arr = out->mutable_array(); - auto out_values = out_arr->GetMutableValues<offset_type>(1); - const offset_type* offsets = list.raw_value_offsets(); - ::arrow::internal::VisitBitBlocksVoid( - list.data()->buffers[0], list.offset(), list.length(), - [&](int64_t position) { - *out_values++ = offsets[position + 1] - offsets[position]; - }, - [&]() { *out_values++ = 0; }); - } else { - const auto& arg0 = batch[0].scalar_as<ScalarType>(); - if (arg0.is_valid) { - checked_cast<OffsetScalarType*>(out->scalar().get())->value = - static_cast<offset_type>(arg0.value->length()); - } - } + using ScalarType = typename TypeTraits<Type>::ScalarType; + using OffsetScalarType = typename TypeTraits<Type>::OffsetScalarType; + + if (batch[0].kind() == Datum::ARRAY) { + typename TypeTraits<Type>::ArrayType list(batch[0].array()); + ArrayData* out_arr = out->mutable_array(); + auto out_values = out_arr->GetMutableValues<offset_type>(1); + const offset_type* offsets = list.raw_value_offsets(); + ::arrow::internal::VisitBitBlocksVoid( + list.data()->buffers[0], list.offset(), list.length(), + [&](int64_t position) { + *out_values++ = offsets[position + 1] - offsets[position]; + }, + [&]() { *out_values++ = 0; }); + } else { + const auto& arg0 = batch[0].scalar_as<ScalarType>(); + if (arg0.is_valid) { + checked_cast<OffsetScalarType*>(out->scalar().get())->value = + static_cast<offset_type>(arg0.value->length()); + } + } return Status::OK(); -} - +} + const FunctionDoc list_value_length_doc{ "Compute list lengths", ("`lists` must have a list-like type.\n" @@ -154,16 +154,16 @@ const FunctionDoc make_struct_doc{"Wrap Arrays into a StructArray", {"*args"}, "MakeStructOptions"}; -} // namespace - -void RegisterScalarNested(FunctionRegistry* registry) { +} // namespace + +void RegisterScalarNested(FunctionRegistry* registry) { auto list_value_length = std::make_shared<ScalarFunction>( "list_value_length", Arity::Unary(), &list_value_length_doc); - DCHECK_OK(list_value_length->AddKernel({InputType(Type::LIST)}, int32(), - ListValueLength<ListType>)); - DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(), - ListValueLength<LargeListType>)); - DCHECK_OK(registry->AddFunction(std::move(list_value_length))); + DCHECK_OK(list_value_length->AddKernel({InputType(Type::LIST)}, int32(), + ListValueLength<ListType>)); + DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(), + ListValueLength<LargeListType>)); + DCHECK_OK(registry->AddFunction(std::move(list_value_length))); static MakeStructOptions kDefaultMakeStructOptions; auto make_struct_function = std::make_shared<ScalarFunction>( @@ -176,8 +176,8 @@ void RegisterScalarNested(FunctionRegistry* registry) { kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; DCHECK_OK(make_struct_function->AddKernel(std::move(kernel))); DCHECK_OK(registry->AddFunction(std::move(make_struct_function))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index 3e2e95e540..867d8d041f 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -1,45 +1,45 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/array/array_base.h" -#include "arrow/array/builder_primitive.h" -#include "arrow/compute/api_scalar.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array/array_base.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/cast.h" -#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_writer.h" -#include "arrow/util/hashing.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using internal::checked_cast; -using internal::HashTraits; - -namespace compute { -namespace internal { -namespace { - -template <typename Type> -struct SetLookupState : public KernelState { +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_writer.h" +#include "arrow/util/hashing.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +using internal::checked_cast; +using internal::HashTraits; + +namespace compute { +namespace internal { +namespace { + +template <typename Type> +struct SetLookupState : public KernelState { explicit SetLookupState(MemoryPool* pool) : lookup_table(pool, 0) {} - - Status Init(const SetLookupOptions& options) { + + Status Init(const SetLookupOptions& options) { if (options.value_set.kind() == Datum::ARRAY) { const ArrayData& value_set = *options.value_set.array(); memo_index_to_value_index.reserve(value_set.length); @@ -63,11 +63,11 @@ struct SetLookupState : public KernelState { Status AddArrayValueSet(const SetLookupOptions& options, const ArrayData& data, int64_t start_index = 0) { - using T = typename GetViewType<Type>::T; + using T = typename GetViewType<Type>::T; int32_t index = static_cast<int32_t>(start_index); - auto visit_valid = [&](T v) { + auto visit_valid = [&](T v) { const auto memo_size = static_cast<int32_t>(memo_index_to_value_index.size()); - int32_t unused_memo_index; + int32_t unused_memo_index; auto on_found = [&](int32_t memo_index) { DCHECK_LT(memo_index, memo_size); }; auto on_not_found = [&](int32_t memo_index) { DCHECK_EQ(memo_index, memo_size); @@ -77,8 +77,8 @@ struct SetLookupState : public KernelState { v, std::move(on_found), std::move(on_not_found), &unused_memo_index)); ++index; return Status::OK(); - }; - auto visit_null = [&]() { + }; + auto visit_null = [&]() { const auto memo_size = static_cast<int32_t>(memo_index_to_value_index.size()); auto on_found = [&](int32_t memo_index) { DCHECK_LT(memo_index, memo_size); }; auto on_not_found = [&](int32_t memo_index) { @@ -87,96 +87,96 @@ struct SetLookupState : public KernelState { }; lookup_table.GetOrInsertNull(std::move(on_found), std::move(on_not_found)); ++index; - return Status::OK(); - }; + return Status::OK(); + }; return VisitArrayDataInline<Type>(data, visit_valid, visit_null); - } - - using MemoTable = typename HashTraits<Type>::MemoTableType; - MemoTable lookup_table; + } + + using MemoTable = typename HashTraits<Type>::MemoTableType; + MemoTable lookup_table; // When there are duplicates in value_set, the MemoTable indices must // be mapped back to indices in the value_set. std::vector<int32_t> memo_index_to_value_index; int32_t null_index = -1; -}; - -template <> -struct SetLookupState<NullType> : public KernelState { - explicit SetLookupState(MemoryPool*) {} - - Status Init(const SetLookupOptions& options) { +}; + +template <> +struct SetLookupState<NullType> : public KernelState { + explicit SetLookupState(MemoryPool*) {} + + Status Init(const SetLookupOptions& options) { value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls; - return Status::OK(); - } - + return Status::OK(); + } + bool value_set_has_null; -}; - -// TODO: Put this concept somewhere reusable -template <int width> -struct UnsignedIntType; - -template <> -struct UnsignedIntType<1> { - using Type = UInt8Type; -}; - -template <> -struct UnsignedIntType<2> { - using Type = UInt16Type; -}; - -template <> -struct UnsignedIntType<4> { - using Type = UInt32Type; -}; - -template <> -struct UnsignedIntType<8> { - using Type = UInt64Type; -}; - -// Constructing the type requires a type parameter -struct InitStateVisitor { - KernelContext* ctx; +}; + +// TODO: Put this concept somewhere reusable +template <int width> +struct UnsignedIntType; + +template <> +struct UnsignedIntType<1> { + using Type = UInt8Type; +}; + +template <> +struct UnsignedIntType<2> { + using Type = UInt16Type; +}; + +template <> +struct UnsignedIntType<4> { + using Type = UInt32Type; +}; + +template <> +struct UnsignedIntType<8> { + using Type = UInt64Type; +}; + +// Constructing the type requires a type parameter +struct InitStateVisitor { + KernelContext* ctx; SetLookupOptions options; const std::shared_ptr<DataType>& arg_type; - std::unique_ptr<KernelState> result; - + std::unique_ptr<KernelState> result; + InitStateVisitor(KernelContext* ctx, const KernelInitArgs& args) : ctx(ctx), options(*checked_cast<const SetLookupOptions*>(args.options)), arg_type(args.inputs[0].type) {} - - template <typename Type> - Status Init() { - using StateType = SetLookupState<Type>; - result.reset(new StateType(ctx->exec_context()->memory_pool())); + + template <typename Type> + Status Init() { + using StateType = SetLookupState<Type>; + result.reset(new StateType(ctx->exec_context()->memory_pool())); return static_cast<StateType*>(result.get())->Init(options); - } - - Status Visit(const DataType&) { return Init<NullType>(); } - - template <typename Type> - enable_if_boolean<Type, Status> Visit(const Type&) { - return Init<BooleanType>(); - } - - template <typename Type> - enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value, Status> Visit( - const Type&) { - return Init<typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>(); - } - - template <typename Type> - enable_if_base_binary<Type, Status> Visit(const Type&) { - return Init<typename Type::PhysicalType>(); - } - - // Handle Decimal128Type, FixedSizeBinaryType - Status Visit(const FixedSizeBinaryType& type) { return Init<FixedSizeBinaryType>(); } - + } + + Status Visit(const DataType&) { return Init<NullType>(); } + + template <typename Type> + enable_if_boolean<Type, Status> Visit(const Type&) { + return Init<BooleanType>(); + } + + template <typename Type> + enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value, Status> Visit( + const Type&) { + return Init<typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>(); + } + + template <typename Type> + enable_if_base_binary<Type, Status> Visit(const Type&) { + return Init<typename Type::PhysicalType>(); + } + + // Handle Decimal128Type, FixedSizeBinaryType + Status Visit(const FixedSizeBinaryType& type) { return Init<FixedSizeBinaryType>(); } + Result<std::unique_ptr<KernelState>> GetResult() { if (!options.value_set.type()->Equals(arg_type)) { ARROW_ASSIGN_OR_RAISE( @@ -186,9 +186,9 @@ struct InitStateVisitor { RETURN_NOT_OK(VisitTypeInline(*arg_type, this)); return std::move(result); - } -}; - + } +}; + Result<std::unique_ptr<KernelState>> InitSetLookup(KernelContext* ctx, const KernelInitArgs& args) { if (args.options == nullptr) { @@ -197,246 +197,246 @@ Result<std::unique_ptr<KernelState>> InitSetLookup(KernelContext* ctx, } return InitStateVisitor{ctx, args}.GetResult(); -} - -struct IndexInVisitor { - KernelContext* ctx; - const ArrayData& data; - Datum* out; - Int32Builder builder; - - IndexInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) - : ctx(ctx), data(data), out(out), builder(ctx->exec_context()->memory_pool()) {} - +} + +struct IndexInVisitor { + KernelContext* ctx; + const ArrayData& data; + Datum* out; + Int32Builder builder; + + IndexInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) + : ctx(ctx), data(data), out(out), builder(ctx->exec_context()->memory_pool()) {} + Status Visit(const DataType& type) { DCHECK_EQ(type.id(), Type::NA); - const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state()); - if (data.length != 0) { + const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state()); + if (data.length != 0) { // skip_nulls is honored for consistency with other types if (state.value_set_has_null) { - RETURN_NOT_OK(this->builder.Reserve(data.length)); - for (int64_t i = 0; i < data.length; ++i) { - this->builder.UnsafeAppend(0); - } + RETURN_NOT_OK(this->builder.Reserve(data.length)); + for (int64_t i = 0; i < data.length; ++i) { + this->builder.UnsafeAppend(0); + } } else { RETURN_NOT_OK(this->builder.AppendNulls(data.length)); - } - } - return Status::OK(); - } - - template <typename Type> - Status ProcessIndexIn() { - using T = typename GetViewType<Type>::T; - - const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state()); - - RETURN_NOT_OK(this->builder.Reserve(data.length)); - VisitArrayDataInline<Type>( - data, - [&](T v) { - int32_t index = state.lookup_table.Get(v); - if (index != -1) { - // matching needle; output index from value_set + } + } + return Status::OK(); + } + + template <typename Type> + Status ProcessIndexIn() { + using T = typename GetViewType<Type>::T; + + const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state()); + + RETURN_NOT_OK(this->builder.Reserve(data.length)); + VisitArrayDataInline<Type>( + data, + [&](T v) { + int32_t index = state.lookup_table.Get(v); + if (index != -1) { + // matching needle; output index from value_set this->builder.UnsafeAppend(state.memo_index_to_value_index[index]); - } else { - // no matching needle; output null - this->builder.UnsafeAppendNull(); - } - }, - [&]() { + } else { + // no matching needle; output null + this->builder.UnsafeAppendNull(); + } + }, + [&]() { if (state.null_index != -1) { - // value_set included null + // value_set included null this->builder.UnsafeAppend(state.null_index); - } else { - // value_set does not include null; output null - this->builder.UnsafeAppendNull(); - } - }); - return Status::OK(); - } - - template <typename Type> - enable_if_boolean<Type, Status> Visit(const Type&) { - return ProcessIndexIn<BooleanType>(); - } - - template <typename Type> - enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value, Status> Visit( - const Type&) { - return ProcessIndexIn< - typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>(); - } - - template <typename Type> - enable_if_base_binary<Type, Status> Visit(const Type&) { - return ProcessIndexIn<typename Type::PhysicalType>(); - } - - // Handle Decimal128Type, FixedSizeBinaryType - Status Visit(const FixedSizeBinaryType& type) { - return ProcessIndexIn<FixedSizeBinaryType>(); - } - - Status Execute() { - Status s = VisitTypeInline(*data.type, this); - if (!s.ok()) { - return s; - } - std::shared_ptr<ArrayData> out_data; - RETURN_NOT_OK(this->builder.FinishInternal(&out_data)); - out->value = std::move(out_data); - return Status::OK(); - } -}; - + } else { + // value_set does not include null; output null + this->builder.UnsafeAppendNull(); + } + }); + return Status::OK(); + } + + template <typename Type> + enable_if_boolean<Type, Status> Visit(const Type&) { + return ProcessIndexIn<BooleanType>(); + } + + template <typename Type> + enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value, Status> Visit( + const Type&) { + return ProcessIndexIn< + typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>(); + } + + template <typename Type> + enable_if_base_binary<Type, Status> Visit(const Type&) { + return ProcessIndexIn<typename Type::PhysicalType>(); + } + + // Handle Decimal128Type, FixedSizeBinaryType + Status Visit(const FixedSizeBinaryType& type) { + return ProcessIndexIn<FixedSizeBinaryType>(); + } + + Status Execute() { + Status s = VisitTypeInline(*data.type, this); + if (!s.ok()) { + return s; + } + std::shared_ptr<ArrayData> out_data; + RETURN_NOT_OK(this->builder.FinishInternal(&out_data)); + out->value = std::move(out_data); + return Status::OK(); + } +}; + Status ExecIndexIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return IndexInVisitor(ctx, *batch[0].array(), out).Execute(); -} - -// ---------------------------------------------------------------------- - +} + +// ---------------------------------------------------------------------- + // IsIn writes the results into a preallocated boolean data bitmap -struct IsInVisitor { - KernelContext* ctx; - const ArrayData& data; - Datum* out; - - IsInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) - : ctx(ctx), data(data), out(out) {} - +struct IsInVisitor { + KernelContext* ctx; + const ArrayData& data; + Datum* out; + + IsInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) + : ctx(ctx), data(data), out(out) {} + Status Visit(const DataType& type) { DCHECK_EQ(type.id(), Type::NA); - const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state()); - ArrayData* output = out->mutable_array(); + const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state()); + ArrayData* output = out->mutable_array(); // skip_nulls is honored for consistency with other types BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, output->length, state.value_set_has_null); - return Status::OK(); - } - - template <typename Type> - Status ProcessIsIn() { - using T = typename GetViewType<Type>::T; - const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state()); - ArrayData* output = out->mutable_array(); - - FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), output->offset, - output->length); - - VisitArrayDataInline<Type>( - this->data, - [&](T v) { - if (state.lookup_table.Get(v) != -1) { - writer.Set(); - } else { - writer.Clear(); - } - writer.Next(); - }, - [&]() { + return Status::OK(); + } + + template <typename Type> + Status ProcessIsIn() { + using T = typename GetViewType<Type>::T; + const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state()); + ArrayData* output = out->mutable_array(); + + FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), output->offset, + output->length); + + VisitArrayDataInline<Type>( + this->data, + [&](T v) { + if (state.lookup_table.Get(v) != -1) { + writer.Set(); + } else { + writer.Clear(); + } + writer.Next(); + }, + [&]() { if (state.null_index != -1) { writer.Set(); } else { writer.Clear(); } - writer.Next(); - }); - writer.Finish(); - return Status::OK(); - } - - template <typename Type> - enable_if_boolean<Type, Status> Visit(const Type&) { - return ProcessIsIn<BooleanType>(); - } - - template <typename Type> - enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value, Status> Visit( - const Type&) { - return ProcessIsIn<typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>(); - } - - template <typename Type> - enable_if_base_binary<Type, Status> Visit(const Type&) { - return ProcessIsIn<typename Type::PhysicalType>(); - } - - // Handle Decimal128Type, FixedSizeBinaryType - Status Visit(const FixedSizeBinaryType& type) { - return ProcessIsIn<FixedSizeBinaryType>(); - } - - Status Execute() { return VisitTypeInline(*data.type, this); } -}; - + writer.Next(); + }); + writer.Finish(); + return Status::OK(); + } + + template <typename Type> + enable_if_boolean<Type, Status> Visit(const Type&) { + return ProcessIsIn<BooleanType>(); + } + + template <typename Type> + enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value, Status> Visit( + const Type&) { + return ProcessIsIn<typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>(); + } + + template <typename Type> + enable_if_base_binary<Type, Status> Visit(const Type&) { + return ProcessIsIn<typename Type::PhysicalType>(); + } + + // Handle Decimal128Type, FixedSizeBinaryType + Status Visit(const FixedSizeBinaryType& type) { + return ProcessIsIn<FixedSizeBinaryType>(); + } + + Status Execute() { return VisitTypeInline(*data.type, this); } +}; + Status ExecIsIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return IsInVisitor(ctx, *batch[0].array(), out).Execute(); -} - -// Unary set lookup kernels available for the following input types -// -// * Null type -// * Boolean -// * Numeric -// * Simple temporal types (date, time, timestamp) -// * Base binary types -// * Decimal - -void AddBasicSetLookupKernels(ScalarKernel kernel, - const std::shared_ptr<DataType>& out_ty, - ScalarFunction* func) { - auto AddKernels = [&](const std::vector<std::shared_ptr<DataType>>& types) { - for (const std::shared_ptr<DataType>& ty : types) { - kernel.signature = KernelSignature::Make({ty}, out_ty); - DCHECK_OK(func->AddKernel(kernel)); - } - }; - - AddKernels(BaseBinaryTypes()); - AddKernels(NumericTypes()); - AddKernels(TemporalTypes()); - - std::vector<Type::type> other_types = {Type::BOOL, Type::DECIMAL, - Type::FIXED_SIZE_BINARY}; - for (auto ty : other_types) { - kernel.signature = KernelSignature::Make({InputType::Array(ty)}, out_ty); - DCHECK_OK(func->AddKernel(kernel)); - } -} - -// Enables calling is_in with CallFunction as though it were binary. -class IsInMetaBinary : public MetaFunction { - public: +} + +// Unary set lookup kernels available for the following input types +// +// * Null type +// * Boolean +// * Numeric +// * Simple temporal types (date, time, timestamp) +// * Base binary types +// * Decimal + +void AddBasicSetLookupKernels(ScalarKernel kernel, + const std::shared_ptr<DataType>& out_ty, + ScalarFunction* func) { + auto AddKernels = [&](const std::vector<std::shared_ptr<DataType>>& types) { + for (const std::shared_ptr<DataType>& ty : types) { + kernel.signature = KernelSignature::Make({ty}, out_ty); + DCHECK_OK(func->AddKernel(kernel)); + } + }; + + AddKernels(BaseBinaryTypes()); + AddKernels(NumericTypes()); + AddKernels(TemporalTypes()); + + std::vector<Type::type> other_types = {Type::BOOL, Type::DECIMAL, + Type::FIXED_SIZE_BINARY}; + for (auto ty : other_types) { + kernel.signature = KernelSignature::Make({InputType::Array(ty)}, out_ty); + DCHECK_OK(func->AddKernel(kernel)); + } +} + +// Enables calling is_in with CallFunction as though it were binary. +class IsInMetaBinary : public MetaFunction { + public: IsInMetaBinary() : MetaFunction("is_in_meta_binary", Arity::Binary(), /*doc=*/nullptr) {} - - Result<Datum> ExecuteImpl(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - if (options != nullptr) { - return Status::Invalid("Unexpected options for 'is_in_meta_binary' function"); - } - return IsIn(args[0], args[1], ctx); - } -}; - -// Enables calling index_in with CallFunction as though it were binary. -class IndexInMetaBinary : public MetaFunction { - public: + + Result<Datum> ExecuteImpl(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + if (options != nullptr) { + return Status::Invalid("Unexpected options for 'is_in_meta_binary' function"); + } + return IsIn(args[0], args[1], ctx); + } +}; + +// Enables calling index_in with CallFunction as though it were binary. +class IndexInMetaBinary : public MetaFunction { + public: IndexInMetaBinary() : MetaFunction("index_in_meta_binary", Arity::Binary(), /*doc=*/nullptr) {} - - Result<Datum> ExecuteImpl(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - if (options != nullptr) { - return Status::Invalid("Unexpected options for 'index_in_meta_binary' function"); - } - return IndexIn(args[0], args[1], ctx); - } -}; - + + Result<Datum> ExecuteImpl(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + if (options != nullptr) { + return Status::Invalid("Unexpected options for 'index_in_meta_binary' function"); + } + return IndexIn(args[0], args[1], ctx); + } +}; + struct SetLookupFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -466,48 +466,48 @@ const FunctionDoc index_in_doc{ {"values"}, "SetLookupOptions"}; -} // namespace - -void RegisterScalarSetLookup(FunctionRegistry* registry) { +} // namespace + +void RegisterScalarSetLookup(FunctionRegistry* registry) { // IsIn writes its boolean output into preallocated memory - { - ScalarKernel isin_base; - isin_base.init = InitSetLookup; + { + ScalarKernel isin_base; + isin_base.init = InitSetLookup; isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn, NullHandling::OUTPUT_NOT_NULL); isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL; auto is_in = std::make_shared<SetLookupFunction>("is_in", Arity::Unary(), &is_in_doc); - - AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); - - isin_base.signature = KernelSignature::Make({null()}, boolean()); - DCHECK_OK(is_in->AddKernel(isin_base)); - DCHECK_OK(registry->AddFunction(is_in)); - - DCHECK_OK(registry->AddFunction(std::make_shared<IsInMetaBinary>())); - } - - // IndexIn uses Int32Builder and so is responsible for all its own allocation - { - ScalarKernel index_in_base; - index_in_base.init = InitSetLookup; + + AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); + + isin_base.signature = KernelSignature::Make({null()}, boolean()); + DCHECK_OK(is_in->AddKernel(isin_base)); + DCHECK_OK(registry->AddFunction(is_in)); + + DCHECK_OK(registry->AddFunction(std::make_shared<IsInMetaBinary>())); + } + + // IndexIn uses Int32Builder and so is responsible for all its own allocation + { + ScalarKernel index_in_base; + index_in_base.init = InitSetLookup; index_in_base.exec = TrivialScalarUnaryAsArraysExec( ExecIndexIn, NullHandling::COMPUTED_NO_PREALLOCATE); - index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; + index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = std::make_shared<SetLookupFunction>("index_in", Arity::Unary(), &index_in_doc); - - AddBasicSetLookupKernels(index_in_base, /*output_type=*/int32(), index_in.get()); - - index_in_base.signature = KernelSignature::Make({null()}, int32()); - DCHECK_OK(index_in->AddKernel(index_in_base)); - DCHECK_OK(registry->AddFunction(index_in)); - - DCHECK_OK(registry->AddFunction(std::make_shared<IndexInMetaBinary>())); - } -} - -} // namespace internal -} // namespace compute -} // namespace arrow + + AddBasicSetLookupKernels(index_in_base, /*output_type=*/int32(), index_in.get()); + + index_in_base.signature = KernelSignature::Make({null()}, int32()); + DCHECK_OK(index_in->AddKernel(index_in_base)); + DCHECK_OK(registry->AddFunction(index_in)); + + DCHECK_OK(registry->AddFunction(std::make_shared<IndexInMetaBinary>())); + } +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc index ab0a490eeb..9d2ed1764e 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1,29 +1,29 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <algorithm> -#include <cctype> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <cctype> #include <iterator> -#include <string> - -#ifdef ARROW_WITH_UTF8PROC -#include <utf8proc.h> -#endif - +#include <string> + +#ifdef ARROW_WITH_UTF8PROC +#include <utf8proc.h> +#endif + #ifdef ARROW_WITH_RE2 #include <re2/re2.h> #endif @@ -33,22 +33,22 @@ #include "arrow/buffer_builder.h" #include "arrow/builder.h" -#include "arrow/compute/api_scalar.h" -#include "arrow/compute/kernels/common.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/common.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/utf8.h" -#include "arrow/util/value_parsing.h" +#include "arrow/util/utf8.h" +#include "arrow/util/value_parsing.h" #include "arrow/visitor_inline.h" - -namespace arrow { + +namespace arrow { using internal::checked_cast; -namespace compute { -namespace internal { - -namespace { - +namespace compute { +namespace internal { + +namespace { + #ifdef ARROW_WITH_RE2 util::string_view ToStringView(re2::StringPiece piece) { return {piece.data(), piece.length()}; @@ -66,33 +66,33 @@ Status RegexStatus(const RE2& regex) { } #endif -// Code units in the range [a-z] can only be an encoding of an ascii -// character/codepoint, not the 2nd, 3rd or 4th code unit (byte) of an different -// codepoint. This guaranteed by non-overlap design of the unicode standard. (see -// section 2.5 of Unicode Standard Core Specification v13.0) - -static inline uint8_t ascii_tolower(uint8_t utf8_code_unit) { - return ((utf8_code_unit >= 'A') && (utf8_code_unit <= 'Z')) ? (utf8_code_unit + 32) - : utf8_code_unit; -} - -static inline uint8_t ascii_toupper(uint8_t utf8_code_unit) { - return ((utf8_code_unit >= 'a') && (utf8_code_unit <= 'z')) ? (utf8_code_unit - 32) - : utf8_code_unit; -} - -template <typename T> -static inline bool IsAsciiCharacter(T character) { - return character < 128; -} - -struct BinaryLength { - template <typename OutValue, typename Arg0Value = util::string_view> +// Code units in the range [a-z] can only be an encoding of an ascii +// character/codepoint, not the 2nd, 3rd or 4th code unit (byte) of an different +// codepoint. This guaranteed by non-overlap design of the unicode standard. (see +// section 2.5 of Unicode Standard Core Specification v13.0) + +static inline uint8_t ascii_tolower(uint8_t utf8_code_unit) { + return ((utf8_code_unit >= 'A') && (utf8_code_unit <= 'Z')) ? (utf8_code_unit + 32) + : utf8_code_unit; +} + +static inline uint8_t ascii_toupper(uint8_t utf8_code_unit) { + return ((utf8_code_unit >= 'a') && (utf8_code_unit <= 'z')) ? (utf8_code_unit - 32) + : utf8_code_unit; +} + +template <typename T> +static inline bool IsAsciiCharacter(T character) { + return character < 128; +} + +struct BinaryLength { + template <typename OutValue, typename Arg0Value = util::string_view> static OutValue Call(KernelContext*, Arg0Value val, Status*) { - return static_cast<OutValue>(val.size()); - } -}; - + return static_cast<OutValue>(val.size()); + } +}; + struct Utf8Length { template <typename OutValue, typename Arg0Value = util::string_view> static OutValue Call(KernelContext*, Arg0Value val, Status*) { @@ -102,28 +102,28 @@ struct Utf8Length { } }; -#ifdef ARROW_WITH_UTF8PROC - -// Direct lookup tables for unicode properties -constexpr uint32_t kMaxCodepointLookup = - 0xffff; // up to this codepoint is in a lookup table -std::vector<uint32_t> lut_upper_codepoint; -std::vector<uint32_t> lut_lower_codepoint; -std::vector<utf8proc_category_t> lut_category; -std::once_flag flag_case_luts; - -void EnsureLookupTablesFilled() { - std::call_once(flag_case_luts, []() { - lut_upper_codepoint.reserve(kMaxCodepointLookup + 1); - lut_lower_codepoint.reserve(kMaxCodepointLookup + 1); - for (uint32_t i = 0; i <= kMaxCodepointLookup; i++) { - lut_upper_codepoint.push_back(utf8proc_toupper(i)); - lut_lower_codepoint.push_back(utf8proc_tolower(i)); - lut_category.push_back(utf8proc_category(i)); - } - }); -} - +#ifdef ARROW_WITH_UTF8PROC + +// Direct lookup tables for unicode properties +constexpr uint32_t kMaxCodepointLookup = + 0xffff; // up to this codepoint is in a lookup table +std::vector<uint32_t> lut_upper_codepoint; +std::vector<uint32_t> lut_lower_codepoint; +std::vector<utf8proc_category_t> lut_category; +std::once_flag flag_case_luts; + +void EnsureLookupTablesFilled() { + std::call_once(flag_case_luts, []() { + lut_upper_codepoint.reserve(kMaxCodepointLookup + 1); + lut_lower_codepoint.reserve(kMaxCodepointLookup + 1); + for (uint32_t i = 0; i <= kMaxCodepointLookup; i++) { + lut_upper_codepoint.push_back(utf8proc_toupper(i)); + lut_lower_codepoint.push_back(utf8proc_tolower(i)); + lut_category.push_back(utf8proc_category(i)); + } + }); +} + #else void EnsureLookupTablesFilled() {} @@ -154,67 +154,67 @@ struct StringTransformBase { template <typename Type, typename StringTransform> struct StringTransformExecBase { - using offset_type = typename Type::offset_type; - using ArrayType = typename TypeTraits<Type>::ArrayType; - + using offset_type = typename Type::offset_type; + using ArrayType = typename TypeTraits<Type>::ArrayType; + static Status Execute(KernelContext* ctx, StringTransform* transform, const ExecBatch& batch, Datum* out) { if (batch[0].kind() == Datum::ARRAY) { return ExecArray(ctx, transform, batch[0].array(), out); - } + } DCHECK_EQ(batch[0].kind(), Datum::SCALAR); return ExecScalar(ctx, transform, batch[0].scalar(), out); - } - + } + static Status ExecArray(KernelContext* ctx, StringTransform* transform, const std::shared_ptr<ArrayData>& data, Datum* out) { ArrayType input(data); ArrayData* output = out->mutable_array(); - + const int64_t input_ncodeunits = input.total_values_length(); const int64_t input_nstrings = input.length(); - + const int64_t output_ncodeunits_max = transform->MaxCodeunits(input_nstrings, input_ncodeunits); if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) { return Status::CapacityError( "Result might not fit in a 32bit utf8 array, convert to large_utf8"); } - + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); output->buffers[2] = values_buffer; - + // String offsets are preallocated offset_type* output_string_offsets = output->GetMutableValues<offset_type>(1); uint8_t* output_str = output->buffers[2]->mutable_data(); offset_type output_ncodeunits = 0; - + output_string_offsets[0] = 0; for (int64_t i = 0; i < input_nstrings; i++) { if (!input.IsNull(i)) { - offset_type input_string_ncodeunits; + offset_type input_string_ncodeunits; const uint8_t* input_string = input.GetValue(i, &input_string_ncodeunits); auto encoded_nbytes = static_cast<offset_type>(transform->Transform( input_string, input_string_ncodeunits, output_str + output_ncodeunits)); if (encoded_nbytes < 0) { return transform->InvalidStatus(); - } - output_ncodeunits += encoded_nbytes; - } + } + output_ncodeunits += encoded_nbytes; + } output_string_offsets[i + 1] = output_ncodeunits; } DCHECK_LE(output_ncodeunits, output_ncodeunits_max); - + // Trim the codepoint buffer, since we allocated too much return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); } - + static Status ExecScalar(KernelContext* ctx, StringTransform* transform, const std::shared_ptr<Scalar>& scalar, Datum* out) { const auto& input = checked_cast<const BaseBinaryScalar&>(*scalar); if (!input.is_valid) { return Status::OK(); - } + } auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get()); result->is_valid = true; const int64_t data_nbytes = static_cast<int64_t>(input.value->size()); @@ -233,9 +233,9 @@ struct StringTransformExecBase { } DCHECK_LE(encoded_nbytes, output_ncodeunits_max); return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); - } -}; - + } +}; + template <typename Type, typename StringTransform> struct StringTransformExec : public StringTransformExecBase<Type, StringTransform> { using StringTransformExecBase<Type, StringTransform>::Execute; @@ -300,26 +300,26 @@ struct CaseMappingTransform { struct UTF8UpperTransform : public CaseMappingTransform { static uint32_t TransformCodepoint(uint32_t codepoint) { - return codepoint <= kMaxCodepointLookup ? lut_upper_codepoint[codepoint] - : utf8proc_toupper(codepoint); - } -}; - -template <typename Type> + return codepoint <= kMaxCodepointLookup ? lut_upper_codepoint[codepoint] + : utf8proc_toupper(codepoint); + } +}; + +template <typename Type> using UTF8Upper = StringTransformExec<Type, StringTransformCodepoint<UTF8UpperTransform>>; struct UTF8LowerTransform : public CaseMappingTransform { - static uint32_t TransformCodepoint(uint32_t codepoint) { - return codepoint <= kMaxCodepointLookup ? lut_lower_codepoint[codepoint] - : utf8proc_tolower(codepoint); - } -}; - + static uint32_t TransformCodepoint(uint32_t codepoint) { + return codepoint <= kMaxCodepointLookup ? lut_lower_codepoint[codepoint] + : utf8proc_tolower(codepoint); + } +}; + template <typename Type> using UTF8Lower = StringTransformExec<Type, StringTransformCodepoint<UTF8LowerTransform>>; - -#endif // ARROW_WITH_UTF8PROC - + +#endif // ARROW_WITH_UTF8PROC + struct AsciiReverseTransform : public StringTransformBase { int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, uint8_t* output) { @@ -357,129 +357,129 @@ struct Utf8ReverseTransform : public StringTransformBase { template <typename Type> using Utf8Reverse = StringTransformExec<Type, Utf8ReverseTransform>; -using TransformFunc = std::function<void(const uint8_t*, int64_t, uint8_t*)>; - -// Transform a buffer of offsets to one which begins with 0 and has same -// value lengths. -template <typename T> -Status GetShiftedOffsets(KernelContext* ctx, const Buffer& input_buffer, int64_t offset, - int64_t length, std::shared_ptr<Buffer>* out) { - ARROW_ASSIGN_OR_RAISE(*out, ctx->Allocate((length + 1) * sizeof(T))); - const T* input_offsets = reinterpret_cast<const T*>(input_buffer.data()) + offset; - T* out_offsets = reinterpret_cast<T*>((*out)->mutable_data()); - T first_offset = *input_offsets; - for (int64_t i = 0; i < length; ++i) { - *out_offsets++ = input_offsets[i] - first_offset; - } - *out_offsets = input_offsets[length] - first_offset; - return Status::OK(); -} - -// Apply `transform` to input character data- this function cannot change the -// length -template <typename Type> +using TransformFunc = std::function<void(const uint8_t*, int64_t, uint8_t*)>; + +// Transform a buffer of offsets to one which begins with 0 and has same +// value lengths. +template <typename T> +Status GetShiftedOffsets(KernelContext* ctx, const Buffer& input_buffer, int64_t offset, + int64_t length, std::shared_ptr<Buffer>* out) { + ARROW_ASSIGN_OR_RAISE(*out, ctx->Allocate((length + 1) * sizeof(T))); + const T* input_offsets = reinterpret_cast<const T*>(input_buffer.data()) + offset; + T* out_offsets = reinterpret_cast<T*>((*out)->mutable_data()); + T first_offset = *input_offsets; + for (int64_t i = 0; i < length; ++i) { + *out_offsets++ = input_offsets[i] - first_offset; + } + *out_offsets = input_offsets[length] - first_offset; + return Status::OK(); +} + +// Apply `transform` to input character data- this function cannot change the +// length +template <typename Type> Status StringDataTransform(KernelContext* ctx, const ExecBatch& batch, TransformFunc transform, Datum* out) { - using ArrayType = typename TypeTraits<Type>::ArrayType; - using offset_type = typename Type::offset_type; - - if (batch[0].kind() == Datum::ARRAY) { - const ArrayData& input = *batch[0].array(); - ArrayType input_boxed(batch[0].array()); - - ArrayData* out_arr = out->mutable_array(); - - if (input.offset == 0) { - // We can reuse offsets from input - out_arr->buffers[1] = input.buffers[1]; - } else { - DCHECK(input.buffers[1]); - // We must allocate new space for the offsets and shift the existing offsets + using ArrayType = typename TypeTraits<Type>::ArrayType; + using offset_type = typename Type::offset_type; + + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& input = *batch[0].array(); + ArrayType input_boxed(batch[0].array()); + + ArrayData* out_arr = out->mutable_array(); + + if (input.offset == 0) { + // We can reuse offsets from input + out_arr->buffers[1] = input.buffers[1]; + } else { + DCHECK(input.buffers[1]); + // We must allocate new space for the offsets and shift the existing offsets RETURN_NOT_OK(GetShiftedOffsets<offset_type>(ctx, *input.buffers[1], input.offset, input.length, &out_arr->buffers[1])); - } - - // Allocate space for output data - int64_t data_nbytes = input_boxed.total_values_length(); + } + + // Allocate space for output data + int64_t data_nbytes = input_boxed.total_values_length(); RETURN_NOT_OK(ctx->Allocate(data_nbytes).Value(&out_arr->buffers[2])); - if (input.length > 0) { - transform(input.buffers[2]->data() + input_boxed.value_offset(0), data_nbytes, - out_arr->buffers[2]->mutable_data()); - } - } else { - const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); - auto result = checked_pointer_cast<BaseBinaryScalar>(MakeNullScalar(out->type())); - if (input.is_valid) { - result->is_valid = true; - int64_t data_nbytes = input.value->size(); + if (input.length > 0) { + transform(input.buffers[2]->data() + input_boxed.value_offset(0), data_nbytes, + out_arr->buffers[2]->mutable_data()); + } + } else { + const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); + auto result = checked_pointer_cast<BaseBinaryScalar>(MakeNullScalar(out->type())); + if (input.is_valid) { + result->is_valid = true; + int64_t data_nbytes = input.value->size(); RETURN_NOT_OK(ctx->Allocate(data_nbytes).Value(&result->value)); - transform(input.value->data(), data_nbytes, result->value->mutable_data()); - } + transform(input.value->data(), data_nbytes, result->value->mutable_data()); + } out->value = result; - } + } return Status::OK(); -} - -void TransformAsciiUpper(const uint8_t* input, int64_t length, uint8_t* output) { - std::transform(input, input + length, output, ascii_toupper); -} - -template <typename Type> -struct AsciiUpper { +} + +void TransformAsciiUpper(const uint8_t* input, int64_t length, uint8_t* output) { + std::transform(input, input + length, output, ascii_toupper); +} + +template <typename Type> +struct AsciiUpper { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return StringDataTransform<Type>(ctx, batch, TransformAsciiUpper, out); - } -}; - -void TransformAsciiLower(const uint8_t* input, int64_t length, uint8_t* output) { - std::transform(input, input + length, output, ascii_tolower); -} - -template <typename Type> -struct AsciiLower { + } +}; + +void TransformAsciiLower(const uint8_t* input, int64_t length, uint8_t* output) { + std::transform(input, input + length, output, ascii_tolower); +} + +template <typename Type> +struct AsciiLower { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return StringDataTransform<Type>(ctx, batch, TransformAsciiLower, out); - } -}; - -// ---------------------------------------------------------------------- -// exact pattern detection - -using StrToBoolTransformFunc = - std::function<void(const void*, const uint8_t*, int64_t, int64_t, uint8_t*)>; - -// Apply `transform` to input character data- this function cannot change the -// length -template <typename Type> -void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, - StrToBoolTransformFunc transform, Datum* out) { - using offset_type = typename Type::offset_type; - - if (batch[0].kind() == Datum::ARRAY) { - const ArrayData& input = *batch[0].array(); - ArrayData* out_arr = out->mutable_array(); - if (input.length > 0) { - transform( - reinterpret_cast<const offset_type*>(input.buffers[1]->data()) + input.offset, - input.buffers[2]->data(), input.length, out_arr->offset, - out_arr->buffers[1]->mutable_data()); - } - } else { - const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); - if (input.is_valid) { - uint8_t result_value = 0; - std::array<offset_type, 2> offsets{0, - static_cast<offset_type>(input.value->size())}; - transform(offsets.data(), input.value->data(), 1, /*output_offset=*/0, - &result_value); + } +}; + +// ---------------------------------------------------------------------- +// exact pattern detection + +using StrToBoolTransformFunc = + std::function<void(const void*, const uint8_t*, int64_t, int64_t, uint8_t*)>; + +// Apply `transform` to input character data- this function cannot change the +// length +template <typename Type> +void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, + StrToBoolTransformFunc transform, Datum* out) { + using offset_type = typename Type::offset_type; + + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& input = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + if (input.length > 0) { + transform( + reinterpret_cast<const offset_type*>(input.buffers[1]->data()) + input.offset, + input.buffers[2]->data(), input.length, out_arr->offset, + out_arr->buffers[1]->mutable_data()); + } + } else { + const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); + if (input.is_valid) { + uint8_t result_value = 0; + std::array<offset_type, 2> offsets{0, + static_cast<offset_type>(input.value->size())}; + transform(offsets.data(), input.value->data(), 1, /*output_offset=*/0, + &result_value); out->value = std::make_shared<BooleanScalar>(result_value > 0); - } - } -} - + } + } +} + using MatchSubstringState = OptionsWrapper<MatchSubstringOptions>; - + // This is an implementation of the Knuth-Morris-Pratt algorithm struct PlainSubstringMatcher { const MatchSubstringOptions& options_; @@ -507,31 +507,31 @@ struct PlainSubstringMatcher { } prefix_length++; prefix_table[pos + 1] = prefix_length; - } - } - + } + } + int64_t Find(util::string_view current) const { // Phase 2: Find the prefix in the data const auto pattern_length = options_.pattern.size(); - int64_t pattern_pos = 0; + int64_t pattern_pos = 0; int64_t pos = 0; if (pattern_length == 0) return 0; for (const auto c : current) { while ((pattern_pos >= 0) && (options_.pattern[pattern_pos] != c)) { - pattern_pos = prefix_table[pattern_pos]; - } - pattern_pos++; + pattern_pos = prefix_table[pattern_pos]; + } + pattern_pos++; if (static_cast<size_t>(pattern_pos) == pattern_length) { return pos + 1 - pattern_length; - } + } pos++; - } + } return -1; - } - + } + bool Match(util::string_view current) const { return Find(current) >= 0; } }; - + struct PlainStartsWithMatcher { const MatchSubstringOptions& options_; @@ -607,12 +607,12 @@ struct RegexSubstringMatcher { template <typename Type, typename Matcher> struct MatchSubstringImpl { - using offset_type = typename Type::offset_type; + using offset_type = typename Type::offset_type; static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out, const Matcher* matcher) { - StringBoolTransform<Type>( - ctx, batch, + StringBoolTransform<Type>( + ctx, batch, [&matcher](const void* raw_offsets, const uint8_t* data, int64_t length, int64_t output_offset, uint8_t* output) { const offset_type* offsets = reinterpret_cast<const offset_type*>(raw_offsets); @@ -626,12 +626,12 @@ struct MatchSubstringImpl { bitmap_writer.Next(); } bitmap_writer.Finish(); - }, - out); + }, + out); return Status::OK(); - } -}; - + } +}; + template <typename Type, typename Matcher> struct MatchSubstring { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -842,7 +842,7 @@ const FunctionDoc match_like_doc( #endif -void AddMatchSubstring(FunctionRegistry* registry) { +void AddMatchSubstring(FunctionRegistry* registry) { { auto func = std::make_shared<ScalarFunction>("match_substring", Arity::Unary(), &match_substring_doc); @@ -1344,420 +1344,420 @@ void AddSlice(FunctionRegistry* registry) { &utf8_slice_codeunits_doc); using t32 = SliceCodeunits<StringType>; using t64 = SliceCodeunits<LargeStringType>; - DCHECK_OK( + DCHECK_OK( func->AddKernel({utf8()}, utf8(), t32::Exec, SliceCodeunitsTransform::State::Init)); DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), t64::Exec, SliceCodeunitsTransform::State::Init)); - DCHECK_OK(registry->AddFunction(std::move(func))); -} - -// IsAlpha/Digit etc - -#ifdef ARROW_WITH_UTF8PROC - -static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, uint32_t mask) { - utf8proc_category_t general_category = codepoint <= kMaxCodepointLookup - ? lut_category[codepoint] - : utf8proc_category(codepoint); - uint32_t general_category_bit = 1 << general_category; - // for e.g. undefined (but valid) codepoints, general_category == 0 == - // UTF8PROC_CATEGORY_CN - return (general_category != UTF8PROC_CATEGORY_CN) && - ((general_category_bit & mask) != 0); -} - -template <typename... Categories> -static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, uint32_t mask, - utf8proc_category_t category, - Categories... categories) { - return HasAnyUnicodeGeneralCategory(codepoint, mask | (1 << category), categories...); -} - -template <typename... Categories> -static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, - utf8proc_category_t category, - Categories... categories) { - return HasAnyUnicodeGeneralCategory(codepoint, static_cast<uint32_t>(1u << category), - categories...); -} - -static inline bool IsCasedCharacterUnicode(uint32_t codepoint) { - return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU, - UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT) || - ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) != codepoint) || - (static_cast<uint32_t>(utf8proc_tolower(codepoint)) != codepoint)); -} - -static inline bool IsLowerCaseCharacterUnicode(uint32_t codepoint) { - // although this trick seems to work for upper case, this is not enough for lower case - // testing, see https://github.com/JuliaStrings/utf8proc/issues/195 . But currently the - // best we can do - return (HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LL) || - ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) != codepoint) && - (static_cast<uint32_t>(utf8proc_tolower(codepoint)) == codepoint))) && - !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LT); -} - -static inline bool IsUpperCaseCharacterUnicode(uint32_t codepoint) { - // this seems to be a good workaround for utf8proc not having case information - // https://github.com/JuliaStrings/utf8proc/issues/195 - return (HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU) || - ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) == codepoint) && - (static_cast<uint32_t>(utf8proc_tolower(codepoint)) != codepoint))) && - !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LT); -} - -static inline bool IsAlphaNumericCharacterUnicode(uint32_t codepoint) { - return HasAnyUnicodeGeneralCategory( - codepoint, UTF8PROC_CATEGORY_LU, UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT, - UTF8PROC_CATEGORY_LM, UTF8PROC_CATEGORY_LO, UTF8PROC_CATEGORY_ND, - UTF8PROC_CATEGORY_NL, UTF8PROC_CATEGORY_NO); -} - -static inline bool IsAlphaCharacterUnicode(uint32_t codepoint) { - return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU, - UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT, - UTF8PROC_CATEGORY_LM, UTF8PROC_CATEGORY_LO); -} - -static inline bool IsDecimalCharacterUnicode(uint32_t codepoint) { - return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND); -} - -static inline bool IsDigitCharacterUnicode(uint32_t codepoint) { - // Python defines this as Numeric_Type=Digit or Numeric_Type=Decimal. - // utf8proc has no support for this, this is the best we can do: - return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND); -} - -static inline bool IsNumericCharacterUnicode(uint32_t codepoint) { - // Formally this is not correct, but utf8proc does not allow us to query for Numerical - // properties, e.g. Numeric_Value and Numeric_Type - // Python defines Numeric as Numeric_Type=Digit, Numeric_Type=Decimal or - // Numeric_Type=Numeric. - return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND, - UTF8PROC_CATEGORY_NL, UTF8PROC_CATEGORY_NO); -} - -static inline bool IsSpaceCharacterUnicode(uint32_t codepoint) { - auto property = utf8proc_get_property(codepoint); - return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ZS) || - property->bidi_class == UTF8PROC_BIDI_CLASS_WS || - property->bidi_class == UTF8PROC_BIDI_CLASS_B || - property->bidi_class == UTF8PROC_BIDI_CLASS_S; -} - -static inline bool IsPrintableCharacterUnicode(uint32_t codepoint) { - uint32_t general_category = utf8proc_category(codepoint); - return (general_category != UTF8PROC_CATEGORY_CN) && - !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_CC, - UTF8PROC_CATEGORY_CF, UTF8PROC_CATEGORY_CS, - UTF8PROC_CATEGORY_CO, UTF8PROC_CATEGORY_ZS, - UTF8PROC_CATEGORY_ZL, UTF8PROC_CATEGORY_ZP); -} - -#endif - -static inline bool IsLowerCaseCharacterAscii(uint8_t ascii_character) { - return (ascii_character >= 'a') && (ascii_character <= 'z'); -} - -static inline bool IsUpperCaseCharacterAscii(uint8_t ascii_character) { - return (ascii_character >= 'A') && (ascii_character <= 'Z'); -} - -static inline bool IsCasedCharacterAscii(uint8_t ascii_character) { - return IsLowerCaseCharacterAscii(ascii_character) || - IsUpperCaseCharacterAscii(ascii_character); -} - -static inline bool IsAlphaCharacterAscii(uint8_t ascii_character) { - return IsCasedCharacterAscii(ascii_character); // same -} - -static inline bool IsAlphaNumericCharacterAscii(uint8_t ascii_character) { - return ((ascii_character >= '0') && (ascii_character <= '9')) || - ((ascii_character >= 'a') && (ascii_character <= 'z')) || - ((ascii_character >= 'A') && (ascii_character <= 'Z')); -} - -static inline bool IsDecimalCharacterAscii(uint8_t ascii_character) { - return ((ascii_character >= '0') && (ascii_character <= '9')); -} - -static inline bool IsSpaceCharacterAscii(uint8_t ascii_character) { - return ((ascii_character >= 0x09) && (ascii_character <= 0x0D)) || - (ascii_character == ' '); -} - -static inline bool IsPrintableCharacterAscii(uint8_t ascii_character) { - return ((ascii_character >= ' ') && (ascii_character <= '~')); -} - -template <typename Derived, bool allow_empty = false> -struct CharacterPredicateUnicode { + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +// IsAlpha/Digit etc + +#ifdef ARROW_WITH_UTF8PROC + +static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, uint32_t mask) { + utf8proc_category_t general_category = codepoint <= kMaxCodepointLookup + ? lut_category[codepoint] + : utf8proc_category(codepoint); + uint32_t general_category_bit = 1 << general_category; + // for e.g. undefined (but valid) codepoints, general_category == 0 == + // UTF8PROC_CATEGORY_CN + return (general_category != UTF8PROC_CATEGORY_CN) && + ((general_category_bit & mask) != 0); +} + +template <typename... Categories> +static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, uint32_t mask, + utf8proc_category_t category, + Categories... categories) { + return HasAnyUnicodeGeneralCategory(codepoint, mask | (1 << category), categories...); +} + +template <typename... Categories> +static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, + utf8proc_category_t category, + Categories... categories) { + return HasAnyUnicodeGeneralCategory(codepoint, static_cast<uint32_t>(1u << category), + categories...); +} + +static inline bool IsCasedCharacterUnicode(uint32_t codepoint) { + return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU, + UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT) || + ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) != codepoint) || + (static_cast<uint32_t>(utf8proc_tolower(codepoint)) != codepoint)); +} + +static inline bool IsLowerCaseCharacterUnicode(uint32_t codepoint) { + // although this trick seems to work for upper case, this is not enough for lower case + // testing, see https://github.com/JuliaStrings/utf8proc/issues/195 . But currently the + // best we can do + return (HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LL) || + ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) != codepoint) && + (static_cast<uint32_t>(utf8proc_tolower(codepoint)) == codepoint))) && + !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LT); +} + +static inline bool IsUpperCaseCharacterUnicode(uint32_t codepoint) { + // this seems to be a good workaround for utf8proc not having case information + // https://github.com/JuliaStrings/utf8proc/issues/195 + return (HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU) || + ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) == codepoint) && + (static_cast<uint32_t>(utf8proc_tolower(codepoint)) != codepoint))) && + !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LT); +} + +static inline bool IsAlphaNumericCharacterUnicode(uint32_t codepoint) { + return HasAnyUnicodeGeneralCategory( + codepoint, UTF8PROC_CATEGORY_LU, UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT, + UTF8PROC_CATEGORY_LM, UTF8PROC_CATEGORY_LO, UTF8PROC_CATEGORY_ND, + UTF8PROC_CATEGORY_NL, UTF8PROC_CATEGORY_NO); +} + +static inline bool IsAlphaCharacterUnicode(uint32_t codepoint) { + return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU, + UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT, + UTF8PROC_CATEGORY_LM, UTF8PROC_CATEGORY_LO); +} + +static inline bool IsDecimalCharacterUnicode(uint32_t codepoint) { + return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND); +} + +static inline bool IsDigitCharacterUnicode(uint32_t codepoint) { + // Python defines this as Numeric_Type=Digit or Numeric_Type=Decimal. + // utf8proc has no support for this, this is the best we can do: + return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND); +} + +static inline bool IsNumericCharacterUnicode(uint32_t codepoint) { + // Formally this is not correct, but utf8proc does not allow us to query for Numerical + // properties, e.g. Numeric_Value and Numeric_Type + // Python defines Numeric as Numeric_Type=Digit, Numeric_Type=Decimal or + // Numeric_Type=Numeric. + return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND, + UTF8PROC_CATEGORY_NL, UTF8PROC_CATEGORY_NO); +} + +static inline bool IsSpaceCharacterUnicode(uint32_t codepoint) { + auto property = utf8proc_get_property(codepoint); + return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ZS) || + property->bidi_class == UTF8PROC_BIDI_CLASS_WS || + property->bidi_class == UTF8PROC_BIDI_CLASS_B || + property->bidi_class == UTF8PROC_BIDI_CLASS_S; +} + +static inline bool IsPrintableCharacterUnicode(uint32_t codepoint) { + uint32_t general_category = utf8proc_category(codepoint); + return (general_category != UTF8PROC_CATEGORY_CN) && + !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_CC, + UTF8PROC_CATEGORY_CF, UTF8PROC_CATEGORY_CS, + UTF8PROC_CATEGORY_CO, UTF8PROC_CATEGORY_ZS, + UTF8PROC_CATEGORY_ZL, UTF8PROC_CATEGORY_ZP); +} + +#endif + +static inline bool IsLowerCaseCharacterAscii(uint8_t ascii_character) { + return (ascii_character >= 'a') && (ascii_character <= 'z'); +} + +static inline bool IsUpperCaseCharacterAscii(uint8_t ascii_character) { + return (ascii_character >= 'A') && (ascii_character <= 'Z'); +} + +static inline bool IsCasedCharacterAscii(uint8_t ascii_character) { + return IsLowerCaseCharacterAscii(ascii_character) || + IsUpperCaseCharacterAscii(ascii_character); +} + +static inline bool IsAlphaCharacterAscii(uint8_t ascii_character) { + return IsCasedCharacterAscii(ascii_character); // same +} + +static inline bool IsAlphaNumericCharacterAscii(uint8_t ascii_character) { + return ((ascii_character >= '0') && (ascii_character <= '9')) || + ((ascii_character >= 'a') && (ascii_character <= 'z')) || + ((ascii_character >= 'A') && (ascii_character <= 'Z')); +} + +static inline bool IsDecimalCharacterAscii(uint8_t ascii_character) { + return ((ascii_character >= '0') && (ascii_character <= '9')); +} + +static inline bool IsSpaceCharacterAscii(uint8_t ascii_character) { + return ((ascii_character >= 0x09) && (ascii_character <= 0x0D)) || + (ascii_character == ' '); +} + +static inline bool IsPrintableCharacterAscii(uint8_t ascii_character) { + return ((ascii_character >= ' ') && (ascii_character <= '~')); +} + +template <typename Derived, bool allow_empty = false> +struct CharacterPredicateUnicode { static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits, Status* st) { - if (allow_empty && input_string_ncodeunits == 0) { - return true; - } - bool all; - bool any = false; - if (!ARROW_PREDICT_TRUE(arrow::util::UTF8AllOf( - input, input + input_string_ncodeunits, &all, [&any](uint32_t codepoint) { - any |= Derived::PredicateCharacterAny(codepoint); - return Derived::PredicateCharacterAll(codepoint); - }))) { + if (allow_empty && input_string_ncodeunits == 0) { + return true; + } + bool all; + bool any = false; + if (!ARROW_PREDICT_TRUE(arrow::util::UTF8AllOf( + input, input + input_string_ncodeunits, &all, [&any](uint32_t codepoint) { + any |= Derived::PredicateCharacterAny(codepoint); + return Derived::PredicateCharacterAll(codepoint); + }))) { *st = Status::Invalid("Invalid UTF8 sequence in input"); - return false; - } - return all & any; - } - - static inline bool PredicateCharacterAny(uint32_t) { - return true; // default condition make sure there is at least 1 charachter - } -}; - -template <typename Derived, bool allow_empty = false> -struct CharacterPredicateAscii { + return false; + } + return all & any; + } + + static inline bool PredicateCharacterAny(uint32_t) { + return true; // default condition make sure there is at least 1 charachter + } +}; + +template <typename Derived, bool allow_empty = false> +struct CharacterPredicateAscii { static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits, Status*) { - if (allow_empty && input_string_ncodeunits == 0) { - return true; - } - bool any = false; - // MB: A simple for loops seems 8% faster on gcc 9.3, running the IsAlphaNumericAscii - // benchmark. I don't consider that worth it. - bool all = std::all_of(input, input + input_string_ncodeunits, - [&any](uint8_t ascii_character) { - any |= Derived::PredicateCharacterAny(ascii_character); - return Derived::PredicateCharacterAll(ascii_character); - }); - return all & any; - } - - static inline bool PredicateCharacterAny(uint8_t) { - return true; // default condition make sure there is at least 1 charachter - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsAlphaNumericUnicode : CharacterPredicateUnicode<IsAlphaNumericUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return IsAlphaNumericCharacterUnicode(codepoint); - } -}; -#endif - -struct IsAlphaNumericAscii : CharacterPredicateAscii<IsAlphaNumericAscii> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - return IsAlphaNumericCharacterAscii(ascii_character); - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsAlphaUnicode : CharacterPredicateUnicode<IsAlphaUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return IsAlphaCharacterUnicode(codepoint); - } -}; -#endif - -struct IsAlphaAscii : CharacterPredicateAscii<IsAlphaAscii> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - return IsAlphaCharacterAscii(ascii_character); - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsDecimalUnicode : CharacterPredicateUnicode<IsDecimalUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return IsDecimalCharacterUnicode(codepoint); - } -}; -#endif - -struct IsDecimalAscii : CharacterPredicateAscii<IsDecimalAscii> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - return IsDecimalCharacterAscii(ascii_character); - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsDigitUnicode : CharacterPredicateUnicode<IsDigitUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return IsDigitCharacterUnicode(codepoint); - } -}; - -struct IsNumericUnicode : CharacterPredicateUnicode<IsNumericUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return IsNumericCharacterUnicode(codepoint); - } -}; -#endif - -struct IsAscii { + if (allow_empty && input_string_ncodeunits == 0) { + return true; + } + bool any = false; + // MB: A simple for loops seems 8% faster on gcc 9.3, running the IsAlphaNumericAscii + // benchmark. I don't consider that worth it. + bool all = std::all_of(input, input + input_string_ncodeunits, + [&any](uint8_t ascii_character) { + any |= Derived::PredicateCharacterAny(ascii_character); + return Derived::PredicateCharacterAll(ascii_character); + }); + return all & any; + } + + static inline bool PredicateCharacterAny(uint8_t) { + return true; // default condition make sure there is at least 1 charachter + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsAlphaNumericUnicode : CharacterPredicateUnicode<IsAlphaNumericUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return IsAlphaNumericCharacterUnicode(codepoint); + } +}; +#endif + +struct IsAlphaNumericAscii : CharacterPredicateAscii<IsAlphaNumericAscii> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + return IsAlphaNumericCharacterAscii(ascii_character); + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsAlphaUnicode : CharacterPredicateUnicode<IsAlphaUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return IsAlphaCharacterUnicode(codepoint); + } +}; +#endif + +struct IsAlphaAscii : CharacterPredicateAscii<IsAlphaAscii> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + return IsAlphaCharacterAscii(ascii_character); + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsDecimalUnicode : CharacterPredicateUnicode<IsDecimalUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return IsDecimalCharacterUnicode(codepoint); + } +}; +#endif + +struct IsDecimalAscii : CharacterPredicateAscii<IsDecimalAscii> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + return IsDecimalCharacterAscii(ascii_character); + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsDigitUnicode : CharacterPredicateUnicode<IsDigitUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return IsDigitCharacterUnicode(codepoint); + } +}; + +struct IsNumericUnicode : CharacterPredicateUnicode<IsNumericUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return IsNumericCharacterUnicode(codepoint); + } +}; +#endif + +struct IsAscii { static bool Call(KernelContext*, const uint8_t* input, size_t input_string_nascii_characters, Status*) { - return std::all_of(input, input + input_string_nascii_characters, - IsAsciiCharacter<uint8_t>); - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsLowerUnicode : CharacterPredicateUnicode<IsLowerUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - // Only for cased character it needs to be lower case - return !IsCasedCharacterUnicode(codepoint) || IsLowerCaseCharacterUnicode(codepoint); - } - static inline bool PredicateCharacterAny(uint32_t codepoint) { - return IsCasedCharacterUnicode(codepoint); // at least 1 cased character - } -}; -#endif - -struct IsLowerAscii : CharacterPredicateAscii<IsLowerAscii> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - // Only for cased character it needs to be lower case - return !IsCasedCharacterAscii(ascii_character) || - IsLowerCaseCharacterAscii(ascii_character); - } - static inline bool PredicateCharacterAny(uint8_t ascii_character) { - return IsCasedCharacterAscii(ascii_character); // at least 1 cased character - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsPrintableUnicode - : CharacterPredicateUnicode<IsPrintableUnicode, /*allow_empty=*/true> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return codepoint == ' ' || IsPrintableCharacterUnicode(codepoint); - } -}; -#endif - -struct IsPrintableAscii - : CharacterPredicateAscii<IsPrintableAscii, /*allow_empty=*/true> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - return IsPrintableCharacterAscii(ascii_character); - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsSpaceUnicode : CharacterPredicateUnicode<IsSpaceUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - return IsSpaceCharacterUnicode(codepoint); - } -}; -#endif - -struct IsSpaceAscii : CharacterPredicateAscii<IsSpaceAscii> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - return IsSpaceCharacterAscii(ascii_character); - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsTitleUnicode { + return std::all_of(input, input + input_string_nascii_characters, + IsAsciiCharacter<uint8_t>); + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsLowerUnicode : CharacterPredicateUnicode<IsLowerUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + // Only for cased character it needs to be lower case + return !IsCasedCharacterUnicode(codepoint) || IsLowerCaseCharacterUnicode(codepoint); + } + static inline bool PredicateCharacterAny(uint32_t codepoint) { + return IsCasedCharacterUnicode(codepoint); // at least 1 cased character + } +}; +#endif + +struct IsLowerAscii : CharacterPredicateAscii<IsLowerAscii> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + // Only for cased character it needs to be lower case + return !IsCasedCharacterAscii(ascii_character) || + IsLowerCaseCharacterAscii(ascii_character); + } + static inline bool PredicateCharacterAny(uint8_t ascii_character) { + return IsCasedCharacterAscii(ascii_character); // at least 1 cased character + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsPrintableUnicode + : CharacterPredicateUnicode<IsPrintableUnicode, /*allow_empty=*/true> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return codepoint == ' ' || IsPrintableCharacterUnicode(codepoint); + } +}; +#endif + +struct IsPrintableAscii + : CharacterPredicateAscii<IsPrintableAscii, /*allow_empty=*/true> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + return IsPrintableCharacterAscii(ascii_character); + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsSpaceUnicode : CharacterPredicateUnicode<IsSpaceUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + return IsSpaceCharacterUnicode(codepoint); + } +}; +#endif + +struct IsSpaceAscii : CharacterPredicateAscii<IsSpaceAscii> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + return IsSpaceCharacterAscii(ascii_character); + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsTitleUnicode { static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits, Status* st) { - // rules: - // * 1: lower case follows cased - // * 2: upper case follows uncased - // * 3: at least 1 cased character (which logically should be upper/title) - bool rules_1_and_2; - bool previous_cased = false; // in LL, LU or LT - bool rule_3 = false; - bool status = - arrow::util::UTF8AllOf(input, input + input_string_ncodeunits, &rules_1_and_2, - [&previous_cased, &rule_3](uint32_t codepoint) { - if (IsLowerCaseCharacterUnicode(codepoint)) { - if (!previous_cased) return false; // rule 1 broken - previous_cased = true; - } else if (IsCasedCharacterUnicode(codepoint)) { - if (previous_cased) return false; // rule 2 broken - // next should be a lower case or uncased - previous_cased = true; - rule_3 = true; // rule 3 obeyed - } else { - // a non-cased char, like _ or 1 - // next should be upper case or more uncased - previous_cased = false; - } - return true; - }); - if (!ARROW_PREDICT_TRUE(status)) { + // rules: + // * 1: lower case follows cased + // * 2: upper case follows uncased + // * 3: at least 1 cased character (which logically should be upper/title) + bool rules_1_and_2; + bool previous_cased = false; // in LL, LU or LT + bool rule_3 = false; + bool status = + arrow::util::UTF8AllOf(input, input + input_string_ncodeunits, &rules_1_and_2, + [&previous_cased, &rule_3](uint32_t codepoint) { + if (IsLowerCaseCharacterUnicode(codepoint)) { + if (!previous_cased) return false; // rule 1 broken + previous_cased = true; + } else if (IsCasedCharacterUnicode(codepoint)) { + if (previous_cased) return false; // rule 2 broken + // next should be a lower case or uncased + previous_cased = true; + rule_3 = true; // rule 3 obeyed + } else { + // a non-cased char, like _ or 1 + // next should be upper case or more uncased + previous_cased = false; + } + return true; + }); + if (!ARROW_PREDICT_TRUE(status)) { *st = Status::Invalid("Invalid UTF8 sequence in input"); - return false; - } - return rules_1_and_2 & rule_3; - } -}; -#endif - -struct IsTitleAscii { + return false; + } + return rules_1_and_2 & rule_3; + } +}; +#endif + +struct IsTitleAscii { static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits, Status*) { - // rules: - // * 1: lower case follows cased - // * 2: upper case follows uncased - // * 3: at least 1 cased character (which logically should be upper/title) - bool rules_1_and_2 = true; - bool previous_cased = false; // in LL, LU or LT - bool rule_3 = false; - // we cannot rely on std::all_of because we need guaranteed order - for (const uint8_t* c = input; c < input + input_string_ncodeunits; ++c) { - if (IsLowerCaseCharacterAscii(*c)) { - if (!previous_cased) { - // rule 1 broken - rules_1_and_2 = false; - break; - } - previous_cased = true; - } else if (IsCasedCharacterAscii(*c)) { - if (previous_cased) { - // rule 2 broken - rules_1_and_2 = false; - break; - } - // next should be a lower case or uncased - previous_cased = true; - rule_3 = true; // rule 3 obeyed - } else { - // a non-cased char, like _ or 1 - // next should be upper case or more uncased - previous_cased = false; - } - } - return rules_1_and_2 & rule_3; - } -}; - -#ifdef ARROW_WITH_UTF8PROC -struct IsUpperUnicode : CharacterPredicateUnicode<IsUpperUnicode> { - static inline bool PredicateCharacterAll(uint32_t codepoint) { - // Only for cased character it needs to be lower case - return !IsCasedCharacterUnicode(codepoint) || IsUpperCaseCharacterUnicode(codepoint); - } - static inline bool PredicateCharacterAny(uint32_t codepoint) { - return IsCasedCharacterUnicode(codepoint); // at least 1 cased character - } -}; -#endif - -struct IsUpperAscii : CharacterPredicateAscii<IsUpperAscii> { - static inline bool PredicateCharacterAll(uint8_t ascii_character) { - // Only for cased character it needs to be lower case - return !IsCasedCharacterAscii(ascii_character) || - IsUpperCaseCharacterAscii(ascii_character); - } - static inline bool PredicateCharacterAny(uint8_t ascii_character) { - return IsCasedCharacterAscii(ascii_character); // at least 1 cased character - } -}; - + // rules: + // * 1: lower case follows cased + // * 2: upper case follows uncased + // * 3: at least 1 cased character (which logically should be upper/title) + bool rules_1_and_2 = true; + bool previous_cased = false; // in LL, LU or LT + bool rule_3 = false; + // we cannot rely on std::all_of because we need guaranteed order + for (const uint8_t* c = input; c < input + input_string_ncodeunits; ++c) { + if (IsLowerCaseCharacterAscii(*c)) { + if (!previous_cased) { + // rule 1 broken + rules_1_and_2 = false; + break; + } + previous_cased = true; + } else if (IsCasedCharacterAscii(*c)) { + if (previous_cased) { + // rule 2 broken + rules_1_and_2 = false; + break; + } + // next should be a lower case or uncased + previous_cased = true; + rule_3 = true; // rule 3 obeyed + } else { + // a non-cased char, like _ or 1 + // next should be upper case or more uncased + previous_cased = false; + } + } + return rules_1_and_2 & rule_3; + } +}; + +#ifdef ARROW_WITH_UTF8PROC +struct IsUpperUnicode : CharacterPredicateUnicode<IsUpperUnicode> { + static inline bool PredicateCharacterAll(uint32_t codepoint) { + // Only for cased character it needs to be lower case + return !IsCasedCharacterUnicode(codepoint) || IsUpperCaseCharacterUnicode(codepoint); + } + static inline bool PredicateCharacterAny(uint32_t codepoint) { + return IsCasedCharacterUnicode(codepoint); // at least 1 cased character + } +}; +#endif + +struct IsUpperAscii : CharacterPredicateAscii<IsUpperAscii> { + static inline bool PredicateCharacterAll(uint8_t ascii_character) { + // Only for cased character it needs to be lower case + return !IsCasedCharacterAscii(ascii_character) || + IsUpperCaseCharacterAscii(ascii_character); + } + static inline bool PredicateCharacterAny(uint8_t ascii_character) { + return IsCasedCharacterAscii(ascii_character); // at least 1 cased character + } +}; + // splitting template <typename Options> @@ -2215,7 +2215,7 @@ void AddSplit(FunctionRegistry* registry) { #endif } -// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- // Replace substring (plain, regex) template <typename Type, typename Replacer> @@ -2773,43 +2773,43 @@ void AddExtractRegex(FunctionRegistry* registry) { #endif // ARROW_WITH_RE2 // ---------------------------------------------------------------------- -// strptime string parsing - -using StrptimeState = OptionsWrapper<StrptimeOptions>; - -struct ParseStrptime { - explicit ParseStrptime(const StrptimeOptions& options) - : parser(TimestampParser::MakeStrptime(options.format)), unit(options.unit) {} - - template <typename... Ignored> +// strptime string parsing + +using StrptimeState = OptionsWrapper<StrptimeOptions>; + +struct ParseStrptime { + explicit ParseStrptime(const StrptimeOptions& options) + : parser(TimestampParser::MakeStrptime(options.format)), unit(options.unit) {} + + template <typename... Ignored> int64_t Call(KernelContext*, util::string_view val, Status* st) const { - int64_t result = 0; - if (!(*parser)(val.data(), val.size(), unit, &result)) { + int64_t result = 0; + if (!(*parser)(val.data(), val.size(), unit, &result)) { *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ", TimestampType(unit).ToString()); - } - return result; - } - - std::shared_ptr<TimestampParser> parser; - TimeUnit::type unit; -}; - -template <typename InputType> + } + return result; + } + + std::shared_ptr<TimestampParser> parser; + TimeUnit::type unit; +}; + +template <typename InputType> Status StrptimeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - applicator::ScalarUnaryNotNullStateful<TimestampType, InputType, ParseStrptime> kernel{ - ParseStrptime(StrptimeState::Get(ctx))}; - return kernel.Exec(ctx, batch, out); -} - -Result<ValueDescr> StrptimeResolve(KernelContext* ctx, const std::vector<ValueDescr>&) { - if (ctx->state()) { - return ::arrow::timestamp(StrptimeState::Get(ctx).unit); - } - - return Status::Invalid("strptime does not provide default StrptimeOptions"); -} - + applicator::ScalarUnaryNotNullStateful<TimestampType, InputType, ParseStrptime> kernel{ + ParseStrptime(StrptimeState::Get(ctx))}; + return kernel.Exec(ctx, batch, out); +} + +Result<ValueDescr> StrptimeResolve(KernelContext* ctx, const std::vector<ValueDescr>&) { + if (ctx->state()) { + return ::arrow::timestamp(StrptimeState::Get(ctx).unit); + } + + return Status::Invalid("strptime does not provide default StrptimeOptions"); +} + // ---------------------------------------------------------------------- // string padding @@ -3273,31 +3273,31 @@ const FunctionDoc utf8_length_doc("Compute UTF8 string lengths", "UTF8 characters. Null values emit null."), {"strings"}); -void AddStrptime(FunctionRegistry* registry) { +void AddStrptime(FunctionRegistry* registry) { auto func = std::make_shared<ScalarFunction>("strptime", Arity::Unary(), &strptime_doc); - DCHECK_OK(func->AddKernel({utf8()}, OutputType(StrptimeResolve), - StrptimeExec<StringType>, StrptimeState::Init)); - DCHECK_OK(func->AddKernel({large_utf8()}, OutputType(StrptimeResolve), - StrptimeExec<LargeStringType>, StrptimeState::Init)); - DCHECK_OK(registry->AddFunction(std::move(func))); -} - -void AddBinaryLength(FunctionRegistry* registry) { + DCHECK_OK(func->AddKernel({utf8()}, OutputType(StrptimeResolve), + StrptimeExec<StringType>, StrptimeState::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, OutputType(StrptimeResolve), + StrptimeExec<LargeStringType>, StrptimeState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +void AddBinaryLength(FunctionRegistry* registry) { auto func = std::make_shared<ScalarFunction>("binary_length", Arity::Unary(), &binary_length_doc); - ArrayKernelExec exec_offset_32 = - applicator::ScalarUnaryNotNull<Int32Type, StringType, BinaryLength>::Exec; - ArrayKernelExec exec_offset_64 = - applicator::ScalarUnaryNotNull<Int64Type, LargeStringType, BinaryLength>::Exec; - for (const auto& input_type : {binary(), utf8()}) { - DCHECK_OK(func->AddKernel({input_type}, int32(), exec_offset_32)); - } - for (const auto& input_type : {large_binary(), large_utf8()}) { - DCHECK_OK(func->AddKernel({input_type}, int64(), exec_offset_64)); - } - DCHECK_OK(registry->AddFunction(std::move(func))); -} - + ArrayKernelExec exec_offset_32 = + applicator::ScalarUnaryNotNull<Int32Type, StringType, BinaryLength>::Exec; + ArrayKernelExec exec_offset_64 = + applicator::ScalarUnaryNotNull<Int64Type, LargeStringType, BinaryLength>::Exec; + for (const auto& input_type : {binary(), utf8()}) { + DCHECK_OK(func->AddKernel({input_type}, int32(), exec_offset_32)); + } + for (const auto& input_type : {large_binary(), large_utf8()}) { + DCHECK_OK(func->AddKernel({input_type}, int64(), exec_offset_64)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); +} + void AddUtf8Length(FunctionRegistry* registry) { auto func = std::make_shared<ScalarFunction>("utf8_length", Arity::Unary(), &utf8_length_doc); @@ -3821,7 +3821,7 @@ void AddBinaryJoin(FunctionRegistry* registry) { } } -template <template <typename> class ExecFunctor> +template <template <typename> class ExecFunctor> void MakeUnaryStringBatchKernel( std::string name, FunctionRegistry* registry, const FunctionDoc* doc, MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE) { @@ -3838,9 +3838,9 @@ void MakeUnaryStringBatchKernel( kernel.mem_allocation = mem_allocation; DCHECK_OK(func->AddKernel(std::move(kernel))); } - DCHECK_OK(registry->AddFunction(std::move(func))); -} - + DCHECK_OK(registry->AddFunction(std::move(func))); +} + template <template <typename> class ExecFunctor> void MakeUnaryStringBatchKernelWithState( std::string name, FunctionRegistry* registry, const FunctionDoc* doc, @@ -3861,71 +3861,71 @@ void MakeUnaryStringBatchKernelWithState( DCHECK_OK(registry->AddFunction(std::move(func))); } -#ifdef ARROW_WITH_UTF8PROC - -template <template <typename> class Transformer> +#ifdef ARROW_WITH_UTF8PROC + +template <template <typename> class Transformer> void MakeUnaryStringUTF8TransformKernel(std::string name, FunctionRegistry* registry, const FunctionDoc* doc) { auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc); - ArrayKernelExec exec_32 = Transformer<StringType>::Exec; - ArrayKernelExec exec_64 = Transformer<LargeStringType>::Exec; - DCHECK_OK(func->AddKernel({utf8()}, utf8(), exec_32)); - DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), exec_64)); - DCHECK_OK(registry->AddFunction(std::move(func))); -} - -#endif - + ArrayKernelExec exec_32 = Transformer<StringType>::Exec; + ArrayKernelExec exec_64 = Transformer<LargeStringType>::Exec; + DCHECK_OK(func->AddKernel({utf8()}, utf8(), exec_32)); + DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), exec_64)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +#endif + // NOTE: Predicate should only populate 'status' with errors, // leave it unmodified to indicate Status::OK() using StringPredicate = std::function<bool(KernelContext*, const uint8_t*, size_t, Status*)>; - -template <typename Type> + +template <typename Type> Status ApplyPredicate(KernelContext* ctx, const ExecBatch& batch, StringPredicate predicate, Datum* out) { Status st = Status::OK(); - EnsureLookupTablesFilled(); - if (batch[0].kind() == Datum::ARRAY) { - const ArrayData& input = *batch[0].array(); - ArrayIterator<Type> input_it(input); - ArrayData* out_arr = out->mutable_array(); - ::arrow::internal::GenerateBitsUnrolled( - out_arr->buffers[1]->mutable_data(), out_arr->offset, input.length, - [&]() -> bool { - util::string_view val = input_it(); + EnsureLookupTablesFilled(); + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& input = *batch[0].array(); + ArrayIterator<Type> input_it(input); + ArrayData* out_arr = out->mutable_array(); + ::arrow::internal::GenerateBitsUnrolled( + out_arr->buffers[1]->mutable_data(), out_arr->offset, input.length, + [&]() -> bool { + util::string_view val = input_it(); return predicate(ctx, reinterpret_cast<const uint8_t*>(val.data()), val.size(), &st); - }); - } else { - const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); - if (input.is_valid) { + }); + } else { + const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); + if (input.is_valid) { bool boolean_result = predicate(ctx, input.value->data(), static_cast<size_t>(input.value->size()), &st); // UTF decoding can lead to issues if (st.ok()) { out->value = std::make_shared<BooleanScalar>(boolean_result); - } - } - } + } + } + } return st; -} - -template <typename Predicate> +} + +template <typename Predicate> void AddUnaryStringPredicate(std::string name, FunctionRegistry* registry, const FunctionDoc* doc) { auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc); - auto exec_32 = [](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + auto exec_32 = [](KernelContext* ctx, const ExecBatch& batch, Datum* out) { return ApplyPredicate<StringType>(ctx, batch, Predicate::Call, out); - }; - auto exec_64 = [](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + }; + auto exec_64 = [](KernelContext* ctx, const ExecBatch& batch, Datum* out) { return ApplyPredicate<LargeStringType>(ctx, batch, Predicate::Call, out); - }; - DCHECK_OK(func->AddKernel({utf8()}, boolean(), std::move(exec_32))); - DCHECK_OK(func->AddKernel({large_utf8()}, boolean(), std::move(exec_64))); - DCHECK_OK(registry->AddFunction(std::move(func))); -} - + }; + DCHECK_OK(func->AddKernel({utf8()}, boolean(), std::move(exec_32))); + DCHECK_OK(func->AddKernel({large_utf8()}, boolean(), std::move(exec_64))); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + FunctionDoc StringPredicateDoc(std::string summary, std::string description) { return FunctionDoc{std::move(summary), std::move(description), {"strings"}}; } @@ -4041,9 +4041,9 @@ const FunctionDoc utf8_reverse_doc( "composed of multiple codepoints."), {"strings"}); -} // namespace - -void RegisterScalarStringAscii(FunctionRegistry* registry) { +} // namespace + +void RegisterScalarStringAscii(FunctionRegistry* registry) { // ascii_upper and ascii_lower are able to reuse the original offsets buffer, // so don't preallocate them in the output. MakeUnaryStringBatchKernel<AsciiUpper>("ascii_upper", registry, &ascii_upper_doc, @@ -4058,7 +4058,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { &ascii_rtrim_whitespace_doc); MakeUnaryStringBatchKernel<AsciiReverse>("ascii_reverse", registry, &ascii_reverse_doc); MakeUnaryStringBatchKernel<Utf8Reverse>("utf8_reverse", registry, &utf8_reverse_doc); - + MakeUnaryStringBatchKernelWithState<AsciiCenter>("ascii_center", registry, &ascii_center_doc); MakeUnaryStringBatchKernelWithState<AsciiLPad>("ascii_lpad", registry, &ascii_lpad_doc); @@ -4067,7 +4067,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { &utf8_center_doc); MakeUnaryStringBatchKernelWithState<Utf8LPad>("utf8_lpad", registry, &utf8_lpad_doc); MakeUnaryStringBatchKernelWithState<Utf8RPad>("utf8_rpad", registry, &utf8_rpad_doc); - + MakeUnaryStringBatchKernelWithState<AsciiTrim>("ascii_trim", registry, &ascii_trim_doc); MakeUnaryStringBatchKernelWithState<AsciiLTrim>("ascii_ltrim", registry, &ascii_ltrim_doc); @@ -4081,16 +4081,16 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddUnaryStringPredicate<IsAlphaAscii>("ascii_is_alpha", registry, &ascii_is_alpha_doc); AddUnaryStringPredicate<IsDecimalAscii>("ascii_is_decimal", registry, &ascii_is_decimal_doc); - // no is_digit for ascii, since it is the same as is_decimal + // no is_digit for ascii, since it is the same as is_decimal AddUnaryStringPredicate<IsLowerAscii>("ascii_is_lower", registry, &ascii_is_lower_doc); - // no is_numeric for ascii, since it is the same as is_decimal + // no is_numeric for ascii, since it is the same as is_decimal AddUnaryStringPredicate<IsPrintableAscii>("ascii_is_printable", registry, &ascii_is_printable_doc); AddUnaryStringPredicate<IsSpaceAscii>("ascii_is_space", registry, &ascii_is_space_doc); AddUnaryStringPredicate<IsTitleAscii>("ascii_is_title", registry, &ascii_is_title_doc); AddUnaryStringPredicate<IsUpperAscii>("ascii_is_upper", registry, &ascii_is_upper_doc); - -#ifdef ARROW_WITH_UTF8PROC + +#ifdef ARROW_WITH_UTF8PROC MakeUnaryStringUTF8TransformKernel<UTF8Upper>("utf8_upper", registry, &utf8_upper_doc); MakeUnaryStringUTF8TransformKernel<UTF8Lower>("utf8_lower", registry, &utf8_lower_doc); MakeUnaryStringBatchKernel<UTF8TrimWhitespace>("utf8_trim_whitespace", registry, @@ -4102,7 +4102,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringBatchKernelWithState<UTF8Trim>("utf8_trim", registry, &utf8_trim_doc); MakeUnaryStringBatchKernelWithState<UTF8LTrim>("utf8_ltrim", registry, &utf8_ltrim_doc); MakeUnaryStringBatchKernelWithState<UTF8RTrim>("utf8_rtrim", registry, &utf8_rtrim_doc); - + AddUnaryStringPredicate<IsAlphaNumericUnicode>("utf8_is_alnum", registry, &utf8_is_alnum_doc); AddUnaryStringPredicate<IsAlphaUnicode>("utf8_is_alpha", registry, &utf8_is_alpha_doc); @@ -4117,11 +4117,11 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddUnaryStringPredicate<IsSpaceUnicode>("utf8_is_space", registry, &utf8_is_space_doc); AddUnaryStringPredicate<IsTitleUnicode>("utf8_is_title", registry, &utf8_is_title_doc); AddUnaryStringPredicate<IsUpperUnicode>("utf8_is_upper", registry, &utf8_is_upper_doc); -#endif - - AddBinaryLength(registry); +#endif + + AddBinaryLength(registry); AddUtf8Length(registry); - AddMatchSubstring(registry); + AddMatchSubstring(registry); AddFindSubstring(registry); AddCountSubstring(registry); MakeUnaryStringBatchKernelWithState<ReplaceSubStringPlain>( @@ -4136,10 +4136,10 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddReplaceSlice(registry); AddSlice(registry); AddSplit(registry); - AddStrptime(registry); + AddStrptime(registry); AddBinaryJoin(registry); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc index ead88abc0f..befb116348 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc @@ -1,65 +1,65 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include <cmath> -#include "arrow/compute/kernels/common.h" - -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_ops.h" - -namespace arrow { - -using internal::CopyBitmap; -using internal::InvertBitmap; - -namespace compute { -namespace internal { -namespace { - -struct IsValidOperator { +#include "arrow/compute/kernels/common.h" + +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using internal::CopyBitmap; +using internal::InvertBitmap; + +namespace compute { +namespace internal { +namespace { + +struct IsValidOperator { static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) { - checked_cast<BooleanScalar*>(out)->value = in.is_valid; + checked_cast<BooleanScalar*>(out)->value = in.is_valid; return Status::OK(); - } - + } + static Status Call(KernelContext* ctx, const ArrayData& arr, ArrayData* out) { - DCHECK_EQ(out->offset, 0); - DCHECK_LE(out->length, arr.length); - if (arr.MayHaveNulls()) { - // Input has nulls => output is the null (validity) bitmap. - // To avoid copying the null bitmap, slice from the starting byte offset - // and set the offset to the remaining bit offset. - out->offset = arr.offset % 8; - out->buffers[1] = - arr.offset == 0 ? arr.buffers[0] - : SliceBuffer(arr.buffers[0], arr.offset / 8, - BitUtil::BytesForBits(out->length + out->offset)); + DCHECK_EQ(out->offset, 0); + DCHECK_LE(out->length, arr.length); + if (arr.MayHaveNulls()) { + // Input has nulls => output is the null (validity) bitmap. + // To avoid copying the null bitmap, slice from the starting byte offset + // and set the offset to the remaining bit offset. + out->offset = arr.offset % 8; + out->buffers[1] = + arr.offset == 0 ? arr.buffers[0] + : SliceBuffer(arr.buffers[0], arr.offset / 8, + BitUtil::BytesForBits(out->length + out->offset)); return Status::OK(); - } - - // Input has no nulls => output is entirely true. + } + + // Input has no nulls => output is entirely true. ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(out->length + out->offset)); - BitUtil::SetBitsTo(out->buffers[1]->mutable_data(), out->offset, out->length, true); + BitUtil::SetBitsTo(out->buffers[1]->mutable_data(), out->offset, out->length, true); return Status::OK(); - } -}; - + } +}; + struct IsFiniteOperator { template <typename OutType, typename InType> static constexpr OutType Call(KernelContext*, const InType& value, Status*) { @@ -74,49 +74,49 @@ struct IsInfOperator { } }; -struct IsNullOperator { +struct IsNullOperator { static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) { - checked_cast<BooleanScalar*>(out)->value = !in.is_valid; + checked_cast<BooleanScalar*>(out)->value = !in.is_valid; return Status::OK(); - } - + } + static Status Call(KernelContext* ctx, const ArrayData& arr, ArrayData* out) { - if (arr.MayHaveNulls()) { - // Input has nulls => output is the inverted null (validity) bitmap. - InvertBitmap(arr.buffers[0]->data(), arr.offset, arr.length, - out->buffers[1]->mutable_data(), out->offset); + if (arr.MayHaveNulls()) { + // Input has nulls => output is the inverted null (validity) bitmap. + InvertBitmap(arr.buffers[0]->data(), arr.offset, arr.length, + out->buffers[1]->mutable_data(), out->offset); } else { // Input has no nulls => output is entirely false. BitUtil::SetBitsTo(out->buffers[1]->mutable_data(), out->offset, out->length, false); - } + } return Status::OK(); } }; - + struct IsNanOperator { template <typename OutType, typename InType> static constexpr OutType Call(KernelContext*, const InType& value, Status*) { return std::isnan(value); - } -}; - + } +}; + void MakeFunction(std::string name, const FunctionDoc* doc, std::vector<InputType> in_types, OutputType out_type, - ArrayKernelExec exec, FunctionRegistry* registry, - MemAllocation::type mem_allocation, bool can_write_into_slices) { - Arity arity{static_cast<int>(in_types.size())}; + ArrayKernelExec exec, FunctionRegistry* registry, + MemAllocation::type mem_allocation, bool can_write_into_slices) { + Arity arity{static_cast<int>(in_types.size())}; auto func = std::make_shared<ScalarFunction>(name, arity, doc); - - ScalarKernel kernel(std::move(in_types), out_type, exec); - kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; - kernel.can_write_into_slices = can_write_into_slices; - kernel.mem_allocation = mem_allocation; - - DCHECK_OK(func->AddKernel(std::move(kernel))); - DCHECK_OK(registry->AddFunction(std::move(func))); -} - + + ScalarKernel kernel(std::move(in_types), out_type, exec); + kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; + kernel.can_write_into_slices = can_write_into_slices; + kernel.mem_allocation = mem_allocation; + + DCHECK_OK(func->AddKernel(std::move(kernel))); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + template <typename InType, typename Op> void AddFloatValidityKernel(const std::shared_ptr<DataType>& ty, ScalarFunction* func) { DCHECK_OK(func->AddKernel({ty}, boolean(), @@ -154,40 +154,40 @@ std::shared_ptr<ScalarFunction> MakeIsNanFunction(std::string name, } Status IsValidExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const Datum& arg0 = batch[0]; - if (arg0.type()->id() == Type::NA) { - auto false_value = std::make_shared<BooleanScalar>(false); - if (arg0.kind() == Datum::SCALAR) { + const Datum& arg0 = batch[0]; + if (arg0.type()->id() == Type::NA) { + auto false_value = std::make_shared<BooleanScalar>(false); + if (arg0.kind() == Datum::SCALAR) { out->value = false_value; - } else { - std::shared_ptr<Array> false_values; + } else { + std::shared_ptr<Array> false_values; RETURN_NOT_OK(MakeArrayFromScalar(*false_value, out->length(), ctx->memory_pool()) .Value(&false_values)); - out->value = false_values->data(); - } + out->value = false_values->data(); + } return Status::OK(); - } else { + } else { return applicator::SimpleUnary<IsValidOperator>(ctx, batch, out); - } -} - + } +} + Status IsNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const Datum& arg0 = batch[0]; - if (arg0.type()->id() == Type::NA) { - if (arg0.kind() == Datum::SCALAR) { + const Datum& arg0 = batch[0]; + if (arg0.type()->id() == Type::NA) { + if (arg0.kind() == Datum::SCALAR) { out->value = std::make_shared<BooleanScalar>(true); - } else { - // Data is preallocated - ArrayData* out_arr = out->mutable_array(); - BitUtil::SetBitsTo(out_arr->buffers[1]->mutable_data(), out_arr->offset, - out_arr->length, true); - } + } else { + // Data is preallocated + ArrayData* out_arr = out->mutable_array(); + BitUtil::SetBitsTo(out_arr->buffers[1]->mutable_data(), out_arr->offset, + out_arr->length, true); + } return Status::OK(); - } else { + } else { return applicator::SimpleUnary<IsNullOperator>(ctx, batch, out); - } -} - + } +} + const FunctionDoc is_valid_doc( "Return true if non-null", ("For each input value, emit true iff the value is valid (non-null)."), {"values"}); @@ -210,21 +210,21 @@ const FunctionDoc is_nan_doc("Return true if NaN", ("For each input value, emit true iff the value is NaN."), {"values"}); -} // namespace - -void RegisterScalarValidity(FunctionRegistry* registry) { +} // namespace + +void RegisterScalarValidity(FunctionRegistry* registry) { MakeFunction("is_valid", &is_valid_doc, {ValueDescr::ANY}, boolean(), IsValidExec, registry, MemAllocation::NO_PREALLOCATE, /*can_write_into_slices=*/false); - + MakeFunction("is_null", &is_null_doc, {ValueDescr::ANY}, boolean(), IsNullExec, registry, MemAllocation::PREALLOCATE, - /*can_write_into_slices=*/true); + /*can_write_into_slices=*/true); DCHECK_OK(registry->AddFunction(MakeIsFiniteFunction("is_finite", &is_finite_doc))); DCHECK_OK(registry->AddFunction(MakeIsInfFunction("is_inf", &is_inf_doc))); DCHECK_OK(registry->AddFunction(MakeIsNanFunction("is_nan", &is_nan_doc))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.cc index 846fa26baf..df011f802c 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.cc @@ -1,62 +1,62 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/kernels/util_internal.h" - -#include <cstdint> - -#include "arrow/array/data.h" -#include "arrow/type.h" -#include "arrow/util/checked_cast.h" - -namespace arrow { - -using internal::checked_cast; - -namespace compute { -namespace internal { - -const uint8_t* GetValidityBitmap(const ArrayData& data) { - const uint8_t* bitmap = nullptr; - if (data.buffers[0]) { - bitmap = data.buffers[0]->data(); - } - return bitmap; -} - -int GetBitWidth(const DataType& type) { - return checked_cast<const FixedWidthType&>(type).bit_width(); -} - -PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { - PrimitiveArg arg; - arg.is_valid = GetValidityBitmap(arr); - arg.data = arr.buffers[1]->data(); - arg.bit_width = GetBitWidth(*arr.type); - arg.offset = arr.offset; - arg.length = arr.length; - if (arg.bit_width > 1) { - arg.data += arr.offset * arg.bit_width / 8; - } - // This may be kUnknownNullCount +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/util_internal.h" + +#include <cstdint> + +#include "arrow/array/data.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace internal { + +const uint8_t* GetValidityBitmap(const ArrayData& data) { + const uint8_t* bitmap = nullptr; + if (data.buffers[0]) { + bitmap = data.buffers[0]->data(); + } + return bitmap; +} + +int GetBitWidth(const DataType& type) { + return checked_cast<const FixedWidthType&>(type).bit_width(); +} + +PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { + PrimitiveArg arg; + arg.is_valid = GetValidityBitmap(arr); + arg.data = arr.buffers[1]->data(); + arg.bit_width = GetBitWidth(*arr.type); + arg.offset = arr.offset; + arg.length = arr.length; + if (arg.bit_width > 1) { + arg.data += arr.offset * arg.bit_width / 8; + } + // This may be kUnknownNullCount arg.null_count = (arg.is_valid != nullptr) ? arr.null_count.load() : 0; - return arg; -} - + return arg; +} + ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec, NullHandling::type null_handling) { return [=](KernelContext* ctx, const ExecBatch& batch, Datum* out) -> Status { @@ -77,6 +77,6 @@ ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec, }; } -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.h index 394e08da58..03d7c0da2b 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/util_internal.h @@ -1,35 +1,35 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> #include <utility> - + #include "arrow/array/util.h" -#include "arrow/buffer.h" +#include "arrow/buffer.h" #include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/type_fwd.h" #include "arrow/util/bit_run_reader.h" - -namespace arrow { -namespace compute { -namespace internal { - + +namespace arrow { +namespace compute { +namespace internal { + // Used in some kernels and testing - not provided by default in MSVC // and _USE_MATH_DEFINES is not reliable with unity builds #ifndef M_PI @@ -42,31 +42,31 @@ namespace internal { #define M_PI_4 0.785398163397448309616 #endif -// An internal data structure for unpacking a primitive argument to pass to a -// kernel implementation -struct PrimitiveArg { - const uint8_t* is_valid; - // If the bit_width is a multiple of 8 (i.e. not boolean), then "data" should - // be shifted by offset * (bit_width / 8). For bit-packed data, the offset - // must be used when indexing. - const uint8_t* data; - int bit_width; - int64_t length; - int64_t offset; - // This may be kUnknownNullCount if the null_count has not yet been computed, - // so use null_count != 0 to determine "may have nulls". - int64_t null_count; -}; - -// Get validity bitmap data or return nullptr if there is no validity buffer -const uint8_t* GetValidityBitmap(const ArrayData& data); - -int GetBitWidth(const DataType& type); - -// Reduce code size by dealing with the unboxing of the kernel inputs once -// rather than duplicating compiled code to do all these in each kernel. -PrimitiveArg GetPrimitiveArg(const ArrayData& arr); - +// An internal data structure for unpacking a primitive argument to pass to a +// kernel implementation +struct PrimitiveArg { + const uint8_t* is_valid; + // If the bit_width is a multiple of 8 (i.e. not boolean), then "data" should + // be shifted by offset * (bit_width / 8). For bit-packed data, the offset + // must be used when indexing. + const uint8_t* data; + int bit_width; + int64_t length; + int64_t offset; + // This may be kUnknownNullCount if the null_count has not yet been computed, + // so use null_count != 0 to determine "may have nulls". + int64_t null_count; +}; + +// Get validity bitmap data or return nullptr if there is no validity buffer +const uint8_t* GetValidityBitmap(const ArrayData& data); + +int GetBitWidth(const DataType& type); + +// Reduce code size by dealing with the unboxing of the kernel inputs once +// rather than duplicating compiled code to do all these in each kernel. +PrimitiveArg GetPrimitiveArg(const ArrayData& arr); + // Augment a unary ArrayKernelExec which supports only array-like inputs with support for // scalar inputs. Scalars will be transformed to 1-long arrays with the scalar's value (or // null if the scalar is null) as its only element. This 1-long array will be passed to @@ -161,6 +161,6 @@ int64_t CopyNonNullValues(const Datum& datum, T* out) { return n; } -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc index a68e78130f..9c37f23faf 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -1,173 +1,173 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <cstring> -#include <mutex> - -#include "arrow/array/array_base.h" -#include "arrow/array/array_dict.h" -#include "arrow/array/array_nested.h" -#include "arrow/array/builder_primitive.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cstring> +#include <mutex> + +#include "arrow/array/array_base.h" +#include "arrow/array/array_dict.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_primitive.h" #include "arrow/array/concatenate.h" -#include "arrow/array/dict_internal.h" -#include "arrow/array/util.h" -#include "arrow/compute/api_vector.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/result.h" -#include "arrow/util/hashing.h" -#include "arrow/util/make_unique.h" - -namespace arrow { - -using internal::DictionaryTraits; -using internal::HashTraits; - -namespace compute { -namespace internal { - -namespace { - -class ActionBase { - public: - ActionBase(const std::shared_ptr<DataType>& type, MemoryPool* pool) - : type_(type), pool_(pool) {} - - protected: - std::shared_ptr<DataType> type_; - MemoryPool* pool_; -}; - -// ---------------------------------------------------------------------- -// Unique - -class UniqueAction final : public ActionBase { - public: - using ActionBase::ActionBase; - - static constexpr bool with_error_status = false; - +#include "arrow/array/dict_internal.h" +#include "arrow/array/util.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/result.h" +#include "arrow/util/hashing.h" +#include "arrow/util/make_unique.h" + +namespace arrow { + +using internal::DictionaryTraits; +using internal::HashTraits; + +namespace compute { +namespace internal { + +namespace { + +class ActionBase { + public: + ActionBase(const std::shared_ptr<DataType>& type, MemoryPool* pool) + : type_(type), pool_(pool) {} + + protected: + std::shared_ptr<DataType> type_; + MemoryPool* pool_; +}; + +// ---------------------------------------------------------------------- +// Unique + +class UniqueAction final : public ActionBase { + public: + using ActionBase::ActionBase; + + static constexpr bool with_error_status = false; + UniqueAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options, MemoryPool* pool) : ActionBase(type, pool) {} - Status Reset() { return Status::OK(); } - - Status Reserve(const int64_t length) { return Status::OK(); } - - template <class Index> - void ObserveNullFound(Index index) {} - - template <class Index> - void ObserveNullNotFound(Index index) {} - - template <class Index> - void ObserveFound(Index index) {} - - template <class Index> - void ObserveNotFound(Index index) {} - + Status Reset() { return Status::OK(); } + + Status Reserve(const int64_t length) { return Status::OK(); } + + template <class Index> + void ObserveNullFound(Index index) {} + + template <class Index> + void ObserveNullNotFound(Index index) {} + + template <class Index> + void ObserveFound(Index index) {} + + template <class Index> + void ObserveNotFound(Index index) {} + bool ShouldEncodeNulls() { return true; } - Status Flush(Datum* out) { return Status::OK(); } - - Status FlushFinal(Datum* out) { return Status::OK(); } -}; - -// ---------------------------------------------------------------------- -// Count values - -class ValueCountsAction final : ActionBase { - public: - using ActionBase::ActionBase; - - static constexpr bool with_error_status = true; - + Status Flush(Datum* out) { return Status::OK(); } + + Status FlushFinal(Datum* out) { return Status::OK(); } +}; + +// ---------------------------------------------------------------------- +// Count values + +class ValueCountsAction final : ActionBase { + public: + using ActionBase::ActionBase; + + static constexpr bool with_error_status = true; + ValueCountsAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options, MemoryPool* pool) - : ActionBase(type, pool), count_builder_(pool) {} - - Status Reserve(const int64_t length) { - // builder size is independent of input array size. - return Status::OK(); - } - - Status Reset() { - count_builder_.Reset(); - return Status::OK(); - } - - // Don't do anything on flush because we don't want to finalize the builder - // or incur the cost of memory copies. - Status Flush(Datum* out) { return Status::OK(); } - - // Return the counts corresponding the MemoTable keys. - Status FlushFinal(Datum* out) { - std::shared_ptr<ArrayData> result; - RETURN_NOT_OK(count_builder_.FinishInternal(&result)); - out->value = std::move(result); - return Status::OK(); - } - - template <class Index> - void ObserveNullFound(Index index) { - count_builder_[index]++; - } - - template <class Index> - void ObserveNullNotFound(Index index) { - ARROW_LOG(FATAL) << "ObserveNullNotFound without err_status should not be called"; - } - - template <class Index> - void ObserveNullNotFound(Index index, Status* status) { - Status s = count_builder_.Append(1); - if (ARROW_PREDICT_FALSE(!s.ok())) { - *status = s; - } - } - - template <class Index> - void ObserveFound(Index slot) { - count_builder_[slot]++; - } - - template <class Index> - void ObserveNotFound(Index slot, Status* status) { - Status s = count_builder_.Append(1); - if (ARROW_PREDICT_FALSE(!s.ok())) { - *status = s; - } - } - + : ActionBase(type, pool), count_builder_(pool) {} + + Status Reserve(const int64_t length) { + // builder size is independent of input array size. + return Status::OK(); + } + + Status Reset() { + count_builder_.Reset(); + return Status::OK(); + } + + // Don't do anything on flush because we don't want to finalize the builder + // or incur the cost of memory copies. + Status Flush(Datum* out) { return Status::OK(); } + + // Return the counts corresponding the MemoTable keys. + Status FlushFinal(Datum* out) { + std::shared_ptr<ArrayData> result; + RETURN_NOT_OK(count_builder_.FinishInternal(&result)); + out->value = std::move(result); + return Status::OK(); + } + + template <class Index> + void ObserveNullFound(Index index) { + count_builder_[index]++; + } + + template <class Index> + void ObserveNullNotFound(Index index) { + ARROW_LOG(FATAL) << "ObserveNullNotFound without err_status should not be called"; + } + + template <class Index> + void ObserveNullNotFound(Index index, Status* status) { + Status s = count_builder_.Append(1); + if (ARROW_PREDICT_FALSE(!s.ok())) { + *status = s; + } + } + + template <class Index> + void ObserveFound(Index slot) { + count_builder_[slot]++; + } + + template <class Index> + void ObserveNotFound(Index slot, Status* status) { + Status s = count_builder_.Append(1); + if (ARROW_PREDICT_FALSE(!s.ok())) { + *status = s; + } + } + bool ShouldEncodeNulls() const { return true; } - private: - Int64Builder count_builder_; -}; - -// ---------------------------------------------------------------------- -// Dictionary encode implementation - -class DictEncodeAction final : public ActionBase { - public: - using ActionBase::ActionBase; - - static constexpr bool with_error_status = false; - + private: + Int64Builder count_builder_; +}; + +// ---------------------------------------------------------------------- +// Dictionary encode implementation + +class DictEncodeAction final : public ActionBase { + public: + using ActionBase::ActionBase; + + static constexpr bool with_error_status = false; + DictEncodeAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options, MemoryPool* pool) : ActionBase(type, pool), indices_builder_(pool) { @@ -175,174 +175,174 @@ class DictEncodeAction final : public ActionBase { encode_options_ = *options_ptr; } } - - Status Reset() { - indices_builder_.Reset(); - return Status::OK(); - } - - Status Reserve(const int64_t length) { return indices_builder_.Reserve(length); } - - template <class Index> - void ObserveNullFound(Index index) { + + Status Reset() { + indices_builder_.Reset(); + return Status::OK(); + } + + Status Reserve(const int64_t length) { return indices_builder_.Reserve(length); } + + template <class Index> + void ObserveNullFound(Index index) { if (encode_options_.null_encoding_behavior == DictionaryEncodeOptions::MASK) { indices_builder_.UnsafeAppendNull(); } else { indices_builder_.UnsafeAppend(index); } - } - - template <class Index> - void ObserveNullNotFound(Index index) { + } + + template <class Index> + void ObserveNullNotFound(Index index) { ObserveNullFound(index); - } - - template <class Index> - void ObserveFound(Index index) { - indices_builder_.UnsafeAppend(index); - } - - template <class Index> - void ObserveNotFound(Index index) { - ObserveFound(index); - } - + } + + template <class Index> + void ObserveFound(Index index) { + indices_builder_.UnsafeAppend(index); + } + + template <class Index> + void ObserveNotFound(Index index) { + ObserveFound(index); + } + bool ShouldEncodeNulls() { return encode_options_.null_encoding_behavior == DictionaryEncodeOptions::ENCODE; } - Status Flush(Datum* out) { - std::shared_ptr<ArrayData> result; - RETURN_NOT_OK(indices_builder_.FinishInternal(&result)); - out->value = std::move(result); - return Status::OK(); - } - - Status FlushFinal(Datum* out) { return Status::OK(); } - - private: - Int32Builder indices_builder_; + Status Flush(Datum* out) { + std::shared_ptr<ArrayData> result; + RETURN_NOT_OK(indices_builder_.FinishInternal(&result)); + out->value = std::move(result); + return Status::OK(); + } + + Status FlushFinal(Datum* out) { return Status::OK(); } + + private: + Int32Builder indices_builder_; DictionaryEncodeOptions encode_options_; -}; - -class HashKernel : public KernelState { - public: +}; + +class HashKernel : public KernelState { + public: HashKernel() : options_(nullptr) {} explicit HashKernel(const FunctionOptions* options) : options_(options) {} - // Reset for another run. - virtual Status Reset() = 0; - - // Flush out accumulated results from the last invocation of Call. - virtual Status Flush(Datum* out) = 0; - // Flush out accumulated results across all invocations of Call. The kernel - // should not be used until after Reset() is called. - virtual Status FlushFinal(Datum* out) = 0; - // Get the values (keys) accumulated in the dictionary so far. - virtual Status GetDictionary(std::shared_ptr<ArrayData>* out) = 0; - - virtual std::shared_ptr<DataType> value_type() const = 0; - - Status Append(KernelContext* ctx, const ArrayData& input) { - std::lock_guard<std::mutex> guard(lock_); - return Append(input); - } - - // Prepare the Action for the given input (e.g. reserve appropriately sized - // data structures) and visit the given input with Action. - virtual Status Append(const ArrayData& arr) = 0; - - protected: + // Reset for another run. + virtual Status Reset() = 0; + + // Flush out accumulated results from the last invocation of Call. + virtual Status Flush(Datum* out) = 0; + // Flush out accumulated results across all invocations of Call. The kernel + // should not be used until after Reset() is called. + virtual Status FlushFinal(Datum* out) = 0; + // Get the values (keys) accumulated in the dictionary so far. + virtual Status GetDictionary(std::shared_ptr<ArrayData>* out) = 0; + + virtual std::shared_ptr<DataType> value_type() const = 0; + + Status Append(KernelContext* ctx, const ArrayData& input) { + std::lock_guard<std::mutex> guard(lock_); + return Append(input); + } + + // Prepare the Action for the given input (e.g. reserve appropriately sized + // data structures) and visit the given input with Action. + virtual Status Append(const ArrayData& arr) = 0; + + protected: const FunctionOptions* options_; - std::mutex lock_; -}; - -// ---------------------------------------------------------------------- -// Base class for all "regular" hash kernel implementations -// (NullType has a separate implementation) - -template <typename Type, typename Scalar, typename Action, + std::mutex lock_; +}; + +// ---------------------------------------------------------------------- +// Base class for all "regular" hash kernel implementations +// (NullType has a separate implementation) + +template <typename Type, typename Scalar, typename Action, bool with_error_status = Action::with_error_status> -class RegularHashKernel : public HashKernel { - public: +class RegularHashKernel : public HashKernel { + public: RegularHashKernel(const std::shared_ptr<DataType>& type, const FunctionOptions* options, MemoryPool* pool) : HashKernel(options), pool_(pool), type_(type), action_(type, options, pool) {} - - Status Reset() override { - memo_table_.reset(new MemoTable(pool_, 0)); - return action_.Reset(); - } - - Status Append(const ArrayData& arr) override { - RETURN_NOT_OK(action_.Reserve(arr.length)); - return DoAppend(arr); - } - - Status Flush(Datum* out) override { return action_.Flush(out); } - - Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } - - Status GetDictionary(std::shared_ptr<ArrayData>* out) override { - return DictionaryTraits<Type>::GetDictionaryArrayData(pool_, type_, *memo_table_, - 0 /* start_offset */, out); - } - - std::shared_ptr<DataType> value_type() const override { return type_; } - - template <bool HasError = with_error_status> - enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) { - return VisitArrayDataInline<Type>( - arr, - [this](Scalar v) { - auto on_found = [this](int32_t memo_index) { - action_.ObserveFound(memo_index); - }; - auto on_not_found = [this](int32_t memo_index) { - action_.ObserveNotFound(memo_index); - }; - - int32_t unused_memo_index; - return memo_table_->GetOrInsert(v, std::move(on_found), std::move(on_not_found), - &unused_memo_index); - }, - [this]() { + + Status Reset() override { + memo_table_.reset(new MemoTable(pool_, 0)); + return action_.Reset(); + } + + Status Append(const ArrayData& arr) override { + RETURN_NOT_OK(action_.Reserve(arr.length)); + return DoAppend(arr); + } + + Status Flush(Datum* out) override { return action_.Flush(out); } + + Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } + + Status GetDictionary(std::shared_ptr<ArrayData>* out) override { + return DictionaryTraits<Type>::GetDictionaryArrayData(pool_, type_, *memo_table_, + 0 /* start_offset */, out); + } + + std::shared_ptr<DataType> value_type() const override { return type_; } + + template <bool HasError = with_error_status> + enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) { + return VisitArrayDataInline<Type>( + arr, + [this](Scalar v) { + auto on_found = [this](int32_t memo_index) { + action_.ObserveFound(memo_index); + }; + auto on_not_found = [this](int32_t memo_index) { + action_.ObserveNotFound(memo_index); + }; + + int32_t unused_memo_index; + return memo_table_->GetOrInsert(v, std::move(on_found), std::move(on_not_found), + &unused_memo_index); + }, + [this]() { if (action_.ShouldEncodeNulls()) { - auto on_found = [this](int32_t memo_index) { - action_.ObserveNullFound(memo_index); - }; - auto on_not_found = [this](int32_t memo_index) { - action_.ObserveNullNotFound(memo_index); - }; - memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); - } else { - action_.ObserveNullNotFound(-1); - } - return Status::OK(); - }); - } - - template <bool HasError = with_error_status> - enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) { - return VisitArrayDataInline<Type>( - arr, - [this](Scalar v) { - Status s = Status::OK(); - auto on_found = [this](int32_t memo_index) { - action_.ObserveFound(memo_index); - }; - auto on_not_found = [this, &s](int32_t memo_index) { - action_.ObserveNotFound(memo_index, &s); - }; - - int32_t unused_memo_index; - RETURN_NOT_OK(memo_table_->GetOrInsert( - v, std::move(on_found), std::move(on_not_found), &unused_memo_index)); - return s; - }, - [this]() { - // Null - Status s = Status::OK(); + auto on_found = [this](int32_t memo_index) { + action_.ObserveNullFound(memo_index); + }; + auto on_not_found = [this](int32_t memo_index) { + action_.ObserveNullNotFound(memo_index); + }; + memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); + } else { + action_.ObserveNullNotFound(-1); + } + return Status::OK(); + }); + } + + template <bool HasError = with_error_status> + enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) { + return VisitArrayDataInline<Type>( + arr, + [this](Scalar v) { + Status s = Status::OK(); + auto on_found = [this](int32_t memo_index) { + action_.ObserveFound(memo_index); + }; + auto on_not_found = [this, &s](int32_t memo_index) { + action_.ObserveNotFound(memo_index, &s); + }; + + int32_t unused_memo_index; + RETURN_NOT_OK(memo_table_->GetOrInsert( + v, std::move(on_found), std::move(on_not_found), &unused_memo_index)); + return s; + }, + [this]() { + // Null + Status s = Status::OK(); auto on_found = [this](int32_t memo_index) { action_.ObserveNullFound(memo_index); }; @@ -350,49 +350,49 @@ class RegularHashKernel : public HashKernel { action_.ObserveNullNotFound(memo_index, &s); }; if (action_.ShouldEncodeNulls()) { - memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); - } - return s; - }); - } - - protected: - using MemoTable = typename HashTraits<Type>::MemoTableType; - - MemoryPool* pool_; - std::shared_ptr<DataType> type_; - Action action_; - std::unique_ptr<MemoTable> memo_table_; -}; - -// ---------------------------------------------------------------------- -// Hash kernel implementation for nulls - + memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); + } + return s; + }); + } + + protected: + using MemoTable = typename HashTraits<Type>::MemoTableType; + + MemoryPool* pool_; + std::shared_ptr<DataType> type_; + Action action_; + std::unique_ptr<MemoTable> memo_table_; +}; + +// ---------------------------------------------------------------------- +// Hash kernel implementation for nulls + template <typename Action, bool with_error_status = Action::with_error_status> -class NullHashKernel : public HashKernel { - public: +class NullHashKernel : public HashKernel { + public: NullHashKernel(const std::shared_ptr<DataType>& type, const FunctionOptions* options, MemoryPool* pool) : pool_(pool), type_(type), action_(type, options, pool) {} - - Status Reset() override { return action_.Reset(); } - + + Status Reset() override { return action_.Reset(); } + Status Append(const ArrayData& arr) override { return DoAppend(arr); } template <bool HasError = with_error_status> enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) { - RETURN_NOT_OK(action_.Reserve(arr.length)); - for (int64_t i = 0; i < arr.length; ++i) { - if (i == 0) { + RETURN_NOT_OK(action_.Reserve(arr.length)); + for (int64_t i = 0; i < arr.length; ++i) { + if (i == 0) { seen_null_ = true; - action_.ObserveNullNotFound(0); - } else { - action_.ObserveNullFound(0); - } - } - return Status::OK(); - } - + action_.ObserveNullNotFound(0); + } else { + action_.ObserveNullFound(0); + } + } + return Status::OK(); + } + template <bool HasError = with_error_status> enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) { Status s = Status::OK(); @@ -408,41 +408,41 @@ class NullHashKernel : public HashKernel { return s; } - Status Flush(Datum* out) override { return action_.Flush(out); } - Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } - - Status GetDictionary(std::shared_ptr<ArrayData>* out) override { + Status Flush(Datum* out) override { return action_.Flush(out); } + Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } + + Status GetDictionary(std::shared_ptr<ArrayData>* out) override { std::shared_ptr<NullArray> null_array; if (seen_null_) { null_array = std::make_shared<NullArray>(1); } else { null_array = std::make_shared<NullArray>(0); } - *out = null_array->data(); - return Status::OK(); - } - - std::shared_ptr<DataType> value_type() const override { return type_; } - - protected: - MemoryPool* pool_; - std::shared_ptr<DataType> type_; + *out = null_array->data(); + return Status::OK(); + } + + std::shared_ptr<DataType> value_type() const override { return type_; } + + protected: + MemoryPool* pool_; + std::shared_ptr<DataType> type_; bool seen_null_ = false; - Action action_; -}; - -// ---------------------------------------------------------------------- -// Hashing for dictionary type - -class DictionaryHashKernel : public HashKernel { - public: - explicit DictionaryHashKernel(std::unique_ptr<HashKernel> indices_kernel) - : indices_kernel_(std::move(indices_kernel)) {} - - Status Reset() override { return indices_kernel_->Reset(); } - + Action action_; +}; + +// ---------------------------------------------------------------------- +// Hashing for dictionary type + +class DictionaryHashKernel : public HashKernel { + public: + explicit DictionaryHashKernel(std::unique_ptr<HashKernel> indices_kernel) + : indices_kernel_(std::move(indices_kernel)) {} + + Status Reset() override { return indices_kernel_->Reset(); } + Status Append(const ArrayData& arr) override { - if (!dictionary_) { + if (!dictionary_) { dictionary_ = arr.dictionary; } else if (!MakeArray(dictionary_)->Equals(*MakeArray(arr.dictionary))) { // NOTE: This approach computes a new dictionary unification per chunk. @@ -468,238 +468,238 @@ class DictionaryHashKernel : public HashKernel { auto tmp, arrow::internal::checked_cast<const DictionaryArray&>(*in_dict_array) .Transpose(arr.type, out_dict, transpose)); return indices_kernel_->Append(*tmp->data()); - } - - return indices_kernel_->Append(arr); - } - - Status Flush(Datum* out) override { return indices_kernel_->Flush(out); } - - Status FlushFinal(Datum* out) override { return indices_kernel_->FlushFinal(out); } - - Status GetDictionary(std::shared_ptr<ArrayData>* out) override { - return indices_kernel_->GetDictionary(out); - } - - std::shared_ptr<DataType> value_type() const override { - return indices_kernel_->value_type(); - } - - std::shared_ptr<ArrayData> dictionary() const { return dictionary_; } - - private: - std::unique_ptr<HashKernel> indices_kernel_; - std::shared_ptr<ArrayData> dictionary_; -}; - -// ---------------------------------------------------------------------- - -template <typename Type, typename Action, typename Enable = void> -struct HashKernelTraits {}; - -template <typename Type, typename Action> -struct HashKernelTraits<Type, Action, enable_if_null<Type>> { - using HashKernel = NullHashKernel<Action>; -}; - -template <typename Type, typename Action> -struct HashKernelTraits<Type, Action, enable_if_has_c_type<Type>> { - using HashKernel = RegularHashKernel<Type, typename Type::c_type, Action>; -}; - -template <typename Type, typename Action> -struct HashKernelTraits<Type, Action, enable_if_has_string_view<Type>> { - using HashKernel = RegularHashKernel<Type, util::string_view, Action>; -}; - -template <typename Type, typename Action> + } + + return indices_kernel_->Append(arr); + } + + Status Flush(Datum* out) override { return indices_kernel_->Flush(out); } + + Status FlushFinal(Datum* out) override { return indices_kernel_->FlushFinal(out); } + + Status GetDictionary(std::shared_ptr<ArrayData>* out) override { + return indices_kernel_->GetDictionary(out); + } + + std::shared_ptr<DataType> value_type() const override { + return indices_kernel_->value_type(); + } + + std::shared_ptr<ArrayData> dictionary() const { return dictionary_; } + + private: + std::unique_ptr<HashKernel> indices_kernel_; + std::shared_ptr<ArrayData> dictionary_; +}; + +// ---------------------------------------------------------------------- + +template <typename Type, typename Action, typename Enable = void> +struct HashKernelTraits {}; + +template <typename Type, typename Action> +struct HashKernelTraits<Type, Action, enable_if_null<Type>> { + using HashKernel = NullHashKernel<Action>; +}; + +template <typename Type, typename Action> +struct HashKernelTraits<Type, Action, enable_if_has_c_type<Type>> { + using HashKernel = RegularHashKernel<Type, typename Type::c_type, Action>; +}; + +template <typename Type, typename Action> +struct HashKernelTraits<Type, Action, enable_if_has_string_view<Type>> { + using HashKernel = RegularHashKernel<Type, util::string_view, Action>; +}; + +template <typename Type, typename Action> Result<std::unique_ptr<HashKernel>> HashInitImpl(KernelContext* ctx, const KernelInitArgs& args) { - using HashKernelType = typename HashKernelTraits<Type, Action>::HashKernel; + using HashKernelType = typename HashKernelTraits<Type, Action>::HashKernel; auto result = ::arrow::internal::make_unique<HashKernelType>( args.inputs[0].type, args.options, ctx->memory_pool()); RETURN_NOT_OK(result->Reset()); - return std::move(result); -} - -template <typename Type, typename Action> + return std::move(result); +} + +template <typename Type, typename Action> Result<std::unique_ptr<KernelState>> HashInit(KernelContext* ctx, const KernelInitArgs& args) { return HashInitImpl<Type, Action>(ctx, args); -} - -template <typename Action> -KernelInit GetHashInit(Type::type type_id) { - // ARROW-8933: Generate only a single hash kernel per physical data - // representation - switch (type_id) { - case Type::NA: - return HashInit<NullType, Action>; - case Type::BOOL: - return HashInit<BooleanType, Action>; - case Type::INT8: - case Type::UINT8: - return HashInit<UInt8Type, Action>; - case Type::INT16: - case Type::UINT16: - return HashInit<UInt16Type, Action>; - case Type::INT32: - case Type::UINT32: - case Type::FLOAT: - case Type::DATE32: - case Type::TIME32: - return HashInit<UInt32Type, Action>; - case Type::INT64: - case Type::UINT64: - case Type::DOUBLE: - case Type::DATE64: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::DURATION: - return HashInit<UInt64Type, Action>; - case Type::BINARY: - case Type::STRING: - return HashInit<BinaryType, Action>; - case Type::LARGE_BINARY: - case Type::LARGE_STRING: - return HashInit<LargeBinaryType, Action>; - case Type::FIXED_SIZE_BINARY: +} + +template <typename Action> +KernelInit GetHashInit(Type::type type_id) { + // ARROW-8933: Generate only a single hash kernel per physical data + // representation + switch (type_id) { + case Type::NA: + return HashInit<NullType, Action>; + case Type::BOOL: + return HashInit<BooleanType, Action>; + case Type::INT8: + case Type::UINT8: + return HashInit<UInt8Type, Action>; + case Type::INT16: + case Type::UINT16: + return HashInit<UInt16Type, Action>; + case Type::INT32: + case Type::UINT32: + case Type::FLOAT: + case Type::DATE32: + case Type::TIME32: + return HashInit<UInt32Type, Action>; + case Type::INT64: + case Type::UINT64: + case Type::DOUBLE: + case Type::DATE64: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::DURATION: + return HashInit<UInt64Type, Action>; + case Type::BINARY: + case Type::STRING: + return HashInit<BinaryType, Action>; + case Type::LARGE_BINARY: + case Type::LARGE_STRING: + return HashInit<LargeBinaryType, Action>; + case Type::FIXED_SIZE_BINARY: case Type::DECIMAL128: case Type::DECIMAL256: - return HashInit<FixedSizeBinaryType, Action>; - default: - DCHECK(false); - return nullptr; - } -} - + return HashInit<FixedSizeBinaryType, Action>; + default: + DCHECK(false); + return nullptr; + } +} + using DictionaryEncodeState = OptionsWrapper<DictionaryEncodeOptions>; -template <typename Action> +template <typename Action> Result<std::unique_ptr<KernelState>> DictionaryHashInit(KernelContext* ctx, const KernelInitArgs& args) { - const auto& dict_type = checked_cast<const DictionaryType&>(*args.inputs[0].type); + const auto& dict_type = checked_cast<const DictionaryType&>(*args.inputs[0].type); Result<std::unique_ptr<HashKernel>> indices_hasher; - switch (dict_type.index_type()->id()) { - case Type::INT8: - indices_hasher = HashInitImpl<UInt8Type, Action>(ctx, args); - break; - case Type::INT16: - indices_hasher = HashInitImpl<UInt16Type, Action>(ctx, args); - break; - case Type::INT32: - indices_hasher = HashInitImpl<UInt32Type, Action>(ctx, args); - break; - case Type::INT64: - indices_hasher = HashInitImpl<UInt64Type, Action>(ctx, args); - break; - default: - DCHECK(false) << "Unsupported dictionary index type"; - break; - } + switch (dict_type.index_type()->id()) { + case Type::INT8: + indices_hasher = HashInitImpl<UInt8Type, Action>(ctx, args); + break; + case Type::INT16: + indices_hasher = HashInitImpl<UInt16Type, Action>(ctx, args); + break; + case Type::INT32: + indices_hasher = HashInitImpl<UInt32Type, Action>(ctx, args); + break; + case Type::INT64: + indices_hasher = HashInitImpl<UInt64Type, Action>(ctx, args); + break; + default: + DCHECK(false) << "Unsupported dictionary index type"; + break; + } RETURN_NOT_OK(indices_hasher); return ::arrow::internal::make_unique<DictionaryHashKernel>( std::move(indices_hasher.ValueOrDie())); -} - +} + Status HashExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - auto hash_impl = checked_cast<HashKernel*>(ctx->state()); + auto hash_impl = checked_cast<HashKernel*>(ctx->state()); RETURN_NOT_OK(hash_impl->Append(ctx, *batch[0].array())); RETURN_NOT_OK(hash_impl->Flush(out)); return Status::OK(); -} - +} + Status UniqueFinalize(KernelContext* ctx, std::vector<Datum>* out) { - auto hash_impl = checked_cast<HashKernel*>(ctx->state()); - std::shared_ptr<ArrayData> uniques; + auto hash_impl = checked_cast<HashKernel*>(ctx->state()); + std::shared_ptr<ArrayData> uniques; RETURN_NOT_OK(hash_impl->GetDictionary(&uniques)); - *out = {Datum(uniques)}; + *out = {Datum(uniques)}; return Status::OK(); -} - +} + Status DictEncodeFinalize(KernelContext* ctx, std::vector<Datum>* out) { - auto hash_impl = checked_cast<HashKernel*>(ctx->state()); - std::shared_ptr<ArrayData> uniques; + auto hash_impl = checked_cast<HashKernel*>(ctx->state()); + std::shared_ptr<ArrayData> uniques; RETURN_NOT_OK(hash_impl->GetDictionary(&uniques)); - auto dict_type = dictionary(int32(), uniques->type); - auto dict = MakeArray(uniques); - for (size_t i = 0; i < out->size(); ++i) { - (*out)[i] = - std::make_shared<DictionaryArray>(dict_type, (*out)[i].make_array(), dict); - } + auto dict_type = dictionary(int32(), uniques->type); + auto dict = MakeArray(uniques); + for (size_t i = 0; i < out->size(); ++i) { + (*out)[i] = + std::make_shared<DictionaryArray>(dict_type, (*out)[i].make_array(), dict); + } return Status::OK(); -} - -std::shared_ptr<ArrayData> BoxValueCounts(const std::shared_ptr<ArrayData>& uniques, - const std::shared_ptr<ArrayData>& counts) { - auto data_type = - struct_({field(kValuesFieldName, uniques->type), field(kCountsFieldName, int64())}); - ArrayVector children = {MakeArray(uniques), MakeArray(counts)}; - return std::make_shared<StructArray>(data_type, uniques->length, children)->data(); -} - +} + +std::shared_ptr<ArrayData> BoxValueCounts(const std::shared_ptr<ArrayData>& uniques, + const std::shared_ptr<ArrayData>& counts) { + auto data_type = + struct_({field(kValuesFieldName, uniques->type), field(kCountsFieldName, int64())}); + ArrayVector children = {MakeArray(uniques), MakeArray(counts)}; + return std::make_shared<StructArray>(data_type, uniques->length, children)->data(); +} + Status ValueCountsFinalize(KernelContext* ctx, std::vector<Datum>* out) { - auto hash_impl = checked_cast<HashKernel*>(ctx->state()); - std::shared_ptr<ArrayData> uniques; - Datum value_counts; - + auto hash_impl = checked_cast<HashKernel*>(ctx->state()); + std::shared_ptr<ArrayData> uniques; + Datum value_counts; + RETURN_NOT_OK(hash_impl->GetDictionary(&uniques)); RETURN_NOT_OK(hash_impl->FlushFinal(&value_counts)); - *out = {Datum(BoxValueCounts(uniques, value_counts.array()))}; + *out = {Datum(BoxValueCounts(uniques, value_counts.array()))}; return Status::OK(); -} - +} + Status UniqueFinalizeDictionary(KernelContext* ctx, std::vector<Datum>* out) { RETURN_NOT_OK(UniqueFinalize(ctx, out)); - auto hash = checked_cast<DictionaryHashKernel*>(ctx->state()); - (*out)[0].mutable_array()->dictionary = hash->dictionary(); + auto hash = checked_cast<DictionaryHashKernel*>(ctx->state()); + (*out)[0].mutable_array()->dictionary = hash->dictionary(); return Status::OK(); -} - +} + Status ValueCountsFinalizeDictionary(KernelContext* ctx, std::vector<Datum>* out) { - auto hash = checked_cast<DictionaryHashKernel*>(ctx->state()); - std::shared_ptr<ArrayData> uniques; - Datum value_counts; + auto hash = checked_cast<DictionaryHashKernel*>(ctx->state()); + std::shared_ptr<ArrayData> uniques; + Datum value_counts; RETURN_NOT_OK(hash->GetDictionary(&uniques)); RETURN_NOT_OK(hash->FlushFinal(&value_counts)); - uniques->dictionary = hash->dictionary(); - *out = {Datum(BoxValueCounts(uniques, value_counts.array()))}; + uniques->dictionary = hash->dictionary(); + *out = {Datum(BoxValueCounts(uniques, value_counts.array()))}; return Status::OK(); -} - -ValueDescr DictEncodeOutput(KernelContext*, const std::vector<ValueDescr>& descrs) { - return ValueDescr::Array(dictionary(int32(), descrs[0].type)); -} - -ValueDescr ValueCountsOutput(KernelContext*, const std::vector<ValueDescr>& descrs) { - return ValueDescr::Array(struct_( - {field(kValuesFieldName, descrs[0].type), field(kCountsFieldName, int64())})); -} - -template <typename Action> -void AddHashKernels(VectorFunction* func, VectorKernel base, OutputType out_ty) { - for (const auto& ty : PrimitiveTypes()) { - base.init = GetHashInit<Action>(ty->id()); - base.signature = KernelSignature::Make({InputType::Array(ty)}, out_ty); - DCHECK_OK(func->AddKernel(base)); - } - - // Example parametric types that we want to match only on Type::type - auto parametric_types = {time32(TimeUnit::SECOND), time64(TimeUnit::MICRO), - timestamp(TimeUnit::SECOND), fixed_size_binary(0)}; - for (const auto& ty : parametric_types) { - base.init = GetHashInit<Action>(ty->id()); - base.signature = KernelSignature::Make({InputType::Array(ty->id())}, out_ty); - DCHECK_OK(func->AddKernel(base)); - } - +} + +ValueDescr DictEncodeOutput(KernelContext*, const std::vector<ValueDescr>& descrs) { + return ValueDescr::Array(dictionary(int32(), descrs[0].type)); +} + +ValueDescr ValueCountsOutput(KernelContext*, const std::vector<ValueDescr>& descrs) { + return ValueDescr::Array(struct_( + {field(kValuesFieldName, descrs[0].type), field(kCountsFieldName, int64())})); +} + +template <typename Action> +void AddHashKernels(VectorFunction* func, VectorKernel base, OutputType out_ty) { + for (const auto& ty : PrimitiveTypes()) { + base.init = GetHashInit<Action>(ty->id()); + base.signature = KernelSignature::Make({InputType::Array(ty)}, out_ty); + DCHECK_OK(func->AddKernel(base)); + } + + // Example parametric types that we want to match only on Type::type + auto parametric_types = {time32(TimeUnit::SECOND), time64(TimeUnit::MICRO), + timestamp(TimeUnit::SECOND), fixed_size_binary(0)}; + for (const auto& ty : parametric_types) { + base.init = GetHashInit<Action>(ty->id()); + base.signature = KernelSignature::Make({InputType::Array(ty->id())}, out_ty); + DCHECK_OK(func->AddKernel(base)); + } + for (auto t : {Type::DECIMAL128, Type::DECIMAL256}) { base.init = GetHashInit<Action>(t); base.signature = KernelSignature::Make({InputType::Array(t)}, out_ty); DCHECK_OK(func->AddKernel(base)); } -} - +} + const FunctionDoc unique_doc( "Compute unique elements", ("Return an array with distinct values. Nulls in the input are ignored."), @@ -718,65 +718,65 @@ const FunctionDoc dictionary_encode_doc( ("Return a dictionary-encoded version of the input array."), {"array"}, "DictionaryEncodeOptions"); -} // namespace - -void RegisterVectorHash(FunctionRegistry* registry) { - VectorKernel base; - base.exec = HashExec; - - // ---------------------------------------------------------------------- - // unique - - base.finalize = UniqueFinalize; - base.output_chunked = false; +} // namespace + +void RegisterVectorHash(FunctionRegistry* registry) { + VectorKernel base; + base.exec = HashExec; + + // ---------------------------------------------------------------------- + // unique + + base.finalize = UniqueFinalize; + base.output_chunked = false; auto unique = std::make_shared<VectorFunction>("unique", Arity::Unary(), &unique_doc); - AddHashKernels<UniqueAction>(unique.get(), base, OutputType(FirstType)); - - // Dictionary unique - base.init = DictionaryHashInit<UniqueAction>; - base.finalize = UniqueFinalizeDictionary; - base.signature = - KernelSignature::Make({InputType::Array(Type::DICTIONARY)}, OutputType(FirstType)); - DCHECK_OK(unique->AddKernel(base)); - - DCHECK_OK(registry->AddFunction(std::move(unique))); - - // ---------------------------------------------------------------------- - // value_counts - - base.finalize = ValueCountsFinalize; + AddHashKernels<UniqueAction>(unique.get(), base, OutputType(FirstType)); + + // Dictionary unique + base.init = DictionaryHashInit<UniqueAction>; + base.finalize = UniqueFinalizeDictionary; + base.signature = + KernelSignature::Make({InputType::Array(Type::DICTIONARY)}, OutputType(FirstType)); + DCHECK_OK(unique->AddKernel(base)); + + DCHECK_OK(registry->AddFunction(std::move(unique))); + + // ---------------------------------------------------------------------- + // value_counts + + base.finalize = ValueCountsFinalize; auto value_counts = std::make_shared<VectorFunction>("value_counts", Arity::Unary(), &value_counts_doc); - AddHashKernels<ValueCountsAction>(value_counts.get(), base, - OutputType(ValueCountsOutput)); - - // Dictionary value counts - base.init = DictionaryHashInit<ValueCountsAction>; - base.finalize = ValueCountsFinalizeDictionary; - base.signature = KernelSignature::Make({InputType::Array(Type::DICTIONARY)}, - OutputType(ValueCountsOutput)); - DCHECK_OK(value_counts->AddKernel(base)); - - DCHECK_OK(registry->AddFunction(std::move(value_counts))); - - // ---------------------------------------------------------------------- - // dictionary_encode - - base.finalize = DictEncodeFinalize; - // Unique and ValueCounts output unchunked arrays - base.output_chunked = true; + AddHashKernels<ValueCountsAction>(value_counts.get(), base, + OutputType(ValueCountsOutput)); + + // Dictionary value counts + base.init = DictionaryHashInit<ValueCountsAction>; + base.finalize = ValueCountsFinalizeDictionary; + base.signature = KernelSignature::Make({InputType::Array(Type::DICTIONARY)}, + OutputType(ValueCountsOutput)); + DCHECK_OK(value_counts->AddKernel(base)); + + DCHECK_OK(registry->AddFunction(std::move(value_counts))); + + // ---------------------------------------------------------------------- + // dictionary_encode + + base.finalize = DictEncodeFinalize; + // Unique and ValueCounts output unchunked arrays + base.output_chunked = true; auto dict_encode = std::make_shared<VectorFunction>("dictionary_encode", Arity::Unary(), &dictionary_encode_doc, &kDefaultDictionaryEncodeOptions); - AddHashKernels<DictEncodeAction>(dict_encode.get(), base, OutputType(DictEncodeOutput)); - - // Calling dictionary_encode on dictionary input not supported, but if it - // ends up being needed (or convenience), a kernel could be added to make it - // a no-op - - DCHECK_OK(registry->AddFunction(std::move(dict_encode))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow + AddHashKernels<DictEncodeAction>(dict_encode.get(), base, OutputType(DictEncodeOutput)); + + // Calling dictionary_encode on dictionary input not supported, but if it + // ends up being needed (or convenience), a kernel could be added to make it + // a no-op + + DCHECK_OK(registry->AddFunction(std::move(dict_encode))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc index b84640854e..627f4edf96 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc @@ -1,68 +1,68 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Vector kernels involving nested types - -#include "arrow/array/array_base.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/result.h" - -namespace arrow { -namespace compute { -namespace internal { -namespace { - -template <typename Type> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Vector kernels involving nested types + +#include "arrow/array/array_base.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/result.h" + +namespace arrow { +namespace compute { +namespace internal { +namespace { + +template <typename Type> Status ListFlatten(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - typename TypeTraits<Type>::ArrayType list_array(batch[0].array()); + typename TypeTraits<Type>::ArrayType list_array(batch[0].array()); ARROW_ASSIGN_OR_RAISE(auto result, list_array.Flatten(ctx->memory_pool())); out->value = result->data(); return Status::OK(); -} - -template <typename Type, typename offset_type = typename Type::offset_type> +} + +template <typename Type, typename offset_type = typename Type::offset_type> Status ListParentIndices(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - typename TypeTraits<Type>::ArrayType list(batch[0].array()); - ArrayData* out_arr = out->mutable_array(); - - const offset_type* offsets = list.raw_value_offsets(); - offset_type values_length = offsets[list.length()] - offsets[0]; - - out_arr->length = values_length; - out_arr->null_count = 0; + typename TypeTraits<Type>::ArrayType list(batch[0].array()); + ArrayData* out_arr = out->mutable_array(); + + const offset_type* offsets = list.raw_value_offsets(); + offset_type values_length = offsets[list.length()] - offsets[0]; + + out_arr->length = values_length; + out_arr->null_count = 0; ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->Allocate(values_length * sizeof(offset_type))); - auto out_indices = reinterpret_cast<offset_type*>(out_arr->buffers[1]->mutable_data()); - for (int64_t i = 0; i < list.length(); ++i) { - // Note: In most cases, null slots are empty, but when they are non-empty - // we write out the indices so make sure they are accounted for. This - // behavior could be changed if needed in the future. - for (offset_type j = offsets[i]; j < offsets[i + 1]; ++j) { - *out_indices++ = static_cast<offset_type>(i); - } - } + auto out_indices = reinterpret_cast<offset_type*>(out_arr->buffers[1]->mutable_data()); + for (int64_t i = 0; i < list.length(); ++i) { + // Note: In most cases, null slots are empty, but when they are non-empty + // we write out the indices so make sure they are accounted for. This + // behavior could be changed if needed in the future. + for (offset_type j = offsets[i]; j < offsets[i + 1]; ++j) { + *out_indices++ = static_cast<offset_type>(i); + } + } return Status::OK(); -} - -Result<ValueDescr> ValuesType(KernelContext*, const std::vector<ValueDescr>& args) { - const auto& list_type = checked_cast<const BaseListType&>(*args[0].type); - return ValueDescr::Array(list_type.value_type()); -} - +} + +Result<ValueDescr> ValuesType(KernelContext*, const std::vector<ValueDescr>& args) { + const auto& list_type = checked_cast<const BaseListType&>(*args[0].type); + return ValueDescr::Array(list_type.value_type()); +} + const FunctionDoc list_flatten_doc( "Flatten list values", ("`lists` must have a list-like type.\n" @@ -77,26 +77,26 @@ const FunctionDoc list_parent_indices_doc( "is emitted."), {"lists"}); -} // namespace - -void RegisterVectorNested(FunctionRegistry* registry) { +} // namespace + +void RegisterVectorNested(FunctionRegistry* registry) { auto flatten = std::make_shared<VectorFunction>("list_flatten", Arity::Unary(), &list_flatten_doc); - DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LIST)}, OutputType(ValuesType), - ListFlatten<ListType>)); - DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LARGE_LIST)}, - OutputType(ValuesType), ListFlatten<LargeListType>)); - DCHECK_OK(registry->AddFunction(std::move(flatten))); - + DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LIST)}, OutputType(ValuesType), + ListFlatten<ListType>)); + DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LARGE_LIST)}, + OutputType(ValuesType), ListFlatten<LargeListType>)); + DCHECK_OK(registry->AddFunction(std::move(flatten))); + auto list_parent_indices = std::make_shared<VectorFunction>( "list_parent_indices", Arity::Unary(), &list_parent_indices_doc); - DCHECK_OK(list_parent_indices->AddKernel({InputType::Array(Type::LIST)}, int32(), - ListParentIndices<ListType>)); - DCHECK_OK(list_parent_indices->AddKernel({InputType::Array(Type::LARGE_LIST)}, int64(), - ListParentIndices<LargeListType>)); - DCHECK_OK(registry->AddFunction(std::move(list_parent_indices))); -} - -} // namespace internal -} // namespace compute -} // namespace arrow + DCHECK_OK(list_parent_indices->AddKernel({InputType::Array(Type::LIST)}, int32(), + ListParentIndices<ListType>)); + DCHECK_OK(list_parent_indices->AddKernel({InputType::Array(Type::LARGE_LIST)}, int64(), + ListParentIndices<LargeListType>)); + DCHECK_OK(registry->AddFunction(std::move(list_parent_indices))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc index 5845a7ee2d..f4fd377eff 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc @@ -1,122 +1,122 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <algorithm> -#include <cstring> -#include <limits> - -#include "arrow/array/array_base.h" -#include "arrow/array/array_binary.h" -#include "arrow/array/array_dict.h" -#include "arrow/array/array_nested.h" -#include "arrow/array/builder_primitive.h" -#include "arrow/array/concatenate.h" -#include "arrow/buffer_builder.h" -#include "arrow/chunked_array.h" -#include "arrow/compute/api_vector.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/util_internal.h" -#include "arrow/extension_type.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/table.h" -#include "arrow/type.h" -#include "arrow/util/bit_block_counter.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <cstring> +#include <limits> + +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/array_dict.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/concatenate.h" +#include "arrow/buffer_builder.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/extension_type.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/bit_block_counter.h" #include "arrow/util/bit_run_reader.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_ops.h" -#include "arrow/util/bitmap_reader.h" -#include "arrow/util/int_util.h" - -namespace arrow { - -using internal::BinaryBitBlockCounter; -using internal::BitBlockCount; -using internal::BitBlockCounter; -using internal::CheckIndexBounds; -using internal::CopyBitmap; -using internal::CountSetBits; -using internal::GetArrayView; -using internal::GetByteWidth; -using internal::OptionalBitBlockCounter; -using internal::OptionalBitIndexer; - -namespace compute { -namespace internal { - -int64_t GetFilterOutputSize(const ArrayData& filter, - FilterOptions::NullSelectionBehavior null_selection) { - int64_t output_size = 0; - - if (filter.MayHaveNulls()) { - const uint8_t* filter_is_valid = filter.buffers[0]->data(); - BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset, - filter_is_valid, filter.offset, filter.length); - int64_t position = 0; - if (null_selection == FilterOptions::EMIT_NULL) { - while (position < filter.length) { - BitBlockCount block = bit_counter.NextOrNotWord(); - output_size += block.popcount; - position += block.length; - } - } else { - while (position < filter.length) { - BitBlockCount block = bit_counter.NextAndWord(); - output_size += block.popcount; - position += block.length; - } - } - } else { - // The filter has no nulls, so we can use CountSetBits - output_size = CountSetBits(filter.buffers[1]->data(), filter.offset, filter.length); - } - return output_size; -} - +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/int_util.h" + +namespace arrow { + +using internal::BinaryBitBlockCounter; +using internal::BitBlockCount; +using internal::BitBlockCounter; +using internal::CheckIndexBounds; +using internal::CopyBitmap; +using internal::CountSetBits; +using internal::GetArrayView; +using internal::GetByteWidth; +using internal::OptionalBitBlockCounter; +using internal::OptionalBitIndexer; + +namespace compute { +namespace internal { + +int64_t GetFilterOutputSize(const ArrayData& filter, + FilterOptions::NullSelectionBehavior null_selection) { + int64_t output_size = 0; + + if (filter.MayHaveNulls()) { + const uint8_t* filter_is_valid = filter.buffers[0]->data(); + BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset, + filter_is_valid, filter.offset, filter.length); + int64_t position = 0; + if (null_selection == FilterOptions::EMIT_NULL) { + while (position < filter.length) { + BitBlockCount block = bit_counter.NextOrNotWord(); + output_size += block.popcount; + position += block.length; + } + } else { + while (position < filter.length) { + BitBlockCount block = bit_counter.NextAndWord(); + output_size += block.popcount; + position += block.length; + } + } + } else { + // The filter has no nulls, so we can use CountSetBits + output_size = CountSetBits(filter.buffers[1]->data(), filter.offset, filter.length); + } + return output_size; +} + namespace { -template <typename IndexType> -Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl( - const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, - MemoryPool* memory_pool) { - using T = typename IndexType::c_type; - - const uint8_t* filter_data = filter.buffers[1]->data(); +template <typename IndexType> +Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl( + const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool) { + using T = typename IndexType::c_type; + + const uint8_t* filter_data = filter.buffers[1]->data(); const bool have_filter_nulls = filter.MayHaveNulls(); const uint8_t* filter_is_valid = have_filter_nulls ? filter.buffers[0]->data() : nullptr; - + if (have_filter_nulls && null_selection == FilterOptions::EMIT_NULL) { // Most complex case: the filter may have nulls and we don't drop them. // The logic is ternary: // - filter is null: emit null // - filter is valid and true: emit index // - filter is valid and false: don't emit anything - + typename TypeTraits<IndexType>::BuilderType builder(memory_pool); - + // The position relative to the start of the filter T position = 0; // The current position taking the filter offset into account int64_t position_with_offset = filter.offset; // To count blocks where filter_data[i] || !filter_is_valid[i] - BinaryBitBlockCounter filter_counter(filter_data, filter.offset, filter_is_valid, - filter.offset, filter.length); + BinaryBitBlockCounter filter_counter(filter_data, filter.offset, filter_is_valid, + filter.offset, filter.length); BitBlockCounter is_valid_counter(filter_is_valid, filter.offset, filter.length); while (position < filter.length) { // true OR NOT valid @@ -125,13 +125,13 @@ Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl( position += selected_or_null_block.length; position_with_offset += selected_or_null_block.length; continue; - } + } RETURN_NOT_OK(builder.Reserve(selected_or_null_block.popcount)); - + // If the values are all valid and the selected_or_null_block is full, // then we can infer that all the values are true and skip the bit checking BitBlockCount is_valid_block = is_valid_counter.NextWord(); - + if (selected_or_null_block.AllSet() && is_valid_block.AllSet()) { // All the values are selected and non-null for (int64_t i = 0; i < selected_or_null_block.length; ++i) { @@ -144,24 +144,24 @@ Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl( if (BitUtil::GetBit(filter_is_valid, position_with_offset)) { if (BitUtil::GetBit(filter_data, position_with_offset)) { builder.UnsafeAppend(position); - } + } } else { // Null slot, so append a null builder.UnsafeAppendNull(); - } + } ++position; ++position_with_offset; - } - } - } + } + } + } std::shared_ptr<ArrayData> result; RETURN_NOT_OK(builder.FinishInternal(&result)); return result; } - + // Other cases don't emit nulls and are therefore simpler. TypedBufferBuilder<T> builder(memory_pool); - + if (have_filter_nulls) { // The filter may have nulls, so we scan the validity bitmap and the filter // data bitmap together. @@ -180,24 +180,24 @@ Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl( if (and_block.AllSet()) { // All the values are selected and non-null for (int64_t i = 0; i < and_block.length; ++i) { - builder.UnsafeAppend(position++); - } + builder.UnsafeAppend(position++); + } position_with_offset += and_block.length; } else if (!and_block.NoneSet()) { // Some of the values are false or null for (int64_t i = 0; i < and_block.length; ++i) { if (BitUtil::GetBit(filter_is_valid, position_with_offset) && BitUtil::GetBit(filter_data, position_with_offset)) { - builder.UnsafeAppend(position); - } - ++position; - ++position_with_offset; - } - } else { + builder.UnsafeAppend(position); + } + ++position; + ++position_with_offset; + } + } else { position += and_block.length; position_with_offset += and_block.length; - } - } + } + } } else { // The filter has no nulls, so we need only look for true values RETURN_NOT_OK(::arrow::internal::VisitSetBitRuns( @@ -209,1465 +209,1465 @@ Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl( } return Status::OK(); })); - } + } const int64_t length = builder.length(); std::shared_ptr<Buffer> out_buffer; RETURN_NOT_OK(builder.Finish(&out_buffer)); return std::make_shared<ArrayData>(TypeTraits<IndexType>::type_singleton(), length, BufferVector{nullptr, out_buffer}, /*null_count=*/0); -} - +} + } // namespace -Result<std::shared_ptr<ArrayData>> GetTakeIndices( - const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, - MemoryPool* memory_pool) { - DCHECK_EQ(filter.type->id(), Type::BOOL); - if (filter.length <= std::numeric_limits<uint16_t>::max()) { - return GetTakeIndicesImpl<UInt16Type>(filter, null_selection, memory_pool); - } else if (filter.length <= std::numeric_limits<uint32_t>::max()) { - return GetTakeIndicesImpl<UInt32Type>(filter, null_selection, memory_pool); - } else { - // Arrays over 4 billion elements, not especially likely. - return Status::NotImplemented( - "Filter length exceeds UINT32_MAX, " - "consider a different strategy for selecting elements"); - } -} - -namespace { - -using FilterState = OptionsWrapper<FilterOptions>; -using TakeState = OptionsWrapper<TakeOptions>; - -Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width, - bool allocate_validity, ArrayData* out) { - // Preallocate memory - out->length = length; - out->buffers.resize(2); - - if (allocate_validity) { - ARROW_ASSIGN_OR_RAISE(out->buffers[0], ctx->AllocateBitmap(length)); - } - if (bit_width == 1) { - ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(length)); - } else { - ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->Allocate(length * bit_width / 8)); - } - return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Implement optimized take for primitive types from boolean to 1/2/4/8-byte -// C-type based types. Use common implementation for every byte width and only -// generate code for unsigned integer indices, since after boundschecking to -// check for negative numbers in the indices we can safely reinterpret_cast -// signed integers as unsigned. - -/// \brief The Take implementation for primitive (fixed-width) types does not -/// use the logical Arrow type but rather the physical C type. This way we -/// only generate one take function for each byte width. -/// -/// This function assumes that the indices have been boundschecked. -template <typename IndexCType, typename ValueCType> -struct PrimitiveTakeImpl { - static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices, - ArrayData* out_arr) { - auto values_data = reinterpret_cast<const ValueCType*>(values.data); - auto values_is_valid = values.is_valid; - auto values_offset = values.offset; - - auto indices_data = reinterpret_cast<const IndexCType*>(indices.data); - auto indices_is_valid = indices.is_valid; - auto indices_offset = indices.offset; - - auto out = out_arr->GetMutableValues<ValueCType>(1); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - auto out_offset = out_arr->offset; - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false); - } - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; - int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - out[position] = values_data[indices_data[position]]; - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some indices but not all are null - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - BitUtil::SetBit(out_is_valid, out_offset + position); - out[position] = values_data[indices_data[position]]; - } else { - out[position] = ValueCType{}; - } - ++position; - } - } else { - memset(out + position, 0, sizeof(ValueCType) * block.length); - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - out[position] = values_data[indices_data[position]]; - BitUtil::SetBit(out_is_valid, out_offset + position); - ++valid_count; - } else { - out[position] = ValueCType{}; - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(indices_is_valid, indices_offset + position) && - BitUtil::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // index is not null && value is not null - out[position] = values_data[indices_data[position]]; - BitUtil::SetBit(out_is_valid, out_offset + position); - ++valid_count; - } else { - out[position] = ValueCType{}; - } - ++position; - } - } else { - memset(out + position, 0, sizeof(ValueCType) * block.length); - position += block.length; - } - } - } - out_arr->null_count = out_arr->length - valid_count; - } -}; - -template <typename IndexCType> -struct BooleanTakeImpl { - static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices, - ArrayData* out_arr) { - const uint8_t* values_data = values.data; - auto values_is_valid = values.is_valid; - auto values_offset = values.offset; - - auto indices_data = reinterpret_cast<const IndexCType*>(indices.data); - auto indices_is_valid = indices.is_valid; - auto indices_offset = indices.offset; - - auto out = out_arr->buffers[1]->mutable_data(); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - auto out_offset = out_arr->offset; - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false); - } - // Avoid uninitialized data in values array - BitUtil::SetBitsTo(out, out_offset, indices.length, false); - - auto PlaceDataBit = [&](int64_t loc, IndexCType index) { - BitUtil::SetBitTo(out, out_offset + loc, - BitUtil::GetBit(values_data, values_offset + index)); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; - int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - PlaceDataBit(position, indices_data[position]); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - BitUtil::SetBit(out_is_valid, out_offset + position); - PlaceDataBit(position, indices_data[position]); - } - ++position; - } - } else { - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - BitUtil::SetBit(out_is_valid, out_offset + position); - PlaceDataBit(position, indices_data[position]); - ++valid_count; - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - if (BitUtil::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - PlaceDataBit(position, indices_data[position]); - BitUtil::SetBit(out_is_valid, out_offset + position); - ++valid_count; - } - } - ++position; - } - } else { - position += block.length; - } - } - } - out_arr->null_count = out_arr->length - valid_count; - } -}; - -template <template <typename...> class TakeImpl, typename... Args> -void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices, - ArrayData* out) { - // With the simplifying assumption that boundschecking has taken place - // already at a higher level, we can now assume that the index values are all - // non-negative. Thus, we can interpret signed integers as unsigned and avoid - // having to generate double the amount of binary code to handle each integer - // width. - switch (indices.bit_width) { - case 8: - return TakeImpl<uint8_t, Args...>::Exec(values, indices, out); - case 16: - return TakeImpl<uint16_t, Args...>::Exec(values, indices, out); - case 32: - return TakeImpl<uint32_t, Args...>::Exec(values, indices, out); - case 64: - return TakeImpl<uint64_t, Args...>::Exec(values, indices, out); - default: - DCHECK(false) << "Invalid indices byte width"; - break; - } -} - +Result<std::shared_ptr<ArrayData>> GetTakeIndices( + const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool) { + DCHECK_EQ(filter.type->id(), Type::BOOL); + if (filter.length <= std::numeric_limits<uint16_t>::max()) { + return GetTakeIndicesImpl<UInt16Type>(filter, null_selection, memory_pool); + } else if (filter.length <= std::numeric_limits<uint32_t>::max()) { + return GetTakeIndicesImpl<UInt32Type>(filter, null_selection, memory_pool); + } else { + // Arrays over 4 billion elements, not especially likely. + return Status::NotImplemented( + "Filter length exceeds UINT32_MAX, " + "consider a different strategy for selecting elements"); + } +} + +namespace { + +using FilterState = OptionsWrapper<FilterOptions>; +using TakeState = OptionsWrapper<TakeOptions>; + +Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width, + bool allocate_validity, ArrayData* out) { + // Preallocate memory + out->length = length; + out->buffers.resize(2); + + if (allocate_validity) { + ARROW_ASSIGN_OR_RAISE(out->buffers[0], ctx->AllocateBitmap(length)); + } + if (bit_width == 1) { + ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(length)); + } else { + ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->Allocate(length * bit_width / 8)); + } + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Implement optimized take for primitive types from boolean to 1/2/4/8-byte +// C-type based types. Use common implementation for every byte width and only +// generate code for unsigned integer indices, since after boundschecking to +// check for negative numbers in the indices we can safely reinterpret_cast +// signed integers as unsigned. + +/// \brief The Take implementation for primitive (fixed-width) types does not +/// use the logical Arrow type but rather the physical C type. This way we +/// only generate one take function for each byte width. +/// +/// This function assumes that the indices have been boundschecked. +template <typename IndexCType, typename ValueCType> +struct PrimitiveTakeImpl { + static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices, + ArrayData* out_arr) { + auto values_data = reinterpret_cast<const ValueCType*>(values.data); + auto values_is_valid = values.is_valid; + auto values_offset = values.offset; + + auto indices_data = reinterpret_cast<const IndexCType*>(indices.data); + auto indices_is_valid = indices.is_valid; + auto indices_offset = indices.offset; + + auto out = out_arr->GetMutableValues<ValueCType>(1); + auto out_is_valid = out_arr->buffers[0]->mutable_data(); + auto out_offset = out_arr->offset; + + // If either the values or indices have nulls, we preemptively zero out the + // out validity bitmap so that we don't have to use ClearBit in each + // iteration for nulls. + if (values.null_count != 0 || indices.null_count != 0) { + BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false); + } + + OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, + indices.length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < indices.length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (values.null_count == 0) { + // Values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither values nor index nulls + BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + out[position] = values_data[indices_data[position]]; + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { + // index is not null + BitUtil::SetBit(out_is_valid, out_offset + position); + out[position] = values_data[indices_data[position]]; + } else { + out[position] = ValueCType{}; + } + ++position; + } + } else { + memset(out + position, 0, sizeof(ValueCType) * block.length); + position += block.length; + } + } else { + // Values have nulls, so we must do random access into the values bitmap + if (block.popcount == block.length) { + // Faster path: indices are not null but values may be + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // value is not null + out[position] = values_data[indices_data[position]]; + BitUtil::SetBit(out_is_valid, out_offset + position); + ++valid_count; + } else { + out[position] = ValueCType{}; + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position) && + BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // index is not null && value is not null + out[position] = values_data[indices_data[position]]; + BitUtil::SetBit(out_is_valid, out_offset + position); + ++valid_count; + } else { + out[position] = ValueCType{}; + } + ++position; + } + } else { + memset(out + position, 0, sizeof(ValueCType) * block.length); + position += block.length; + } + } + } + out_arr->null_count = out_arr->length - valid_count; + } +}; + +template <typename IndexCType> +struct BooleanTakeImpl { + static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices, + ArrayData* out_arr) { + const uint8_t* values_data = values.data; + auto values_is_valid = values.is_valid; + auto values_offset = values.offset; + + auto indices_data = reinterpret_cast<const IndexCType*>(indices.data); + auto indices_is_valid = indices.is_valid; + auto indices_offset = indices.offset; + + auto out = out_arr->buffers[1]->mutable_data(); + auto out_is_valid = out_arr->buffers[0]->mutable_data(); + auto out_offset = out_arr->offset; + + // If either the values or indices have nulls, we preemptively zero out the + // out validity bitmap so that we don't have to use ClearBit in each + // iteration for nulls. + if (values.null_count != 0 || indices.null_count != 0) { + BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false); + } + // Avoid uninitialized data in values array + BitUtil::SetBitsTo(out, out_offset, indices.length, false); + + auto PlaceDataBit = [&](int64_t loc, IndexCType index) { + BitUtil::SetBitTo(out, out_offset + loc, + BitUtil::GetBit(values_data, values_offset + index)); + }; + + OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, + indices.length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < indices.length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (values.null_count == 0) { + // Values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither values nor index nulls + BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + PlaceDataBit(position, indices_data[position]); + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { + // index is not null + BitUtil::SetBit(out_is_valid, out_offset + position); + PlaceDataBit(position, indices_data[position]); + } + ++position; + } + } else { + position += block.length; + } + } else { + // Values have nulls, so we must do random access into the values bitmap + if (block.popcount == block.length) { + // Faster path: indices are not null but values may be + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // value is not null + BitUtil::SetBit(out_is_valid, out_offset + position); + PlaceDataBit(position, indices_data[position]); + ++valid_count; + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { + // index is not null + if (BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // value is not null + PlaceDataBit(position, indices_data[position]); + BitUtil::SetBit(out_is_valid, out_offset + position); + ++valid_count; + } + } + ++position; + } + } else { + position += block.length; + } + } + } + out_arr->null_count = out_arr->length - valid_count; + } +}; + +template <template <typename...> class TakeImpl, typename... Args> +void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices, + ArrayData* out) { + // With the simplifying assumption that boundschecking has taken place + // already at a higher level, we can now assume that the index values are all + // non-negative. Thus, we can interpret signed integers as unsigned and avoid + // having to generate double the amount of binary code to handle each integer + // width. + switch (indices.bit_width) { + case 8: + return TakeImpl<uint8_t, Args...>::Exec(values, indices, out); + case 16: + return TakeImpl<uint16_t, Args...>::Exec(values, indices, out); + case 32: + return TakeImpl<uint32_t, Args...>::Exec(values, indices, out); + case 64: + return TakeImpl<uint64_t, Args...>::Exec(values, indices, out); + default: + DCHECK(false) << "Invalid indices byte width"; + break; + } +} + Status PrimitiveTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (TakeState::Get(ctx).boundscheck) { + if (TakeState::Get(ctx).boundscheck) { RETURN_NOT_OK(CheckIndexBounds(*batch[1].array(), batch[0].length())); - } - - PrimitiveArg values = GetPrimitiveArg(*batch[0].array()); - PrimitiveArg indices = GetPrimitiveArg(*batch[1].array()); - - ArrayData* out_arr = out->mutable_array(); - - // TODO: When neither values nor indices contain nulls, we can skip - // allocating the validity bitmap altogether and save time and space. A - // streamlined PrimitiveTakeImpl would need to be written that skips all - // interactions with the output validity bitmap, though. + } + + PrimitiveArg values = GetPrimitiveArg(*batch[0].array()); + PrimitiveArg indices = GetPrimitiveArg(*batch[1].array()); + + ArrayData* out_arr = out->mutable_array(); + + // TODO: When neither values nor indices contain nulls, we can skip + // allocating the validity bitmap altogether and save time and space. A + // streamlined PrimitiveTakeImpl would need to be written that skips all + // interactions with the output validity bitmap, though. RETURN_NOT_OK(PreallocateData(ctx, indices.length, values.bit_width, /*allocate_validity=*/true, out_arr)); - switch (values.bit_width) { - case 1: + switch (values.bit_width) { + case 1: TakeIndexDispatch<BooleanTakeImpl>(values, indices, out_arr); break; - case 8: + case 8: TakeIndexDispatch<PrimitiveTakeImpl, int8_t>(values, indices, out_arr); break; - case 16: + case 16: TakeIndexDispatch<PrimitiveTakeImpl, int16_t>(values, indices, out_arr); break; - case 32: + case 32: TakeIndexDispatch<PrimitiveTakeImpl, int32_t>(values, indices, out_arr); break; - case 64: + case 64: TakeIndexDispatch<PrimitiveTakeImpl, int64_t>(values, indices, out_arr); break; - default: - DCHECK(false) << "Invalid values byte width"; - break; - } + default: + DCHECK(false) << "Invalid values byte width"; + break; + } return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Optimized and streamlined filter for primitive types - -// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a -// word at a time for the DROP selection type. -class DropNullCounter { - public: - // validity bitmap may be null - DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset, - int64_t length) - : data_counter_(data, offset, length), - data_and_validity_counter_(data, offset, validity, offset, length), - has_validity_(validity != nullptr) {} - - BitBlockCount NextBlock() { - if (has_validity_) { - // filter is true AND not null - return data_and_validity_counter_.NextAndWord(); - } else { - return data_counter_.NextWord(); - } - } - - private: - // For when just data is present, but no validity bitmap - BitBlockCounter data_counter_; - - // For when both validity bitmap and data are present - BinaryBitBlockCounter data_and_validity_counter_; - const bool has_validity_; -}; - -/// \brief The Filter implementation for primitive (fixed-width) types does not -/// use the logical Arrow type but rather the physical C type. This way we only -/// generate one take function for each byte width. We use the same -/// implementation here for boolean and fixed-byte-size inputs with some -/// template specialization. -template <typename ArrowType> -class PrimitiveFilterImpl { - public: - using T = typename std::conditional<std::is_same<ArrowType, BooleanType>::value, - uint8_t, typename ArrowType::c_type>::type; - - PrimitiveFilterImpl(const PrimitiveArg& values, const PrimitiveArg& filter, - FilterOptions::NullSelectionBehavior null_selection, - ArrayData* out_arr) - : values_is_valid_(values.is_valid), - values_data_(reinterpret_cast<const T*>(values.data)), - values_null_count_(values.null_count), - values_offset_(values.offset), - values_length_(values.length), - filter_is_valid_(filter.is_valid), - filter_data_(filter.data), - filter_null_count_(filter.null_count), - filter_offset_(filter.offset), - null_selection_(null_selection) { - if (out_arr->buffers[0] != nullptr) { - // May not be allocated if neither filter nor values contains nulls - out_is_valid_ = out_arr->buffers[0]->mutable_data(); - } - out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data()); - out_offset_ = out_arr->offset; - out_length_ = out_arr->length; - out_position_ = 0; - } - - void ExecNonNull() { - // Fast filter when values and filter are not null +} + +// ---------------------------------------------------------------------- +// Optimized and streamlined filter for primitive types + +// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a +// word at a time for the DROP selection type. +class DropNullCounter { + public: + // validity bitmap may be null + DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset, + int64_t length) + : data_counter_(data, offset, length), + data_and_validity_counter_(data, offset, validity, offset, length), + has_validity_(validity != nullptr) {} + + BitBlockCount NextBlock() { + if (has_validity_) { + // filter is true AND not null + return data_and_validity_counter_.NextAndWord(); + } else { + return data_counter_.NextWord(); + } + } + + private: + // For when just data is present, but no validity bitmap + BitBlockCounter data_counter_; + + // For when both validity bitmap and data are present + BinaryBitBlockCounter data_and_validity_counter_; + const bool has_validity_; +}; + +/// \brief The Filter implementation for primitive (fixed-width) types does not +/// use the logical Arrow type but rather the physical C type. This way we only +/// generate one take function for each byte width. We use the same +/// implementation here for boolean and fixed-byte-size inputs with some +/// template specialization. +template <typename ArrowType> +class PrimitiveFilterImpl { + public: + using T = typename std::conditional<std::is_same<ArrowType, BooleanType>::value, + uint8_t, typename ArrowType::c_type>::type; + + PrimitiveFilterImpl(const PrimitiveArg& values, const PrimitiveArg& filter, + FilterOptions::NullSelectionBehavior null_selection, + ArrayData* out_arr) + : values_is_valid_(values.is_valid), + values_data_(reinterpret_cast<const T*>(values.data)), + values_null_count_(values.null_count), + values_offset_(values.offset), + values_length_(values.length), + filter_is_valid_(filter.is_valid), + filter_data_(filter.data), + filter_null_count_(filter.null_count), + filter_offset_(filter.offset), + null_selection_(null_selection) { + if (out_arr->buffers[0] != nullptr) { + // May not be allocated if neither filter nor values contains nulls + out_is_valid_ = out_arr->buffers[0]->mutable_data(); + } + out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data()); + out_offset_ = out_arr->offset; + out_length_ = out_arr->length; + out_position_ = 0; + } + + void ExecNonNull() { + // Fast filter when values and filter are not null ::arrow::internal::VisitSetBitRunsVoid( filter_data_, filter_offset_, values_length_, [&](int64_t position, int64_t length) { WriteValueSegment(position, length); }); - } - - void Exec() { - if (filter_null_count_ == 0 && values_null_count_ == 0) { - return ExecNonNull(); - } - - // Bit counters used for both null_selection behaviors - DropNullCounter drop_null_counter(filter_is_valid_, filter_data_, filter_offset_, - values_length_); - OptionalBitBlockCounter data_counter(values_is_valid_, values_offset_, - values_length_); - OptionalBitBlockCounter filter_valid_counter(filter_is_valid_, filter_offset_, - values_length_); - - auto WriteNotNull = [&](int64_t index) { - BitUtil::SetBit(out_is_valid_, out_offset_ + out_position_); - // Increments out_position_ - WriteValue(index); - }; - - auto WriteMaybeNull = [&](int64_t index) { - BitUtil::SetBitTo(out_is_valid_, out_offset_ + out_position_, - BitUtil::GetBit(values_is_valid_, values_offset_ + index)); - // Increments out_position_ - WriteValue(index); - }; - - int64_t in_position = 0; - while (in_position < values_length_) { - BitBlockCount filter_block = drop_null_counter.NextBlock(); - BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); - BitBlockCount data_block = data_counter.NextWord(); - if (filter_block.AllSet() && data_block.AllSet()) { - // Fastest path: all values in block are included and not null - BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_, - filter_block.length, true); - WriteValueSegment(in_position, filter_block.length); - in_position += filter_block.length; - } else if (filter_block.AllSet()) { - // Faster: all values are selected, but some values are null - // Batch copy bits from values validity bitmap to output validity bitmap - CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length, - out_is_valid_, out_offset_ + out_position_); - WriteValueSegment(in_position, filter_block.length); - in_position += filter_block.length; - } else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) { - // For this exceedingly common case in low-selectivity filters we can - // skip further analysis of the data and move on to the next block. - in_position += filter_block.length; - } else { - // Some filter values are false or null - if (data_block.AllSet()) { - // No values are null - if (filter_valid_block.AllSet()) { - // Filter is non-null but some values are false - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { - WriteNotNull(in_position); - } - ++in_position; - } - } else if (null_selection_ == FilterOptions::DROP) { - // If any values are selected, they ARE NOT null - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) && - BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { - WriteNotNull(in_position); - } - ++in_position; - } - } else { // null_selection == FilterOptions::EMIT_NULL - // Data values in this block are not null - for (int64_t i = 0; i < filter_block.length; ++i) { - const bool is_valid = - BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position); - if (is_valid && - BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { - // Filter slot is non-null and set - WriteNotNull(in_position); - } else if (!is_valid) { - // Filter slot is null, so we have a null in the output - BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_); - WriteNull(); - } - ++in_position; - } - } - } else { // !data_block.AllSet() - // Some values are null - if (filter_valid_block.AllSet()) { - // Filter is non-null but some values are false - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { - WriteMaybeNull(in_position); - } - ++in_position; - } - } else if (null_selection_ == FilterOptions::DROP) { - // If any values are selected, they ARE NOT null - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) && - BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { - WriteMaybeNull(in_position); - } - ++in_position; - } - } else { // null_selection == FilterOptions::EMIT_NULL - // Data values in this block are not null - for (int64_t i = 0; i < filter_block.length; ++i) { - const bool is_valid = - BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position); - if (is_valid && - BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { - // Filter slot is non-null and set - WriteMaybeNull(in_position); - } else if (!is_valid) { - // Filter slot is null, so we have a null in the output - BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_); - WriteNull(); - } - ++in_position; - } - } - } - } // !filter_block.AllSet() - } // while(in_position < values_length_) - } - - // Write the next out_position given the selected in_position for the input - // data and advance out_position - void WriteValue(int64_t in_position) { - out_data_[out_position_++] = values_data_[in_position]; - } - - void WriteValueSegment(int64_t in_start, int64_t length) { - std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * sizeof(T)); - out_position_ += length; - } - - void WriteNull() { - // Zero the memory - out_data_[out_position_++] = T{}; - } - - private: - const uint8_t* values_is_valid_; - const T* values_data_; - int64_t values_null_count_; - int64_t values_offset_; - int64_t values_length_; - const uint8_t* filter_is_valid_; - const uint8_t* filter_data_; - int64_t filter_null_count_; - int64_t filter_offset_; - FilterOptions::NullSelectionBehavior null_selection_; - uint8_t* out_is_valid_; - T* out_data_; - int64_t out_offset_; - int64_t out_length_; - int64_t out_position_; -}; - -template <> -inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) { - BitUtil::SetBitTo(out_data_, out_offset_ + out_position_++, - BitUtil::GetBit(values_data_, values_offset_ + in_position)); -} - -template <> -inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t in_start, - int64_t length) { - CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_, - out_offset_ + out_position_); - out_position_ += length; -} - -template <> -inline void PrimitiveFilterImpl<BooleanType>::WriteNull() { - // Zero the bit - BitUtil::ClearBit(out_data_, out_offset_ + out_position_++); -} - + } + + void Exec() { + if (filter_null_count_ == 0 && values_null_count_ == 0) { + return ExecNonNull(); + } + + // Bit counters used for both null_selection behaviors + DropNullCounter drop_null_counter(filter_is_valid_, filter_data_, filter_offset_, + values_length_); + OptionalBitBlockCounter data_counter(values_is_valid_, values_offset_, + values_length_); + OptionalBitBlockCounter filter_valid_counter(filter_is_valid_, filter_offset_, + values_length_); + + auto WriteNotNull = [&](int64_t index) { + BitUtil::SetBit(out_is_valid_, out_offset_ + out_position_); + // Increments out_position_ + WriteValue(index); + }; + + auto WriteMaybeNull = [&](int64_t index) { + BitUtil::SetBitTo(out_is_valid_, out_offset_ + out_position_, + BitUtil::GetBit(values_is_valid_, values_offset_ + index)); + // Increments out_position_ + WriteValue(index); + }; + + int64_t in_position = 0; + while (in_position < values_length_) { + BitBlockCount filter_block = drop_null_counter.NextBlock(); + BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); + BitBlockCount data_block = data_counter.NextWord(); + if (filter_block.AllSet() && data_block.AllSet()) { + // Fastest path: all values in block are included and not null + BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_, + filter_block.length, true); + WriteValueSegment(in_position, filter_block.length); + in_position += filter_block.length; + } else if (filter_block.AllSet()) { + // Faster: all values are selected, but some values are null + // Batch copy bits from values validity bitmap to output validity bitmap + CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length, + out_is_valid_, out_offset_ + out_position_); + WriteValueSegment(in_position, filter_block.length); + in_position += filter_block.length; + } else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) { + // For this exceedingly common case in low-selectivity filters we can + // skip further analysis of the data and move on to the next block. + in_position += filter_block.length; + } else { + // Some filter values are false or null + if (data_block.AllSet()) { + // No values are null + if (filter_valid_block.AllSet()) { + // Filter is non-null but some values are false + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { + WriteNotNull(in_position); + } + ++in_position; + } + } else if (null_selection_ == FilterOptions::DROP) { + // If any values are selected, they ARE NOT null + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) && + BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { + WriteNotNull(in_position); + } + ++in_position; + } + } else { // null_selection == FilterOptions::EMIT_NULL + // Data values in this block are not null + for (int64_t i = 0; i < filter_block.length; ++i) { + const bool is_valid = + BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position); + if (is_valid && + BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { + // Filter slot is non-null and set + WriteNotNull(in_position); + } else if (!is_valid) { + // Filter slot is null, so we have a null in the output + BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_); + WriteNull(); + } + ++in_position; + } + } + } else { // !data_block.AllSet() + // Some values are null + if (filter_valid_block.AllSet()) { + // Filter is non-null but some values are false + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { + WriteMaybeNull(in_position); + } + ++in_position; + } + } else if (null_selection_ == FilterOptions::DROP) { + // If any values are selected, they ARE NOT null + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) && + BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { + WriteMaybeNull(in_position); + } + ++in_position; + } + } else { // null_selection == FilterOptions::EMIT_NULL + // Data values in this block are not null + for (int64_t i = 0; i < filter_block.length; ++i) { + const bool is_valid = + BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position); + if (is_valid && + BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) { + // Filter slot is non-null and set + WriteMaybeNull(in_position); + } else if (!is_valid) { + // Filter slot is null, so we have a null in the output + BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_); + WriteNull(); + } + ++in_position; + } + } + } + } // !filter_block.AllSet() + } // while(in_position < values_length_) + } + + // Write the next out_position given the selected in_position for the input + // data and advance out_position + void WriteValue(int64_t in_position) { + out_data_[out_position_++] = values_data_[in_position]; + } + + void WriteValueSegment(int64_t in_start, int64_t length) { + std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * sizeof(T)); + out_position_ += length; + } + + void WriteNull() { + // Zero the memory + out_data_[out_position_++] = T{}; + } + + private: + const uint8_t* values_is_valid_; + const T* values_data_; + int64_t values_null_count_; + int64_t values_offset_; + int64_t values_length_; + const uint8_t* filter_is_valid_; + const uint8_t* filter_data_; + int64_t filter_null_count_; + int64_t filter_offset_; + FilterOptions::NullSelectionBehavior null_selection_; + uint8_t* out_is_valid_; + T* out_data_; + int64_t out_offset_; + int64_t out_length_; + int64_t out_position_; +}; + +template <> +inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) { + BitUtil::SetBitTo(out_data_, out_offset_ + out_position_++, + BitUtil::GetBit(values_data_, values_offset_ + in_position)); +} + +template <> +inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t in_start, + int64_t length) { + CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_, + out_offset_ + out_position_); + out_position_ += length; +} + +template <> +inline void PrimitiveFilterImpl<BooleanType>::WriteNull() { + // Zero the bit + BitUtil::ClearBit(out_data_, out_offset_ + out_position_++); +} + Status PrimitiveFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - PrimitiveArg values = GetPrimitiveArg(*batch[0].array()); - PrimitiveArg filter = GetPrimitiveArg(*batch[1].array()); - FilterOptions::NullSelectionBehavior null_selection = - FilterState::Get(ctx).null_selection_behavior; - - int64_t output_length = GetFilterOutputSize(*batch[1].array(), null_selection); - - ArrayData* out_arr = out->mutable_array(); - - // The output precomputed null count is unknown except in the narrow - // condition that all the values are non-null and the filter will not cause - // any new nulls to be created. - if (values.null_count == 0 && - (null_selection == FilterOptions::DROP || filter.null_count == 0)) { - out_arr->null_count = 0; - } else { - out_arr->null_count = kUnknownNullCount; - } - - // When neither the values nor filter is known to have any nulls, we will - // elect the optimized ExecNonNull path where there is no need to populate a - // validity bitmap. - bool allocate_validity = values.null_count != 0 || filter.null_count != 0; - + PrimitiveArg values = GetPrimitiveArg(*batch[0].array()); + PrimitiveArg filter = GetPrimitiveArg(*batch[1].array()); + FilterOptions::NullSelectionBehavior null_selection = + FilterState::Get(ctx).null_selection_behavior; + + int64_t output_length = GetFilterOutputSize(*batch[1].array(), null_selection); + + ArrayData* out_arr = out->mutable_array(); + + // The output precomputed null count is unknown except in the narrow + // condition that all the values are non-null and the filter will not cause + // any new nulls to be created. + if (values.null_count == 0 && + (null_selection == FilterOptions::DROP || filter.null_count == 0)) { + out_arr->null_count = 0; + } else { + out_arr->null_count = kUnknownNullCount; + } + + // When neither the values nor filter is known to have any nulls, we will + // elect the optimized ExecNonNull path where there is no need to populate a + // validity bitmap. + bool allocate_validity = values.null_count != 0 || filter.null_count != 0; + RETURN_NOT_OK( PreallocateData(ctx, output_length, values.bit_width, allocate_validity, out_arr)); - - switch (values.bit_width) { - case 1: + + switch (values.bit_width) { + case 1: PrimitiveFilterImpl<BooleanType>(values, filter, null_selection, out_arr).Exec(); break; - case 8: + case 8: PrimitiveFilterImpl<UInt8Type>(values, filter, null_selection, out_arr).Exec(); break; - case 16: + case 16: PrimitiveFilterImpl<UInt16Type>(values, filter, null_selection, out_arr).Exec(); break; - case 32: + case 32: PrimitiveFilterImpl<UInt32Type>(values, filter, null_selection, out_arr).Exec(); break; - case 64: + case 64: PrimitiveFilterImpl<UInt64Type>(values, filter, null_selection, out_arr).Exec(); break; - default: - DCHECK(false) << "Invalid values bit width"; - break; - } + default: + DCHECK(false) << "Invalid values bit width"; + break; + } return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Optimized filter for base binary types (32-bit and 64-bit) - -#define BINARY_FILTER_SETUP_COMMON() \ - auto raw_offsets = \ - reinterpret_cast<const offset_type*>(values.buffers[1]->data()) + values.offset; \ - const uint8_t* raw_data = values.buffers[2]->data(); \ - \ - TypedBufferBuilder<offset_type> offset_builder(ctx->memory_pool()); \ - TypedBufferBuilder<uint8_t> data_builder(ctx->memory_pool()); \ - RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); \ - \ - /* Presize the data builder with a rough estimate */ \ - if (values.length > 0) { \ - const double mean_value_length = (raw_offsets[values.length] - raw_offsets[0]) / \ - static_cast<double>(values.length); \ - RETURN_NOT_OK( \ - data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length))); \ - } \ - int64_t space_available = data_builder.capacity(); \ - offset_type offset = 0; - -#define APPEND_RAW_DATA(DATA, NBYTES) \ - if (ARROW_PREDICT_FALSE(NBYTES > space_available)) { \ - RETURN_NOT_OK(data_builder.Reserve(NBYTES)); \ - space_available = data_builder.capacity() - data_builder.length(); \ - } \ - data_builder.UnsafeAppend(DATA, NBYTES); \ - space_available -= NBYTES - -#define APPEND_SINGLE_VALUE() \ - do { \ - offset_type val_size = raw_offsets[in_position + 1] - raw_offsets[in_position]; \ - APPEND_RAW_DATA(raw_data + raw_offsets[in_position], val_size); \ - offset += val_size; \ - } while (0) - -// Optimized binary filter for the case where neither values nor filter have -// nulls -template <typename Type> -Status BinaryFilterNonNullImpl(KernelContext* ctx, const ArrayData& values, - const ArrayData& filter, int64_t output_length, - FilterOptions::NullSelectionBehavior null_selection, - ArrayData* out) { - using offset_type = typename Type::offset_type; - const auto filter_data = filter.buffers[1]->data(); - - BINARY_FILTER_SETUP_COMMON(); - +} + +// ---------------------------------------------------------------------- +// Optimized filter for base binary types (32-bit and 64-bit) + +#define BINARY_FILTER_SETUP_COMMON() \ + auto raw_offsets = \ + reinterpret_cast<const offset_type*>(values.buffers[1]->data()) + values.offset; \ + const uint8_t* raw_data = values.buffers[2]->data(); \ + \ + TypedBufferBuilder<offset_type> offset_builder(ctx->memory_pool()); \ + TypedBufferBuilder<uint8_t> data_builder(ctx->memory_pool()); \ + RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); \ + \ + /* Presize the data builder with a rough estimate */ \ + if (values.length > 0) { \ + const double mean_value_length = (raw_offsets[values.length] - raw_offsets[0]) / \ + static_cast<double>(values.length); \ + RETURN_NOT_OK( \ + data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length))); \ + } \ + int64_t space_available = data_builder.capacity(); \ + offset_type offset = 0; + +#define APPEND_RAW_DATA(DATA, NBYTES) \ + if (ARROW_PREDICT_FALSE(NBYTES > space_available)) { \ + RETURN_NOT_OK(data_builder.Reserve(NBYTES)); \ + space_available = data_builder.capacity() - data_builder.length(); \ + } \ + data_builder.UnsafeAppend(DATA, NBYTES); \ + space_available -= NBYTES + +#define APPEND_SINGLE_VALUE() \ + do { \ + offset_type val_size = raw_offsets[in_position + 1] - raw_offsets[in_position]; \ + APPEND_RAW_DATA(raw_data + raw_offsets[in_position], val_size); \ + offset += val_size; \ + } while (0) + +// Optimized binary filter for the case where neither values nor filter have +// nulls +template <typename Type> +Status BinaryFilterNonNullImpl(KernelContext* ctx, const ArrayData& values, + const ArrayData& filter, int64_t output_length, + FilterOptions::NullSelectionBehavior null_selection, + ArrayData* out) { + using offset_type = typename Type::offset_type; + const auto filter_data = filter.buffers[1]->data(); + + BINARY_FILTER_SETUP_COMMON(); + RETURN_NOT_OK(arrow::internal::VisitSetBitRuns( filter_data, filter.offset, filter.length, [&](int64_t position, int64_t length) { - // Bulk-append raw data + // Bulk-append raw data const offset_type run_data_bytes = (raw_offsets[position + length] - raw_offsets[position]); APPEND_RAW_DATA(raw_data + raw_offsets[position], run_data_bytes); - // Append offsets + // Append offsets offset_type cur_offset = raw_offsets[position]; for (int64_t i = 0; i < length; ++i) { - offset_builder.UnsafeAppend(offset); + offset_builder.UnsafeAppend(offset); offset += raw_offsets[i + position + 1] - cur_offset; cur_offset = raw_offsets[i + position + 1]; - } + } return Status::OK(); })); - offset_builder.UnsafeAppend(offset); - out->length = output_length; - RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); - return data_builder.Finish(&out->buffers[2]); -} - -template <typename Type> -Status BinaryFilterImpl(KernelContext* ctx, const ArrayData& values, - const ArrayData& filter, int64_t output_length, - FilterOptions::NullSelectionBehavior null_selection, - ArrayData* out) { - using offset_type = typename Type::offset_type; - - const auto filter_data = filter.buffers[1]->data(); - const uint8_t* filter_is_valid = GetValidityBitmap(filter); - const int64_t filter_offset = filter.offset; - - const uint8_t* values_is_valid = GetValidityBitmap(values); - const int64_t values_offset = values.offset; - - uint8_t* out_is_valid = out->buffers[0]->mutable_data(); - // Zero bits and then only have to set valid values to true - BitUtil::SetBitsTo(out_is_valid, 0, output_length, false); - - // We use 3 block counters for fast scanning of the filter - // - // * values_valid_counter: for values null/not-null - // * filter_valid_counter: for filter null/not-null - // * filter_counter: for filter true/false - OptionalBitBlockCounter values_valid_counter(values_is_valid, values_offset, - values.length); - OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset, - filter.length); - BitBlockCounter filter_counter(filter_data, filter_offset, filter.length); - - BINARY_FILTER_SETUP_COMMON(); - + offset_builder.UnsafeAppend(offset); + out->length = output_length; + RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); + return data_builder.Finish(&out->buffers[2]); +} + +template <typename Type> +Status BinaryFilterImpl(KernelContext* ctx, const ArrayData& values, + const ArrayData& filter, int64_t output_length, + FilterOptions::NullSelectionBehavior null_selection, + ArrayData* out) { + using offset_type = typename Type::offset_type; + + const auto filter_data = filter.buffers[1]->data(); + const uint8_t* filter_is_valid = GetValidityBitmap(filter); + const int64_t filter_offset = filter.offset; + + const uint8_t* values_is_valid = GetValidityBitmap(values); + const int64_t values_offset = values.offset; + + uint8_t* out_is_valid = out->buffers[0]->mutable_data(); + // Zero bits and then only have to set valid values to true + BitUtil::SetBitsTo(out_is_valid, 0, output_length, false); + + // We use 3 block counters for fast scanning of the filter + // + // * values_valid_counter: for values null/not-null + // * filter_valid_counter: for filter null/not-null + // * filter_counter: for filter true/false + OptionalBitBlockCounter values_valid_counter(values_is_valid, values_offset, + values.length); + OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset, + filter.length); + BitBlockCounter filter_counter(filter_data, filter_offset, filter.length); + + BINARY_FILTER_SETUP_COMMON(); + int64_t in_position = 0; int64_t out_position = 0; - while (in_position < filter.length) { - BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); - BitBlockCount values_valid_block = values_valid_counter.NextWord(); - BitBlockCount filter_block = filter_counter.NextWord(); - if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) { - // For this exceedingly common case in low-selectivity filters we can - // skip further analysis of the data and move on to the next block. - in_position += filter_block.length; - } else if (filter_valid_block.AllSet()) { - // Simpler path: no filter values are null - if (filter_block.AllSet()) { - // Fastest path: filter values are all true and not null - if (values_valid_block.AllSet()) { - // The values aren't null either - BitUtil::SetBitsTo(out_is_valid, out_position, filter_block.length, true); - - // Bulk-append raw data - offset_type block_data_bytes = - (raw_offsets[in_position + filter_block.length] - raw_offsets[in_position]); - APPEND_RAW_DATA(raw_data + raw_offsets[in_position], block_data_bytes); - // Append offsets - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - offset_builder.UnsafeAppend(offset); - offset += raw_offsets[in_position + 1] - raw_offsets[in_position]; - } - out_position += filter_block.length; - } else { - // Some of the values in this block are null - for (int64_t i = 0; i < filter_block.length; - ++i, ++in_position, ++out_position) { - offset_builder.UnsafeAppend(offset); - if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { - BitUtil::SetBit(out_is_valid, out_position); - APPEND_SINGLE_VALUE(); - } - } - } - } else { // !filter_block.AllSet() - // Some of the filter values are false, but all not null - if (values_valid_block.AllSet()) { - // All the values are not-null, so we can skip null checking for - // them - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { - offset_builder.UnsafeAppend(offset); - BitUtil::SetBit(out_is_valid, out_position++); - APPEND_SINGLE_VALUE(); - } - } - } else { - // Some of the values in the block are null, so we have to check - // each one - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { - offset_builder.UnsafeAppend(offset); - if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { - BitUtil::SetBit(out_is_valid, out_position); - APPEND_SINGLE_VALUE(); - } - ++out_position; - } - } - } - } - } else { // !filter_valid_block.AllSet() - // Some of the filter values are null, so we have to handle the DROP - // versus EMIT_NULL null selection behavior. - if (null_selection == FilterOptions::DROP) { - // Filter null values are treated as false. - if (values_valid_block.AllSet()) { - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) && - BitUtil::GetBit(filter_data, filter_offset + in_position)) { - offset_builder.UnsafeAppend(offset); - BitUtil::SetBit(out_is_valid, out_position++); - APPEND_SINGLE_VALUE(); - } - } - } else { - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) && - BitUtil::GetBit(filter_data, filter_offset + in_position)) { - offset_builder.UnsafeAppend(offset); - if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { - BitUtil::SetBit(out_is_valid, out_position); - APPEND_SINGLE_VALUE(); - } - ++out_position; - } - } - } - } else { - // EMIT_NULL - - // Filter null values are appended to output as null whether the - // value in the corresponding slot is valid or not - if (values_valid_block.AllSet()) { - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - const bool filter_not_null = - BitUtil::GetBit(filter_is_valid, filter_offset + in_position); - if (filter_not_null && - BitUtil::GetBit(filter_data, filter_offset + in_position)) { - offset_builder.UnsafeAppend(offset); - BitUtil::SetBit(out_is_valid, out_position++); - APPEND_SINGLE_VALUE(); - } else if (!filter_not_null) { - offset_builder.UnsafeAppend(offset); - ++out_position; - } - } - } else { - for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { - const bool filter_not_null = - BitUtil::GetBit(filter_is_valid, filter_offset + in_position); - if (filter_not_null && - BitUtil::GetBit(filter_data, filter_offset + in_position)) { - offset_builder.UnsafeAppend(offset); - if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { - BitUtil::SetBit(out_is_valid, out_position); - APPEND_SINGLE_VALUE(); - } - ++out_position; - } else if (!filter_not_null) { - offset_builder.UnsafeAppend(offset); - ++out_position; - } - } - } - } - } - } - offset_builder.UnsafeAppend(offset); - out->length = output_length; - RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); - return data_builder.Finish(&out->buffers[2]); -} - -#undef BINARY_FILTER_SETUP_COMMON -#undef APPEND_RAW_DATA -#undef APPEND_SINGLE_VALUE - + while (in_position < filter.length) { + BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); + BitBlockCount values_valid_block = values_valid_counter.NextWord(); + BitBlockCount filter_block = filter_counter.NextWord(); + if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) { + // For this exceedingly common case in low-selectivity filters we can + // skip further analysis of the data and move on to the next block. + in_position += filter_block.length; + } else if (filter_valid_block.AllSet()) { + // Simpler path: no filter values are null + if (filter_block.AllSet()) { + // Fastest path: filter values are all true and not null + if (values_valid_block.AllSet()) { + // The values aren't null either + BitUtil::SetBitsTo(out_is_valid, out_position, filter_block.length, true); + + // Bulk-append raw data + offset_type block_data_bytes = + (raw_offsets[in_position + filter_block.length] - raw_offsets[in_position]); + APPEND_RAW_DATA(raw_data + raw_offsets[in_position], block_data_bytes); + // Append offsets + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + offset_builder.UnsafeAppend(offset); + offset += raw_offsets[in_position + 1] - raw_offsets[in_position]; + } + out_position += filter_block.length; + } else { + // Some of the values in this block are null + for (int64_t i = 0; i < filter_block.length; + ++i, ++in_position, ++out_position) { + offset_builder.UnsafeAppend(offset); + if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { + BitUtil::SetBit(out_is_valid, out_position); + APPEND_SINGLE_VALUE(); + } + } + } + } else { // !filter_block.AllSet() + // Some of the filter values are false, but all not null + if (values_valid_block.AllSet()) { + // All the values are not-null, so we can skip null checking for + // them + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + BitUtil::SetBit(out_is_valid, out_position++); + APPEND_SINGLE_VALUE(); + } + } + } else { + // Some of the values in the block are null, so we have to check + // each one + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { + BitUtil::SetBit(out_is_valid, out_position); + APPEND_SINGLE_VALUE(); + } + ++out_position; + } + } + } + } + } else { // !filter_valid_block.AllSet() + // Some of the filter values are null, so we have to handle the DROP + // versus EMIT_NULL null selection behavior. + if (null_selection == FilterOptions::DROP) { + // Filter null values are treated as false. + if (values_valid_block.AllSet()) { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) && + BitUtil::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + BitUtil::SetBit(out_is_valid, out_position++); + APPEND_SINGLE_VALUE(); + } + } + } else { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) && + BitUtil::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { + BitUtil::SetBit(out_is_valid, out_position); + APPEND_SINGLE_VALUE(); + } + ++out_position; + } + } + } + } else { + // EMIT_NULL + + // Filter null values are appended to output as null whether the + // value in the corresponding slot is valid or not + if (values_valid_block.AllSet()) { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + const bool filter_not_null = + BitUtil::GetBit(filter_is_valid, filter_offset + in_position); + if (filter_not_null && + BitUtil::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + BitUtil::SetBit(out_is_valid, out_position++); + APPEND_SINGLE_VALUE(); + } else if (!filter_not_null) { + offset_builder.UnsafeAppend(offset); + ++out_position; + } + } + } else { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + const bool filter_not_null = + BitUtil::GetBit(filter_is_valid, filter_offset + in_position); + if (filter_not_null && + BitUtil::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) { + BitUtil::SetBit(out_is_valid, out_position); + APPEND_SINGLE_VALUE(); + } + ++out_position; + } else if (!filter_not_null) { + offset_builder.UnsafeAppend(offset); + ++out_position; + } + } + } + } + } + } + offset_builder.UnsafeAppend(offset); + out->length = output_length; + RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); + return data_builder.Finish(&out->buffers[2]); +} + +#undef BINARY_FILTER_SETUP_COMMON +#undef APPEND_RAW_DATA +#undef APPEND_SINGLE_VALUE + Status BinaryFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - FilterOptions::NullSelectionBehavior null_selection = - FilterState::Get(ctx).null_selection_behavior; - - const ArrayData& values = *batch[0].array(); - const ArrayData& filter = *batch[1].array(); - int64_t output_length = GetFilterOutputSize(filter, null_selection); - ArrayData* out_arr = out->mutable_array(); - - // The output precomputed null count is unknown except in the narrow - // condition that all the values are non-null and the filter will not cause - // any new nulls to be created. - if (values.null_count == 0 && - (null_selection == FilterOptions::DROP || filter.null_count == 0)) { - out_arr->null_count = 0; - } else { - out_arr->null_count = kUnknownNullCount; - } - Type::type type_id = values.type->id(); - if (values.null_count == 0 && filter.null_count == 0) { - // Faster no-nulls case - if (is_binary_like(type_id)) { + FilterOptions::NullSelectionBehavior null_selection = + FilterState::Get(ctx).null_selection_behavior; + + const ArrayData& values = *batch[0].array(); + const ArrayData& filter = *batch[1].array(); + int64_t output_length = GetFilterOutputSize(filter, null_selection); + ArrayData* out_arr = out->mutable_array(); + + // The output precomputed null count is unknown except in the narrow + // condition that all the values are non-null and the filter will not cause + // any new nulls to be created. + if (values.null_count == 0 && + (null_selection == FilterOptions::DROP || filter.null_count == 0)) { + out_arr->null_count = 0; + } else { + out_arr->null_count = kUnknownNullCount; + } + Type::type type_id = values.type->id(); + if (values.null_count == 0 && filter.null_count == 0) { + // Faster no-nulls case + if (is_binary_like(type_id)) { RETURN_NOT_OK(BinaryFilterNonNullImpl<BinaryType>( ctx, values, filter, output_length, null_selection, out_arr)); - } else if (is_large_binary_like(type_id)) { + } else if (is_large_binary_like(type_id)) { RETURN_NOT_OK(BinaryFilterNonNullImpl<LargeBinaryType>( ctx, values, filter, output_length, null_selection, out_arr)); - } else { - DCHECK(false); - } - } else { - // Output may have nulls + } else { + DCHECK(false); + } + } else { + // Output may have nulls RETURN_NOT_OK(ctx->AllocateBitmap(output_length).Value(&out_arr->buffers[0])); - if (is_binary_like(type_id)) { + if (is_binary_like(type_id)) { RETURN_NOT_OK(BinaryFilterImpl<BinaryType>(ctx, values, filter, output_length, null_selection, out_arr)); - } else if (is_large_binary_like(type_id)) { + } else if (is_large_binary_like(type_id)) { RETURN_NOT_OK(BinaryFilterImpl<LargeBinaryType>(ctx, values, filter, output_length, null_selection, out_arr)); - } else { - DCHECK(false); - } - } + } else { + DCHECK(false); + } + } return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Null take and filter - +} + +// ---------------------------------------------------------------------- +// Null take and filter + Status NullTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (TakeState::Get(ctx).boundscheck) { + if (TakeState::Get(ctx).boundscheck) { RETURN_NOT_OK(CheckIndexBounds(*batch[1].array(), batch[0].length())); - } - // batch.length doesn't take into account the take indices - auto new_length = batch[1].array()->length; - out->value = std::make_shared<NullArray>(new_length)->data(); + } + // batch.length doesn't take into account the take indices + auto new_length = batch[1].array()->length; + out->value = std::make_shared<NullArray>(new_length)->data(); return Status::OK(); -} - +} + Status NullFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - int64_t output_length = GetFilterOutputSize( - *batch[1].array(), FilterState::Get(ctx).null_selection_behavior); - out->value = std::make_shared<NullArray>(output_length)->data(); + int64_t output_length = GetFilterOutputSize( + *batch[1].array(), FilterState::Get(ctx).null_selection_behavior); + out->value = std::make_shared<NullArray>(output_length)->data(); return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Dictionary take and filter - +} + +// ---------------------------------------------------------------------- +// Dictionary take and filter + Status DictionaryTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DictionaryArray values(batch[0].array()); - Datum result; + DictionaryArray values(batch[0].array()); + Datum result; RETURN_NOT_OK( Take(Datum(values.indices()), batch[1], TakeState::Get(ctx), ctx->exec_context()) .Value(&result)); - DictionaryArray taken_values(values.type(), result.make_array(), values.dictionary()); - out->value = taken_values.data(); + DictionaryArray taken_values(values.type(), result.make_array(), values.dictionary()); + out->value = taken_values.data(); return Status::OK(); -} - +} + Status DictionaryFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - DictionaryArray dict_values(batch[0].array()); - Datum result; + DictionaryArray dict_values(batch[0].array()); + Datum result; RETURN_NOT_OK(Filter(Datum(dict_values.indices()), batch[1].array(), FilterState::Get(ctx), ctx->exec_context()) .Value(&result)); - DictionaryArray filtered_values(dict_values.type(), result.make_array(), - dict_values.dictionary()); - out->value = filtered_values.data(); + DictionaryArray filtered_values(dict_values.type(), result.make_array(), + dict_values.dictionary()); + out->value = filtered_values.data(); return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Extension take and filter - +} + +// ---------------------------------------------------------------------- +// Extension take and filter + Status ExtensionTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExtensionArray values(batch[0].array()); - Datum result; + ExtensionArray values(batch[0].array()); + Datum result; RETURN_NOT_OK( Take(Datum(values.storage()), batch[1], TakeState::Get(ctx), ctx->exec_context()) .Value(&result)); - ExtensionArray taken_values(values.type(), result.make_array()); - out->value = taken_values.data(); + ExtensionArray taken_values(values.type(), result.make_array()); + out->value = taken_values.data(); return Status::OK(); -} - +} + Status ExtensionFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExtensionArray ext_values(batch[0].array()); - Datum result; + ExtensionArray ext_values(batch[0].array()); + Datum result; RETURN_NOT_OK(Filter(Datum(ext_values.storage()), batch[1].array(), FilterState::Get(ctx), ctx->exec_context()) .Value(&result)); - ExtensionArray filtered_values(ext_values.type(), result.make_array()); - out->value = filtered_values.data(); + ExtensionArray filtered_values(ext_values.type(), result.make_array()); + out->value = filtered_values.data(); return Status::OK(); -} - -// ---------------------------------------------------------------------- -// Implement take for other data types where there is less performance -// sensitivity by visiting the selected indices. - -// Use CRTP to dispatch to type-specific processing of take indices for each -// unsigned integer type. -template <typename Impl, typename Type> -struct Selection { - using ValuesArrayType = typename TypeTraits<Type>::ArrayType; - - // Forwards the generic value visitors to the take index visitor template - template <typename IndexCType> - struct TakeAdapter { - static constexpr bool is_take = true; - - Impl* impl; - explicit TakeAdapter(Impl* impl) : impl(impl) {} - template <typename ValidVisitor, typename NullVisitor> - Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { - return impl->template VisitTake<IndexCType>(std::forward<ValidVisitor>(visit_valid), - std::forward<NullVisitor>(visit_null)); - } - }; - - // Forwards the generic value visitors to the VisitFilter template - struct FilterAdapter { - static constexpr bool is_take = false; - - Impl* impl; - explicit FilterAdapter(Impl* impl) : impl(impl) {} - template <typename ValidVisitor, typename NullVisitor> - Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { - return impl->VisitFilter(std::forward<ValidVisitor>(visit_valid), - std::forward<NullVisitor>(visit_null)); - } - }; - - KernelContext* ctx; - std::shared_ptr<ArrayData> values; - std::shared_ptr<ArrayData> selection; - int64_t output_length; - ArrayData* out; - TypedBufferBuilder<bool> validity_builder; - - Selection(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) - : ctx(ctx), - values(batch[0].array()), - selection(batch[1].array()), - output_length(output_length), - out(out->mutable_array()), - validity_builder(ctx->memory_pool()) {} - - virtual ~Selection() = default; - - Status FinishCommon() { - out->buffers.resize(values->buffers.size()); - out->length = validity_builder.length(); - out->null_count = validity_builder.false_count(); - return validity_builder.Finish(&out->buffers[0]); - } - - template <typename IndexCType, typename ValidVisitor, typename NullVisitor> - Status VisitTake(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { - const auto indices_values = selection->GetValues<IndexCType>(1); - const uint8_t* is_valid = GetValidityBitmap(*selection); - OptionalBitIndexer indices_is_valid(selection->buffers[0], selection->offset); - OptionalBitIndexer values_is_valid(values->buffers[0], values->offset); - - const bool values_have_nulls = values->MayHaveNulls(); - OptionalBitBlockCounter bit_counter(is_valid, selection->offset, selection->length); - int64_t position = 0; - while (position < selection->length) { - BitBlockCount block = bit_counter.NextBlock(); - const bool indices_have_nulls = block.popcount < block.length; - if (!indices_have_nulls && !values_have_nulls) { - // Fastest path, neither indices nor values have nulls - validity_builder.UnsafeAppend(block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - RETURN_NOT_OK(visit_valid(indices_values[position++])); - } - } else if (block.popcount > 0) { - // Since we have to branch on whether the indices are null or not, we - // combine the "non-null indices block but some values null" and - // "some-null indices block but values non-null" into a single loop. - for (int64_t i = 0; i < block.length; ++i) { - if ((!indices_have_nulls || indices_is_valid[position]) && - values_is_valid[indices_values[position]]) { - validity_builder.UnsafeAppend(true); - RETURN_NOT_OK(visit_valid(indices_values[position])); - } else { - validity_builder.UnsafeAppend(false); - RETURN_NOT_OK(visit_null()); - } - ++position; - } - } else { - // The whole block is null - validity_builder.UnsafeAppend(block.length, false); - for (int64_t i = 0; i < block.length; ++i) { - RETURN_NOT_OK(visit_null()); - } - position += block.length; - } - } - return Status::OK(); - } - - // We use the NullVisitor both for "selected" nulls as well as "emitted" - // nulls coming from the filter when using FilterOptions::EMIT_NULL - template <typename ValidVisitor, typename NullVisitor> - Status VisitFilter(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { - auto null_selection = FilterState::Get(ctx).null_selection_behavior; - - const auto filter_data = selection->buffers[1]->data(); - - const uint8_t* filter_is_valid = GetValidityBitmap(*selection); - const int64_t filter_offset = selection->offset; - OptionalBitIndexer values_is_valid(values->buffers[0], values->offset); - - // We use 3 block counters for fast scanning of the filter - // - // * values_valid_counter: for values null/not-null - // * filter_valid_counter: for filter null/not-null - // * filter_counter: for filter true/false - OptionalBitBlockCounter values_valid_counter(GetValidityBitmap(*values), - values->offset, values->length); - OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset, - selection->length); - BitBlockCounter filter_counter(filter_data, filter_offset, selection->length); - int64_t in_position = 0; - - auto AppendNotNull = [&](int64_t index) -> Status { - validity_builder.UnsafeAppend(true); - return visit_valid(index); - }; - - auto AppendNull = [&]() -> Status { - validity_builder.UnsafeAppend(false); - return visit_null(); - }; - - auto AppendMaybeNull = [&](int64_t index) -> Status { - if (values_is_valid[index]) { - return AppendNotNull(index); - } else { - return AppendNull(); - } - }; - - while (in_position < selection->length) { - BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); - BitBlockCount values_valid_block = values_valid_counter.NextWord(); - BitBlockCount filter_block = filter_counter.NextWord(); - if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) { - // For this exceedingly common case in low-selectivity filters we can - // skip further analysis of the data and move on to the next block. - in_position += filter_block.length; - } else if (filter_valid_block.AllSet()) { - // Simpler path: no filter values are null - if (filter_block.AllSet()) { - // Fastest path: filter values are all true and not null - if (values_valid_block.AllSet()) { - // The values aren't null either - validity_builder.UnsafeAppend(filter_block.length, true); - for (int64_t i = 0; i < filter_block.length; ++i) { - RETURN_NOT_OK(visit_valid(in_position++)); - } - } else { - // Some of the values in this block are null - for (int64_t i = 0; i < filter_block.length; ++i) { - RETURN_NOT_OK(AppendMaybeNull(in_position++)); - } - } - } else { // !filter_block.AllSet() - // Some of the filter values are false, but all not null - if (values_valid_block.AllSet()) { - // All the values are not-null, so we can skip null checking for - // them - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { - RETURN_NOT_OK(AppendNotNull(in_position)); - } - ++in_position; - } - } else { - // Some of the values in the block are null, so we have to check - // each one - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { - RETURN_NOT_OK(AppendMaybeNull(in_position)); - } - ++in_position; - } - } - } - } else { // !filter_valid_block.AllSet() - // Some of the filter values are null, so we have to handle the DROP - // versus EMIT_NULL null selection behavior. - if (null_selection == FilterOptions::DROP) { - // Filter null values are treated as false. - for (int64_t i = 0; i < filter_block.length; ++i) { - if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) && - BitUtil::GetBit(filter_data, filter_offset + in_position)) { - RETURN_NOT_OK(AppendMaybeNull(in_position)); - } - ++in_position; - } - } else { - // Filter null values are appended to output as null whether the - // value in the corresponding slot is valid or not - for (int64_t i = 0; i < filter_block.length; ++i) { - const bool filter_not_null = - BitUtil::GetBit(filter_is_valid, filter_offset + in_position); - if (filter_not_null && - BitUtil::GetBit(filter_data, filter_offset + in_position)) { - RETURN_NOT_OK(AppendMaybeNull(in_position)); - } else if (!filter_not_null) { - // EMIT_NULL case - RETURN_NOT_OK(AppendNull()); - } - ++in_position; - } - } - } - } - return Status::OK(); - } - - virtual Status Init() { return Status::OK(); } - - // Implementation specific finish logic - virtual Status Finish() = 0; - - Status ExecTake() { - RETURN_NOT_OK(this->validity_builder.Reserve(output_length)); - RETURN_NOT_OK(Init()); - int index_width = GetByteWidth(*this->selection->type); - - // CTRP dispatch here - switch (index_width) { - case 1: { - Status s = - static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint8_t>>(); - RETURN_NOT_OK(s); - } break; - case 2: { - Status s = - static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint16_t>>(); - RETURN_NOT_OK(s); - } break; - case 4: { - Status s = - static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint32_t>>(); - RETURN_NOT_OK(s); - } break; - case 8: { - Status s = - static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint64_t>>(); - RETURN_NOT_OK(s); - } break; - default: - DCHECK(false) << "Invalid index width"; - break; - } - RETURN_NOT_OK(this->FinishCommon()); - return Finish(); - } - - Status ExecFilter() { - RETURN_NOT_OK(this->validity_builder.Reserve(output_length)); - RETURN_NOT_OK(Init()); - // CRTP dispatch - Status s = static_cast<Impl*>(this)->template GenerateOutput<FilterAdapter>(); - RETURN_NOT_OK(s); - RETURN_NOT_OK(this->FinishCommon()); - return Finish(); - } -}; - -#define LIFT_BASE_MEMBERS() \ - using ValuesArrayType = typename Base::ValuesArrayType; \ - using Base::ctx; \ - using Base::values; \ - using Base::selection; \ - using Base::output_length; \ - using Base::out; \ - using Base::validity_builder - -static inline Status VisitNoop() { return Status::OK(); } - -// A selection implementation for 32-bit and 64-bit variable binary -// types. Common generated kernels are shared between Binary/String and -// LargeBinary/LargeString -template <typename Type> -struct VarBinaryImpl : public Selection<VarBinaryImpl<Type>, Type> { - using offset_type = typename Type::offset_type; - - using Base = Selection<VarBinaryImpl<Type>, Type>; - LIFT_BASE_MEMBERS(); - - std::shared_ptr<ArrayData> values_as_binary; - TypedBufferBuilder<offset_type> offset_builder; - TypedBufferBuilder<uint8_t> data_builder; - - static constexpr int64_t kOffsetLimit = std::numeric_limits<offset_type>::max() - 1; - - VarBinaryImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, - Datum* out) - : Base(ctx, batch, output_length, out), - offset_builder(ctx->memory_pool()), - data_builder(ctx->memory_pool()) {} - - template <typename Adapter> - Status GenerateOutput() { - ValuesArrayType typed_values(this->values_as_binary); - - // Presize the data builder with a rough estimate of the required data size - if (values->length > 0) { - const double mean_value_length = - (typed_values.total_values_length() / static_cast<double>(values->length)); - - // TODO: See if possible to reduce output_length for take/filter cases - // where there are nulls in the selection array - RETURN_NOT_OK( - data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length))); - } - int64_t space_available = data_builder.capacity(); - - const offset_type* raw_offsets = typed_values.raw_value_offsets(); - const uint8_t* raw_data = typed_values.raw_data(); - - offset_type offset = 0; - Adapter adapter(this); - RETURN_NOT_OK(adapter.Generate( - [&](int64_t index) { - offset_builder.UnsafeAppend(offset); - offset_type val_offset = raw_offsets[index]; - offset_type val_size = raw_offsets[index + 1] - val_offset; - - // Use static property to prune this code from the filter path in - // optimized builds - if (Adapter::is_take && - ARROW_PREDICT_FALSE(static_cast<int64_t>(offset) + - static_cast<int64_t>(val_size)) > kOffsetLimit) { - return Status::Invalid("Take operation overflowed binary array capacity"); - } - offset += val_size; - if (ARROW_PREDICT_FALSE(val_size > space_available)) { - RETURN_NOT_OK(data_builder.Reserve(val_size)); - space_available = data_builder.capacity() - data_builder.length(); - } - data_builder.UnsafeAppend(raw_data + val_offset, val_size); - space_available -= val_size; - return Status::OK(); - }, - [&]() { - offset_builder.UnsafeAppend(offset); - return Status::OK(); - })); - offset_builder.UnsafeAppend(offset); - return Status::OK(); - } - - Status Init() override { - ARROW_ASSIGN_OR_RAISE(this->values_as_binary, - GetArrayView(this->values, TypeTraits<Type>::type_singleton())); - return offset_builder.Reserve(output_length + 1); - } - - Status Finish() override { - RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); - return data_builder.Finish(&out->buffers[2]); - } -}; - -struct FSBImpl : public Selection<FSBImpl, FixedSizeBinaryType> { - using Base = Selection<FSBImpl, FixedSizeBinaryType>; - LIFT_BASE_MEMBERS(); - - TypedBufferBuilder<uint8_t> data_builder; - - FSBImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) - : Base(ctx, batch, output_length, out), data_builder(ctx->memory_pool()) {} - - template <typename Adapter> - Status GenerateOutput() { - FixedSizeBinaryArray typed_values(this->values); - int32_t value_size = typed_values.byte_width(); - - RETURN_NOT_OK(data_builder.Reserve(value_size * output_length)); - Adapter adapter(this); - return adapter.Generate( - [&](int64_t index) { - auto val = typed_values.GetView(index); - data_builder.UnsafeAppend(reinterpret_cast<const uint8_t*>(val.data()), - value_size); - return Status::OK(); - }, - [&]() { - data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00)); - return Status::OK(); - }); - } - - Status Finish() override { return data_builder.Finish(&out->buffers[1]); } -}; - -template <typename Type> -struct ListImpl : public Selection<ListImpl<Type>, Type> { - using offset_type = typename Type::offset_type; - - using Base = Selection<ListImpl<Type>, Type>; - LIFT_BASE_MEMBERS(); - - TypedBufferBuilder<offset_type> offset_builder; - typename TypeTraits<Type>::OffsetBuilderType child_index_builder; - - ListImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) - : Base(ctx, batch, output_length, out), - offset_builder(ctx->memory_pool()), - child_index_builder(ctx->memory_pool()) {} - - template <typename Adapter> - Status GenerateOutput() { - ValuesArrayType typed_values(this->values); - - // TODO presize child_index_builder with a similar heuristic as VarBinaryImpl - - offset_type offset = 0; - Adapter adapter(this); - RETURN_NOT_OK(adapter.Generate( - [&](int64_t index) { - offset_builder.UnsafeAppend(offset); - offset_type value_offset = typed_values.value_offset(index); - offset_type value_length = typed_values.value_length(index); - offset += value_length; - RETURN_NOT_OK(child_index_builder.Reserve(value_length)); - for (offset_type j = value_offset; j < value_offset + value_length; ++j) { - child_index_builder.UnsafeAppend(j); - } - return Status::OK(); - }, - [&]() { - offset_builder.UnsafeAppend(offset); - return Status::OK(); - })); - offset_builder.UnsafeAppend(offset); - return Status::OK(); - } - - Status Init() override { - RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); - return Status::OK(); - } - - Status Finish() override { - std::shared_ptr<Array> child_indices; - RETURN_NOT_OK(child_index_builder.Finish(&child_indices)); - - ValuesArrayType typed_values(this->values); - - // No need to boundscheck the child values indices - ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child, - Take(*typed_values.values(), *child_indices, - TakeOptions::NoBoundsCheck(), ctx->exec_context())); - RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); - out->child_data = {taken_child->data()}; - return Status::OK(); - } -}; - +} + +// ---------------------------------------------------------------------- +// Implement take for other data types where there is less performance +// sensitivity by visiting the selected indices. + +// Use CRTP to dispatch to type-specific processing of take indices for each +// unsigned integer type. +template <typename Impl, typename Type> +struct Selection { + using ValuesArrayType = typename TypeTraits<Type>::ArrayType; + + // Forwards the generic value visitors to the take index visitor template + template <typename IndexCType> + struct TakeAdapter { + static constexpr bool is_take = true; + + Impl* impl; + explicit TakeAdapter(Impl* impl) : impl(impl) {} + template <typename ValidVisitor, typename NullVisitor> + Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { + return impl->template VisitTake<IndexCType>(std::forward<ValidVisitor>(visit_valid), + std::forward<NullVisitor>(visit_null)); + } + }; + + // Forwards the generic value visitors to the VisitFilter template + struct FilterAdapter { + static constexpr bool is_take = false; + + Impl* impl; + explicit FilterAdapter(Impl* impl) : impl(impl) {} + template <typename ValidVisitor, typename NullVisitor> + Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { + return impl->VisitFilter(std::forward<ValidVisitor>(visit_valid), + std::forward<NullVisitor>(visit_null)); + } + }; + + KernelContext* ctx; + std::shared_ptr<ArrayData> values; + std::shared_ptr<ArrayData> selection; + int64_t output_length; + ArrayData* out; + TypedBufferBuilder<bool> validity_builder; + + Selection(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) + : ctx(ctx), + values(batch[0].array()), + selection(batch[1].array()), + output_length(output_length), + out(out->mutable_array()), + validity_builder(ctx->memory_pool()) {} + + virtual ~Selection() = default; + + Status FinishCommon() { + out->buffers.resize(values->buffers.size()); + out->length = validity_builder.length(); + out->null_count = validity_builder.false_count(); + return validity_builder.Finish(&out->buffers[0]); + } + + template <typename IndexCType, typename ValidVisitor, typename NullVisitor> + Status VisitTake(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { + const auto indices_values = selection->GetValues<IndexCType>(1); + const uint8_t* is_valid = GetValidityBitmap(*selection); + OptionalBitIndexer indices_is_valid(selection->buffers[0], selection->offset); + OptionalBitIndexer values_is_valid(values->buffers[0], values->offset); + + const bool values_have_nulls = values->MayHaveNulls(); + OptionalBitBlockCounter bit_counter(is_valid, selection->offset, selection->length); + int64_t position = 0; + while (position < selection->length) { + BitBlockCount block = bit_counter.NextBlock(); + const bool indices_have_nulls = block.popcount < block.length; + if (!indices_have_nulls && !values_have_nulls) { + // Fastest path, neither indices nor values have nulls + validity_builder.UnsafeAppend(block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + RETURN_NOT_OK(visit_valid(indices_values[position++])); + } + } else if (block.popcount > 0) { + // Since we have to branch on whether the indices are null or not, we + // combine the "non-null indices block but some values null" and + // "some-null indices block but values non-null" into a single loop. + for (int64_t i = 0; i < block.length; ++i) { + if ((!indices_have_nulls || indices_is_valid[position]) && + values_is_valid[indices_values[position]]) { + validity_builder.UnsafeAppend(true); + RETURN_NOT_OK(visit_valid(indices_values[position])); + } else { + validity_builder.UnsafeAppend(false); + RETURN_NOT_OK(visit_null()); + } + ++position; + } + } else { + // The whole block is null + validity_builder.UnsafeAppend(block.length, false); + for (int64_t i = 0; i < block.length; ++i) { + RETURN_NOT_OK(visit_null()); + } + position += block.length; + } + } + return Status::OK(); + } + + // We use the NullVisitor both for "selected" nulls as well as "emitted" + // nulls coming from the filter when using FilterOptions::EMIT_NULL + template <typename ValidVisitor, typename NullVisitor> + Status VisitFilter(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { + auto null_selection = FilterState::Get(ctx).null_selection_behavior; + + const auto filter_data = selection->buffers[1]->data(); + + const uint8_t* filter_is_valid = GetValidityBitmap(*selection); + const int64_t filter_offset = selection->offset; + OptionalBitIndexer values_is_valid(values->buffers[0], values->offset); + + // We use 3 block counters for fast scanning of the filter + // + // * values_valid_counter: for values null/not-null + // * filter_valid_counter: for filter null/not-null + // * filter_counter: for filter true/false + OptionalBitBlockCounter values_valid_counter(GetValidityBitmap(*values), + values->offset, values->length); + OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset, + selection->length); + BitBlockCounter filter_counter(filter_data, filter_offset, selection->length); + int64_t in_position = 0; + + auto AppendNotNull = [&](int64_t index) -> Status { + validity_builder.UnsafeAppend(true); + return visit_valid(index); + }; + + auto AppendNull = [&]() -> Status { + validity_builder.UnsafeAppend(false); + return visit_null(); + }; + + auto AppendMaybeNull = [&](int64_t index) -> Status { + if (values_is_valid[index]) { + return AppendNotNull(index); + } else { + return AppendNull(); + } + }; + + while (in_position < selection->length) { + BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); + BitBlockCount values_valid_block = values_valid_counter.NextWord(); + BitBlockCount filter_block = filter_counter.NextWord(); + if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) { + // For this exceedingly common case in low-selectivity filters we can + // skip further analysis of the data and move on to the next block. + in_position += filter_block.length; + } else if (filter_valid_block.AllSet()) { + // Simpler path: no filter values are null + if (filter_block.AllSet()) { + // Fastest path: filter values are all true and not null + if (values_valid_block.AllSet()) { + // The values aren't null either + validity_builder.UnsafeAppend(filter_block.length, true); + for (int64_t i = 0; i < filter_block.length; ++i) { + RETURN_NOT_OK(visit_valid(in_position++)); + } + } else { + // Some of the values in this block are null + for (int64_t i = 0; i < filter_block.length; ++i) { + RETURN_NOT_OK(AppendMaybeNull(in_position++)); + } + } + } else { // !filter_block.AllSet() + // Some of the filter values are false, but all not null + if (values_valid_block.AllSet()) { + // All the values are not-null, so we can skip null checking for + // them + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { + RETURN_NOT_OK(AppendNotNull(in_position)); + } + ++in_position; + } + } else { + // Some of the values in the block are null, so we have to check + // each one + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_data, filter_offset + in_position)) { + RETURN_NOT_OK(AppendMaybeNull(in_position)); + } + ++in_position; + } + } + } + } else { // !filter_valid_block.AllSet() + // Some of the filter values are null, so we have to handle the DROP + // versus EMIT_NULL null selection behavior. + if (null_selection == FilterOptions::DROP) { + // Filter null values are treated as false. + for (int64_t i = 0; i < filter_block.length; ++i) { + if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) && + BitUtil::GetBit(filter_data, filter_offset + in_position)) { + RETURN_NOT_OK(AppendMaybeNull(in_position)); + } + ++in_position; + } + } else { + // Filter null values are appended to output as null whether the + // value in the corresponding slot is valid or not + for (int64_t i = 0; i < filter_block.length; ++i) { + const bool filter_not_null = + BitUtil::GetBit(filter_is_valid, filter_offset + in_position); + if (filter_not_null && + BitUtil::GetBit(filter_data, filter_offset + in_position)) { + RETURN_NOT_OK(AppendMaybeNull(in_position)); + } else if (!filter_not_null) { + // EMIT_NULL case + RETURN_NOT_OK(AppendNull()); + } + ++in_position; + } + } + } + } + return Status::OK(); + } + + virtual Status Init() { return Status::OK(); } + + // Implementation specific finish logic + virtual Status Finish() = 0; + + Status ExecTake() { + RETURN_NOT_OK(this->validity_builder.Reserve(output_length)); + RETURN_NOT_OK(Init()); + int index_width = GetByteWidth(*this->selection->type); + + // CTRP dispatch here + switch (index_width) { + case 1: { + Status s = + static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint8_t>>(); + RETURN_NOT_OK(s); + } break; + case 2: { + Status s = + static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint16_t>>(); + RETURN_NOT_OK(s); + } break; + case 4: { + Status s = + static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint32_t>>(); + RETURN_NOT_OK(s); + } break; + case 8: { + Status s = + static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint64_t>>(); + RETURN_NOT_OK(s); + } break; + default: + DCHECK(false) << "Invalid index width"; + break; + } + RETURN_NOT_OK(this->FinishCommon()); + return Finish(); + } + + Status ExecFilter() { + RETURN_NOT_OK(this->validity_builder.Reserve(output_length)); + RETURN_NOT_OK(Init()); + // CRTP dispatch + Status s = static_cast<Impl*>(this)->template GenerateOutput<FilterAdapter>(); + RETURN_NOT_OK(s); + RETURN_NOT_OK(this->FinishCommon()); + return Finish(); + } +}; + +#define LIFT_BASE_MEMBERS() \ + using ValuesArrayType = typename Base::ValuesArrayType; \ + using Base::ctx; \ + using Base::values; \ + using Base::selection; \ + using Base::output_length; \ + using Base::out; \ + using Base::validity_builder + +static inline Status VisitNoop() { return Status::OK(); } + +// A selection implementation for 32-bit and 64-bit variable binary +// types. Common generated kernels are shared between Binary/String and +// LargeBinary/LargeString +template <typename Type> +struct VarBinaryImpl : public Selection<VarBinaryImpl<Type>, Type> { + using offset_type = typename Type::offset_type; + + using Base = Selection<VarBinaryImpl<Type>, Type>; + LIFT_BASE_MEMBERS(); + + std::shared_ptr<ArrayData> values_as_binary; + TypedBufferBuilder<offset_type> offset_builder; + TypedBufferBuilder<uint8_t> data_builder; + + static constexpr int64_t kOffsetLimit = std::numeric_limits<offset_type>::max() - 1; + + VarBinaryImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, + Datum* out) + : Base(ctx, batch, output_length, out), + offset_builder(ctx->memory_pool()), + data_builder(ctx->memory_pool()) {} + + template <typename Adapter> + Status GenerateOutput() { + ValuesArrayType typed_values(this->values_as_binary); + + // Presize the data builder with a rough estimate of the required data size + if (values->length > 0) { + const double mean_value_length = + (typed_values.total_values_length() / static_cast<double>(values->length)); + + // TODO: See if possible to reduce output_length for take/filter cases + // where there are nulls in the selection array + RETURN_NOT_OK( + data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length))); + } + int64_t space_available = data_builder.capacity(); + + const offset_type* raw_offsets = typed_values.raw_value_offsets(); + const uint8_t* raw_data = typed_values.raw_data(); + + offset_type offset = 0; + Adapter adapter(this); + RETURN_NOT_OK(adapter.Generate( + [&](int64_t index) { + offset_builder.UnsafeAppend(offset); + offset_type val_offset = raw_offsets[index]; + offset_type val_size = raw_offsets[index + 1] - val_offset; + + // Use static property to prune this code from the filter path in + // optimized builds + if (Adapter::is_take && + ARROW_PREDICT_FALSE(static_cast<int64_t>(offset) + + static_cast<int64_t>(val_size)) > kOffsetLimit) { + return Status::Invalid("Take operation overflowed binary array capacity"); + } + offset += val_size; + if (ARROW_PREDICT_FALSE(val_size > space_available)) { + RETURN_NOT_OK(data_builder.Reserve(val_size)); + space_available = data_builder.capacity() - data_builder.length(); + } + data_builder.UnsafeAppend(raw_data + val_offset, val_size); + space_available -= val_size; + return Status::OK(); + }, + [&]() { + offset_builder.UnsafeAppend(offset); + return Status::OK(); + })); + offset_builder.UnsafeAppend(offset); + return Status::OK(); + } + + Status Init() override { + ARROW_ASSIGN_OR_RAISE(this->values_as_binary, + GetArrayView(this->values, TypeTraits<Type>::type_singleton())); + return offset_builder.Reserve(output_length + 1); + } + + Status Finish() override { + RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); + return data_builder.Finish(&out->buffers[2]); + } +}; + +struct FSBImpl : public Selection<FSBImpl, FixedSizeBinaryType> { + using Base = Selection<FSBImpl, FixedSizeBinaryType>; + LIFT_BASE_MEMBERS(); + + TypedBufferBuilder<uint8_t> data_builder; + + FSBImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) + : Base(ctx, batch, output_length, out), data_builder(ctx->memory_pool()) {} + + template <typename Adapter> + Status GenerateOutput() { + FixedSizeBinaryArray typed_values(this->values); + int32_t value_size = typed_values.byte_width(); + + RETURN_NOT_OK(data_builder.Reserve(value_size * output_length)); + Adapter adapter(this); + return adapter.Generate( + [&](int64_t index) { + auto val = typed_values.GetView(index); + data_builder.UnsafeAppend(reinterpret_cast<const uint8_t*>(val.data()), + value_size); + return Status::OK(); + }, + [&]() { + data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00)); + return Status::OK(); + }); + } + + Status Finish() override { return data_builder.Finish(&out->buffers[1]); } +}; + +template <typename Type> +struct ListImpl : public Selection<ListImpl<Type>, Type> { + using offset_type = typename Type::offset_type; + + using Base = Selection<ListImpl<Type>, Type>; + LIFT_BASE_MEMBERS(); + + TypedBufferBuilder<offset_type> offset_builder; + typename TypeTraits<Type>::OffsetBuilderType child_index_builder; + + ListImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) + : Base(ctx, batch, output_length, out), + offset_builder(ctx->memory_pool()), + child_index_builder(ctx->memory_pool()) {} + + template <typename Adapter> + Status GenerateOutput() { + ValuesArrayType typed_values(this->values); + + // TODO presize child_index_builder with a similar heuristic as VarBinaryImpl + + offset_type offset = 0; + Adapter adapter(this); + RETURN_NOT_OK(adapter.Generate( + [&](int64_t index) { + offset_builder.UnsafeAppend(offset); + offset_type value_offset = typed_values.value_offset(index); + offset_type value_length = typed_values.value_length(index); + offset += value_length; + RETURN_NOT_OK(child_index_builder.Reserve(value_length)); + for (offset_type j = value_offset; j < value_offset + value_length; ++j) { + child_index_builder.UnsafeAppend(j); + } + return Status::OK(); + }, + [&]() { + offset_builder.UnsafeAppend(offset); + return Status::OK(); + })); + offset_builder.UnsafeAppend(offset); + return Status::OK(); + } + + Status Init() override { + RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); + return Status::OK(); + } + + Status Finish() override { + std::shared_ptr<Array> child_indices; + RETURN_NOT_OK(child_index_builder.Finish(&child_indices)); + + ValuesArrayType typed_values(this->values); + + // No need to boundscheck the child values indices + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child, + Take(*typed_values.values(), *child_indices, + TakeOptions::NoBoundsCheck(), ctx->exec_context())); + RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); + out->child_data = {taken_child->data()}; + return Status::OK(); + } +}; + struct DenseUnionImpl : public Selection<DenseUnionImpl, DenseUnionType> { using Base = Selection<DenseUnionImpl, DenseUnionType>; LIFT_BASE_MEMBERS(); @@ -1743,144 +1743,144 @@ struct DenseUnionImpl : public Selection<DenseUnionImpl, DenseUnionType> { } }; -struct FSLImpl : public Selection<FSLImpl, FixedSizeListType> { - Int64Builder child_index_builder; - - using Base = Selection<FSLImpl, FixedSizeListType>; - LIFT_BASE_MEMBERS(); - - FSLImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) - : Base(ctx, batch, output_length, out), child_index_builder(ctx->memory_pool()) {} - - template <typename Adapter> - Status GenerateOutput() { - ValuesArrayType typed_values(this->values); - int32_t list_size = typed_values.list_type()->list_size(); - - /// We must take list_size elements even for null elements of - /// indices. - RETURN_NOT_OK(child_index_builder.Reserve(output_length * list_size)); - - Adapter adapter(this); - return adapter.Generate( - [&](int64_t index) { - int64_t offset = index * list_size; - for (int64_t j = offset; j < offset + list_size; ++j) { - child_index_builder.UnsafeAppend(j); - } - return Status::OK(); - }, - [&]() { return child_index_builder.AppendNulls(list_size); }); - } - - Status Finish() override { - std::shared_ptr<Array> child_indices; - RETURN_NOT_OK(child_index_builder.Finish(&child_indices)); - - ValuesArrayType typed_values(this->values); - - // No need to boundscheck the child values indices - ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child, - Take(*typed_values.values(), *child_indices, - TakeOptions::NoBoundsCheck(), ctx->exec_context())); - out->child_data = {taken_child->data()}; - return Status::OK(); - } -}; - -// ---------------------------------------------------------------------- -// Struct selection implementations - -// We need a slightly different approach for StructType. For Take, we can -// invoke Take on each struct field's data with boundschecking disabled. For -// Filter on the other hand, if we naively call Filter on each field, then the -// filter output length will have to be redundantly computed. Thus, for Filter -// we instead convert the filter to selection indices and then invoke take. - -// Struct selection implementation. ONLY used for Take -struct StructImpl : public Selection<StructImpl, StructType> { - using Base = Selection<StructImpl, StructType>; - LIFT_BASE_MEMBERS(); - using Base::Base; - - template <typename Adapter> - Status GenerateOutput() { - StructArray typed_values(values); - Adapter adapter(this); - // There's nothing to do for Struct except to generate the validity bitmap - return adapter.Generate([&](int64_t index) { return Status::OK(); }, - /*visit_null=*/VisitNoop); - } - - Status Finish() override { - StructArray typed_values(values); - - // Select from children without boundschecking - out->child_data.resize(values->type->num_fields()); - for (int field_index = 0; field_index < values->type->num_fields(); ++field_index) { - ARROW_ASSIGN_OR_RAISE(Datum taken_field, - Take(Datum(typed_values.field(field_index)), Datum(selection), - TakeOptions::NoBoundsCheck(), ctx->exec_context())); - out->child_data[field_index] = taken_field.array(); - } - return Status::OK(); - } -}; - +struct FSLImpl : public Selection<FSLImpl, FixedSizeListType> { + Int64Builder child_index_builder; + + using Base = Selection<FSLImpl, FixedSizeListType>; + LIFT_BASE_MEMBERS(); + + FSLImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out) + : Base(ctx, batch, output_length, out), child_index_builder(ctx->memory_pool()) {} + + template <typename Adapter> + Status GenerateOutput() { + ValuesArrayType typed_values(this->values); + int32_t list_size = typed_values.list_type()->list_size(); + + /// We must take list_size elements even for null elements of + /// indices. + RETURN_NOT_OK(child_index_builder.Reserve(output_length * list_size)); + + Adapter adapter(this); + return adapter.Generate( + [&](int64_t index) { + int64_t offset = index * list_size; + for (int64_t j = offset; j < offset + list_size; ++j) { + child_index_builder.UnsafeAppend(j); + } + return Status::OK(); + }, + [&]() { return child_index_builder.AppendNulls(list_size); }); + } + + Status Finish() override { + std::shared_ptr<Array> child_indices; + RETURN_NOT_OK(child_index_builder.Finish(&child_indices)); + + ValuesArrayType typed_values(this->values); + + // No need to boundscheck the child values indices + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child, + Take(*typed_values.values(), *child_indices, + TakeOptions::NoBoundsCheck(), ctx->exec_context())); + out->child_data = {taken_child->data()}; + return Status::OK(); + } +}; + +// ---------------------------------------------------------------------- +// Struct selection implementations + +// We need a slightly different approach for StructType. For Take, we can +// invoke Take on each struct field's data with boundschecking disabled. For +// Filter on the other hand, if we naively call Filter on each field, then the +// filter output length will have to be redundantly computed. Thus, for Filter +// we instead convert the filter to selection indices and then invoke take. + +// Struct selection implementation. ONLY used for Take +struct StructImpl : public Selection<StructImpl, StructType> { + using Base = Selection<StructImpl, StructType>; + LIFT_BASE_MEMBERS(); + using Base::Base; + + template <typename Adapter> + Status GenerateOutput() { + StructArray typed_values(values); + Adapter adapter(this); + // There's nothing to do for Struct except to generate the validity bitmap + return adapter.Generate([&](int64_t index) { return Status::OK(); }, + /*visit_null=*/VisitNoop); + } + + Status Finish() override { + StructArray typed_values(values); + + // Select from children without boundschecking + out->child_data.resize(values->type->num_fields()); + for (int field_index = 0; field_index < values->type->num_fields(); ++field_index) { + ARROW_ASSIGN_OR_RAISE(Datum taken_field, + Take(Datum(typed_values.field(field_index)), Datum(selection), + TakeOptions::NoBoundsCheck(), ctx->exec_context())); + out->child_data[field_index] = taken_field.array(); + } + return Status::OK(); + } +}; + Status StructFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Transform filter to selection indices and then use Take. - std::shared_ptr<ArrayData> indices; + // Transform filter to selection indices and then use Take. + std::shared_ptr<ArrayData> indices; RETURN_NOT_OK(GetTakeIndices(*batch[1].array(), FilterState::Get(ctx).null_selection_behavior, ctx->memory_pool()) .Value(&indices)); - - Datum result; + + Datum result; RETURN_NOT_OK( Take(batch[0], Datum(indices), TakeOptions::NoBoundsCheck(), ctx->exec_context()) .Value(&result)); - out->value = result.array(); + out->value = result.array(); return Status::OK(); -} - -#undef LIFT_BASE_MEMBERS - -// ---------------------------------------------------------------------- -// Implement Filter metafunction - -Result<std::shared_ptr<RecordBatch>> FilterRecordBatch(const RecordBatch& batch, - const Datum& filter, - const FunctionOptions* options, - ExecContext* ctx) { - if (batch.num_rows() != filter.length()) { - return Status::Invalid("Filter inputs must all be the same length"); - } - - // Convert filter to selection vector/indices and use Take - const auto& filter_opts = *static_cast<const FilterOptions*>(options); - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr<ArrayData> indices, +} + +#undef LIFT_BASE_MEMBERS + +// ---------------------------------------------------------------------- +// Implement Filter metafunction + +Result<std::shared_ptr<RecordBatch>> FilterRecordBatch(const RecordBatch& batch, + const Datum& filter, + const FunctionOptions* options, + ExecContext* ctx) { + if (batch.num_rows() != filter.length()) { + return Status::Invalid("Filter inputs must all be the same length"); + } + + // Convert filter to selection vector/indices and use Take + const auto& filter_opts = *static_cast<const FilterOptions*>(options); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr<ArrayData> indices, GetTakeIndices(*filter.array(), filter_opts.null_selection_behavior, ctx->memory_pool())); - std::vector<std::shared_ptr<Array>> columns(batch.num_columns()); - for (int i = 0; i < batch.num_columns(); ++i) { - ARROW_ASSIGN_OR_RAISE(Datum out, Take(batch.column(i)->data(), Datum(indices), - TakeOptions::NoBoundsCheck(), ctx)); - columns[i] = out.make_array(); - } + std::vector<std::shared_ptr<Array>> columns(batch.num_columns()); + for (int i = 0; i < batch.num_columns(); ++i) { + ARROW_ASSIGN_OR_RAISE(Datum out, Take(batch.column(i)->data(), Datum(indices), + TakeOptions::NoBoundsCheck(), ctx)); + columns[i] = out.make_array(); + } return RecordBatch::Make(batch.schema(), indices->length, std::move(columns)); -} - -Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filter, - const FunctionOptions* options, - ExecContext* ctx) { - if (table.num_rows() != filter.length()) { - return Status::Invalid("Filter inputs must all be the same length"); - } +} + +Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filter, + const FunctionOptions* options, + ExecContext* ctx) { + if (table.num_rows() != filter.length()) { + return Status::Invalid("Filter inputs must all be the same length"); + } if (table.num_rows() == 0) { return Table::Make(table.schema(), table.columns(), 0); } - + // Last input element will be the filter array const int num_columns = table.num_columns(); std::vector<ArrayVector> inputs(num_columns + 1); @@ -1914,7 +1914,7 @@ Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filt for (int64_t i = 0; i < num_chunks; ++i) { const ArrayData& filter_chunk = *inputs.back()[i]->data(); - ARROW_ASSIGN_OR_RAISE( + ARROW_ASSIGN_OR_RAISE( const auto indices, GetTakeIndices(filter_chunk, filter_opts.null_selection_behavior, ctx->memory_pool())); @@ -1930,7 +1930,7 @@ Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filt } out_num_rows += indices->length; } - } + } ChunkedArrayVector out_chunks(num_columns); for (int i = 0; i < num_columns; ++i) { @@ -1938,10 +1938,10 @@ Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filt table.column(i)->type()); } return Table::Make(table.schema(), std::move(out_chunks), out_num_rows); -} - -static auto kDefaultFilterOptions = FilterOptions::Defaults(); - +} + +static auto kDefaultFilterOptions = FilterOptions::Defaults(); + const FunctionDoc filter_doc( "Filter with a boolean selection filter", ("The output is populated with values from the input at positions\n" @@ -1949,244 +1949,244 @@ const FunctionDoc filter_doc( "are handled based on FilterOptions."), {"input", "selection_filter"}, "FilterOptions"); -class FilterMetaFunction : public MetaFunction { - public: - FilterMetaFunction() +class FilterMetaFunction : public MetaFunction { + public: + FilterMetaFunction() : MetaFunction("filter", Arity::Binary(), &filter_doc, &kDefaultFilterOptions) {} - - Result<Datum> ExecuteImpl(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - if (args[1].type()->id() != Type::BOOL) { - return Status::NotImplemented("Filter argument must be boolean type"); - } - - if (args[0].kind() == Datum::RECORD_BATCH) { - auto values_batch = args[0].record_batch(); - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr<RecordBatch> out_batch, - FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx)); - return Datum(out_batch); - } else if (args[0].kind() == Datum::TABLE) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> out_table, - FilterTable(*args[0].table(), args[1], options, ctx)); - return Datum(out_table); - } else { - return CallFunction("array_filter", args, options, ctx); - } - } -}; - -// ---------------------------------------------------------------------- -// Implement Take metafunction - -// Shorthand naming of these functions -// A -> Array -// C -> ChunkedArray -// R -> RecordBatch -// T -> Table - -Result<std::shared_ptr<Array>> TakeAA(const Array& values, const Array& indices, - const TakeOptions& options, ExecContext* ctx) { - ARROW_ASSIGN_OR_RAISE(Datum result, - CallFunction("array_take", {values, indices}, &options, ctx)); - return result.make_array(); -} - -Result<std::shared_ptr<ChunkedArray>> TakeCA(const ChunkedArray& values, - const Array& indices, - const TakeOptions& options, - ExecContext* ctx) { - auto num_chunks = values.num_chunks(); - std::vector<std::shared_ptr<Array>> new_chunks(1); // Hard-coded 1 for now - std::shared_ptr<Array> current_chunk; - - // Case 1: `values` has a single chunk, so just use it - if (num_chunks == 1) { - current_chunk = values.chunk(0); - } else { - // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it - // See - // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151 - // TODO Case 3: If indices are sorted, can slice them and call Array Take - - // Case 4: Else, concatenate chunks and call Array Take - ARROW_ASSIGN_OR_RAISE(current_chunk, - Concatenate(values.chunks(), ctx->memory_pool())); - } - // Call Array Take on our single chunk - ARROW_ASSIGN_OR_RAISE(new_chunks[0], TakeAA(*current_chunk, indices, options, ctx)); - return std::make_shared<ChunkedArray>(std::move(new_chunks)); -} - -Result<std::shared_ptr<ChunkedArray>> TakeCC(const ChunkedArray& values, - const ChunkedArray& indices, - const TakeOptions& options, - ExecContext* ctx) { - auto num_chunks = indices.num_chunks(); - std::vector<std::shared_ptr<Array>> new_chunks(num_chunks); - for (int i = 0; i < num_chunks; i++) { - // Take with that indices chunk - // Note that as currently implemented, this is inefficient because `values` - // will get concatenated on every iteration of this loop - ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ChunkedArray> current_chunk, - TakeCA(values, *indices.chunk(i), options, ctx)); - // Concatenate the result to make a single array for this chunk - ARROW_ASSIGN_OR_RAISE(new_chunks[i], - Concatenate(current_chunk->chunks(), ctx->memory_pool())); - } - return std::make_shared<ChunkedArray>(std::move(new_chunks)); -} - -Result<std::shared_ptr<ChunkedArray>> TakeAC(const Array& values, - const ChunkedArray& indices, - const TakeOptions& options, - ExecContext* ctx) { - auto num_chunks = indices.num_chunks(); - std::vector<std::shared_ptr<Array>> new_chunks(num_chunks); - for (int i = 0; i < num_chunks; i++) { - // Take with that indices chunk - ARROW_ASSIGN_OR_RAISE(new_chunks[i], TakeAA(values, *indices.chunk(i), options, ctx)); - } - return std::make_shared<ChunkedArray>(std::move(new_chunks)); -} - -Result<std::shared_ptr<RecordBatch>> TakeRA(const RecordBatch& batch, - const Array& indices, - const TakeOptions& options, - ExecContext* ctx) { - auto ncols = batch.num_columns(); - auto nrows = indices.length(); - std::vector<std::shared_ptr<Array>> columns(ncols); - for (int j = 0; j < ncols; j++) { - ARROW_ASSIGN_OR_RAISE(columns[j], TakeAA(*batch.column(j), indices, options, ctx)); - } + + Result<Datum> ExecuteImpl(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + if (args[1].type()->id() != Type::BOOL) { + return Status::NotImplemented("Filter argument must be boolean type"); + } + + if (args[0].kind() == Datum::RECORD_BATCH) { + auto values_batch = args[0].record_batch(); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr<RecordBatch> out_batch, + FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx)); + return Datum(out_batch); + } else if (args[0].kind() == Datum::TABLE) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> out_table, + FilterTable(*args[0].table(), args[1], options, ctx)); + return Datum(out_table); + } else { + return CallFunction("array_filter", args, options, ctx); + } + } +}; + +// ---------------------------------------------------------------------- +// Implement Take metafunction + +// Shorthand naming of these functions +// A -> Array +// C -> ChunkedArray +// R -> RecordBatch +// T -> Table + +Result<std::shared_ptr<Array>> TakeAA(const Array& values, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("array_take", {values, indices}, &options, ctx)); + return result.make_array(); +} + +Result<std::shared_ptr<ChunkedArray>> TakeCA(const ChunkedArray& values, + const Array& indices, + const TakeOptions& options, + ExecContext* ctx) { + auto num_chunks = values.num_chunks(); + std::vector<std::shared_ptr<Array>> new_chunks(1); // Hard-coded 1 for now + std::shared_ptr<Array> current_chunk; + + // Case 1: `values` has a single chunk, so just use it + if (num_chunks == 1) { + current_chunk = values.chunk(0); + } else { + // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it + // See + // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151 + // TODO Case 3: If indices are sorted, can slice them and call Array Take + + // Case 4: Else, concatenate chunks and call Array Take + ARROW_ASSIGN_OR_RAISE(current_chunk, + Concatenate(values.chunks(), ctx->memory_pool())); + } + // Call Array Take on our single chunk + ARROW_ASSIGN_OR_RAISE(new_chunks[0], TakeAA(*current_chunk, indices, options, ctx)); + return std::make_shared<ChunkedArray>(std::move(new_chunks)); +} + +Result<std::shared_ptr<ChunkedArray>> TakeCC(const ChunkedArray& values, + const ChunkedArray& indices, + const TakeOptions& options, + ExecContext* ctx) { + auto num_chunks = indices.num_chunks(); + std::vector<std::shared_ptr<Array>> new_chunks(num_chunks); + for (int i = 0; i < num_chunks; i++) { + // Take with that indices chunk + // Note that as currently implemented, this is inefficient because `values` + // will get concatenated on every iteration of this loop + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ChunkedArray> current_chunk, + TakeCA(values, *indices.chunk(i), options, ctx)); + // Concatenate the result to make a single array for this chunk + ARROW_ASSIGN_OR_RAISE(new_chunks[i], + Concatenate(current_chunk->chunks(), ctx->memory_pool())); + } + return std::make_shared<ChunkedArray>(std::move(new_chunks)); +} + +Result<std::shared_ptr<ChunkedArray>> TakeAC(const Array& values, + const ChunkedArray& indices, + const TakeOptions& options, + ExecContext* ctx) { + auto num_chunks = indices.num_chunks(); + std::vector<std::shared_ptr<Array>> new_chunks(num_chunks); + for (int i = 0; i < num_chunks; i++) { + // Take with that indices chunk + ARROW_ASSIGN_OR_RAISE(new_chunks[i], TakeAA(values, *indices.chunk(i), options, ctx)); + } + return std::make_shared<ChunkedArray>(std::move(new_chunks)); +} + +Result<std::shared_ptr<RecordBatch>> TakeRA(const RecordBatch& batch, + const Array& indices, + const TakeOptions& options, + ExecContext* ctx) { + auto ncols = batch.num_columns(); + auto nrows = indices.length(); + std::vector<std::shared_ptr<Array>> columns(ncols); + for (int j = 0; j < ncols; j++) { + ARROW_ASSIGN_OR_RAISE(columns[j], TakeAA(*batch.column(j), indices, options, ctx)); + } return RecordBatch::Make(batch.schema(), nrows, std::move(columns)); -} - -Result<std::shared_ptr<Table>> TakeTA(const Table& table, const Array& indices, - const TakeOptions& options, ExecContext* ctx) { - auto ncols = table.num_columns(); - std::vector<std::shared_ptr<ChunkedArray>> columns(ncols); - - for (int j = 0; j < ncols; j++) { - ARROW_ASSIGN_OR_RAISE(columns[j], TakeCA(*table.column(j), indices, options, ctx)); - } +} + +Result<std::shared_ptr<Table>> TakeTA(const Table& table, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + auto ncols = table.num_columns(); + std::vector<std::shared_ptr<ChunkedArray>> columns(ncols); + + for (int j = 0; j < ncols; j++) { + ARROW_ASSIGN_OR_RAISE(columns[j], TakeCA(*table.column(j), indices, options, ctx)); + } return Table::Make(table.schema(), std::move(columns)); -} - -Result<std::shared_ptr<Table>> TakeTC(const Table& table, const ChunkedArray& indices, - const TakeOptions& options, ExecContext* ctx) { - auto ncols = table.num_columns(); - std::vector<std::shared_ptr<ChunkedArray>> columns(ncols); - for (int j = 0; j < ncols; j++) { - ARROW_ASSIGN_OR_RAISE(columns[j], TakeCC(*table.column(j), indices, options, ctx)); - } +} + +Result<std::shared_ptr<Table>> TakeTC(const Table& table, const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + auto ncols = table.num_columns(); + std::vector<std::shared_ptr<ChunkedArray>> columns(ncols); + for (int j = 0; j < ncols; j++) { + ARROW_ASSIGN_OR_RAISE(columns[j], TakeCC(*table.column(j), indices, options, ctx)); + } return Table::Make(table.schema(), std::move(columns)); -} - -static auto kDefaultTakeOptions = TakeOptions::Defaults(); - +} + +static auto kDefaultTakeOptions = TakeOptions::Defaults(); + const FunctionDoc take_doc( "Select values from an input based on indices from another array", ("The output is populated with values from the input at positions\n" "given by `indices`. Nulls in `indices` emit null in the output."), {"input", "indices"}, "TakeOptions"); -// Metafunction for dispatching to different Take implementations other than -// Array-Array. -// -// TODO: Revamp approach to executing Take operations. In addition to being -// overly complex dispatching, there is no parallelization. -class TakeMetaFunction : public MetaFunction { - public: +// Metafunction for dispatching to different Take implementations other than +// Array-Array. +// +// TODO: Revamp approach to executing Take operations. In addition to being +// overly complex dispatching, there is no parallelization. +class TakeMetaFunction : public MetaFunction { + public: TakeMetaFunction() : MetaFunction("take", Arity::Binary(), &take_doc, &kDefaultTakeOptions) {} - - Result<Datum> ExecuteImpl(const std::vector<Datum>& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - Datum::Kind index_kind = args[1].kind(); - const TakeOptions& take_opts = static_cast<const TakeOptions&>(*options); - switch (args[0].kind()) { - case Datum::ARRAY: - if (index_kind == Datum::ARRAY) { - return TakeAA(*args[0].make_array(), *args[1].make_array(), take_opts, ctx); - } else if (index_kind == Datum::CHUNKED_ARRAY) { - return TakeAC(*args[0].make_array(), *args[1].chunked_array(), take_opts, ctx); - } - break; - case Datum::CHUNKED_ARRAY: - if (index_kind == Datum::ARRAY) { - return TakeCA(*args[0].chunked_array(), *args[1].make_array(), take_opts, ctx); - } else if (index_kind == Datum::CHUNKED_ARRAY) { - return TakeCC(*args[0].chunked_array(), *args[1].chunked_array(), take_opts, - ctx); - } - break; - case Datum::RECORD_BATCH: - if (index_kind == Datum::ARRAY) { - return TakeRA(*args[0].record_batch(), *args[1].make_array(), take_opts, ctx); - } - break; - case Datum::TABLE: - if (index_kind == Datum::ARRAY) { - return TakeTA(*args[0].table(), *args[1].make_array(), take_opts, ctx); - } else if (index_kind == Datum::CHUNKED_ARRAY) { - return TakeTC(*args[0].table(), *args[1].chunked_array(), take_opts, ctx); - } - break; - default: - break; - } - return Status::NotImplemented( - "Unsupported types for take operation: " - "values=", - args[0].ToString(), "indices=", args[1].ToString()); - } -}; - -// ---------------------------------------------------------------------- - -template <typename Impl> + + Result<Datum> ExecuteImpl(const std::vector<Datum>& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + Datum::Kind index_kind = args[1].kind(); + const TakeOptions& take_opts = static_cast<const TakeOptions&>(*options); + switch (args[0].kind()) { + case Datum::ARRAY: + if (index_kind == Datum::ARRAY) { + return TakeAA(*args[0].make_array(), *args[1].make_array(), take_opts, ctx); + } else if (index_kind == Datum::CHUNKED_ARRAY) { + return TakeAC(*args[0].make_array(), *args[1].chunked_array(), take_opts, ctx); + } + break; + case Datum::CHUNKED_ARRAY: + if (index_kind == Datum::ARRAY) { + return TakeCA(*args[0].chunked_array(), *args[1].make_array(), take_opts, ctx); + } else if (index_kind == Datum::CHUNKED_ARRAY) { + return TakeCC(*args[0].chunked_array(), *args[1].chunked_array(), take_opts, + ctx); + } + break; + case Datum::RECORD_BATCH: + if (index_kind == Datum::ARRAY) { + return TakeRA(*args[0].record_batch(), *args[1].make_array(), take_opts, ctx); + } + break; + case Datum::TABLE: + if (index_kind == Datum::ARRAY) { + return TakeTA(*args[0].table(), *args[1].make_array(), take_opts, ctx); + } else if (index_kind == Datum::CHUNKED_ARRAY) { + return TakeTC(*args[0].table(), *args[1].chunked_array(), take_opts, ctx); + } + break; + default: + break; + } + return Status::NotImplemented( + "Unsupported types for take operation: " + "values=", + args[0].ToString(), "indices=", args[1].ToString()); + } +}; + +// ---------------------------------------------------------------------- + +template <typename Impl> Status FilterExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: where are the values and filter length equality checked? - int64_t output_length = GetFilterOutputSize( - *batch[1].array(), FilterState::Get(ctx).null_selection_behavior); - Impl kernel(ctx, batch, output_length, out); + // TODO: where are the values and filter length equality checked? + int64_t output_length = GetFilterOutputSize( + *batch[1].array(), FilterState::Get(ctx).null_selection_behavior); + Impl kernel(ctx, batch, output_length, out); return kernel.ExecFilter(); -} - -template <typename Impl> +} + +template <typename Impl> Status TakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (TakeState::Get(ctx).boundscheck) { + if (TakeState::Get(ctx).boundscheck) { RETURN_NOT_OK(CheckIndexBounds(*batch[1].array(), batch[0].length())); - } - Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out); + } + Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out); return kernel.ExecTake(); -} - -struct SelectionKernelDescr { - InputType input; - ArrayKernelExec exec; -}; - +} + +struct SelectionKernelDescr { + InputType input; + ArrayKernelExec exec; +}; + void RegisterSelectionFunction(const std::string& name, const FunctionDoc* doc, VectorKernel base_kernel, InputType selection_type, - const std::vector<SelectionKernelDescr>& descrs, - const FunctionOptions* default_options, - FunctionRegistry* registry) { + const std::vector<SelectionKernelDescr>& descrs, + const FunctionOptions* default_options, + FunctionRegistry* registry) { auto func = std::make_shared<VectorFunction>(name, Arity::Binary(), doc, default_options); - for (auto& descr : descrs) { - base_kernel.signature = KernelSignature::Make( - {std::move(descr.input), selection_type}, OutputType(FirstType)); - base_kernel.exec = descr.exec; - DCHECK_OK(func->AddKernel(base_kernel)); - } - DCHECK_OK(registry->AddFunction(std::move(func))); -} - + for (auto& descr : descrs) { + base_kernel.signature = KernelSignature::Make( + {std::move(descr.input), selection_type}, OutputType(FirstType)); + base_kernel.exec = descr.exec; + DCHECK_OK(func->AddKernel(base_kernel)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); +} + const FunctionDoc array_filter_doc( "Filter with a boolean selection filter", ("The output is populated with values from the input `array` at positions\n" @@ -2200,69 +2200,69 @@ const FunctionDoc array_take_doc( "given by `indices`. Nulls in `indices` emit null in the output."), {"array", "indices"}, "TakeOptions"); -} // namespace - -void RegisterVectorSelection(FunctionRegistry* registry) { - // Filter kernels - std::vector<SelectionKernelDescr> filter_kernel_descrs = { - {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveFilter}, - {InputType(match::BinaryLike(), ValueDescr::ARRAY), BinaryFilter}, - {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY), BinaryFilter}, - {InputType::Array(Type::FIXED_SIZE_BINARY), FilterExec<FSBImpl>}, - {InputType::Array(null()), NullFilter}, - {InputType::Array(Type::DECIMAL), FilterExec<FSBImpl>}, - {InputType::Array(Type::DICTIONARY), DictionaryFilter}, - {InputType::Array(Type::EXTENSION), ExtensionFilter}, - {InputType::Array(Type::LIST), FilterExec<ListImpl<ListType>>}, - {InputType::Array(Type::LARGE_LIST), FilterExec<ListImpl<LargeListType>>}, - {InputType::Array(Type::FIXED_SIZE_LIST), FilterExec<FSLImpl>}, +} // namespace + +void RegisterVectorSelection(FunctionRegistry* registry) { + // Filter kernels + std::vector<SelectionKernelDescr> filter_kernel_descrs = { + {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveFilter}, + {InputType(match::BinaryLike(), ValueDescr::ARRAY), BinaryFilter}, + {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY), BinaryFilter}, + {InputType::Array(Type::FIXED_SIZE_BINARY), FilterExec<FSBImpl>}, + {InputType::Array(null()), NullFilter}, + {InputType::Array(Type::DECIMAL), FilterExec<FSBImpl>}, + {InputType::Array(Type::DICTIONARY), DictionaryFilter}, + {InputType::Array(Type::EXTENSION), ExtensionFilter}, + {InputType::Array(Type::LIST), FilterExec<ListImpl<ListType>>}, + {InputType::Array(Type::LARGE_LIST), FilterExec<ListImpl<LargeListType>>}, + {InputType::Array(Type::FIXED_SIZE_LIST), FilterExec<FSLImpl>}, {InputType::Array(Type::DENSE_UNION), FilterExec<DenseUnionImpl>}, - {InputType::Array(Type::STRUCT), StructFilter}, - // TODO: Reuse ListType kernel for MAP - {InputType::Array(Type::MAP), FilterExec<ListImpl<MapType>>}, - }; - - VectorKernel filter_base; - filter_base.init = FilterState::Init; + {InputType::Array(Type::STRUCT), StructFilter}, + // TODO: Reuse ListType kernel for MAP + {InputType::Array(Type::MAP), FilterExec<ListImpl<MapType>>}, + }; + + VectorKernel filter_base; + filter_base.init = FilterState::Init; RegisterSelectionFunction("array_filter", &array_filter_doc, filter_base, - /*selection_type=*/InputType::Array(boolean()), - filter_kernel_descrs, &kDefaultFilterOptions, registry); - - DCHECK_OK(registry->AddFunction(std::make_shared<FilterMetaFunction>())); - - // Take kernels - std::vector<SelectionKernelDescr> take_kernel_descrs = { - {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveTake}, - {InputType(match::BinaryLike(), ValueDescr::ARRAY), - TakeExec<VarBinaryImpl<BinaryType>>}, - {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY), - TakeExec<VarBinaryImpl<LargeBinaryType>>}, - {InputType::Array(Type::FIXED_SIZE_BINARY), TakeExec<FSBImpl>}, - {InputType::Array(null()), NullTake}, + /*selection_type=*/InputType::Array(boolean()), + filter_kernel_descrs, &kDefaultFilterOptions, registry); + + DCHECK_OK(registry->AddFunction(std::make_shared<FilterMetaFunction>())); + + // Take kernels + std::vector<SelectionKernelDescr> take_kernel_descrs = { + {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveTake}, + {InputType(match::BinaryLike(), ValueDescr::ARRAY), + TakeExec<VarBinaryImpl<BinaryType>>}, + {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY), + TakeExec<VarBinaryImpl<LargeBinaryType>>}, + {InputType::Array(Type::FIXED_SIZE_BINARY), TakeExec<FSBImpl>}, + {InputType::Array(null()), NullTake}, {InputType::Array(Type::DECIMAL128), TakeExec<FSBImpl>}, {InputType::Array(Type::DECIMAL256), TakeExec<FSBImpl>}, - {InputType::Array(Type::DICTIONARY), DictionaryTake}, - {InputType::Array(Type::EXTENSION), ExtensionTake}, - {InputType::Array(Type::LIST), TakeExec<ListImpl<ListType>>}, - {InputType::Array(Type::LARGE_LIST), TakeExec<ListImpl<LargeListType>>}, - {InputType::Array(Type::FIXED_SIZE_LIST), TakeExec<FSLImpl>}, + {InputType::Array(Type::DICTIONARY), DictionaryTake}, + {InputType::Array(Type::EXTENSION), ExtensionTake}, + {InputType::Array(Type::LIST), TakeExec<ListImpl<ListType>>}, + {InputType::Array(Type::LARGE_LIST), TakeExec<ListImpl<LargeListType>>}, + {InputType::Array(Type::FIXED_SIZE_LIST), TakeExec<FSLImpl>}, {InputType::Array(Type::DENSE_UNION), TakeExec<DenseUnionImpl>}, - {InputType::Array(Type::STRUCT), TakeExec<StructImpl>}, - // TODO: Reuse ListType kernel for MAP - {InputType::Array(Type::MAP), TakeExec<ListImpl<MapType>>}, - }; - - VectorKernel take_base; - take_base.init = TakeState::Init; - take_base.can_execute_chunkwise = false; - RegisterSelectionFunction( + {InputType::Array(Type::STRUCT), TakeExec<StructImpl>}, + // TODO: Reuse ListType kernel for MAP + {InputType::Array(Type::MAP), TakeExec<ListImpl<MapType>>}, + }; + + VectorKernel take_base; + take_base.init = TakeState::Init; + take_base.can_execute_chunkwise = false; + RegisterSelectionFunction( "array_take", &array_take_doc, take_base, - /*selection_type=*/InputType(match::Integer(), ValueDescr::ARRAY), - take_kernel_descrs, &kDefaultTakeOptions, registry); - - DCHECK_OK(registry->AddFunction(std::make_shared<TakeMetaFunction>())); -} - -} // namespace internal -} // namespace compute -} // namespace arrow + /*selection_type=*/InputType(match::Integer(), ValueDescr::ARRAY), + take_kernel_descrs, &kDefaultTakeOptions, registry); + + DCHECK_OK(registry->AddFunction(std::make_shared<TakeMetaFunction>())); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc index 7fa43e715d..b7e7adc70e 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -1,30 +1,30 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <algorithm> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> #include <cmath> -#include <limits> -#include <numeric> +#include <limits> +#include <numeric> #include <type_traits> #include <utility> - -#include "arrow/array/data.h" -#include "arrow/compute/api_vector.h" -#include "arrow/compute/kernels/common.h" + +#include "arrow/array/data.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/table.h" #include "arrow/type_traits.h" @@ -32,16 +32,16 @@ #include "arrow/util/bitmap.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/optional.h" +#include "arrow/util/optional.h" #include "arrow/visitor_inline.h" - -namespace arrow { + +namespace arrow { using internal::checked_cast; -namespace compute { +namespace compute { namespace internal { - + // Visit all physical types for which sorting is implemented. #define VISIT_PHYSICAL_TYPES(VISIT) \ VISIT(BooleanType) \ @@ -61,8 +61,8 @@ namespace internal { VISIT(Decimal128Type) \ VISIT(Decimal256Type) -namespace { - +namespace { + // The target chunk in a chunked array. template <typename ArrayType> struct ResolvedChunk { @@ -315,59 +315,59 @@ uint64_t* PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end, null_count); } -// ---------------------------------------------------------------------- -// partition_nth_indices implementation - -// We need to preserve the options -using PartitionNthToIndicesState = internal::OptionsWrapper<PartitionNthOptions>; - -template <typename OutType, typename InType> -struct PartitionNthToIndices { - using ArrayType = typename TypeTraits<InType>::ArrayType; +// ---------------------------------------------------------------------- +// partition_nth_indices implementation + +// We need to preserve the options +using PartitionNthToIndicesState = internal::OptionsWrapper<PartitionNthOptions>; + +template <typename OutType, typename InType> +struct PartitionNthToIndices { + using ArrayType = typename TypeTraits<InType>::ArrayType; static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { using GetView = GetViewType<InType>; - if (ctx->state() == nullptr) { + if (ctx->state() == nullptr) { return Status::Invalid("NthToIndices requires PartitionNthOptions"); - } - + } + ArrayType arr(batch[0].array()); - - int64_t pivot = PartitionNthToIndicesState::Get(ctx).pivot; - if (pivot > arr.length()) { + + int64_t pivot = PartitionNthToIndicesState::Get(ctx).pivot; + if (pivot > arr.length()) { return Status::IndexError("NthToIndices index out of bound"); - } - ArrayData* out_arr = out->mutable_array(); - uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1); - uint64_t* out_end = out_begin + arr.length(); - std::iota(out_begin, out_end, 0); - if (pivot == arr.length()) { + } + ArrayData* out_arr = out->mutable_array(); + uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1); + uint64_t* out_end = out_begin + arr.length(); + std::iota(out_begin, out_end, 0); + if (pivot == arr.length()) { return Status::OK(); - } + } auto nulls_begin = PartitionNulls<ArrayType, NonStablePartitioner>(out_begin, out_end, arr, 0); - auto nth_begin = out_begin + pivot; - if (nth_begin < nulls_begin) { - std::nth_element(out_begin, nth_begin, nulls_begin, - [&arr](uint64_t left, uint64_t right) { + auto nth_begin = out_begin + pivot; + if (nth_begin < nulls_begin) { + std::nth_element(out_begin, nth_begin, nulls_begin, + [&arr](uint64_t left, uint64_t right) { const auto lval = GetView::LogicalValue(arr.GetView(left)); const auto rval = GetView::LogicalValue(arr.GetView(right)); return lval < rval; - }); - } + }); + } return Status::OK(); - } -}; - + } +}; + // ---------------------------------------------------------------------- // Array sorting implementations -template <typename ArrayType, typename VisitorNotNull, typename VisitorNull> -inline void VisitRawValuesInline(const ArrayType& values, - VisitorNotNull&& visitor_not_null, - VisitorNull&& visitor_null) { - const auto data = values.raw_values(); +template <typename ArrayType, typename VisitorNotNull, typename VisitorNull> +inline void VisitRawValuesInline(const ArrayType& values, + VisitorNotNull&& visitor_not_null, + VisitorNull&& visitor_null) { + const auto data = values.raw_values(); VisitBitBlocksVoid( values.null_bitmap(), values.offset(), values.length(), [&](int64_t i) { visitor_not_null(data[i]); }, [&]() { visitor_null(); }); @@ -383,20 +383,20 @@ inline void VisitRawValuesInline(const BooleanArray& values, values.null_bitmap(), values.offset(), values.length(), [&](int64_t i) { visitor_not_null(BitUtil::GetBit(data, values.offset() + i)); }, [&]() { visitor_null(); }); - } else { + } else { // Can avoid GetBit() overhead in the no-nulls case VisitBitBlocksVoid( values.data()->buffers[1], values.offset(), values.length(), [&](int64_t i) { visitor_not_null(true); }, [&]() { visitor_not_null(false); }); - } -} - -template <typename ArrowType> + } +} + +template <typename ArrowType> class ArrayCompareSorter { - using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; using GetView = GetViewType<ArrowType>; - - public: + + public: // Returns where null starts. // // `offset` is used when this is called on a chunk of a chunked array @@ -420,54 +420,54 @@ class ArrayCompareSorter { // If we use 'right < left' here, '<' is only required. return rhs < lhs; }); - } + } return nulls_begin; - } -}; - -template <typename ArrowType> + } +}; + +template <typename ArrowType> class ArrayCountSorter { - using ArrayType = typename TypeTraits<ArrowType>::ArrayType; - using c_type = typename ArrowType::c_type; - - public: + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using c_type = typename ArrowType::c_type; + + public: ArrayCountSorter() = default; - + explicit ArrayCountSorter(c_type min, c_type max) { SetMinMax(min, max); } - - // Assume: max >= min && (max - min) < 4Gi - void SetMinMax(c_type min, c_type max) { - min_ = min; - value_range_ = static_cast<uint32_t>(max - min) + 1; - } - + + // Assume: max >= min && (max - min) < 4Gi + void SetMinMax(c_type min, c_type max) { + min_ = min; + value_range_ = static_cast<uint32_t>(max - min) + 1; + } + // Returns where null starts. uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, int64_t offset, const ArraySortOptions& options) { - // 32bit counter performs much better than 64bit one - if (values.length() < (1LL << 32)) { + // 32bit counter performs much better than 64bit one + if (values.length() < (1LL << 32)) { return SortInternal<uint32_t>(indices_begin, indices_end, values, offset, options); - } else { + } else { return SortInternal<uint64_t>(indices_begin, indices_end, values, offset, options); - } - } - - private: - c_type min_{0}; - uint32_t value_range_{0}; - + } + } + + private: + c_type min_{0}; + uint32_t value_range_{0}; + // Returns where null starts. // // `offset` is used when this is called on a chunk of a chunked array - template <typename CounterType> + template <typename CounterType> uint64_t* SortInternal(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, int64_t offset, const ArraySortOptions& options) { - const uint32_t value_range = value_range_; - - // first slot reserved for prefix sum - std::vector<CounterType> counts(1 + value_range); - + const uint32_t value_range = value_range_; + + // first slot reserved for prefix sum + std::vector<CounterType> counts(1 + value_range); + if (options.order == SortOrder::Ascending) { VisitRawValuesInline( values, [&](c_type v) { ++counts[v - min_ + 1]; }, []() {}); @@ -497,7 +497,7 @@ class ArrayCountSorter { } } }; - + using ::arrow::internal::Bitmap; template <> @@ -526,135 +526,135 @@ class ArrayCountSorter<BooleanType> { } else { // zeros start after ones counts[0] = ones; - } - VisitRawValuesInline( + } + VisitRawValuesInline( values, [&](bool v) { indices_begin[counts[v]++] = index++; }, [&]() { indices_begin[null_position++] = index++; }); return nulls_begin; - } -}; - -// Sort integers with counting sort or comparison based sorting algorithm -// - Use O(n) counting sort if values are in a small range -// - Use O(nlogn) std::stable_sort otherwise -template <typename ArrowType> + } +}; + +// Sort integers with counting sort or comparison based sorting algorithm +// - Use O(n) counting sort if values are in a small range +// - Use O(nlogn) std::stable_sort otherwise +template <typename ArrowType> class ArrayCountOrCompareSorter { - using ArrayType = typename TypeTraits<ArrowType>::ArrayType; - using c_type = typename ArrowType::c_type; - - public: + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using c_type = typename ArrowType::c_type; + + public: // Returns where null starts. // // `offset` is used when this is called on a chunk of a chunked array uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, int64_t offset, const ArraySortOptions& options) { - if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) { + if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) { c_type min, max; std::tie(min, max) = GetMinMax<c_type>(*values.data()); - - // For signed int32/64, (max - min) may overflow and trigger UBSAN. - // Cast to largest unsigned type(uint64_t) before subtraction. - if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <= - countsort_max_range_) { - count_sorter_.SetMinMax(min, max); + + // For signed int32/64, (max - min) may overflow and trigger UBSAN. + // Cast to largest unsigned type(uint64_t) before subtraction. + if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <= + countsort_max_range_) { + count_sorter_.SetMinMax(min, max); return count_sorter_.Sort(indices_begin, indices_end, values, offset, options); - } - } - + } + } + return compare_sorter_.Sort(indices_begin, indices_end, values, offset, options); - } - - private: + } + + private: ArrayCompareSorter<ArrowType> compare_sorter_; ArrayCountSorter<ArrowType> count_sorter_; - - // Cross point to prefer counting sort than stl::stable_sort(merge sort) - // - array to be sorted is longer than "count_min_len_" - // - value range (max-min) is within "count_max_range_" - // - // The optimal setting depends heavily on running CPU. Below setting is - // conservative to adapt to various hardware and keep code simple. - // It's possible to decrease array-len and/or increase value-range to cover - // more cases, or setup a table for best array-len/value-range combinations. - // See https://issues.apache.org/jira/browse/ARROW-1571 for detailed analysis. - static const uint32_t countsort_min_len_ = 1024; - static const uint32_t countsort_max_range_ = 4096; -}; - -template <typename Type, typename Enable = void> + + // Cross point to prefer counting sort than stl::stable_sort(merge sort) + // - array to be sorted is longer than "count_min_len_" + // - value range (max-min) is within "count_max_range_" + // + // The optimal setting depends heavily on running CPU. Below setting is + // conservative to adapt to various hardware and keep code simple. + // It's possible to decrease array-len and/or increase value-range to cover + // more cases, or setup a table for best array-len/value-range combinations. + // See https://issues.apache.org/jira/browse/ARROW-1571 for detailed analysis. + static const uint32_t countsort_min_len_ = 1024; + static const uint32_t countsort_max_range_ = 4096; +}; + +template <typename Type, typename Enable = void> struct ArraySorter; - -template <> + +template <> struct ArraySorter<BooleanType> { ArrayCountSorter<BooleanType> impl; -}; - -template <> +}; + +template <> struct ArraySorter<UInt8Type> { ArrayCountSorter<UInt8Type> impl; ArraySorter() : impl(0, 255) {} -}; - +}; + template <> struct ArraySorter<Int8Type> { ArrayCountSorter<Int8Type> impl; ArraySorter() : impl(-128, 127) {} }; -template <typename Type> +template <typename Type> struct ArraySorter<Type, enable_if_t<(is_integer_type<Type>::value && (sizeof(typename Type::c_type) > 1)) || is_temporal_type<Type>::value>> { ArrayCountOrCompareSorter<Type> impl; -}; - -template <typename Type> +}; + +template <typename Type> struct ArraySorter< Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value || is_fixed_size_binary_type<Type>::value>> { ArrayCompareSorter<Type> impl; -}; - +}; + using ArraySortIndicesState = internal::OptionsWrapper<ArraySortOptions>; -template <typename OutType, typename InType> +template <typename OutType, typename InType> struct ArraySortIndices { - using ArrayType = typename TypeTraits<InType>::ArrayType; + using ArrayType = typename TypeTraits<InType>::ArrayType; static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& options = ArraySortIndicesState::Get(ctx); ArrayType arr(batch[0].array()); - ArrayData* out_arr = out->mutable_array(); - uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1); - uint64_t* out_end = out_begin + arr.length(); + ArrayData* out_arr = out->mutable_array(); + uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1); + uint64_t* out_end = out_begin + arr.length(); std::iota(out_begin, out_end, 0); - + ArraySorter<InType> sorter; sorter.impl.Sort(out_begin, out_end, arr, 0, options); return Status::OK(); - } -}; - -// Sort indices kernels implemented for -// + } +}; + +// Sort indices kernels implemented for +// // * Boolean type -// * Number types -// * Base binary types - -template <template <typename...> class ExecTemplate> -void AddSortingKernels(VectorKernel base, VectorFunction* func) { +// * Number types +// * Base binary types + +template <template <typename...> class ExecTemplate> +void AddSortingKernels(VectorKernel base, VectorFunction* func) { // bool type base.signature = KernelSignature::Make({InputType::Array(boolean())}, uint64()); base.exec = ExecTemplate<UInt64Type, BooleanType>::Exec; DCHECK_OK(func->AddKernel(base)); - for (const auto& ty : NumericTypes()) { + for (const auto& ty : NumericTypes()) { auto physical_type = GetPhysicalType(ty); - base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64()); + base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64()); base.exec = GenerateNumeric<ExecTemplate, UInt64Type>(*physical_type); - DCHECK_OK(func->AddKernel(base)); - } + DCHECK_OK(func->AddKernel(base)); + } for (const auto& ty : TemporalTypes()) { auto physical_type = GetPhysicalType(ty); base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64()); @@ -666,18 +666,18 @@ void AddSortingKernels(VectorKernel base, VectorFunction* func) { base.exec = GenerateDecimal<ExecTemplate, UInt64Type>(id); DCHECK_OK(func->AddKernel(base)); } - for (const auto& ty : BaseBinaryTypes()) { + for (const auto& ty : BaseBinaryTypes()) { auto physical_type = GetPhysicalType(ty); - base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64()); + base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64()); base.exec = GenerateVarBinaryBase<ExecTemplate, UInt64Type>(*physical_type); - DCHECK_OK(func->AddKernel(base)); - } + DCHECK_OK(func->AddKernel(base)); + } base.signature = KernelSignature::Make({InputType::Array(Type::FIXED_SIZE_BINARY)}, uint64()); base.exec = ExecTemplate<UInt64Type, FixedSizeBinaryType>::Exec; DCHECK_OK(func->AddKernel(base)); -} - +} + // ---------------------------------------------------------------------- // ChunkedArray sorting implementations @@ -1808,31 +1808,31 @@ const FunctionDoc partition_nth_indices_doc( } // namespace -void RegisterVectorSort(FunctionRegistry* registry) { - // The kernel outputs into preallocated memory and is never null - VectorKernel base; - base.mem_allocation = MemAllocation::PREALLOCATE; - base.null_handling = NullHandling::OUTPUT_NOT_NULL; - +void RegisterVectorSort(FunctionRegistry* registry) { + // The kernel outputs into preallocated memory and is never null + VectorKernel base; + base.mem_allocation = MemAllocation::PREALLOCATE; + base.null_handling = NullHandling::OUTPUT_NOT_NULL; + auto array_sort_indices = std::make_shared<VectorFunction>( "array_sort_indices", Arity::Unary(), &array_sort_indices_doc, &kDefaultArraySortOptions); base.init = ArraySortIndicesState::Init; AddSortingKernels<ArraySortIndices>(base, array_sort_indices.get()); DCHECK_OK(registry->AddFunction(std::move(array_sort_indices))); - + DCHECK_OK(registry->AddFunction(std::make_shared<SortIndicesMetaFunction>())); - // partition_nth_indices has a parameter so needs its init function + // partition_nth_indices has a parameter so needs its init function auto part_indices = std::make_shared<VectorFunction>( "partition_nth_indices", Arity::Unary(), &partition_nth_indices_doc); - base.init = PartitionNthToIndicesState::Init; - AddSortingKernels<PartitionNthToIndices>(base, part_indices.get()); - DCHECK_OK(registry->AddFunction(std::move(part_indices))); -} - + base.init = PartitionNthToIndicesState::Init; + AddSortingKernels<PartitionNthToIndices>(base, part_indices.get()); + DCHECK_OK(registry->AddFunction(std::move(part_indices))); +} + #undef VISIT_PHYSICAL_TYPES -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.cc b/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.cc index ca7b613730..7439faa7b2 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.cc @@ -1,64 +1,64 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/registry.h" - -#include <algorithm> -#include <memory> -#include <mutex> -#include <unordered_map> -#include <utility> - -#include "arrow/compute/function.h" +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/registry.h" + +#include <algorithm> +#include <memory> +#include <mutex> +#include <unordered_map> +#include <utility> + +#include "arrow/compute/function.h" #include "arrow/compute/function_internal.h" -#include "arrow/compute/registry_internal.h" -#include "arrow/status.h" +#include "arrow/compute/registry_internal.h" +#include "arrow/status.h" #include "arrow/util/logging.h" - -namespace arrow { -namespace compute { - -class FunctionRegistry::FunctionRegistryImpl { - public: - Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) { + +namespace arrow { +namespace compute { + +class FunctionRegistry::FunctionRegistryImpl { + public: + Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) { RETURN_NOT_OK(function->Validate()); - std::lock_guard<std::mutex> mutation_guard(lock_); - - const std::string& name = function->name(); - auto it = name_to_function_.find(name); - if (it != name_to_function_.end() && !allow_overwrite) { - return Status::KeyError("Already have a function registered with name: ", name); - } - name_to_function_[name] = std::move(function); - return Status::OK(); - } - - Status AddAlias(const std::string& target_name, const std::string& source_name) { - std::lock_guard<std::mutex> mutation_guard(lock_); - - auto it = name_to_function_.find(source_name); - if (it == name_to_function_.end()) { - return Status::KeyError("No function registered with name: ", source_name); - } - name_to_function_[target_name] = it->second; - return Status::OK(); - } - + std::lock_guard<std::mutex> mutation_guard(lock_); + + const std::string& name = function->name(); + auto it = name_to_function_.find(name); + if (it != name_to_function_.end() && !allow_overwrite) { + return Status::KeyError("Already have a function registered with name: ", name); + } + name_to_function_[name] = std::move(function); + return Status::OK(); + } + + Status AddAlias(const std::string& target_name, const std::string& source_name) { + std::lock_guard<std::mutex> mutation_guard(lock_); + + auto it = name_to_function_.find(source_name); + if (it == name_to_function_.end()) { + return Status::KeyError("No function registered with name: ", source_name); + } + name_to_function_[target_name] = it->second; + return Status::OK(); + } + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) { std::lock_guard<std::mutex> mutation_guard(lock_); @@ -73,23 +73,23 @@ class FunctionRegistry::FunctionRegistryImpl { return Status::OK(); } - Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const { - auto it = name_to_function_.find(name); - if (it == name_to_function_.end()) { - return Status::KeyError("No function registered with name: ", name); - } - return it->second; - } - - std::vector<std::string> GetFunctionNames() const { - std::vector<std::string> results; - for (auto it : name_to_function_) { - results.push_back(it.first); - } - std::sort(results.begin(), results.end()); - return results; - } - + Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const { + auto it = name_to_function_.find(name); + if (it == name_to_function_.end()) { + return Status::KeyError("No function registered with name: ", name); + } + return it->second; + } + + std::vector<std::string> GetFunctionNames() const { + std::vector<std::string> results; + for (auto it : name_to_function_) { + results.push_back(it.first); + } + std::sort(results.begin(), results.end()); + return results; + } + Result<const FunctionOptionsType*> GetFunctionOptionsType( const std::string& name) const { auto it = name_to_options_type_.find(name); @@ -99,80 +99,80 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - int num_functions() const { return static_cast<int>(name_to_function_.size()); } - - private: - std::mutex lock_; - std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_; + int num_functions() const { return static_cast<int>(name_to_function_.size()); } + + private: + std::mutex lock_; + std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_; std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_; -}; - -std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() { - return std::unique_ptr<FunctionRegistry>(new FunctionRegistry()); -} - -FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); } - -FunctionRegistry::~FunctionRegistry() {} - -Status FunctionRegistry::AddFunction(std::shared_ptr<Function> function, - bool allow_overwrite) { - return impl_->AddFunction(std::move(function), allow_overwrite); -} - -Status FunctionRegistry::AddAlias(const std::string& target_name, - const std::string& source_name) { - return impl_->AddAlias(target_name, source_name); -} - +}; + +std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() { + return std::unique_ptr<FunctionRegistry>(new FunctionRegistry()); +} + +FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); } + +FunctionRegistry::~FunctionRegistry() {} + +Status FunctionRegistry::AddFunction(std::shared_ptr<Function> function, + bool allow_overwrite) { + return impl_->AddFunction(std::move(function), allow_overwrite); +} + +Status FunctionRegistry::AddAlias(const std::string& target_name, + const std::string& source_name) { + return impl_->AddAlias(target_name, source_name); +} + Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite) { return impl_->AddFunctionOptionsType(options_type, allow_overwrite); } -Result<std::shared_ptr<Function>> FunctionRegistry::GetFunction( - const std::string& name) const { - return impl_->GetFunction(name); -} - -std::vector<std::string> FunctionRegistry::GetFunctionNames() const { - return impl_->GetFunctionNames(); -} - +Result<std::shared_ptr<Function>> FunctionRegistry::GetFunction( + const std::string& name) const { + return impl_->GetFunction(name); +} + +std::vector<std::string> FunctionRegistry::GetFunctionNames() const { + return impl_->GetFunctionNames(); +} + Result<const FunctionOptionsType*> FunctionRegistry::GetFunctionOptionsType( const std::string& name) const { return impl_->GetFunctionOptionsType(name); } -int FunctionRegistry::num_functions() const { return impl_->num_functions(); } - -namespace internal { - -static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() { - auto registry = FunctionRegistry::Make(); - - // Scalar functions - RegisterScalarArithmetic(registry.get()); - RegisterScalarBoolean(registry.get()); - RegisterScalarCast(registry.get()); - RegisterScalarComparison(registry.get()); - RegisterScalarNested(registry.get()); - RegisterScalarSetLookup(registry.get()); - RegisterScalarStringAscii(registry.get()); - RegisterScalarValidity(registry.get()); - RegisterScalarFillNull(registry.get()); +int FunctionRegistry::num_functions() const { return impl_->num_functions(); } + +namespace internal { + +static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() { + auto registry = FunctionRegistry::Make(); + + // Scalar functions + RegisterScalarArithmetic(registry.get()); + RegisterScalarBoolean(registry.get()); + RegisterScalarCast(registry.get()); + RegisterScalarComparison(registry.get()); + RegisterScalarNested(registry.get()); + RegisterScalarSetLookup(registry.get()); + RegisterScalarStringAscii(registry.get()); + RegisterScalarValidity(registry.get()); + RegisterScalarFillNull(registry.get()); RegisterScalarIfElse(registry.get()); RegisterScalarTemporal(registry.get()); - + RegisterScalarOptions(registry.get()); - - // Vector functions - RegisterVectorHash(registry.get()); + + // Vector functions + RegisterVectorHash(registry.get()); RegisterVectorReplace(registry.get()); - RegisterVectorSelection(registry.get()); - RegisterVectorNested(registry.get()); - RegisterVectorSort(registry.get()); - + RegisterVectorSelection(registry.get()); + RegisterVectorNested(registry.get()); + RegisterVectorSort(registry.get()); + RegisterVectorOptions(registry.get()); // Aggregate functions @@ -185,15 +185,15 @@ static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() { RegisterAggregateOptions(registry.get()); - return registry; -} - -} // namespace internal - -FunctionRegistry* GetFunctionRegistry() { - static auto g_registry = internal::CreateBuiltInRegistry(); - return g_registry.get(); -} - -} // namespace compute -} // namespace arrow + return registry; +} + +} // namespace internal + +FunctionRegistry* GetFunctionRegistry() { + static auto g_registry = internal::CreateBuiltInRegistry(); + return g_registry.get(); +} + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h index e83036db6a..6769ecf79c 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h @@ -1,93 +1,93 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle - -#pragma once - -#include <memory> -#include <string> -#include <vector> - -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace compute { - -class Function; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include <memory> +#include <string> +#include <vector> + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +class Function; class FunctionOptionsType; - -/// \brief A mutable central function registry for built-in functions as well -/// as user-defined functions. Functions are implementations of -/// arrow::compute::Function. -/// -/// Generally, each function contains kernels which are implementations of a -/// function for a specific argument signature. After looking up a function in -/// the registry, one can either execute it eagerly with Function::Execute or -/// use one of the function's dispatch methods to pick a suitable kernel for -/// lower-level function execution. -class ARROW_EXPORT FunctionRegistry { - public: - ~FunctionRegistry(); - - /// \brief Construct a new registry. Most users only need to use the global - /// registry - static std::unique_ptr<FunctionRegistry> Make(); - - /// \brief Add a new function to the registry. Returns Status::KeyError if a - /// function with the same name is already registered - Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite = false); - - /// \brief Add aliases for the given function name. Returns Status::KeyError if the - /// function with the given name is not registered - Status AddAlias(const std::string& target_name, const std::string& source_name); - + +/// \brief A mutable central function registry for built-in functions as well +/// as user-defined functions. Functions are implementations of +/// arrow::compute::Function. +/// +/// Generally, each function contains kernels which are implementations of a +/// function for a specific argument signature. After looking up a function in +/// the registry, one can either execute it eagerly with Function::Execute or +/// use one of the function's dispatch methods to pick a suitable kernel for +/// lower-level function execution. +class ARROW_EXPORT FunctionRegistry { + public: + ~FunctionRegistry(); + + /// \brief Construct a new registry. Most users only need to use the global + /// registry + static std::unique_ptr<FunctionRegistry> Make(); + + /// \brief Add a new function to the registry. Returns Status::KeyError if a + /// function with the same name is already registered + Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite = false); + + /// \brief Add aliases for the given function name. Returns Status::KeyError if the + /// function with the given name is not registered + Status AddAlias(const std::string& target_name, const std::string& source_name); + /// \brief Add a new function options type to the registry. Returns Status::KeyError if /// a function options type with the same name is already registered Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false); - /// \brief Retrieve a function by name from the registry - Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const; - - /// \brief Return vector of all entry names in the registry. Helpful for - /// displaying a manifest of available functions - std::vector<std::string> GetFunctionNames() const; - + /// \brief Retrieve a function by name from the registry + Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const; + + /// \brief Return vector of all entry names in the registry. Helpful for + /// displaying a manifest of available functions + std::vector<std::string> GetFunctionNames() const; + /// \brief Retrieve a function options type by name from the registry Result<const FunctionOptionsType*> GetFunctionOptionsType( const std::string& name) const; - /// \brief The number of currently registered functions - int num_functions() const; - - private: - FunctionRegistry(); - - // Use PIMPL pattern to not have std::unordered_map here - class FunctionRegistryImpl; - std::unique_ptr<FunctionRegistryImpl> impl_; -}; - -/// \brief Return the process-global function registry -ARROW_EXPORT FunctionRegistry* GetFunctionRegistry(); - -} // namespace compute -} // namespace arrow + /// \brief The number of currently registered functions + int num_functions() const; + + private: + FunctionRegistry(); + + // Use PIMPL pattern to not have std::unordered_map here + class FunctionRegistryImpl; + std::unique_ptr<FunctionRegistryImpl> impl_; +}; + +/// \brief Return the process-global function registry +ARROW_EXPORT FunctionRegistry* GetFunctionRegistry(); + +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h index 892b54341d..f078bc5510 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h @@ -1,63 +1,63 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -namespace arrow { -namespace compute { - -class FunctionRegistry; - -namespace internal { - -// Built-in scalar / elementwise functions -void RegisterScalarArithmetic(FunctionRegistry* registry); -void RegisterScalarBoolean(FunctionRegistry* registry); -void RegisterScalarCast(FunctionRegistry* registry); -void RegisterScalarComparison(FunctionRegistry* registry); -void RegisterScalarNested(FunctionRegistry* registry); -void RegisterScalarSetLookup(FunctionRegistry* registry); -void RegisterScalarStringAscii(FunctionRegistry* registry); -void RegisterScalarValidity(FunctionRegistry* registry); -void RegisterScalarFillNull(FunctionRegistry* registry); +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow { +namespace compute { + +class FunctionRegistry; + +namespace internal { + +// Built-in scalar / elementwise functions +void RegisterScalarArithmetic(FunctionRegistry* registry); +void RegisterScalarBoolean(FunctionRegistry* registry); +void RegisterScalarCast(FunctionRegistry* registry); +void RegisterScalarComparison(FunctionRegistry* registry); +void RegisterScalarNested(FunctionRegistry* registry); +void RegisterScalarSetLookup(FunctionRegistry* registry); +void RegisterScalarStringAscii(FunctionRegistry* registry); +void RegisterScalarValidity(FunctionRegistry* registry); +void RegisterScalarFillNull(FunctionRegistry* registry); void RegisterScalarIfElse(FunctionRegistry* registry); void RegisterScalarTemporal(FunctionRegistry* registry); - + void RegisterScalarOptions(FunctionRegistry* registry); -// Vector functions -void RegisterVectorHash(FunctionRegistry* registry); +// Vector functions +void RegisterVectorHash(FunctionRegistry* registry); void RegisterVectorReplace(FunctionRegistry* registry); -void RegisterVectorSelection(FunctionRegistry* registry); -void RegisterVectorNested(FunctionRegistry* registry); -void RegisterVectorSort(FunctionRegistry* registry); - +void RegisterVectorSelection(FunctionRegistry* registry); +void RegisterVectorNested(FunctionRegistry* registry); +void RegisterVectorSort(FunctionRegistry* registry); + void RegisterVectorOptions(FunctionRegistry* registry); -// Aggregate functions -void RegisterScalarAggregateBasic(FunctionRegistry* registry); +// Aggregate functions +void RegisterScalarAggregateBasic(FunctionRegistry* registry); void RegisterScalarAggregateMode(FunctionRegistry* registry); void RegisterScalarAggregateQuantile(FunctionRegistry* registry); void RegisterScalarAggregateTDigest(FunctionRegistry* registry); void RegisterScalarAggregateVariance(FunctionRegistry* registry); void RegisterHashAggregateBasic(FunctionRegistry* registry); - + void RegisterAggregateOptions(FunctionRegistry* registry); -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/type_fwd.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/type_fwd.h index eebc8c1b67..167cda6a04 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/type_fwd.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/type_fwd.h @@ -1,48 +1,48 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -namespace arrow { - -struct Datum; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow { + +struct Datum; struct ValueDescr; - -namespace compute { - + +namespace compute { + class Function; class FunctionOptions; class CastOptions; struct ExecBatch; -class ExecContext; -class KernelContext; - -struct Kernel; -struct ScalarKernel; -struct ScalarAggregateKernel; -struct VectorKernel; - +class ExecContext; +class KernelContext; + +struct Kernel; +struct ScalarKernel; +struct ScalarAggregateKernel; +struct VectorKernel; + struct KernelState; class Expression; class ExecNode; class ExecPlan; -} // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/compute/util_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/compute/util_internal.h index 396c2ca2a0..4f7e43dae5 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/compute/util_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/compute/util_internal.h @@ -1,32 +1,32 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "arrow/buffer.h" - -namespace arrow { -namespace compute { -namespace internal { - -static inline void ZeroMemory(Buffer* buffer) { - std::memset(buffer->mutable_data(), 0, buffer->size()); -} - -} // namespace internal -} // namespace compute -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/buffer.h" + +namespace arrow { +namespace compute { +namespace internal { + +static inline void ZeroMemory(Buffer* buffer) { + std::memset(buffer->mutable_data(), 0, buffer->size()); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h index 7bf3931576..ed247e369b 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h @@ -1,26 +1,26 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "arrow/csv/options.h" -#include "arrow/csv/reader.h" - -// The writer depends on compute module for casting. -#ifdef ARROW_COMPUTE -#include "arrow/csv/writer.h" -#endif +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/csv/options.h" +#include "arrow/csv/reader.h" + +// The writer depends on compute module for casting. +#ifdef ARROW_COMPUTE +#include "arrow/csv/writer.h" +#endif diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.cc index b3a0dead59..900f2d6228 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.cc @@ -1,300 +1,300 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/chunker.h" - -#include <cstdint> -#include <memory> -#include <utility> - -#include "arrow/status.h" -#include "arrow/util/logging.h" -#include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" - -namespace arrow { -namespace csv { - -namespace { - -// NOTE: csvmonkey (https://github.com/dw/csvmonkey) has optimization ideas - -template <bool quoting, bool escaping> -class Lexer { - public: - enum State { - FIELD_START, - IN_FIELD, - AT_ESCAPE, - IN_QUOTED_FIELD, - AT_QUOTED_QUOTE, - AT_QUOTED_ESCAPE - }; - - explicit Lexer(const ParseOptions& options) : options_(options) { - DCHECK_EQ(quoting, options_.quoting); - DCHECK_EQ(escaping, options_.escaping); - } - - const char* ReadLine(const char* data, const char* data_end) { - // The parsing state machine - char c; - DCHECK_GT(data_end - data, 0); - if (ARROW_PREDICT_TRUE(state_ == FIELD_START)) { - goto FieldStart; - } - switch (state_) { - case FIELD_START: - goto FieldStart; - case IN_FIELD: - goto InField; - case AT_ESCAPE: - goto AtEscape; - case IN_QUOTED_FIELD: - goto InQuotedField; - case AT_QUOTED_QUOTE: - goto AtQuotedQuote; - case AT_QUOTED_ESCAPE: - goto AtQuotedEscape; - } - - FieldStart: - // At the start of a field - if (ARROW_PREDICT_FALSE(data == data_end)) { - state_ = FIELD_START; - goto AbortLine; - } - // Quoting is only recognized at start of field - if (quoting && *data == options_.quote_char) { - data++; - goto InQuotedField; - } else { - goto InField; - } - - InField: - // Inside a non-quoted part of a field - if (ARROW_PREDICT_FALSE(data == data_end)) { - state_ = IN_FIELD; - goto AbortLine; - } - c = *data++; - if (escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { - if (ARROW_PREDICT_FALSE(data == data_end)) { - state_ = AT_ESCAPE; - goto AbortLine; - } - data++; - goto InField; - } - if (ARROW_PREDICT_FALSE(c == '\r')) { - if (ARROW_PREDICT_TRUE(data != data_end) && *data == '\n') { - data++; - } - goto LineEnd; - } - if (ARROW_PREDICT_FALSE(c == '\n')) { - goto LineEnd; - } - if (ARROW_PREDICT_FALSE(c == options_.delimiter)) { - goto FieldEnd; - } - goto InField; - - AtEscape: - // Coming here if last block ended on a non-quoted escape - data++; - goto InField; - - InQuotedField: - // Inside a quoted part of a field - if (ARROW_PREDICT_FALSE(data == data_end)) { - state_ = IN_QUOTED_FIELD; - goto AbortLine; - } - c = *data++; - if (escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { - if (ARROW_PREDICT_FALSE(data == data_end)) { - state_ = AT_QUOTED_ESCAPE; - goto AbortLine; - } - data++; - goto InQuotedField; - } - if (ARROW_PREDICT_FALSE(c == options_.quote_char)) { - if (ARROW_PREDICT_FALSE(data == data_end)) { - state_ = AT_QUOTED_QUOTE; - goto AbortLine; - } - if (options_.double_quote && *data == options_.quote_char) { - // Double-quoting - data++; - } else { - // End of single-quoting - goto InField; - } - } - goto InQuotedField; - - AtQuotedEscape: - // Coming here if last block ended on a quoted escape - data++; - goto InQuotedField; - - AtQuotedQuote: - // Coming here if last block ended on a quoted quote - if (options_.double_quote && *data == options_.quote_char) { - // Double-quoting - data++; - goto InQuotedField; - } else { - // End of single-quoting - goto InField; - } - - FieldEnd: - // At the end of a field - goto FieldStart; - - LineEnd: - state_ = FIELD_START; - return data; - - AbortLine: - // Truncated line - return nullptr; - } - - protected: - const ParseOptions& options_; - State state_ = FIELD_START; -}; - -// A BoundaryFinder implementation that assumes CSV cells can contain raw newlines, -// and uses actual CSV lexing to delimit them. -template <bool quoting, bool escaping> -class LexingBoundaryFinder : public BoundaryFinder { - public: - explicit LexingBoundaryFinder(ParseOptions options) : options_(std::move(options)) {} - - Status FindFirst(util::string_view partial, util::string_view block, - int64_t* out_pos) override { - Lexer<quoting, escaping> lexer(options_); - - const char* line_end = - lexer.ReadLine(partial.data(), partial.data() + partial.size()); - DCHECK_EQ(line_end, nullptr); // Otherwise `partial` is a whole CSV line - line_end = lexer.ReadLine(block.data(), block.data() + block.size()); - - if (line_end == nullptr) { - // No complete CSV line - *out_pos = -1; - } else { - *out_pos = static_cast<int64_t>(line_end - block.data()); - DCHECK_GT(*out_pos, 0); - } - return Status::OK(); - } - - Status FindLast(util::string_view block, int64_t* out_pos) override { - Lexer<quoting, escaping> lexer(options_); - - const char* data = block.data(); - const char* const data_end = block.data() + block.size(); - - while (data < data_end) { - const char* line_end = lexer.ReadLine(data, data_end); - if (line_end == nullptr) { - // Cannot read any further - break; - } - DCHECK_GT(line_end, data); - data = line_end; - } - if (data == block.data()) { - // No complete CSV line - *out_pos = -1; - } else { - *out_pos = static_cast<int64_t>(data - block.data()); - DCHECK_GT(*out_pos, 0); - } - return Status::OK(); - } - - Status FindNth(util::string_view partial, util::string_view block, int64_t count, - int64_t* out_pos, int64_t* num_found) override { - Lexer<quoting, escaping> lexer(options_); - int64_t found = 0; - const char* data = block.data(); - const char* const data_end = block.data() + block.size(); - - const char* line_end; - if (partial.size()) { - line_end = lexer.ReadLine(partial.data(), partial.data() + partial.size()); - DCHECK_EQ(line_end, nullptr); // Otherwise `partial` is a whole CSV line - } - - for (; data < data_end && found < count; ++found) { - line_end = lexer.ReadLine(data, data_end); - if (line_end == nullptr) { - // Cannot read any further - break; - } - DCHECK_GT(line_end, data); - data = line_end; - } - - if (data == block.data()) { - // No complete CSV line - *out_pos = kNoDelimiterFound; - } else { - *out_pos = static_cast<int64_t>(data - block.data()); - } - *num_found = found; - return Status::OK(); - } - - protected: - ParseOptions options_; -}; - -} // namespace - -std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options) { - std::shared_ptr<BoundaryFinder> delimiter; - if (!options.newlines_in_values) { - delimiter = MakeNewlineBoundaryFinder(); - } else { - if (options.quoting) { - if (options.escaping) { - delimiter = std::make_shared<LexingBoundaryFinder<true, true>>(options); - } else { - delimiter = std::make_shared<LexingBoundaryFinder<true, false>>(options); - } - } else { - if (options.escaping) { - delimiter = std::make_shared<LexingBoundaryFinder<false, true>>(options); - } else { - delimiter = std::make_shared<LexingBoundaryFinder<false, false>>(options); - } - } - } - return internal::make_unique<Chunker>(std::move(delimiter)); -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/chunker.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "arrow/status.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace csv { + +namespace { + +// NOTE: csvmonkey (https://github.com/dw/csvmonkey) has optimization ideas + +template <bool quoting, bool escaping> +class Lexer { + public: + enum State { + FIELD_START, + IN_FIELD, + AT_ESCAPE, + IN_QUOTED_FIELD, + AT_QUOTED_QUOTE, + AT_QUOTED_ESCAPE + }; + + explicit Lexer(const ParseOptions& options) : options_(options) { + DCHECK_EQ(quoting, options_.quoting); + DCHECK_EQ(escaping, options_.escaping); + } + + const char* ReadLine(const char* data, const char* data_end) { + // The parsing state machine + char c; + DCHECK_GT(data_end - data, 0); + if (ARROW_PREDICT_TRUE(state_ == FIELD_START)) { + goto FieldStart; + } + switch (state_) { + case FIELD_START: + goto FieldStart; + case IN_FIELD: + goto InField; + case AT_ESCAPE: + goto AtEscape; + case IN_QUOTED_FIELD: + goto InQuotedField; + case AT_QUOTED_QUOTE: + goto AtQuotedQuote; + case AT_QUOTED_ESCAPE: + goto AtQuotedEscape; + } + + FieldStart: + // At the start of a field + if (ARROW_PREDICT_FALSE(data == data_end)) { + state_ = FIELD_START; + goto AbortLine; + } + // Quoting is only recognized at start of field + if (quoting && *data == options_.quote_char) { + data++; + goto InQuotedField; + } else { + goto InField; + } + + InField: + // Inside a non-quoted part of a field + if (ARROW_PREDICT_FALSE(data == data_end)) { + state_ = IN_FIELD; + goto AbortLine; + } + c = *data++; + if (escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { + if (ARROW_PREDICT_FALSE(data == data_end)) { + state_ = AT_ESCAPE; + goto AbortLine; + } + data++; + goto InField; + } + if (ARROW_PREDICT_FALSE(c == '\r')) { + if (ARROW_PREDICT_TRUE(data != data_end) && *data == '\n') { + data++; + } + goto LineEnd; + } + if (ARROW_PREDICT_FALSE(c == '\n')) { + goto LineEnd; + } + if (ARROW_PREDICT_FALSE(c == options_.delimiter)) { + goto FieldEnd; + } + goto InField; + + AtEscape: + // Coming here if last block ended on a non-quoted escape + data++; + goto InField; + + InQuotedField: + // Inside a quoted part of a field + if (ARROW_PREDICT_FALSE(data == data_end)) { + state_ = IN_QUOTED_FIELD; + goto AbortLine; + } + c = *data++; + if (escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { + if (ARROW_PREDICT_FALSE(data == data_end)) { + state_ = AT_QUOTED_ESCAPE; + goto AbortLine; + } + data++; + goto InQuotedField; + } + if (ARROW_PREDICT_FALSE(c == options_.quote_char)) { + if (ARROW_PREDICT_FALSE(data == data_end)) { + state_ = AT_QUOTED_QUOTE; + goto AbortLine; + } + if (options_.double_quote && *data == options_.quote_char) { + // Double-quoting + data++; + } else { + // End of single-quoting + goto InField; + } + } + goto InQuotedField; + + AtQuotedEscape: + // Coming here if last block ended on a quoted escape + data++; + goto InQuotedField; + + AtQuotedQuote: + // Coming here if last block ended on a quoted quote + if (options_.double_quote && *data == options_.quote_char) { + // Double-quoting + data++; + goto InQuotedField; + } else { + // End of single-quoting + goto InField; + } + + FieldEnd: + // At the end of a field + goto FieldStart; + + LineEnd: + state_ = FIELD_START; + return data; + + AbortLine: + // Truncated line + return nullptr; + } + + protected: + const ParseOptions& options_; + State state_ = FIELD_START; +}; + +// A BoundaryFinder implementation that assumes CSV cells can contain raw newlines, +// and uses actual CSV lexing to delimit them. +template <bool quoting, bool escaping> +class LexingBoundaryFinder : public BoundaryFinder { + public: + explicit LexingBoundaryFinder(ParseOptions options) : options_(std::move(options)) {} + + Status FindFirst(util::string_view partial, util::string_view block, + int64_t* out_pos) override { + Lexer<quoting, escaping> lexer(options_); + + const char* line_end = + lexer.ReadLine(partial.data(), partial.data() + partial.size()); + DCHECK_EQ(line_end, nullptr); // Otherwise `partial` is a whole CSV line + line_end = lexer.ReadLine(block.data(), block.data() + block.size()); + + if (line_end == nullptr) { + // No complete CSV line + *out_pos = -1; + } else { + *out_pos = static_cast<int64_t>(line_end - block.data()); + DCHECK_GT(*out_pos, 0); + } + return Status::OK(); + } + + Status FindLast(util::string_view block, int64_t* out_pos) override { + Lexer<quoting, escaping> lexer(options_); + + const char* data = block.data(); + const char* const data_end = block.data() + block.size(); + + while (data < data_end) { + const char* line_end = lexer.ReadLine(data, data_end); + if (line_end == nullptr) { + // Cannot read any further + break; + } + DCHECK_GT(line_end, data); + data = line_end; + } + if (data == block.data()) { + // No complete CSV line + *out_pos = -1; + } else { + *out_pos = static_cast<int64_t>(data - block.data()); + DCHECK_GT(*out_pos, 0); + } + return Status::OK(); + } + + Status FindNth(util::string_view partial, util::string_view block, int64_t count, + int64_t* out_pos, int64_t* num_found) override { + Lexer<quoting, escaping> lexer(options_); + int64_t found = 0; + const char* data = block.data(); + const char* const data_end = block.data() + block.size(); + + const char* line_end; + if (partial.size()) { + line_end = lexer.ReadLine(partial.data(), partial.data() + partial.size()); + DCHECK_EQ(line_end, nullptr); // Otherwise `partial` is a whole CSV line + } + + for (; data < data_end && found < count; ++found) { + line_end = lexer.ReadLine(data, data_end); + if (line_end == nullptr) { + // Cannot read any further + break; + } + DCHECK_GT(line_end, data); + data = line_end; + } + + if (data == block.data()) { + // No complete CSV line + *out_pos = kNoDelimiterFound; + } else { + *out_pos = static_cast<int64_t>(data - block.data()); + } + *num_found = found; + return Status::OK(); + } + + protected: + ParseOptions options_; +}; + +} // namespace + +std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options) { + std::shared_ptr<BoundaryFinder> delimiter; + if (!options.newlines_in_values) { + delimiter = MakeNewlineBoundaryFinder(); + } else { + if (options.quoting) { + if (options.escaping) { + delimiter = std::make_shared<LexingBoundaryFinder<true, true>>(options); + } else { + delimiter = std::make_shared<LexingBoundaryFinder<true, false>>(options); + } + } else { + if (options.escaping) { + delimiter = std::make_shared<LexingBoundaryFinder<false, true>>(options); + } else { + delimiter = std::make_shared<LexingBoundaryFinder<false, false>>(options); + } + } + } + return internal::make_unique<Chunker>(std::move(delimiter)); +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.h index 662b16ec40..bcebd6572e 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/chunker.h @@ -1,36 +1,36 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> -#include <memory> - -#include "arrow/csv/options.h" -#include "arrow/status.h" -#include "arrow/util/delimiting.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace csv { - -ARROW_EXPORT -std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options); - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> + +#include "arrow/csv/options.h" +#include "arrow/status.h" +#include "arrow/util/delimiting.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +ARROW_EXPORT +std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options); + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.cc index bc97442873..910ca1980c 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.cc @@ -1,367 +1,367 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include <cstddef> -#include <cstdint> -#include <memory> -#include <mutex> -#include <sstream> -#include <string> -#include <utility> -#include <vector> - -#include "arrow/array.h" -#include "arrow/array/builder_base.h" -#include "arrow/chunked_array.h" -#include "arrow/csv/column_builder.h" -#include "arrow/csv/converter.h" -#include "arrow/csv/inference_internal.h" -#include "arrow/csv/options.h" -#include "arrow/csv/parser.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/logging.h" -#include "arrow/util/task_group.h" - -namespace arrow { -namespace csv { - -class BlockParser; - -using internal::TaskGroup; - -class ConcreteColumnBuilder : public ColumnBuilder { - public: - explicit ConcreteColumnBuilder(MemoryPool* pool, - std::shared_ptr<internal::TaskGroup> task_group, - int32_t col_index = -1) - : ColumnBuilder(std::move(task_group)), pool_(pool), col_index_(col_index) {} - - void Append(const std::shared_ptr<BlockParser>& parser) override { - Insert(static_cast<int64_t>(chunks_.size()), parser); - } - - Result<std::shared_ptr<ChunkedArray>> Finish() override { - std::lock_guard<std::mutex> lock(mutex_); - - return FinishUnlocked(); - } - - protected: - virtual std::shared_ptr<DataType> type() const = 0; - - Result<std::shared_ptr<ChunkedArray>> FinishUnlocked() { - auto type = this->type(); - for (const auto& chunk : chunks_) { - if (chunk == nullptr) { - return Status::UnknownError("a chunk failed converting for an unknown reason"); - } - DCHECK_EQ(chunk->type()->id(), type->id()) << "Chunk types not equal!"; - } - return std::make_shared<ChunkedArray>(chunks_, std::move(type)); - } - - void ReserveChunks(int64_t block_index) { - // Create a null Array pointer at the back at the list. - std::lock_guard<std::mutex> lock(mutex_); - ReserveChunksUnlocked(block_index); - } - - void ReserveChunksUnlocked(int64_t block_index) { - // Create a null Array pointer at the back at the list. - size_t chunk_index = static_cast<size_t>(block_index); - if (chunks_.size() <= chunk_index) { - chunks_.resize(chunk_index + 1); - } - } - - Status SetChunk(int64_t chunk_index, Result<std::shared_ptr<Array>> maybe_array) { - std::lock_guard<std::mutex> lock(mutex_); - return SetChunkUnlocked(chunk_index, std::move(maybe_array)); - } - - Status SetChunkUnlocked(int64_t chunk_index, - Result<std::shared_ptr<Array>> maybe_array) { - // Should not insert an already built chunk - DCHECK_EQ(chunks_[chunk_index], nullptr); - - if (maybe_array.ok()) { - chunks_[chunk_index] = *std::move(maybe_array); - return Status::OK(); - } else { - return WrapConversionError(maybe_array.status()); - } - } - - Status WrapConversionError(const Status& st) { - if (ARROW_PREDICT_TRUE(st.ok())) { - return st; - } else { - std::stringstream ss; - ss << "In CSV column #" << col_index_ << ": " << st.message(); - return st.WithMessage(ss.str()); - } - } - - MemoryPool* pool_; - int32_t col_index_; - - ArrayVector chunks_; - - std::mutex mutex_; -}; - -////////////////////////////////////////////////////////////////////////// -// Null column builder implementation (for a column not in the CSV file) - -class NullColumnBuilder : public ConcreteColumnBuilder { - public: - explicit NullColumnBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool, - const std::shared_ptr<internal::TaskGroup>& task_group) - : ConcreteColumnBuilder(pool, task_group), type_(type) {} - - void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override; - - protected: - std::shared_ptr<DataType> type() const override { return type_; } - - std::shared_ptr<DataType> type_; -}; - -void NullColumnBuilder::Insert(int64_t block_index, - const std::shared_ptr<BlockParser>& parser) { - ReserveChunks(block_index); - - // Spawn a task that will build an array of nulls with the right DataType - const int32_t num_rows = parser->num_rows(); - DCHECK_GE(num_rows, 0); - - task_group_->Append([=]() -> Status { - std::unique_ptr<ArrayBuilder> builder; - RETURN_NOT_OK(MakeBuilder(pool_, type_, &builder)); - std::shared_ptr<Array> res; - RETURN_NOT_OK(builder->AppendNulls(num_rows)); - RETURN_NOT_OK(builder->Finish(&res)); - - return SetChunk(block_index, res); - }); -} - -////////////////////////////////////////////////////////////////////////// -// Pre-typed column builder implementation - -class TypedColumnBuilder : public ConcreteColumnBuilder { - public: - TypedColumnBuilder(const std::shared_ptr<DataType>& type, int32_t col_index, - const ConvertOptions& options, MemoryPool* pool, - const std::shared_ptr<internal::TaskGroup>& task_group) - : ConcreteColumnBuilder(pool, task_group, col_index), - type_(type), - options_(options) {} - - Status Init(); - - void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override; - - protected: - std::shared_ptr<DataType> type() const override { return type_; } - - std::shared_ptr<DataType> type_; - // CAUTION: ConvertOptions can grow large (if it customizes hundreds or - // thousands of columns), so avoid copying it in each TypedColumnBuilder. - const ConvertOptions& options_; - - std::shared_ptr<Converter> converter_; -}; - -Status TypedColumnBuilder::Init() { - ARROW_ASSIGN_OR_RAISE(converter_, Converter::Make(type_, options_, pool_)); - return Status::OK(); -} - -void TypedColumnBuilder::Insert(int64_t block_index, - const std::shared_ptr<BlockParser>& parser) { - DCHECK_NE(converter_, nullptr); - - ReserveChunks(block_index); - - // We're careful that all references in the closure outlive the Append() call - task_group_->Append([=]() -> Status { - return SetChunk(block_index, converter_->Convert(*parser, col_index_)); - }); -} - -////////////////////////////////////////////////////////////////////////// -// Type-inferring column builder implementation - -class InferringColumnBuilder : public ConcreteColumnBuilder { - public: - InferringColumnBuilder(int32_t col_index, const ConvertOptions& options, - MemoryPool* pool, - const std::shared_ptr<internal::TaskGroup>& task_group) - : ConcreteColumnBuilder(pool, task_group, col_index), - options_(options), - infer_status_(options) {} - - Status Init(); - - void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override; - Result<std::shared_ptr<ChunkedArray>> Finish() override; - - protected: - std::shared_ptr<DataType> type() const override { - DCHECK_NE(converter_, nullptr); - return converter_->type(); - } - - Status UpdateType(); - Status TryConvertChunk(int64_t chunk_index); - // This must be called unlocked! - void ScheduleConvertChunk(int64_t chunk_index); - - // CAUTION: ConvertOptions can grow large (if it customizes hundreds or - // thousands of columns), so avoid copying it in each InferringColumnBuilder. - const ConvertOptions& options_; - - // Current inference status - InferStatus infer_status_; - std::shared_ptr<Converter> converter_; - - // The parsers corresponding to each chunk (for reconverting) - std::vector<std::shared_ptr<BlockParser>> parsers_; -}; - -Status InferringColumnBuilder::Init() { return UpdateType(); } - -Status InferringColumnBuilder::UpdateType() { - return infer_status_.MakeConverter(pool_).Value(&converter_); -} - -void InferringColumnBuilder::ScheduleConvertChunk(int64_t chunk_index) { - task_group_->Append([=]() { return TryConvertChunk(chunk_index); }); -} - -Status InferringColumnBuilder::TryConvertChunk(int64_t chunk_index) { - std::unique_lock<std::mutex> lock(mutex_); - std::shared_ptr<Converter> converter = converter_; - std::shared_ptr<BlockParser> parser = parsers_[chunk_index]; - InferKind kind = infer_status_.kind(); - - DCHECK_NE(parser, nullptr); - - lock.unlock(); - auto maybe_array = converter->Convert(*parser, col_index_); - lock.lock(); - - if (kind != infer_status_.kind()) { - // infer_kind_ was changed by another task, reconvert - lock.unlock(); - ScheduleConvertChunk(chunk_index); - return Status::OK(); - } - - if (maybe_array.ok() || !infer_status_.can_loosen_type()) { - // Conversion succeeded, or failed definitively - if (!infer_status_.can_loosen_type()) { - // We won't try to reconvert anymore - parsers_[chunk_index].reset(); - } - return SetChunkUnlocked(chunk_index, maybe_array); - } - - // Conversion failed, try another type - infer_status_.LoosenType(maybe_array.status()); - RETURN_NOT_OK(UpdateType()); - - // Reconvert past finished chunks - // (unfinished chunks will notice by themselves if they need reconverting) - const auto nchunks = static_cast<int64_t>(chunks_.size()); - for (int64_t i = 0; i < nchunks; ++i) { - if (i != chunk_index && chunks_[i]) { - // We're assuming the chunk was converted using the wrong type - // (which should be true unless the executor reorders tasks) - chunks_[i].reset(); - lock.unlock(); - ScheduleConvertChunk(i); - lock.lock(); - } - } - - // Reconvert this chunk - lock.unlock(); - ScheduleConvertChunk(chunk_index); - - return Status::OK(); -} - -void InferringColumnBuilder::Insert(int64_t block_index, - const std::shared_ptr<BlockParser>& parser) { - // Create a slot for the new chunk and spawn a task to convert it - size_t chunk_index = static_cast<size_t>(block_index); - { - std::lock_guard<std::mutex> lock(mutex_); - - DCHECK_NE(converter_, nullptr); - if (parsers_.size() <= chunk_index) { - parsers_.resize(chunk_index + 1); - } - // Should not insert an already converting chunk - DCHECK_EQ(parsers_[chunk_index], nullptr); - parsers_[chunk_index] = parser; - ReserveChunksUnlocked(block_index); - } - - ScheduleConvertChunk(chunk_index); -} - -Result<std::shared_ptr<ChunkedArray>> InferringColumnBuilder::Finish() { - std::lock_guard<std::mutex> lock(mutex_); - - parsers_.clear(); - return FinishUnlocked(); -} - -////////////////////////////////////////////////////////////////////////// -// Factory functions - -Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::Make( - MemoryPool* pool, const std::shared_ptr<DataType>& type, int32_t col_index, - const ConvertOptions& options, const std::shared_ptr<TaskGroup>& task_group) { - auto ptr = - std::make_shared<TypedColumnBuilder>(type, col_index, options, pool, task_group); - RETURN_NOT_OK(ptr->Init()); - return ptr; -} - -Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::Make( - MemoryPool* pool, int32_t col_index, const ConvertOptions& options, - const std::shared_ptr<TaskGroup>& task_group) { - auto ptr = - std::make_shared<InferringColumnBuilder>(col_index, options, pool, task_group); - RETURN_NOT_OK(ptr->Init()); - return ptr; -} - -Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::MakeNull( - MemoryPool* pool, const std::shared_ptr<DataType>& type, - const std::shared_ptr<internal::TaskGroup>& task_group) { - return std::make_shared<NullColumnBuilder>(type, pool, task_group); -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <mutex> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/array.h" +#include "arrow/array/builder_base.h" +#include "arrow/chunked_array.h" +#include "arrow/csv/column_builder.h" +#include "arrow/csv/converter.h" +#include "arrow/csv/inference_internal.h" +#include "arrow/csv/options.h" +#include "arrow/csv/parser.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/logging.h" +#include "arrow/util/task_group.h" + +namespace arrow { +namespace csv { + +class BlockParser; + +using internal::TaskGroup; + +class ConcreteColumnBuilder : public ColumnBuilder { + public: + explicit ConcreteColumnBuilder(MemoryPool* pool, + std::shared_ptr<internal::TaskGroup> task_group, + int32_t col_index = -1) + : ColumnBuilder(std::move(task_group)), pool_(pool), col_index_(col_index) {} + + void Append(const std::shared_ptr<BlockParser>& parser) override { + Insert(static_cast<int64_t>(chunks_.size()), parser); + } + + Result<std::shared_ptr<ChunkedArray>> Finish() override { + std::lock_guard<std::mutex> lock(mutex_); + + return FinishUnlocked(); + } + + protected: + virtual std::shared_ptr<DataType> type() const = 0; + + Result<std::shared_ptr<ChunkedArray>> FinishUnlocked() { + auto type = this->type(); + for (const auto& chunk : chunks_) { + if (chunk == nullptr) { + return Status::UnknownError("a chunk failed converting for an unknown reason"); + } + DCHECK_EQ(chunk->type()->id(), type->id()) << "Chunk types not equal!"; + } + return std::make_shared<ChunkedArray>(chunks_, std::move(type)); + } + + void ReserveChunks(int64_t block_index) { + // Create a null Array pointer at the back at the list. + std::lock_guard<std::mutex> lock(mutex_); + ReserveChunksUnlocked(block_index); + } + + void ReserveChunksUnlocked(int64_t block_index) { + // Create a null Array pointer at the back at the list. + size_t chunk_index = static_cast<size_t>(block_index); + if (chunks_.size() <= chunk_index) { + chunks_.resize(chunk_index + 1); + } + } + + Status SetChunk(int64_t chunk_index, Result<std::shared_ptr<Array>> maybe_array) { + std::lock_guard<std::mutex> lock(mutex_); + return SetChunkUnlocked(chunk_index, std::move(maybe_array)); + } + + Status SetChunkUnlocked(int64_t chunk_index, + Result<std::shared_ptr<Array>> maybe_array) { + // Should not insert an already built chunk + DCHECK_EQ(chunks_[chunk_index], nullptr); + + if (maybe_array.ok()) { + chunks_[chunk_index] = *std::move(maybe_array); + return Status::OK(); + } else { + return WrapConversionError(maybe_array.status()); + } + } + + Status WrapConversionError(const Status& st) { + if (ARROW_PREDICT_TRUE(st.ok())) { + return st; + } else { + std::stringstream ss; + ss << "In CSV column #" << col_index_ << ": " << st.message(); + return st.WithMessage(ss.str()); + } + } + + MemoryPool* pool_; + int32_t col_index_; + + ArrayVector chunks_; + + std::mutex mutex_; +}; + +////////////////////////////////////////////////////////////////////////// +// Null column builder implementation (for a column not in the CSV file) + +class NullColumnBuilder : public ConcreteColumnBuilder { + public: + explicit NullColumnBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool, + const std::shared_ptr<internal::TaskGroup>& task_group) + : ConcreteColumnBuilder(pool, task_group), type_(type) {} + + void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override; + + protected: + std::shared_ptr<DataType> type() const override { return type_; } + + std::shared_ptr<DataType> type_; +}; + +void NullColumnBuilder::Insert(int64_t block_index, + const std::shared_ptr<BlockParser>& parser) { + ReserveChunks(block_index); + + // Spawn a task that will build an array of nulls with the right DataType + const int32_t num_rows = parser->num_rows(); + DCHECK_GE(num_rows, 0); + + task_group_->Append([=]() -> Status { + std::unique_ptr<ArrayBuilder> builder; + RETURN_NOT_OK(MakeBuilder(pool_, type_, &builder)); + std::shared_ptr<Array> res; + RETURN_NOT_OK(builder->AppendNulls(num_rows)); + RETURN_NOT_OK(builder->Finish(&res)); + + return SetChunk(block_index, res); + }); +} + +////////////////////////////////////////////////////////////////////////// +// Pre-typed column builder implementation + +class TypedColumnBuilder : public ConcreteColumnBuilder { + public: + TypedColumnBuilder(const std::shared_ptr<DataType>& type, int32_t col_index, + const ConvertOptions& options, MemoryPool* pool, + const std::shared_ptr<internal::TaskGroup>& task_group) + : ConcreteColumnBuilder(pool, task_group, col_index), + type_(type), + options_(options) {} + + Status Init(); + + void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override; + + protected: + std::shared_ptr<DataType> type() const override { return type_; } + + std::shared_ptr<DataType> type_; + // CAUTION: ConvertOptions can grow large (if it customizes hundreds or + // thousands of columns), so avoid copying it in each TypedColumnBuilder. + const ConvertOptions& options_; + + std::shared_ptr<Converter> converter_; +}; + +Status TypedColumnBuilder::Init() { + ARROW_ASSIGN_OR_RAISE(converter_, Converter::Make(type_, options_, pool_)); + return Status::OK(); +} + +void TypedColumnBuilder::Insert(int64_t block_index, + const std::shared_ptr<BlockParser>& parser) { + DCHECK_NE(converter_, nullptr); + + ReserveChunks(block_index); + + // We're careful that all references in the closure outlive the Append() call + task_group_->Append([=]() -> Status { + return SetChunk(block_index, converter_->Convert(*parser, col_index_)); + }); +} + +////////////////////////////////////////////////////////////////////////// +// Type-inferring column builder implementation + +class InferringColumnBuilder : public ConcreteColumnBuilder { + public: + InferringColumnBuilder(int32_t col_index, const ConvertOptions& options, + MemoryPool* pool, + const std::shared_ptr<internal::TaskGroup>& task_group) + : ConcreteColumnBuilder(pool, task_group, col_index), + options_(options), + infer_status_(options) {} + + Status Init(); + + void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override; + Result<std::shared_ptr<ChunkedArray>> Finish() override; + + protected: + std::shared_ptr<DataType> type() const override { + DCHECK_NE(converter_, nullptr); + return converter_->type(); + } + + Status UpdateType(); + Status TryConvertChunk(int64_t chunk_index); + // This must be called unlocked! + void ScheduleConvertChunk(int64_t chunk_index); + + // CAUTION: ConvertOptions can grow large (if it customizes hundreds or + // thousands of columns), so avoid copying it in each InferringColumnBuilder. + const ConvertOptions& options_; + + // Current inference status + InferStatus infer_status_; + std::shared_ptr<Converter> converter_; + + // The parsers corresponding to each chunk (for reconverting) + std::vector<std::shared_ptr<BlockParser>> parsers_; +}; + +Status InferringColumnBuilder::Init() { return UpdateType(); } + +Status InferringColumnBuilder::UpdateType() { + return infer_status_.MakeConverter(pool_).Value(&converter_); +} + +void InferringColumnBuilder::ScheduleConvertChunk(int64_t chunk_index) { + task_group_->Append([=]() { return TryConvertChunk(chunk_index); }); +} + +Status InferringColumnBuilder::TryConvertChunk(int64_t chunk_index) { + std::unique_lock<std::mutex> lock(mutex_); + std::shared_ptr<Converter> converter = converter_; + std::shared_ptr<BlockParser> parser = parsers_[chunk_index]; + InferKind kind = infer_status_.kind(); + + DCHECK_NE(parser, nullptr); + + lock.unlock(); + auto maybe_array = converter->Convert(*parser, col_index_); + lock.lock(); + + if (kind != infer_status_.kind()) { + // infer_kind_ was changed by another task, reconvert + lock.unlock(); + ScheduleConvertChunk(chunk_index); + return Status::OK(); + } + + if (maybe_array.ok() || !infer_status_.can_loosen_type()) { + // Conversion succeeded, or failed definitively + if (!infer_status_.can_loosen_type()) { + // We won't try to reconvert anymore + parsers_[chunk_index].reset(); + } + return SetChunkUnlocked(chunk_index, maybe_array); + } + + // Conversion failed, try another type + infer_status_.LoosenType(maybe_array.status()); + RETURN_NOT_OK(UpdateType()); + + // Reconvert past finished chunks + // (unfinished chunks will notice by themselves if they need reconverting) + const auto nchunks = static_cast<int64_t>(chunks_.size()); + for (int64_t i = 0; i < nchunks; ++i) { + if (i != chunk_index && chunks_[i]) { + // We're assuming the chunk was converted using the wrong type + // (which should be true unless the executor reorders tasks) + chunks_[i].reset(); + lock.unlock(); + ScheduleConvertChunk(i); + lock.lock(); + } + } + + // Reconvert this chunk + lock.unlock(); + ScheduleConvertChunk(chunk_index); + + return Status::OK(); +} + +void InferringColumnBuilder::Insert(int64_t block_index, + const std::shared_ptr<BlockParser>& parser) { + // Create a slot for the new chunk and spawn a task to convert it + size_t chunk_index = static_cast<size_t>(block_index); + { + std::lock_guard<std::mutex> lock(mutex_); + + DCHECK_NE(converter_, nullptr); + if (parsers_.size() <= chunk_index) { + parsers_.resize(chunk_index + 1); + } + // Should not insert an already converting chunk + DCHECK_EQ(parsers_[chunk_index], nullptr); + parsers_[chunk_index] = parser; + ReserveChunksUnlocked(block_index); + } + + ScheduleConvertChunk(chunk_index); +} + +Result<std::shared_ptr<ChunkedArray>> InferringColumnBuilder::Finish() { + std::lock_guard<std::mutex> lock(mutex_); + + parsers_.clear(); + return FinishUnlocked(); +} + +////////////////////////////////////////////////////////////////////////// +// Factory functions + +Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::Make( + MemoryPool* pool, const std::shared_ptr<DataType>& type, int32_t col_index, + const ConvertOptions& options, const std::shared_ptr<TaskGroup>& task_group) { + auto ptr = + std::make_shared<TypedColumnBuilder>(type, col_index, options, pool, task_group); + RETURN_NOT_OK(ptr->Init()); + return ptr; +} + +Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::Make( + MemoryPool* pool, int32_t col_index, const ConvertOptions& options, + const std::shared_ptr<TaskGroup>& task_group) { + auto ptr = + std::make_shared<InferringColumnBuilder>(col_index, options, pool, task_group); + RETURN_NOT_OK(ptr->Init()); + return ptr; +} + +Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::MakeNull( + MemoryPool* pool, const std::shared_ptr<DataType>& type, + const std::shared_ptr<internal::TaskGroup>& task_group) { + return std::make_shared<NullColumnBuilder>(type, pool, task_group); +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.h index 170a8ad067..72bb46586e 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_builder.h @@ -1,78 +1,78 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> -#include <memory> -#include <utility> - -#include "arrow/result.h" -#include "arrow/type_fwd.h" -#include "arrow/util/type_fwd.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace csv { - -class BlockParser; -struct ConvertOptions; - -class ARROW_EXPORT ColumnBuilder { - public: - virtual ~ColumnBuilder() = default; - - /// Spawn a task that will try to convert and append the given CSV block. - /// All calls to Append() should happen on the same thread, otherwise - /// call Insert() instead. - virtual void Append(const std::shared_ptr<BlockParser>& parser) = 0; - - /// Spawn a task that will try to convert and insert the given CSV block - virtual void Insert(int64_t block_index, - const std::shared_ptr<BlockParser>& parser) = 0; - - /// Return the final chunked array. The TaskGroup _must_ have finished! - virtual Result<std::shared_ptr<ChunkedArray>> Finish() = 0; - - std::shared_ptr<internal::TaskGroup> task_group() { return task_group_; } - - /// Construct a strictly-typed ColumnBuilder. - static Result<std::shared_ptr<ColumnBuilder>> Make( - MemoryPool* pool, const std::shared_ptr<DataType>& type, int32_t col_index, - const ConvertOptions& options, - const std::shared_ptr<internal::TaskGroup>& task_group); - - /// Construct a type-inferring ColumnBuilder. - static Result<std::shared_ptr<ColumnBuilder>> Make( - MemoryPool* pool, int32_t col_index, const ConvertOptions& options, - const std::shared_ptr<internal::TaskGroup>& task_group); - - /// Construct a ColumnBuilder for a column of nulls - /// (i.e. not present in the CSV file). - static Result<std::shared_ptr<ColumnBuilder>> MakeNull( - MemoryPool* pool, const std::shared_ptr<DataType>& type, - const std::shared_ptr<internal::TaskGroup>& task_group); - - protected: - explicit ColumnBuilder(std::shared_ptr<internal::TaskGroup> task_group) - : task_group_(std::move(task_group)) {} - - std::shared_ptr<internal::TaskGroup> task_group_; -}; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> +#include <utility> + +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +class BlockParser; +struct ConvertOptions; + +class ARROW_EXPORT ColumnBuilder { + public: + virtual ~ColumnBuilder() = default; + + /// Spawn a task that will try to convert and append the given CSV block. + /// All calls to Append() should happen on the same thread, otherwise + /// call Insert() instead. + virtual void Append(const std::shared_ptr<BlockParser>& parser) = 0; + + /// Spawn a task that will try to convert and insert the given CSV block + virtual void Insert(int64_t block_index, + const std::shared_ptr<BlockParser>& parser) = 0; + + /// Return the final chunked array. The TaskGroup _must_ have finished! + virtual Result<std::shared_ptr<ChunkedArray>> Finish() = 0; + + std::shared_ptr<internal::TaskGroup> task_group() { return task_group_; } + + /// Construct a strictly-typed ColumnBuilder. + static Result<std::shared_ptr<ColumnBuilder>> Make( + MemoryPool* pool, const std::shared_ptr<DataType>& type, int32_t col_index, + const ConvertOptions& options, + const std::shared_ptr<internal::TaskGroup>& task_group); + + /// Construct a type-inferring ColumnBuilder. + static Result<std::shared_ptr<ColumnBuilder>> Make( + MemoryPool* pool, int32_t col_index, const ConvertOptions& options, + const std::shared_ptr<internal::TaskGroup>& task_group); + + /// Construct a ColumnBuilder for a column of nulls + /// (i.e. not present in the CSV file). + static Result<std::shared_ptr<ColumnBuilder>> MakeNull( + MemoryPool* pool, const std::shared_ptr<DataType>& type, + const std::shared_ptr<internal::TaskGroup>& task_group); + + protected: + explicit ColumnBuilder(std::shared_ptr<internal::TaskGroup> task_group) + : task_group_(std::move(task_group)) {} + + std::shared_ptr<internal::TaskGroup> task_group_; +}; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.cc index 436d703a9c..70d8e90b35 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.cc @@ -1,243 +1,243 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/column_decoder.h" - -#include <cstddef> -#include <cstdint> -#include <memory> -#include <sstream> -#include <string> -#include <utility> - -#include "arrow/array.h" -#include "arrow/array/builder_base.h" -#include "arrow/csv/converter.h" -#include "arrow/csv/inference_internal.h" -#include "arrow/csv/options.h" -#include "arrow/csv/parser.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/future.h" -#include "arrow/util/logging.h" -#include "arrow/util/task_group.h" - -namespace arrow { -namespace csv { - -using internal::TaskGroup; - -class ConcreteColumnDecoder : public ColumnDecoder { - public: - explicit ConcreteColumnDecoder(MemoryPool* pool, int32_t col_index = -1) - : ColumnDecoder(), pool_(pool), col_index_(col_index) {} - - protected: - // XXX useful? - virtual std::shared_ptr<DataType> type() const = 0; - - Result<std::shared_ptr<Array>> WrapConversionError( - const Result<std::shared_ptr<Array>>& result) { - if (ARROW_PREDICT_TRUE(result.ok())) { - return result; - } else { - const auto& st = result.status(); - std::stringstream ss; - ss << "In CSV column #" << col_index_ << ": " << st.message(); - return st.WithMessage(ss.str()); - } - } - - MemoryPool* pool_; - int32_t col_index_; - internal::Executor* executor_; -}; - -////////////////////////////////////////////////////////////////////////// -// Null column decoder implementation (for a column not in the CSV file) - -class NullColumnDecoder : public ConcreteColumnDecoder { - public: - explicit NullColumnDecoder(const std::shared_ptr<DataType>& type, MemoryPool* pool) - : ConcreteColumnDecoder(pool), type_(type) {} - - Future<std::shared_ptr<Array>> Decode( - const std::shared_ptr<BlockParser>& parser) override; - - protected: - std::shared_ptr<DataType> type() const override { return type_; } - - std::shared_ptr<DataType> type_; -}; - -Future<std::shared_ptr<Array>> NullColumnDecoder::Decode( - const std::shared_ptr<BlockParser>& parser) { - DCHECK_GE(parser->num_rows(), 0); - return WrapConversionError(MakeArrayOfNull(type_, parser->num_rows(), pool_)); -} - -////////////////////////////////////////////////////////////////////////// -// Pre-typed column decoder implementation - -class TypedColumnDecoder : public ConcreteColumnDecoder { - public: - TypedColumnDecoder(const std::shared_ptr<DataType>& type, int32_t col_index, - const ConvertOptions& options, MemoryPool* pool) - : ConcreteColumnDecoder(pool, col_index), type_(type), options_(options) {} - - Status Init(); - - Future<std::shared_ptr<Array>> Decode( - const std::shared_ptr<BlockParser>& parser) override; - - protected: - std::shared_ptr<DataType> type() const override { return type_; } - - std::shared_ptr<DataType> type_; - // CAUTION: ConvertOptions can grow large (if it customizes hundreds or - // thousands of columns), so avoid copying it in each TypedColumnDecoder. - const ConvertOptions& options_; - - std::shared_ptr<Converter> converter_; -}; - -Status TypedColumnDecoder::Init() { - ARROW_ASSIGN_OR_RAISE(converter_, Converter::Make(type_, options_, pool_)); - return Status::OK(); -} - -Future<std::shared_ptr<Array>> TypedColumnDecoder::Decode( - const std::shared_ptr<BlockParser>& parser) { - DCHECK_NE(converter_, nullptr); - return Future<std::shared_ptr<Array>>::MakeFinished( - WrapConversionError(converter_->Convert(*parser, col_index_))); -} - -////////////////////////////////////////////////////////////////////////// -// Type-inferring column builder implementation - -class InferringColumnDecoder : public ConcreteColumnDecoder { - public: - InferringColumnDecoder(int32_t col_index, const ConvertOptions& options, - MemoryPool* pool) - : ConcreteColumnDecoder(pool, col_index), - options_(options), - infer_status_(options), - type_frozen_(false) { - first_inference_run_ = Future<>::Make(); - first_inferrer_ = 0; - } - - Status Init(); - - Future<std::shared_ptr<Array>> Decode( - const std::shared_ptr<BlockParser>& parser) override; - - protected: - std::shared_ptr<DataType> type() const override { - DCHECK_NE(converter_, nullptr); - return converter_->type(); - } - - Status UpdateType(); - Result<std::shared_ptr<Array>> RunInference(const std::shared_ptr<BlockParser>& parser); - - // CAUTION: ConvertOptions can grow large (if it customizes hundreds or - // thousands of columns), so avoid copying it in each InferringColumnDecoder. - const ConvertOptions& options_; - - // Current inference status - InferStatus infer_status_; - bool type_frozen_; - std::atomic<int> first_inferrer_; - Future<> first_inference_run_; - std::shared_ptr<Converter> converter_; -}; - -Status InferringColumnDecoder::Init() { return UpdateType(); } - -Status InferringColumnDecoder::UpdateType() { - return infer_status_.MakeConverter(pool_).Value(&converter_); -} - -Result<std::shared_ptr<Array>> InferringColumnDecoder::RunInference( - const std::shared_ptr<BlockParser>& parser) { - while (true) { - // (no one else should be updating converter_ concurrently) - auto maybe_array = converter_->Convert(*parser, col_index_); - - if (maybe_array.ok() || !infer_status_.can_loosen_type()) { - // Conversion succeeded, or failed definitively - DCHECK(!type_frozen_); - type_frozen_ = true; - return maybe_array; - } - // Conversion failed temporarily, try another type - infer_status_.LoosenType(maybe_array.status()); - auto update_status = UpdateType(); - if (!update_status.ok()) { - return update_status; - } - } -} - -Future<std::shared_ptr<Array>> InferringColumnDecoder::Decode( - const std::shared_ptr<BlockParser>& parser) { - bool already_taken = first_inferrer_.fetch_or(1); - // First block: run inference - if (!already_taken) { - auto maybe_array = RunInference(parser); - first_inference_run_.MarkFinished(); - return Future<std::shared_ptr<Array>>::MakeFinished(std::move(maybe_array)); - } - - // Non-first block: wait for inference to finish on first block now, - // without blocking a TaskGroup thread. - return first_inference_run_.Then([this, parser] { - DCHECK(type_frozen_); - auto maybe_array = converter_->Convert(*parser, col_index_); - return WrapConversionError(converter_->Convert(*parser, col_index_)); - }); -} - -////////////////////////////////////////////////////////////////////////// -// Factory functions - -Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::Make( - MemoryPool* pool, int32_t col_index, const ConvertOptions& options) { - auto ptr = std::make_shared<InferringColumnDecoder>(col_index, options, pool); - RETURN_NOT_OK(ptr->Init()); - return ptr; -} - -Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::Make( - MemoryPool* pool, std::shared_ptr<DataType> type, int32_t col_index, - const ConvertOptions& options) { - auto ptr = - std::make_shared<TypedColumnDecoder>(std::move(type), col_index, options, pool); - RETURN_NOT_OK(ptr->Init()); - return ptr; -} - -Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::MakeNull( - MemoryPool* pool, std::shared_ptr<DataType> type) { - return std::make_shared<NullColumnDecoder>(std::move(type), pool); -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/column_decoder.h" + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <sstream> +#include <string> +#include <utility> + +#include "arrow/array.h" +#include "arrow/array/builder_base.h" +#include "arrow/csv/converter.h" +#include "arrow/csv/inference_internal.h" +#include "arrow/csv/options.h" +#include "arrow/csv/parser.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/task_group.h" + +namespace arrow { +namespace csv { + +using internal::TaskGroup; + +class ConcreteColumnDecoder : public ColumnDecoder { + public: + explicit ConcreteColumnDecoder(MemoryPool* pool, int32_t col_index = -1) + : ColumnDecoder(), pool_(pool), col_index_(col_index) {} + + protected: + // XXX useful? + virtual std::shared_ptr<DataType> type() const = 0; + + Result<std::shared_ptr<Array>> WrapConversionError( + const Result<std::shared_ptr<Array>>& result) { + if (ARROW_PREDICT_TRUE(result.ok())) { + return result; + } else { + const auto& st = result.status(); + std::stringstream ss; + ss << "In CSV column #" << col_index_ << ": " << st.message(); + return st.WithMessage(ss.str()); + } + } + + MemoryPool* pool_; + int32_t col_index_; + internal::Executor* executor_; +}; + +////////////////////////////////////////////////////////////////////////// +// Null column decoder implementation (for a column not in the CSV file) + +class NullColumnDecoder : public ConcreteColumnDecoder { + public: + explicit NullColumnDecoder(const std::shared_ptr<DataType>& type, MemoryPool* pool) + : ConcreteColumnDecoder(pool), type_(type) {} + + Future<std::shared_ptr<Array>> Decode( + const std::shared_ptr<BlockParser>& parser) override; + + protected: + std::shared_ptr<DataType> type() const override { return type_; } + + std::shared_ptr<DataType> type_; +}; + +Future<std::shared_ptr<Array>> NullColumnDecoder::Decode( + const std::shared_ptr<BlockParser>& parser) { + DCHECK_GE(parser->num_rows(), 0); + return WrapConversionError(MakeArrayOfNull(type_, parser->num_rows(), pool_)); +} + +////////////////////////////////////////////////////////////////////////// +// Pre-typed column decoder implementation + +class TypedColumnDecoder : public ConcreteColumnDecoder { + public: + TypedColumnDecoder(const std::shared_ptr<DataType>& type, int32_t col_index, + const ConvertOptions& options, MemoryPool* pool) + : ConcreteColumnDecoder(pool, col_index), type_(type), options_(options) {} + + Status Init(); + + Future<std::shared_ptr<Array>> Decode( + const std::shared_ptr<BlockParser>& parser) override; + + protected: + std::shared_ptr<DataType> type() const override { return type_; } + + std::shared_ptr<DataType> type_; + // CAUTION: ConvertOptions can grow large (if it customizes hundreds or + // thousands of columns), so avoid copying it in each TypedColumnDecoder. + const ConvertOptions& options_; + + std::shared_ptr<Converter> converter_; +}; + +Status TypedColumnDecoder::Init() { + ARROW_ASSIGN_OR_RAISE(converter_, Converter::Make(type_, options_, pool_)); + return Status::OK(); +} + +Future<std::shared_ptr<Array>> TypedColumnDecoder::Decode( + const std::shared_ptr<BlockParser>& parser) { + DCHECK_NE(converter_, nullptr); + return Future<std::shared_ptr<Array>>::MakeFinished( + WrapConversionError(converter_->Convert(*parser, col_index_))); +} + +////////////////////////////////////////////////////////////////////////// +// Type-inferring column builder implementation + +class InferringColumnDecoder : public ConcreteColumnDecoder { + public: + InferringColumnDecoder(int32_t col_index, const ConvertOptions& options, + MemoryPool* pool) + : ConcreteColumnDecoder(pool, col_index), + options_(options), + infer_status_(options), + type_frozen_(false) { + first_inference_run_ = Future<>::Make(); + first_inferrer_ = 0; + } + + Status Init(); + + Future<std::shared_ptr<Array>> Decode( + const std::shared_ptr<BlockParser>& parser) override; + + protected: + std::shared_ptr<DataType> type() const override { + DCHECK_NE(converter_, nullptr); + return converter_->type(); + } + + Status UpdateType(); + Result<std::shared_ptr<Array>> RunInference(const std::shared_ptr<BlockParser>& parser); + + // CAUTION: ConvertOptions can grow large (if it customizes hundreds or + // thousands of columns), so avoid copying it in each InferringColumnDecoder. + const ConvertOptions& options_; + + // Current inference status + InferStatus infer_status_; + bool type_frozen_; + std::atomic<int> first_inferrer_; + Future<> first_inference_run_; + std::shared_ptr<Converter> converter_; +}; + +Status InferringColumnDecoder::Init() { return UpdateType(); } + +Status InferringColumnDecoder::UpdateType() { + return infer_status_.MakeConverter(pool_).Value(&converter_); +} + +Result<std::shared_ptr<Array>> InferringColumnDecoder::RunInference( + const std::shared_ptr<BlockParser>& parser) { + while (true) { + // (no one else should be updating converter_ concurrently) + auto maybe_array = converter_->Convert(*parser, col_index_); + + if (maybe_array.ok() || !infer_status_.can_loosen_type()) { + // Conversion succeeded, or failed definitively + DCHECK(!type_frozen_); + type_frozen_ = true; + return maybe_array; + } + // Conversion failed temporarily, try another type + infer_status_.LoosenType(maybe_array.status()); + auto update_status = UpdateType(); + if (!update_status.ok()) { + return update_status; + } + } +} + +Future<std::shared_ptr<Array>> InferringColumnDecoder::Decode( + const std::shared_ptr<BlockParser>& parser) { + bool already_taken = first_inferrer_.fetch_or(1); + // First block: run inference + if (!already_taken) { + auto maybe_array = RunInference(parser); + first_inference_run_.MarkFinished(); + return Future<std::shared_ptr<Array>>::MakeFinished(std::move(maybe_array)); + } + + // Non-first block: wait for inference to finish on first block now, + // without blocking a TaskGroup thread. + return first_inference_run_.Then([this, parser] { + DCHECK(type_frozen_); + auto maybe_array = converter_->Convert(*parser, col_index_); + return WrapConversionError(converter_->Convert(*parser, col_index_)); + }); +} + +////////////////////////////////////////////////////////////////////////// +// Factory functions + +Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::Make( + MemoryPool* pool, int32_t col_index, const ConvertOptions& options) { + auto ptr = std::make_shared<InferringColumnDecoder>(col_index, options, pool); + RETURN_NOT_OK(ptr->Init()); + return ptr; +} + +Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::Make( + MemoryPool* pool, std::shared_ptr<DataType> type, int32_t col_index, + const ConvertOptions& options) { + auto ptr = + std::make_shared<TypedColumnDecoder>(std::move(type), col_index, options, pool); + RETURN_NOT_OK(ptr->Init()); + return ptr; +} + +Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::MakeNull( + MemoryPool* pool, std::shared_ptr<DataType> type) { + return std::make_shared<NullColumnDecoder>(std::move(type), pool); +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.h index 5fbbd5df58..1b72573dee 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/column_decoder.h @@ -1,64 +1,64 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> -#include <memory> -#include <utility> - -#include "arrow/result.h" -#include "arrow/type_fwd.h" -#include "arrow/util/type_fwd.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace csv { - -class BlockParser; -struct ConvertOptions; - -class ARROW_EXPORT ColumnDecoder { - public: - virtual ~ColumnDecoder() = default; - - /// Spawn a task that will try to convert and insert the given CSV block - virtual Future<std::shared_ptr<Array>> Decode( - const std::shared_ptr<BlockParser>& parser) = 0; - - /// Construct a strictly-typed ColumnDecoder. - static Result<std::shared_ptr<ColumnDecoder>> Make(MemoryPool* pool, - std::shared_ptr<DataType> type, - int32_t col_index, - const ConvertOptions& options); - - /// Construct a type-inferring ColumnDecoder. - /// Inference will run only on the first block, the type will be frozen afterwards. - static Result<std::shared_ptr<ColumnDecoder>> Make(MemoryPool* pool, int32_t col_index, - const ConvertOptions& options); - - /// Construct a ColumnDecoder for a column of nulls - /// (i.e. not present in the CSV file). - static Result<std::shared_ptr<ColumnDecoder>> MakeNull(MemoryPool* pool, - std::shared_ptr<DataType> type); - - protected: - ColumnDecoder() = default; -}; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> +#include <utility> + +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +class BlockParser; +struct ConvertOptions; + +class ARROW_EXPORT ColumnDecoder { + public: + virtual ~ColumnDecoder() = default; + + /// Spawn a task that will try to convert and insert the given CSV block + virtual Future<std::shared_ptr<Array>> Decode( + const std::shared_ptr<BlockParser>& parser) = 0; + + /// Construct a strictly-typed ColumnDecoder. + static Result<std::shared_ptr<ColumnDecoder>> Make(MemoryPool* pool, + std::shared_ptr<DataType> type, + int32_t col_index, + const ConvertOptions& options); + + /// Construct a type-inferring ColumnDecoder. + /// Inference will run only on the first block, the type will be frozen afterwards. + static Result<std::shared_ptr<ColumnDecoder>> Make(MemoryPool* pool, int32_t col_index, + const ConvertOptions& options); + + /// Construct a ColumnDecoder for a column of nulls + /// (i.e. not present in the CSV file). + static Result<std::shared_ptr<ColumnDecoder>> MakeNull(MemoryPool* pool, + std::shared_ptr<DataType> type); + + protected: + ColumnDecoder() = default; +}; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.cc index cb72b22b40..5d0386c6ca 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.cc @@ -1,692 +1,692 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/converter.h" - -#include <cstring> -#include <limits> -#include <sstream> -#include <string> -#include <type_traits> -#include <vector> - -#include "arrow/array/builder_binary.h" -#include "arrow/array/builder_decimal.h" -#include "arrow/array/builder_dict.h" -#include "arrow/array/builder_primitive.h" -#include "arrow/csv/parser.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/type_traits.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/decimal.h" -#include "arrow/util/trie.h" -#include "arrow/util/utf8.h" -#include "arrow/util/value_parsing.h" // IWYU pragma: keep - -namespace arrow { -namespace csv { - -using internal::checked_cast; -using internal::Trie; -using internal::TrieBuilder; - -namespace { - -Status GenericConversionError(const std::shared_ptr<DataType>& type, const uint8_t* data, - uint32_t size) { - return Status::Invalid("CSV conversion error to ", type->ToString(), - ": invalid value '", - std::string(reinterpret_cast<const char*>(data), size), "'"); -} - -inline bool IsWhitespace(uint8_t c) { - if (ARROW_PREDICT_TRUE(c > ' ')) { - return false; - } - return c == ' ' || c == '\t'; -} - -// Updates data_inout and size_inout to not include leading/trailing whitespace -// characters. -inline void TrimWhiteSpace(const uint8_t** data_inout, uint32_t* size_inout) { - const uint8_t*& data = *data_inout; - uint32_t& size = *size_inout; - // Skip trailing whitespace - if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[size - 1]))) { - const uint8_t* p = data + size - 1; - while (size > 0 && IsWhitespace(*p)) { - --size; - --p; - } - } - // Skip leading whitespace - if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[0]))) { - while (size > 0 && IsWhitespace(*data)) { - --size; - ++data; - } - } -} - -Status InitializeTrie(const std::vector<std::string>& inputs, Trie* trie) { - TrieBuilder builder; - for (const auto& s : inputs) { - RETURN_NOT_OK(builder.Append(s, true /* allow_duplicates */)); - } - *trie = builder.Finish(); - return Status::OK(); -} - -// Presize a builder based on parser contents -template <typename BuilderType> -enable_if_t<!is_base_binary_type<typename BuilderType::TypeClass>::value, Status> -PresizeBuilder(const BlockParser& parser, BuilderType* builder) { - return builder->Resize(parser.num_rows()); -} - -// Same, for variable-sized binary builders -template <typename T> -Status PresizeBuilder(const BlockParser& parser, BaseBinaryBuilder<T>* builder) { - RETURN_NOT_OK(builder->Resize(parser.num_rows())); - return builder->ReserveData(parser.num_bytes()); -} - -///////////////////////////////////////////////////////////////////////// -// Per-type value decoders - -struct ValueDecoder { - explicit ValueDecoder(const std::shared_ptr<DataType>& type, - const ConvertOptions& options) - : type_(type), options_(options) {} - - Status Initialize() { - // TODO no need to build a separate Trie for each instance - return InitializeTrie(options_.null_values, &null_trie_); - } - - bool IsNull(const uint8_t* data, uint32_t size, bool quoted) { - if (quoted) { - return false; - } - return null_trie_.Find( - util::string_view(reinterpret_cast<const char*>(data), size)) >= 0; - } - - protected: - Trie null_trie_; - std::shared_ptr<DataType> type_; - const ConvertOptions& options_; -}; - -// -// Value decoder for fixed-size binary -// - -struct FixedSizeBinaryValueDecoder : public ValueDecoder { - using value_type = const uint8_t*; - - explicit FixedSizeBinaryValueDecoder(const std::shared_ptr<DataType>& type, - const ConvertOptions& options) - : ValueDecoder(type, options), - byte_width_(checked_cast<const FixedSizeBinaryType&>(*type).byte_width()) {} - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - if (ARROW_PREDICT_FALSE(size != byte_width_)) { - return Status::Invalid("CSV conversion error to ", type_->ToString(), ": got a ", - size, "-byte long string"); - } - *out = data; - return Status::OK(); - } - - protected: - const uint32_t byte_width_; -}; - -// -// Value decoder for variable-size binary -// - -template <bool CheckUTF8> -struct BinaryValueDecoder : public ValueDecoder { - using value_type = util::string_view; - - using ValueDecoder::ValueDecoder; - - Status Initialize() { - util::InitializeUTF8(); - return ValueDecoder::Initialize(); - } - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - if (CheckUTF8 && ARROW_PREDICT_FALSE(!util::ValidateUTF8(data, size))) { - return Status::Invalid("CSV conversion error to ", type_->ToString(), - ": invalid UTF8 data"); - } - *out = {reinterpret_cast<const char*>(data), size}; - return Status::OK(); - } - - bool IsNull(const uint8_t* data, uint32_t size, bool quoted) { - return options_.strings_can_be_null && - (!quoted || options_.quoted_strings_can_be_null) && - ValueDecoder::IsNull(data, size, false /* quoted */); - } -}; - -// -// Value decoder for integers and floats -// - -template <typename T> -struct NumericValueDecoder : public ValueDecoder { - using value_type = typename T::c_type; - - using ValueDecoder::ValueDecoder; - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - // XXX should quoted values be allowed at all? - TrimWhiteSpace(&data, &size); - if (ARROW_PREDICT_FALSE( - !internal::ParseValue<T>(reinterpret_cast<const char*>(data), size, out))) { - return GenericConversionError(type_, data, size); - } - return Status::OK(); - } -}; - -// -// Value decoder for booleans -// - -struct BooleanValueDecoder : public ValueDecoder { - using value_type = bool; - - using ValueDecoder::ValueDecoder; - - Status Initialize() { - // TODO no need to build separate Tries for each instance - RETURN_NOT_OK(InitializeTrie(options_.true_values, &true_trie_)); - RETURN_NOT_OK(InitializeTrie(options_.false_values, &false_trie_)); - return ValueDecoder::Initialize(); - } - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - // XXX should quoted values be allowed at all? - if (false_trie_.Find(util::string_view(reinterpret_cast<const char*>(data), size)) >= - 0) { - *out = false; - return Status::OK(); - } - if (ARROW_PREDICT_TRUE(true_trie_.Find(util::string_view( - reinterpret_cast<const char*>(data), size)) >= 0)) { - *out = true; - return Status::OK(); - } - return GenericConversionError(type_, data, size); - } - - protected: - Trie true_trie_; - Trie false_trie_; -}; - -// -// Value decoder for decimals -// - -struct DecimalValueDecoder : public ValueDecoder { - using value_type = Decimal128; - - explicit DecimalValueDecoder(const std::shared_ptr<DataType>& type, - const ConvertOptions& options) - : ValueDecoder(type, options), - decimal_type_(internal::checked_cast<const DecimalType&>(*type_)), - type_precision_(decimal_type_.precision()), - type_scale_(decimal_type_.scale()) {} - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - TrimWhiteSpace(&data, &size); - Decimal128 decimal; - int32_t precision, scale; - util::string_view view(reinterpret_cast<const char*>(data), size); - RETURN_NOT_OK(Decimal128::FromString(view, &decimal, &precision, &scale)); - if (precision > type_precision_) { - return Status::Invalid("Error converting '", view, "' to ", type_->ToString(), - ": precision not supported by type."); - } - if (scale != type_scale_) { - ARROW_ASSIGN_OR_RAISE(*out, decimal.Rescale(scale, type_scale_)); - } else { - *out = std::move(decimal); - } - return Status::OK(); - } - - protected: - const DecimalType& decimal_type_; - const int32_t type_precision_; - const int32_t type_scale_; -}; - -// -// Value decoders for timestamps -// - -struct InlineISO8601ValueDecoder : public ValueDecoder { - using value_type = int64_t; - - explicit InlineISO8601ValueDecoder(const std::shared_ptr<DataType>& type, - const ConvertOptions& options) - : ValueDecoder(type, options), - unit_(checked_cast<const TimestampType&>(*type_).unit()) {} - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - if (ARROW_PREDICT_FALSE(!internal::ParseTimestampISO8601( - reinterpret_cast<const char*>(data), size, unit_, out))) { - return GenericConversionError(type_, data, size); - } - return Status::OK(); - } - - protected: - TimeUnit::type unit_; -}; - -struct SingleParserTimestampValueDecoder : public ValueDecoder { - using value_type = int64_t; - - explicit SingleParserTimestampValueDecoder(const std::shared_ptr<DataType>& type, - const ConvertOptions& options) - : ValueDecoder(type, options), - unit_(checked_cast<const TimestampType&>(*type_).unit()), - parser_(*options_.timestamp_parsers[0]) {} - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - if (ARROW_PREDICT_FALSE( - !parser_(reinterpret_cast<const char*>(data), size, unit_, out))) { - return GenericConversionError(type_, data, size); - } - return Status::OK(); - } - - protected: - TimeUnit::type unit_; - const TimestampParser& parser_; -}; - -struct MultipleParsersTimestampValueDecoder : public ValueDecoder { - using value_type = int64_t; - - explicit MultipleParsersTimestampValueDecoder(const std::shared_ptr<DataType>& type, - const ConvertOptions& options) - : ValueDecoder(type, options), - unit_(checked_cast<const TimestampType&>(*type_).unit()), - parsers_(GetParsers(options_)) {} - - Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - for (const auto& parser : parsers_) { - if (parser->operator()(reinterpret_cast<const char*>(data), size, unit_, out)) { - return Status::OK(); - } - } - return GenericConversionError(type_, data, size); - } - - protected: - using ParserVector = std::vector<const TimestampParser*>; - - static ParserVector GetParsers(const ConvertOptions& options) { - ParserVector parsers(options.timestamp_parsers.size()); - for (size_t i = 0; i < options.timestamp_parsers.size(); ++i) { - parsers[i] = options.timestamp_parsers[i].get(); - } - return parsers; - } - - TimeUnit::type unit_; - std::vector<const TimestampParser*> parsers_; -}; - -///////////////////////////////////////////////////////////////////////// -// Concrete Converter hierarchy - -class ConcreteConverter : public Converter { - public: - using Converter::Converter; -}; - -class ConcreteDictionaryConverter : public DictionaryConverter { - public: - using DictionaryConverter::DictionaryConverter; -}; - -// -// Concrete Converter for nulls -// - -class NullConverter : public ConcreteConverter { - public: - NullConverter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, - MemoryPool* pool) - : ConcreteConverter(type, options, pool), decoder_(type_, options_) {} - - Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, - int32_t col_index) override { - NullBuilder builder(pool_); - - auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { - if (ARROW_PREDICT_TRUE(decoder_.IsNull(data, size, quoted))) { - return builder.AppendNull(); - } else { - return GenericConversionError(type_, data, size); - } - }; - RETURN_NOT_OK(parser.VisitColumn(col_index, visit)); - std::shared_ptr<Array> res; - RETURN_NOT_OK(builder.Finish(&res)); - return res; - } - - protected: - Status Initialize() override { return decoder_.Initialize(); } - - ValueDecoder decoder_; -}; - -// -// Concrete Converter for primitives -// - -template <typename T, typename ValueDecoderType> -class PrimitiveConverter : public ConcreteConverter { - public: - PrimitiveConverter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, - MemoryPool* pool) - : ConcreteConverter(type, options, pool), decoder_(type_, options_) {} - - Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, - int32_t col_index) override { - using BuilderType = typename TypeTraits<T>::BuilderType; - using value_type = typename ValueDecoderType::value_type; - - BuilderType builder(type_, pool_); - RETURN_NOT_OK(PresizeBuilder(parser, &builder)); - - auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { - if (decoder_.IsNull(data, size, quoted /* quoted */)) { - return builder.AppendNull(); - } - value_type value{}; - RETURN_NOT_OK(decoder_.Decode(data, size, quoted, &value)); - builder.UnsafeAppend(value); - return Status::OK(); - }; - RETURN_NOT_OK(parser.VisitColumn(col_index, visit)); - - std::shared_ptr<Array> res; - RETURN_NOT_OK(builder.Finish(&res)); - return res; - } - - protected: - Status Initialize() override { return decoder_.Initialize(); } - - ValueDecoderType decoder_; -}; - -// -// Concrete Converter for dictionaries -// - -template <typename T, typename ValueDecoderType> -class TypedDictionaryConverter : public ConcreteDictionaryConverter { - public: - TypedDictionaryConverter(const std::shared_ptr<DataType>& value_type, - const ConvertOptions& options, MemoryPool* pool) - : ConcreteDictionaryConverter(value_type, options, pool), - decoder_(value_type, options_) {} - - Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, - int32_t col_index) override { - // We use a fixed index width so that all column chunks get the same index type - using BuilderType = Dictionary32Builder<T>; - using value_type = typename ValueDecoderType::value_type; - - BuilderType builder(value_type_, pool_); - RETURN_NOT_OK(PresizeBuilder(parser, &builder)); - - auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { - if (decoder_.IsNull(data, size, quoted /* quoted */)) { - return builder.AppendNull(); - } - if (ARROW_PREDICT_FALSE(builder.dictionary_length() > max_cardinality_)) { - return Status::IndexError("Dictionary length exceeded max cardinality"); - } - value_type value{}; - RETURN_NOT_OK(decoder_.Decode(data, size, quoted, &value)); - return builder.Append(value); - }; - RETURN_NOT_OK(parser.VisitColumn(col_index, visit)); - - std::shared_ptr<Array> res; - RETURN_NOT_OK(builder.Finish(&res)); - return res; - } - - void SetMaxCardinality(int32_t max_length) override { max_cardinality_ = max_length; } - - protected: - Status Initialize() override { - util::InitializeUTF8(); - return decoder_.Initialize(); - } - - ValueDecoderType decoder_; - int32_t max_cardinality_ = std::numeric_limits<int32_t>::max(); -}; - -// -// Concrete Converter factory for timestamps -// - -template <template <typename, typename> class ConverterType> -std::shared_ptr<Converter> MakeTimestampConverter(const std::shared_ptr<DataType>& type, - const ConvertOptions& options, - MemoryPool* pool) { - if (options.timestamp_parsers.size() == 0) { - // Default to ISO-8601 - return std::make_shared<ConverterType<TimestampType, InlineISO8601ValueDecoder>>( - type, options, pool); - } else if (options.timestamp_parsers.size() == 1) { - // Single user-supplied converter - return std::make_shared< - ConverterType<TimestampType, SingleParserTimestampValueDecoder>>(type, options, - pool); - } else { - // Multiple converters, must iterate for each value - return std::make_shared< - ConverterType<TimestampType, MultipleParsersTimestampValueDecoder>>(type, options, - pool); - } -} - -} // namespace - -///////////////////////////////////////////////////////////////////////// -// Base Converter class implementation - -Converter::Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, - MemoryPool* pool) - : options_(options), pool_(pool), type_(type) {} - -DictionaryConverter::DictionaryConverter(const std::shared_ptr<DataType>& value_type, - const ConvertOptions& options, MemoryPool* pool) - : Converter(dictionary(int32(), value_type), options, pool), - value_type_(value_type) {} - -Result<std::shared_ptr<Converter>> Converter::Make(const std::shared_ptr<DataType>& type, - const ConvertOptions& options, - MemoryPool* pool) { - std::shared_ptr<Converter> ptr; - - switch (type->id()) { -#define CONVERTER_CASE(TYPE_ID, CONVERTER_TYPE) \ - case TYPE_ID: \ - ptr.reset(new CONVERTER_TYPE(type, options, pool)); \ - break; - -#define NUMERIC_CONVERTER_CASE(TYPE_ID, TYPE_CLASS) \ - CONVERTER_CASE(TYPE_ID, \ - (PrimitiveConverter<TYPE_CLASS, NumericValueDecoder<TYPE_CLASS>>)) - - CONVERTER_CASE(Type::NA, NullConverter) - NUMERIC_CONVERTER_CASE(Type::INT8, Int8Type) - NUMERIC_CONVERTER_CASE(Type::INT16, Int16Type) - NUMERIC_CONVERTER_CASE(Type::INT32, Int32Type) - NUMERIC_CONVERTER_CASE(Type::INT64, Int64Type) - NUMERIC_CONVERTER_CASE(Type::UINT8, UInt8Type) - NUMERIC_CONVERTER_CASE(Type::UINT16, UInt16Type) - NUMERIC_CONVERTER_CASE(Type::UINT32, UInt32Type) - NUMERIC_CONVERTER_CASE(Type::UINT64, UInt64Type) - NUMERIC_CONVERTER_CASE(Type::FLOAT, FloatType) - NUMERIC_CONVERTER_CASE(Type::DOUBLE, DoubleType) - NUMERIC_CONVERTER_CASE(Type::DATE32, Date32Type) - NUMERIC_CONVERTER_CASE(Type::DATE64, Date64Type) - CONVERTER_CASE(Type::BOOL, (PrimitiveConverter<BooleanType, BooleanValueDecoder>)) - CONVERTER_CASE(Type::BINARY, - (PrimitiveConverter<BinaryType, BinaryValueDecoder<false>>)) - CONVERTER_CASE(Type::LARGE_BINARY, - (PrimitiveConverter<LargeBinaryType, BinaryValueDecoder<false>>)) - CONVERTER_CASE(Type::FIXED_SIZE_BINARY, - (PrimitiveConverter<FixedSizeBinaryType, FixedSizeBinaryValueDecoder>)) - CONVERTER_CASE(Type::DECIMAL, - (PrimitiveConverter<Decimal128Type, DecimalValueDecoder>)) - - case Type::TIMESTAMP: - ptr = MakeTimestampConverter<PrimitiveConverter>(type, options, pool); - break; - - case Type::STRING: - if (options.check_utf8) { - ptr = std::make_shared<PrimitiveConverter<StringType, BinaryValueDecoder<true>>>( - type, options, pool); - } else { - ptr = std::make_shared<PrimitiveConverter<StringType, BinaryValueDecoder<false>>>( - type, options, pool); - } - break; - - case Type::LARGE_STRING: - if (options.check_utf8) { - ptr = std::make_shared< - PrimitiveConverter<LargeStringType, BinaryValueDecoder<true>>>(type, options, - pool); - } else { - ptr = std::make_shared< - PrimitiveConverter<LargeStringType, BinaryValueDecoder<false>>>(type, options, - pool); - } - break; - - case Type::DICTIONARY: { - const auto& dict_type = checked_cast<const DictionaryType&>(*type); - if (dict_type.index_type()->id() != Type::INT32) { - return Status::NotImplemented( - "CSV conversion to dictionary only supported for int32 indices, " - "got ", - type->ToString()); - } - return DictionaryConverter::Make(dict_type.value_type(), options, pool); - } - - default: { - return Status::NotImplemented("CSV conversion to ", type->ToString(), - " is not supported"); - } - -#undef CONVERTER_CASE -#undef NUMERIC_CONVERTER_CASE - } - RETURN_NOT_OK(ptr->Initialize()); - return ptr; -} - -Result<std::shared_ptr<DictionaryConverter>> DictionaryConverter::Make( - const std::shared_ptr<DataType>& type, const ConvertOptions& options, - MemoryPool* pool) { - std::shared_ptr<DictionaryConverter> ptr; - - switch (type->id()) { -#define CONVERTER_CASE(TYPE_ID, TYPE, VALUE_DECODER_TYPE) \ - case TYPE_ID: \ - ptr.reset( \ - new TypedDictionaryConverter<TYPE, VALUE_DECODER_TYPE>(type, options, pool)); \ - break; - - // XXX Are 32-bit types useful? - CONVERTER_CASE(Type::INT32, Int32Type, NumericValueDecoder<Int32Type>) - CONVERTER_CASE(Type::INT64, Int64Type, NumericValueDecoder<Int64Type>) - CONVERTER_CASE(Type::UINT32, UInt32Type, NumericValueDecoder<UInt32Type>) - CONVERTER_CASE(Type::UINT64, UInt64Type, NumericValueDecoder<UInt64Type>) - CONVERTER_CASE(Type::FLOAT, FloatType, NumericValueDecoder<FloatType>) - CONVERTER_CASE(Type::DOUBLE, DoubleType, NumericValueDecoder<DoubleType>) - CONVERTER_CASE(Type::DECIMAL, Decimal128Type, DecimalValueDecoder) - CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryType, - FixedSizeBinaryValueDecoder) - CONVERTER_CASE(Type::BINARY, BinaryType, BinaryValueDecoder<false>) - CONVERTER_CASE(Type::LARGE_BINARY, LargeBinaryType, BinaryValueDecoder<false>) - - case Type::STRING: - if (options.check_utf8) { - ptr = std::make_shared< - TypedDictionaryConverter<StringType, BinaryValueDecoder<true>>>(type, options, - pool); - } else { - ptr = std::make_shared< - TypedDictionaryConverter<StringType, BinaryValueDecoder<false>>>( - type, options, pool); - } - break; - - case Type::LARGE_STRING: - if (options.check_utf8) { - ptr = std::make_shared< - TypedDictionaryConverter<LargeStringType, BinaryValueDecoder<true>>>( - type, options, pool); - } else { - ptr = std::make_shared< - TypedDictionaryConverter<LargeStringType, BinaryValueDecoder<false>>>( - type, options, pool); - } - break; - - default: { - return Status::NotImplemented("CSV dictionary conversion to ", type->ToString(), - " is not supported"); - } - -#undef CONVERTER_CASE - } - RETURN_NOT_OK(ptr->Initialize()); - return ptr; -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/converter.h" + +#include <cstring> +#include <limits> +#include <sstream> +#include <string> +#include <type_traits> +#include <vector> + +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_decimal.h" +#include "arrow/array/builder_dict.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/csv/parser.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" +#include "arrow/util/trie.h" +#include "arrow/util/utf8.h" +#include "arrow/util/value_parsing.h" // IWYU pragma: keep + +namespace arrow { +namespace csv { + +using internal::checked_cast; +using internal::Trie; +using internal::TrieBuilder; + +namespace { + +Status GenericConversionError(const std::shared_ptr<DataType>& type, const uint8_t* data, + uint32_t size) { + return Status::Invalid("CSV conversion error to ", type->ToString(), + ": invalid value '", + std::string(reinterpret_cast<const char*>(data), size), "'"); +} + +inline bool IsWhitespace(uint8_t c) { + if (ARROW_PREDICT_TRUE(c > ' ')) { + return false; + } + return c == ' ' || c == '\t'; +} + +// Updates data_inout and size_inout to not include leading/trailing whitespace +// characters. +inline void TrimWhiteSpace(const uint8_t** data_inout, uint32_t* size_inout) { + const uint8_t*& data = *data_inout; + uint32_t& size = *size_inout; + // Skip trailing whitespace + if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[size - 1]))) { + const uint8_t* p = data + size - 1; + while (size > 0 && IsWhitespace(*p)) { + --size; + --p; + } + } + // Skip leading whitespace + if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[0]))) { + while (size > 0 && IsWhitespace(*data)) { + --size; + ++data; + } + } +} + +Status InitializeTrie(const std::vector<std::string>& inputs, Trie* trie) { + TrieBuilder builder; + for (const auto& s : inputs) { + RETURN_NOT_OK(builder.Append(s, true /* allow_duplicates */)); + } + *trie = builder.Finish(); + return Status::OK(); +} + +// Presize a builder based on parser contents +template <typename BuilderType> +enable_if_t<!is_base_binary_type<typename BuilderType::TypeClass>::value, Status> +PresizeBuilder(const BlockParser& parser, BuilderType* builder) { + return builder->Resize(parser.num_rows()); +} + +// Same, for variable-sized binary builders +template <typename T> +Status PresizeBuilder(const BlockParser& parser, BaseBinaryBuilder<T>* builder) { + RETURN_NOT_OK(builder->Resize(parser.num_rows())); + return builder->ReserveData(parser.num_bytes()); +} + +///////////////////////////////////////////////////////////////////////// +// Per-type value decoders + +struct ValueDecoder { + explicit ValueDecoder(const std::shared_ptr<DataType>& type, + const ConvertOptions& options) + : type_(type), options_(options) {} + + Status Initialize() { + // TODO no need to build a separate Trie for each instance + return InitializeTrie(options_.null_values, &null_trie_); + } + + bool IsNull(const uint8_t* data, uint32_t size, bool quoted) { + if (quoted) { + return false; + } + return null_trie_.Find( + util::string_view(reinterpret_cast<const char*>(data), size)) >= 0; + } + + protected: + Trie null_trie_; + std::shared_ptr<DataType> type_; + const ConvertOptions& options_; +}; + +// +// Value decoder for fixed-size binary +// + +struct FixedSizeBinaryValueDecoder : public ValueDecoder { + using value_type = const uint8_t*; + + explicit FixedSizeBinaryValueDecoder(const std::shared_ptr<DataType>& type, + const ConvertOptions& options) + : ValueDecoder(type, options), + byte_width_(checked_cast<const FixedSizeBinaryType&>(*type).byte_width()) {} + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + if (ARROW_PREDICT_FALSE(size != byte_width_)) { + return Status::Invalid("CSV conversion error to ", type_->ToString(), ": got a ", + size, "-byte long string"); + } + *out = data; + return Status::OK(); + } + + protected: + const uint32_t byte_width_; +}; + +// +// Value decoder for variable-size binary +// + +template <bool CheckUTF8> +struct BinaryValueDecoder : public ValueDecoder { + using value_type = util::string_view; + + using ValueDecoder::ValueDecoder; + + Status Initialize() { + util::InitializeUTF8(); + return ValueDecoder::Initialize(); + } + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + if (CheckUTF8 && ARROW_PREDICT_FALSE(!util::ValidateUTF8(data, size))) { + return Status::Invalid("CSV conversion error to ", type_->ToString(), + ": invalid UTF8 data"); + } + *out = {reinterpret_cast<const char*>(data), size}; + return Status::OK(); + } + + bool IsNull(const uint8_t* data, uint32_t size, bool quoted) { + return options_.strings_can_be_null && + (!quoted || options_.quoted_strings_can_be_null) && + ValueDecoder::IsNull(data, size, false /* quoted */); + } +}; + +// +// Value decoder for integers and floats +// + +template <typename T> +struct NumericValueDecoder : public ValueDecoder { + using value_type = typename T::c_type; + + using ValueDecoder::ValueDecoder; + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + // XXX should quoted values be allowed at all? + TrimWhiteSpace(&data, &size); + if (ARROW_PREDICT_FALSE( + !internal::ParseValue<T>(reinterpret_cast<const char*>(data), size, out))) { + return GenericConversionError(type_, data, size); + } + return Status::OK(); + } +}; + +// +// Value decoder for booleans +// + +struct BooleanValueDecoder : public ValueDecoder { + using value_type = bool; + + using ValueDecoder::ValueDecoder; + + Status Initialize() { + // TODO no need to build separate Tries for each instance + RETURN_NOT_OK(InitializeTrie(options_.true_values, &true_trie_)); + RETURN_NOT_OK(InitializeTrie(options_.false_values, &false_trie_)); + return ValueDecoder::Initialize(); + } + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + // XXX should quoted values be allowed at all? + if (false_trie_.Find(util::string_view(reinterpret_cast<const char*>(data), size)) >= + 0) { + *out = false; + return Status::OK(); + } + if (ARROW_PREDICT_TRUE(true_trie_.Find(util::string_view( + reinterpret_cast<const char*>(data), size)) >= 0)) { + *out = true; + return Status::OK(); + } + return GenericConversionError(type_, data, size); + } + + protected: + Trie true_trie_; + Trie false_trie_; +}; + +// +// Value decoder for decimals +// + +struct DecimalValueDecoder : public ValueDecoder { + using value_type = Decimal128; + + explicit DecimalValueDecoder(const std::shared_ptr<DataType>& type, + const ConvertOptions& options) + : ValueDecoder(type, options), + decimal_type_(internal::checked_cast<const DecimalType&>(*type_)), + type_precision_(decimal_type_.precision()), + type_scale_(decimal_type_.scale()) {} + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + TrimWhiteSpace(&data, &size); + Decimal128 decimal; + int32_t precision, scale; + util::string_view view(reinterpret_cast<const char*>(data), size); + RETURN_NOT_OK(Decimal128::FromString(view, &decimal, &precision, &scale)); + if (precision > type_precision_) { + return Status::Invalid("Error converting '", view, "' to ", type_->ToString(), + ": precision not supported by type."); + } + if (scale != type_scale_) { + ARROW_ASSIGN_OR_RAISE(*out, decimal.Rescale(scale, type_scale_)); + } else { + *out = std::move(decimal); + } + return Status::OK(); + } + + protected: + const DecimalType& decimal_type_; + const int32_t type_precision_; + const int32_t type_scale_; +}; + +// +// Value decoders for timestamps +// + +struct InlineISO8601ValueDecoder : public ValueDecoder { + using value_type = int64_t; + + explicit InlineISO8601ValueDecoder(const std::shared_ptr<DataType>& type, + const ConvertOptions& options) + : ValueDecoder(type, options), + unit_(checked_cast<const TimestampType&>(*type_).unit()) {} + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + if (ARROW_PREDICT_FALSE(!internal::ParseTimestampISO8601( + reinterpret_cast<const char*>(data), size, unit_, out))) { + return GenericConversionError(type_, data, size); + } + return Status::OK(); + } + + protected: + TimeUnit::type unit_; +}; + +struct SingleParserTimestampValueDecoder : public ValueDecoder { + using value_type = int64_t; + + explicit SingleParserTimestampValueDecoder(const std::shared_ptr<DataType>& type, + const ConvertOptions& options) + : ValueDecoder(type, options), + unit_(checked_cast<const TimestampType&>(*type_).unit()), + parser_(*options_.timestamp_parsers[0]) {} + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + if (ARROW_PREDICT_FALSE( + !parser_(reinterpret_cast<const char*>(data), size, unit_, out))) { + return GenericConversionError(type_, data, size); + } + return Status::OK(); + } + + protected: + TimeUnit::type unit_; + const TimestampParser& parser_; +}; + +struct MultipleParsersTimestampValueDecoder : public ValueDecoder { + using value_type = int64_t; + + explicit MultipleParsersTimestampValueDecoder(const std::shared_ptr<DataType>& type, + const ConvertOptions& options) + : ValueDecoder(type, options), + unit_(checked_cast<const TimestampType&>(*type_).unit()), + parsers_(GetParsers(options_)) {} + + Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { + for (const auto& parser : parsers_) { + if (parser->operator()(reinterpret_cast<const char*>(data), size, unit_, out)) { + return Status::OK(); + } + } + return GenericConversionError(type_, data, size); + } + + protected: + using ParserVector = std::vector<const TimestampParser*>; + + static ParserVector GetParsers(const ConvertOptions& options) { + ParserVector parsers(options.timestamp_parsers.size()); + for (size_t i = 0; i < options.timestamp_parsers.size(); ++i) { + parsers[i] = options.timestamp_parsers[i].get(); + } + return parsers; + } + + TimeUnit::type unit_; + std::vector<const TimestampParser*> parsers_; +}; + +///////////////////////////////////////////////////////////////////////// +// Concrete Converter hierarchy + +class ConcreteConverter : public Converter { + public: + using Converter::Converter; +}; + +class ConcreteDictionaryConverter : public DictionaryConverter { + public: + using DictionaryConverter::DictionaryConverter; +}; + +// +// Concrete Converter for nulls +// + +class NullConverter : public ConcreteConverter { + public: + NullConverter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, + MemoryPool* pool) + : ConcreteConverter(type, options, pool), decoder_(type_, options_) {} + + Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, + int32_t col_index) override { + NullBuilder builder(pool_); + + auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { + if (ARROW_PREDICT_TRUE(decoder_.IsNull(data, size, quoted))) { + return builder.AppendNull(); + } else { + return GenericConversionError(type_, data, size); + } + }; + RETURN_NOT_OK(parser.VisitColumn(col_index, visit)); + std::shared_ptr<Array> res; + RETURN_NOT_OK(builder.Finish(&res)); + return res; + } + + protected: + Status Initialize() override { return decoder_.Initialize(); } + + ValueDecoder decoder_; +}; + +// +// Concrete Converter for primitives +// + +template <typename T, typename ValueDecoderType> +class PrimitiveConverter : public ConcreteConverter { + public: + PrimitiveConverter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, + MemoryPool* pool) + : ConcreteConverter(type, options, pool), decoder_(type_, options_) {} + + Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, + int32_t col_index) override { + using BuilderType = typename TypeTraits<T>::BuilderType; + using value_type = typename ValueDecoderType::value_type; + + BuilderType builder(type_, pool_); + RETURN_NOT_OK(PresizeBuilder(parser, &builder)); + + auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { + if (decoder_.IsNull(data, size, quoted /* quoted */)) { + return builder.AppendNull(); + } + value_type value{}; + RETURN_NOT_OK(decoder_.Decode(data, size, quoted, &value)); + builder.UnsafeAppend(value); + return Status::OK(); + }; + RETURN_NOT_OK(parser.VisitColumn(col_index, visit)); + + std::shared_ptr<Array> res; + RETURN_NOT_OK(builder.Finish(&res)); + return res; + } + + protected: + Status Initialize() override { return decoder_.Initialize(); } + + ValueDecoderType decoder_; +}; + +// +// Concrete Converter for dictionaries +// + +template <typename T, typename ValueDecoderType> +class TypedDictionaryConverter : public ConcreteDictionaryConverter { + public: + TypedDictionaryConverter(const std::shared_ptr<DataType>& value_type, + const ConvertOptions& options, MemoryPool* pool) + : ConcreteDictionaryConverter(value_type, options, pool), + decoder_(value_type, options_) {} + + Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, + int32_t col_index) override { + // We use a fixed index width so that all column chunks get the same index type + using BuilderType = Dictionary32Builder<T>; + using value_type = typename ValueDecoderType::value_type; + + BuilderType builder(value_type_, pool_); + RETURN_NOT_OK(PresizeBuilder(parser, &builder)); + + auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { + if (decoder_.IsNull(data, size, quoted /* quoted */)) { + return builder.AppendNull(); + } + if (ARROW_PREDICT_FALSE(builder.dictionary_length() > max_cardinality_)) { + return Status::IndexError("Dictionary length exceeded max cardinality"); + } + value_type value{}; + RETURN_NOT_OK(decoder_.Decode(data, size, quoted, &value)); + return builder.Append(value); + }; + RETURN_NOT_OK(parser.VisitColumn(col_index, visit)); + + std::shared_ptr<Array> res; + RETURN_NOT_OK(builder.Finish(&res)); + return res; + } + + void SetMaxCardinality(int32_t max_length) override { max_cardinality_ = max_length; } + + protected: + Status Initialize() override { + util::InitializeUTF8(); + return decoder_.Initialize(); + } + + ValueDecoderType decoder_; + int32_t max_cardinality_ = std::numeric_limits<int32_t>::max(); +}; + +// +// Concrete Converter factory for timestamps +// + +template <template <typename, typename> class ConverterType> +std::shared_ptr<Converter> MakeTimestampConverter(const std::shared_ptr<DataType>& type, + const ConvertOptions& options, + MemoryPool* pool) { + if (options.timestamp_parsers.size() == 0) { + // Default to ISO-8601 + return std::make_shared<ConverterType<TimestampType, InlineISO8601ValueDecoder>>( + type, options, pool); + } else if (options.timestamp_parsers.size() == 1) { + // Single user-supplied converter + return std::make_shared< + ConverterType<TimestampType, SingleParserTimestampValueDecoder>>(type, options, + pool); + } else { + // Multiple converters, must iterate for each value + return std::make_shared< + ConverterType<TimestampType, MultipleParsersTimestampValueDecoder>>(type, options, + pool); + } +} + +} // namespace + +///////////////////////////////////////////////////////////////////////// +// Base Converter class implementation + +Converter::Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, + MemoryPool* pool) + : options_(options), pool_(pool), type_(type) {} + +DictionaryConverter::DictionaryConverter(const std::shared_ptr<DataType>& value_type, + const ConvertOptions& options, MemoryPool* pool) + : Converter(dictionary(int32(), value_type), options, pool), + value_type_(value_type) {} + +Result<std::shared_ptr<Converter>> Converter::Make(const std::shared_ptr<DataType>& type, + const ConvertOptions& options, + MemoryPool* pool) { + std::shared_ptr<Converter> ptr; + + switch (type->id()) { +#define CONVERTER_CASE(TYPE_ID, CONVERTER_TYPE) \ + case TYPE_ID: \ + ptr.reset(new CONVERTER_TYPE(type, options, pool)); \ + break; + +#define NUMERIC_CONVERTER_CASE(TYPE_ID, TYPE_CLASS) \ + CONVERTER_CASE(TYPE_ID, \ + (PrimitiveConverter<TYPE_CLASS, NumericValueDecoder<TYPE_CLASS>>)) + + CONVERTER_CASE(Type::NA, NullConverter) + NUMERIC_CONVERTER_CASE(Type::INT8, Int8Type) + NUMERIC_CONVERTER_CASE(Type::INT16, Int16Type) + NUMERIC_CONVERTER_CASE(Type::INT32, Int32Type) + NUMERIC_CONVERTER_CASE(Type::INT64, Int64Type) + NUMERIC_CONVERTER_CASE(Type::UINT8, UInt8Type) + NUMERIC_CONVERTER_CASE(Type::UINT16, UInt16Type) + NUMERIC_CONVERTER_CASE(Type::UINT32, UInt32Type) + NUMERIC_CONVERTER_CASE(Type::UINT64, UInt64Type) + NUMERIC_CONVERTER_CASE(Type::FLOAT, FloatType) + NUMERIC_CONVERTER_CASE(Type::DOUBLE, DoubleType) + NUMERIC_CONVERTER_CASE(Type::DATE32, Date32Type) + NUMERIC_CONVERTER_CASE(Type::DATE64, Date64Type) + CONVERTER_CASE(Type::BOOL, (PrimitiveConverter<BooleanType, BooleanValueDecoder>)) + CONVERTER_CASE(Type::BINARY, + (PrimitiveConverter<BinaryType, BinaryValueDecoder<false>>)) + CONVERTER_CASE(Type::LARGE_BINARY, + (PrimitiveConverter<LargeBinaryType, BinaryValueDecoder<false>>)) + CONVERTER_CASE(Type::FIXED_SIZE_BINARY, + (PrimitiveConverter<FixedSizeBinaryType, FixedSizeBinaryValueDecoder>)) + CONVERTER_CASE(Type::DECIMAL, + (PrimitiveConverter<Decimal128Type, DecimalValueDecoder>)) + + case Type::TIMESTAMP: + ptr = MakeTimestampConverter<PrimitiveConverter>(type, options, pool); + break; + + case Type::STRING: + if (options.check_utf8) { + ptr = std::make_shared<PrimitiveConverter<StringType, BinaryValueDecoder<true>>>( + type, options, pool); + } else { + ptr = std::make_shared<PrimitiveConverter<StringType, BinaryValueDecoder<false>>>( + type, options, pool); + } + break; + + case Type::LARGE_STRING: + if (options.check_utf8) { + ptr = std::make_shared< + PrimitiveConverter<LargeStringType, BinaryValueDecoder<true>>>(type, options, + pool); + } else { + ptr = std::make_shared< + PrimitiveConverter<LargeStringType, BinaryValueDecoder<false>>>(type, options, + pool); + } + break; + + case Type::DICTIONARY: { + const auto& dict_type = checked_cast<const DictionaryType&>(*type); + if (dict_type.index_type()->id() != Type::INT32) { + return Status::NotImplemented( + "CSV conversion to dictionary only supported for int32 indices, " + "got ", + type->ToString()); + } + return DictionaryConverter::Make(dict_type.value_type(), options, pool); + } + + default: { + return Status::NotImplemented("CSV conversion to ", type->ToString(), + " is not supported"); + } + +#undef CONVERTER_CASE +#undef NUMERIC_CONVERTER_CASE + } + RETURN_NOT_OK(ptr->Initialize()); + return ptr; +} + +Result<std::shared_ptr<DictionaryConverter>> DictionaryConverter::Make( + const std::shared_ptr<DataType>& type, const ConvertOptions& options, + MemoryPool* pool) { + std::shared_ptr<DictionaryConverter> ptr; + + switch (type->id()) { +#define CONVERTER_CASE(TYPE_ID, TYPE, VALUE_DECODER_TYPE) \ + case TYPE_ID: \ + ptr.reset( \ + new TypedDictionaryConverter<TYPE, VALUE_DECODER_TYPE>(type, options, pool)); \ + break; + + // XXX Are 32-bit types useful? + CONVERTER_CASE(Type::INT32, Int32Type, NumericValueDecoder<Int32Type>) + CONVERTER_CASE(Type::INT64, Int64Type, NumericValueDecoder<Int64Type>) + CONVERTER_CASE(Type::UINT32, UInt32Type, NumericValueDecoder<UInt32Type>) + CONVERTER_CASE(Type::UINT64, UInt64Type, NumericValueDecoder<UInt64Type>) + CONVERTER_CASE(Type::FLOAT, FloatType, NumericValueDecoder<FloatType>) + CONVERTER_CASE(Type::DOUBLE, DoubleType, NumericValueDecoder<DoubleType>) + CONVERTER_CASE(Type::DECIMAL, Decimal128Type, DecimalValueDecoder) + CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryType, + FixedSizeBinaryValueDecoder) + CONVERTER_CASE(Type::BINARY, BinaryType, BinaryValueDecoder<false>) + CONVERTER_CASE(Type::LARGE_BINARY, LargeBinaryType, BinaryValueDecoder<false>) + + case Type::STRING: + if (options.check_utf8) { + ptr = std::make_shared< + TypedDictionaryConverter<StringType, BinaryValueDecoder<true>>>(type, options, + pool); + } else { + ptr = std::make_shared< + TypedDictionaryConverter<StringType, BinaryValueDecoder<false>>>( + type, options, pool); + } + break; + + case Type::LARGE_STRING: + if (options.check_utf8) { + ptr = std::make_shared< + TypedDictionaryConverter<LargeStringType, BinaryValueDecoder<true>>>( + type, options, pool); + } else { + ptr = std::make_shared< + TypedDictionaryConverter<LargeStringType, BinaryValueDecoder<false>>>( + type, options, pool); + } + break; + + default: { + return Status::NotImplemented("CSV dictionary conversion to ", type->ToString(), + " is not supported"); + } + +#undef CONVERTER_CASE + } + RETURN_NOT_OK(ptr->Initialize()); + return ptr; +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.h index 639f692f26..3bf1cb898f 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/converter.h @@ -1,82 +1,82 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> -#include <memory> - -#include "arrow/csv/options.h" -#include "arrow/result.h" -#include "arrow/type_fwd.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace csv { - -class BlockParser; - -class ARROW_EXPORT Converter { - public: - Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, - MemoryPool* pool); - virtual ~Converter() = default; - - virtual Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, - int32_t col_index) = 0; - - std::shared_ptr<DataType> type() const { return type_; } - - // Create a Converter for the given data type - static Result<std::shared_ptr<Converter>> Make( - const std::shared_ptr<DataType>& type, const ConvertOptions& options, - MemoryPool* pool = default_memory_pool()); - - protected: - ARROW_DISALLOW_COPY_AND_ASSIGN(Converter); - - virtual Status Initialize() = 0; - - // CAUTION: ConvertOptions can grow large (if it customizes hundreds or - // thousands of columns), so avoid copying it in each Converter. - const ConvertOptions& options_; - MemoryPool* pool_; - std::shared_ptr<DataType> type_; -}; - -class ARROW_EXPORT DictionaryConverter : public Converter { - public: - DictionaryConverter(const std::shared_ptr<DataType>& value_type, - const ConvertOptions& options, MemoryPool* pool); - - // If the dictionary length goes above this value, conversion will fail - // with Status::IndexError. - virtual void SetMaxCardinality(int32_t max_length) = 0; - - // Create a Converter for the given dictionary value type. - // The dictionary index type will always be Int32. - static Result<std::shared_ptr<DictionaryConverter>> Make( - const std::shared_ptr<DataType>& value_type, const ConvertOptions& options, - MemoryPool* pool = default_memory_pool()); - - protected: - std::shared_ptr<DataType> value_type_; -}; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> + +#include "arrow/csv/options.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +class BlockParser; + +class ARROW_EXPORT Converter { + public: + Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options, + MemoryPool* pool); + virtual ~Converter() = default; + + virtual Result<std::shared_ptr<Array>> Convert(const BlockParser& parser, + int32_t col_index) = 0; + + std::shared_ptr<DataType> type() const { return type_; } + + // Create a Converter for the given data type + static Result<std::shared_ptr<Converter>> Make( + const std::shared_ptr<DataType>& type, const ConvertOptions& options, + MemoryPool* pool = default_memory_pool()); + + protected: + ARROW_DISALLOW_COPY_AND_ASSIGN(Converter); + + virtual Status Initialize() = 0; + + // CAUTION: ConvertOptions can grow large (if it customizes hundreds or + // thousands of columns), so avoid copying it in each Converter. + const ConvertOptions& options_; + MemoryPool* pool_; + std::shared_ptr<DataType> type_; +}; + +class ARROW_EXPORT DictionaryConverter : public Converter { + public: + DictionaryConverter(const std::shared_ptr<DataType>& value_type, + const ConvertOptions& options, MemoryPool* pool); + + // If the dictionary length goes above this value, conversion will fail + // with Status::IndexError. + virtual void SetMaxCardinality(int32_t max_length) = 0; + + // Create a Converter for the given dictionary value type. + // The dictionary index type will always be Int32. + static Result<std::shared_ptr<DictionaryConverter>> Make( + const std::shared_ptr<DataType>& value_type, const ConvertOptions& options, + MemoryPool* pool = default_memory_pool()); + + protected: + std::shared_ptr<DataType> value_type_; +}; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/inference_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/inference_internal.h index 42486a1eba..9549a55bea 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/inference_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/inference_internal.h @@ -1,150 +1,150 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <memory> - -#include "arrow/csv/converter.h" -#include "arrow/csv/options.h" -#include "arrow/util/logging.h" - -namespace arrow { -namespace csv { - -enum class InferKind { - Null, - Integer, - Boolean, - Real, - Date, - Timestamp, - TimestampNS, - TextDict, - BinaryDict, - Text, - Binary -}; - -class InferStatus { - public: - explicit InferStatus(const ConvertOptions& options) - : kind_(InferKind::Null), can_loosen_type_(true), options_(options) {} - - InferKind kind() const { return kind_; } - - bool can_loosen_type() const { return can_loosen_type_; } - - void LoosenType(const Status& conversion_error) { - DCHECK(can_loosen_type_); - - switch (kind_) { - case InferKind::Null: - return SetKind(InferKind::Integer); - case InferKind::Integer: - return SetKind(InferKind::Boolean); - case InferKind::Boolean: - return SetKind(InferKind::Date); - case InferKind::Date: - return SetKind(InferKind::Timestamp); - case InferKind::Timestamp: - return SetKind(InferKind::TimestampNS); - case InferKind::TimestampNS: - return SetKind(InferKind::Real); - case InferKind::Real: - if (options_.auto_dict_encode) { - return SetKind(InferKind::TextDict); - } else { - return SetKind(InferKind::Text); - } - case InferKind::TextDict: - if (conversion_error.IsIndexError()) { - // Cardinality too large, fall back to non-dict encoding - return SetKind(InferKind::Text); - } else { - // Assuming UTF8 validation failure - return SetKind(InferKind::BinaryDict); - } - break; - case InferKind::BinaryDict: - // Assuming cardinality too large - return SetKind(InferKind::Binary); - case InferKind::Text: - // Assuming UTF8 validation failure - return SetKind(InferKind::Binary); - default: - ARROW_LOG(FATAL) << "Shouldn't come here"; - } - } - - Result<std::shared_ptr<Converter>> MakeConverter(MemoryPool* pool) { - auto make_converter = - [&](std::shared_ptr<DataType> type) -> Result<std::shared_ptr<Converter>> { - return Converter::Make(type, options_, pool); - }; - - auto make_dict_converter = - [&](std::shared_ptr<DataType> type) -> Result<std::shared_ptr<Converter>> { - ARROW_ASSIGN_OR_RAISE(auto dict_converter, - DictionaryConverter::Make(type, options_, pool)); - dict_converter->SetMaxCardinality(options_.auto_dict_max_cardinality); - return dict_converter; - }; - - switch (kind_) { - case InferKind::Null: - return make_converter(null()); - case InferKind::Integer: - return make_converter(int64()); - case InferKind::Boolean: - return make_converter(boolean()); - case InferKind::Date: - return make_converter(date32()); - case InferKind::Timestamp: - return make_converter(timestamp(TimeUnit::SECOND)); - case InferKind::TimestampNS: - return make_converter(timestamp(TimeUnit::NANO)); - case InferKind::Real: - return make_converter(float64()); - case InferKind::Text: - return make_converter(utf8()); - case InferKind::Binary: - return make_converter(binary()); - case InferKind::TextDict: - return make_dict_converter(utf8()); - case InferKind::BinaryDict: - return make_dict_converter(binary()); - } - return Status::UnknownError("Shouldn't come here"); - } - - protected: - void SetKind(InferKind kind) { - kind_ = kind; - if (kind == InferKind::Binary) { - // Binary is the catch-all type - can_loosen_type_ = false; - } - } - - InferKind kind_; - bool can_loosen_type_; - const ConvertOptions& options_; -}; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/csv/converter.h" +#include "arrow/csv/options.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace csv { + +enum class InferKind { + Null, + Integer, + Boolean, + Real, + Date, + Timestamp, + TimestampNS, + TextDict, + BinaryDict, + Text, + Binary +}; + +class InferStatus { + public: + explicit InferStatus(const ConvertOptions& options) + : kind_(InferKind::Null), can_loosen_type_(true), options_(options) {} + + InferKind kind() const { return kind_; } + + bool can_loosen_type() const { return can_loosen_type_; } + + void LoosenType(const Status& conversion_error) { + DCHECK(can_loosen_type_); + + switch (kind_) { + case InferKind::Null: + return SetKind(InferKind::Integer); + case InferKind::Integer: + return SetKind(InferKind::Boolean); + case InferKind::Boolean: + return SetKind(InferKind::Date); + case InferKind::Date: + return SetKind(InferKind::Timestamp); + case InferKind::Timestamp: + return SetKind(InferKind::TimestampNS); + case InferKind::TimestampNS: + return SetKind(InferKind::Real); + case InferKind::Real: + if (options_.auto_dict_encode) { + return SetKind(InferKind::TextDict); + } else { + return SetKind(InferKind::Text); + } + case InferKind::TextDict: + if (conversion_error.IsIndexError()) { + // Cardinality too large, fall back to non-dict encoding + return SetKind(InferKind::Text); + } else { + // Assuming UTF8 validation failure + return SetKind(InferKind::BinaryDict); + } + break; + case InferKind::BinaryDict: + // Assuming cardinality too large + return SetKind(InferKind::Binary); + case InferKind::Text: + // Assuming UTF8 validation failure + return SetKind(InferKind::Binary); + default: + ARROW_LOG(FATAL) << "Shouldn't come here"; + } + } + + Result<std::shared_ptr<Converter>> MakeConverter(MemoryPool* pool) { + auto make_converter = + [&](std::shared_ptr<DataType> type) -> Result<std::shared_ptr<Converter>> { + return Converter::Make(type, options_, pool); + }; + + auto make_dict_converter = + [&](std::shared_ptr<DataType> type) -> Result<std::shared_ptr<Converter>> { + ARROW_ASSIGN_OR_RAISE(auto dict_converter, + DictionaryConverter::Make(type, options_, pool)); + dict_converter->SetMaxCardinality(options_.auto_dict_max_cardinality); + return dict_converter; + }; + + switch (kind_) { + case InferKind::Null: + return make_converter(null()); + case InferKind::Integer: + return make_converter(int64()); + case InferKind::Boolean: + return make_converter(boolean()); + case InferKind::Date: + return make_converter(date32()); + case InferKind::Timestamp: + return make_converter(timestamp(TimeUnit::SECOND)); + case InferKind::TimestampNS: + return make_converter(timestamp(TimeUnit::NANO)); + case InferKind::Real: + return make_converter(float64()); + case InferKind::Text: + return make_converter(utf8()); + case InferKind::Binary: + return make_converter(binary()); + case InferKind::TextDict: + return make_dict_converter(utf8()); + case InferKind::BinaryDict: + return make_dict_converter(binary()); + } + return Status::UnknownError("Shouldn't come here"); + } + + protected: + void SetKind(InferKind kind) { + kind_ = kind; + if (kind == InferKind::Binary) { + // Binary is the catch-all type + can_loosen_type_ = false; + } + } + + InferKind kind_; + bool can_loosen_type_; + const ConvertOptions& options_; +}; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.cc index c71cfdaf29..f15fada47e 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.cc @@ -1,83 +1,83 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/options.h" - -namespace arrow { -namespace csv { - -ParseOptions ParseOptions::Defaults() { return ParseOptions(); } - -Status ParseOptions::Validate() const { - if (ARROW_PREDICT_FALSE(delimiter == '\n' || delimiter == '\r')) { - return Status::Invalid("ParseOptions: delimiter cannot be \\r or \\n"); - } - if (ARROW_PREDICT_FALSE(quoting && (quote_char == '\n' || quote_char == '\r'))) { - return Status::Invalid("ParseOptions: quote_char cannot be \\r or \\n"); - } - if (ARROW_PREDICT_FALSE(escaping && (escape_char == '\n' || escape_char == '\r'))) { - return Status::Invalid("ParseOptions: escape_char cannot be \\r or \\n"); - } - return Status::OK(); -} - -ConvertOptions ConvertOptions::Defaults() { - auto options = ConvertOptions(); - // Same default null / true / false spellings as in Pandas. - options.null_values = {"", "#N/A", "#N/A N/A", "#NA", "-1.#IND", "-1.#QNAN", - "-NaN", "-nan", "1.#IND", "1.#QNAN", "N/A", "NA", - "NULL", "NaN", "n/a", "nan", "null"}; - options.true_values = {"1", "True", "TRUE", "true"}; - options.false_values = {"0", "False", "FALSE", "false"}; - return options; -} - -Status ConvertOptions::Validate() const { return Status::OK(); } - -ReadOptions ReadOptions::Defaults() { return ReadOptions(); } - -Status ReadOptions::Validate() const { - if (ARROW_PREDICT_FALSE(block_size < 1)) { - // Min is 1 because some tests use really small block sizes - return Status::Invalid("ReadOptions: block_size must be at least 1: ", block_size); - } - if (ARROW_PREDICT_FALSE(skip_rows < 0)) { - return Status::Invalid("ReadOptions: skip_rows cannot be negative: ", skip_rows); - } - if (ARROW_PREDICT_FALSE(skip_rows_after_names < 0)) { - return Status::Invalid("ReadOptions: skip_rows_after_names cannot be negative: ", - skip_rows_after_names); - } - if (ARROW_PREDICT_FALSE(autogenerate_column_names && !column_names.empty())) { - return Status::Invalid( - "ReadOptions: autogenerate_column_names cannot be true when column_names are " - "provided"); - } - return Status::OK(); -} - -WriteOptions WriteOptions::Defaults() { return WriteOptions(); } - -Status WriteOptions::Validate() const { - if (ARROW_PREDICT_FALSE(batch_size < 1)) { - return Status::Invalid("WriteOptions: batch_size must be at least 1: ", batch_size); - } - return Status::OK(); -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/options.h" + +namespace arrow { +namespace csv { + +ParseOptions ParseOptions::Defaults() { return ParseOptions(); } + +Status ParseOptions::Validate() const { + if (ARROW_PREDICT_FALSE(delimiter == '\n' || delimiter == '\r')) { + return Status::Invalid("ParseOptions: delimiter cannot be \\r or \\n"); + } + if (ARROW_PREDICT_FALSE(quoting && (quote_char == '\n' || quote_char == '\r'))) { + return Status::Invalid("ParseOptions: quote_char cannot be \\r or \\n"); + } + if (ARROW_PREDICT_FALSE(escaping && (escape_char == '\n' || escape_char == '\r'))) { + return Status::Invalid("ParseOptions: escape_char cannot be \\r or \\n"); + } + return Status::OK(); +} + +ConvertOptions ConvertOptions::Defaults() { + auto options = ConvertOptions(); + // Same default null / true / false spellings as in Pandas. + options.null_values = {"", "#N/A", "#N/A N/A", "#NA", "-1.#IND", "-1.#QNAN", + "-NaN", "-nan", "1.#IND", "1.#QNAN", "N/A", "NA", + "NULL", "NaN", "n/a", "nan", "null"}; + options.true_values = {"1", "True", "TRUE", "true"}; + options.false_values = {"0", "False", "FALSE", "false"}; + return options; +} + +Status ConvertOptions::Validate() const { return Status::OK(); } + +ReadOptions ReadOptions::Defaults() { return ReadOptions(); } + +Status ReadOptions::Validate() const { + if (ARROW_PREDICT_FALSE(block_size < 1)) { + // Min is 1 because some tests use really small block sizes + return Status::Invalid("ReadOptions: block_size must be at least 1: ", block_size); + } + if (ARROW_PREDICT_FALSE(skip_rows < 0)) { + return Status::Invalid("ReadOptions: skip_rows cannot be negative: ", skip_rows); + } + if (ARROW_PREDICT_FALSE(skip_rows_after_names < 0)) { + return Status::Invalid("ReadOptions: skip_rows_after_names cannot be negative: ", + skip_rows_after_names); + } + if (ARROW_PREDICT_FALSE(autogenerate_column_names && !column_names.empty())) { + return Status::Invalid( + "ReadOptions: autogenerate_column_names cannot be true when column_names are " + "provided"); + } + return Status::OK(); +} + +WriteOptions WriteOptions::Defaults() { return WriteOptions(); } + +Status WriteOptions::Validate() const { + if (ARROW_PREDICT_FALSE(batch_size < 1)) { + return Status::Invalid("WriteOptions: batch_size must be at least 1: ", batch_size); + } + return Status::OK(); +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.h index 5face6f32d..9e6f704af9 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/options.h @@ -1,189 +1,189 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstdint> -#include <memory> -#include <string> -#include <unordered_map> -#include <vector> - -#include "arrow/csv/type_fwd.h" -#include "arrow/io/interfaces.h" -#include "arrow/status.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class DataType; -class TimestampParser; - -namespace csv { - -// Silly workaround for https://github.com/michaeljones/breathe/issues/453 -constexpr char kDefaultEscapeChar = '\\'; - -struct ARROW_EXPORT ParseOptions { - // Parsing options - - /// Field delimiter - char delimiter = ','; - /// Whether quoting is used - bool quoting = true; - /// Quoting character (if `quoting` is true) - char quote_char = '"'; - /// Whether a quote inside a value is double-quoted - bool double_quote = true; - /// Whether escaping is used - bool escaping = false; - /// Escaping character (if `escaping` is true) - char escape_char = kDefaultEscapeChar; - /// Whether values are allowed to contain CR (0x0d) and LF (0x0a) characters - bool newlines_in_values = false; - /// Whether empty lines are ignored. If false, an empty line represents - /// a single empty value (assuming a one-column CSV file). - bool ignore_empty_lines = true; - - /// Create parsing options with default values - static ParseOptions Defaults(); - - /// \brief Test that all set options are valid - Status Validate() const; -}; - -struct ARROW_EXPORT ConvertOptions { - // Conversion options - - /// Whether to check UTF8 validity of string columns - bool check_utf8 = true; - /// Optional per-column types (disabling type inference on those columns) - std::unordered_map<std::string, std::shared_ptr<DataType>> column_types; - /// Recognized spellings for null values - std::vector<std::string> null_values; - /// Recognized spellings for boolean true values - std::vector<std::string> true_values; - /// Recognized spellings for boolean false values - std::vector<std::string> false_values; - - /// Whether string / binary columns can have null values. - /// - /// If true, then strings in "null_values" are considered null for string columns. - /// If false, then all strings are valid string values. - bool strings_can_be_null = false; - /// Whether string / binary columns can have quoted null values. - /// - /// If true *and* `strings_can_be_null` is true, then quoted strings in - /// "null_values" are also considered null for string columns. Otherwise, - /// quoted strings are never considered null. - bool quoted_strings_can_be_null = true; - - /// Whether to try to automatically dict-encode string / binary data. - /// If true, then when type inference detects a string or binary column, - /// it is dict-encoded up to `auto_dict_max_cardinality` distinct values - /// (per chunk), after which it switches to regular encoding. - /// - /// This setting is ignored for non-inferred columns (those in `column_types`). - bool auto_dict_encode = false; - int32_t auto_dict_max_cardinality = 50; - - // XXX Should we have a separate FilterOptions? - - /// If non-empty, indicates the names of columns from the CSV file that should - /// be actually read and converted (in the vector's order). - /// Columns not in this vector will be ignored. - std::vector<std::string> include_columns; - /// If false, columns in `include_columns` but not in the CSV file will error out. - /// If true, columns in `include_columns` but not in the CSV file will produce - /// a column of nulls (whose type is selected using `column_types`, - /// or null by default) - /// This option is ignored if `include_columns` is empty. - bool include_missing_columns = false; - - /// User-defined timestamp parsers, using the virtual parser interface in - /// arrow/util/value_parsing.h. More than one parser can be specified, and - /// the CSV conversion logic will try parsing values starting from the - /// beginning of this vector. If no parsers are specified, we use the default - /// built-in ISO-8601 parser. - std::vector<std::shared_ptr<TimestampParser>> timestamp_parsers; - - /// Create conversion options with default values, including conventional - /// values for `null_values`, `true_values` and `false_values` - static ConvertOptions Defaults(); - - /// \brief Test that all set options are valid - Status Validate() const; -}; - -struct ARROW_EXPORT ReadOptions { - // Reader options - - /// Whether to use the global CPU thread pool - bool use_threads = true; - - /// \brief Block size we request from the IO layer. - /// - /// This will determine multi-threading granularity as well as - /// the size of individual record batches. - /// Minimum valid value for block size is 1 - int32_t block_size = 1 << 20; // 1 MB - - /// Number of header rows to skip (not including the row of column names, if any) - int32_t skip_rows = 0; - - /// Number of rows to skip after the column names are read, if any - int32_t skip_rows_after_names = 0; - - /// Column names for the target table. - /// If empty, fall back on autogenerate_column_names. - std::vector<std::string> column_names; - - /// Whether to autogenerate column names if `column_names` is empty. - /// If true, column names will be of the form "f0", "f1"... - /// If false, column names will be read from the first CSV row after `skip_rows`. - bool autogenerate_column_names = false; - - /// Create read options with default values - static ReadOptions Defaults(); - - /// \brief Test that all set options are valid - Status Validate() const; -}; - -/// Experimental -struct ARROW_EXPORT WriteOptions { - /// Whether to write an initial header line with column names - bool include_header = true; - - /// \brief Maximum number of rows processed at a time - /// - /// The CSV writer converts and writes data in batches of N rows. - /// This number can impact performance. - int32_t batch_size = 1024; - - /// \brief IO context for writing. - io::IOContext io_context; - - /// Create write options with default values - static WriteOptions Defaults(); - - /// \brief Test that all set options are valid - Status Validate() const; -}; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "arrow/csv/type_fwd.h" +#include "arrow/io/interfaces.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class DataType; +class TimestampParser; + +namespace csv { + +// Silly workaround for https://github.com/michaeljones/breathe/issues/453 +constexpr char kDefaultEscapeChar = '\\'; + +struct ARROW_EXPORT ParseOptions { + // Parsing options + + /// Field delimiter + char delimiter = ','; + /// Whether quoting is used + bool quoting = true; + /// Quoting character (if `quoting` is true) + char quote_char = '"'; + /// Whether a quote inside a value is double-quoted + bool double_quote = true; + /// Whether escaping is used + bool escaping = false; + /// Escaping character (if `escaping` is true) + char escape_char = kDefaultEscapeChar; + /// Whether values are allowed to contain CR (0x0d) and LF (0x0a) characters + bool newlines_in_values = false; + /// Whether empty lines are ignored. If false, an empty line represents + /// a single empty value (assuming a one-column CSV file). + bool ignore_empty_lines = true; + + /// Create parsing options with default values + static ParseOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +struct ARROW_EXPORT ConvertOptions { + // Conversion options + + /// Whether to check UTF8 validity of string columns + bool check_utf8 = true; + /// Optional per-column types (disabling type inference on those columns) + std::unordered_map<std::string, std::shared_ptr<DataType>> column_types; + /// Recognized spellings for null values + std::vector<std::string> null_values; + /// Recognized spellings for boolean true values + std::vector<std::string> true_values; + /// Recognized spellings for boolean false values + std::vector<std::string> false_values; + + /// Whether string / binary columns can have null values. + /// + /// If true, then strings in "null_values" are considered null for string columns. + /// If false, then all strings are valid string values. + bool strings_can_be_null = false; + /// Whether string / binary columns can have quoted null values. + /// + /// If true *and* `strings_can_be_null` is true, then quoted strings in + /// "null_values" are also considered null for string columns. Otherwise, + /// quoted strings are never considered null. + bool quoted_strings_can_be_null = true; + + /// Whether to try to automatically dict-encode string / binary data. + /// If true, then when type inference detects a string or binary column, + /// it is dict-encoded up to `auto_dict_max_cardinality` distinct values + /// (per chunk), after which it switches to regular encoding. + /// + /// This setting is ignored for non-inferred columns (those in `column_types`). + bool auto_dict_encode = false; + int32_t auto_dict_max_cardinality = 50; + + // XXX Should we have a separate FilterOptions? + + /// If non-empty, indicates the names of columns from the CSV file that should + /// be actually read and converted (in the vector's order). + /// Columns not in this vector will be ignored. + std::vector<std::string> include_columns; + /// If false, columns in `include_columns` but not in the CSV file will error out. + /// If true, columns in `include_columns` but not in the CSV file will produce + /// a column of nulls (whose type is selected using `column_types`, + /// or null by default) + /// This option is ignored if `include_columns` is empty. + bool include_missing_columns = false; + + /// User-defined timestamp parsers, using the virtual parser interface in + /// arrow/util/value_parsing.h. More than one parser can be specified, and + /// the CSV conversion logic will try parsing values starting from the + /// beginning of this vector. If no parsers are specified, we use the default + /// built-in ISO-8601 parser. + std::vector<std::shared_ptr<TimestampParser>> timestamp_parsers; + + /// Create conversion options with default values, including conventional + /// values for `null_values`, `true_values` and `false_values` + static ConvertOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +struct ARROW_EXPORT ReadOptions { + // Reader options + + /// Whether to use the global CPU thread pool + bool use_threads = true; + + /// \brief Block size we request from the IO layer. + /// + /// This will determine multi-threading granularity as well as + /// the size of individual record batches. + /// Minimum valid value for block size is 1 + int32_t block_size = 1 << 20; // 1 MB + + /// Number of header rows to skip (not including the row of column names, if any) + int32_t skip_rows = 0; + + /// Number of rows to skip after the column names are read, if any + int32_t skip_rows_after_names = 0; + + /// Column names for the target table. + /// If empty, fall back on autogenerate_column_names. + std::vector<std::string> column_names; + + /// Whether to autogenerate column names if `column_names` is empty. + /// If true, column names will be of the form "f0", "f1"... + /// If false, column names will be read from the first CSV row after `skip_rows`. + bool autogenerate_column_names = false; + + /// Create read options with default values + static ReadOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +/// Experimental +struct ARROW_EXPORT WriteOptions { + /// Whether to write an initial header line with column names + bool include_header = true; + + /// \brief Maximum number of rows processed at a time + /// + /// The CSV writer converts and writes data in batches of N rows. + /// This number can impact performance. + int32_t batch_size = 1024; + + /// \brief IO context for writing. + io::IOContext io_context; + + /// Create write options with default values + static WriteOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.cc index 446f36a4ee..0e1fd91c51 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.cc @@ -1,581 +1,581 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/parser.h" - -#include <algorithm> -#include <cstdio> -#include <limits> -#include <utility> - -#include "arrow/memory_pool.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/util/logging.h" - -namespace arrow { -namespace csv { - -using detail::DataBatch; -using detail::ParsedValueDesc; - -namespace { - -template <typename... Args> -Status ParseError(Args&&... args) { - return Status::Invalid("CSV parse error: ", std::forward<Args>(args)...); -} - -Status MismatchingColumns(int32_t expected, int32_t actual, int64_t row_num, - util::string_view row) { - std::string ellipse; - if (row.length() > 100) { - row = row.substr(0, 96); - ellipse = " ..."; - } - if (row_num < 0) { - return ParseError("Expected ", expected, " columns, got ", actual, ": ", row, - ellipse); - } - return ParseError("Row #", row_num, ": Expected ", expected, " columns, got ", actual, - ": ", row, ellipse); -} - -inline bool IsControlChar(uint8_t c) { return c < ' '; } - -template <bool Quoting, bool Escaping> -class SpecializedOptions { - public: - static constexpr bool quoting = Quoting; - static constexpr bool escaping = Escaping; -}; - -// A helper class allocating the buffer for parsed values and writing into it -// without any further resizes, except at the end. -class PresizedDataWriter { - public: - PresizedDataWriter(MemoryPool* pool, uint32_t size) - : parsed_size_(0), parsed_capacity_(size) { - parsed_buffer_ = *AllocateResizableBuffer(parsed_capacity_, pool); - parsed_ = parsed_buffer_->mutable_data(); - } - - void Finish(std::shared_ptr<Buffer>* out_parsed) { - ARROW_CHECK_OK(parsed_buffer_->Resize(parsed_size_)); - *out_parsed = parsed_buffer_; - } - - void BeginLine() { saved_parsed_size_ = parsed_size_; } - - void PushFieldChar(char c) { - DCHECK_LT(parsed_size_, parsed_capacity_); - parsed_[parsed_size_++] = static_cast<uint8_t>(c); - } - - // Rollback the state that was saved in BeginLine() - void RollbackLine() { parsed_size_ = saved_parsed_size_; } - - int64_t size() { return parsed_size_; } - - protected: - std::shared_ptr<ResizableBuffer> parsed_buffer_; - uint8_t* parsed_; - int64_t parsed_size_; - int64_t parsed_capacity_; - // Checkpointing, for when an incomplete line is encountered at end of block - int64_t saved_parsed_size_; -}; - -template <typename Derived> -class ValueDescWriter { - public: - Derived* derived() { return static_cast<Derived*>(this); } - - template <typename DataWriter> - void Start(DataWriter& parsed_writer) { - derived()->PushValue( - {static_cast<uint32_t>(parsed_writer.size()) & 0x7fffffffU, false}); - } - - void BeginLine() { saved_values_size_ = values_size_; } - - // Rollback the state that was saved in BeginLine() - void RollbackLine() { values_size_ = saved_values_size_; } - - void StartField(bool quoted) { quoted_ = quoted; } - - template <typename DataWriter> - void FinishField(DataWriter* parsed_writer) { - derived()->PushValue( - {static_cast<uint32_t>(parsed_writer->size()) & 0x7fffffffU, quoted_}); - } - - void Finish(std::shared_ptr<Buffer>* out_values) { - ARROW_CHECK_OK(values_buffer_->Resize(values_size_ * sizeof(*values_))); - *out_values = values_buffer_; - } - - protected: - ValueDescWriter(MemoryPool* pool, int64_t values_capacity) - : values_size_(0), values_capacity_(values_capacity) { - values_buffer_ = *AllocateResizableBuffer(values_capacity_ * sizeof(*values_), pool); - values_ = reinterpret_cast<ParsedValueDesc*>(values_buffer_->mutable_data()); - } - - std::shared_ptr<ResizableBuffer> values_buffer_; - ParsedValueDesc* values_; - int64_t values_size_; - int64_t values_capacity_; - bool quoted_; - // Checkpointing, for when an incomplete line is encountered at end of block - int64_t saved_values_size_; -}; - -// A helper class handling a growable buffer for values offsets. This class is -// used when the number of columns is not yet known and we therefore cannot -// efficiently presize the target area for a given number of rows. -class ResizableValueDescWriter : public ValueDescWriter<ResizableValueDescWriter> { - public: - explicit ResizableValueDescWriter(MemoryPool* pool) - : ValueDescWriter(pool, /*values_capacity=*/256) {} - - void PushValue(ParsedValueDesc v) { - if (ARROW_PREDICT_FALSE(values_size_ == values_capacity_)) { - values_capacity_ = values_capacity_ * 2; - ARROW_CHECK_OK(values_buffer_->Resize(values_capacity_ * sizeof(*values_))); - values_ = reinterpret_cast<ParsedValueDesc*>(values_buffer_->mutable_data()); - } - values_[values_size_++] = v; - } -}; - -// A helper class allocating the buffer for values offsets and writing into it -// without any further resizes, except at the end. This class is used once the -// number of columns is known, as it eliminates resizes and generates simpler, -// faster CSV parsing code. -class PresizedValueDescWriter : public ValueDescWriter<PresizedValueDescWriter> { - public: - PresizedValueDescWriter(MemoryPool* pool, int32_t num_rows, int32_t num_cols) - : ValueDescWriter(pool, /*values_capacity=*/1 + num_rows * num_cols) {} - - void PushValue(ParsedValueDesc v) { - DCHECK_LT(values_size_, values_capacity_); - values_[values_size_++] = v; - } -}; - -} // namespace - -class BlockParserImpl { - public: - BlockParserImpl(MemoryPool* pool, ParseOptions options, int32_t num_cols, - int64_t first_row, int32_t max_num_rows) - : pool_(pool), - options_(options), - first_row_(first_row), - max_num_rows_(max_num_rows), - batch_(num_cols) {} - - const DataBatch& parsed_batch() const { return batch_; } - - int64_t first_row_num() const { return first_row_; } - - template <typename SpecializedOptions, typename ValueDescWriter, typename DataWriter> - Status ParseLine(ValueDescWriter* values_writer, DataWriter* parsed_writer, - const char* data, const char* data_end, bool is_final, - const char** out_data) { - int32_t num_cols = 0; - char c; - const auto start = data; - - DCHECK_GT(data_end, data); - - auto FinishField = [&]() { values_writer->FinishField(parsed_writer); }; - - values_writer->BeginLine(); - parsed_writer->BeginLine(); - - // The parsing state machine - - // Special case empty lines: do we start with a newline separator? - c = *data; - if (ARROW_PREDICT_FALSE(IsControlChar(c))) { - if (c == '\r') { - data++; - if (data < data_end && *data == '\n') { - data++; - } - goto EmptyLine; - } - if (c == '\n') { - data++; - goto EmptyLine; - } - } - - FieldStart: - // At the start of a field - // Quoting is only recognized at start of field - if (SpecializedOptions::quoting && - ARROW_PREDICT_FALSE(*data == options_.quote_char)) { - ++data; - values_writer->StartField(true /* quoted */); - goto InQuotedField; - } else { - values_writer->StartField(false /* quoted */); - goto InField; - } - - InField: - // Inside a non-quoted part of a field - if (ARROW_PREDICT_FALSE(data == data_end)) { - goto AbortLine; - } - c = *data++; - if (SpecializedOptions::escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { - if (ARROW_PREDICT_FALSE(data == data_end)) { - goto AbortLine; - } - c = *data++; - parsed_writer->PushFieldChar(c); - goto InField; - } - if (ARROW_PREDICT_FALSE(c == options_.delimiter)) { - goto FieldEnd; - } - if (ARROW_PREDICT_FALSE(IsControlChar(c))) { - if (c == '\r') { - // In the middle of a newline separator? - if (ARROW_PREDICT_TRUE(data < data_end) && *data == '\n') { - data++; - } - goto LineEnd; - } - if (c == '\n') { - goto LineEnd; - } - } - parsed_writer->PushFieldChar(c); - goto InField; - - InQuotedField: - // Inside a quoted part of a field - if (ARROW_PREDICT_FALSE(data == data_end)) { - goto AbortLine; - } - c = *data++; - if (SpecializedOptions::escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { - if (ARROW_PREDICT_FALSE(data == data_end)) { - goto AbortLine; - } - c = *data++; - parsed_writer->PushFieldChar(c); - goto InQuotedField; - } - if (ARROW_PREDICT_FALSE(c == options_.quote_char)) { - if (options_.double_quote && ARROW_PREDICT_TRUE(data < data_end) && - ARROW_PREDICT_FALSE(*data == options_.quote_char)) { - // Double-quoting - ++data; - } else { - // End of single-quoting - goto InField; - } - } - parsed_writer->PushFieldChar(c); - goto InQuotedField; - - FieldEnd: - // At the end of a field - FinishField(); - ++num_cols; - if (ARROW_PREDICT_FALSE(data == data_end)) { - goto AbortLine; - } - goto FieldStart; - - LineEnd: - // At the end of line - FinishField(); - ++num_cols; - if (ARROW_PREDICT_FALSE(num_cols != batch_.num_cols_)) { - if (batch_.num_cols_ == -1) { - batch_.num_cols_ = num_cols; - } else { - // Find the end of the line without newline or carriage return - auto end = data; - if (*(end - 1) == '\n') { - --end; - } - if (*(end - 1) == '\r') { - --end; - } - return MismatchingColumns(batch_.num_cols_, num_cols, - first_row_ < 0 ? -1 : first_row_ + batch_.num_rows_, - util::string_view(start, end - start)); - } - } - ++batch_.num_rows_; - *out_data = data; - return Status::OK(); - - AbortLine: - // Not a full line except perhaps if in final block - if (is_final) { - goto LineEnd; - } - // Truncated line at end of block, rewind parsed state - values_writer->RollbackLine(); - parsed_writer->RollbackLine(); - return Status::OK(); - - EmptyLine: - if (!options_.ignore_empty_lines) { - if (batch_.num_cols_ == -1) { - // Consider as single value - batch_.num_cols_ = 1; - } - // Record as row of empty (null?) values - while (num_cols++ < batch_.num_cols_) { - values_writer->StartField(false /* quoted */); - FinishField(); - } - ++batch_.num_rows_; - } - *out_data = data; - return Status::OK(); - } - - template <typename SpecializedOptions, typename ValueDescWriter, typename DataWriter> - Status ParseChunk(ValueDescWriter* values_writer, DataWriter* parsed_writer, - const char* data, const char* data_end, bool is_final, - int32_t rows_in_chunk, const char** out_data, - bool* finished_parsing) { - int32_t num_rows_deadline = batch_.num_rows_ + rows_in_chunk; - - while (data < data_end && batch_.num_rows_ < num_rows_deadline) { - const char* line_end = data; - RETURN_NOT_OK(ParseLine<SpecializedOptions>(values_writer, parsed_writer, data, - data_end, is_final, &line_end)); - if (line_end == data) { - // Cannot parse any further - *finished_parsing = true; - break; - } - data = line_end; - } - // Append new buffers and update size - std::shared_ptr<Buffer> values_buffer; - values_writer->Finish(&values_buffer); - if (values_buffer->size() > 0) { - values_size_ += - static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc) - 1); - batch_.values_buffers_.push_back(std::move(values_buffer)); - } - *out_data = data; - return Status::OK(); - } - - template <typename SpecializedOptions> - Status ParseSpecialized(const std::vector<util::string_view>& views, bool is_final, - uint32_t* out_size) { - batch_ = DataBatch{batch_.num_cols_}; - values_size_ = 0; - - size_t total_view_length = 0; - for (const auto& view : views) { - total_view_length += view.length(); - } - if (total_view_length > std::numeric_limits<uint32_t>::max()) { - return Status::Invalid("CSV block too large"); - } - - PresizedDataWriter parsed_writer(pool_, static_cast<uint32_t>(total_view_length)); - uint32_t total_parsed_length = 0; - - for (const auto& view : views) { - const char* data = view.data(); - const char* data_end = view.data() + view.length(); - bool finished_parsing = false; - - if (batch_.num_cols_ == -1) { - // Can't presize values when the number of columns is not known, first parse - // a single line - const int32_t rows_in_chunk = 1; - ResizableValueDescWriter values_writer(pool_); - values_writer.Start(parsed_writer); - - RETURN_NOT_OK(ParseChunk<SpecializedOptions>(&values_writer, &parsed_writer, data, - data_end, is_final, rows_in_chunk, - &data, &finished_parsing)); - if (batch_.num_cols_ == -1) { - return ParseError("Empty CSV file or block: cannot infer number of columns"); - } - } - - while (!finished_parsing && data < data_end && batch_.num_rows_ < max_num_rows_) { - // We know the number of columns, so can presize a values array for - // a given number of rows - DCHECK_GE(batch_.num_cols_, 0); - - int32_t rows_in_chunk; - constexpr int32_t kTargetChunkSize = 32768; // in number of values - if (batch_.num_cols_ > 0) { - rows_in_chunk = std::min(std::max(kTargetChunkSize / batch_.num_cols_, 512), - max_num_rows_ - batch_.num_rows_); - } else { - rows_in_chunk = std::min(kTargetChunkSize, max_num_rows_ - batch_.num_rows_); - } - - PresizedValueDescWriter values_writer(pool_, rows_in_chunk, batch_.num_cols_); - values_writer.Start(parsed_writer); - - RETURN_NOT_OK(ParseChunk<SpecializedOptions>(&values_writer, &parsed_writer, data, - data_end, is_final, rows_in_chunk, - &data, &finished_parsing)); - } - DCHECK_GE(data, view.data()); - DCHECK_LE(data, data_end); - total_parsed_length += static_cast<uint32_t>(data - view.data()); - - if (data < data_end) { - // Stopped early, for some reason - break; - } - } - - parsed_writer.Finish(&batch_.parsed_buffer_); - batch_.parsed_size_ = static_cast<int32_t>(batch_.parsed_buffer_->size()); - batch_.parsed_ = batch_.parsed_buffer_->data(); - - if (batch_.num_cols_ == -1) { - DCHECK_EQ(batch_.num_rows_, 0); - } - DCHECK_EQ(values_size_, batch_.num_rows_ * batch_.num_cols_); -#ifndef NDEBUG - if (batch_.num_rows_ > 0) { - // Ending parsed offset should be equal to number of parsed bytes - DCHECK_GT(batch_.values_buffers_.size(), 0); - const auto& last_values_buffer = batch_.values_buffers_.back(); - const auto last_values = - reinterpret_cast<const ParsedValueDesc*>(last_values_buffer->data()); - const auto last_values_size = last_values_buffer->size() / sizeof(ParsedValueDesc); - const auto check_parsed_size = - static_cast<int32_t>(last_values[last_values_size - 1].offset); - DCHECK_EQ(batch_.parsed_size_, check_parsed_size); - } else { - DCHECK_EQ(batch_.parsed_size_, 0); - } -#endif - *out_size = static_cast<uint32_t>(total_parsed_length); - return Status::OK(); - } - - Status Parse(const std::vector<util::string_view>& data, bool is_final, - uint32_t* out_size) { - if (options_.quoting) { - if (options_.escaping) { - return ParseSpecialized<SpecializedOptions<true, true>>(data, is_final, out_size); - } else { - return ParseSpecialized<SpecializedOptions<true, false>>(data, is_final, - out_size); - } - } else { - if (options_.escaping) { - return ParseSpecialized<SpecializedOptions<false, true>>(data, is_final, - out_size); - } else { - return ParseSpecialized<SpecializedOptions<false, false>>(data, is_final, - out_size); - } - } - } - - protected: - MemoryPool* pool_; - const ParseOptions options_; - const int64_t first_row_; - // The maximum number of rows to parse from a block - int32_t max_num_rows_; - - // Unparsed data size - int32_t values_size_; - // Parsed data batch - DataBatch batch_; -}; - -BlockParser::BlockParser(ParseOptions options, int32_t num_cols, int64_t first_row, - int32_t max_num_rows) - : BlockParser(default_memory_pool(), options, num_cols, first_row, max_num_rows) {} - -BlockParser::BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols, - int64_t first_row, int32_t max_num_rows) - : impl_(new BlockParserImpl(pool, std::move(options), num_cols, first_row, - max_num_rows)) {} - -BlockParser::~BlockParser() {} - -Status BlockParser::Parse(const std::vector<util::string_view>& data, - uint32_t* out_size) { - return impl_->Parse(data, false /* is_final */, out_size); -} - -Status BlockParser::ParseFinal(const std::vector<util::string_view>& data, - uint32_t* out_size) { - return impl_->Parse(data, true /* is_final */, out_size); -} - -Status BlockParser::Parse(util::string_view data, uint32_t* out_size) { - return impl_->Parse({data}, false /* is_final */, out_size); -} - -Status BlockParser::ParseFinal(util::string_view data, uint32_t* out_size) { - return impl_->Parse({data}, true /* is_final */, out_size); -} - -const DataBatch& BlockParser::parsed_batch() const { return impl_->parsed_batch(); } - -int64_t BlockParser::first_row_num() const { return impl_->first_row_num(); } - -int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows, - const uint8_t** out_data) { - const auto end = data + size; - int32_t skipped_rows = 0; - *out_data = data; - - for (; skipped_rows < num_rows; ++skipped_rows) { - uint8_t c; - do { - while (ARROW_PREDICT_FALSE(data < end && !IsControlChar(*data))) { - ++data; - } - if (ARROW_PREDICT_FALSE(data == end)) { - return skipped_rows; - } - c = *data++; - } while (c != '\r' && c != '\n'); - if (c == '\r' && data < end && *data == '\n') { - ++data; - } - *out_data = data; - } - - return skipped_rows; -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/parser.h" + +#include <algorithm> +#include <cstdio> +#include <limits> +#include <utility> + +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace csv { + +using detail::DataBatch; +using detail::ParsedValueDesc; + +namespace { + +template <typename... Args> +Status ParseError(Args&&... args) { + return Status::Invalid("CSV parse error: ", std::forward<Args>(args)...); +} + +Status MismatchingColumns(int32_t expected, int32_t actual, int64_t row_num, + util::string_view row) { + std::string ellipse; + if (row.length() > 100) { + row = row.substr(0, 96); + ellipse = " ..."; + } + if (row_num < 0) { + return ParseError("Expected ", expected, " columns, got ", actual, ": ", row, + ellipse); + } + return ParseError("Row #", row_num, ": Expected ", expected, " columns, got ", actual, + ": ", row, ellipse); +} + +inline bool IsControlChar(uint8_t c) { return c < ' '; } + +template <bool Quoting, bool Escaping> +class SpecializedOptions { + public: + static constexpr bool quoting = Quoting; + static constexpr bool escaping = Escaping; +}; + +// A helper class allocating the buffer for parsed values and writing into it +// without any further resizes, except at the end. +class PresizedDataWriter { + public: + PresizedDataWriter(MemoryPool* pool, uint32_t size) + : parsed_size_(0), parsed_capacity_(size) { + parsed_buffer_ = *AllocateResizableBuffer(parsed_capacity_, pool); + parsed_ = parsed_buffer_->mutable_data(); + } + + void Finish(std::shared_ptr<Buffer>* out_parsed) { + ARROW_CHECK_OK(parsed_buffer_->Resize(parsed_size_)); + *out_parsed = parsed_buffer_; + } + + void BeginLine() { saved_parsed_size_ = parsed_size_; } + + void PushFieldChar(char c) { + DCHECK_LT(parsed_size_, parsed_capacity_); + parsed_[parsed_size_++] = static_cast<uint8_t>(c); + } + + // Rollback the state that was saved in BeginLine() + void RollbackLine() { parsed_size_ = saved_parsed_size_; } + + int64_t size() { return parsed_size_; } + + protected: + std::shared_ptr<ResizableBuffer> parsed_buffer_; + uint8_t* parsed_; + int64_t parsed_size_; + int64_t parsed_capacity_; + // Checkpointing, for when an incomplete line is encountered at end of block + int64_t saved_parsed_size_; +}; + +template <typename Derived> +class ValueDescWriter { + public: + Derived* derived() { return static_cast<Derived*>(this); } + + template <typename DataWriter> + void Start(DataWriter& parsed_writer) { + derived()->PushValue( + {static_cast<uint32_t>(parsed_writer.size()) & 0x7fffffffU, false}); + } + + void BeginLine() { saved_values_size_ = values_size_; } + + // Rollback the state that was saved in BeginLine() + void RollbackLine() { values_size_ = saved_values_size_; } + + void StartField(bool quoted) { quoted_ = quoted; } + + template <typename DataWriter> + void FinishField(DataWriter* parsed_writer) { + derived()->PushValue( + {static_cast<uint32_t>(parsed_writer->size()) & 0x7fffffffU, quoted_}); + } + + void Finish(std::shared_ptr<Buffer>* out_values) { + ARROW_CHECK_OK(values_buffer_->Resize(values_size_ * sizeof(*values_))); + *out_values = values_buffer_; + } + + protected: + ValueDescWriter(MemoryPool* pool, int64_t values_capacity) + : values_size_(0), values_capacity_(values_capacity) { + values_buffer_ = *AllocateResizableBuffer(values_capacity_ * sizeof(*values_), pool); + values_ = reinterpret_cast<ParsedValueDesc*>(values_buffer_->mutable_data()); + } + + std::shared_ptr<ResizableBuffer> values_buffer_; + ParsedValueDesc* values_; + int64_t values_size_; + int64_t values_capacity_; + bool quoted_; + // Checkpointing, for when an incomplete line is encountered at end of block + int64_t saved_values_size_; +}; + +// A helper class handling a growable buffer for values offsets. This class is +// used when the number of columns is not yet known and we therefore cannot +// efficiently presize the target area for a given number of rows. +class ResizableValueDescWriter : public ValueDescWriter<ResizableValueDescWriter> { + public: + explicit ResizableValueDescWriter(MemoryPool* pool) + : ValueDescWriter(pool, /*values_capacity=*/256) {} + + void PushValue(ParsedValueDesc v) { + if (ARROW_PREDICT_FALSE(values_size_ == values_capacity_)) { + values_capacity_ = values_capacity_ * 2; + ARROW_CHECK_OK(values_buffer_->Resize(values_capacity_ * sizeof(*values_))); + values_ = reinterpret_cast<ParsedValueDesc*>(values_buffer_->mutable_data()); + } + values_[values_size_++] = v; + } +}; + +// A helper class allocating the buffer for values offsets and writing into it +// without any further resizes, except at the end. This class is used once the +// number of columns is known, as it eliminates resizes and generates simpler, +// faster CSV parsing code. +class PresizedValueDescWriter : public ValueDescWriter<PresizedValueDescWriter> { + public: + PresizedValueDescWriter(MemoryPool* pool, int32_t num_rows, int32_t num_cols) + : ValueDescWriter(pool, /*values_capacity=*/1 + num_rows * num_cols) {} + + void PushValue(ParsedValueDesc v) { + DCHECK_LT(values_size_, values_capacity_); + values_[values_size_++] = v; + } +}; + +} // namespace + +class BlockParserImpl { + public: + BlockParserImpl(MemoryPool* pool, ParseOptions options, int32_t num_cols, + int64_t first_row, int32_t max_num_rows) + : pool_(pool), + options_(options), + first_row_(first_row), + max_num_rows_(max_num_rows), + batch_(num_cols) {} + + const DataBatch& parsed_batch() const { return batch_; } + + int64_t first_row_num() const { return first_row_; } + + template <typename SpecializedOptions, typename ValueDescWriter, typename DataWriter> + Status ParseLine(ValueDescWriter* values_writer, DataWriter* parsed_writer, + const char* data, const char* data_end, bool is_final, + const char** out_data) { + int32_t num_cols = 0; + char c; + const auto start = data; + + DCHECK_GT(data_end, data); + + auto FinishField = [&]() { values_writer->FinishField(parsed_writer); }; + + values_writer->BeginLine(); + parsed_writer->BeginLine(); + + // The parsing state machine + + // Special case empty lines: do we start with a newline separator? + c = *data; + if (ARROW_PREDICT_FALSE(IsControlChar(c))) { + if (c == '\r') { + data++; + if (data < data_end && *data == '\n') { + data++; + } + goto EmptyLine; + } + if (c == '\n') { + data++; + goto EmptyLine; + } + } + + FieldStart: + // At the start of a field + // Quoting is only recognized at start of field + if (SpecializedOptions::quoting && + ARROW_PREDICT_FALSE(*data == options_.quote_char)) { + ++data; + values_writer->StartField(true /* quoted */); + goto InQuotedField; + } else { + values_writer->StartField(false /* quoted */); + goto InField; + } + + InField: + // Inside a non-quoted part of a field + if (ARROW_PREDICT_FALSE(data == data_end)) { + goto AbortLine; + } + c = *data++; + if (SpecializedOptions::escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { + if (ARROW_PREDICT_FALSE(data == data_end)) { + goto AbortLine; + } + c = *data++; + parsed_writer->PushFieldChar(c); + goto InField; + } + if (ARROW_PREDICT_FALSE(c == options_.delimiter)) { + goto FieldEnd; + } + if (ARROW_PREDICT_FALSE(IsControlChar(c))) { + if (c == '\r') { + // In the middle of a newline separator? + if (ARROW_PREDICT_TRUE(data < data_end) && *data == '\n') { + data++; + } + goto LineEnd; + } + if (c == '\n') { + goto LineEnd; + } + } + parsed_writer->PushFieldChar(c); + goto InField; + + InQuotedField: + // Inside a quoted part of a field + if (ARROW_PREDICT_FALSE(data == data_end)) { + goto AbortLine; + } + c = *data++; + if (SpecializedOptions::escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) { + if (ARROW_PREDICT_FALSE(data == data_end)) { + goto AbortLine; + } + c = *data++; + parsed_writer->PushFieldChar(c); + goto InQuotedField; + } + if (ARROW_PREDICT_FALSE(c == options_.quote_char)) { + if (options_.double_quote && ARROW_PREDICT_TRUE(data < data_end) && + ARROW_PREDICT_FALSE(*data == options_.quote_char)) { + // Double-quoting + ++data; + } else { + // End of single-quoting + goto InField; + } + } + parsed_writer->PushFieldChar(c); + goto InQuotedField; + + FieldEnd: + // At the end of a field + FinishField(); + ++num_cols; + if (ARROW_PREDICT_FALSE(data == data_end)) { + goto AbortLine; + } + goto FieldStart; + + LineEnd: + // At the end of line + FinishField(); + ++num_cols; + if (ARROW_PREDICT_FALSE(num_cols != batch_.num_cols_)) { + if (batch_.num_cols_ == -1) { + batch_.num_cols_ = num_cols; + } else { + // Find the end of the line without newline or carriage return + auto end = data; + if (*(end - 1) == '\n') { + --end; + } + if (*(end - 1) == '\r') { + --end; + } + return MismatchingColumns(batch_.num_cols_, num_cols, + first_row_ < 0 ? -1 : first_row_ + batch_.num_rows_, + util::string_view(start, end - start)); + } + } + ++batch_.num_rows_; + *out_data = data; + return Status::OK(); + + AbortLine: + // Not a full line except perhaps if in final block + if (is_final) { + goto LineEnd; + } + // Truncated line at end of block, rewind parsed state + values_writer->RollbackLine(); + parsed_writer->RollbackLine(); + return Status::OK(); + + EmptyLine: + if (!options_.ignore_empty_lines) { + if (batch_.num_cols_ == -1) { + // Consider as single value + batch_.num_cols_ = 1; + } + // Record as row of empty (null?) values + while (num_cols++ < batch_.num_cols_) { + values_writer->StartField(false /* quoted */); + FinishField(); + } + ++batch_.num_rows_; + } + *out_data = data; + return Status::OK(); + } + + template <typename SpecializedOptions, typename ValueDescWriter, typename DataWriter> + Status ParseChunk(ValueDescWriter* values_writer, DataWriter* parsed_writer, + const char* data, const char* data_end, bool is_final, + int32_t rows_in_chunk, const char** out_data, + bool* finished_parsing) { + int32_t num_rows_deadline = batch_.num_rows_ + rows_in_chunk; + + while (data < data_end && batch_.num_rows_ < num_rows_deadline) { + const char* line_end = data; + RETURN_NOT_OK(ParseLine<SpecializedOptions>(values_writer, parsed_writer, data, + data_end, is_final, &line_end)); + if (line_end == data) { + // Cannot parse any further + *finished_parsing = true; + break; + } + data = line_end; + } + // Append new buffers and update size + std::shared_ptr<Buffer> values_buffer; + values_writer->Finish(&values_buffer); + if (values_buffer->size() > 0) { + values_size_ += + static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc) - 1); + batch_.values_buffers_.push_back(std::move(values_buffer)); + } + *out_data = data; + return Status::OK(); + } + + template <typename SpecializedOptions> + Status ParseSpecialized(const std::vector<util::string_view>& views, bool is_final, + uint32_t* out_size) { + batch_ = DataBatch{batch_.num_cols_}; + values_size_ = 0; + + size_t total_view_length = 0; + for (const auto& view : views) { + total_view_length += view.length(); + } + if (total_view_length > std::numeric_limits<uint32_t>::max()) { + return Status::Invalid("CSV block too large"); + } + + PresizedDataWriter parsed_writer(pool_, static_cast<uint32_t>(total_view_length)); + uint32_t total_parsed_length = 0; + + for (const auto& view : views) { + const char* data = view.data(); + const char* data_end = view.data() + view.length(); + bool finished_parsing = false; + + if (batch_.num_cols_ == -1) { + // Can't presize values when the number of columns is not known, first parse + // a single line + const int32_t rows_in_chunk = 1; + ResizableValueDescWriter values_writer(pool_); + values_writer.Start(parsed_writer); + + RETURN_NOT_OK(ParseChunk<SpecializedOptions>(&values_writer, &parsed_writer, data, + data_end, is_final, rows_in_chunk, + &data, &finished_parsing)); + if (batch_.num_cols_ == -1) { + return ParseError("Empty CSV file or block: cannot infer number of columns"); + } + } + + while (!finished_parsing && data < data_end && batch_.num_rows_ < max_num_rows_) { + // We know the number of columns, so can presize a values array for + // a given number of rows + DCHECK_GE(batch_.num_cols_, 0); + + int32_t rows_in_chunk; + constexpr int32_t kTargetChunkSize = 32768; // in number of values + if (batch_.num_cols_ > 0) { + rows_in_chunk = std::min(std::max(kTargetChunkSize / batch_.num_cols_, 512), + max_num_rows_ - batch_.num_rows_); + } else { + rows_in_chunk = std::min(kTargetChunkSize, max_num_rows_ - batch_.num_rows_); + } + + PresizedValueDescWriter values_writer(pool_, rows_in_chunk, batch_.num_cols_); + values_writer.Start(parsed_writer); + + RETURN_NOT_OK(ParseChunk<SpecializedOptions>(&values_writer, &parsed_writer, data, + data_end, is_final, rows_in_chunk, + &data, &finished_parsing)); + } + DCHECK_GE(data, view.data()); + DCHECK_LE(data, data_end); + total_parsed_length += static_cast<uint32_t>(data - view.data()); + + if (data < data_end) { + // Stopped early, for some reason + break; + } + } + + parsed_writer.Finish(&batch_.parsed_buffer_); + batch_.parsed_size_ = static_cast<int32_t>(batch_.parsed_buffer_->size()); + batch_.parsed_ = batch_.parsed_buffer_->data(); + + if (batch_.num_cols_ == -1) { + DCHECK_EQ(batch_.num_rows_, 0); + } + DCHECK_EQ(values_size_, batch_.num_rows_ * batch_.num_cols_); +#ifndef NDEBUG + if (batch_.num_rows_ > 0) { + // Ending parsed offset should be equal to number of parsed bytes + DCHECK_GT(batch_.values_buffers_.size(), 0); + const auto& last_values_buffer = batch_.values_buffers_.back(); + const auto last_values = + reinterpret_cast<const ParsedValueDesc*>(last_values_buffer->data()); + const auto last_values_size = last_values_buffer->size() / sizeof(ParsedValueDesc); + const auto check_parsed_size = + static_cast<int32_t>(last_values[last_values_size - 1].offset); + DCHECK_EQ(batch_.parsed_size_, check_parsed_size); + } else { + DCHECK_EQ(batch_.parsed_size_, 0); + } +#endif + *out_size = static_cast<uint32_t>(total_parsed_length); + return Status::OK(); + } + + Status Parse(const std::vector<util::string_view>& data, bool is_final, + uint32_t* out_size) { + if (options_.quoting) { + if (options_.escaping) { + return ParseSpecialized<SpecializedOptions<true, true>>(data, is_final, out_size); + } else { + return ParseSpecialized<SpecializedOptions<true, false>>(data, is_final, + out_size); + } + } else { + if (options_.escaping) { + return ParseSpecialized<SpecializedOptions<false, true>>(data, is_final, + out_size); + } else { + return ParseSpecialized<SpecializedOptions<false, false>>(data, is_final, + out_size); + } + } + } + + protected: + MemoryPool* pool_; + const ParseOptions options_; + const int64_t first_row_; + // The maximum number of rows to parse from a block + int32_t max_num_rows_; + + // Unparsed data size + int32_t values_size_; + // Parsed data batch + DataBatch batch_; +}; + +BlockParser::BlockParser(ParseOptions options, int32_t num_cols, int64_t first_row, + int32_t max_num_rows) + : BlockParser(default_memory_pool(), options, num_cols, first_row, max_num_rows) {} + +BlockParser::BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols, + int64_t first_row, int32_t max_num_rows) + : impl_(new BlockParserImpl(pool, std::move(options), num_cols, first_row, + max_num_rows)) {} + +BlockParser::~BlockParser() {} + +Status BlockParser::Parse(const std::vector<util::string_view>& data, + uint32_t* out_size) { + return impl_->Parse(data, false /* is_final */, out_size); +} + +Status BlockParser::ParseFinal(const std::vector<util::string_view>& data, + uint32_t* out_size) { + return impl_->Parse(data, true /* is_final */, out_size); +} + +Status BlockParser::Parse(util::string_view data, uint32_t* out_size) { + return impl_->Parse({data}, false /* is_final */, out_size); +} + +Status BlockParser::ParseFinal(util::string_view data, uint32_t* out_size) { + return impl_->Parse({data}, true /* is_final */, out_size); +} + +const DataBatch& BlockParser::parsed_batch() const { return impl_->parsed_batch(); } + +int64_t BlockParser::first_row_num() const { return impl_->first_row_num(); } + +int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows, + const uint8_t** out_data) { + const auto end = data + size; + int32_t skipped_rows = 0; + *out_data = data; + + for (; skipped_rows < num_rows; ++skipped_rows) { + uint8_t c; + do { + while (ARROW_PREDICT_FALSE(data < end && !IsControlChar(*data))) { + ++data; + } + if (ARROW_PREDICT_FALSE(data == end)) { + return skipped_rows; + } + c = *data++; + } while (c != '\r' && c != '\n'); + if (c == '\r' && data < end && *data == '\n') { + ++data; + } + *out_data = data; + } + + return skipped_rows; +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.h index ffc735c228..76ba0fbaf3 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/parser.h @@ -1,202 +1,202 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <cstddef> -#include <cstdint> -#include <memory> -#include <vector> - -#include "arrow/buffer.h" -#include "arrow/csv/options.h" -#include "arrow/status.h" -#include "arrow/util/macros.h" -#include "arrow/util/string_view.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class MemoryPool; - -namespace csv { - -/// Skip at most num_rows from the given input. The input pointer is updated -/// and the number of actually skipped rows is returns (may be less than -/// requested if the input is too short). -ARROW_EXPORT -int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows, - const uint8_t** out_data); - -class BlockParserImpl; - -namespace detail { - -struct ParsedValueDesc { - uint32_t offset : 31; - bool quoted : 1; -}; - -class ARROW_EXPORT DataBatch { - public: - explicit DataBatch(int32_t num_cols) : num_cols_(num_cols) {} - - /// \brief Return the number of parsed rows - int32_t num_rows() const { return num_rows_; } - /// \brief Return the number of parsed columns - int32_t num_cols() const { return num_cols_; } - /// \brief Return the total size in bytes of parsed data - uint32_t num_bytes() const { return parsed_size_; } - - template <typename Visitor> - Status VisitColumn(int32_t col_index, int64_t first_row, Visitor&& visit) const { - using detail::ParsedValueDesc; - - int64_t row = first_row; - for (size_t buf_index = 0; buf_index < values_buffers_.size(); ++buf_index) { - const auto& values_buffer = values_buffers_[buf_index]; - const auto values = reinterpret_cast<const ParsedValueDesc*>(values_buffer->data()); - const auto max_pos = - static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc)) - 1; - for (int32_t pos = col_index; pos < max_pos; pos += num_cols_, ++row) { - auto start = values[pos].offset; - auto stop = values[pos + 1].offset; - auto quoted = values[pos + 1].quoted; - Status status = visit(parsed_ + start, stop - start, quoted); - if (ARROW_PREDICT_FALSE(!status.ok())) { - if (first_row >= 0) { - status = status.WithMessage("Row #", row, ": ", status.message()); - } - ARROW_RETURN_NOT_OK(status); - } - } - } - return Status::OK(); - } - - template <typename Visitor> - Status VisitLastRow(Visitor&& visit) const { - using detail::ParsedValueDesc; - - const auto& values_buffer = values_buffers_.back(); - const auto values = reinterpret_cast<const ParsedValueDesc*>(values_buffer->data()); - const auto start_pos = - static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc)) - - num_cols_ - 1; - for (int32_t col_index = 0; col_index < num_cols_; ++col_index) { - auto start = values[start_pos + col_index].offset; - auto stop = values[start_pos + col_index + 1].offset; - auto quoted = values[start_pos + col_index + 1].quoted; - ARROW_RETURN_NOT_OK(visit(parsed_ + start, stop - start, quoted)); - } - return Status::OK(); - } - - protected: - // The number of rows in this batch - int32_t num_rows_ = 0; - // The number of columns - int32_t num_cols_ = 0; - - // XXX should we ensure the parsed buffer is padded with 8 or 16 excess zero bytes? - // It may help with null parsing... - std::vector<std::shared_ptr<Buffer>> values_buffers_; - std::shared_ptr<Buffer> parsed_buffer_; - const uint8_t* parsed_ = NULLPTR; - int32_t parsed_size_ = 0; - - friend class ::arrow::csv::BlockParserImpl; -}; - -} // namespace detail - -constexpr int32_t kMaxParserNumRows = 100000; - -/// \class BlockParser -/// \brief A reusable block-based parser for CSV data -/// -/// The parser takes a block of CSV data and delimits rows and fields, -/// unquoting and unescaping them on the fly. Parsed data is own by the -/// parser, so the original buffer can be discarded after Parse() returns. -/// -/// If the block is truncated (i.e. not all data can be parsed), it is up -/// to the caller to arrange the next block to start with the trailing data. -/// Also, if the previous block ends with CR (0x0d) and a new block starts -/// with LF (0x0a), the parser will consider the leading newline as an empty -/// line; the caller should therefore strip it. -class ARROW_EXPORT BlockParser { - public: - explicit BlockParser(ParseOptions options, int32_t num_cols = -1, - int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows); - explicit BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols = -1, - int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows); - ~BlockParser(); - - /// \brief Parse a block of data - /// - /// Parse a block of CSV data, ingesting up to max_num_rows rows. - /// The number of bytes actually parsed is returned in out_size. - Status Parse(util::string_view data, uint32_t* out_size); - - /// \brief Parse sequential blocks of data - /// - /// Only the last block is allowed to be truncated. - Status Parse(const std::vector<util::string_view>& data, uint32_t* out_size); - - /// \brief Parse the final block of data - /// - /// Like Parse(), but called with the final block in a file. - /// The last row may lack a trailing line separator. - Status ParseFinal(util::string_view data, uint32_t* out_size); - - /// \brief Parse the final sequential blocks of data - /// - /// Only the last block is allowed to be truncated. - Status ParseFinal(const std::vector<util::string_view>& data, uint32_t* out_size); - - /// \brief Return the number of parsed rows - int32_t num_rows() const { return parsed_batch().num_rows(); } - /// \brief Return the number of parsed columns - int32_t num_cols() const { return parsed_batch().num_cols(); } - /// \brief Return the total size in bytes of parsed data - uint32_t num_bytes() const { return parsed_batch().num_bytes(); } - /// \brief Return the row number of the first row in the block or -1 if unsupported - int64_t first_row_num() const; - - /// \brief Visit parsed values in a column - /// - /// The signature of the visitor is - /// Status(const uint8_t* data, uint32_t size, bool quoted) - template <typename Visitor> - Status VisitColumn(int32_t col_index, Visitor&& visit) const { - return parsed_batch().VisitColumn(col_index, first_row_num(), - std::forward<Visitor>(visit)); - } - - template <typename Visitor> - Status VisitLastRow(Visitor&& visit) const { - return parsed_batch().VisitLastRow(std::forward<Visitor>(visit)); - } - - protected: - std::unique_ptr<BlockParserImpl> impl_; - - const detail::DataBatch& parsed_batch() const; -}; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <vector> + +#include "arrow/buffer.h" +#include "arrow/csv/options.h" +#include "arrow/status.h" +#include "arrow/util/macros.h" +#include "arrow/util/string_view.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class MemoryPool; + +namespace csv { + +/// Skip at most num_rows from the given input. The input pointer is updated +/// and the number of actually skipped rows is returns (may be less than +/// requested if the input is too short). +ARROW_EXPORT +int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows, + const uint8_t** out_data); + +class BlockParserImpl; + +namespace detail { + +struct ParsedValueDesc { + uint32_t offset : 31; + bool quoted : 1; +}; + +class ARROW_EXPORT DataBatch { + public: + explicit DataBatch(int32_t num_cols) : num_cols_(num_cols) {} + + /// \brief Return the number of parsed rows + int32_t num_rows() const { return num_rows_; } + /// \brief Return the number of parsed columns + int32_t num_cols() const { return num_cols_; } + /// \brief Return the total size in bytes of parsed data + uint32_t num_bytes() const { return parsed_size_; } + + template <typename Visitor> + Status VisitColumn(int32_t col_index, int64_t first_row, Visitor&& visit) const { + using detail::ParsedValueDesc; + + int64_t row = first_row; + for (size_t buf_index = 0; buf_index < values_buffers_.size(); ++buf_index) { + const auto& values_buffer = values_buffers_[buf_index]; + const auto values = reinterpret_cast<const ParsedValueDesc*>(values_buffer->data()); + const auto max_pos = + static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc)) - 1; + for (int32_t pos = col_index; pos < max_pos; pos += num_cols_, ++row) { + auto start = values[pos].offset; + auto stop = values[pos + 1].offset; + auto quoted = values[pos + 1].quoted; + Status status = visit(parsed_ + start, stop - start, quoted); + if (ARROW_PREDICT_FALSE(!status.ok())) { + if (first_row >= 0) { + status = status.WithMessage("Row #", row, ": ", status.message()); + } + ARROW_RETURN_NOT_OK(status); + } + } + } + return Status::OK(); + } + + template <typename Visitor> + Status VisitLastRow(Visitor&& visit) const { + using detail::ParsedValueDesc; + + const auto& values_buffer = values_buffers_.back(); + const auto values = reinterpret_cast<const ParsedValueDesc*>(values_buffer->data()); + const auto start_pos = + static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc)) - + num_cols_ - 1; + for (int32_t col_index = 0; col_index < num_cols_; ++col_index) { + auto start = values[start_pos + col_index].offset; + auto stop = values[start_pos + col_index + 1].offset; + auto quoted = values[start_pos + col_index + 1].quoted; + ARROW_RETURN_NOT_OK(visit(parsed_ + start, stop - start, quoted)); + } + return Status::OK(); + } + + protected: + // The number of rows in this batch + int32_t num_rows_ = 0; + // The number of columns + int32_t num_cols_ = 0; + + // XXX should we ensure the parsed buffer is padded with 8 or 16 excess zero bytes? + // It may help with null parsing... + std::vector<std::shared_ptr<Buffer>> values_buffers_; + std::shared_ptr<Buffer> parsed_buffer_; + const uint8_t* parsed_ = NULLPTR; + int32_t parsed_size_ = 0; + + friend class ::arrow::csv::BlockParserImpl; +}; + +} // namespace detail + +constexpr int32_t kMaxParserNumRows = 100000; + +/// \class BlockParser +/// \brief A reusable block-based parser for CSV data +/// +/// The parser takes a block of CSV data and delimits rows and fields, +/// unquoting and unescaping them on the fly. Parsed data is own by the +/// parser, so the original buffer can be discarded after Parse() returns. +/// +/// If the block is truncated (i.e. not all data can be parsed), it is up +/// to the caller to arrange the next block to start with the trailing data. +/// Also, if the previous block ends with CR (0x0d) and a new block starts +/// with LF (0x0a), the parser will consider the leading newline as an empty +/// line; the caller should therefore strip it. +class ARROW_EXPORT BlockParser { + public: + explicit BlockParser(ParseOptions options, int32_t num_cols = -1, + int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows); + explicit BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols = -1, + int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows); + ~BlockParser(); + + /// \brief Parse a block of data + /// + /// Parse a block of CSV data, ingesting up to max_num_rows rows. + /// The number of bytes actually parsed is returned in out_size. + Status Parse(util::string_view data, uint32_t* out_size); + + /// \brief Parse sequential blocks of data + /// + /// Only the last block is allowed to be truncated. + Status Parse(const std::vector<util::string_view>& data, uint32_t* out_size); + + /// \brief Parse the final block of data + /// + /// Like Parse(), but called with the final block in a file. + /// The last row may lack a trailing line separator. + Status ParseFinal(util::string_view data, uint32_t* out_size); + + /// \brief Parse the final sequential blocks of data + /// + /// Only the last block is allowed to be truncated. + Status ParseFinal(const std::vector<util::string_view>& data, uint32_t* out_size); + + /// \brief Return the number of parsed rows + int32_t num_rows() const { return parsed_batch().num_rows(); } + /// \brief Return the number of parsed columns + int32_t num_cols() const { return parsed_batch().num_cols(); } + /// \brief Return the total size in bytes of parsed data + uint32_t num_bytes() const { return parsed_batch().num_bytes(); } + /// \brief Return the row number of the first row in the block or -1 if unsupported + int64_t first_row_num() const; + + /// \brief Visit parsed values in a column + /// + /// The signature of the visitor is + /// Status(const uint8_t* data, uint32_t size, bool quoted) + template <typename Visitor> + Status VisitColumn(int32_t col_index, Visitor&& visit) const { + return parsed_batch().VisitColumn(col_index, first_row_num(), + std::forward<Visitor>(visit)); + } + + template <typename Visitor> + Status VisitLastRow(Visitor&& visit) const { + return parsed_batch().VisitLastRow(std::forward<Visitor>(visit)); + } + + protected: + std::unique_ptr<BlockParserImpl> impl_; + + const detail::DataBatch& parsed_batch() const; +}; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.cc index 1a7836561d..d31d39ccf8 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.cc @@ -1,1279 +1,1279 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/reader.h" - -#include <cstdint> -#include <cstring> -#include <functional> -#include <limits> -#include <memory> -#include <sstream> -#include <string> -#include <unordered_map> -#include <utility> -#include <vector> - -#include "arrow/array.h" -#include "arrow/buffer.h" -#include "arrow/csv/chunker.h" -#include "arrow/csv/column_builder.h" -#include "arrow/csv/column_decoder.h" -#include "arrow/csv/options.h" -#include "arrow/csv/parser.h" -#include "arrow/io/interfaces.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/table.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/async_generator.h" -#include "arrow/util/future.h" -#include "arrow/util/iterator.h" -#include "arrow/util/logging.h" -#include "arrow/util/macros.h" -#include "arrow/util/optional.h" -#include "arrow/util/task_group.h" -#include "arrow/util/thread_pool.h" -#include "arrow/util/utf8.h" -#include "arrow/util/vector.h" - -namespace arrow { -namespace csv { - -using internal::Executor; - -namespace { - -struct ConversionSchema { - struct Column { - std::string name; - // Physical column index in CSV file - int32_t index; - // If true, make a column of nulls - bool is_missing; - // If set, convert the CSV column to this type - // If unset (and is_missing is false), infer the type from the CSV column - std::shared_ptr<DataType> type; - }; - - static Column NullColumn(std::string col_name, std::shared_ptr<DataType> type) { - return Column{std::move(col_name), -1, true, std::move(type)}; - } - - static Column TypedColumn(std::string col_name, int32_t col_index, - std::shared_ptr<DataType> type) { - return Column{std::move(col_name), col_index, false, std::move(type)}; - } - - static Column InferredColumn(std::string col_name, int32_t col_index) { - return Column{std::move(col_name), col_index, false, nullptr}; - } - - std::vector<Column> columns; -}; - -// An iterator of Buffers that makes sure there is no straddling CRLF sequence. -class CSVBufferIterator { - public: - static Iterator<std::shared_ptr<Buffer>> Make( - Iterator<std::shared_ptr<Buffer>> buffer_iterator) { - Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn = - CSVBufferIterator(); - return MakeTransformedIterator(std::move(buffer_iterator), fn); - } - - static AsyncGenerator<std::shared_ptr<Buffer>> MakeAsync( - AsyncGenerator<std::shared_ptr<Buffer>> buffer_iterator) { - Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn = - CSVBufferIterator(); - return MakeTransformedGenerator(std::move(buffer_iterator), fn); - } - - Result<TransformFlow<std::shared_ptr<Buffer>>> operator()(std::shared_ptr<Buffer> buf) { - if (buf == nullptr) { - // EOF - return TransformFinish(); - } - - int64_t offset = 0; - if (first_buffer_) { - ARROW_ASSIGN_OR_RAISE(auto data, util::SkipUTF8BOM(buf->data(), buf->size())); - offset += data - buf->data(); - DCHECK_GE(offset, 0); - first_buffer_ = false; - } - - if (trailing_cr_ && buf->data()[offset] == '\n') { - // Skip '\r\n' line separator that started at the end of previous buffer - ++offset; - } - - trailing_cr_ = (buf->data()[buf->size() - 1] == '\r'); - buf = SliceBuffer(buf, offset); - if (buf->size() == 0) { - // EOF - return TransformFinish(); - } else { - return TransformYield(buf); - } - } - - protected: - bool first_buffer_ = true; - // Whether there was a trailing CR at the end of last received buffer - bool trailing_cr_ = false; -}; - -struct CSVBlock { - // (partial + completion + buffer) is an entire delimited CSV buffer. - std::shared_ptr<Buffer> partial; - std::shared_ptr<Buffer> completion; - std::shared_ptr<Buffer> buffer; - int64_t block_index; - bool is_final; - int64_t bytes_skipped; - std::function<Status(int64_t)> consume_bytes; -}; - -} // namespace -} // namespace csv - -template <> -struct IterationTraits<csv::CSVBlock> { - static csv::CSVBlock End() { return csv::CSVBlock{{}, {}, {}, -1, true, 0, {}}; } - static bool IsEnd(const csv::CSVBlock& val) { return val.block_index < 0; } -}; - -namespace csv { -namespace { - -// This is a callable that can be used to transform an iterator. The source iterator -// will contain buffers of data and the output iterator will contain delimited CSV -// blocks. util::optional is used so that there is an end token (required by the -// iterator APIs (e.g. Visit)) even though an empty optional is never used in this code. -class BlockReader { - public: - BlockReader(std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer, - int64_t skip_rows) - : chunker_(std::move(chunker)), - partial_(std::make_shared<Buffer>("")), - buffer_(std::move(first_buffer)), - skip_rows_(skip_rows) {} - - protected: - std::unique_ptr<Chunker> chunker_; - std::shared_ptr<Buffer> partial_, buffer_; - int64_t skip_rows_; - int64_t block_index_ = 0; - // Whether there was a trailing CR at the end of last received buffer - bool trailing_cr_ = false; -}; - -// An object that reads delimited CSV blocks for serial use. -// The number of bytes consumed should be notified after each read, -// using CSVBlock::consume_bytes. -class SerialBlockReader : public BlockReader { - public: - using BlockReader::BlockReader; - - static Iterator<CSVBlock> MakeIterator( - Iterator<std::shared_ptr<Buffer>> buffer_iterator, std::unique_ptr<Chunker> chunker, - std::shared_ptr<Buffer> first_buffer, int64_t skip_rows) { - auto block_reader = - std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer, skip_rows); - // Wrap shared pointer in callable - Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn = - [block_reader](std::shared_ptr<Buffer> buf) { - return (*block_reader)(std::move(buf)); - }; - return MakeTransformedIterator(std::move(buffer_iterator), block_reader_fn); - } - - static AsyncGenerator<CSVBlock> MakeAsyncIterator( - AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator, - std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer, - int64_t skip_rows) { - auto block_reader = - std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer, skip_rows); - // Wrap shared pointer in callable - Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn = - [block_reader](std::shared_ptr<Buffer> next) { - return (*block_reader)(std::move(next)); - }; - return MakeTransformedGenerator(std::move(buffer_generator), block_reader_fn); - } - - Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer> next_buffer) { - if (buffer_ == nullptr) { - return TransformFinish(); - } - - bool is_final = (next_buffer == nullptr); - int64_t bytes_skipped = 0; - - if (skip_rows_) { - bytes_skipped += partial_->size(); - auto orig_size = buffer_->size(); - RETURN_NOT_OK( - chunker_->ProcessSkip(partial_, buffer_, is_final, &skip_rows_, &buffer_)); - bytes_skipped += orig_size - buffer_->size(); - auto empty = std::make_shared<Buffer>(nullptr, 0); - if (skip_rows_) { - // Still have rows beyond this buffer to skip return empty block - partial_ = std::move(buffer_); - buffer_ = next_buffer; - return TransformYield<CSVBlock>(CSVBlock{empty, empty, empty, block_index_++, - is_final, bytes_skipped, - [](int64_t) { return Status::OK(); }}); - } - partial_ = std::move(empty); - } - - std::shared_ptr<Buffer> completion; - - if (is_final) { - // End of file reached => compute completion from penultimate block - RETURN_NOT_OK(chunker_->ProcessFinal(partial_, buffer_, &completion, &buffer_)); - } else { - // Get completion of partial from previous block. - RETURN_NOT_OK( - chunker_->ProcessWithPartial(partial_, buffer_, &completion, &buffer_)); - } - int64_t bytes_before_buffer = partial_->size() + completion->size(); - - auto consume_bytes = [this, bytes_before_buffer, - next_buffer](int64_t nbytes) -> Status { - DCHECK_GE(nbytes, 0); - auto offset = nbytes - bytes_before_buffer; - if (offset < 0) { - // Should not happen - return Status::Invalid("CSV parser got out of sync with chunker"); - } - partial_ = SliceBuffer(buffer_, offset); - buffer_ = next_buffer; - return Status::OK(); - }; - - return TransformYield<CSVBlock>(CSVBlock{partial_, completion, buffer_, - block_index_++, is_final, bytes_skipped, - std::move(consume_bytes)}); - } -}; - -// An object that reads delimited CSV blocks for threaded use. -class ThreadedBlockReader : public BlockReader { - public: - using BlockReader::BlockReader; - - static AsyncGenerator<CSVBlock> MakeAsyncIterator( - AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator, - std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer, - int64_t skip_rows) { - auto block_reader = std::make_shared<ThreadedBlockReader>(std::move(chunker), - first_buffer, skip_rows); - // Wrap shared pointer in callable - Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn = - [block_reader](std::shared_ptr<Buffer> next) { return (*block_reader)(next); }; - return MakeTransformedGenerator(std::move(buffer_generator), block_reader_fn); - } - - Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer> next_buffer) { - if (buffer_ == nullptr) { - // EOF - return TransformFinish(); - } - - bool is_final = (next_buffer == nullptr); - - auto current_partial = std::move(partial_); - auto current_buffer = std::move(buffer_); - int64_t bytes_skipped = 0; - - if (skip_rows_) { - auto orig_size = current_buffer->size(); - bytes_skipped = current_partial->size(); - RETURN_NOT_OK(chunker_->ProcessSkip(current_partial, current_buffer, is_final, - &skip_rows_, ¤t_buffer)); - bytes_skipped += orig_size - current_buffer->size(); - current_partial = std::make_shared<Buffer>(nullptr, 0); - if (skip_rows_) { - partial_ = std::move(current_buffer); - buffer_ = std::move(next_buffer); - return TransformYield<CSVBlock>(CSVBlock{current_partial, - current_partial, - current_partial, - block_index_++, - is_final, - bytes_skipped, - {}}); - } - } - - std::shared_ptr<Buffer> whole, completion, next_partial; - - if (is_final) { - // End of file reached => compute completion from penultimate block - RETURN_NOT_OK( - chunker_->ProcessFinal(current_partial, current_buffer, &completion, &whole)); - } else { - // Get completion of partial from previous block. - std::shared_ptr<Buffer> starts_with_whole; - // Get completion of partial from previous block. - RETURN_NOT_OK(chunker_->ProcessWithPartial(current_partial, current_buffer, - &completion, &starts_with_whole)); - - // Get a complete CSV block inside `partial + block`, and keep - // the rest for the next iteration. - RETURN_NOT_OK(chunker_->Process(starts_with_whole, &whole, &next_partial)); - } - - partial_ = std::move(next_partial); - buffer_ = std::move(next_buffer); - - return TransformYield<CSVBlock>(CSVBlock{ - current_partial, completion, whole, block_index_++, is_final, bytes_skipped, {}}); - } -}; - -struct ParsedBlock { - std::shared_ptr<BlockParser> parser; - int64_t block_index; - int64_t bytes_parsed_or_skipped; -}; - -struct DecodedBlock { - std::shared_ptr<RecordBatch> record_batch; - // Represents the number of input bytes represented by this batch - // This will include bytes skipped when skipping rows after the header - int64_t bytes_processed; -}; - -} // namespace - -} // namespace csv - -template <> -struct IterationTraits<csv::ParsedBlock> { - static csv::ParsedBlock End() { return csv::ParsedBlock{nullptr, -1, -1}; } - static bool IsEnd(const csv::ParsedBlock& val) { return val.block_index < 0; } -}; - -template <> -struct IterationTraits<csv::DecodedBlock> { - static csv::DecodedBlock End() { return csv::DecodedBlock{nullptr, -1}; } - static bool IsEnd(const csv::DecodedBlock& val) { return val.bytes_processed < 0; } -}; - -namespace csv { -namespace { - -// A function object that takes in a buffer of CSV data and returns a parsed batch of CSV -// data (CSVBlock -> ParsedBlock) for use with MakeMappedGenerator. -// The parsed batch contains a list of offsets for each of the columns so that columns -// can be individually scanned -// -// This operator is not re-entrant -class BlockParsingOperator { - public: - BlockParsingOperator(io::IOContext io_context, ParseOptions parse_options, - int num_csv_cols, int64_t first_row) - : io_context_(io_context), - parse_options_(parse_options), - num_csv_cols_(num_csv_cols), - count_rows_(first_row >= 0), - num_rows_seen_(first_row) {} - - Result<ParsedBlock> operator()(const CSVBlock& block) { - constexpr int32_t max_num_rows = std::numeric_limits<int32_t>::max(); - auto parser = std::make_shared<BlockParser>( - io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows); - - std::shared_ptr<Buffer> straddling; - std::vector<util::string_view> views; - if (block.partial->size() != 0 || block.completion->size() != 0) { - if (block.partial->size() == 0) { - straddling = block.completion; - } else if (block.completion->size() == 0) { - straddling = block.partial; - } else { - ARROW_ASSIGN_OR_RAISE( - straddling, - ConcatenateBuffers({block.partial, block.completion}, io_context_.pool())); - } - views = {util::string_view(*straddling), util::string_view(*block.buffer)}; - } else { - views = {util::string_view(*block.buffer)}; - } - uint32_t parsed_size; - if (block.is_final) { - RETURN_NOT_OK(parser->ParseFinal(views, &parsed_size)); - } else { - RETURN_NOT_OK(parser->Parse(views, &parsed_size)); - } - if (count_rows_) { - num_rows_seen_ += parser->num_rows(); - } - RETURN_NOT_OK(block.consume_bytes(parsed_size)); - return ParsedBlock{std::move(parser), block.block_index, - static_cast<int64_t>(parsed_size) + block.bytes_skipped}; - } - - private: - io::IOContext io_context_; - ParseOptions parse_options_; - int num_csv_cols_; - bool count_rows_; - int64_t num_rows_seen_; -}; - -// A function object that takes in parsed batch of CSV data and decodes it to an arrow -// record batch (ParsedBlock -> DecodedBlock) for use with MakeMappedGenerator. -class BlockDecodingOperator { - public: - Future<DecodedBlock> operator()(const ParsedBlock& block) { - DCHECK(!state_->column_decoders.empty()); - std::vector<Future<std::shared_ptr<Array>>> decoded_array_futs; - for (auto& decoder : state_->column_decoders) { - decoded_array_futs.push_back(decoder->Decode(block.parser)); - } - auto bytes_parsed_or_skipped = block.bytes_parsed_or_skipped; - auto decoded_arrays_fut = All(std::move(decoded_array_futs)); - auto state = state_; - return decoded_arrays_fut.Then( - [state, bytes_parsed_or_skipped]( - const std::vector<Result<std::shared_ptr<Array>>>& maybe_decoded_arrays) - -> Result<DecodedBlock> { - ARROW_ASSIGN_OR_RAISE(auto decoded_arrays, - internal::UnwrapOrRaise(maybe_decoded_arrays)); - - ARROW_ASSIGN_OR_RAISE(auto batch, - state->DecodedArraysToBatch(std::move(decoded_arrays))); - return DecodedBlock{std::move(batch), bytes_parsed_or_skipped}; - }); - } - - static Result<BlockDecodingOperator> Make(io::IOContext io_context, - ConvertOptions convert_options, - ConversionSchema conversion_schema) { - BlockDecodingOperator op(std::move(io_context), std::move(convert_options), - std::move(conversion_schema)); - RETURN_NOT_OK(op.state_->MakeColumnDecoders(io_context)); - return op; - } - - private: - BlockDecodingOperator(io::IOContext io_context, ConvertOptions convert_options, - ConversionSchema conversion_schema) - : state_(std::make_shared<State>(std::move(io_context), std::move(convert_options), - std::move(conversion_schema))) {} - - struct State { - State(io::IOContext io_context, ConvertOptions convert_options, - ConversionSchema conversion_schema) - : convert_options(std::move(convert_options)), - conversion_schema(std::move(conversion_schema)) {} - - Result<std::shared_ptr<RecordBatch>> DecodedArraysToBatch( - std::vector<std::shared_ptr<Array>> arrays) { - if (schema == nullptr) { - FieldVector fields(arrays.size()); - for (size_t i = 0; i < arrays.size(); ++i) { - fields[i] = field(conversion_schema.columns[i].name, arrays[i]->type()); - } - schema = arrow::schema(std::move(fields)); - } - const auto n_rows = arrays[0]->length(); - return RecordBatch::Make(schema, n_rows, std::move(arrays)); - } - - // Make column decoders from conversion schema - Status MakeColumnDecoders(io::IOContext io_context) { - for (const auto& column : conversion_schema.columns) { - std::shared_ptr<ColumnDecoder> decoder; - if (column.is_missing) { - ARROW_ASSIGN_OR_RAISE(decoder, - ColumnDecoder::MakeNull(io_context.pool(), column.type)); - } else if (column.type != nullptr) { - ARROW_ASSIGN_OR_RAISE( - decoder, ColumnDecoder::Make(io_context.pool(), column.type, column.index, - convert_options)); - } else { - ARROW_ASSIGN_OR_RAISE( - decoder, - ColumnDecoder::Make(io_context.pool(), column.index, convert_options)); - } - column_decoders.push_back(std::move(decoder)); - } - return Status::OK(); - } - - ConvertOptions convert_options; - ConversionSchema conversion_schema; - std::vector<std::shared_ptr<ColumnDecoder>> column_decoders; - std::shared_ptr<Schema> schema; - }; - - std::shared_ptr<State> state_; -}; - -///////////////////////////////////////////////////////////////////////// -// Base class for common functionality - -class ReaderMixin { - public: - ReaderMixin(io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options, bool count_rows) - : io_context_(std::move(io_context)), - read_options_(read_options), - parse_options_(parse_options), - convert_options_(convert_options), - count_rows_(count_rows), - num_rows_seen_(count_rows_ ? 1 : -1), - input_(std::move(input)) {} - - protected: - // Read header and column names from buffer, create column builders - // Returns the # of bytes consumed - Result<int64_t> ProcessHeader(const std::shared_ptr<Buffer>& buf, - std::shared_ptr<Buffer>* rest) { - const uint8_t* data = buf->data(); - const auto data_end = data + buf->size(); - DCHECK_GT(data_end - data, 0); - - if (read_options_.skip_rows) { - // Skip initial rows (potentially invalid CSV data) - auto num_skipped_rows = SkipRows(data, static_cast<uint32_t>(data_end - data), - read_options_.skip_rows, &data); - if (num_skipped_rows < read_options_.skip_rows) { - return Status::Invalid( - "Could not skip initial ", read_options_.skip_rows, - " rows from CSV file, " - "either file is too short or header is larger than block size"); - } - if (count_rows_) { - num_rows_seen_ += num_skipped_rows; - } - } - - if (read_options_.column_names.empty()) { - // Parse one row (either to read column names or to know the number of columns) - BlockParser parser(io_context_.pool(), parse_options_, num_csv_cols_, - num_rows_seen_, 1); - uint32_t parsed_size = 0; - RETURN_NOT_OK(parser.Parse( - util::string_view(reinterpret_cast<const char*>(data), data_end - data), - &parsed_size)); - if (parser.num_rows() != 1) { - return Status::Invalid( - "Could not read first row from CSV file, either " - "file is too short or header is larger than block size"); - } - if (parser.num_cols() == 0) { - return Status::Invalid("No columns in CSV file"); - } - - if (read_options_.autogenerate_column_names) { - column_names_ = GenerateColumnNames(parser.num_cols()); - } else { - // Read column names from header row - auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { - column_names_.emplace_back(reinterpret_cast<const char*>(data), size); - return Status::OK(); - }; - RETURN_NOT_OK(parser.VisitLastRow(visit)); - DCHECK_EQ(static_cast<size_t>(parser.num_cols()), column_names_.size()); - // Skip parsed header row - data += parsed_size; - if (count_rows_) { - ++num_rows_seen_; - } - } - } else { - column_names_ = read_options_.column_names; - } - - if (count_rows_) { - // increase rows seen to skip past rows which will be skipped - num_rows_seen_ += read_options_.skip_rows_after_names; - } - - auto bytes_consumed = data - buf->data(); - *rest = SliceBuffer(buf, bytes_consumed); - - num_csv_cols_ = static_cast<int32_t>(column_names_.size()); - DCHECK_GT(num_csv_cols_, 0); - - RETURN_NOT_OK(MakeConversionSchema()); - return bytes_consumed; - } - - std::vector<std::string> GenerateColumnNames(int32_t num_cols) { - std::vector<std::string> res; - res.reserve(num_cols); - for (int32_t i = 0; i < num_cols; ++i) { - std::stringstream ss; - ss << "f" << i; - res.push_back(ss.str()); - } - return res; - } - - // Make conversion schema from options and parsed CSV header - Status MakeConversionSchema() { - // Append a column converted from CSV data - auto append_csv_column = [&](std::string col_name, int32_t col_index) { - // Does the named column have a fixed type? - auto it = convert_options_.column_types.find(col_name); - if (it == convert_options_.column_types.end()) { - conversion_schema_.columns.push_back( - ConversionSchema::InferredColumn(std::move(col_name), col_index)); - } else { - conversion_schema_.columns.push_back( - ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second)); - } - }; - - // Append a column of nulls - auto append_null_column = [&](std::string col_name) { - // If the named column has a fixed type, use it, otherwise use null() - std::shared_ptr<DataType> type; - auto it = convert_options_.column_types.find(col_name); - if (it == convert_options_.column_types.end()) { - type = null(); - } else { - type = it->second; - } - conversion_schema_.columns.push_back( - ConversionSchema::NullColumn(std::move(col_name), std::move(type))); - }; - - if (convert_options_.include_columns.empty()) { - // Include all columns in CSV file order - for (int32_t col_index = 0; col_index < num_csv_cols_; ++col_index) { - append_csv_column(column_names_[col_index], col_index); - } - } else { - // Include columns from `include_columns` (in that order) - // Compute indices of columns in the CSV file - std::unordered_map<std::string, int32_t> col_indices; - col_indices.reserve(column_names_.size()); - for (int32_t i = 0; i < static_cast<int32_t>(column_names_.size()); ++i) { - col_indices.emplace(column_names_[i], i); - } - - for (const auto& col_name : convert_options_.include_columns) { - auto it = col_indices.find(col_name); - if (it != col_indices.end()) { - append_csv_column(col_name, it->second); - } else if (convert_options_.include_missing_columns) { - append_null_column(col_name); - } else { - return Status::KeyError("Column '", col_name, - "' in include_columns " - "does not exist in CSV file"); - } - } - } - return Status::OK(); - } - - struct ParseResult { - std::shared_ptr<BlockParser> parser; - int64_t parsed_bytes; - }; - - Result<ParseResult> Parse(const std::shared_ptr<Buffer>& partial, - const std::shared_ptr<Buffer>& completion, - const std::shared_ptr<Buffer>& block, int64_t block_index, - bool is_final) { - static constexpr int32_t max_num_rows = std::numeric_limits<int32_t>::max(); - auto parser = std::make_shared<BlockParser>( - io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows); - - std::shared_ptr<Buffer> straddling; - std::vector<util::string_view> views; - if (partial->size() != 0 || completion->size() != 0) { - if (partial->size() == 0) { - straddling = completion; - } else if (completion->size() == 0) { - straddling = partial; - } else { - ARROW_ASSIGN_OR_RAISE( - straddling, ConcatenateBuffers({partial, completion}, io_context_.pool())); - } - views = {util::string_view(*straddling), util::string_view(*block)}; - } else { - views = {util::string_view(*block)}; - } - uint32_t parsed_size; - if (is_final) { - RETURN_NOT_OK(parser->ParseFinal(views, &parsed_size)); - } else { - RETURN_NOT_OK(parser->Parse(views, &parsed_size)); - } - if (count_rows_) { - num_rows_seen_ += parser->num_rows(); - } - return ParseResult{std::move(parser), static_cast<int64_t>(parsed_size)}; - } - - io::IOContext io_context_; - ReadOptions read_options_; - ParseOptions parse_options_; - ConvertOptions convert_options_; - - // Number of columns in the CSV file - int32_t num_csv_cols_ = -1; - // Whether num_rows_seen_ tracks the number of rows seen in the CSV being parsed - bool count_rows_; - // Number of rows seen in the csv. Not used if count_rows is false - int64_t num_rows_seen_; - // Column names in the CSV file - std::vector<std::string> column_names_; - ConversionSchema conversion_schema_; - - std::shared_ptr<io::InputStream> input_; - std::shared_ptr<internal::TaskGroup> task_group_; -}; - -///////////////////////////////////////////////////////////////////////// -// Base class for one-shot table readers - -class BaseTableReader : public ReaderMixin, public csv::TableReader { - public: - using ReaderMixin::ReaderMixin; - - virtual Status Init() = 0; - - Future<std::shared_ptr<Table>> ReadAsync() override { - return Future<std::shared_ptr<Table>>::MakeFinished(Read()); - } - - protected: - // Make column builders from conversion schema - Status MakeColumnBuilders() { - for (const auto& column : conversion_schema_.columns) { - std::shared_ptr<ColumnBuilder> builder; - if (column.is_missing) { - ARROW_ASSIGN_OR_RAISE(builder, ColumnBuilder::MakeNull(io_context_.pool(), - column.type, task_group_)); - } else if (column.type != nullptr) { - ARROW_ASSIGN_OR_RAISE( - builder, ColumnBuilder::Make(io_context_.pool(), column.type, column.index, - convert_options_, task_group_)); - } else { - ARROW_ASSIGN_OR_RAISE(builder, - ColumnBuilder::Make(io_context_.pool(), column.index, - convert_options_, task_group_)); - } - column_builders_.push_back(std::move(builder)); - } - return Status::OK(); - } - - Result<int64_t> ParseAndInsert(const std::shared_ptr<Buffer>& partial, - const std::shared_ptr<Buffer>& completion, - const std::shared_ptr<Buffer>& block, - int64_t block_index, bool is_final) { - ARROW_ASSIGN_OR_RAISE(auto result, - Parse(partial, completion, block, block_index, is_final)); - RETURN_NOT_OK(ProcessData(result.parser, block_index)); - return result.parsed_bytes; - } - - // Trigger conversion of parsed block data - Status ProcessData(const std::shared_ptr<BlockParser>& parser, int64_t block_index) { - for (auto& builder : column_builders_) { - builder->Insert(block_index, parser); - } - return Status::OK(); - } - - Result<std::shared_ptr<Table>> MakeTable() { - DCHECK_EQ(column_builders_.size(), conversion_schema_.columns.size()); - - std::vector<std::shared_ptr<Field>> fields; - std::vector<std::shared_ptr<ChunkedArray>> columns; - - for (int32_t i = 0; i < static_cast<int32_t>(column_builders_.size()); ++i) { - const auto& column = conversion_schema_.columns[i]; - ARROW_ASSIGN_OR_RAISE(auto array, column_builders_[i]->Finish()); - fields.push_back(::arrow::field(column.name, array->type())); - columns.emplace_back(std::move(array)); - } - return Table::Make(schema(std::move(fields)), std::move(columns)); - } - - // Column builders for target Table (in ConversionSchema order) - std::vector<std::shared_ptr<ColumnBuilder>> column_builders_; -}; - -///////////////////////////////////////////////////////////////////////// -// Base class for streaming readers - -class StreamingReaderImpl : public ReaderMixin, - public csv::StreamingReader, - public std::enable_shared_from_this<StreamingReaderImpl> { - public: - StreamingReaderImpl(io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options, bool count_rows) - : ReaderMixin(io_context, std::move(input), read_options, parse_options, - convert_options, count_rows), - bytes_decoded_(std::make_shared<std::atomic<int64_t>>(0)) {} - - Future<> Init(Executor* cpu_executor) { - ARROW_ASSIGN_OR_RAISE(auto istream_it, - io::MakeInputStreamIterator(input_, read_options_.block_size)); - - // TODO Consider exposing readahead as a read option (ARROW-12090) - ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it), - io_context_.executor())); - - auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor); - - auto buffer_generator = CSVBufferIterator::MakeAsync(std::move(transferred_it)); - - int max_readahead = cpu_executor->GetCapacity(); - auto self = shared_from_this(); - - return buffer_generator().Then([self, buffer_generator, max_readahead]( - const std::shared_ptr<Buffer>& first_buffer) { - return self->InitAfterFirstBuffer(first_buffer, buffer_generator, max_readahead); - }); - } - - std::shared_ptr<Schema> schema() const override { return schema_; } - - int64_t bytes_read() const override { return bytes_decoded_->load(); } - - Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { - auto next_fut = ReadNextAsync(); - auto next_result = next_fut.result(); - return std::move(next_result).Value(batch); - } - - Future<std::shared_ptr<RecordBatch>> ReadNextAsync() override { - return record_batch_gen_(); - } - - protected: - Future<> InitAfterFirstBuffer(const std::shared_ptr<Buffer>& first_buffer, - AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator, - int max_readahead) { - if (first_buffer == nullptr) { - return Status::Invalid("Empty CSV file"); - } - - std::shared_ptr<Buffer> after_header; - ARROW_ASSIGN_OR_RAISE(auto header_bytes_consumed, - ProcessHeader(first_buffer, &after_header)); - bytes_decoded_->fetch_add(header_bytes_consumed); - - auto parser_op = - BlockParsingOperator(io_context_, parse_options_, num_csv_cols_, num_rows_seen_); - ARROW_ASSIGN_OR_RAISE( - auto decoder_op, - BlockDecodingOperator::Make(io_context_, convert_options_, conversion_schema_)); - - auto block_gen = SerialBlockReader::MakeAsyncIterator( - std::move(buffer_generator), MakeChunker(parse_options_), std::move(after_header), - read_options_.skip_rows_after_names); - auto parsed_block_gen = - MakeMappedGenerator(std::move(block_gen), std::move(parser_op)); - auto rb_gen = MakeMappedGenerator(std::move(parsed_block_gen), std::move(decoder_op)); - - auto self = shared_from_this(); - return rb_gen().Then([self, rb_gen, max_readahead](const DecodedBlock& first_block) { - return self->InitAfterFirstBatch(first_block, std::move(rb_gen), max_readahead); - }); - } - - Status InitAfterFirstBatch(const DecodedBlock& first_block, - AsyncGenerator<DecodedBlock> batch_gen, int max_readahead) { - schema_ = first_block.record_batch->schema(); - - AsyncGenerator<DecodedBlock> readahead_gen; - if (read_options_.use_threads) { - readahead_gen = MakeReadaheadGenerator(std::move(batch_gen), max_readahead); - } else { - readahead_gen = std::move(batch_gen); - } - - AsyncGenerator<DecodedBlock> restarted_gen; - // Streaming reader should not emit empty record batches - if (first_block.record_batch->num_rows() > 0) { - restarted_gen = MakeGeneratorStartsWith({first_block}, std::move(readahead_gen)); - } else { - restarted_gen = std::move(readahead_gen); - } - - auto bytes_decoded = bytes_decoded_; - auto unwrap_and_record_bytes = - [bytes_decoded]( - const DecodedBlock& block) -> Result<std::shared_ptr<RecordBatch>> { - bytes_decoded->fetch_add(block.bytes_processed); - return block.record_batch; - }; - - auto unwrapped = - MakeMappedGenerator(std::move(restarted_gen), std::move(unwrap_and_record_bytes)); - - record_batch_gen_ = MakeCancellable(std::move(unwrapped), io_context_.stop_token()); - return Status::OK(); - } - - std::shared_ptr<Schema> schema_; - AsyncGenerator<std::shared_ptr<RecordBatch>> record_batch_gen_; - // bytes which have been decoded and asked for by the caller - std::shared_ptr<std::atomic<int64_t>> bytes_decoded_; -}; - -///////////////////////////////////////////////////////////////////////// -// Serial TableReader implementation - -class SerialTableReader : public BaseTableReader { - public: - using BaseTableReader::BaseTableReader; - - Status Init() override { - ARROW_ASSIGN_OR_RAISE(auto istream_it, - io::MakeInputStreamIterator(input_, read_options_.block_size)); - - // Since we're converting serially, no need to readahead more than one block - int32_t block_queue_size = 1; - ARROW_ASSIGN_OR_RAISE(auto rh_it, - MakeReadaheadIterator(std::move(istream_it), block_queue_size)); - buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it)); - return Status::OK(); - } - - Result<std::shared_ptr<Table>> Read() override { - task_group_ = internal::TaskGroup::MakeSerial(io_context_.stop_token()); - - // First block - ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next()); - if (first_buffer == nullptr) { - return Status::Invalid("Empty CSV file"); - } - RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer)); - RETURN_NOT_OK(MakeColumnBuilders()); - - auto block_iterator = SerialBlockReader::MakeIterator( - std::move(buffer_iterator_), MakeChunker(parse_options_), std::move(first_buffer), - read_options_.skip_rows_after_names); - while (true) { - RETURN_NOT_OK(io_context_.stop_token().Poll()); - - ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next()); - if (IsIterationEnd(maybe_block)) { - // EOF - break; - } - ARROW_ASSIGN_OR_RAISE( - int64_t parsed_bytes, - ParseAndInsert(maybe_block.partial, maybe_block.completion, maybe_block.buffer, - maybe_block.block_index, maybe_block.is_final)); - RETURN_NOT_OK(maybe_block.consume_bytes(parsed_bytes)); - } - // Finish conversion, create schema and table - RETURN_NOT_OK(task_group_->Finish()); - return MakeTable(); - } - - protected: - Iterator<std::shared_ptr<Buffer>> buffer_iterator_; -}; - -class AsyncThreadedTableReader - : public BaseTableReader, - public std::enable_shared_from_this<AsyncThreadedTableReader> { - public: - using BaseTableReader::BaseTableReader; - - AsyncThreadedTableReader(io::IOContext io_context, - std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, - const ParseOptions& parse_options, - const ConvertOptions& convert_options, Executor* cpu_executor) - // Count rows is currently not supported during parallel read - : BaseTableReader(std::move(io_context), input, read_options, parse_options, - convert_options, /*count_rows=*/false), - cpu_executor_(cpu_executor) {} - - ~AsyncThreadedTableReader() override { - if (task_group_) { - // In case of error, make sure all pending tasks are finished before - // we start destroying BaseTableReader members - ARROW_UNUSED(task_group_->Finish()); - } - } - - Status Init() override { - ARROW_ASSIGN_OR_RAISE(auto istream_it, - io::MakeInputStreamIterator(input_, read_options_.block_size)); - - int max_readahead = cpu_executor_->GetCapacity(); - int readahead_restart = std::max(1, max_readahead / 2); - - ARROW_ASSIGN_OR_RAISE( - auto bg_it, MakeBackgroundGenerator(std::move(istream_it), io_context_.executor(), - max_readahead, readahead_restart)); - - auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_); - buffer_generator_ = CSVBufferIterator::MakeAsync(std::move(transferred_it)); - return Status::OK(); - } - - Result<std::shared_ptr<Table>> Read() override { return ReadAsync().result(); } - - Future<std::shared_ptr<Table>> ReadAsync() override { - task_group_ = - internal::TaskGroup::MakeThreaded(cpu_executor_, io_context_.stop_token()); - - auto self = shared_from_this(); - return ProcessFirstBuffer().Then([self](const std::shared_ptr<Buffer>& first_buffer) { - auto block_generator = ThreadedBlockReader::MakeAsyncIterator( - self->buffer_generator_, MakeChunker(self->parse_options_), - std::move(first_buffer), self->read_options_.skip_rows_after_names); - - std::function<Status(CSVBlock)> block_visitor = - [self](CSVBlock maybe_block) -> Status { - // The logic in VisitAsyncGenerator ensures that we will never be - // passed an empty block (visit does not call with the end token) so - // we can be assured maybe_block has a value. - DCHECK_GE(maybe_block.block_index, 0); - DCHECK(!maybe_block.consume_bytes); - - // Launch parse task - self->task_group_->Append([self, maybe_block] { - return self - ->ParseAndInsert(maybe_block.partial, maybe_block.completion, - maybe_block.buffer, maybe_block.block_index, - maybe_block.is_final) - .status(); - }); - return Status::OK(); - }; - - return VisitAsyncGenerator(std::move(block_generator), block_visitor) - .Then([self]() -> Future<> { - // By this point we've added all top level tasks so it is safe to call - // FinishAsync - return self->task_group_->FinishAsync(); - }) - .Then([self]() -> Result<std::shared_ptr<Table>> { - // Finish conversion, create schema and table - return self->MakeTable(); - }); - }); - } - - protected: - Future<std::shared_ptr<Buffer>> ProcessFirstBuffer() { - // First block - auto first_buffer_future = buffer_generator_(); - return first_buffer_future.Then([this](const std::shared_ptr<Buffer>& first_buffer) - -> Result<std::shared_ptr<Buffer>> { - if (first_buffer == nullptr) { - return Status::Invalid("Empty CSV file"); - } - std::shared_ptr<Buffer> first_buffer_processed; - RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer_processed)); - RETURN_NOT_OK(MakeColumnBuilders()); - return first_buffer_processed; - }); - } - - Executor* cpu_executor_; - AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator_; -}; - -Result<std::shared_ptr<TableReader>> MakeTableReader( - MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - RETURN_NOT_OK(parse_options.Validate()); - RETURN_NOT_OK(read_options.Validate()); - RETURN_NOT_OK(convert_options.Validate()); - std::shared_ptr<BaseTableReader> reader; - if (read_options.use_threads) { - auto cpu_executor = internal::GetCpuThreadPool(); - reader = std::make_shared<AsyncThreadedTableReader>( - io_context, input, read_options, parse_options, convert_options, cpu_executor); - } else { - reader = std::make_shared<SerialTableReader>(io_context, input, read_options, - parse_options, convert_options, - /*count_rows=*/true); - } - RETURN_NOT_OK(reader->Init()); - return reader; -} - -Future<std::shared_ptr<StreamingReader>> MakeStreamingReader( - io::IOContext io_context, std::shared_ptr<io::InputStream> input, - internal::Executor* cpu_executor, const ReadOptions& read_options, - const ParseOptions& parse_options, const ConvertOptions& convert_options) { - RETURN_NOT_OK(parse_options.Validate()); - RETURN_NOT_OK(read_options.Validate()); - RETURN_NOT_OK(convert_options.Validate()); - std::shared_ptr<StreamingReaderImpl> reader; - reader = std::make_shared<StreamingReaderImpl>( - io_context, input, read_options, parse_options, convert_options, - /*count_rows=*/!read_options.use_threads || cpu_executor->GetCapacity() == 1); - return reader->Init(cpu_executor).Then([reader] { - return std::dynamic_pointer_cast<StreamingReader>(reader); - }); -} - -///////////////////////////////////////////////////////////////////////// -// Row count implementation - -class CSVRowCounter : public ReaderMixin, - public std::enable_shared_from_this<CSVRowCounter> { - public: - CSVRowCounter(io::IOContext io_context, Executor* cpu_executor, - std::shared_ptr<io::InputStream> input, const ReadOptions& read_options, - const ParseOptions& parse_options) - : ReaderMixin(io_context, std::move(input), read_options, parse_options, - ConvertOptions::Defaults(), /*count_rows=*/true), - cpu_executor_(cpu_executor), - row_count_(0) {} - - Future<int64_t> Count() { - auto self = shared_from_this(); - return Init(self).Then([self]() { return self->DoCount(self); }); - } - - private: - Future<> Init(const std::shared_ptr<CSVRowCounter>& self) { - ARROW_ASSIGN_OR_RAISE(auto istream_it, - io::MakeInputStreamIterator(input_, read_options_.block_size)); - // TODO Consider exposing readahead as a read option (ARROW-12090) - ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it), - io_context_.executor())); - auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_); - auto buffer_generator = CSVBufferIterator::MakeAsync(std::move(transferred_it)); - - return buffer_generator().Then( - [self, buffer_generator](std::shared_ptr<Buffer> first_buffer) { - if (!first_buffer) { - return Status::Invalid("Empty CSV file"); - } - RETURN_NOT_OK(self->ProcessHeader(first_buffer, &first_buffer)); - self->block_generator_ = SerialBlockReader::MakeAsyncIterator( - buffer_generator, MakeChunker(self->parse_options_), - std::move(first_buffer), 0); - return Status::OK(); - }); - } - - Future<int64_t> DoCount(const std::shared_ptr<CSVRowCounter>& self) { - // count_cb must return a value instead of Status/Future<> to work with - // MakeMappedGenerator, and it must use a type with a valid end value to work with - // IterationEnd. - std::function<Result<util::optional<int64_t>>(const CSVBlock&)> count_cb = - [self](const CSVBlock& maybe_block) -> Result<util::optional<int64_t>> { - ARROW_ASSIGN_OR_RAISE( - auto parser, - self->Parse(maybe_block.partial, maybe_block.completion, maybe_block.buffer, - maybe_block.block_index, maybe_block.is_final)); - RETURN_NOT_OK(maybe_block.consume_bytes(parser.parsed_bytes)); - self->row_count_ += parser.parser->num_rows(); - return parser.parser->num_rows(); - }; - auto count_gen = MakeMappedGenerator(block_generator_, std::move(count_cb)); - return DiscardAllFromAsyncGenerator(count_gen).Then( - [self]() { return self->row_count_; }); - } - - Executor* cpu_executor_; - AsyncGenerator<CSVBlock> block_generator_; - int64_t row_count_; -}; - -} // namespace - -///////////////////////////////////////////////////////////////////////// -// Factory functions - -Result<std::shared_ptr<TableReader>> TableReader::Make( - io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - return MakeTableReader(io_context.pool(), io_context, std::move(input), read_options, - parse_options, convert_options); -} - -Result<std::shared_ptr<TableReader>> TableReader::Make( - MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - return MakeTableReader(pool, io_context, std::move(input), read_options, parse_options, - convert_options); -} - -Result<std::shared_ptr<StreamingReader>> StreamingReader::Make( - MemoryPool* pool, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - auto io_context = io::IOContext(pool); - auto cpu_executor = internal::GetCpuThreadPool(); - auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor, - read_options, parse_options, convert_options); - auto reader_result = reader_fut.result(); - ARROW_ASSIGN_OR_RAISE(auto reader, reader_result); - return reader; -} - -Result<std::shared_ptr<StreamingReader>> StreamingReader::Make( - io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - auto cpu_executor = internal::GetCpuThreadPool(); - auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor, - read_options, parse_options, convert_options); - auto reader_result = reader_fut.result(); - ARROW_ASSIGN_OR_RAISE(auto reader, reader_result); - return reader; -} - -Future<std::shared_ptr<StreamingReader>> StreamingReader::MakeAsync( - io::IOContext io_context, std::shared_ptr<io::InputStream> input, - internal::Executor* cpu_executor, const ReadOptions& read_options, - const ParseOptions& parse_options, const ConvertOptions& convert_options) { - return MakeStreamingReader(io_context, std::move(input), cpu_executor, read_options, - parse_options, convert_options); -} - -Future<int64_t> CountRowsAsync(io::IOContext io_context, - std::shared_ptr<io::InputStream> input, - internal::Executor* cpu_executor, - const ReadOptions& read_options, - const ParseOptions& parse_options) { - RETURN_NOT_OK(parse_options.Validate()); - RETURN_NOT_OK(read_options.Validate()); - auto counter = std::make_shared<CSVRowCounter>( - io_context, cpu_executor, std::move(input), read_options, parse_options); - return counter->Count(); -} - -} // namespace csv - -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/reader.h" + +#include <cstdint> +#include <cstring> +#include <functional> +#include <limits> +#include <memory> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/csv/chunker.h" +#include "arrow/csv/column_builder.h" +#include "arrow/csv/column_decoder.h" +#include "arrow/csv/options.h" +#include "arrow/csv/parser.h" +#include "arrow/io/interfaces.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/future.h" +#include "arrow/util/iterator.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/optional.h" +#include "arrow/util/task_group.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/utf8.h" +#include "arrow/util/vector.h" + +namespace arrow { +namespace csv { + +using internal::Executor; + +namespace { + +struct ConversionSchema { + struct Column { + std::string name; + // Physical column index in CSV file + int32_t index; + // If true, make a column of nulls + bool is_missing; + // If set, convert the CSV column to this type + // If unset (and is_missing is false), infer the type from the CSV column + std::shared_ptr<DataType> type; + }; + + static Column NullColumn(std::string col_name, std::shared_ptr<DataType> type) { + return Column{std::move(col_name), -1, true, std::move(type)}; + } + + static Column TypedColumn(std::string col_name, int32_t col_index, + std::shared_ptr<DataType> type) { + return Column{std::move(col_name), col_index, false, std::move(type)}; + } + + static Column InferredColumn(std::string col_name, int32_t col_index) { + return Column{std::move(col_name), col_index, false, nullptr}; + } + + std::vector<Column> columns; +}; + +// An iterator of Buffers that makes sure there is no straddling CRLF sequence. +class CSVBufferIterator { + public: + static Iterator<std::shared_ptr<Buffer>> Make( + Iterator<std::shared_ptr<Buffer>> buffer_iterator) { + Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn = + CSVBufferIterator(); + return MakeTransformedIterator(std::move(buffer_iterator), fn); + } + + static AsyncGenerator<std::shared_ptr<Buffer>> MakeAsync( + AsyncGenerator<std::shared_ptr<Buffer>> buffer_iterator) { + Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn = + CSVBufferIterator(); + return MakeTransformedGenerator(std::move(buffer_iterator), fn); + } + + Result<TransformFlow<std::shared_ptr<Buffer>>> operator()(std::shared_ptr<Buffer> buf) { + if (buf == nullptr) { + // EOF + return TransformFinish(); + } + + int64_t offset = 0; + if (first_buffer_) { + ARROW_ASSIGN_OR_RAISE(auto data, util::SkipUTF8BOM(buf->data(), buf->size())); + offset += data - buf->data(); + DCHECK_GE(offset, 0); + first_buffer_ = false; + } + + if (trailing_cr_ && buf->data()[offset] == '\n') { + // Skip '\r\n' line separator that started at the end of previous buffer + ++offset; + } + + trailing_cr_ = (buf->data()[buf->size() - 1] == '\r'); + buf = SliceBuffer(buf, offset); + if (buf->size() == 0) { + // EOF + return TransformFinish(); + } else { + return TransformYield(buf); + } + } + + protected: + bool first_buffer_ = true; + // Whether there was a trailing CR at the end of last received buffer + bool trailing_cr_ = false; +}; + +struct CSVBlock { + // (partial + completion + buffer) is an entire delimited CSV buffer. + std::shared_ptr<Buffer> partial; + std::shared_ptr<Buffer> completion; + std::shared_ptr<Buffer> buffer; + int64_t block_index; + bool is_final; + int64_t bytes_skipped; + std::function<Status(int64_t)> consume_bytes; +}; + +} // namespace +} // namespace csv + +template <> +struct IterationTraits<csv::CSVBlock> { + static csv::CSVBlock End() { return csv::CSVBlock{{}, {}, {}, -1, true, 0, {}}; } + static bool IsEnd(const csv::CSVBlock& val) { return val.block_index < 0; } +}; + +namespace csv { +namespace { + +// This is a callable that can be used to transform an iterator. The source iterator +// will contain buffers of data and the output iterator will contain delimited CSV +// blocks. util::optional is used so that there is an end token (required by the +// iterator APIs (e.g. Visit)) even though an empty optional is never used in this code. +class BlockReader { + public: + BlockReader(std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer, + int64_t skip_rows) + : chunker_(std::move(chunker)), + partial_(std::make_shared<Buffer>("")), + buffer_(std::move(first_buffer)), + skip_rows_(skip_rows) {} + + protected: + std::unique_ptr<Chunker> chunker_; + std::shared_ptr<Buffer> partial_, buffer_; + int64_t skip_rows_; + int64_t block_index_ = 0; + // Whether there was a trailing CR at the end of last received buffer + bool trailing_cr_ = false; +}; + +// An object that reads delimited CSV blocks for serial use. +// The number of bytes consumed should be notified after each read, +// using CSVBlock::consume_bytes. +class SerialBlockReader : public BlockReader { + public: + using BlockReader::BlockReader; + + static Iterator<CSVBlock> MakeIterator( + Iterator<std::shared_ptr<Buffer>> buffer_iterator, std::unique_ptr<Chunker> chunker, + std::shared_ptr<Buffer> first_buffer, int64_t skip_rows) { + auto block_reader = + std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer, skip_rows); + // Wrap shared pointer in callable + Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn = + [block_reader](std::shared_ptr<Buffer> buf) { + return (*block_reader)(std::move(buf)); + }; + return MakeTransformedIterator(std::move(buffer_iterator), block_reader_fn); + } + + static AsyncGenerator<CSVBlock> MakeAsyncIterator( + AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator, + std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer, + int64_t skip_rows) { + auto block_reader = + std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer, skip_rows); + // Wrap shared pointer in callable + Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn = + [block_reader](std::shared_ptr<Buffer> next) { + return (*block_reader)(std::move(next)); + }; + return MakeTransformedGenerator(std::move(buffer_generator), block_reader_fn); + } + + Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer> next_buffer) { + if (buffer_ == nullptr) { + return TransformFinish(); + } + + bool is_final = (next_buffer == nullptr); + int64_t bytes_skipped = 0; + + if (skip_rows_) { + bytes_skipped += partial_->size(); + auto orig_size = buffer_->size(); + RETURN_NOT_OK( + chunker_->ProcessSkip(partial_, buffer_, is_final, &skip_rows_, &buffer_)); + bytes_skipped += orig_size - buffer_->size(); + auto empty = std::make_shared<Buffer>(nullptr, 0); + if (skip_rows_) { + // Still have rows beyond this buffer to skip return empty block + partial_ = std::move(buffer_); + buffer_ = next_buffer; + return TransformYield<CSVBlock>(CSVBlock{empty, empty, empty, block_index_++, + is_final, bytes_skipped, + [](int64_t) { return Status::OK(); }}); + } + partial_ = std::move(empty); + } + + std::shared_ptr<Buffer> completion; + + if (is_final) { + // End of file reached => compute completion from penultimate block + RETURN_NOT_OK(chunker_->ProcessFinal(partial_, buffer_, &completion, &buffer_)); + } else { + // Get completion of partial from previous block. + RETURN_NOT_OK( + chunker_->ProcessWithPartial(partial_, buffer_, &completion, &buffer_)); + } + int64_t bytes_before_buffer = partial_->size() + completion->size(); + + auto consume_bytes = [this, bytes_before_buffer, + next_buffer](int64_t nbytes) -> Status { + DCHECK_GE(nbytes, 0); + auto offset = nbytes - bytes_before_buffer; + if (offset < 0) { + // Should not happen + return Status::Invalid("CSV parser got out of sync with chunker"); + } + partial_ = SliceBuffer(buffer_, offset); + buffer_ = next_buffer; + return Status::OK(); + }; + + return TransformYield<CSVBlock>(CSVBlock{partial_, completion, buffer_, + block_index_++, is_final, bytes_skipped, + std::move(consume_bytes)}); + } +}; + +// An object that reads delimited CSV blocks for threaded use. +class ThreadedBlockReader : public BlockReader { + public: + using BlockReader::BlockReader; + + static AsyncGenerator<CSVBlock> MakeAsyncIterator( + AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator, + std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer, + int64_t skip_rows) { + auto block_reader = std::make_shared<ThreadedBlockReader>(std::move(chunker), + first_buffer, skip_rows); + // Wrap shared pointer in callable + Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn = + [block_reader](std::shared_ptr<Buffer> next) { return (*block_reader)(next); }; + return MakeTransformedGenerator(std::move(buffer_generator), block_reader_fn); + } + + Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer> next_buffer) { + if (buffer_ == nullptr) { + // EOF + return TransformFinish(); + } + + bool is_final = (next_buffer == nullptr); + + auto current_partial = std::move(partial_); + auto current_buffer = std::move(buffer_); + int64_t bytes_skipped = 0; + + if (skip_rows_) { + auto orig_size = current_buffer->size(); + bytes_skipped = current_partial->size(); + RETURN_NOT_OK(chunker_->ProcessSkip(current_partial, current_buffer, is_final, + &skip_rows_, ¤t_buffer)); + bytes_skipped += orig_size - current_buffer->size(); + current_partial = std::make_shared<Buffer>(nullptr, 0); + if (skip_rows_) { + partial_ = std::move(current_buffer); + buffer_ = std::move(next_buffer); + return TransformYield<CSVBlock>(CSVBlock{current_partial, + current_partial, + current_partial, + block_index_++, + is_final, + bytes_skipped, + {}}); + } + } + + std::shared_ptr<Buffer> whole, completion, next_partial; + + if (is_final) { + // End of file reached => compute completion from penultimate block + RETURN_NOT_OK( + chunker_->ProcessFinal(current_partial, current_buffer, &completion, &whole)); + } else { + // Get completion of partial from previous block. + std::shared_ptr<Buffer> starts_with_whole; + // Get completion of partial from previous block. + RETURN_NOT_OK(chunker_->ProcessWithPartial(current_partial, current_buffer, + &completion, &starts_with_whole)); + + // Get a complete CSV block inside `partial + block`, and keep + // the rest for the next iteration. + RETURN_NOT_OK(chunker_->Process(starts_with_whole, &whole, &next_partial)); + } + + partial_ = std::move(next_partial); + buffer_ = std::move(next_buffer); + + return TransformYield<CSVBlock>(CSVBlock{ + current_partial, completion, whole, block_index_++, is_final, bytes_skipped, {}}); + } +}; + +struct ParsedBlock { + std::shared_ptr<BlockParser> parser; + int64_t block_index; + int64_t bytes_parsed_or_skipped; +}; + +struct DecodedBlock { + std::shared_ptr<RecordBatch> record_batch; + // Represents the number of input bytes represented by this batch + // This will include bytes skipped when skipping rows after the header + int64_t bytes_processed; +}; + +} // namespace + +} // namespace csv + +template <> +struct IterationTraits<csv::ParsedBlock> { + static csv::ParsedBlock End() { return csv::ParsedBlock{nullptr, -1, -1}; } + static bool IsEnd(const csv::ParsedBlock& val) { return val.block_index < 0; } +}; + +template <> +struct IterationTraits<csv::DecodedBlock> { + static csv::DecodedBlock End() { return csv::DecodedBlock{nullptr, -1}; } + static bool IsEnd(const csv::DecodedBlock& val) { return val.bytes_processed < 0; } +}; + +namespace csv { +namespace { + +// A function object that takes in a buffer of CSV data and returns a parsed batch of CSV +// data (CSVBlock -> ParsedBlock) for use with MakeMappedGenerator. +// The parsed batch contains a list of offsets for each of the columns so that columns +// can be individually scanned +// +// This operator is not re-entrant +class BlockParsingOperator { + public: + BlockParsingOperator(io::IOContext io_context, ParseOptions parse_options, + int num_csv_cols, int64_t first_row) + : io_context_(io_context), + parse_options_(parse_options), + num_csv_cols_(num_csv_cols), + count_rows_(first_row >= 0), + num_rows_seen_(first_row) {} + + Result<ParsedBlock> operator()(const CSVBlock& block) { + constexpr int32_t max_num_rows = std::numeric_limits<int32_t>::max(); + auto parser = std::make_shared<BlockParser>( + io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows); + + std::shared_ptr<Buffer> straddling; + std::vector<util::string_view> views; + if (block.partial->size() != 0 || block.completion->size() != 0) { + if (block.partial->size() == 0) { + straddling = block.completion; + } else if (block.completion->size() == 0) { + straddling = block.partial; + } else { + ARROW_ASSIGN_OR_RAISE( + straddling, + ConcatenateBuffers({block.partial, block.completion}, io_context_.pool())); + } + views = {util::string_view(*straddling), util::string_view(*block.buffer)}; + } else { + views = {util::string_view(*block.buffer)}; + } + uint32_t parsed_size; + if (block.is_final) { + RETURN_NOT_OK(parser->ParseFinal(views, &parsed_size)); + } else { + RETURN_NOT_OK(parser->Parse(views, &parsed_size)); + } + if (count_rows_) { + num_rows_seen_ += parser->num_rows(); + } + RETURN_NOT_OK(block.consume_bytes(parsed_size)); + return ParsedBlock{std::move(parser), block.block_index, + static_cast<int64_t>(parsed_size) + block.bytes_skipped}; + } + + private: + io::IOContext io_context_; + ParseOptions parse_options_; + int num_csv_cols_; + bool count_rows_; + int64_t num_rows_seen_; +}; + +// A function object that takes in parsed batch of CSV data and decodes it to an arrow +// record batch (ParsedBlock -> DecodedBlock) for use with MakeMappedGenerator. +class BlockDecodingOperator { + public: + Future<DecodedBlock> operator()(const ParsedBlock& block) { + DCHECK(!state_->column_decoders.empty()); + std::vector<Future<std::shared_ptr<Array>>> decoded_array_futs; + for (auto& decoder : state_->column_decoders) { + decoded_array_futs.push_back(decoder->Decode(block.parser)); + } + auto bytes_parsed_or_skipped = block.bytes_parsed_or_skipped; + auto decoded_arrays_fut = All(std::move(decoded_array_futs)); + auto state = state_; + return decoded_arrays_fut.Then( + [state, bytes_parsed_or_skipped]( + const std::vector<Result<std::shared_ptr<Array>>>& maybe_decoded_arrays) + -> Result<DecodedBlock> { + ARROW_ASSIGN_OR_RAISE(auto decoded_arrays, + internal::UnwrapOrRaise(maybe_decoded_arrays)); + + ARROW_ASSIGN_OR_RAISE(auto batch, + state->DecodedArraysToBatch(std::move(decoded_arrays))); + return DecodedBlock{std::move(batch), bytes_parsed_or_skipped}; + }); + } + + static Result<BlockDecodingOperator> Make(io::IOContext io_context, + ConvertOptions convert_options, + ConversionSchema conversion_schema) { + BlockDecodingOperator op(std::move(io_context), std::move(convert_options), + std::move(conversion_schema)); + RETURN_NOT_OK(op.state_->MakeColumnDecoders(io_context)); + return op; + } + + private: + BlockDecodingOperator(io::IOContext io_context, ConvertOptions convert_options, + ConversionSchema conversion_schema) + : state_(std::make_shared<State>(std::move(io_context), std::move(convert_options), + std::move(conversion_schema))) {} + + struct State { + State(io::IOContext io_context, ConvertOptions convert_options, + ConversionSchema conversion_schema) + : convert_options(std::move(convert_options)), + conversion_schema(std::move(conversion_schema)) {} + + Result<std::shared_ptr<RecordBatch>> DecodedArraysToBatch( + std::vector<std::shared_ptr<Array>> arrays) { + if (schema == nullptr) { + FieldVector fields(arrays.size()); + for (size_t i = 0; i < arrays.size(); ++i) { + fields[i] = field(conversion_schema.columns[i].name, arrays[i]->type()); + } + schema = arrow::schema(std::move(fields)); + } + const auto n_rows = arrays[0]->length(); + return RecordBatch::Make(schema, n_rows, std::move(arrays)); + } + + // Make column decoders from conversion schema + Status MakeColumnDecoders(io::IOContext io_context) { + for (const auto& column : conversion_schema.columns) { + std::shared_ptr<ColumnDecoder> decoder; + if (column.is_missing) { + ARROW_ASSIGN_OR_RAISE(decoder, + ColumnDecoder::MakeNull(io_context.pool(), column.type)); + } else if (column.type != nullptr) { + ARROW_ASSIGN_OR_RAISE( + decoder, ColumnDecoder::Make(io_context.pool(), column.type, column.index, + convert_options)); + } else { + ARROW_ASSIGN_OR_RAISE( + decoder, + ColumnDecoder::Make(io_context.pool(), column.index, convert_options)); + } + column_decoders.push_back(std::move(decoder)); + } + return Status::OK(); + } + + ConvertOptions convert_options; + ConversionSchema conversion_schema; + std::vector<std::shared_ptr<ColumnDecoder>> column_decoders; + std::shared_ptr<Schema> schema; + }; + + std::shared_ptr<State> state_; +}; + +///////////////////////////////////////////////////////////////////////// +// Base class for common functionality + +class ReaderMixin { + public: + ReaderMixin(io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options, bool count_rows) + : io_context_(std::move(io_context)), + read_options_(read_options), + parse_options_(parse_options), + convert_options_(convert_options), + count_rows_(count_rows), + num_rows_seen_(count_rows_ ? 1 : -1), + input_(std::move(input)) {} + + protected: + // Read header and column names from buffer, create column builders + // Returns the # of bytes consumed + Result<int64_t> ProcessHeader(const std::shared_ptr<Buffer>& buf, + std::shared_ptr<Buffer>* rest) { + const uint8_t* data = buf->data(); + const auto data_end = data + buf->size(); + DCHECK_GT(data_end - data, 0); + + if (read_options_.skip_rows) { + // Skip initial rows (potentially invalid CSV data) + auto num_skipped_rows = SkipRows(data, static_cast<uint32_t>(data_end - data), + read_options_.skip_rows, &data); + if (num_skipped_rows < read_options_.skip_rows) { + return Status::Invalid( + "Could not skip initial ", read_options_.skip_rows, + " rows from CSV file, " + "either file is too short or header is larger than block size"); + } + if (count_rows_) { + num_rows_seen_ += num_skipped_rows; + } + } + + if (read_options_.column_names.empty()) { + // Parse one row (either to read column names or to know the number of columns) + BlockParser parser(io_context_.pool(), parse_options_, num_csv_cols_, + num_rows_seen_, 1); + uint32_t parsed_size = 0; + RETURN_NOT_OK(parser.Parse( + util::string_view(reinterpret_cast<const char*>(data), data_end - data), + &parsed_size)); + if (parser.num_rows() != 1) { + return Status::Invalid( + "Could not read first row from CSV file, either " + "file is too short or header is larger than block size"); + } + if (parser.num_cols() == 0) { + return Status::Invalid("No columns in CSV file"); + } + + if (read_options_.autogenerate_column_names) { + column_names_ = GenerateColumnNames(parser.num_cols()); + } else { + // Read column names from header row + auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status { + column_names_.emplace_back(reinterpret_cast<const char*>(data), size); + return Status::OK(); + }; + RETURN_NOT_OK(parser.VisitLastRow(visit)); + DCHECK_EQ(static_cast<size_t>(parser.num_cols()), column_names_.size()); + // Skip parsed header row + data += parsed_size; + if (count_rows_) { + ++num_rows_seen_; + } + } + } else { + column_names_ = read_options_.column_names; + } + + if (count_rows_) { + // increase rows seen to skip past rows which will be skipped + num_rows_seen_ += read_options_.skip_rows_after_names; + } + + auto bytes_consumed = data - buf->data(); + *rest = SliceBuffer(buf, bytes_consumed); + + num_csv_cols_ = static_cast<int32_t>(column_names_.size()); + DCHECK_GT(num_csv_cols_, 0); + + RETURN_NOT_OK(MakeConversionSchema()); + return bytes_consumed; + } + + std::vector<std::string> GenerateColumnNames(int32_t num_cols) { + std::vector<std::string> res; + res.reserve(num_cols); + for (int32_t i = 0; i < num_cols; ++i) { + std::stringstream ss; + ss << "f" << i; + res.push_back(ss.str()); + } + return res; + } + + // Make conversion schema from options and parsed CSV header + Status MakeConversionSchema() { + // Append a column converted from CSV data + auto append_csv_column = [&](std::string col_name, int32_t col_index) { + // Does the named column have a fixed type? + auto it = convert_options_.column_types.find(col_name); + if (it == convert_options_.column_types.end()) { + conversion_schema_.columns.push_back( + ConversionSchema::InferredColumn(std::move(col_name), col_index)); + } else { + conversion_schema_.columns.push_back( + ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second)); + } + }; + + // Append a column of nulls + auto append_null_column = [&](std::string col_name) { + // If the named column has a fixed type, use it, otherwise use null() + std::shared_ptr<DataType> type; + auto it = convert_options_.column_types.find(col_name); + if (it == convert_options_.column_types.end()) { + type = null(); + } else { + type = it->second; + } + conversion_schema_.columns.push_back( + ConversionSchema::NullColumn(std::move(col_name), std::move(type))); + }; + + if (convert_options_.include_columns.empty()) { + // Include all columns in CSV file order + for (int32_t col_index = 0; col_index < num_csv_cols_; ++col_index) { + append_csv_column(column_names_[col_index], col_index); + } + } else { + // Include columns from `include_columns` (in that order) + // Compute indices of columns in the CSV file + std::unordered_map<std::string, int32_t> col_indices; + col_indices.reserve(column_names_.size()); + for (int32_t i = 0; i < static_cast<int32_t>(column_names_.size()); ++i) { + col_indices.emplace(column_names_[i], i); + } + + for (const auto& col_name : convert_options_.include_columns) { + auto it = col_indices.find(col_name); + if (it != col_indices.end()) { + append_csv_column(col_name, it->second); + } else if (convert_options_.include_missing_columns) { + append_null_column(col_name); + } else { + return Status::KeyError("Column '", col_name, + "' in include_columns " + "does not exist in CSV file"); + } + } + } + return Status::OK(); + } + + struct ParseResult { + std::shared_ptr<BlockParser> parser; + int64_t parsed_bytes; + }; + + Result<ParseResult> Parse(const std::shared_ptr<Buffer>& partial, + const std::shared_ptr<Buffer>& completion, + const std::shared_ptr<Buffer>& block, int64_t block_index, + bool is_final) { + static constexpr int32_t max_num_rows = std::numeric_limits<int32_t>::max(); + auto parser = std::make_shared<BlockParser>( + io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows); + + std::shared_ptr<Buffer> straddling; + std::vector<util::string_view> views; + if (partial->size() != 0 || completion->size() != 0) { + if (partial->size() == 0) { + straddling = completion; + } else if (completion->size() == 0) { + straddling = partial; + } else { + ARROW_ASSIGN_OR_RAISE( + straddling, ConcatenateBuffers({partial, completion}, io_context_.pool())); + } + views = {util::string_view(*straddling), util::string_view(*block)}; + } else { + views = {util::string_view(*block)}; + } + uint32_t parsed_size; + if (is_final) { + RETURN_NOT_OK(parser->ParseFinal(views, &parsed_size)); + } else { + RETURN_NOT_OK(parser->Parse(views, &parsed_size)); + } + if (count_rows_) { + num_rows_seen_ += parser->num_rows(); + } + return ParseResult{std::move(parser), static_cast<int64_t>(parsed_size)}; + } + + io::IOContext io_context_; + ReadOptions read_options_; + ParseOptions parse_options_; + ConvertOptions convert_options_; + + // Number of columns in the CSV file + int32_t num_csv_cols_ = -1; + // Whether num_rows_seen_ tracks the number of rows seen in the CSV being parsed + bool count_rows_; + // Number of rows seen in the csv. Not used if count_rows is false + int64_t num_rows_seen_; + // Column names in the CSV file + std::vector<std::string> column_names_; + ConversionSchema conversion_schema_; + + std::shared_ptr<io::InputStream> input_; + std::shared_ptr<internal::TaskGroup> task_group_; +}; + +///////////////////////////////////////////////////////////////////////// +// Base class for one-shot table readers + +class BaseTableReader : public ReaderMixin, public csv::TableReader { + public: + using ReaderMixin::ReaderMixin; + + virtual Status Init() = 0; + + Future<std::shared_ptr<Table>> ReadAsync() override { + return Future<std::shared_ptr<Table>>::MakeFinished(Read()); + } + + protected: + // Make column builders from conversion schema + Status MakeColumnBuilders() { + for (const auto& column : conversion_schema_.columns) { + std::shared_ptr<ColumnBuilder> builder; + if (column.is_missing) { + ARROW_ASSIGN_OR_RAISE(builder, ColumnBuilder::MakeNull(io_context_.pool(), + column.type, task_group_)); + } else if (column.type != nullptr) { + ARROW_ASSIGN_OR_RAISE( + builder, ColumnBuilder::Make(io_context_.pool(), column.type, column.index, + convert_options_, task_group_)); + } else { + ARROW_ASSIGN_OR_RAISE(builder, + ColumnBuilder::Make(io_context_.pool(), column.index, + convert_options_, task_group_)); + } + column_builders_.push_back(std::move(builder)); + } + return Status::OK(); + } + + Result<int64_t> ParseAndInsert(const std::shared_ptr<Buffer>& partial, + const std::shared_ptr<Buffer>& completion, + const std::shared_ptr<Buffer>& block, + int64_t block_index, bool is_final) { + ARROW_ASSIGN_OR_RAISE(auto result, + Parse(partial, completion, block, block_index, is_final)); + RETURN_NOT_OK(ProcessData(result.parser, block_index)); + return result.parsed_bytes; + } + + // Trigger conversion of parsed block data + Status ProcessData(const std::shared_ptr<BlockParser>& parser, int64_t block_index) { + for (auto& builder : column_builders_) { + builder->Insert(block_index, parser); + } + return Status::OK(); + } + + Result<std::shared_ptr<Table>> MakeTable() { + DCHECK_EQ(column_builders_.size(), conversion_schema_.columns.size()); + + std::vector<std::shared_ptr<Field>> fields; + std::vector<std::shared_ptr<ChunkedArray>> columns; + + for (int32_t i = 0; i < static_cast<int32_t>(column_builders_.size()); ++i) { + const auto& column = conversion_schema_.columns[i]; + ARROW_ASSIGN_OR_RAISE(auto array, column_builders_[i]->Finish()); + fields.push_back(::arrow::field(column.name, array->type())); + columns.emplace_back(std::move(array)); + } + return Table::Make(schema(std::move(fields)), std::move(columns)); + } + + // Column builders for target Table (in ConversionSchema order) + std::vector<std::shared_ptr<ColumnBuilder>> column_builders_; +}; + +///////////////////////////////////////////////////////////////////////// +// Base class for streaming readers + +class StreamingReaderImpl : public ReaderMixin, + public csv::StreamingReader, + public std::enable_shared_from_this<StreamingReaderImpl> { + public: + StreamingReaderImpl(io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options, bool count_rows) + : ReaderMixin(io_context, std::move(input), read_options, parse_options, + convert_options, count_rows), + bytes_decoded_(std::make_shared<std::atomic<int64_t>>(0)) {} + + Future<> Init(Executor* cpu_executor) { + ARROW_ASSIGN_OR_RAISE(auto istream_it, + io::MakeInputStreamIterator(input_, read_options_.block_size)); + + // TODO Consider exposing readahead as a read option (ARROW-12090) + ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it), + io_context_.executor())); + + auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor); + + auto buffer_generator = CSVBufferIterator::MakeAsync(std::move(transferred_it)); + + int max_readahead = cpu_executor->GetCapacity(); + auto self = shared_from_this(); + + return buffer_generator().Then([self, buffer_generator, max_readahead]( + const std::shared_ptr<Buffer>& first_buffer) { + return self->InitAfterFirstBuffer(first_buffer, buffer_generator, max_readahead); + }); + } + + std::shared_ptr<Schema> schema() const override { return schema_; } + + int64_t bytes_read() const override { return bytes_decoded_->load(); } + + Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { + auto next_fut = ReadNextAsync(); + auto next_result = next_fut.result(); + return std::move(next_result).Value(batch); + } + + Future<std::shared_ptr<RecordBatch>> ReadNextAsync() override { + return record_batch_gen_(); + } + + protected: + Future<> InitAfterFirstBuffer(const std::shared_ptr<Buffer>& first_buffer, + AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator, + int max_readahead) { + if (first_buffer == nullptr) { + return Status::Invalid("Empty CSV file"); + } + + std::shared_ptr<Buffer> after_header; + ARROW_ASSIGN_OR_RAISE(auto header_bytes_consumed, + ProcessHeader(first_buffer, &after_header)); + bytes_decoded_->fetch_add(header_bytes_consumed); + + auto parser_op = + BlockParsingOperator(io_context_, parse_options_, num_csv_cols_, num_rows_seen_); + ARROW_ASSIGN_OR_RAISE( + auto decoder_op, + BlockDecodingOperator::Make(io_context_, convert_options_, conversion_schema_)); + + auto block_gen = SerialBlockReader::MakeAsyncIterator( + std::move(buffer_generator), MakeChunker(parse_options_), std::move(after_header), + read_options_.skip_rows_after_names); + auto parsed_block_gen = + MakeMappedGenerator(std::move(block_gen), std::move(parser_op)); + auto rb_gen = MakeMappedGenerator(std::move(parsed_block_gen), std::move(decoder_op)); + + auto self = shared_from_this(); + return rb_gen().Then([self, rb_gen, max_readahead](const DecodedBlock& first_block) { + return self->InitAfterFirstBatch(first_block, std::move(rb_gen), max_readahead); + }); + } + + Status InitAfterFirstBatch(const DecodedBlock& first_block, + AsyncGenerator<DecodedBlock> batch_gen, int max_readahead) { + schema_ = first_block.record_batch->schema(); + + AsyncGenerator<DecodedBlock> readahead_gen; + if (read_options_.use_threads) { + readahead_gen = MakeReadaheadGenerator(std::move(batch_gen), max_readahead); + } else { + readahead_gen = std::move(batch_gen); + } + + AsyncGenerator<DecodedBlock> restarted_gen; + // Streaming reader should not emit empty record batches + if (first_block.record_batch->num_rows() > 0) { + restarted_gen = MakeGeneratorStartsWith({first_block}, std::move(readahead_gen)); + } else { + restarted_gen = std::move(readahead_gen); + } + + auto bytes_decoded = bytes_decoded_; + auto unwrap_and_record_bytes = + [bytes_decoded]( + const DecodedBlock& block) -> Result<std::shared_ptr<RecordBatch>> { + bytes_decoded->fetch_add(block.bytes_processed); + return block.record_batch; + }; + + auto unwrapped = + MakeMappedGenerator(std::move(restarted_gen), std::move(unwrap_and_record_bytes)); + + record_batch_gen_ = MakeCancellable(std::move(unwrapped), io_context_.stop_token()); + return Status::OK(); + } + + std::shared_ptr<Schema> schema_; + AsyncGenerator<std::shared_ptr<RecordBatch>> record_batch_gen_; + // bytes which have been decoded and asked for by the caller + std::shared_ptr<std::atomic<int64_t>> bytes_decoded_; +}; + +///////////////////////////////////////////////////////////////////////// +// Serial TableReader implementation + +class SerialTableReader : public BaseTableReader { + public: + using BaseTableReader::BaseTableReader; + + Status Init() override { + ARROW_ASSIGN_OR_RAISE(auto istream_it, + io::MakeInputStreamIterator(input_, read_options_.block_size)); + + // Since we're converting serially, no need to readahead more than one block + int32_t block_queue_size = 1; + ARROW_ASSIGN_OR_RAISE(auto rh_it, + MakeReadaheadIterator(std::move(istream_it), block_queue_size)); + buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it)); + return Status::OK(); + } + + Result<std::shared_ptr<Table>> Read() override { + task_group_ = internal::TaskGroup::MakeSerial(io_context_.stop_token()); + + // First block + ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next()); + if (first_buffer == nullptr) { + return Status::Invalid("Empty CSV file"); + } + RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer)); + RETURN_NOT_OK(MakeColumnBuilders()); + + auto block_iterator = SerialBlockReader::MakeIterator( + std::move(buffer_iterator_), MakeChunker(parse_options_), std::move(first_buffer), + read_options_.skip_rows_after_names); + while (true) { + RETURN_NOT_OK(io_context_.stop_token().Poll()); + + ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next()); + if (IsIterationEnd(maybe_block)) { + // EOF + break; + } + ARROW_ASSIGN_OR_RAISE( + int64_t parsed_bytes, + ParseAndInsert(maybe_block.partial, maybe_block.completion, maybe_block.buffer, + maybe_block.block_index, maybe_block.is_final)); + RETURN_NOT_OK(maybe_block.consume_bytes(parsed_bytes)); + } + // Finish conversion, create schema and table + RETURN_NOT_OK(task_group_->Finish()); + return MakeTable(); + } + + protected: + Iterator<std::shared_ptr<Buffer>> buffer_iterator_; +}; + +class AsyncThreadedTableReader + : public BaseTableReader, + public std::enable_shared_from_this<AsyncThreadedTableReader> { + public: + using BaseTableReader::BaseTableReader; + + AsyncThreadedTableReader(io::IOContext io_context, + std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, + const ParseOptions& parse_options, + const ConvertOptions& convert_options, Executor* cpu_executor) + // Count rows is currently not supported during parallel read + : BaseTableReader(std::move(io_context), input, read_options, parse_options, + convert_options, /*count_rows=*/false), + cpu_executor_(cpu_executor) {} + + ~AsyncThreadedTableReader() override { + if (task_group_) { + // In case of error, make sure all pending tasks are finished before + // we start destroying BaseTableReader members + ARROW_UNUSED(task_group_->Finish()); + } + } + + Status Init() override { + ARROW_ASSIGN_OR_RAISE(auto istream_it, + io::MakeInputStreamIterator(input_, read_options_.block_size)); + + int max_readahead = cpu_executor_->GetCapacity(); + int readahead_restart = std::max(1, max_readahead / 2); + + ARROW_ASSIGN_OR_RAISE( + auto bg_it, MakeBackgroundGenerator(std::move(istream_it), io_context_.executor(), + max_readahead, readahead_restart)); + + auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_); + buffer_generator_ = CSVBufferIterator::MakeAsync(std::move(transferred_it)); + return Status::OK(); + } + + Result<std::shared_ptr<Table>> Read() override { return ReadAsync().result(); } + + Future<std::shared_ptr<Table>> ReadAsync() override { + task_group_ = + internal::TaskGroup::MakeThreaded(cpu_executor_, io_context_.stop_token()); + + auto self = shared_from_this(); + return ProcessFirstBuffer().Then([self](const std::shared_ptr<Buffer>& first_buffer) { + auto block_generator = ThreadedBlockReader::MakeAsyncIterator( + self->buffer_generator_, MakeChunker(self->parse_options_), + std::move(first_buffer), self->read_options_.skip_rows_after_names); + + std::function<Status(CSVBlock)> block_visitor = + [self](CSVBlock maybe_block) -> Status { + // The logic in VisitAsyncGenerator ensures that we will never be + // passed an empty block (visit does not call with the end token) so + // we can be assured maybe_block has a value. + DCHECK_GE(maybe_block.block_index, 0); + DCHECK(!maybe_block.consume_bytes); + + // Launch parse task + self->task_group_->Append([self, maybe_block] { + return self + ->ParseAndInsert(maybe_block.partial, maybe_block.completion, + maybe_block.buffer, maybe_block.block_index, + maybe_block.is_final) + .status(); + }); + return Status::OK(); + }; + + return VisitAsyncGenerator(std::move(block_generator), block_visitor) + .Then([self]() -> Future<> { + // By this point we've added all top level tasks so it is safe to call + // FinishAsync + return self->task_group_->FinishAsync(); + }) + .Then([self]() -> Result<std::shared_ptr<Table>> { + // Finish conversion, create schema and table + return self->MakeTable(); + }); + }); + } + + protected: + Future<std::shared_ptr<Buffer>> ProcessFirstBuffer() { + // First block + auto first_buffer_future = buffer_generator_(); + return first_buffer_future.Then([this](const std::shared_ptr<Buffer>& first_buffer) + -> Result<std::shared_ptr<Buffer>> { + if (first_buffer == nullptr) { + return Status::Invalid("Empty CSV file"); + } + std::shared_ptr<Buffer> first_buffer_processed; + RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer_processed)); + RETURN_NOT_OK(MakeColumnBuilders()); + return first_buffer_processed; + }); + } + + Executor* cpu_executor_; + AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator_; +}; + +Result<std::shared_ptr<TableReader>> MakeTableReader( + MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + RETURN_NOT_OK(parse_options.Validate()); + RETURN_NOT_OK(read_options.Validate()); + RETURN_NOT_OK(convert_options.Validate()); + std::shared_ptr<BaseTableReader> reader; + if (read_options.use_threads) { + auto cpu_executor = internal::GetCpuThreadPool(); + reader = std::make_shared<AsyncThreadedTableReader>( + io_context, input, read_options, parse_options, convert_options, cpu_executor); + } else { + reader = std::make_shared<SerialTableReader>(io_context, input, read_options, + parse_options, convert_options, + /*count_rows=*/true); + } + RETURN_NOT_OK(reader->Init()); + return reader; +} + +Future<std::shared_ptr<StreamingReader>> MakeStreamingReader( + io::IOContext io_context, std::shared_ptr<io::InputStream> input, + internal::Executor* cpu_executor, const ReadOptions& read_options, + const ParseOptions& parse_options, const ConvertOptions& convert_options) { + RETURN_NOT_OK(parse_options.Validate()); + RETURN_NOT_OK(read_options.Validate()); + RETURN_NOT_OK(convert_options.Validate()); + std::shared_ptr<StreamingReaderImpl> reader; + reader = std::make_shared<StreamingReaderImpl>( + io_context, input, read_options, parse_options, convert_options, + /*count_rows=*/!read_options.use_threads || cpu_executor->GetCapacity() == 1); + return reader->Init(cpu_executor).Then([reader] { + return std::dynamic_pointer_cast<StreamingReader>(reader); + }); +} + +///////////////////////////////////////////////////////////////////////// +// Row count implementation + +class CSVRowCounter : public ReaderMixin, + public std::enable_shared_from_this<CSVRowCounter> { + public: + CSVRowCounter(io::IOContext io_context, Executor* cpu_executor, + std::shared_ptr<io::InputStream> input, const ReadOptions& read_options, + const ParseOptions& parse_options) + : ReaderMixin(io_context, std::move(input), read_options, parse_options, + ConvertOptions::Defaults(), /*count_rows=*/true), + cpu_executor_(cpu_executor), + row_count_(0) {} + + Future<int64_t> Count() { + auto self = shared_from_this(); + return Init(self).Then([self]() { return self->DoCount(self); }); + } + + private: + Future<> Init(const std::shared_ptr<CSVRowCounter>& self) { + ARROW_ASSIGN_OR_RAISE(auto istream_it, + io::MakeInputStreamIterator(input_, read_options_.block_size)); + // TODO Consider exposing readahead as a read option (ARROW-12090) + ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it), + io_context_.executor())); + auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_); + auto buffer_generator = CSVBufferIterator::MakeAsync(std::move(transferred_it)); + + return buffer_generator().Then( + [self, buffer_generator](std::shared_ptr<Buffer> first_buffer) { + if (!first_buffer) { + return Status::Invalid("Empty CSV file"); + } + RETURN_NOT_OK(self->ProcessHeader(first_buffer, &first_buffer)); + self->block_generator_ = SerialBlockReader::MakeAsyncIterator( + buffer_generator, MakeChunker(self->parse_options_), + std::move(first_buffer), 0); + return Status::OK(); + }); + } + + Future<int64_t> DoCount(const std::shared_ptr<CSVRowCounter>& self) { + // count_cb must return a value instead of Status/Future<> to work with + // MakeMappedGenerator, and it must use a type with a valid end value to work with + // IterationEnd. + std::function<Result<util::optional<int64_t>>(const CSVBlock&)> count_cb = + [self](const CSVBlock& maybe_block) -> Result<util::optional<int64_t>> { + ARROW_ASSIGN_OR_RAISE( + auto parser, + self->Parse(maybe_block.partial, maybe_block.completion, maybe_block.buffer, + maybe_block.block_index, maybe_block.is_final)); + RETURN_NOT_OK(maybe_block.consume_bytes(parser.parsed_bytes)); + self->row_count_ += parser.parser->num_rows(); + return parser.parser->num_rows(); + }; + auto count_gen = MakeMappedGenerator(block_generator_, std::move(count_cb)); + return DiscardAllFromAsyncGenerator(count_gen).Then( + [self]() { return self->row_count_; }); + } + + Executor* cpu_executor_; + AsyncGenerator<CSVBlock> block_generator_; + int64_t row_count_; +}; + +} // namespace + +///////////////////////////////////////////////////////////////////////// +// Factory functions + +Result<std::shared_ptr<TableReader>> TableReader::Make( + io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + return MakeTableReader(io_context.pool(), io_context, std::move(input), read_options, + parse_options, convert_options); +} + +Result<std::shared_ptr<TableReader>> TableReader::Make( + MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + return MakeTableReader(pool, io_context, std::move(input), read_options, parse_options, + convert_options); +} + +Result<std::shared_ptr<StreamingReader>> StreamingReader::Make( + MemoryPool* pool, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + auto io_context = io::IOContext(pool); + auto cpu_executor = internal::GetCpuThreadPool(); + auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor, + read_options, parse_options, convert_options); + auto reader_result = reader_fut.result(); + ARROW_ASSIGN_OR_RAISE(auto reader, reader_result); + return reader; +} + +Result<std::shared_ptr<StreamingReader>> StreamingReader::Make( + io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + auto cpu_executor = internal::GetCpuThreadPool(); + auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor, + read_options, parse_options, convert_options); + auto reader_result = reader_fut.result(); + ARROW_ASSIGN_OR_RAISE(auto reader, reader_result); + return reader; +} + +Future<std::shared_ptr<StreamingReader>> StreamingReader::MakeAsync( + io::IOContext io_context, std::shared_ptr<io::InputStream> input, + internal::Executor* cpu_executor, const ReadOptions& read_options, + const ParseOptions& parse_options, const ConvertOptions& convert_options) { + return MakeStreamingReader(io_context, std::move(input), cpu_executor, read_options, + parse_options, convert_options); +} + +Future<int64_t> CountRowsAsync(io::IOContext io_context, + std::shared_ptr<io::InputStream> input, + internal::Executor* cpu_executor, + const ReadOptions& read_options, + const ParseOptions& parse_options) { + RETURN_NOT_OK(parse_options.Validate()); + RETURN_NOT_OK(read_options.Validate()); + auto counter = std::make_shared<CSVRowCounter>( + io_context, cpu_executor, std::move(input), read_options, parse_options); + return counter->Count(); +} + +} // namespace csv + +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h index 48f02882b1..b1c1749f4b 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h @@ -1,123 +1,123 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <memory> - -#include "arrow/csv/options.h" // IWYU pragma: keep -#include "arrow/io/interfaces.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/future.h" -#include "arrow/util/thread_pool.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace io { -class InputStream; -} // namespace io - -namespace csv { - -/// A class that reads an entire CSV file into a Arrow Table -class ARROW_EXPORT TableReader { - public: - virtual ~TableReader() = default; - - /// Read the entire CSV file and convert it to a Arrow Table - virtual Result<std::shared_ptr<Table>> Read() = 0; - /// Read the entire CSV file and convert it to a Arrow Table - virtual Future<std::shared_ptr<Table>> ReadAsync() = 0; - - /// Create a TableReader instance - static Result<std::shared_ptr<TableReader>> Make(io::IOContext io_context, - std::shared_ptr<io::InputStream> input, - const ReadOptions&, - const ParseOptions&, - const ConvertOptions&); - - ARROW_DEPRECATED("Use MemoryPool-less variant (the IOContext holds a pool already)") - static Result<std::shared_ptr<TableReader>> Make( - MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions&, const ParseOptions&, const ConvertOptions&); -}; - -/// \brief A class that reads a CSV file incrementally -/// -/// Caveats: -/// - For now, this is always single-threaded (regardless of `ReadOptions::use_threads`. -/// - Type inference is done on the first block and types are frozen afterwards; -/// to make sure the right data types are inferred, either set -/// `ReadOptions::block_size` to a large enough value, or use -/// `ConvertOptions::column_types` to set the desired data types explicitly. -class ARROW_EXPORT StreamingReader : public RecordBatchReader { - public: - virtual ~StreamingReader() = default; - - virtual Future<std::shared_ptr<RecordBatch>> ReadNextAsync() = 0; - - /// \brief Return the number of bytes which have been read and processed - /// - /// The returned number includes CSV bytes which the StreamingReader has - /// finished processing, but not bytes for which some processing (e.g. - /// CSV parsing or conversion to Arrow layout) is still ongoing. - /// - /// Furthermore, the following rules apply: - /// - bytes skipped by `ReadOptions.skip_rows` are counted as being read before - /// any records are returned. - /// - bytes read while parsing the header are counted as being read before any - /// records are returned. - /// - bytes skipped by `ReadOptions.skip_rows_after_names` are counted after the - /// first batch is returned. - virtual int64_t bytes_read() const = 0; - - /// Create a StreamingReader instance - /// - /// This involves some I/O as the first batch must be loaded during the creation process - /// so it is returned as a future - /// - /// Currently, the StreamingReader is not async-reentrant and does not do any fan-out - /// parsing (see ARROW-11889) - static Future<std::shared_ptr<StreamingReader>> MakeAsync( - io::IOContext io_context, std::shared_ptr<io::InputStream> input, - internal::Executor* cpu_executor, const ReadOptions&, const ParseOptions&, - const ConvertOptions&); - - static Result<std::shared_ptr<StreamingReader>> Make( - io::IOContext io_context, std::shared_ptr<io::InputStream> input, - const ReadOptions&, const ParseOptions&, const ConvertOptions&); - - ARROW_DEPRECATED("Use IOContext-based overload") - static Result<std::shared_ptr<StreamingReader>> Make( - MemoryPool* pool, std::shared_ptr<io::InputStream> input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options); -}; - -/// \brief Count the logical rows of data in a CSV file (i.e. the -/// number of rows you would get if you read the file into a table). -ARROW_EXPORT -Future<int64_t> CountRowsAsync(io::IOContext io_context, - std::shared_ptr<io::InputStream> input, - internal::Executor* cpu_executor, const ReadOptions&, - const ParseOptions&); - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/csv/options.h" // IWYU pragma: keep +#include "arrow/io/interfaces.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace io { +class InputStream; +} // namespace io + +namespace csv { + +/// A class that reads an entire CSV file into a Arrow Table +class ARROW_EXPORT TableReader { + public: + virtual ~TableReader() = default; + + /// Read the entire CSV file and convert it to a Arrow Table + virtual Result<std::shared_ptr<Table>> Read() = 0; + /// Read the entire CSV file and convert it to a Arrow Table + virtual Future<std::shared_ptr<Table>> ReadAsync() = 0; + + /// Create a TableReader instance + static Result<std::shared_ptr<TableReader>> Make(io::IOContext io_context, + std::shared_ptr<io::InputStream> input, + const ReadOptions&, + const ParseOptions&, + const ConvertOptions&); + + ARROW_DEPRECATED("Use MemoryPool-less variant (the IOContext holds a pool already)") + static Result<std::shared_ptr<TableReader>> Make( + MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions&, const ParseOptions&, const ConvertOptions&); +}; + +/// \brief A class that reads a CSV file incrementally +/// +/// Caveats: +/// - For now, this is always single-threaded (regardless of `ReadOptions::use_threads`. +/// - Type inference is done on the first block and types are frozen afterwards; +/// to make sure the right data types are inferred, either set +/// `ReadOptions::block_size` to a large enough value, or use +/// `ConvertOptions::column_types` to set the desired data types explicitly. +class ARROW_EXPORT StreamingReader : public RecordBatchReader { + public: + virtual ~StreamingReader() = default; + + virtual Future<std::shared_ptr<RecordBatch>> ReadNextAsync() = 0; + + /// \brief Return the number of bytes which have been read and processed + /// + /// The returned number includes CSV bytes which the StreamingReader has + /// finished processing, but not bytes for which some processing (e.g. + /// CSV parsing or conversion to Arrow layout) is still ongoing. + /// + /// Furthermore, the following rules apply: + /// - bytes skipped by `ReadOptions.skip_rows` are counted as being read before + /// any records are returned. + /// - bytes read while parsing the header are counted as being read before any + /// records are returned. + /// - bytes skipped by `ReadOptions.skip_rows_after_names` are counted after the + /// first batch is returned. + virtual int64_t bytes_read() const = 0; + + /// Create a StreamingReader instance + /// + /// This involves some I/O as the first batch must be loaded during the creation process + /// so it is returned as a future + /// + /// Currently, the StreamingReader is not async-reentrant and does not do any fan-out + /// parsing (see ARROW-11889) + static Future<std::shared_ptr<StreamingReader>> MakeAsync( + io::IOContext io_context, std::shared_ptr<io::InputStream> input, + internal::Executor* cpu_executor, const ReadOptions&, const ParseOptions&, + const ConvertOptions&); + + static Result<std::shared_ptr<StreamingReader>> Make( + io::IOContext io_context, std::shared_ptr<io::InputStream> input, + const ReadOptions&, const ParseOptions&, const ConvertOptions&); + + ARROW_DEPRECATED("Use IOContext-based overload") + static Result<std::shared_ptr<StreamingReader>> Make( + MemoryPool* pool, std::shared_ptr<io::InputStream> input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options); +}; + +/// \brief Count the logical rows of data in a CSV file (i.e. the +/// number of rows you would get if you read the file into a table). +ARROW_EXPORT +Future<int64_t> CountRowsAsync(io::IOContext io_context, + std::shared_ptr<io::InputStream> input, + internal::Executor* cpu_executor, const ReadOptions&, + const ParseOptions&); + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/type_fwd.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/type_fwd.h index c0a53847a9..e34a1ab7f5 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/type_fwd.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/type_fwd.h @@ -1,28 +1,28 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -namespace arrow { -namespace csv { - -class TableReader; -struct ConvertOptions; -struct ReadOptions; -struct ParseOptions; -struct WriteOptions; - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace arrow { +namespace csv { + +class TableReader; +struct ConvertOptions; +struct ReadOptions; +struct ParseOptions; +struct WriteOptions; + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.cc b/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.cc index 1b782cae7d..ac58350221 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.cc @@ -1,460 +1,460 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/csv/writer.h" -#include "arrow/array.h" -#include "arrow/compute/cast.h" -#include "arrow/io/interfaces.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/result_internal.h" -#include "arrow/stl_allocator.h" -#include "arrow/util/iterator.h" -#include "arrow/util/logging.h" -#include "arrow/util/make_unique.h" - -#include "arrow/visitor_inline.h" - -namespace arrow { -namespace csv { -// This implementation is intentionally light on configurability to minimize the size of -// the initial PR. Aditional features can be added as there is demand and interest to -// implement them. -// -// The algorithm used here at a high level is to break RecordBatches/Tables into slices -// and convert each slice independently. A slice is then converted to CSV by first -// scanning each column to determine the size of its contents when rendered as a string in -// CSV. For non-string types this requires casting the value to string (which is cached). -// This data is used to understand the precise length of each row and a single allocation -// for the final CSV data buffer. Once the final size is known each column is then -// iterated over again to place its contents into the CSV data buffer. The rationale for -// choosing this approach is it allows for reuse of the cast functionality in the compute -// module and inline data visiting functionality in the core library. A performance -// comparison has not been done using a naive single-pass approach. This approach might -// still be competitive due to reduction in the number of per row branches necessary with -// a single pass approach. Profiling would likely yield further opportunities for -// optimization with this approach. - -namespace { - -struct SliceIteratorFunctor { - Result<std::shared_ptr<RecordBatch>> Next() { - if (current_offset < batch->num_rows()) { - std::shared_ptr<RecordBatch> next = batch->Slice(current_offset, slice_size); - current_offset += slice_size; - return next; - } - return IterationTraits<std::shared_ptr<RecordBatch>>::End(); - } - const RecordBatch* const batch; - const int64_t slice_size; - int64_t current_offset; -}; - -RecordBatchIterator RecordBatchSliceIterator(const RecordBatch& batch, - int64_t slice_size) { - SliceIteratorFunctor functor = {&batch, slice_size, /*offset=*/static_cast<int64_t>(0)}; - return RecordBatchIterator(std::move(functor)); -} - -// Counts the number of characters that need escaping in s. -int64_t CountEscapes(util::string_view s) { - return static_cast<int64_t>(std::count(s.begin(), s.end(), '"')); -} - -// Matching quote pair character length. -constexpr int64_t kQuoteCount = 2; -constexpr int64_t kQuoteDelimiterCount = kQuoteCount + /*end_char*/ 1; - -// Interface for generating CSV data per column. -// The intended usage is to iteratively call UpdateRowLengths for a column and -// then PopulateColumns. PopulateColumns must be called in the reverse order of the -// populators (it populates data backwards). -class ColumnPopulator { - public: - ColumnPopulator(MemoryPool* pool, char end_char) : end_char_(end_char), pool_(pool) {} - - virtual ~ColumnPopulator() = default; - - // Adds the number of characters each entry in data will add to to elements - // in row_lengths. - Status UpdateRowLengths(const Array& data, int32_t* row_lengths) { - compute::ExecContext ctx(pool_); - // Populators are intented to be applied to reasonably small data. In most cases - // threading overhead would not be justified. - ctx.set_use_threads(false); - ASSIGN_OR_RAISE( - std::shared_ptr<Array> casted, - compute::Cast(data, /*to_type=*/utf8(), compute::CastOptions(), &ctx)); - casted_array_ = internal::checked_pointer_cast<StringArray>(casted); - return UpdateRowLengths(row_lengths); - } - - // Places string data onto each row in output and updates the corresponding row - // row pointers in preparation for calls to other (preceding) ColumnPopulators. - // Args: - // output: character buffer to write to. - // offsets: an array of end of row column within the the output buffer (values are - // one past the end of the position to write to). - virtual void PopulateColumns(char* output, int32_t* offsets) const = 0; - - protected: - virtual Status UpdateRowLengths(int32_t* row_lengths) = 0; - std::shared_ptr<StringArray> casted_array_; - const char end_char_; - - private: - MemoryPool* const pool_; -}; - -// Copies the contents of to out properly escaping any necessary characters. -// Returns the position prior to last copied character (out_end is decremented). -char* EscapeReverse(arrow::util::string_view s, char* out_end) { - for (const char* val = s.data() + s.length() - 1; val >= s.data(); val--, out_end--) { - if (*val == '"') { - *out_end = *val; - out_end--; - } - *out_end = *val; - } - return out_end; -} - -// Populator for non-string types. This populator relies on compute Cast functionality to -// String if it doesn't exist it will be an error. it also assumes the resulting string -// from a cast does not require quoting or escaping. -class UnquotedColumnPopulator : public ColumnPopulator { - public: - explicit UnquotedColumnPopulator(MemoryPool* memory_pool, char end_char) - : ColumnPopulator(memory_pool, end_char) {} - - Status UpdateRowLengths(int32_t* row_lengths) override { - for (int x = 0; x < casted_array_->length(); x++) { - row_lengths[x] += casted_array_->value_length(x); - } - return Status::OK(); - } - - void PopulateColumns(char* output, int32_t* offsets) const override { - VisitArrayDataInline<StringType>( - *casted_array_->data(), - [&](arrow::util::string_view s) { - int64_t next_column_offset = s.length() + /*end_char*/ 1; - memcpy((output + *offsets - next_column_offset), s.data(), s.length()); - *(output + *offsets - 1) = end_char_; - *offsets -= static_cast<int32_t>(next_column_offset); - offsets++; - }, - [&]() { - // Nulls are empty (unquoted) to distinguish with empty string. - *(output + *offsets - 1) = end_char_; - *offsets -= 1; - offsets++; - }); - } -}; - -// Strings need special handling to ensure they are escaped properly. -// This class handles escaping assuming that all strings will be quoted -// and that the only character within the string that needs to escaped is -// a quote character (") and escaping is done my adding another quote. -class QuotedColumnPopulator : public ColumnPopulator { - public: - QuotedColumnPopulator(MemoryPool* pool, char end_char) - : ColumnPopulator(pool, end_char) {} - - Status UpdateRowLengths(int32_t* row_lengths) override { - const StringArray& input = *casted_array_; - int row_number = 0; - row_needs_escaping_.resize(casted_array_->length()); - VisitArrayDataInline<StringType>( - *input.data(), - [&](arrow::util::string_view s) { - int64_t escaped_count = CountEscapes(s); - // TODO: Maybe use 64 bit row lengths or safe cast? - row_needs_escaping_[row_number] = escaped_count > 0; - row_lengths[row_number] += static_cast<int32_t>(s.length()) + - static_cast<int32_t>(escaped_count + kQuoteCount); - row_number++; - }, - [&]() { - row_needs_escaping_[row_number] = false; - row_number++; - }); - return Status::OK(); - } - - void PopulateColumns(char* output, int32_t* offsets) const override { - auto needs_escaping = row_needs_escaping_.begin(); - VisitArrayDataInline<StringType>( - *(casted_array_->data()), - [&](arrow::util::string_view s) { - // still needs string content length to be added - char* row_end = output + *offsets; - int32_t next_column_offset = 0; - if (!*needs_escaping) { - next_column_offset = static_cast<int32_t>(s.length() + kQuoteDelimiterCount); - memcpy(row_end - next_column_offset + /*quote_offset=*/1, s.data(), - s.length()); - } else { - // Adjust row_end by 3: 1 quote char, 1 end char and 1 to position at the - // first position to write to. - next_column_offset = - static_cast<int32_t>(row_end - EscapeReverse(s, row_end - 3)); - } - *(row_end - next_column_offset) = '"'; - *(row_end - 2) = '"'; - *(row_end - 1) = end_char_; - *offsets -= next_column_offset; - offsets++; - needs_escaping++; - }, - [&]() { - // Nulls are empty (unquoted) to distinguish with empty string. - *(output + *offsets - 1) = end_char_; - *offsets -= 1; - offsets++; - needs_escaping++; - }); - } - - private: - // Older version of GCC don't support custom allocators - // at some point we should change this to use memory_pool - // backed allocator. - std::vector<bool> row_needs_escaping_; -}; - -struct PopulatorFactory { - template <typename TypeClass> - enable_if_t<is_base_binary_type<TypeClass>::value || - std::is_same<FixedSizeBinaryType, TypeClass>::value, - Status> - Visit(const TypeClass& type) { - populator = new QuotedColumnPopulator(pool, end_char); - return Status::OK(); - } - - template <typename TypeClass> - enable_if_dictionary<TypeClass, Status> Visit(const TypeClass& type) { - return VisitTypeInline(*type.value_type(), this); - } - - template <typename TypeClass> - enable_if_t<is_nested_type<TypeClass>::value || is_extension_type<TypeClass>::value, - Status> - Visit(const TypeClass& type) { - return Status::Invalid("Unsupported Type:", type.ToString()); - } - - template <typename TypeClass> - enable_if_t<is_primitive_ctype<TypeClass>::value || is_decimal_type<TypeClass>::value || - is_null_type<TypeClass>::value || is_temporal_type<TypeClass>::value, - Status> - Visit(const TypeClass& type) { - populator = new UnquotedColumnPopulator(pool, end_char); - return Status::OK(); - } - - char end_char; - MemoryPool* pool; - ColumnPopulator* populator; -}; - -Result<std::unique_ptr<ColumnPopulator>> MakePopulator(const Field& field, char end_char, - MemoryPool* pool) { - PopulatorFactory factory{end_char, pool, nullptr}; - RETURN_NOT_OK(VisitTypeInline(*field.type(), &factory)); - return std::unique_ptr<ColumnPopulator>(factory.populator); -} - -class CSVWriterImpl : public ipc::RecordBatchWriter { - public: - static Result<std::shared_ptr<CSVWriterImpl>> Make( - io::OutputStream* sink, std::shared_ptr<io::OutputStream> owned_sink, - std::shared_ptr<Schema> schema, const WriteOptions& options) { - RETURN_NOT_OK(options.Validate()); - std::vector<std::unique_ptr<ColumnPopulator>> populators(schema->num_fields()); - for (int col = 0; col < schema->num_fields(); col++) { - char end_char = col < schema->num_fields() - 1 ? ',' : '\n'; - ASSIGN_OR_RAISE(populators[col], MakePopulator(*schema->field(col), end_char, - options.io_context.pool())); - } - auto writer = std::make_shared<CSVWriterImpl>( - sink, std::move(owned_sink), std::move(schema), std::move(populators), options); - RETURN_NOT_OK(writer->PrepareForContentsWrite()); - if (options.include_header) { - RETURN_NOT_OK(writer->WriteHeader()); - } - return writer; - } - - Status WriteRecordBatch(const RecordBatch& batch) override { - RecordBatchIterator iterator = RecordBatchSliceIterator(batch, options_.batch_size); - for (auto maybe_slice : iterator) { - ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> slice, maybe_slice); - RETURN_NOT_OK(TranslateMinimalBatch(*slice)); - RETURN_NOT_OK(sink_->Write(data_buffer_)); - stats_.num_record_batches++; - } - return Status::OK(); - } - - Status WriteTable(const Table& table, int64_t max_chunksize) override { - TableBatchReader reader(table); - reader.set_chunksize(max_chunksize > 0 ? max_chunksize : options_.batch_size); - std::shared_ptr<RecordBatch> batch; - RETURN_NOT_OK(reader.ReadNext(&batch)); - while (batch != nullptr) { - RETURN_NOT_OK(TranslateMinimalBatch(*batch)); - RETURN_NOT_OK(sink_->Write(data_buffer_)); - RETURN_NOT_OK(reader.ReadNext(&batch)); - stats_.num_record_batches++; - } - - return Status::OK(); - } - - Status Close() override { return Status::OK(); } - - ipc::WriteStats stats() const override { return stats_; } - - CSVWriterImpl(io::OutputStream* sink, std::shared_ptr<io::OutputStream> owned_sink, - std::shared_ptr<Schema> schema, - std::vector<std::unique_ptr<ColumnPopulator>> populators, - const WriteOptions& options) - : sink_(sink), - owned_sink_(std::move(owned_sink)), - column_populators_(std::move(populators)), - offsets_(0, 0, ::arrow::stl::allocator<char*>(options.io_context.pool())), - schema_(std::move(schema)), - options_(options) {} - - private: - Status PrepareForContentsWrite() { - // Only called once, as part of initialization - if (data_buffer_ == nullptr) { - ASSIGN_OR_RAISE(data_buffer_, - AllocateResizableBuffer( - options_.batch_size * schema_->num_fields() * kColumnSizeGuess, - options_.io_context.pool())); - } - return Status::OK(); - } - - int64_t CalculateHeaderSize() const { - int64_t header_length = 0; - for (int col = 0; col < schema_->num_fields(); col++) { - const std::string& col_name = schema_->field(col)->name(); - header_length += col_name.size(); - header_length += CountEscapes(col_name); - } - return header_length + (kQuoteDelimiterCount * schema_->num_fields()); - } - - Status WriteHeader() { - // Only called once, as part of initialization - RETURN_NOT_OK(data_buffer_->Resize(CalculateHeaderSize(), /*shrink_to_fit=*/false)); - char* next = - reinterpret_cast<char*>(data_buffer_->mutable_data() + data_buffer_->size() - 1); - for (int col = schema_->num_fields() - 1; col >= 0; col--) { - *next-- = ','; - *next-- = '"'; - next = EscapeReverse(schema_->field(col)->name(), next); - *next-- = '"'; - } - *(data_buffer_->mutable_data() + data_buffer_->size() - 1) = '\n'; - DCHECK_EQ(reinterpret_cast<uint8_t*>(next + 1), data_buffer_->data()); - return sink_->Write(data_buffer_); - } - - Status TranslateMinimalBatch(const RecordBatch& batch) { - if (batch.num_rows() == 0) { - return Status::OK(); - } - offsets_.resize(batch.num_rows()); - std::fill(offsets_.begin(), offsets_.end(), 0); - - // Calculate relative offsets for each row (excluding delimiters) - for (int32_t col = 0; col < static_cast<int32_t>(column_populators_.size()); col++) { - RETURN_NOT_OK( - column_populators_[col]->UpdateRowLengths(*batch.column(col), offsets_.data())); - } - // Calculate cumulalative offsets for each row (including delimiters). - offsets_[0] += batch.num_columns(); - for (int64_t row = 1; row < batch.num_rows(); row++) { - offsets_[row] += offsets_[row - 1] + /*delimiter lengths*/ batch.num_columns(); - } - // Resize the target buffer to required size. We assume batch to batch sizes - // should be pretty close so don't shrink the buffer to avoid allocation churn. - RETURN_NOT_OK(data_buffer_->Resize(offsets_.back(), /*shrink_to_fit=*/false)); - - // Use the offsets to populate contents. - for (auto populator = column_populators_.rbegin(); - populator != column_populators_.rend(); populator++) { - (*populator) - ->PopulateColumns(reinterpret_cast<char*>(data_buffer_->mutable_data()), - offsets_.data()); - } - DCHECK_EQ(0, offsets_[0]); - return Status::OK(); - } - - static constexpr int64_t kColumnSizeGuess = 8; - io::OutputStream* sink_; - std::shared_ptr<io::OutputStream> owned_sink_; - std::vector<std::unique_ptr<ColumnPopulator>> column_populators_; - std::vector<int32_t, arrow::stl::allocator<int32_t>> offsets_; - std::shared_ptr<ResizableBuffer> data_buffer_; - const std::shared_ptr<Schema> schema_; - const WriteOptions options_; - ipc::WriteStats stats_; -}; - -} // namespace - -Status WriteCSV(const Table& table, const WriteOptions& options, - arrow::io::OutputStream* output) { - ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(output, table.schema(), options)); - RETURN_NOT_OK(writer->WriteTable(table)); - return writer->Close(); -} - -Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, - arrow::io::OutputStream* output) { - ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(output, batch.schema(), options)); - RETURN_NOT_OK(writer->WriteRecordBatch(batch)); - return writer->Close(); -} - -ARROW_EXPORT -Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( - std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema, - const WriteOptions& options) { - return CSVWriterImpl::Make(sink.get(), sink, schema, options); -} - -ARROW_EXPORT -Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( - io::OutputStream* sink, const std::shared_ptr<Schema>& schema, - const WriteOptions& options) { - return CSVWriterImpl::Make(sink, nullptr, schema, options); -} - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/writer.h" +#include "arrow/array.h" +#include "arrow/compute/cast.h" +#include "arrow/io/interfaces.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/result_internal.h" +#include "arrow/stl_allocator.h" +#include "arrow/util/iterator.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace csv { +// This implementation is intentionally light on configurability to minimize the size of +// the initial PR. Aditional features can be added as there is demand and interest to +// implement them. +// +// The algorithm used here at a high level is to break RecordBatches/Tables into slices +// and convert each slice independently. A slice is then converted to CSV by first +// scanning each column to determine the size of its contents when rendered as a string in +// CSV. For non-string types this requires casting the value to string (which is cached). +// This data is used to understand the precise length of each row and a single allocation +// for the final CSV data buffer. Once the final size is known each column is then +// iterated over again to place its contents into the CSV data buffer. The rationale for +// choosing this approach is it allows for reuse of the cast functionality in the compute +// module and inline data visiting functionality in the core library. A performance +// comparison has not been done using a naive single-pass approach. This approach might +// still be competitive due to reduction in the number of per row branches necessary with +// a single pass approach. Profiling would likely yield further opportunities for +// optimization with this approach. + +namespace { + +struct SliceIteratorFunctor { + Result<std::shared_ptr<RecordBatch>> Next() { + if (current_offset < batch->num_rows()) { + std::shared_ptr<RecordBatch> next = batch->Slice(current_offset, slice_size); + current_offset += slice_size; + return next; + } + return IterationTraits<std::shared_ptr<RecordBatch>>::End(); + } + const RecordBatch* const batch; + const int64_t slice_size; + int64_t current_offset; +}; + +RecordBatchIterator RecordBatchSliceIterator(const RecordBatch& batch, + int64_t slice_size) { + SliceIteratorFunctor functor = {&batch, slice_size, /*offset=*/static_cast<int64_t>(0)}; + return RecordBatchIterator(std::move(functor)); +} + +// Counts the number of characters that need escaping in s. +int64_t CountEscapes(util::string_view s) { + return static_cast<int64_t>(std::count(s.begin(), s.end(), '"')); +} + +// Matching quote pair character length. +constexpr int64_t kQuoteCount = 2; +constexpr int64_t kQuoteDelimiterCount = kQuoteCount + /*end_char*/ 1; + +// Interface for generating CSV data per column. +// The intended usage is to iteratively call UpdateRowLengths for a column and +// then PopulateColumns. PopulateColumns must be called in the reverse order of the +// populators (it populates data backwards). +class ColumnPopulator { + public: + ColumnPopulator(MemoryPool* pool, char end_char) : end_char_(end_char), pool_(pool) {} + + virtual ~ColumnPopulator() = default; + + // Adds the number of characters each entry in data will add to to elements + // in row_lengths. + Status UpdateRowLengths(const Array& data, int32_t* row_lengths) { + compute::ExecContext ctx(pool_); + // Populators are intented to be applied to reasonably small data. In most cases + // threading overhead would not be justified. + ctx.set_use_threads(false); + ASSIGN_OR_RAISE( + std::shared_ptr<Array> casted, + compute::Cast(data, /*to_type=*/utf8(), compute::CastOptions(), &ctx)); + casted_array_ = internal::checked_pointer_cast<StringArray>(casted); + return UpdateRowLengths(row_lengths); + } + + // Places string data onto each row in output and updates the corresponding row + // row pointers in preparation for calls to other (preceding) ColumnPopulators. + // Args: + // output: character buffer to write to. + // offsets: an array of end of row column within the the output buffer (values are + // one past the end of the position to write to). + virtual void PopulateColumns(char* output, int32_t* offsets) const = 0; + + protected: + virtual Status UpdateRowLengths(int32_t* row_lengths) = 0; + std::shared_ptr<StringArray> casted_array_; + const char end_char_; + + private: + MemoryPool* const pool_; +}; + +// Copies the contents of to out properly escaping any necessary characters. +// Returns the position prior to last copied character (out_end is decremented). +char* EscapeReverse(arrow::util::string_view s, char* out_end) { + for (const char* val = s.data() + s.length() - 1; val >= s.data(); val--, out_end--) { + if (*val == '"') { + *out_end = *val; + out_end--; + } + *out_end = *val; + } + return out_end; +} + +// Populator for non-string types. This populator relies on compute Cast functionality to +// String if it doesn't exist it will be an error. it also assumes the resulting string +// from a cast does not require quoting or escaping. +class UnquotedColumnPopulator : public ColumnPopulator { + public: + explicit UnquotedColumnPopulator(MemoryPool* memory_pool, char end_char) + : ColumnPopulator(memory_pool, end_char) {} + + Status UpdateRowLengths(int32_t* row_lengths) override { + for (int x = 0; x < casted_array_->length(); x++) { + row_lengths[x] += casted_array_->value_length(x); + } + return Status::OK(); + } + + void PopulateColumns(char* output, int32_t* offsets) const override { + VisitArrayDataInline<StringType>( + *casted_array_->data(), + [&](arrow::util::string_view s) { + int64_t next_column_offset = s.length() + /*end_char*/ 1; + memcpy((output + *offsets - next_column_offset), s.data(), s.length()); + *(output + *offsets - 1) = end_char_; + *offsets -= static_cast<int32_t>(next_column_offset); + offsets++; + }, + [&]() { + // Nulls are empty (unquoted) to distinguish with empty string. + *(output + *offsets - 1) = end_char_; + *offsets -= 1; + offsets++; + }); + } +}; + +// Strings need special handling to ensure they are escaped properly. +// This class handles escaping assuming that all strings will be quoted +// and that the only character within the string that needs to escaped is +// a quote character (") and escaping is done my adding another quote. +class QuotedColumnPopulator : public ColumnPopulator { + public: + QuotedColumnPopulator(MemoryPool* pool, char end_char) + : ColumnPopulator(pool, end_char) {} + + Status UpdateRowLengths(int32_t* row_lengths) override { + const StringArray& input = *casted_array_; + int row_number = 0; + row_needs_escaping_.resize(casted_array_->length()); + VisitArrayDataInline<StringType>( + *input.data(), + [&](arrow::util::string_view s) { + int64_t escaped_count = CountEscapes(s); + // TODO: Maybe use 64 bit row lengths or safe cast? + row_needs_escaping_[row_number] = escaped_count > 0; + row_lengths[row_number] += static_cast<int32_t>(s.length()) + + static_cast<int32_t>(escaped_count + kQuoteCount); + row_number++; + }, + [&]() { + row_needs_escaping_[row_number] = false; + row_number++; + }); + return Status::OK(); + } + + void PopulateColumns(char* output, int32_t* offsets) const override { + auto needs_escaping = row_needs_escaping_.begin(); + VisitArrayDataInline<StringType>( + *(casted_array_->data()), + [&](arrow::util::string_view s) { + // still needs string content length to be added + char* row_end = output + *offsets; + int32_t next_column_offset = 0; + if (!*needs_escaping) { + next_column_offset = static_cast<int32_t>(s.length() + kQuoteDelimiterCount); + memcpy(row_end - next_column_offset + /*quote_offset=*/1, s.data(), + s.length()); + } else { + // Adjust row_end by 3: 1 quote char, 1 end char and 1 to position at the + // first position to write to. + next_column_offset = + static_cast<int32_t>(row_end - EscapeReverse(s, row_end - 3)); + } + *(row_end - next_column_offset) = '"'; + *(row_end - 2) = '"'; + *(row_end - 1) = end_char_; + *offsets -= next_column_offset; + offsets++; + needs_escaping++; + }, + [&]() { + // Nulls are empty (unquoted) to distinguish with empty string. + *(output + *offsets - 1) = end_char_; + *offsets -= 1; + offsets++; + needs_escaping++; + }); + } + + private: + // Older version of GCC don't support custom allocators + // at some point we should change this to use memory_pool + // backed allocator. + std::vector<bool> row_needs_escaping_; +}; + +struct PopulatorFactory { + template <typename TypeClass> + enable_if_t<is_base_binary_type<TypeClass>::value || + std::is_same<FixedSizeBinaryType, TypeClass>::value, + Status> + Visit(const TypeClass& type) { + populator = new QuotedColumnPopulator(pool, end_char); + return Status::OK(); + } + + template <typename TypeClass> + enable_if_dictionary<TypeClass, Status> Visit(const TypeClass& type) { + return VisitTypeInline(*type.value_type(), this); + } + + template <typename TypeClass> + enable_if_t<is_nested_type<TypeClass>::value || is_extension_type<TypeClass>::value, + Status> + Visit(const TypeClass& type) { + return Status::Invalid("Unsupported Type:", type.ToString()); + } + + template <typename TypeClass> + enable_if_t<is_primitive_ctype<TypeClass>::value || is_decimal_type<TypeClass>::value || + is_null_type<TypeClass>::value || is_temporal_type<TypeClass>::value, + Status> + Visit(const TypeClass& type) { + populator = new UnquotedColumnPopulator(pool, end_char); + return Status::OK(); + } + + char end_char; + MemoryPool* pool; + ColumnPopulator* populator; +}; + +Result<std::unique_ptr<ColumnPopulator>> MakePopulator(const Field& field, char end_char, + MemoryPool* pool) { + PopulatorFactory factory{end_char, pool, nullptr}; + RETURN_NOT_OK(VisitTypeInline(*field.type(), &factory)); + return std::unique_ptr<ColumnPopulator>(factory.populator); +} + +class CSVWriterImpl : public ipc::RecordBatchWriter { + public: + static Result<std::shared_ptr<CSVWriterImpl>> Make( + io::OutputStream* sink, std::shared_ptr<io::OutputStream> owned_sink, + std::shared_ptr<Schema> schema, const WriteOptions& options) { + RETURN_NOT_OK(options.Validate()); + std::vector<std::unique_ptr<ColumnPopulator>> populators(schema->num_fields()); + for (int col = 0; col < schema->num_fields(); col++) { + char end_char = col < schema->num_fields() - 1 ? ',' : '\n'; + ASSIGN_OR_RAISE(populators[col], MakePopulator(*schema->field(col), end_char, + options.io_context.pool())); + } + auto writer = std::make_shared<CSVWriterImpl>( + sink, std::move(owned_sink), std::move(schema), std::move(populators), options); + RETURN_NOT_OK(writer->PrepareForContentsWrite()); + if (options.include_header) { + RETURN_NOT_OK(writer->WriteHeader()); + } + return writer; + } + + Status WriteRecordBatch(const RecordBatch& batch) override { + RecordBatchIterator iterator = RecordBatchSliceIterator(batch, options_.batch_size); + for (auto maybe_slice : iterator) { + ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> slice, maybe_slice); + RETURN_NOT_OK(TranslateMinimalBatch(*slice)); + RETURN_NOT_OK(sink_->Write(data_buffer_)); + stats_.num_record_batches++; + } + return Status::OK(); + } + + Status WriteTable(const Table& table, int64_t max_chunksize) override { + TableBatchReader reader(table); + reader.set_chunksize(max_chunksize > 0 ? max_chunksize : options_.batch_size); + std::shared_ptr<RecordBatch> batch; + RETURN_NOT_OK(reader.ReadNext(&batch)); + while (batch != nullptr) { + RETURN_NOT_OK(TranslateMinimalBatch(*batch)); + RETURN_NOT_OK(sink_->Write(data_buffer_)); + RETURN_NOT_OK(reader.ReadNext(&batch)); + stats_.num_record_batches++; + } + + return Status::OK(); + } + + Status Close() override { return Status::OK(); } + + ipc::WriteStats stats() const override { return stats_; } + + CSVWriterImpl(io::OutputStream* sink, std::shared_ptr<io::OutputStream> owned_sink, + std::shared_ptr<Schema> schema, + std::vector<std::unique_ptr<ColumnPopulator>> populators, + const WriteOptions& options) + : sink_(sink), + owned_sink_(std::move(owned_sink)), + column_populators_(std::move(populators)), + offsets_(0, 0, ::arrow::stl::allocator<char*>(options.io_context.pool())), + schema_(std::move(schema)), + options_(options) {} + + private: + Status PrepareForContentsWrite() { + // Only called once, as part of initialization + if (data_buffer_ == nullptr) { + ASSIGN_OR_RAISE(data_buffer_, + AllocateResizableBuffer( + options_.batch_size * schema_->num_fields() * kColumnSizeGuess, + options_.io_context.pool())); + } + return Status::OK(); + } + + int64_t CalculateHeaderSize() const { + int64_t header_length = 0; + for (int col = 0; col < schema_->num_fields(); col++) { + const std::string& col_name = schema_->field(col)->name(); + header_length += col_name.size(); + header_length += CountEscapes(col_name); + } + return header_length + (kQuoteDelimiterCount * schema_->num_fields()); + } + + Status WriteHeader() { + // Only called once, as part of initialization + RETURN_NOT_OK(data_buffer_->Resize(CalculateHeaderSize(), /*shrink_to_fit=*/false)); + char* next = + reinterpret_cast<char*>(data_buffer_->mutable_data() + data_buffer_->size() - 1); + for (int col = schema_->num_fields() - 1; col >= 0; col--) { + *next-- = ','; + *next-- = '"'; + next = EscapeReverse(schema_->field(col)->name(), next); + *next-- = '"'; + } + *(data_buffer_->mutable_data() + data_buffer_->size() - 1) = '\n'; + DCHECK_EQ(reinterpret_cast<uint8_t*>(next + 1), data_buffer_->data()); + return sink_->Write(data_buffer_); + } + + Status TranslateMinimalBatch(const RecordBatch& batch) { + if (batch.num_rows() == 0) { + return Status::OK(); + } + offsets_.resize(batch.num_rows()); + std::fill(offsets_.begin(), offsets_.end(), 0); + + // Calculate relative offsets for each row (excluding delimiters) + for (int32_t col = 0; col < static_cast<int32_t>(column_populators_.size()); col++) { + RETURN_NOT_OK( + column_populators_[col]->UpdateRowLengths(*batch.column(col), offsets_.data())); + } + // Calculate cumulalative offsets for each row (including delimiters). + offsets_[0] += batch.num_columns(); + for (int64_t row = 1; row < batch.num_rows(); row++) { + offsets_[row] += offsets_[row - 1] + /*delimiter lengths*/ batch.num_columns(); + } + // Resize the target buffer to required size. We assume batch to batch sizes + // should be pretty close so don't shrink the buffer to avoid allocation churn. + RETURN_NOT_OK(data_buffer_->Resize(offsets_.back(), /*shrink_to_fit=*/false)); + + // Use the offsets to populate contents. + for (auto populator = column_populators_.rbegin(); + populator != column_populators_.rend(); populator++) { + (*populator) + ->PopulateColumns(reinterpret_cast<char*>(data_buffer_->mutable_data()), + offsets_.data()); + } + DCHECK_EQ(0, offsets_[0]); + return Status::OK(); + } + + static constexpr int64_t kColumnSizeGuess = 8; + io::OutputStream* sink_; + std::shared_ptr<io::OutputStream> owned_sink_; + std::vector<std::unique_ptr<ColumnPopulator>> column_populators_; + std::vector<int32_t, arrow::stl::allocator<int32_t>> offsets_; + std::shared_ptr<ResizableBuffer> data_buffer_; + const std::shared_ptr<Schema> schema_; + const WriteOptions options_; + ipc::WriteStats stats_; +}; + +} // namespace + +Status WriteCSV(const Table& table, const WriteOptions& options, + arrow::io::OutputStream* output) { + ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(output, table.schema(), options)); + RETURN_NOT_OK(writer->WriteTable(table)); + return writer->Close(); +} + +Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, + arrow::io::OutputStream* output) { + ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(output, batch.schema(), options)); + RETURN_NOT_OK(writer->WriteRecordBatch(batch)); + return writer->Close(); +} + +ARROW_EXPORT +Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( + std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema, + const WriteOptions& options) { + return CSVWriterImpl::Make(sink.get(), sink, schema, options); +} + +ARROW_EXPORT +Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( + io::OutputStream* sink, const std::shared_ptr<Schema>& schema, + const WriteOptions& options) { + return CSVWriterImpl::Make(sink, nullptr, schema, options); +} + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h b/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h index 2f1442ae0a..bb31b223a8 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h @@ -1,73 +1,73 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <memory> - -#include "arrow/csv/options.h" -#include "arrow/io/interfaces.h" -#include "arrow/ipc/type_fwd.h" -#include "arrow/record_batch.h" -#include "arrow/table.h" - -namespace arrow { -namespace csv { -// Functionality for converting Arrow data to Comma separated value text. -// This library supports all primitive types that can be cast to a StringArrays. -// It applies to following formatting rules: -// - For non-binary types no quotes surround values. Nulls are represented as the empty -// string. -// - For binary types all non-null data is quoted (and quotes within data are escaped -// with an additional quote). -// Null values are empty and unquoted. -// - LF (\n) is always used as a line ending. - -/// \brief Converts table to a CSV and writes the results to output. -/// Experimental -ARROW_EXPORT Status WriteCSV(const Table& table, const WriteOptions& options, - arrow::io::OutputStream* output); -/// \brief Converts batch to CSV and writes the results to output. -/// Experimental -ARROW_EXPORT Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, - arrow::io::OutputStream* output); - -/// \brief Create a new CSV writer. User is responsible for closing the -/// actual OutputStream. -/// -/// \param[in] sink output stream to write to -/// \param[in] schema the schema of the record batches to be written -/// \param[in] options options for serialization -/// \return Result<std::shared_ptr<RecordBatchWriter>> -ARROW_EXPORT -Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( - std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema, - const WriteOptions& options = WriteOptions::Defaults()); - -/// \brief Create a new CSV writer. -/// -/// \param[in] sink output stream to write to (does not take ownership) -/// \param[in] schema the schema of the record batches to be written -/// \param[in] options options for serialization -/// \return Result<std::shared_ptr<RecordBatchWriter>> -ARROW_EXPORT -Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( - io::OutputStream* sink, const std::shared_ptr<Schema>& schema, - const WriteOptions& options = WriteOptions::Defaults()); - -} // namespace csv -} // namespace arrow +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/csv/options.h" +#include "arrow/io/interfaces.h" +#include "arrow/ipc/type_fwd.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" + +namespace arrow { +namespace csv { +// Functionality for converting Arrow data to Comma separated value text. +// This library supports all primitive types that can be cast to a StringArrays. +// It applies to following formatting rules: +// - For non-binary types no quotes surround values. Nulls are represented as the empty +// string. +// - For binary types all non-null data is quoted (and quotes within data are escaped +// with an additional quote). +// Null values are empty and unquoted. +// - LF (\n) is always used as a line ending. + +/// \brief Converts table to a CSV and writes the results to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const Table& table, const WriteOptions& options, + arrow::io::OutputStream* output); +/// \brief Converts batch to CSV and writes the results to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, + arrow::io::OutputStream* output); + +/// \brief Create a new CSV writer. User is responsible for closing the +/// actual OutputStream. +/// +/// \param[in] sink output stream to write to +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization +/// \return Result<std::shared_ptr<RecordBatchWriter>> +ARROW_EXPORT +Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( + std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema, + const WriteOptions& options = WriteOptions::Defaults()); + +/// \brief Create a new CSV writer. +/// +/// \param[in] sink output stream to write to (does not take ownership) +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization +/// \return Result<std::shared_ptr<RecordBatchWriter>> +ARROW_EXPORT +Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter( + io::OutputStream* sink, const std::shared_ptr<Schema>& schema, + const WriteOptions& options = WriteOptions::Defaults()); + +} // namespace csv +} // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/util/int128_internal.h b/contrib/libs/apache/arrow/cpp/src/arrow/util/int128_internal.h index 1d494671a9..b7d40118b4 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/util/int128_internal.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/util/int128_internal.h @@ -20,7 +20,7 @@ #include "arrow/util/macros.h" #ifndef ARROW_USE_NATIVE_INT128 -#include <boost/multiprecision/cpp_int.hpp> +#include <boost/multiprecision/cpp_int.hpp> #endif namespace arrow { diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/util/logging.cc b/contrib/libs/apache/arrow/cpp/src/arrow/util/logging.cc index 65359b4408..f212d3d406 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/util/logging.cc +++ b/contrib/libs/apache/arrow/cpp/src/arrow/util/logging.cc @@ -28,7 +28,7 @@ #include <signal.h> #include <vector> -#error #include "glog/logging.h" +#error #include "glog/logging.h" // Restore our versions of DCHECK and friends, as GLog defines its own #undef DCHECK diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/util/optional.h b/contrib/libs/apache/arrow/cpp/src/arrow/util/optional.h index b824b499bb..546c74bb32 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/util/optional.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/util/optional.h @@ -17,17 +17,17 @@ #pragma once -#include <optional> +#include <optional> namespace arrow { namespace util { template <typename T> -using optional = std::optional<T>; +using optional = std::optional<T>; -using std::bad_optional_access; -using std::make_optional; -using std::nullopt; +using std::bad_optional_access; +using std::make_optional; +using std::nullopt; } // namespace util } // namespace arrow diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.h b/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.h index 46567d69b1..f3edfcd1dc 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.h +++ b/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.h @@ -1,53 +1,53 @@ -// -// ios.h -// DateTimeLib -// -// The MIT License (MIT) -// -// Copyright (c) 2016 Alexander Kormanovsky -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#ifndef ios_hpp -#define ios_hpp - -#if __APPLE__ -# include <TargetConditionals.h> -# if TARGET_OS_IPHONE -# include <string> - - namespace arrow_vendored - { - namespace date - { - namespace iOSUtils - { - - std::string get_tzdata_path(); - std::string get_current_timezone(); - - } // namespace iOSUtils - } // namespace date - } // namespace arrow_vendored - -# endif // TARGET_OS_IPHONE -#else // !__APPLE__ -# define TARGET_OS_IPHONE 0 -#endif // !__APPLE__ -#endif // ios_hpp +// +// ios.h +// DateTimeLib +// +// The MIT License (MIT) +// +// Copyright (c) 2016 Alexander Kormanovsky +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ios_hpp +#define ios_hpp + +#if __APPLE__ +# include <TargetConditionals.h> +# if TARGET_OS_IPHONE +# include <string> + + namespace arrow_vendored + { + namespace date + { + namespace iOSUtils + { + + std::string get_tzdata_path(); + std::string get_current_timezone(); + + } // namespace iOSUtils + } // namespace date + } // namespace arrow_vendored + +# endif // TARGET_OS_IPHONE +#else // !__APPLE__ +# define TARGET_OS_IPHONE 0 +#endif // !__APPLE__ +#endif // ios_hpp diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.mm b/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.mm index 18c521201d..7d432afe85 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.mm +++ b/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/ios.mm @@ -1,340 +1,340 @@ -// -// The MIT License (MIT) -// -// Copyright (c) 2016 Alexander Kormanovsky -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. -// - -#include "ios.h" - -#if TARGET_OS_IPHONE - -#include <Foundation/Foundation.h> - -#include <fstream> -#include <zlib.h> -#include <sys/stat.h> - -#ifndef TAR_DEBUG -# define TAR_DEBUG 0 -#endif - -#define INTERNAL_DIR "Library" -#define TZDATA_DIR "tzdata" -#define TARGZ_EXTENSION "tar.gz" - -#define TAR_BLOCK_SIZE 512 -#define TAR_TYPE_POSITION 156 -#define TAR_NAME_POSITION 0 -#define TAR_NAME_SIZE 100 -#define TAR_SIZE_POSITION 124 -#define TAR_SIZE_SIZE 12 - -namespace arrow_vendored -{ -namespace date -{ - namespace iOSUtils - { - - struct TarInfo - { - char objType; - std::string objName; - size_t realContentSize; // writable size without padding zeroes - size_t blocksContentSize; // adjusted size to 512 bytes blocks - bool success; - }; - - std::string convertCFStringRefPathToCStringPath(CFStringRef ref); - bool extractTzdata(CFURLRef homeUrl, CFURLRef archiveUrl, std::string destPath); - TarInfo getTarObjectInfo(std::ifstream &readStream); - std::string getTarObject(std::ifstream &readStream, int64_t size); - bool writeFile(const std::string &tzdataPath, const std::string &fileName, - const std::string &data, size_t realContentSize); - - std::string - get_current_timezone() - { - CFTimeZoneRef tzRef = CFTimeZoneCopySystem(); - CFStringRef tzNameRef = CFTimeZoneGetName(tzRef); - CFIndex bufferSize = CFStringGetLength(tzNameRef) + 1; - char buffer[bufferSize]; - - if (CFStringGetCString(tzNameRef, buffer, bufferSize, kCFStringEncodingUTF8)) - { - CFRelease(tzRef); - return std::string(buffer); - } - - CFRelease(tzRef); - - return ""; - } - - std::string - get_tzdata_path() - { - CFURLRef homeUrlRef = CFCopyHomeDirectoryURL(); - CFStringRef homePath = CFURLCopyPath(homeUrlRef); - std::string path(std::string(convertCFStringRefPathToCStringPath(homePath)) + - INTERNAL_DIR + "/" + TZDATA_DIR); - std::string result_path(std::string(convertCFStringRefPathToCStringPath(homePath)) + - INTERNAL_DIR); - - if (access(path.c_str(), F_OK) == 0) - { -#if TAR_DEBUG - printf("tzdata dir exists\n"); -#endif - CFRelease(homeUrlRef); - CFRelease(homePath); - - return result_path; - } - - CFBundleRef mainBundle = CFBundleGetMainBundle(); - CFArrayRef paths = CFBundleCopyResourceURLsOfType(mainBundle, CFSTR(TARGZ_EXTENSION), - NULL); - - if (CFArrayGetCount(paths) != 0) - { - // get archive path, assume there is no other tar.gz in bundle - CFURLRef archiveUrl = static_cast<CFURLRef>(CFArrayGetValueAtIndex(paths, 0)); - CFStringRef archiveName = CFURLCopyPath(archiveUrl); - archiveUrl = CFBundleCopyResourceURL(mainBundle, archiveName, NULL, NULL); - - extractTzdata(homeUrlRef, archiveUrl, path); - - CFRelease(archiveUrl); - CFRelease(archiveName); - } - - CFRelease(homeUrlRef); - CFRelease(homePath); - CFRelease(paths); - - return result_path; - } - - std::string - convertCFStringRefPathToCStringPath(CFStringRef ref) - { - CFIndex bufferSize = CFStringGetMaximumSizeOfFileSystemRepresentation(ref); - char *buffer = new char[bufferSize]; - CFStringGetFileSystemRepresentation(ref, buffer, bufferSize); - auto result = std::string(buffer); - delete[] buffer; - return result; - } - - bool - extractTzdata(CFURLRef homeUrl, CFURLRef archiveUrl, std::string destPath) - { - std::string TAR_TMP_PATH = "/tmp.tar"; - - CFStringRef homeStringRef = CFURLCopyPath(homeUrl); - auto homePath = convertCFStringRefPathToCStringPath(homeStringRef); - CFRelease(homeStringRef); - - CFStringRef archiveStringRef = CFURLCopyPath(archiveUrl); - auto archivePath = convertCFStringRefPathToCStringPath(archiveStringRef); - CFRelease(archiveStringRef); - - // create Library path - auto libraryPath = homePath + INTERNAL_DIR; - - // create tzdata path - auto tzdataPath = libraryPath + "/" + TZDATA_DIR; - - // -- replace %20 with " " - const std::string search = "%20"; - const std::string replacement = " "; - size_t pos = 0; - - while ((pos = archivePath.find(search, pos)) != std::string::npos) { - archivePath.replace(pos, search.length(), replacement); - pos += replacement.length(); - } - - gzFile tarFile = gzopen(archivePath.c_str(), "rb"); - - // create tar unpacking path - auto tarPath = libraryPath + TAR_TMP_PATH; - - // create tzdata directory - mkdir(destPath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); - - // ======= extract tar ======== - - std::ofstream os(tarPath.c_str(), std::ofstream::out | std::ofstream::app); - unsigned int bufferLength = 1024 * 256; // 256Kb - unsigned char *buffer = (unsigned char *)malloc(bufferLength); - bool success = true; - - while (true) - { - int readBytes = gzread(tarFile, buffer, bufferLength); - - if (readBytes > 0) - { - os.write((char *) &buffer[0], readBytes); - } - else - if (readBytes == 0) - { - break; - } - else - if (readBytes == -1) - { - printf("decompression failed\n"); - success = false; - break; - } - else - { - printf("unexpected zlib state\n"); - success = false; - break; - } - } - - os.close(); - free(buffer); - gzclose(tarFile); - - if (!success) - { - remove(tarPath.c_str()); - return false; - } - - // ======== extract files ========= - - uint64_t location = 0; // Position in the file - - // get file size - struct stat stat_buf; - int res = stat(tarPath.c_str(), &stat_buf); - if (res != 0) - { - printf("error file size\n"); - remove(tarPath.c_str()); - return false; - } - int64_t tarSize = stat_buf.st_size; - - // create read stream - std::ifstream is(tarPath.c_str(), std::ifstream::in | std::ifstream::binary); - - // process files - while (location < tarSize) - { - TarInfo info = getTarObjectInfo(is); - - if (!info.success || info.realContentSize == 0) - { - break; // something wrong or all files are read - } - - switch (info.objType) - { - case '0': // file - case '\0': // - { - std::string obj = getTarObject(is, info.blocksContentSize); -#if TAR_DEBUG - size += info.realContentSize; - printf("#%i %s file size %lld written total %ld from %lld\n", ++count, - info.objName.c_str(), info.realContentSize, size, tarSize); -#endif - writeFile(tzdataPath, info.objName, obj, info.realContentSize); - location += info.blocksContentSize; - - break; - } - } - } - - remove(tarPath.c_str()); - - return true; - } - - TarInfo - getTarObjectInfo(std::ifstream &readStream) - { - int64_t length = TAR_BLOCK_SIZE; - char buffer[length]; - char type; - char name[TAR_NAME_SIZE + 1]; - char sizeBuf[TAR_SIZE_SIZE + 1]; - - readStream.read(buffer, length); - - memcpy(&type, &buffer[TAR_TYPE_POSITION], 1); - - memset(&name, '\0', TAR_NAME_SIZE + 1); - memcpy(&name, &buffer[TAR_NAME_POSITION], TAR_NAME_SIZE); - - memset(&sizeBuf, '\0', TAR_SIZE_SIZE + 1); - memcpy(&sizeBuf, &buffer[TAR_SIZE_POSITION], TAR_SIZE_SIZE); - size_t realSize = strtol(sizeBuf, NULL, 8); - size_t blocksSize = realSize + (TAR_BLOCK_SIZE - (realSize % TAR_BLOCK_SIZE)); - - return {type, std::string(name), realSize, blocksSize, true}; - } - - std::string - getTarObject(std::ifstream &readStream, int64_t size) - { - char buffer[size]; - readStream.read(buffer, size); - return std::string(buffer); - } - - bool - writeFile(const std::string &tzdataPath, const std::string &fileName, const std::string &data, - size_t realContentSize) - { - std::ofstream os(tzdataPath + "/" + fileName, std::ofstream::out | std::ofstream::binary); - - if (!os) { - return false; - } - - // trim empty space - char trimmedData[realContentSize + 1]; - memset(&trimmedData, '\0', realContentSize); - memcpy(&trimmedData, data.c_str(), realContentSize); - - // write - os.write(trimmedData, realContentSize); - os.close(); - - return true; - } - - } // namespace iOSUtils -} // namespace date -} // namespace arrow_vendored - -#endif // TARGET_OS_IPHONE +// +// The MIT License (MIT) +// +// Copyright (c) 2016 Alexander Kormanovsky +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// + +#include "ios.h" + +#if TARGET_OS_IPHONE + +#include <Foundation/Foundation.h> + +#include <fstream> +#include <zlib.h> +#include <sys/stat.h> + +#ifndef TAR_DEBUG +# define TAR_DEBUG 0 +#endif + +#define INTERNAL_DIR "Library" +#define TZDATA_DIR "tzdata" +#define TARGZ_EXTENSION "tar.gz" + +#define TAR_BLOCK_SIZE 512 +#define TAR_TYPE_POSITION 156 +#define TAR_NAME_POSITION 0 +#define TAR_NAME_SIZE 100 +#define TAR_SIZE_POSITION 124 +#define TAR_SIZE_SIZE 12 + +namespace arrow_vendored +{ +namespace date +{ + namespace iOSUtils + { + + struct TarInfo + { + char objType; + std::string objName; + size_t realContentSize; // writable size without padding zeroes + size_t blocksContentSize; // adjusted size to 512 bytes blocks + bool success; + }; + + std::string convertCFStringRefPathToCStringPath(CFStringRef ref); + bool extractTzdata(CFURLRef homeUrl, CFURLRef archiveUrl, std::string destPath); + TarInfo getTarObjectInfo(std::ifstream &readStream); + std::string getTarObject(std::ifstream &readStream, int64_t size); + bool writeFile(const std::string &tzdataPath, const std::string &fileName, + const std::string &data, size_t realContentSize); + + std::string + get_current_timezone() + { + CFTimeZoneRef tzRef = CFTimeZoneCopySystem(); + CFStringRef tzNameRef = CFTimeZoneGetName(tzRef); + CFIndex bufferSize = CFStringGetLength(tzNameRef) + 1; + char buffer[bufferSize]; + + if (CFStringGetCString(tzNameRef, buffer, bufferSize, kCFStringEncodingUTF8)) + { + CFRelease(tzRef); + return std::string(buffer); + } + + CFRelease(tzRef); + + return ""; + } + + std::string + get_tzdata_path() + { + CFURLRef homeUrlRef = CFCopyHomeDirectoryURL(); + CFStringRef homePath = CFURLCopyPath(homeUrlRef); + std::string path(std::string(convertCFStringRefPathToCStringPath(homePath)) + + INTERNAL_DIR + "/" + TZDATA_DIR); + std::string result_path(std::string(convertCFStringRefPathToCStringPath(homePath)) + + INTERNAL_DIR); + + if (access(path.c_str(), F_OK) == 0) + { +#if TAR_DEBUG + printf("tzdata dir exists\n"); +#endif + CFRelease(homeUrlRef); + CFRelease(homePath); + + return result_path; + } + + CFBundleRef mainBundle = CFBundleGetMainBundle(); + CFArrayRef paths = CFBundleCopyResourceURLsOfType(mainBundle, CFSTR(TARGZ_EXTENSION), + NULL); + + if (CFArrayGetCount(paths) != 0) + { + // get archive path, assume there is no other tar.gz in bundle + CFURLRef archiveUrl = static_cast<CFURLRef>(CFArrayGetValueAtIndex(paths, 0)); + CFStringRef archiveName = CFURLCopyPath(archiveUrl); + archiveUrl = CFBundleCopyResourceURL(mainBundle, archiveName, NULL, NULL); + + extractTzdata(homeUrlRef, archiveUrl, path); + + CFRelease(archiveUrl); + CFRelease(archiveName); + } + + CFRelease(homeUrlRef); + CFRelease(homePath); + CFRelease(paths); + + return result_path; + } + + std::string + convertCFStringRefPathToCStringPath(CFStringRef ref) + { + CFIndex bufferSize = CFStringGetMaximumSizeOfFileSystemRepresentation(ref); + char *buffer = new char[bufferSize]; + CFStringGetFileSystemRepresentation(ref, buffer, bufferSize); + auto result = std::string(buffer); + delete[] buffer; + return result; + } + + bool + extractTzdata(CFURLRef homeUrl, CFURLRef archiveUrl, std::string destPath) + { + std::string TAR_TMP_PATH = "/tmp.tar"; + + CFStringRef homeStringRef = CFURLCopyPath(homeUrl); + auto homePath = convertCFStringRefPathToCStringPath(homeStringRef); + CFRelease(homeStringRef); + + CFStringRef archiveStringRef = CFURLCopyPath(archiveUrl); + auto archivePath = convertCFStringRefPathToCStringPath(archiveStringRef); + CFRelease(archiveStringRef); + + // create Library path + auto libraryPath = homePath + INTERNAL_DIR; + + // create tzdata path + auto tzdataPath = libraryPath + "/" + TZDATA_DIR; + + // -- replace %20 with " " + const std::string search = "%20"; + const std::string replacement = " "; + size_t pos = 0; + + while ((pos = archivePath.find(search, pos)) != std::string::npos) { + archivePath.replace(pos, search.length(), replacement); + pos += replacement.length(); + } + + gzFile tarFile = gzopen(archivePath.c_str(), "rb"); + + // create tar unpacking path + auto tarPath = libraryPath + TAR_TMP_PATH; + + // create tzdata directory + mkdir(destPath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); + + // ======= extract tar ======== + + std::ofstream os(tarPath.c_str(), std::ofstream::out | std::ofstream::app); + unsigned int bufferLength = 1024 * 256; // 256Kb + unsigned char *buffer = (unsigned char *)malloc(bufferLength); + bool success = true; + + while (true) + { + int readBytes = gzread(tarFile, buffer, bufferLength); + + if (readBytes > 0) + { + os.write((char *) &buffer[0], readBytes); + } + else + if (readBytes == 0) + { + break; + } + else + if (readBytes == -1) + { + printf("decompression failed\n"); + success = false; + break; + } + else + { + printf("unexpected zlib state\n"); + success = false; + break; + } + } + + os.close(); + free(buffer); + gzclose(tarFile); + + if (!success) + { + remove(tarPath.c_str()); + return false; + } + + // ======== extract files ========= + + uint64_t location = 0; // Position in the file + + // get file size + struct stat stat_buf; + int res = stat(tarPath.c_str(), &stat_buf); + if (res != 0) + { + printf("error file size\n"); + remove(tarPath.c_str()); + return false; + } + int64_t tarSize = stat_buf.st_size; + + // create read stream + std::ifstream is(tarPath.c_str(), std::ifstream::in | std::ifstream::binary); + + // process files + while (location < tarSize) + { + TarInfo info = getTarObjectInfo(is); + + if (!info.success || info.realContentSize == 0) + { + break; // something wrong or all files are read + } + + switch (info.objType) + { + case '0': // file + case '\0': // + { + std::string obj = getTarObject(is, info.blocksContentSize); +#if TAR_DEBUG + size += info.realContentSize; + printf("#%i %s file size %lld written total %ld from %lld\n", ++count, + info.objName.c_str(), info.realContentSize, size, tarSize); +#endif + writeFile(tzdataPath, info.objName, obj, info.realContentSize); + location += info.blocksContentSize; + + break; + } + } + } + + remove(tarPath.c_str()); + + return true; + } + + TarInfo + getTarObjectInfo(std::ifstream &readStream) + { + int64_t length = TAR_BLOCK_SIZE; + char buffer[length]; + char type; + char name[TAR_NAME_SIZE + 1]; + char sizeBuf[TAR_SIZE_SIZE + 1]; + + readStream.read(buffer, length); + + memcpy(&type, &buffer[TAR_TYPE_POSITION], 1); + + memset(&name, '\0', TAR_NAME_SIZE + 1); + memcpy(&name, &buffer[TAR_NAME_POSITION], TAR_NAME_SIZE); + + memset(&sizeBuf, '\0', TAR_SIZE_SIZE + 1); + memcpy(&sizeBuf, &buffer[TAR_SIZE_POSITION], TAR_SIZE_SIZE); + size_t realSize = strtol(sizeBuf, NULL, 8); + size_t blocksSize = realSize + (TAR_BLOCK_SIZE - (realSize % TAR_BLOCK_SIZE)); + + return {type, std::string(name), realSize, blocksSize, true}; + } + + std::string + getTarObject(std::ifstream &readStream, int64_t size) + { + char buffer[size]; + readStream.read(buffer, size); + return std::string(buffer); + } + + bool + writeFile(const std::string &tzdataPath, const std::string &fileName, const std::string &data, + size_t realContentSize) + { + std::ofstream os(tzdataPath + "/" + fileName, std::ofstream::out | std::ofstream::binary); + + if (!os) { + return false; + } + + // trim empty space + char trimmedData[realContentSize + 1]; + memset(&trimmedData, '\0', realContentSize); + memcpy(&trimmedData, data.c_str(), realContentSize); + + // write + os.write(trimmedData, realContentSize); + os.close(); + + return true; + } + + } // namespace iOSUtils +} // namespace date +} // namespace arrow_vendored + +#endif // TARGET_OS_IPHONE diff --git a/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/tz.cpp b/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/tz.cpp index e80e392bd7..b1ae8a70fe 100644 --- a/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/tz.cpp +++ b/contrib/libs/apache/arrow/cpp/src/arrow/vendored/datetime/tz.cpp @@ -89,7 +89,7 @@ #include "tz_private.h" #ifdef __APPLE__ -# include "ios.h" +# include "ios.h" #else # define TARGET_OS_IPHONE 0 # define TARGET_OS_SIMULATOR 0 @@ -1309,7 +1309,7 @@ void detail::Rule::split(std::vector<Rule>& rules, std::size_t i, std::size_t k, std::size_t& e) { using namespace date; - using difference_type = std::iterator_traits<std::vector<Rule>::iterator>::difference_type; + using difference_type = std::iterator_traits<std::vector<Rule>::iterator>::difference_type; // rules[i].starting_year_ <= rules[k].starting_year_ && // rules[i].ending_year_ >= rules[k].starting_year_ && // (rules[i].starting_year_ != rules[k].starting_year_ || @@ -1377,7 +1377,7 @@ detail::Rule::split(std::vector<Rule>& rules, std::size_t i, std::size_t k, std: void detail::Rule::split_overlaps(std::vector<Rule>& rules, std::size_t i, std::size_t& e) { - using difference_type = std::iterator_traits<std::vector<Rule>::iterator>::difference_type; + using difference_type = std::iterator_traits<std::vector<Rule>::iterator>::difference_type; auto j = i; for (; i + 1 < e; ++i) { @@ -1401,7 +1401,7 @@ detail::Rule::split_overlaps(std::vector<Rule>& rules, std::size_t i, std::size_ void detail::Rule::split_overlaps(std::vector<Rule>& rules) { - using difference_type = std::iterator_traits<std::vector<Rule>::iterator>::difference_type; + using difference_type = std::iterator_traits<std::vector<Rule>::iterator>::difference_type; for (std::size_t i = 0; i < rules.size();) { auto e = static_cast<std::size_t>(std::upper_bound( diff --git a/contrib/libs/apache/arrow/src/arrow/util/config.h b/contrib/libs/apache/arrow/src/arrow/util/config.h index 2d46017e47..49dfcf79a3 100644 --- a/contrib/libs/apache/arrow/src/arrow/util/config.h +++ b/contrib/libs/apache/arrow/src/arrow/util/config.h @@ -35,7 +35,7 @@ #define ARROW_PACKAGE_KIND "" #define ARROW_COMPUTE -#define ARROW_CSV +#define ARROW_CSV /* #undef ARROW_DATASET */ /* #undef ARROW_FILESYSTEM */ /* #undef ARROW_FLIGHT */ diff --git a/contrib/libs/apache/arrow/ya.make b/contrib/libs/apache/arrow/ya.make index 27b9235d9e..e85b49d3f6 100644 --- a/contrib/libs/apache/arrow/ya.make +++ b/contrib/libs/apache/arrow/ya.make @@ -1,4 +1,4 @@ -# Generated by devtools/yamaker from nixpkgs 3322db8e36d0b32700737d8de7315bd9e9c2b21a. +# Generated by devtools/yamaker from nixpkgs 3322db8e36d0b32700737d8de7315bd9e9c2b21a. LIBRARY() @@ -35,7 +35,7 @@ PEERDIR( contrib/libs/lz4 contrib/libs/re2 contrib/libs/snappy - contrib/libs/utf8proc + contrib/libs/utf8proc contrib/libs/xxhash contrib/libs/zlib contrib/libs/zstd @@ -53,9 +53,9 @@ ADDINCL( contrib/libs/flatbuffers/include contrib/libs/lz4 contrib/libs/re2 - contrib/libs/utf8proc + contrib/libs/utf8proc contrib/libs/zstd/include - contrib/restricted/boost + contrib/restricted/boost ) NO_COMPILER_WARNINGS() @@ -63,14 +63,14 @@ NO_COMPILER_WARNINGS() NO_UTIL() CFLAGS( - GLOBAL -DARROW_STATIC + GLOBAL -DARROW_STATIC -DARROW_EXPORTING -DARROW_WITH_BROTLI -DARROW_WITH_LZ4 -DARROW_WITH_RE2 -DARROW_WITH_SNAPPY -DARROW_WITH_TIMING_TESTS - -DARROW_WITH_UTF8PROC + -DARROW_WITH_UTF8PROC -DARROW_WITH_ZLIB -DARROW_WITH_ZSTD -DHAVE_INTTYPES_H @@ -80,9 +80,9 @@ CFLAGS( ) IF (NOT OS_WINDOWS) - CFLAGS( - -DHAVE_NETINET_IN_H - ) + CFLAGS( + -DHAVE_NETINET_IN_H + ) ENDIF() SRCS( @@ -159,14 +159,14 @@ SRCS( cpp/src/arrow/compute/kernels/vector_sort.cc cpp/src/arrow/compute/registry.cc cpp/src/arrow/config.cc - cpp/src/arrow/csv/chunker.cc - cpp/src/arrow/csv/column_builder.cc - cpp/src/arrow/csv/column_decoder.cc - cpp/src/arrow/csv/converter.cc - cpp/src/arrow/csv/options.cc - cpp/src/arrow/csv/parser.cc - cpp/src/arrow/csv/reader.cc - cpp/src/arrow/csv/writer.cc + cpp/src/arrow/csv/chunker.cc + cpp/src/arrow/csv/column_builder.cc + cpp/src/arrow/csv/column_decoder.cc + cpp/src/arrow/csv/converter.cc + cpp/src/arrow/csv/options.cc + cpp/src/arrow/csv/parser.cc + cpp/src/arrow/csv/reader.cc + cpp/src/arrow/csv/writer.cc cpp/src/arrow/datum.cc cpp/src/arrow/device.cc cpp/src/arrow/extension_type.cc |