// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #include #include #include #include #include #include #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compare.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/api_scalar.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/cast.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernel.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernels/codegen_internal.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernels/common_internal.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/compute/kernels/util_internal.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/type.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/type_traits.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/util/decimal.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/util/int_util_overflow.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/util/macros.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/visit_scalar_inline.h" #include "contrib/libs/apache/arrow_next/cpp/src/arrow/visit_type_inline.h" namespace arrow20 { using internal::AddWithOverflow; using internal::DivideWithOverflow; using internal::MultiplyWithOverflow; using internal::NegateWithOverflow; using internal::SubtractWithOverflow; namespace compute::internal { using applicator::ScalarBinary; using applicator::ScalarBinaryEqualTypes; using applicator::ScalarBinaryNotNull; using applicator::ScalarBinaryNotNullEqualTypes; using applicator::ScalarUnary; using applicator::ScalarUnaryNotNull; using applicator::ScalarUnaryNotNullStateful; namespace { // ---------------------------------------------------------------------- // Begin utility structs for round kernels // Convenience visitor to detect if a numeric Scalar is positive. struct IsPositiveVisitor { bool result = false; template Status Visit(const NumericScalar& scalar) { result = scalar.value > 0; return Status::OK(); } template Status Visit(const DecimalScalar& scalar) { result = scalar.value > 0; return Status::OK(); } Status Visit(const Scalar& scalar) { return Status::OK(); } }; bool IsPositive(const Scalar& scalar) { IsPositiveVisitor visitor{}; std::ignore = VisitScalarInline(scalar, &visitor); return visitor.result; } struct RoundUtil { // Calculate powers of ten with arbitrary integer exponent template static enable_if_floating_value Pow10(int64_t power) { static constexpr T lut[] = {1e0F, 1e1F, 1e2F, 1e3F, 1e4F, 1e5F, 1e6F, 1e7F, 1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F}; int64_t lut_size = (sizeof(lut) / sizeof(*lut)); int64_t abs_power = std::abs(power); auto pow10 = lut[std::min(abs_power, lut_size - 1)]; while (abs_power-- >= lut_size) { pow10 *= 1e1F; } return (power >= 0) ? pow10 : (1 / pow10); } // Calculate powers of ten with arbitrary integer exponent template static enable_if_integer_value Pow10(int64_t power) { DCHECK_GE(power, 0); DCHECK_LE(power, std::numeric_limits::digits10); static constexpr uint64_t lut[] = { // clang-format off 1ULL, 10ULL, 100ULL, 1000ULL, 10000ULL, 100000ULL, 1000000ULL, 10000000ULL, 100000000ULL, 1000000000ULL, 10000000000ULL, 100000000000ULL, 1000000000000ULL, 10000000000000ULL, 100000000000000ULL, 1000000000000000ULL, 10000000000000000ULL, 100000000000000000ULL, 1000000000000000000ULL // clang-format on }; return static_cast(lut[power]); } }; // End utility structs for round kernels // ---------------------------------------------------------------------- // Begin round implementations for single scalar // Specializations of rounding implementations for round kernels template struct RoundImpl; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return std::floor(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { (*val) -= remainder; if (remainder.Sign() < 0) { (*val) -= pow10; } } template static enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { if constexpr (is_signed_integer_value::value) { if (ARROW_PREDICT_FALSE(val < 0 && std::numeric_limits::min() + multiple > floor)) { *st = Status::Invalid("Rounding ", val, " down to multiple of ", multiple, " would overflow"); return val; } return val < 0 ? floor - multiple : floor; } return floor; } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return std::ceil(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { (*val) -= remainder; if (remainder.Sign() > 0 && remainder != 0) { (*val) += pow10; } } template static enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { if (ARROW_PREDICT_FALSE(val > 0 && std::numeric_limits::max() - multiple < floor)) { *st = Status::Invalid("Rounding ", val, " up to multiple of ", multiple, " would overflow"); return val; } return val > 0 ? floor + multiple : floor; } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return std::trunc(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { (*val) -= remainder; } template static enable_if_integer_value Round(const T val, const T floor, const T pow10, Status* st) { return floor; } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return std::signbit(val) ? std::floor(val) : std::ceil(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { (*val) -= remainder; if (remainder.Sign() < 0) { (*val) -= pow10; } else if (remainder.Sign() > 0 && remainder != 0) { (*val) += pow10; } } template static enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { if constexpr (is_signed_integer_value::value) { if (ARROW_PREDICT_FALSE(val < 0 && std::numeric_limits::min() + multiple > floor)) { *st = Status::Invalid("Rounding ", val, " down to multiple of ", multiple, " would overflow"); return val; } } if (ARROW_PREDICT_FALSE(val > 0 && std::numeric_limits::max() - multiple < floor)) { *st = Status::Invalid("Rounding ", val, " up to multiple of ", multiple, " would overflow"); return val; } if constexpr (is_signed_integer_value::value) { return val < 0 ? floor - multiple : floor + multiple; } return floor + multiple; } }; // NOTE: RoundImpl variants for the HALF_* rounding modes are only // invoked when the fractional part is equal to 0.5 (std::round is invoked // otherwise). template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return RoundImpl::Round(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { RoundImpl::Round(val, remainder, pow10, scale); } template static constexpr enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { return RoundImpl::Round(val, floor, multiple, st); } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return RoundImpl::Round(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { RoundImpl::Round(val, remainder, pow10, scale); } template static constexpr enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { return RoundImpl::Round(val, floor, multiple, st); } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return RoundImpl::Round(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { RoundImpl::Round(val, remainder, pow10, scale); } template static constexpr enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { return RoundImpl::Round(val, floor, multiple, st); } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return RoundImpl::Round(val); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& multiple, const int32_t scale) { RoundImpl::Round(val, remainder, multiple, scale); } template static constexpr enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { return RoundImpl::Round(val, floor, multiple, st); } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return std::round(val * T(0.5)) * 2; } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { auto scaled = val->ReduceScaleBy(scale, /*round=*/false); if (scaled.low_bits() % 2 != 0) { scaled += remainder.Sign() >= 0 ? 1 : -1; } *val = scaled.IncreaseScaleBy(scale); } template static constexpr enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { if ((floor / multiple) % 2 == 0) { return floor; } return RoundImpl::Round(val, floor, multiple, st); } }; template struct RoundImpl { template static constexpr enable_if_floating_value Round(const T val) { return std::floor(val * T(0.5)) + std::ceil(val * T(0.5)); } template static enable_if_decimal_value Round(T* val, const T& remainder, const T& pow10, const int32_t scale) { auto scaled = val->ReduceScaleBy(scale, /*round=*/false); if (scaled.low_bits() % 2 == 0) { scaled += remainder.Sign() ? 1 : -1; } *val = scaled.IncreaseScaleBy(scale); } template static constexpr enable_if_integer_value Round(const T val, const T floor, const T multiple, Status* st) { if ((floor / multiple) % 2 != 0) { return floor; } return RoundImpl::Round(val, floor, multiple, st); } }; // End round implementations for single scalar // ---------------------------------------------------------------------- // Begin round options wrappers // Specializations of kernel state for round kernels // CType is the physical type used to store pow10 template struct RoundOptionsWrapper; template struct RoundOptionsWrapper : public OptionsWrapper { using OptionsType = RoundOptions; CType pow10; explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) { // Only positive exponents for powers of 10 are used because combining // multiply and division operations produced more stable rounding than // using multiply-only. Refer to NumPy's round implementation: // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589 pow10 = RoundUtil::Pow10(std::abs(options.ndigits)); } static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { if (auto options = static_cast(args.options)) { if constexpr (is_integer_value::value) { if (-options->ndigits > std::numeric_limits::digits10) { return Status::Invalid("Rounding to ", options->ndigits, " digits is out of range for type ", args.inputs[0].ToString()); } } return std::make_unique(*options); } return Status::Invalid( "Attempted to initialize KernelState from null FunctionOptions"); } }; template struct RoundOptionsWrapper : public OptionsWrapper { using OptionsType = RoundBinaryOptions; explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) {} static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { if (auto options = static_cast(args.options)) { return std::make_unique(*options); } return Status::Invalid( "Attempted to initialize KernelState from null FunctionOptions"); } }; template struct RoundOptionsWrapper : public OptionsWrapper { using OptionsType = RoundToMultipleOptions; using OptionsWrapper::OptionsWrapper; static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { auto options = static_cast(args.options); if (!options) { return Status::Invalid( "Attempted to initialize KernelState from null FunctionOptions"); } const auto& multiple = options->multiple; if (!multiple || !multiple->is_valid) { return Status::Invalid("Rounding multiple must be non-null and valid"); } if (!IsPositive(*multiple)) { return Status::Invalid("Rounding multiple must be positive"); } // Ensure the rounding multiple option matches the kernel's input type. std::shared_ptr to_type = args.inputs[0].GetSharedPtr(); if (!multiple->type->Equals(to_type)) { ARROW_ASSIGN_OR_RAISE( auto casted_multiple, Cast(Datum(multiple), to_type, CastOptions::Safe(), ctx->exec_context())); // Create a new option object if the rounding multiple was casted. auto new_options = OptionsType(casted_multiple.scalar(), options->round_mode); return std::make_unique(new_options); } return std::make_unique(*options); } }; template struct RoundOptionsTrait; template struct RoundOptionsTrait> { using CType = double; }; template struct RoundOptionsTrait> { using CType = double; }; template struct RoundOptionsTrait> { using CType = typename ArrowType::c_type; }; // End round options wrappers // ---------------------------------------------------------------------- // Begin round op implementations template struct RoundToMultiple { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper::CType>; CType multiple; explicit RoundToMultiple(const State& state, const DataType& out_ty) : multiple(UnboxScalar::Unbox(*state.options.multiple)) { const auto& options = state.options; DCHECK(options.multiple); DCHECK(options.multiple->is_valid); DCHECK(is_floating(options.multiple->type->id())); } template ::CType> enable_if_floating_value Call(KernelContext* ctx, CType arg, Status* st) const { // Do not process Inf or NaN because they will trigger the overflow error at end of // function. if (!std::isfinite(arg)) { return arg; } auto round_val = arg / multiple; auto frac = round_val - std::floor(round_val); if (frac != T(0)) { // Use std::round() if in tie-breaking mode and scaled value is not 0.5. if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { round_val = std::round(round_val); } else { round_val = RoundImpl::Round(round_val); } round_val *= multiple; if (!std::isfinite(round_val)) { *st = Status::Invalid("overflow occurred during rounding"); return arg; } } else { // If scaled value is an integer, then no rounding is needed. round_val = arg; } return round_val; } }; template struct RoundToMultiple> { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; const ArrowType& ty; CType multiple, half_multiple, neg_half_multiple; bool has_halfway_point; explicit RoundToMultiple(const State& state, const DataType& out_ty) : ty(checked_cast(out_ty)), multiple(UnboxScalar::Unbox(*state.options.multiple)), half_multiple(multiple / 2), neg_half_multiple(-half_multiple), has_halfway_point(multiple.low_bits() % 2 == 0) { const auto& options = state.options; DCHECK(options.multiple); DCHECK(options.multiple->is_valid); DCHECK(options.multiple->type->Equals(out_ty)); } template ::CType> enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const { std::pair pair; *st = arg.Divide(multiple).Value(&pair); if (!st->ok()) return arg; const auto& remainder = pair.second; if (remainder == 0) return arg; if (kRoundMode >= RoundMode::HALF_DOWN) { if (has_halfway_point && (remainder == half_multiple || remainder == neg_half_multiple)) { // On the halfway point, use tiebreaker // Manually implement rounding since we're not actually rounding a // decimal value, but rather manipulating the multiple switch (kRoundMode) { case RoundMode::HALF_DOWN: if (remainder.Sign() < 0) pair.first -= 1; break; case RoundMode::HALF_UP: if (remainder.Sign() >= 0) pair.first += 1; break; case RoundMode::HALF_TOWARDS_ZERO: // Do nothing break; case RoundMode::HALF_TOWARDS_INFINITY: pair.first += remainder.Sign() >= 0 ? 1 : -1; break; case RoundMode::HALF_TO_EVEN: if (pair.first.low_bits() % 2 != 0) { pair.first += remainder.Sign() >= 0 ? 1 : -1; } break; case RoundMode::HALF_TO_ODD: if (pair.first.low_bits() % 2 == 0) { pair.first += remainder.Sign() >= 0 ? 1 : -1; } break; default: DCHECK(false); } } else if (remainder.Sign() >= 0) { // Positive, round up/down if (remainder > half_multiple) { pair.first += 1; } } else { // Negative, round up/down if (remainder < neg_half_multiple) { pair.first -= 1; } } } else { // Manually implement rounding since we're not actually rounding a // decimal value, but rather manipulating the multiple switch (kRoundMode) { case RoundMode::DOWN: if (remainder.Sign() < 0) pair.first -= 1; break; case RoundMode::UP: if (remainder.Sign() >= 0) pair.first += 1; break; case RoundMode::TOWARDS_ZERO: // Do nothing break; case RoundMode::TOWARDS_INFINITY: pair.first += remainder.Sign() >= 0 ? 1 : -1; break; default: DCHECK(false); } } CType round_val = pair.first * multiple; if (!round_val.FitsInPrecision(ty.precision())) { *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()), " does not fit in precision of ", ty); return 0; } return round_val; } }; template struct RoundToMultiple> { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; CType multiple; explicit RoundToMultiple(const State& state, const DataType& out_ty) : multiple(UnboxScalar::Unbox(*state.options.multiple)) { const auto& options = state.options; DCHECK(options.multiple); DCHECK(options.multiple->is_valid); DCHECK(is_integer(options.multiple->type->id())); } explicit RoundToMultiple(const CType multiple, const DataType& out_ty) : multiple(multiple) {} template ::CType> enable_if_integer_value Call(KernelContext* ctx, CType arg, Status* st) const { CType floor = arg / multiple * multiple; CType remainder = arg > floor ? arg - floor : floor - arg; if (remainder == 0) { return arg; } if (kRoundMode >= RoundMode::HALF_DOWN && remainder * 2 != multiple) { // not half way, round to nearest multiple of multiple like std::round if (remainder * 2 > multiple) { if (arg >= 0) { if (ARROW_PREDICT_FALSE(std::numeric_limits::max() - multiple < floor)) { *st = Status::Invalid("Rounding ", arg, " up to multiples of ", multiple, " would overflow"); return arg; } return floor + multiple; } else { if (ARROW_PREDICT_FALSE(std::numeric_limits::min() + multiple > floor)) { *st = Status::Invalid("Rounding ", arg, " down to multiples of ", multiple, " would overflow"); return arg; } return floor - multiple; } } else { return floor; } } else { return RoundImpl::Round(arg, floor, multiple, st); } } }; template struct Round { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper::CType>; CType pow10; int64_t ndigits; explicit Round(const State& state, const DataType& out_ty) : pow10(static_cast(state.pow10)), ndigits(state.options.ndigits) {} template ::CType> enable_if_floating_value Call(KernelContext* ctx, CType arg, Status* st) const { // Do not process Inf or NaN because they will trigger the overflow error at end of // function. if (!std::isfinite(arg)) { return arg; } auto round_val = ndigits >= 0 ? (arg * pow10) : (arg / pow10); auto frac = round_val - std::floor(round_val); if (frac != T(0)) { // Use std::round() if in tie-breaking mode and scaled value is not 0.5. if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { round_val = std::round(round_val); } else { round_val = RoundImpl::Round(round_val); } // Equality check is omitted so that the common case of 10^0 (integer rounding) // uses multiply-only round_val = ndigits > 0 ? (round_val / pow10) : (round_val * pow10); if (!std::isfinite(round_val)) { *st = Status::Invalid("overflow occurred during rounding"); return arg; } } else { // If scaled value is an integer, then no rounding is needed. round_val = arg; } return round_val; } }; template struct Round> { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; const ArrowType& ty; int64_t ndigits; int32_t pow; // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5". CType pow10, half_pow10, neg_half_pow10; explicit Round(const State& state, const DataType& out_ty) : Round(state.options.ndigits, out_ty) {} explicit Round(int64_t ndigits, const DataType& out_ty) : ty(checked_cast(out_ty)), ndigits(ndigits), pow(static_cast(ty.scale() - ndigits)) { if (pow >= ty.precision() || pow < 0) { pow10 = half_pow10 = neg_half_pow10 = 0; } else { pow10 = CType::GetScaleMultiplier(pow); half_pow10 = CType::GetHalfScaleMultiplier(pow); neg_half_pow10 = -half_pow10; } } template ::CType> enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const { if (pow >= ty.precision()) { *st = Status::Invalid("Rounding to ", ndigits, " digits will not fit in precision of ", ty); return 0; } else if (pow < 0) { // no-op, copy output to input return arg; } std::pair pair; *st = arg.Divide(pow10).Value(&pair); if (!st->ok()) return arg; // The remainder is effectively the scaled fractional part after division. const auto& remainder = pair.second; if (remainder == 0) return arg; if (kRoundMode >= RoundMode::HALF_DOWN) { if (remainder == half_pow10 || remainder == neg_half_pow10) { // On the halfway point, use tiebreaker RoundImpl::Round(&arg, remainder, pow10, pow); } else if (remainder.Sign() >= 0) { // Positive, round up/down arg -= remainder; if (remainder > half_pow10) { arg += pow10; } } else { // Negative, round up/down arg -= remainder; if (remainder < neg_half_pow10) { arg -= pow10; } } } else { RoundImpl::Round(&arg, remainder, pow10, pow); } if (!arg.FitsInPrecision(ty.precision())) { *st = Status::Invalid("Rounded value ", arg.ToString(ty.scale()), " does not fit in precision of ", ty); return 0; } return arg; } }; template struct Round> { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; CType pow10; int64_t ndigits; const DataType& out_ty; explicit Round(const State& state, const DataType& out_ty) : pow10(static_cast(state.pow10)), ndigits(state.options.ndigits), out_ty(out_ty) {} template ::CType> enable_if_integer_value Call(KernelContext* ctx, CType arg, Status* st) const { // no-op if ndigits is non-negative if (ndigits >= 0) { return arg; } // If ndigits is negative, then round to the nearest multiple of 10^ndigits. RoundToMultiple round_to_multiple(pow10, out_ty); return round_to_multiple.Call(ctx, arg, st); } }; template struct RoundBinary { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper::CType>; explicit RoundBinary(const State& state, const DataType& out_ty) {} template ::CType0, typename CType1 = typename TypeTraits::CType1> enable_if_floating_value Call(KernelContext* ctx, CType0 arg0, CType1 arg1, Status* st) const { // Do not process Inf or NaN because they will trigger the overflow error at end of // function. if (!std::isfinite(arg0)) { return arg0; } // Only positive exponents for powers of 10 are used because combining // multiply and division operations produced more stable rounding than // using multiply-only. Refer to NumPy's round implementation: // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589 double pow10 = RoundUtil::Pow10(std::abs(arg1)); auto round_val = arg1 >= 0 ? (arg0 * pow10) : (arg0 / pow10); auto frac = round_val - std::floor(round_val); if (frac != T(0)) { // Use std::round() if in tie-breaking mode and scaled value is not 0.5. if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { round_val = std::round(round_val); } else { round_val = RoundImpl::Round(round_val); } // Equality check is omitted so that the common case of 10^0 (integer rounding) // uses multiply-only round_val = arg1 > 0 ? (round_val / pow10) : (round_val * pow10); if (!std::isfinite(round_val)) { *st = Status::Invalid("overflow occurred during rounding"); return arg0; } } else { // If scaled value is an integer, then no rounding is needed. round_val = arg0; } return static_cast(round_val); } }; template struct RoundBinary> { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; const ArrowType& ty; int32_t pow; // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5". CType half_pow10, neg_half_pow10; explicit RoundBinary(const State& state, const DataType& out_ty) : RoundBinary(out_ty) {} explicit RoundBinary(const DataType& out_ty) : ty(checked_cast(out_ty)), pow(static_cast(ty.scale() - 0)) { if (pow >= ty.precision() || pow < 0) { half_pow10 = neg_half_pow10 = 0; } else { half_pow10 = CType::GetHalfScaleMultiplier(pow); neg_half_pow10 = -half_pow10; } } template ::CType0, typename CType1 = typename TypeTraits::CType1> enable_if_decimal_value Call(KernelContext* ctx, CType0 arg0, CType1 arg1, Status* st) const { if (pow - arg1 >= ty.precision()) { *st = Status::Invalid("Rounding to ", arg1, " digits will not fit in precision of ", ty); return 0; } else if (pow < 0) { // no-op, copy output to input return arg0; } CType0 pow10 = CType0::GetScaleMultiplier(static_cast(ty.scale() - arg1)); std::pair pair; *st = arg0.Divide(pow10).Value(&pair); if (!st->ok()) return arg0; // The remainder is effectively the scaled fractional part after division. const auto& remainder = pair.second; if (remainder == 0) return arg0; if (kRoundMode >= RoundMode::HALF_DOWN) { if (remainder == half_pow10 || remainder == neg_half_pow10) { // On the halfway point, use tiebreaker RoundImpl::Round(&arg0, remainder, pow10, pow); } else if (remainder.Sign() >= 0) { // Positive, round up/down arg0 -= remainder; if (remainder > half_pow10) { arg0 += pow10; } } else { // Negative, round up/down arg0 -= remainder; if (remainder < neg_half_pow10) { arg0 -= pow10; } } } else { RoundImpl::Round(&arg0, remainder, pow10, pow); } if (!arg0.FitsInPrecision(ty.precision())) { *st = Status::Invalid("Rounded value ", arg0.ToString(ty.scale()), " does not fit in precision of ", ty); return 0; } return arg0; } }; template struct RoundBinary> { using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; const DataType& out_ty; explicit RoundBinary(const State& state, const DataType& out_ty) : out_ty(out_ty) {} template ::CType0, typename CType1 = typename TypeTraits::CType1> enable_if_integer_value Call(KernelContext* ctx, CType0 arg0, CType1 arg1, Status* st) const { // ndigits >= 0 is a no-op if (arg1 >= 0) { return arg0; } if (-arg1 > std::numeric_limits::digits10) { // ndigits is larger than the number of digits CType can hold *st = Status::Invalid("Rounding to ", arg1, " digits is out of range for type ", out_ty.ToString()); return arg0; } // If ndigits is negative, then round to the nearest multiple of 10^ndigits. CType pow10 = RoundUtil::Pow10(std::abs(arg1)); RoundToMultiple round_to_multiple(pow10, out_ty); return round_to_multiple.Call(ctx, arg0, st); } }; struct Floor { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { static_assert(std::is_same::value); return RoundImpl::Round(arg); } }; struct Ceil { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { static_assert(std::is_same::value); return RoundImpl::Round(arg); } }; struct Trunc { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { static_assert(std::is_same::value); return RoundImpl::Round(arg); } }; // End round op implementations // ---------------------------------------------------------------------- // Begin round functions struct RoundFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* types) const override { RETURN_NOT_OK(CheckArity(types->size())); using arrow20::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); // for binary round functions, the second scalar must be int32 if (types->size() == 2 && (*types)[1].id() != Type::INT32) { (*types)[1] = int32(); } if (auto kernel = DispatchExactImpl(this, *types)) return kernel; return arrow20::compute::detail::NoMatchingKernel(this, *types); } }; /// A RoundFunction that promotes only decimal arguments to double. struct RoundDecimalToFloatingPointFunction : public RoundFunction { using RoundFunction::RoundFunction; Result DispatchBest(std::vector* types) const override { RETURN_NOT_OK(CheckArity(types->size())); using arrow20::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); // Size of types is checked above. const auto originalType = (*types)[0]; if (is_decimal((*types)[0].type->id())) { (*types)[0] = float64(); } if (auto kernel = DispatchExactImpl(this, *types)) return kernel; (*types)[0] = originalType; return arrow20::compute::detail::NoMatchingKernel(this, *types); } }; /// A RoundFunction that promotes only the first integer argument to double. struct RoundIntegerToFloatingPointFunction : public RoundFunction { using RoundFunction::RoundFunction; Result DispatchBest(std::vector* types) const override { RETURN_NOT_OK(CheckArity(types->size())); using arrow20::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); // Size of types is checked above. const auto originalType = (*types)[0]; if (is_integer((*types)[0].type->id())) { (*types)[0] = float64(); } if (auto kernel = DispatchExactImpl(this, *types)) return kernel; (*types)[0] = originalType; return arrow20::compute::detail::NoMatchingKernel(this, *types); } }; /// A RoundFunction that promotes integer and decimal arguments to double. struct RoundFloatingPointFunction : public RoundFunction { using RoundFunction::RoundFunction; Result DispatchBest(std::vector* types) const override { RETURN_NOT_OK(CheckArity(types->size())); using arrow20::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); // Size of types is checked above. const auto originalType = (*types)[0]; if (is_integer((*types)[0].type->id()) || is_decimal((*types)[0].type->id())) { (*types)[0] = float64(); } if (auto kernel = DispatchExactImpl(this, *types)) return kernel; (*types)[0] = originalType; return arrow20::compute::detail::NoMatchingKernel(this, *types); } }; // End round functions // ---------------------------------------------------------------------- // Begin round kernels #define ROUND_CASE(MODE) \ case RoundMode::MODE: { \ using Op = OpImpl; \ return ScalarUnaryNotNullStateful(Op(state, *out->type())) \ .Exec(ctx, batch, out); \ } // Exec the round kernel for the given types template class OpImpl> struct RoundKernel { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { using State = RoundOptionsWrapper::CType>; const auto& state = static_cast(*ctx->state()); switch (state.options.round_mode) { ROUND_CASE(DOWN) ROUND_CASE(UP) ROUND_CASE(TOWARDS_ZERO) ROUND_CASE(TOWARDS_INFINITY) ROUND_CASE(HALF_DOWN) ROUND_CASE(HALF_UP) ROUND_CASE(HALF_TOWARDS_ZERO) ROUND_CASE(HALF_TOWARDS_INFINITY) ROUND_CASE(HALF_TO_EVEN) ROUND_CASE(HALF_TO_ODD) } DCHECK(false); return Status::NotImplemented( "Internal implementation error: round mode not implemented: ", state.options.ToString()); } }; #undef ROUND_CASE #define ROUND_BINARY_CASE(MODE) \ case RoundMode::MODE: { \ using Op = OpImpl; \ return applicator::ScalarBinaryNotNullStateful( \ Op(state, *out->type())) \ .Exec(ctx, batch, out); \ } // Exec the round (binary) kernel for the given types template class OpImpl> struct RoundBinaryKernel { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { using State = RoundOptionsWrapper::CType>; const auto& state = static_cast(*ctx->state()); switch (state.options.round_mode) { ROUND_BINARY_CASE(DOWN) ROUND_BINARY_CASE(UP) ROUND_BINARY_CASE(TOWARDS_ZERO) ROUND_BINARY_CASE(TOWARDS_INFINITY) ROUND_BINARY_CASE(HALF_DOWN) ROUND_BINARY_CASE(HALF_UP) ROUND_BINARY_CASE(HALF_TOWARDS_ZERO) ROUND_BINARY_CASE(HALF_TOWARDS_INFINITY) ROUND_BINARY_CASE(HALF_TO_EVEN) ROUND_BINARY_CASE(HALF_TO_ODD) } DCHECK(false); return Status::NotImplemented( "Internal implementation error: round mode not implemented: ", state.options.ToString()); } }; #undef ROUND_BINARY_CASE template Status FixedRoundDecimalExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { using Op = Round; return ScalarUnaryNotNullStateful( Op(kDigits, *out->type())) .Exec(ctx, batch, out); } // End round kernels // ---------------------------------------------------------------------- // Begin round kernel generation and function registration template < template class Op, template typename> class Kernel, typename OptionsType> struct RoundKernelGenerator { template Status Visit(const ArrowType& type, ArrayKernelExec* exec, KernelInit* init) { if constexpr (is_integer_type::value || (is_floating_type::value && !is_half_float_type::value) || is_decimal_type::value) { *exec = Kernel::Exec; *init = RoundOptionsWrapper::CType>::Init; } else { DCHECK(false); return Status::NotImplemented("Round does not support ", type.ToString()); } return Status::OK(); } }; // For unary rounding functions that control kernel dispatch based on RoundMode, only on // non-null output. template