diff options
| author | AlexSm <[email protected]> | 2024-01-04 15:09:05 +0100 |
|---|---|---|
| committer | GitHub <[email protected]> | 2024-01-04 15:09:05 +0100 |
| commit | dab291146f6cd7d35684e3a1150e5bb1c412982c (patch) | |
| tree | 36ef35f6cacb6432845a4a33f940c95871036b32 /contrib/clickhouse/src/Functions/FunctionBinaryArithmetic.h | |
| parent | 63660ad5e7512029fd0218e7a636580695a24e1f (diff) | |
Library import 5, delete go dependencies (#832)
* Library import 5, delete go dependencies
* Fix yt client
Diffstat (limited to 'contrib/clickhouse/src/Functions/FunctionBinaryArithmetic.h')
| -rw-r--r-- | contrib/clickhouse/src/Functions/FunctionBinaryArithmetic.h | 2482 |
1 files changed, 0 insertions, 2482 deletions
diff --git a/contrib/clickhouse/src/Functions/FunctionBinaryArithmetic.h b/contrib/clickhouse/src/Functions/FunctionBinaryArithmetic.h deleted file mode 100644 index a409b75c83e..00000000000 --- a/contrib/clickhouse/src/Functions/FunctionBinaryArithmetic.h +++ /dev/null @@ -1,2482 +0,0 @@ -#pragma once - -// Include this first, because `#define _asan_poison_address` from -// llvm/Support/Compiler.h conflicts with its forward declaration in -// sanitizer/asan_interface.h -#include <memory> -#include <type_traits> -#include <base/wide_integer_to_string.h> - -#include <Columns/ColumnAggregateFunction.h> -#include <Columns/ColumnConst.h> -#include <Columns/ColumnDecimal.h> -#include <Columns/ColumnFixedString.h> -#include <Columns/ColumnNullable.h> -#include <Columns/ColumnString.h> -#include <Columns/ColumnVector.h> -#include <Core/DecimalFunctions.h> -#include <DataTypes/DataTypeAggregateFunction.h> -#include <DataTypes/DataTypeDate.h> -#include <DataTypes/DataTypeDateTime.h> -#include <DataTypes/DataTypeDateTime64.h> -#include <DataTypes/DataTypeFactory.h> -#include <DataTypes/DataTypeFixedString.h> -#include <DataTypes/DataTypeInterval.h> -#include <DataTypes/DataTypeTuple.h> -#include <DataTypes/DataTypeString.h> -#include <DataTypes/DataTypeIPv4andIPv6.h> -#include <DataTypes/DataTypesDecimal.h> -#include <DataTypes/DataTypesNumber.h> -#include <DataTypes/Native.h> -#include <DataTypes/NumberTraits.h> -#include <Functions/DivisionUtils.h> -#include <Functions/FunctionFactory.h> -#include <Functions/FunctionHelpers.h> -#include <Functions/IFunction.h> -#include <Functions/IsOperation.h> -#include <Functions/castTypeToEither.h> -#include <Interpreters/castColumn.h> -#include <base/TypeList.h> -#include <base/map.h> -#include <Common/FieldVisitorsAccurateComparison.h> -#include <Common/assert_cast.h> -#include <Common/typeid_cast.h> -#include <Common/Arena.h> -#include <Core/ColumnWithTypeAndName.h> -#include <base/types.h> -#include <Columns/ColumnArray.h> -#include <Columns/IColumn.h> -#include <Core/ColumnsWithTypeAndName.h> -#include <DataTypes/IDataType.h> -#include <DataTypes/getMostSubtype.h> -#include <base/TypeLists.h> -#include <DataTypes/DataTypeArray.h> -#include <DataTypes/DataTypeLowCardinality.h> -#include <Interpreters/Context.h> - -#if USE_EMBEDDED_COMPILER -# error #include <llvm/IR/IRBuilder.h> -#endif - -#include <cassert> - -namespace DB -{ - -namespace ErrorCodes -{ - extern const int ILLEGAL_COLUMN; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int LOGICAL_ERROR; - extern const int DECIMAL_OVERFLOW; - extern const int CANNOT_ADD_DIFFERENT_AGGREGATE_STATES; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int SIZES_OF_ARRAYS_DONT_MATCH; -} - -namespace traits_ -{ -struct InvalidType; /// Used to indicate undefined operation - -template <bool V, typename T> struct Case : std::bool_constant<V> { using type = T; }; - -/// Switch<Case<C0, T0>, ...> -- select the first Ti for which Ci is true, InvalidType if none. -template <typename... Ts> using Switch = typename std::disjunction<Ts..., Case<true, InvalidType>>::type; - -template <class T> -using DataTypeFromFieldType = std::conditional_t<std::is_same_v<T, NumberTraits::Error>, - InvalidType, DataTypeNumber<T>>; - -template <typename DataType> constexpr bool IsIntegral = false; -template <> inline constexpr bool IsIntegral<DataTypeUInt8> = true; -template <> inline constexpr bool IsIntegral<DataTypeUInt16> = true; -template <> inline constexpr bool IsIntegral<DataTypeUInt32> = true; -template <> inline constexpr bool IsIntegral<DataTypeUInt64> = true; -template <> inline constexpr bool IsIntegral<DataTypeInt8> = true; -template <> inline constexpr bool IsIntegral<DataTypeInt16> = true; -template <> inline constexpr bool IsIntegral<DataTypeInt32> = true; -template <> inline constexpr bool IsIntegral<DataTypeInt64> = true; - -template <typename DataType> constexpr bool IsExtended = false; -template <> inline constexpr bool IsExtended<DataTypeUInt128> = true; -template <> inline constexpr bool IsExtended<DataTypeUInt256> = true; -template <> inline constexpr bool IsExtended<DataTypeInt128> = true; -template <> inline constexpr bool IsExtended<DataTypeInt256> = true; - -template <typename DataType> constexpr bool IsIntegralOrExtended = IsIntegral<DataType> || IsExtended<DataType>; -template <typename DataType> constexpr bool IsIntegralOrExtendedOrDecimal = - IsIntegralOrExtended<DataType> || - IsDataTypeDecimal<DataType>; - -template <typename DataType> constexpr bool IsFloatingPoint = false; -template <> inline constexpr bool IsFloatingPoint<DataTypeFloat32> = true; -template <> inline constexpr bool IsFloatingPoint<DataTypeFloat64> = true; - -template <typename DataType> constexpr bool IsArray = false; -template <> inline constexpr bool IsArray<DataTypeArray> = true; - -template <typename DataType> constexpr bool IsDateOrDateTime = false; -template <> inline constexpr bool IsDateOrDateTime<DataTypeDate> = true; -template <> inline constexpr bool IsDateOrDateTime<DataTypeDateTime> = true; - -template <typename DataType> constexpr bool IsIPv4 = false; -template <> inline constexpr bool IsIPv4<DataTypeIPv4> = true; - -template <typename T0, typename T1> constexpr bool UseLeftDecimal = false; -template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal256>, DataTypeDecimal<Decimal128>> = true; -template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal256>, DataTypeDecimal<Decimal64>> = true; -template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal256>, DataTypeDecimal<Decimal32>> = true; -template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal128>, DataTypeDecimal<Decimal32>> = true; -template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal128>, DataTypeDecimal<Decimal64>> = true; -template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal64>, DataTypeDecimal<Decimal32>> = true; - -template <typename DataType> constexpr bool IsFixedString = false; -template <> inline constexpr bool IsFixedString<DataTypeFixedString> = true; - -template <typename DataType> constexpr bool IsString = false; -template <> inline constexpr bool IsString<DataTypeString> = true; - -template <template <typename, typename> class Operation, typename LeftDataType, typename RightDataType> -struct BinaryOperationTraits -{ - using T0 = typename LeftDataType::FieldType; - using T1 = typename RightDataType::FieldType; -private: /// it's not correct for Decimal - using Op = Operation<T0, T1>; - -public: - static constexpr bool allow_decimal = IsOperation<Operation>::allow_decimal; - - /// Appropriate result type for binary operator on numeric types. "Date" can also mean - /// DateTime, but if both operands are Dates, their type must be the same (e.g. Date - DateTime is invalid). - using ResultDataType = Switch< - /// Decimal cases - Case<!allow_decimal && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>), InvalidType>, - Case< - IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType> && UseLeftDecimal<LeftDataType, RightDataType>, - LeftDataType>, - Case<IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>, RightDataType>, - Case<IsDataTypeDecimal<LeftDataType> && IsIntegralOrExtended<RightDataType>, LeftDataType>, - Case<IsDataTypeDecimal<RightDataType> && IsIntegralOrExtended<LeftDataType>, RightDataType>, - - /// e.g Decimal +-*/ Float, least(Decimal, Float), greatest(Decimal, Float) = Float64 - Case<IsOperation<Operation>::allow_decimal && IsDataTypeDecimal<LeftDataType> && IsFloatingPoint<RightDataType>, DataTypeFloat64>, - Case<IsOperation<Operation>::allow_decimal && IsDataTypeDecimal<RightDataType> && IsFloatingPoint<LeftDataType>, DataTypeFloat64>, - - Case<IsOperation<Operation>::bit_hamming_distance && IsIntegral<LeftDataType> && IsIntegral<RightDataType>, DataTypeUInt8>, - Case<IsOperation<Operation>::bit_hamming_distance && IsFixedString<LeftDataType> && IsFixedString<RightDataType>, DataTypeUInt16>, - Case<IsOperation<Operation>::bit_hamming_distance && IsString<LeftDataType> && IsString<RightDataType>, DataTypeUInt64>, - - /// Decimal <op> Real is not supported (traditional DBs convert Decimal <op> Real to Real) - Case<IsDataTypeDecimal<LeftDataType> && !IsIntegralOrExtendedOrDecimal<RightDataType>, InvalidType>, - Case<IsDataTypeDecimal<RightDataType> && !IsIntegralOrExtendedOrDecimal<LeftDataType>, InvalidType>, - - /// number <op> number -> see corresponding impl - Case<!IsDateOrDateTime<LeftDataType> && !IsDateOrDateTime<RightDataType>, DataTypeFromFieldType<typename Op::ResultType>>, - - /// Date + Integral -> Date - /// Integral + Date -> Date - Case< - IsOperation<Operation>::plus, - Switch<Case<IsIntegral<RightDataType>, LeftDataType>, Case<IsIntegral<LeftDataType>, RightDataType>>>, - - /// Date - Date -> Int32 - /// Date - Integral -> Date - Case< - IsOperation<Operation>::minus, - Switch< - Case<std::is_same_v<LeftDataType, RightDataType>, DataTypeInt32>, - Case<IsDateOrDateTime<LeftDataType> && IsIntegral<RightDataType>, LeftDataType>>>, - - /// least(Date, Date) -> Date - /// greatest(Date, Date) -> Date - Case< - std::is_same_v<LeftDataType, RightDataType> && (IsOperation<Operation>::least || IsOperation<Operation>::greatest), - LeftDataType>, - - /// Date % Int32 -> Int32 - /// Date % Float -> Float64 - Case< - IsOperation<Operation>::modulo || IsOperation<Operation>::positive_modulo, - Switch< - Case<IsDateOrDateTime<LeftDataType> && IsIntegral<RightDataType>, RightDataType>, - Case<IsDateOrDateTime<LeftDataType> && IsFloatingPoint<RightDataType>, DataTypeFloat64>>>>; -}; -} - -namespace impl_ -{ - -/** Arithmetic operations: +, -, *, /, %, - * intDiv (integer division) - * Bitwise operations: |, &, ^, ~. - * Etc. - */ - -enum class OpCase { Vector, LeftConstant, RightConstant }; - -constexpr const auto & undec(const auto & x) { return x; } -constexpr const auto & undec(const is_decimal auto & x) { return x.value; } - -template <typename A, typename B, typename Op, typename OpResultType = typename Op::ResultType> -struct BinaryOperation -{ - using ResultType = OpResultType; - static const constexpr bool allow_fixed_string = false; - static const constexpr bool allow_string_integer = false; - - template <OpCase op_case> - static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap = nullptr) - { - if constexpr (op_case == OpCase::RightConstant) - { - if (right_nullmap && (*right_nullmap)[0]) - return; - - for (size_t i = 0; i < size; ++i) - c[i] = Op::template apply<ResultType>(a[i], *b); - } - else - { - if (right_nullmap) - { - for (size_t i = 0; i < size; ++i) - if ((*right_nullmap)[i]) - c[i] = ResultType(); - else - apply<op_case>(a, b, c, i); - } - else - for (size_t i = 0; i < size; ++i) - apply<op_case>(a, b, c, i); - } - } - - static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); } - -private: - template <OpCase op_case> - static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i) - { - if constexpr (op_case == OpCase::Vector) - c[i] = Op::template apply<ResultType>(a[i], b[i]); - else - c[i] = Op::template apply<ResultType>(*a, b[i]); - } -}; - -template <typename B, typename Op> -struct StringIntegerOperationImpl -{ - template <OpCase op_case> - static void NO_INLINE processFixedString(const UInt8 * __restrict in_vec, const UInt64 n, const B * __restrict b, ColumnFixedString::Chars & out_vec, size_t size) - { - size_t prev_offset = 0; - out_vec.reserve(n * size); - for (size_t i = 0; i < size; ++i) - { - if constexpr (op_case == OpCase::LeftConstant) - { - Op::apply(&in_vec[0], &in_vec[n], b[i], out_vec); - } - else - { - size_t new_offset = prev_offset + n; - - if constexpr (op_case == OpCase::Vector) - { - Op::apply(&in_vec[prev_offset], &in_vec[new_offset], b[i], out_vec); - } - else - { - Op::apply(&in_vec[prev_offset], &in_vec[new_offset], b[0], out_vec); - } - prev_offset = new_offset; - } - } - } - - - template <OpCase op_case> - static void NO_INLINE processString(const UInt8 * __restrict in_vec, const UInt64 * __restrict in_offsets, const B * __restrict b, ColumnString::Chars & out_vec, ColumnString::Offsets & out_offsets, size_t size) - { - size_t prev_offset = 0; - - for (size_t i = 0; i < size; ++i) - { - if constexpr (op_case == OpCase::LeftConstant) - { - Op::apply(&in_vec[0], &in_vec[in_offsets[0] - 1], b[i], out_vec, out_offsets); - } - else - { - size_t new_offset = in_offsets[i]; - - if constexpr (op_case == OpCase::Vector) - { - Op::apply(&in_vec[prev_offset], &in_vec[new_offset - 1], b[i], out_vec, out_offsets); - } - else - { - Op::apply(&in_vec[prev_offset], &in_vec[new_offset - 1], b[0], out_vec, out_offsets); - } - - prev_offset = new_offset; - } - } - } -}; - -template <typename Op> -struct FixedStringOperationImpl -{ - template <OpCase op_case> - static void NO_INLINE process( - const UInt8 * __restrict a, const UInt8 * __restrict b, UInt8 * __restrict result, - size_t size, [[maybe_unused]] size_t N) - { - if constexpr (op_case == OpCase::Vector) - for (size_t i = 0; i < size; ++i) - result[i] = Op::template apply<UInt8>(a[i], b[i]); - else if constexpr (op_case == OpCase::LeftConstant) - withConst<true>(b, a, result, size, N); - else - withConst<false>(a, b, result, size, N); - } - -private: - template <bool inverted> - static void NO_INLINE withConst(const UInt8 * __restrict a, const UInt8 * __restrict b, UInt8 * __restrict c, size_t size, size_t N) - { - /// These complications are needed to avoid integer division in inner loop. - - /// Create a pattern of repeated values of b with at least 16 bytes, - /// so we can read 16 bytes of this repeated pattern starting from any offset inside b. - /// - /// Example: - /// - /// N = 6 - /// ------ - /// [abcdefabcdefabcdefabc] - /// ^^^^^^^^^^^^^^^^ - /// 16 bytes starting from the last offset inside b. - - const size_t b_repeated_size = N + 15; - - UInt8 b_repeated[b_repeated_size]; - - for (size_t i = 0; i < b_repeated_size; ++i) - b_repeated[i] = b[i % N]; - - size_t b_offset = 0; - const size_t b_increment = 16 % N; - - /// Example: - /// - /// At first iteration we copy 16 bytes at offset 0 from b_repeated: - /// [abcdefabcdefabcdefabc] - /// ^^^^^^^^^^^^^^^^ - /// At second iteration we copy 16 bytes at offset 4 = 16 % 6 from b_repeated: - /// [abcdefabcdefabcdefabc] - /// ^^^^^^^^^^^^^^^^ - /// At third iteration we copy 16 bytes at offset 2 = (16 * 2) % 6 from b_repeated: - /// [abcdefabcdefabcdefabc] - /// ^^^^^^^^^^^^^^^^ - - /// PaddedPODArray allows overflow for 15 bytes. - for (size_t i = 0; i < size; i += 16) - { - /// This loop is formed in a way to be vectorized into two SIMD mov. - for (size_t j = 0; j < 16; ++j) - c[i + j] = inverted - ? Op::template apply<UInt8>(a[i + j], b_repeated[b_offset + j]) - : Op::template apply<UInt8>(b_repeated[b_offset + j], a[i + j]); - - b_offset += b_increment; - - if (b_offset >= N) /// This condition is easily predictable. - b_offset -= N; - } - } -}; - -template <typename Op> -struct FixedStringReduceOperationImpl -{ - template <OpCase op_case> - static void inline process(const UInt8 * __restrict a, const UInt8 * __restrict b, UInt16 * __restrict result, size_t size, size_t N) - { - if constexpr (op_case == OpCase::Vector) - vectorVector(a, b, result, size, N); - else if constexpr (op_case == OpCase::LeftConstant) - vectorConstant(b, a, result, size, N); - else - vectorConstant(a, b, result, size, N); - } - -private: - static void vectorVector(const UInt8 * __restrict a, const UInt8 * __restrict b, UInt16 * __restrict result, size_t size, size_t N) - { - for (size_t i = 0; i < size; ++i) - { - size_t offset = i * N; - for (size_t j = 0; j < N; ++j) - { - result[i] += Op::template apply<UInt8>(a[offset + j], b[offset + j]); - } - } - } - - static void vectorConstant(const UInt8 * __restrict a, const UInt8 * __restrict b, UInt16 * __restrict result, size_t size, size_t N) - { - for (size_t i = 0; i < size; ++i) - { - size_t offset = i * N; - for (size_t j = 0; j < N; ++j) - { - result[i] += Op::template apply<UInt8>(a[offset + j], b[j]); - } - } - } -}; - -template <typename Op> -struct StringReduceOperationImpl -{ - static void vectorVector( - const ColumnString::Chars & a, - const ColumnString::Offsets & offsets_a, - const ColumnString::Chars & b, - const ColumnString::Offsets & offsets_b, - PaddedPODArray<UInt64> & res) - { - size_t size = res.size(); - for (size_t i = 0; i < size; ++i) - { - res[i] = process( - a.data() + offsets_a[i - 1], - a.data() + offsets_a[i] - 1, - b.data() + offsets_b[i - 1], - b.data() + offsets_b[i] - 1); - } - } - - static void - vectorConstant(const ColumnString::Chars & a, const ColumnString::Offsets & offsets_a, std::string_view b, PaddedPODArray<UInt64> & res) - { - size_t size = res.size(); - for (size_t i = 0; i < size; ++i) - { - res[i] = process( - a.data() + offsets_a[i - 1], - a.data() + offsets_a[i] - 1, - reinterpret_cast<const UInt8 *>(b.data()), - reinterpret_cast<const UInt8 *>(b.data()) + b.size()); - } - } - - static inline UInt64 constConst(std::string_view a, std::string_view b) - { - return process( - reinterpret_cast<const UInt8 *>(a.data()), - reinterpret_cast<const UInt8 *>(a.data()) + a.size(), - reinterpret_cast<const UInt8 *>(b.data()), - reinterpret_cast<const UInt8 *>(b.data()) + b.size()); - } - -private: - static UInt64 process(const UInt8 * __restrict start_a, const UInt8 * __restrict end_a, const UInt8 * start_b, const UInt8 * end_b) - { - UInt64 res = 0; - while (start_a < end_a && start_b < end_b) - res += Op::template apply<UInt8>(*start_a++, *start_b++); - - while (start_a < end_a) - res += Op::template apply<UInt8>(*start_a++, 0); - while (start_b < end_b) - res += Op::template apply<UInt8>(0, *start_b++); - return res; - } -}; - -template <typename A, typename B, typename Op, typename ResultType = typename Op::ResultType> -struct BinaryOperationImpl : BinaryOperation<A, B, Op, ResultType> { }; - -/** - * Binary operations with Decimals (either Decimal OP Decimal or Decimal Op Float) need to scale the args correctly. - * - + (plus), - (minus), * (multiply), least and greatest operations scale one of the args (which scale factor is not 1). - * The resulting scale is either left or the right scale. - * - / (divide) operation scales the first argument. - * The resulting scale is the first one's. - */ -template <template <typename, typename> typename Operation, class OpResultType, bool check_overflow = true> -struct DecimalBinaryOperation -{ -private: - using ResultType = OpResultType; // e.g. Decimal32 - using NativeResultType = NativeType<ResultType>; // e.g. UInt32 for Decimal32 - - using ResultContainerType = typename ColumnVectorOrDecimal<ResultType>::Container; - -public: - template <OpCase op_case, bool is_decimal_a, bool is_decimal_b> - static void NO_INLINE process(const auto & a, const auto & b, ResultContainerType & c, - NativeResultType scale_a, NativeResultType scale_b, const NullMap * right_nullmap = nullptr) - { - if constexpr (op_case == OpCase::LeftConstant) static_assert(!is_decimal<decltype(a)>); - if constexpr (op_case == OpCase::RightConstant) static_assert(!is_decimal<decltype(b)>); - - size_t size; - - if constexpr (op_case == OpCase::LeftConstant) - size = b.size(); - else - size = a.size(); - - if constexpr (is_plus_minus_compare) - { - if (scale_a != 1) - { - for (size_t i = 0; i < size; ++i) - c[i] = applyScaled<true>( - static_cast<NativeResultType>(unwrap<op_case, OpCase::LeftConstant>(a, i)), - static_cast<NativeResultType>(unwrap<op_case, OpCase::RightConstant>(b, i)), - scale_a); - return; - } - else if (scale_b != 1) - { - for (size_t i = 0; i < size; ++i) - c[i] = applyScaled<false>( - static_cast<NativeResultType>(unwrap<op_case, OpCase::LeftConstant>(a, i)), - static_cast<NativeResultType>(unwrap<op_case, OpCase::RightConstant>(b, i)), - scale_b); - return; - } - } - else if constexpr (is_multiply) - { - if (scale_a != 1) - { - for (size_t i = 0; i < size; ++i) - c[i] = applyScaled<true, false>( - static_cast<NativeResultType>(unwrap<op_case, OpCase::LeftConstant>(a, i)), - static_cast<NativeResultType>(unwrap<op_case, OpCase::RightConstant>(b, i)), - scale_a); - return; - } - else if (scale_b != 1) - { - for (size_t i = 0; i < size; ++i) - c[i] = applyScaled<false, false>( - static_cast<NativeResultType>(unwrap<op_case, OpCase::LeftConstant>(a, i)), - static_cast<NativeResultType>(unwrap<op_case, OpCase::RightConstant>(b, i)), - scale_b); - return; - } - - } - else if constexpr (is_division && is_decimal_b) - { - processWithRightNullmapImpl<op_case>(a, b, c, size, right_nullmap, [&scale_a](const auto & left, const auto & right) - { - return applyScaledDiv<is_decimal_a>( - static_cast<NativeResultType>(left), right, scale_a); - }); - return; - } - - processWithRightNullmapImpl<op_case>( - a, b, c, size, right_nullmap, - [](const auto & left, const auto & right) - { - return apply( - static_cast<NativeResultType>(left), - static_cast<NativeResultType>(right)); - }); - } - - template <bool is_decimal_a, bool is_decimal_b, class A, class B> - static ResultType process(A a, B b, NativeResultType scale_a, NativeResultType scale_b) - requires(!is_decimal<A> && !is_decimal<B>) - { - if constexpr (is_division && is_decimal_b) - return applyScaledDiv<is_decimal_a>(a, b, scale_a); - else if constexpr (is_plus_minus_compare) - { - if (scale_a != 1) - return applyScaled<true>(a, b, scale_a); - if (scale_b != 1) - return applyScaled<false>(a, b, scale_b); - } - - return apply(a, b); - } - -private: - template <OpCase op_case, typename ApplyFunc> - static inline void processWithRightNullmapImpl(const auto & a, const auto & b, ResultContainerType & c, size_t size, const NullMap * right_nullmap, ApplyFunc apply_func) - { - if (right_nullmap) - { - if constexpr (op_case == OpCase::RightConstant) - { - if ((*right_nullmap)[0]) - return; - - for (size_t i = 0; i < size; ++i) - c[i] = apply_func(undec(a[i]), undec(b)); - } - else - { - for (size_t i = 0; i < size; ++i) - { - if ((*right_nullmap)[i]) - c[i] = ResultType(); - else - c[i] = apply_func(unwrap<op_case, OpCase::LeftConstant>(a, i), undec(b[i])); - } - } - } - else - for (size_t i = 0; i < size; ++i) - c[i] = apply_func(unwrap<op_case, OpCase::LeftConstant>(a, i), unwrap<op_case, OpCase::RightConstant>(b, i)); - } - - static constexpr bool is_plus_minus = IsOperation<Operation>::plus || - IsOperation<Operation>::minus; - static constexpr bool is_multiply = IsOperation<Operation>::multiply; - static constexpr bool is_float_division = IsOperation<Operation>::div_floating; - static constexpr bool is_int_division = IsOperation<Operation>::div_int || - IsOperation<Operation>::div_int_or_zero; - static constexpr bool is_division = is_float_division || is_int_division; - static constexpr bool is_compare = IsOperation<Operation>::least || - IsOperation<Operation>::greatest; - static constexpr bool is_plus_minus_compare = is_plus_minus || is_compare; - static constexpr bool can_overflow = is_plus_minus || is_multiply; - - using Op = std::conditional_t<is_float_division, - DivideIntegralImpl<NativeResultType, NativeResultType>, /// substitute divide by intDiv (throw on division by zero) - Operation<NativeResultType, NativeResultType>>; - - template <OpCase op_case, OpCase target, class E> - static auto unwrap(const E& elem, size_t i) - { - if constexpr (op_case == target) - return undec(elem); - else - return undec(elem[i]); - } - - /// there's implicit type conversion here - static NativeResultType apply(NativeResultType a, NativeResultType b) - { - if constexpr (can_overflow && check_overflow) - { - NativeResultType res; - if (Op::template apply<NativeResultType>(a, b, res)) - throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow"); - return res; - } - else - return Op::template apply<NativeResultType>(a, b); - } - - template <bool scale_left, bool may_check_overflow = true> - static NO_SANITIZE_UNDEFINED NativeResultType applyScaled(NativeResultType a, NativeResultType b, NativeResultType scale) - { - static_assert(is_plus_minus_compare || is_multiply); - NativeResultType res; - - if constexpr (check_overflow && may_check_overflow) - { - bool overflow = false; - - if constexpr (scale_left) - overflow |= common::mulOverflow(a, scale, a); - else - overflow |= common::mulOverflow(b, scale, b); - - if constexpr (can_overflow) - overflow |= Op::template apply<NativeResultType>(a, b, res); - else - res = Op::template apply<NativeResultType>(a, b); - - if (overflow) - throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow"); - } - else - { - if constexpr (scale_left) - a *= scale; - else - b *= scale; - res = Op::template apply<NativeResultType>(a, b); - } - - return res; - } - - template <bool is_decimal_a> - static NO_SANITIZE_UNDEFINED NativeResultType applyScaledDiv(NativeResultType a, NativeResultType b, NativeResultType scale) - { - if constexpr (is_division) - { - if constexpr (check_overflow) - { - bool overflow = false; - if constexpr (!is_decimal_a) - overflow |= common::mulOverflow(scale, scale, scale); - overflow |= common::mulOverflow(a, scale, a); - if (overflow) - throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow"); - } - else - { - if constexpr (!is_decimal_a) - scale *= scale; - a *= scale; - } - - return Op::template apply<NativeResultType>(a, b); - } - } -}; -} - -using namespace traits_; -using namespace impl_; - -template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true, bool division_by_nullable = false> -class FunctionBinaryArithmetic : public IFunction -{ - static constexpr bool is_plus = IsOperation<Op>::plus; - static constexpr bool is_minus = IsOperation<Op>::minus; - static constexpr bool is_multiply = IsOperation<Op>::multiply; - static constexpr bool is_division = IsOperation<Op>::division; - static constexpr bool is_bit_hamming_distance = IsOperation<Op>::bit_hamming_distance; - static constexpr bool is_modulo = IsOperation<Op>::modulo; - static constexpr bool is_div_int = IsOperation<Op>::div_int; - static constexpr bool is_div_int_or_zero = IsOperation<Op>::div_int_or_zero; - - ContextPtr context; - bool check_decimal_overflow = true; - - static bool castType(const IDataType * type, auto && f) - { - using Types = TypeList< - DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64, DataTypeUInt128, DataTypeUInt256, - DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, DataTypeInt128, DataTypeInt256, - DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128, DataTypeDecimal256, - DataTypeDate, DataTypeDateTime, - DataTypeFixedString, DataTypeString, - DataTypeInterval>; - - using Floats = TypeList<DataTypeFloat32, DataTypeFloat64>; - - using ValidTypes = std::conditional_t<valid_on_float_arguments, - TypeListConcat<Types, Floats>, - Types>; - - return castTypeToEither(ValidTypes{}, type, std::forward<decltype(f)>(f)); - } - - template <typename F> - static bool castBothTypes(const IDataType * left, const IDataType * right, F && f) - { - return castType(left, [&](const auto & left_) - { - return castType(right, [&](const auto & right_) - { - return f(left_, right_); - }); - }); - } - - static FunctionOverloadResolverPtr - getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context) - { - bool first_is_date_or_datetime = isDateOrDate32(type0) || isDateTime(type0) || isDateTime64(type0); - bool second_is_date_or_datetime = isDateOrDate32(type1) || isDateTime(type1) || isDateTime64(type1); - - /// Exactly one argument must be Date or DateTime - if (first_is_date_or_datetime == second_is_date_or_datetime) - return {}; - - /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. - /// We construct another function (example: addMonths) and call it. - - if constexpr (!is_plus && !is_minus) - return {}; - - const DataTypePtr & type_time = first_is_date_or_datetime ? type0 : type1; - const DataTypePtr & type_interval = first_is_date_or_datetime ? type1 : type0; - - bool interval_is_number = isNumber(type_interval); - - const DataTypeInterval * interval_data_type = nullptr; - if (!interval_is_number) - { - interval_data_type = checkAndGetDataType<DataTypeInterval>(type_interval.get()); - - if (!interval_data_type) - return {}; - } - - if (second_is_date_or_datetime && is_minus) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Wrong order of arguments for function {}: " - "argument of type Interval cannot be first", name); - - std::string function_name; - if (interval_data_type) - { - function_name = fmt::format("{}{}s", - is_plus ? "add" : "subtract", - interval_data_type->getKind().toString()); - } - else - { - if (isDateOrDate32(type_time)) - function_name = is_plus ? "addDays" : "subtractDays"; - else - function_name = is_plus ? "addSeconds" : "subtractSeconds"; - } - - return FunctionFactory::instance().get(function_name, context); - } - - static FunctionOverloadResolverPtr - getFunctionForDateTupleOfIntervalsArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context) - { - bool first_is_date_or_datetime = isDateOrDate32(type0) || isDateTime(type0) || isDateTime64(type0); - bool second_is_date_or_datetime = isDateOrDate32(type1) || isDateTime(type1) || isDateTime64(type1); - - /// Exactly one argument must be Date or DateTime - if (first_is_date_or_datetime == second_is_date_or_datetime) - return {}; - - if (!isTuple(type0) && !isTuple(type1)) - return {}; - - /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Tuple. - /// We construct another function and call it. - if constexpr (!is_plus && !is_minus) - return {}; - - if (isTuple(type0) && second_is_date_or_datetime && is_minus) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Wrong order of arguments for function {}: " - "argument of Tuple type cannot be first", name); - - std::string function_name; - if (is_plus) - { - function_name = "addTupleOfIntervals"; - } - else - { - function_name = "subtractTupleOfIntervals"; - } - - return FunctionFactory::instance().get(function_name, context); - } - - static FunctionOverloadResolverPtr - getFunctionForMergeIntervalsArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context) - { - /// Special case when the function is plus or minus, first argument is Interval or Tuple of Intervals - /// and the second argument is the Interval of a different kind. - /// We construct another function (example: addIntervals) and call it - - if constexpr (!is_plus && !is_minus) - return {}; - - const auto * tuple_data_type_0 = checkAndGetDataType<DataTypeTuple>(type0.get()); - const auto * interval_data_type_0 = checkAndGetDataType<DataTypeInterval>(type0.get()); - const auto * interval_data_type_1 = checkAndGetDataType<DataTypeInterval>(type1.get()); - - if ((!tuple_data_type_0 && !interval_data_type_0) || !interval_data_type_1) - return {}; - - if (interval_data_type_0 && interval_data_type_0->equals(*interval_data_type_1)) - return {}; - - if (tuple_data_type_0) - { - const auto & tuple_types = tuple_data_type_0->getElements(); - for (const auto & type : tuple_types) - if (!isInterval(type)) - return {}; - } - - std::string function_name; - if (is_plus) - { - function_name = "addInterval"; - } - else - { - function_name = "subtractInterval"; - } - - return FunctionFactory::instance().get(function_name, context); - } - - static FunctionOverloadResolverPtr - getFunctionForTupleArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context) - { - if (!isTuple(type0) || !isTuple(type1)) - return {}; - - /// Special case when the function is plus, minus or multiply, both arguments are tuples. - /// We construct another function (example: tuplePlus) and call it. - - if constexpr (!is_plus && !is_minus && !is_multiply) - return {}; - - std::string function_name; - if (is_plus) - { - function_name = "tuplePlus"; - } - else if (is_minus) - { - function_name = "tupleMinus"; - } - else - { - function_name = "dotProduct"; - } - - return FunctionFactory::instance().get(function_name, context); - } - - static FunctionOverloadResolverPtr - getFunctionForTupleAndNumberArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context) - { - if (!(isTuple(type0) && isNumber(type1)) && !(isTuple(type1) && isNumber(type0))) - return {}; - - /// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number. - /// We construct another function (example: tupleMultiplyByNumber) and call it. - - if constexpr (!is_multiply && !is_division) - return {}; - - if (isNumber(type0) && is_division) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Wrong order of arguments for function {}: " - "argument of numeric type cannot be first", name); - - std::string function_name; - if constexpr (is_multiply) - { - function_name = "tupleMultiplyByNumber"; - } - else // is_division - { - if constexpr (is_modulo) - { - function_name = "tupleModuloByNumber"; - } - else if constexpr (is_div_int) - { - function_name = "tupleIntDivByNumber"; - } - else if constexpr (is_div_int_or_zero) - { - function_name = "tupleIntDivOrZeroByNumber"; - } - else - { - function_name = "tupleDivideByNumber"; - } - } - - return FunctionFactory::instance().get(function_name, context); - } - - static bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) - { - if constexpr (!is_multiply) - return false; - - WhichDataType which0(type0); - WhichDataType which1(type1); - - return (which0.isAggregateFunction() && which1.isNativeUInt()) - || (which0.isNativeUInt() && which1.isAggregateFunction()); - } - - static bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1) - { - if constexpr (!is_plus) - return false; - - WhichDataType which0(type0); - WhichDataType which1(type1); - - return which0.isAggregateFunction() && which1.isAggregateFunction(); - } - - /// Multiply aggregation state by integer constant: by merging it with itself specified number of times. - ColumnPtr executeAggregateMultiply(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const - { - ColumnsWithTypeAndName new_arguments = arguments; - if (WhichDataType(new_arguments[1].type).isAggregateFunction()) - std::swap(new_arguments[0], new_arguments[1]); - - if (!isColumnConst(*new_arguments[1].column)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of aggregation state multiply. " - "Should be integer constant", new_arguments[1].column->getName()); - - const IColumn & agg_state_column = *new_arguments[0].column; - bool agg_state_is_const = isColumnConst(agg_state_column); - const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>( - agg_state_is_const ? assert_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column); - - AggregateFunctionPtr function = column.getAggregateFunction(); - - size_t size = agg_state_is_const ? 1 : input_rows_count; - - auto column_to = ColumnAggregateFunction::create(function); - column_to->reserve(size); - - auto column_from = ColumnAggregateFunction::create(function); - column_from->reserve(size); - - for (size_t i = 0; i < size; ++i) - { - column_to->insertDefault(); - column_from->insertFrom(column.getData()[i]); - } - - auto & vec_to = column_to->getData(); - auto & vec_from = column_from->getData(); - - UInt64 m = typeid_cast<const ColumnConst *>(new_arguments[1].column.get())->getValue<UInt64>(); - - // Since we merge the function states by ourselves, we have to have an - // Arena for this. Pass it to the resulting column so that the arena - // has a proper lifetime. - auto arena = std::make_shared<Arena>(); - column_to->addArena(arena); - - /// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations - /// https://en.wikipedia.org/wiki/Exponentiation_by_squaring - while (m) - { - if (m % 2) - { - for (size_t i = 0; i < size; ++i) - function->merge(vec_to[i], vec_from[i], arena.get()); - --m; - } - else - { - for (size_t i = 0; i < size; ++i) - function->merge(vec_from[i], vec_from[i], arena.get()); - m /= 2; - } - } - - if (agg_state_is_const) - return ColumnConst::create(std::move(column_to), input_rows_count); - else - return column_to; - } - - /// Merge two aggregation states together. - ColumnPtr executeAggregateAddition(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const - { - const IColumn & lhs_column = *arguments[0].column; - const IColumn & rhs_column = *arguments[1].column; - - bool lhs_is_const = isColumnConst(lhs_column); - bool rhs_is_const = isColumnConst(rhs_column); - - const ColumnAggregateFunction & lhs = typeid_cast<const ColumnAggregateFunction &>( - lhs_is_const ? assert_cast<const ColumnConst &>(lhs_column).getDataColumn() : lhs_column); - const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>( - rhs_is_const ? assert_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column); - - AggregateFunctionPtr function = lhs.getAggregateFunction(); - - size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count; - - auto column_to = ColumnAggregateFunction::create(function); - column_to->reserve(size); - - for (size_t i = 0; i < size; ++i) - { - column_to->insertFrom(lhs.getData()[lhs_is_const ? 0 : i]); - column_to->insertMergeFrom(rhs.getData()[rhs_is_const ? 0 : i]); - } - - if (lhs_is_const && rhs_is_const) - return ColumnConst::create(std::move(column_to), input_rows_count); - else - return column_to; - } - - ColumnPtr executeDateTimeIntervalPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, - size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const - { - ColumnsWithTypeAndName new_arguments = arguments; - - /// Interval argument must be second. - if (isDateOrDate32(arguments[1].type) || isDateTime(arguments[1].type) || isDateTime64(arguments[1].type)) - std::swap(new_arguments[0], new_arguments[1]); - - /// Change interval argument type to its representation - if (WhichDataType(new_arguments[1].type).isInterval()) - new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>(); - - auto function = function_builder->build(new_arguments); - return function->execute(new_arguments, result_type, input_rows_count); - } - - ColumnPtr executeDateTimeTupleOfIntervalsPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, - size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const - { - ColumnsWithTypeAndName new_arguments = arguments; - - /// Tuple argument must be second. - if (isTuple(arguments[0].type)) - std::swap(new_arguments[0], new_arguments[1]); - - auto function = function_builder->build(new_arguments); - - return function->execute(new_arguments, result_type, input_rows_count); - } - - ColumnPtr executeIntervalTupleOfIntervalsPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, - size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const - { - auto function = function_builder->build(arguments); - - return function->execute(arguments, result_type, input_rows_count); - } - - ColumnPtr executeArrayImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const - { - const auto * return_type_array = checkAndGetDataType<DataTypeArray>(result_type.get()); - - if (!return_type_array) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Return type for function {} must be array.", getName()); - - auto num_args = arguments.size(); - DataTypes data_types; - - ColumnsWithTypeAndName new_arguments {num_args}; - DataTypePtr result_array_type; - - const auto * left_const = typeid_cast<const ColumnConst *>(arguments[0].column.get()); - const auto * right_const = typeid_cast<const ColumnConst *>(arguments[1].column.get()); - - /// Unpacking arrays if both are constants. - if (left_const && right_const) - { - new_arguments[0] = {left_const->getDataColumnPtr(), arguments[0].type, arguments[0].name}; - new_arguments[1] = {right_const->getDataColumnPtr(), arguments[1].type, arguments[1].name}; - auto col = executeImpl(new_arguments, result_type, 1); - return ColumnConst::create(std::move(col), input_rows_count); - } - - /// Unpacking arrays if at least one column is constant. - if (left_const || right_const) - { - new_arguments[0] = {arguments[0].column->convertToFullColumnIfConst(), arguments[0].type, arguments[0].name}; - new_arguments[1] = {arguments[1].column->convertToFullColumnIfConst(), arguments[1].type, arguments[1].name}; - return executeImpl(new_arguments, result_type, input_rows_count); - } - - const auto * left_array_col = typeid_cast<const ColumnArray *>(arguments[0].column.get()); - const auto * right_array_col = typeid_cast<const ColumnArray *>(arguments[1].column.get()); - if (!left_array_col->hasEqualOffsets(*right_array_col)) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Two arguments for function {} must have equal sizes", getName()); - - const auto & left_array_type = typeid_cast<const DataTypeArray *>(arguments[0].type.get())->getNestedType(); - new_arguments[0] = {left_array_col->getDataPtr(), left_array_type, arguments[0].name}; - - const auto & right_array_type = typeid_cast<const DataTypeArray *>(arguments[1].type.get())->getNestedType(); - new_arguments[1] = {right_array_col->getDataPtr(), right_array_type, arguments[1].name}; - - result_array_type = typeid_cast<const DataTypeArray *>(result_type.get())->getNestedType(); - - size_t rows_count = 0; - const auto & left_offsets = left_array_col->getOffsets(); - if (!left_offsets.empty()) - rows_count = left_offsets.back(); - auto res = executeImpl(new_arguments, result_array_type, rows_count); - - return ColumnArray::create(res, typeid_cast<const ColumnArray *>(arguments[0].column.get())->getOffsetsPtr()); - } - - ColumnPtr executeTupleNumberOperator(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, - size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const - { - ColumnsWithTypeAndName new_arguments = arguments; - - /// Number argument must be second. - if (isNumber(arguments[0].type)) - std::swap(new_arguments[0], new_arguments[1]); - - auto function = function_builder->build(new_arguments); - - return function->execute(new_arguments, result_type, input_rows_count); - } - - template <typename T, typename ResultDataType> - static auto helperGetOrConvert(const auto & col_const, const auto & col) - { - using ResultType = typename ResultDataType::FieldType; - using NativeResultType = NativeType<ResultType>; - - if constexpr (IsFloatingPoint<ResultDataType> && is_decimal<T>) - return DecimalUtils::convertTo<NativeResultType>(col_const->template getValue<T>(), col.getScale()); - else if constexpr (is_decimal<T>) - return col_const->template getValue<T>().value; - else - return col_const->template getValue<T>(); - } - - template <OpCase op_case, bool left_decimal, bool right_decimal, typename OpImpl, typename OpImplCheck> - void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b, const NullMap * right_nullmap) const - { - if (check_decimal_overflow) - OpImplCheck::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b, right_nullmap); - else - OpImpl::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b, right_nullmap); - } - - template <class LeftDataType, class RightDataType, class ResultDataType> - ColumnPtr executeNumericWithDecimal( - const auto & left, const auto & right, - const ColumnConst * const col_left_const, const ColumnConst * const col_right_const, - const auto * const col_left, const auto * const col_right, - size_t col_left_size, const NullMap * right_nullmap) const - { - using T0 = typename LeftDataType::FieldType; - using T1 = typename RightDataType::FieldType; - using ResultType = typename ResultDataType::FieldType; - - using NativeResultType = NativeType<ResultType>; - using OpImpl = DecimalBinaryOperation<Op, ResultType, false>; - using OpImplCheck = DecimalBinaryOperation<Op, ResultType, true>; - - using ColVecResult = ColumnVectorOrDecimal<ResultType>; - - static constexpr const bool left_is_decimal = is_decimal<T0>; - static constexpr const bool right_is_decimal = is_decimal<T1>; - - typename ColVecResult::MutablePtr col_res = nullptr; - - const ResultDataType type = decimalResultType<is_multiply, is_division>(left, right); - - const ResultType scale_a = [&] - { - if constexpr (IsDataTypeDecimal<RightDataType> && is_division) - return right.getScaleMultiplier(); // the division impl uses only the scale_a - else - { - if constexpr (is_multiply) - // the decimal impl uses scales, but if the result is decimal, both of the arguments are decimal, - // so they would multiply correctly, so we need to scale the result to the neutral element (1). - // The explicit type is needed as the int (in contrast with float) can't be implicitly converted - // to decimal. - return ResultType{1}; - else - return type.scaleFactorFor(left, false); - } - }(); - - const ResultType scale_b = [&] - { - if constexpr (is_multiply) - return ResultType{1}; - else - return type.scaleFactorFor(right, is_division); - }(); - - /// non-vector result - if (col_left_const && col_right_const) - { - const NativeResultType const_a = static_cast<NativeResultType>( - helperGetOrConvert<T0, ResultDataType>(col_left_const, left)); - const NativeResultType const_b = static_cast<NativeResultType>( - helperGetOrConvert<T1, ResultDataType>(col_right_const, right)); - - ResultType res = {}; - if (!right_nullmap || !(*right_nullmap)[0]) - res = check_decimal_overflow - ? OpImplCheck::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b) - : OpImpl::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b); - - return ResultDataType(type.getPrecision(), type.getScale()) - .createColumnConst(col_left_const->size(), toField(res, type.getScale())); - } - - col_res = ColVecResult::create(0, type.getScale()); - - auto & vec_res = col_res->getData(); - vec_res.resize(col_left_size); - - if (col_left && col_right) - { - helperInvokeEither<OpCase::Vector, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>( - col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b, right_nullmap); - } - else if (col_left_const && col_right) - { - const NativeResultType const_a = static_cast<NativeResultType>( - helperGetOrConvert<T0, ResultDataType>(col_left_const, left)); - - helperInvokeEither<OpCase::LeftConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>( - const_a, col_right->getData(), vec_res, scale_a, scale_b, right_nullmap); - } - else if (col_left && col_right_const) - { - const NativeResultType const_b = static_cast<NativeResultType>( - helperGetOrConvert<T1, ResultDataType>(col_right_const, right)); - - helperInvokeEither<OpCase::RightConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>( - col_left->getData(), const_b, vec_res, scale_a, scale_b, right_nullmap); - } - else - return nullptr; - - return col_res; - } - -public: - static constexpr auto name = Name::name; - static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionBinaryArithmetic>(context); } - - explicit FunctionBinaryArithmetic(ContextPtr context_) - : context(context_), - check_decimal_overflow(decimalCheckArithmeticOverflow(context)) - {} - - String getName() const override { return name; } - - size_t getNumberOfArguments() const override { return 2; } - - bool useDefaultImplementationForNulls() const override - { - /// We shouldn't use default implementation for nulls for the case when operation is divide, - /// intDiv or modulo and denominator is Nullable(Something), because it may cause division - /// by zero error (when value is Null we store default value 0 in nested column). - return !division_by_nullable; - } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & arguments) const override - { - return ((IsOperation<Op>::div_int || IsOperation<Op>::modulo || IsOperation<Op>::positive_modulo) && !arguments[1].is_const) - || (IsOperation<Op>::div_floating - && (isDecimalOrNullableDecimal(arguments[0].type) || isDecimalOrNullableDecimal(arguments[1].type))); - } - - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - return getReturnTypeImplStatic(arguments, context); - } - - static DataTypePtr getReturnTypeImplStatic(const DataTypes & arguments, ContextPtr context) - { - /// Special case when multiply aggregate function state - if (isAggregateMultiply(arguments[0], arguments[1])) - { - if (WhichDataType(arguments[0]).isAggregateFunction()) - return arguments[0]; - return arguments[1]; - } - - /// Special case - addition of two aggregate functions states - if (isAggregateAddition(arguments[0], arguments[1])) - { - if (!arguments[0]->equals(*arguments[1])) - throw Exception(ErrorCodes::CANNOT_ADD_DIFFERENT_AGGREGATE_STATES, - "Cannot add aggregate states of different functions: {} and {}", - arguments[0]->getName(), arguments[1]->getName()); - - return arguments[0]; - } - - /// Special case - one or both arguments are IPv4 - if (isIPv4(arguments[0]) || isIPv4(arguments[1])) - { - DataTypes new_arguments { - isIPv4(arguments[0]) ? std::make_shared<DataTypeUInt32>() : arguments[0], - isIPv4(arguments[1]) ? std::make_shared<DataTypeUInt32>() : arguments[1], - }; - - return getReturnTypeImplStatic(new_arguments, context); - } - - - if constexpr (is_plus || is_minus) - { - if (isArray(arguments[0]) && isArray(arguments[1])) - { - DataTypes new_arguments { - static_cast<const DataTypeArray &>(*arguments[0]).getNestedType(), - static_cast<const DataTypeArray &>(*arguments[1]).getNestedType(), - }; - return std::make_shared<DataTypeArray>(getReturnTypeImplStatic(new_arguments, context)); - } - } - - - /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. - if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1], context)) - { - ColumnsWithTypeAndName new_arguments(2); - - for (size_t i = 0; i < 2; ++i) - new_arguments[i].type = arguments[i]; - - /// Interval argument must be second. - if (isDateOrDate32(new_arguments[1].type) || isDateTime(new_arguments[1].type) || isDateTime64(new_arguments[1].type)) - std::swap(new_arguments[0], new_arguments[1]); - - /// Change interval argument to its representation - new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>(); - - auto function = function_builder->build(new_arguments); - return function->getResultType(); - } - - /// Special case when the function is plus, minus or multiply, both arguments are tuples. - if (auto function_builder = getFunctionForTupleArithmetic(arguments[0], arguments[1], context)) - { - ColumnsWithTypeAndName new_arguments(2); - - for (size_t i = 0; i < 2; ++i) - new_arguments[i].type = arguments[i]; - - auto function = function_builder->build(new_arguments); - return function->getResultType(); - } - - /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Tuple. - if (auto function_builder = getFunctionForDateTupleOfIntervalsArithmetic(arguments[0], arguments[1], context)) - { - ColumnsWithTypeAndName new_arguments(2); - - for (size_t i = 0; i < 2; ++i) - new_arguments[i].type = arguments[i]; - - /// Tuple argument must be second. - if (isTuple(new_arguments[0].type)) - std::swap(new_arguments[0], new_arguments[1]); - - auto function = function_builder->build(new_arguments); - return function->getResultType(); - } - - /// Special case when the function is plus or minus, one of arguments is Interval/Tuple of Intervals and another is Interval. - if (auto function_builder = getFunctionForMergeIntervalsArithmetic(arguments[0], arguments[1], context)) - { - ColumnsWithTypeAndName new_arguments(2); - - for (size_t i = 0; i < 2; ++i) - new_arguments[i].type = arguments[i]; - - auto function = function_builder->build(new_arguments); - return function->getResultType(); - } - - /// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number. - if (auto function_builder = getFunctionForTupleAndNumberArithmetic(arguments[0], arguments[1], context)) - { - ColumnsWithTypeAndName new_arguments(2); - - for (size_t i = 0; i < 2; ++i) - new_arguments[i].type = arguments[i]; - - /// Number argument must be second. - if (isNumber(new_arguments[0].type)) - std::swap(new_arguments[0], new_arguments[1]); - - auto function = function_builder->build(new_arguments); - return function->getResultType(); - } - - DataTypePtr type_res; - - const bool valid = castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right) - { - using LeftDataType = std::decay_t<decltype(left)>; - using RightDataType = std::decay_t<decltype(right)>; - - if constexpr ((std::is_same_v<DataTypeFixedString, LeftDataType> || std::is_same_v<DataTypeString, LeftDataType>) || - (std::is_same_v<DataTypeFixedString, RightDataType> || std::is_same_v<DataTypeString, RightDataType>)) - { - if constexpr (std::is_same_v<DataTypeFixedString, LeftDataType> && - std::is_same_v<DataTypeFixedString, RightDataType>) - { - if constexpr (!Op<DataTypeFixedString, DataTypeFixedString>::allow_fixed_string) - return false; - else - { - if (left.getN() == right.getN()) - { - if constexpr (is_bit_hamming_distance) - type_res = std::make_shared<DataTypeUInt16>(); - else - type_res = std::make_shared<LeftDataType>(left.getN()); - return true; - } - } - } - - if constexpr ( - is_bit_hamming_distance - && std::is_same_v<DataTypeString, LeftDataType> && std::is_same_v<DataTypeString, RightDataType>) - type_res = std::make_shared<DataTypeUInt64>(); - else if constexpr (!Op<LeftDataType, RightDataType>::allow_string_integer) - return false; - else if constexpr (!IsIntegral<RightDataType>) - return false; - else if constexpr (std::is_same_v<DataTypeFixedString, LeftDataType>) - type_res = std::make_shared<LeftDataType>(left.getN()); - else - type_res = std::make_shared<DataTypeString>(); - return true; - } - else if constexpr (std::is_same_v<LeftDataType, DataTypeInterval> || std::is_same_v<RightDataType, DataTypeInterval>) - { - if constexpr (std::is_same_v<LeftDataType, DataTypeInterval> && - std::is_same_v<RightDataType, DataTypeInterval>) - { - if constexpr (is_plus || is_minus) - { - if (left.getKind() == right.getKind()) - { - type_res = std::make_shared<LeftDataType>(left.getKind()); - return true; - } - } - } - } - else - { - using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType; - - if constexpr (!std::is_same_v<ResultDataType, InvalidType>) - { - if constexpr (IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>) - { - if constexpr (is_division) - { - if (context->getSettingsRef().decimal_check_overflow) - { - /// Check overflow by using operands scale (based on big decimal division implementation details): - /// big decimal arithmetic is based on big integers, decimal operands are converted to big integers - /// i.e. int_operand = decimal_operand*10^scale - /// For division, left operand will be scaled by right operand scale also to do big integer division, - /// BigInt result = left*10^(left_scale + right_scale) / right * 10^right_scale - /// So, we can check upfront possible overflow just by checking max scale used for left operand - /// Note: it doesn't detect all possible overflow during big decimal division - if (left.getScale() + right.getScale() > ResultDataType::maxPrecision()) - throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Overflow during decimal division"); - } - } - ResultDataType result_type = decimalResultType<is_multiply, is_division>(left, right); - type_res = std::make_shared<ResultDataType>(result_type.getPrecision(), result_type.getScale()); - } - else if constexpr ((IsDataTypeDecimal<LeftDataType> && IsFloatingPoint<RightDataType>) || - (IsDataTypeDecimal<RightDataType> && IsFloatingPoint<LeftDataType>)) - type_res = std::make_shared<DataTypeFloat64>(); - else if constexpr (IsDataTypeDecimal<LeftDataType>) - type_res = std::make_shared<LeftDataType>(left.getPrecision(), left.getScale()); - else if constexpr (IsDataTypeDecimal<RightDataType>) - type_res = std::make_shared<RightDataType>(right.getPrecision(), right.getScale()); - else if constexpr (std::is_same_v<ResultDataType, DataTypeDateTime>) - { - // Special case for DateTime: binary OPS should reuse timezone - // of DateTime argument as timezeone of result type. - // NOTE: binary plus/minus are not allowed on DateTime64, and we are not handling it here. - - const TimezoneMixin * tz = nullptr; - if constexpr (std::is_same_v<RightDataType, DataTypeDateTime>) - tz = &right; - if constexpr (std::is_same_v<LeftDataType, DataTypeDateTime>) - tz = &left; - type_res = std::make_shared<ResultDataType>(*tz); - } - else - type_res = std::make_shared<ResultDataType>(); - return true; - } - } - return false; - }); - - if (valid) - return type_res; - - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal types {} and {} of arguments of function {}", - arguments[0]->getName(), arguments[1]->getName(), String(name)); - } - - ColumnPtr executeFixedString(const ColumnsWithTypeAndName & arguments) const - { - using OpImpl = FixedStringOperationImpl<Op<UInt8, UInt8>>; - using OpReduceImpl = FixedStringReduceOperationImpl<Op<UInt8, UInt8>>; - - const auto * const col_left_raw = arguments[0].column.get(); - const auto * const col_right_raw = arguments[1].column.get(); - - if (const auto * col_left_const = checkAndGetColumnConst<ColumnFixedString>(col_left_raw)) - { - if (const auto * col_right_const = checkAndGetColumnConst<ColumnFixedString>(col_right_raw)) - { - const auto * col_left = checkAndGetColumn<ColumnFixedString>(col_left_const->getDataColumn()); - const auto * col_right = checkAndGetColumn<ColumnFixedString>(col_right_const->getDataColumn()); - - if (col_left->getN() != col_right->getN()) - return nullptr; - - if constexpr (is_bit_hamming_distance) - { - auto col_res = ColumnUInt16::create(); - auto & data = col_res->getData(); - data.resize_fill(col_left->size()); - - OpReduceImpl::template process<OpCase::Vector>( - col_left->getChars().data(), col_right->getChars().data(), data.data(), data.size(), col_left->getN()); - - return ColumnConst::create(std::move(col_res), col_left_raw->size()); - } - else - { - auto col_res = ColumnFixedString::create(col_left->getN()); - auto & out_chars = col_res->getChars(); - - out_chars.resize(col_left->getN()); - - OpImpl::template process<OpCase::Vector>( - col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), out_chars.size(), {}); - - return ColumnConst::create(std::move(col_res), col_left_raw->size()); - } - - } - } - - const bool is_left_column_const = checkAndGetColumnConst<ColumnFixedString>(col_left_raw) != nullptr; - const bool is_right_column_const = checkAndGetColumnConst<ColumnFixedString>(col_right_raw) != nullptr; - - const auto * col_left = is_left_column_const - ? checkAndGetColumn<ColumnFixedString>( - checkAndGetColumnConst<ColumnFixedString>(col_left_raw)->getDataColumn()) - : checkAndGetColumn<ColumnFixedString>(col_left_raw); - const auto * col_right = is_right_column_const - ? checkAndGetColumn<ColumnFixedString>( - checkAndGetColumnConst<ColumnFixedString>(col_right_raw)->getDataColumn()) - : checkAndGetColumn<ColumnFixedString>(col_right_raw); - - if (col_left && col_right) - { - if (col_left->getN() != col_right->getN()) - return nullptr; - - if constexpr (is_bit_hamming_distance) - { - auto col_res = ColumnUInt16::create(); - auto & data = col_res->getData(); - data.resize_fill(is_right_column_const ? col_left->size() : col_right->size()); - - if (!is_left_column_const && !is_right_column_const) - { - OpReduceImpl::template process<OpCase::Vector>( - col_left->getChars().data(), col_right->getChars().data(), data.data(), data.size(), col_left->getN()); - } - else if (is_left_column_const) - { - OpReduceImpl::template process<OpCase::LeftConstant>( - col_left->getChars().data(), col_right->getChars().data(), data.data(), data.size(), col_left->getN()); - } - else - { - OpReduceImpl::template process<OpCase::RightConstant>( - col_left->getChars().data(), col_right->getChars().data(), data.data(), data.size(), col_left->getN()); - } - - return col_res; - } - else - { - auto col_res = ColumnFixedString::create(col_left->getN()); - auto & out_chars = col_res->getChars(); - out_chars.resize((is_right_column_const ? col_left->size() : col_right->size()) * col_left->getN()); - - if (!is_left_column_const && !is_right_column_const) - { - OpImpl::template process<OpCase::Vector>( - col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), out_chars.size(), {}); - } - else if (is_left_column_const) - { - OpImpl::template process<OpCase::LeftConstant>( - col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), out_chars.size(), col_left->getN()); - } - else - { - OpImpl::template process<OpCase::RightConstant>( - col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), out_chars.size(), col_left->getN()); - } - - return col_res; - } - } - return nullptr; - } - - /// Only used for bitHammingDistance - ColumnPtr executeString(const ColumnsWithTypeAndName & arguments) const - { - using OpImpl = StringReduceOperationImpl<Op<UInt8, UInt8>>; - - const auto * const col_left_raw = arguments[0].column.get(); - const auto * const col_right_raw = arguments[1].column.get(); - - if (const auto * col_left_const = checkAndGetColumnConst<ColumnString>(col_left_raw)) - { - if (const auto * col_right_const = checkAndGetColumnConst<ColumnString>(col_right_raw)) - { - const auto * col_left = checkAndGetColumn<ColumnString>(col_left_const->getDataColumn()); - const auto * col_right = checkAndGetColumn<ColumnString>(col_right_const->getDataColumn()); - - std::string_view a = col_left->getDataAt(0).toView(); - std::string_view b = col_right->getDataAt(0).toView(); - - auto res = OpImpl::constConst(a, b); - - return DataTypeUInt64{}.createColumnConst(1, res); - } - } - - const bool is_left_column_const = checkAndGetColumnConst<ColumnString>(col_left_raw) != nullptr; - const bool is_right_column_const = checkAndGetColumnConst<ColumnString>(col_right_raw) != nullptr; - - const auto * col_left = is_left_column_const - ? checkAndGetColumn<ColumnString>(checkAndGetColumnConst<ColumnString>(col_left_raw)->getDataColumn()) - : checkAndGetColumn<ColumnString>(col_left_raw); - const auto * col_right = is_right_column_const - ? checkAndGetColumn<ColumnString>(checkAndGetColumnConst<ColumnString>(col_right_raw)->getDataColumn()) - : checkAndGetColumn<ColumnString>(col_right_raw); - - if (col_left && col_right) - { - auto col_res = ColumnUInt64::create(); - auto & data = col_res->getData(); - data.resize(is_right_column_const ? col_left->size() : col_right->size()); - - if (!is_left_column_const && !is_right_column_const) - { - OpImpl::vectorVector( - col_left->getChars(), col_left->getOffsets(), col_right->getChars(), col_right->getOffsets(), data); - } - else if (is_left_column_const) - { - std::string_view str_view = col_left->getDataAt(0).toView(); - OpImpl::vectorConstant(col_right->getChars(), col_right->getOffsets(), str_view, data); - } - else - { - std::string_view str_view = col_right->getDataAt(0).toView(); - OpImpl::vectorConstant(col_left->getChars(), col_left->getOffsets(), str_view, data); - } - - return col_res; - } - return nullptr; - } - -template <typename LeftColumnType, typename A, typename B> -ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A & left, const B & right) const - { - using LeftDataType = std::decay_t<decltype(left)>; - using RightDataType = std::decay_t<decltype(right)>; - - const auto * const col_left_raw = arguments[0].column.get(); - const auto * const col_right_raw = arguments[1].column.get(); - using T1 = typename RightDataType::FieldType; - - using ColVecT1 = ColumnVector<T1>; - const ColVecT1 * const col_right = checkAndGetColumn<ColVecT1>(col_right_raw); - const ColumnConst * const col_right_const = checkAndGetColumnConst<ColVecT1>(col_right_raw); - - using OpImpl = StringIntegerOperationImpl<T1, Op<LeftDataType, T1>>; - - const ColumnConst * const col_left_const = checkAndGetColumnConst<LeftColumnType>(col_left_raw); - - const auto * col_left = col_left_const ? checkAndGetColumn<LeftColumnType>(col_left_const->getDataColumn()) - : checkAndGetColumn<LeftColumnType>(col_left_raw); - - if (!col_left) - return nullptr; - - const typename LeftColumnType::Chars & in_vec = col_left->getChars(); - - typename LeftColumnType::MutablePtr col_res; - if constexpr (std::is_same_v<LeftDataType, DataTypeFixedString>) - col_res = LeftColumnType::create(col_left->getN()); - else - col_res = LeftColumnType::create(); - - typename LeftColumnType::Chars & out_vec = col_res->getChars(); - - if (col_left_const && col_right_const) - { - const T1 value = col_right_const->template getValue<T1>(); - if constexpr (std::is_same_v<LeftDataType, DataTypeFixedString>) - { - OpImpl::template processFixedString<OpCase::Vector>(in_vec.data(), col_left->getN(), &value, out_vec, 1); - } - else - { - ColumnString::Offsets & out_offsets = col_res->getOffsets(); - OpImpl::template processString<OpCase::Vector>(in_vec.data(), col_left->getOffsets().data(), &value, out_vec, out_offsets, 1); - } - - return ColumnConst::create(std::move(col_res), col_left_const->size()); - } - else if (!col_left_const && !col_right_const && col_right) - { - if constexpr (std::is_same_v<LeftDataType, DataTypeFixedString>) - { - OpImpl::template processFixedString<OpCase::Vector>(in_vec.data(), col_left->getN(), col_right->getData().data(), out_vec, col_left->size()); - } - else - { - ColumnString::Offsets & out_offsets = col_res->getOffsets(); - out_offsets.reserve(col_left->size()); - OpImpl::template processString<OpCase::Vector>( - in_vec.data(), col_left->getOffsets().data(), col_right->getData().data(), out_vec, out_offsets, col_left->size()); - } - } - else if (col_left_const && col_right) - { - if constexpr (std::is_same_v<LeftDataType, DataTypeFixedString>) - { - OpImpl::template processFixedString<OpCase::LeftConstant>( - in_vec.data(), col_left->getN(), col_right->getData().data(), out_vec, col_right->size()); - } - else - { - ColumnString::Offsets & out_offsets = col_res->getOffsets(); - out_offsets.reserve(col_right->size()); - OpImpl::template processString<OpCase::LeftConstant>( - in_vec.data(), col_left->getOffsets().data(), col_right->getData().data(), out_vec, out_offsets, col_right->size()); - } - } - else if (col_right_const) - { - const T1 value = col_right_const->template getValue<T1>(); - if constexpr (std::is_same_v<LeftDataType, DataTypeFixedString>) - { - OpImpl::template processFixedString<OpCase::RightConstant>(in_vec.data(), col_left->getN(), &value, out_vec, col_left->size()); - } - else - { - ColumnString::Offsets & out_offsets = col_res->getOffsets(); - out_offsets.reserve(col_left->size()); - OpImpl::template processString<OpCase::RightConstant>( - in_vec.data(), col_left->getOffsets().data(), &value, out_vec, out_offsets, col_left->size()); - } - } - else - return nullptr; - - return col_res; - } - - template <typename A, typename B> - ColumnPtr executeNumeric(const ColumnsWithTypeAndName & arguments, const A & left, const B & right, const NullMap * right_nullmap) const - { - using LeftDataType = std::decay_t<decltype(left)>; - using RightDataType = std::decay_t<decltype(right)>; - using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType; - - if constexpr (std::is_same_v<ResultDataType, InvalidType>) - return nullptr; - else // we can't avoid the else because otherwise the compiler may assume the ResultDataType may be Invalid - // and that would produce the compile error. - { - constexpr bool decimal_with_float = (IsDataTypeDecimal<LeftDataType> && IsFloatingPoint<RightDataType>) - || (IsFloatingPoint<LeftDataType> && IsDataTypeDecimal<RightDataType>); - - using T0 = std::conditional_t<decimal_with_float, Float64, typename LeftDataType::FieldType>; - using T1 = std::conditional_t<decimal_with_float, Float64, typename RightDataType::FieldType>; - using ResultType = typename ResultDataType::FieldType; - using ColVecT0 = ColumnVectorOrDecimal<T0>; - using ColVecT1 = ColumnVectorOrDecimal<T1>; - using ColVecResult = ColumnVectorOrDecimal<ResultType>; - - ColumnPtr left_col = nullptr; - ColumnPtr right_col = nullptr; - - /// When Decimal op Float32/64, convert both of them into Float64 - if constexpr (decimal_with_float) - { - const auto converted_type = std::make_shared<DataTypeFloat64>(); - left_col = castColumn(arguments[0], converted_type); - right_col = castColumn(arguments[1], converted_type); - } - else - { - left_col = arguments[0].column; - right_col = arguments[1].column; - } - const auto * const col_left_raw = left_col.get(); - const auto * const col_right_raw = right_col.get(); - - const size_t col_left_size = col_left_raw->size(); - - const ColumnConst * const col_left_const = checkAndGetColumnConst<ColVecT0>(col_left_raw); - const ColumnConst * const col_right_const = checkAndGetColumnConst<ColVecT1>(col_right_raw); - - const ColVecT0 * const col_left = checkAndGetColumn<ColVecT0>(col_left_raw); - const ColVecT1 * const col_right = checkAndGetColumn<ColVecT1>(col_right_raw); - - if constexpr (IsDataTypeDecimal<ResultDataType>) - { - return executeNumericWithDecimal<LeftDataType, RightDataType, ResultDataType>( - left, right, - col_left_const, col_right_const, - col_left, col_right, - col_left_size, - right_nullmap); - } - else // can't avoid else and another indentation level, otherwise the compiler would try to instantiate - // ColVecResult for Decimals which would lead to a compile error. - { - using OpImpl = BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>; - - /// non-vector result - if (col_left_const && col_right_const) - { - const auto res = right_nullmap && (*right_nullmap)[0] ? ResultType() : OpImpl::process( - col_left_const->template getValue<T0>(), - col_right_const->template getValue<T1>()); - - return ResultDataType().createColumnConst(col_left_const->size(), toField(res)); - } - - typename ColVecResult::MutablePtr col_res = ColVecResult::create(); - - auto & vec_res = col_res->getData(); - vec_res.resize(col_left_size); - - if (col_left && col_right) - { - OpImpl::template process<OpCase::Vector>( - col_left->getData().data(), - col_right->getData().data(), - vec_res.data(), - vec_res.size(), - right_nullmap); - } - else if (col_left_const && col_right) - { - const T0 value = col_left_const->template getValue<T0>(); - - OpImpl::template process<OpCase::LeftConstant>( - &value, - col_right->getData().data(), - vec_res.data(), - vec_res.size(), - right_nullmap); - } - else if (col_left && col_right_const) - { - const T1 value = col_right_const->template getValue<T1>(); - - OpImpl::template process<OpCase::RightConstant>( - col_left->getData().data(), &value, vec_res.data(), vec_res.size(), right_nullmap); - } - else - return nullptr; - - return col_res; - } - } - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override - { - /// Special case when multiply aggregate function state - if (isAggregateMultiply(arguments[0].type, arguments[1].type)) - { - return executeAggregateMultiply(arguments, result_type, input_rows_count); - } - - /// Special case - addition of two aggregate functions states - if (isAggregateAddition(arguments[0].type, arguments[1].type)) - { - return executeAggregateAddition(arguments, result_type, input_rows_count); - } - - /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. - if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0].type, arguments[1].type, context)) - { - return executeDateTimeIntervalPlusMinus(arguments, result_type, input_rows_count, function_builder); - } - - /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Tuple. - if (auto function_builder = getFunctionForDateTupleOfIntervalsArithmetic(arguments[0].type, arguments[1].type, context)) - { - return executeDateTimeTupleOfIntervalsPlusMinus(arguments, result_type, input_rows_count, function_builder); - } - - /// Special case when the function is plus or minus, one of arguments is Interval/Tuple of Intervals and another is Interval. - if (auto function_builder = getFunctionForMergeIntervalsArithmetic(arguments[0].type, arguments[1].type, context)) - { - return executeIntervalTupleOfIntervalsPlusMinus(arguments, result_type, input_rows_count, function_builder); - } - - /// Special case when the function is plus, minus or multiply, both arguments are tuples. - if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context)) - { - return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count); - } - - /// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number. - if (auto function_builder = getFunctionForTupleAndNumberArithmetic(arguments[0].type, arguments[1].type, context)) - { - return executeTupleNumberOperator(arguments, result_type, input_rows_count, function_builder); - } - - return executeImpl2(arguments, result_type, input_rows_count); - } - - ColumnPtr executeImpl2(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, const NullMap * right_nullmap = nullptr) const - { - const auto & left_argument = arguments[0]; - const auto & right_argument = arguments[1]; - - /// Process special case when operation is divide, intDiv or modulo and denominator - /// is Nullable(Something) to prevent division by zero error. - if (division_by_nullable && !right_nullmap) - { - assert(right_argument.type->isNullable()); - - bool is_const = checkColumnConst<ColumnNullable>(right_argument.column.get()); - const ColumnNullable * nullable_column = is_const ? checkAndGetColumnConstData<ColumnNullable>(right_argument.column.get()) - : checkAndGetColumn<ColumnNullable>(*right_argument.column); - - const auto & null_bytemap = nullable_column->getNullMapData(); - auto res = executeImpl2(createBlockWithNestedColumns(arguments), removeNullable(result_type), input_rows_count, &null_bytemap); - return wrapInNullable(res, arguments, result_type, input_rows_count); - } - - /// Special case - one or both arguments are IPv4 - if (isIPv4(arguments[0].type) || isIPv4(arguments[1].type)) - { - ColumnsWithTypeAndName new_arguments { - { - isIPv4(arguments[0].type) ? castColumn(arguments[0], std::make_shared<DataTypeUInt32>()) : arguments[0].column, - isIPv4(arguments[0].type) ? std::make_shared<DataTypeUInt32>() : arguments[0].type, - arguments[0].name, - }, - { - isIPv4(arguments[1].type) ? castColumn(arguments[1], std::make_shared<DataTypeUInt32>()) : arguments[1].column, - isIPv4(arguments[1].type) ? std::make_shared<DataTypeUInt32>() : arguments[1].type, - arguments[1].name - } - }; - - return executeImpl2(new_arguments, result_type, input_rows_count, right_nullmap); - } - - const auto * const left_generic = left_argument.type.get(); - const auto * const right_generic = right_argument.type.get(); - ColumnPtr res; - - const bool valid = castBothTypes(left_generic, right_generic, [&](const auto & left, const auto & right) - { - using LeftDataType = std::decay_t<decltype(left)>; - using RightDataType = std::decay_t<decltype(right)>; - - if constexpr ((std::is_same_v<DataTypeFixedString, LeftDataType> || std::is_same_v<DataTypeString, LeftDataType>) || - (std::is_same_v<DataTypeFixedString, RightDataType> || std::is_same_v<DataTypeString, RightDataType>)) - { - if constexpr (std::is_same_v<DataTypeFixedString, LeftDataType> && - std::is_same_v<DataTypeFixedString, RightDataType>) - { - if constexpr (!Op<DataTypeFixedString, DataTypeFixedString>::allow_fixed_string) - return false; - else - return (res = executeFixedString(arguments)) != nullptr; - } - - if constexpr ( - is_bit_hamming_distance - && std::is_same_v<DataTypeString, LeftDataType> && std::is_same_v<DataTypeString, RightDataType>) - return (res = executeString(arguments)) != nullptr; - else if constexpr (!Op<LeftDataType, RightDataType>::allow_string_integer) - return false; - else if constexpr (!IsIntegral<RightDataType>) - return false; - else if constexpr (std::is_same_v<DataTypeFixedString, LeftDataType>) - { - return (res = executeStringInteger<ColumnFixedString>(arguments, left, right)) != nullptr; - } - else if constexpr (std::is_same_v<DataTypeString, LeftDataType>) - return (res = executeStringInteger<ColumnString>(arguments, left, right)) != nullptr; - } - else - return (res = executeNumeric(arguments, left, right, right_nullmap)) != nullptr; - }); - - if (isArray(result_type)) - return executeArrayImpl(arguments, result_type, input_rows_count); - - if (!valid) - { - // This is a logical error, because the types should have been checked - // by getReturnTypeImpl(). - throw Exception(ErrorCodes::LOGICAL_ERROR, - "Arguments of '{}' have incorrect data types: '{}' of type '{}'," - " '{}' of type '{}'", getName(), - left_argument.name, left_argument.type->getName(), - right_argument.name, right_argument.type->getName()); - } - - return res; - } - -#if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override - { - if (2 != arguments.size()) - return false; - - if (!canBeNativeType(*arguments[0]) || !canBeNativeType(*arguments[1]) || !canBeNativeType(*result_type)) - return false; - - WhichDataType data_type_lhs(arguments[0]); - WhichDataType data_type_rhs(arguments[1]); - if ((data_type_lhs.isDateOrDate32() || data_type_lhs.isDateTime()) || - (data_type_rhs.isDateOrDate32() || data_type_rhs.isDateTime())) - return false; - - return castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right) - { - using LeftDataType = std::decay_t<decltype(left)>; - using RightDataType = std::decay_t<decltype(right)>; - if constexpr (!std::is_same_v<DataTypeFixedString, LeftDataType> && - !std::is_same_v<DataTypeFixedString, RightDataType> && - !std::is_same_v<DataTypeString, LeftDataType> && - !std::is_same_v<DataTypeString, RightDataType>) - { - using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType; - using OpSpec = Op<typename LeftDataType::FieldType, typename RightDataType::FieldType>; - if constexpr (!std::is_same_v<ResultDataType, InvalidType> && !IsDataTypeDecimal<ResultDataType> && OpSpec::compilable) - return true; - } - return false; - }); - } - - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override - { - assert(2 == arguments.size()); - - llvm::Value * result = nullptr; - castBothTypes(arguments[0].type.get(), arguments[1].type.get(), [&](const auto & left, const auto & right) - { - using LeftDataType = std::decay_t<decltype(left)>; - using RightDataType = std::decay_t<decltype(right)>; - if constexpr (!std::is_same_v<DataTypeFixedString, LeftDataType> && - !std::is_same_v<DataTypeFixedString, RightDataType> && - !std::is_same_v<DataTypeString, LeftDataType> && - !std::is_same_v<DataTypeString, RightDataType>) - { - using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType; - using OpSpec = Op<typename LeftDataType::FieldType, typename RightDataType::FieldType>; - if constexpr (!std::is_same_v<ResultDataType, InvalidType> && !IsDataTypeDecimal<ResultDataType> && OpSpec::compilable) - { - auto & b = static_cast<llvm::IRBuilder<> &>(builder); - auto * lval = nativeCast(b, arguments[0], result_type); - auto * rval = nativeCast(b, arguments[1], result_type); - result = OpSpec::compile(b, lval, rval, std::is_signed_v<typename ResultDataType::FieldType>); - - return true; - } - } - - return false; - }); - - return result; - } -#endif - - bool canBeExecutedOnDefaultArguments() const override { return valid_on_default_arguments; } -}; - - -template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true, bool division_by_nullable = false> -class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, division_by_nullable> -{ -public: - using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, division_by_nullable>; - using Monotonicity = typename Base::Monotonicity; - - static FunctionPtr create( - const ColumnWithTypeAndName & left_, - const ColumnWithTypeAndName & right_, - const DataTypePtr & return_type_, - ContextPtr context) - { - return std::make_shared<FunctionBinaryArithmeticWithConstants>(left_, right_, return_type_, context); - } - - FunctionBinaryArithmeticWithConstants( - const ColumnWithTypeAndName & left_, - const ColumnWithTypeAndName & right_, - const DataTypePtr & return_type_, - ContextPtr context_) - : Base(context_), left(left_), right(right_), return_type(return_type_) - { - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override - { - if (left.column && isColumnConst(*left.column) && arguments.size() == 1) - { - ColumnsWithTypeAndName columns_with_constant - = {{left.column->cloneResized(input_rows_count), left.type, left.name}, - arguments[0]}; - - return Base::executeImpl(columns_with_constant, result_type, input_rows_count); - } - else if (right.column && isColumnConst(*right.column) && arguments.size() == 1) - { - ColumnsWithTypeAndName columns_with_constant - = {arguments[0], - {right.column->cloneResized(input_rows_count), right.type, right.name}}; - - return Base::executeImpl(columns_with_constant, result_type, input_rows_count); - } - else - return Base::executeImpl(arguments, result_type, input_rows_count); - } - - bool hasInformationAboutMonotonicity() const override - { - const std::string_view name_view = Name::name; - return (name_view == "minus" || name_view == "plus" || name_view == "divide" || name_view == "intDiv"); - } - - Monotonicity getMonotonicityForRange(const IDataType &, const Field & left_point, const Field & right_point) const override - { - const std::string_view name_view = Name::name; - - // For simplicity, we treat null values as monotonicity breakers, except for variable / non-zero constant. - if (left_point.isNull() || right_point.isNull()) - { - if (name_view == "divide" || name_view == "intDiv") - { - // variable / constant - if (right.column && isColumnConst(*right.column)) - { - auto constant = (*right.column)[0]; - if (applyVisitor(FieldVisitorAccurateEquals(), constant, Field(0))) - return {false, true, false}; // variable / 0 is undefined, let's treat it as non-monotonic - bool is_constant_positive = applyVisitor(FieldVisitorAccurateLess(), Field(0), constant); - - // division is saturated to `inf`, thus it doesn't have overflow issues. - return {true, is_constant_positive, true}; - } - } - return {false, true, false, false}; - } - - // For simplicity, we treat every single value interval as positive monotonic. - if (applyVisitor(FieldVisitorAccurateEquals(), left_point, right_point)) - return {true, true, false, false}; - - if (name_view == "minus" || name_view == "plus") - { - // const +|- variable - if (left.column && isColumnConst(*left.column)) - { - auto left_type = removeNullable(removeLowCardinality(left.type)); - auto right_type = removeNullable(removeLowCardinality(right.type)); - auto ret_type = removeNullable(removeLowCardinality(return_type)); - - auto transform = [&](const Field & point) - { - ColumnsWithTypeAndName columns_with_constant - = {{left_type->createColumnConst(1, (*left.column)[0]), left_type, left.name}, - {right_type->createColumnConst(1, point), right_type, right.name}}; - - /// This is a bit dangerous to call Base::executeImpl cause it ignores `use Default Implementation For XXX` flags. - /// It was possible to check monotonicity for nullable right type which result to exception. - /// Adding removeNullable above fixes the issue, but some other inconsistency may left. - auto col = Base::executeImpl(columns_with_constant, ret_type, 1); - Field point_transformed; - col->get(0, point_transformed); - return point_transformed; - }; - - bool is_positive_monotonicity = applyVisitor(FieldVisitorAccurateLess(), left_point, right_point) - == applyVisitor(FieldVisitorAccurateLess(), transform(left_point), transform(right_point)); - - if (name_view == "plus") - { - // Check if there is an overflow - if (is_positive_monotonicity) - return {true, true, false, true}; - else - return {false, true, false, false}; - } - else - { - // Check if there is an overflow - if (!is_positive_monotonicity) - return {true, false, false, true}; - else - return {false, false, false, false}; - } - } - // variable +|- constant - else if (right.column && isColumnConst(*right.column)) - { - auto left_type = removeNullable(removeLowCardinality(left.type)); - auto right_type = removeNullable(removeLowCardinality(right.type)); - auto ret_type = removeNullable(removeLowCardinality(return_type)); - - auto transform = [&](const Field & point) - { - ColumnsWithTypeAndName columns_with_constant - = {{left_type->createColumnConst(1, point), left_type, left.name}, - {right_type->createColumnConst(1, (*right.column)[0]), right_type, right.name}}; - - auto col = Base::executeImpl(columns_with_constant, ret_type, 1); - Field point_transformed; - col->get(0, point_transformed); - return point_transformed; - }; - - // Check if there is an overflow - if (applyVisitor(FieldVisitorAccurateLess(), left_point, right_point) - == applyVisitor(FieldVisitorAccurateLess(), transform(left_point), transform(right_point))) - return {true, true, false, true}; - else - return {false, true, false, false}; - } - } - if (name_view == "divide" || name_view == "intDiv") - { - bool is_strict = name_view == "divide"; - - // const / variable - if (left.column && isColumnConst(*left.column)) - { - auto constant = (*left.column)[0]; - if (applyVisitor(FieldVisitorAccurateEquals(), constant, Field(0))) - return {true, true, false, false}; // 0 / 0 is undefined, thus it's not always monotonic - - bool is_constant_positive = applyVisitor(FieldVisitorAccurateLess(), Field(0), constant); - if (applyVisitor(FieldVisitorAccurateLess(), left_point, Field(0)) - && applyVisitor(FieldVisitorAccurateLess(), right_point, Field(0))) - { - return {true, is_constant_positive, false, is_strict}; - } - else if ( - applyVisitor(FieldVisitorAccurateLess(), Field(0), left_point) - && applyVisitor(FieldVisitorAccurateLess(), Field(0), right_point)) - { - return {true, !is_constant_positive, false, is_strict}; - } - } - // variable / constant - else if (right.column && isColumnConst(*right.column)) - { - auto constant = (*right.column)[0]; - if (applyVisitor(FieldVisitorAccurateEquals(), constant, Field(0))) - return {false, true, false, false}; // variable / 0 is undefined, let's treat it as non-monotonic - - bool is_constant_positive = applyVisitor(FieldVisitorAccurateLess(), Field(0), constant); - // division is saturated to `inf`, thus it doesn't have overflow issues. - return {true, is_constant_positive, true, is_strict}; - } - } - return {false, true, false}; - } - -private: - ColumnWithTypeAndName left; - ColumnWithTypeAndName right; - DataTypePtr return_type; -}; - -template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true> -class BinaryArithmeticOverloadResolver : public IFunctionOverloadResolver -{ -public: - static constexpr auto name = Name::name; - static FunctionOverloadResolverPtr create(ContextPtr context) - { - return std::make_unique<BinaryArithmeticOverloadResolver>(context); - } - - explicit BinaryArithmeticOverloadResolver(ContextPtr context_) : context(context_) {} - - String getName() const override { return name; } - size_t getNumberOfArguments() const override { return 2; } - bool isVariadic() const override { return false; } - - FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override - { - /// Check the case when operation is divide, intDiv or modulo and denominator is Nullable(Something). - /// For divide operation we should check only Nullable(Decimal), because only this case can throw division by zero error. - bool division_by_nullable = !arguments[0].type->onlyNull() && !arguments[1].type->onlyNull() && arguments[1].type->isNullable() - && (IsOperation<Op>::div_int || IsOperation<Op>::modulo || IsOperation<Op>::positive_modulo - || (IsOperation<Op>::div_floating - && (isDecimalOrNullableDecimal(arguments[0].type) || isDecimalOrNullableDecimal(arguments[1].type)))); - - /// More efficient specialization for two numeric arguments. - if (arguments.size() == 2 - && ((arguments[0].column && isColumnConst(*arguments[0].column)) - || (arguments[1].column && isColumnConst(*arguments[1].column)))) - { - auto function = division_by_nullable ? FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments, true>::create( - arguments[0], arguments[1], return_type, context) - : FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments, false>::create( - arguments[0], arguments[1], return_type, context); - - return std::make_unique<FunctionToFunctionBaseAdaptor>( - function, - collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), - return_type); - } - auto function = division_by_nullable - ? FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, true>::create(context) - : FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, false>::create(context); - - return std::make_unique<FunctionToFunctionBaseAdaptor>( - function, - collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), - return_type); - - } - - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - if (arguments.size() != 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Number of arguments for function {} doesn't match: passed {}, should be 2", - getName(), arguments.size()); - return FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::getReturnTypeImplStatic(arguments, context); - } - -private: - ContextPtr context; -}; -} |
