diff options
author | freulaeux <freulaeux@yandex-team.ru> | 2022-02-10 16:52:27 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:52:27 +0300 |
commit | 97307a7569f3826a54684399895ab653340d3d5d (patch) | |
tree | 0bb7987fbc03d534b88e7da11d512b6d75f5d15f | |
parent | 75bb4f570f7414195473e1290048787dddf72c4b (diff) | |
download | ydb-97307a7569f3826a54684399895ab653340d3d5d.tar.gz |
Restoring authorship annotation for <freulaeux@yandex-team.ru>. Commit 1 of 2.
24 files changed, 2030 insertions, 2030 deletions
diff --git a/ydb/core/formats/arrow_helpers.cpp b/ydb/core/formats/arrow_helpers.cpp index 3e1e1b0444..346e6b9fe6 100644 --- a/ydb/core/formats/arrow_helpers.cpp +++ b/ydb/core/formats/arrow_helpers.cpp @@ -6,11 +6,11 @@ #include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> #include <library/cpp/containers/stack_vector/stack_vec.h> -#include <memory> +#include <memory> #define Y_VERIFY_OK(status) Y_VERIFY(status.ok(), "%s", status.ToString().c_str()) @@ -386,8 +386,8 @@ std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow:: return arrow::RecordBatch::Make(dstSchema, srcBatch->num_rows(), columns); } - - + + std::shared_ptr<arrow::Table> CombineInTable(const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches) { auto res = arrow::Table::FromRecordBatches(batches); if (!res.ok()) { @@ -1001,41 +1001,41 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err return true; } -std::shared_ptr<arrow::Array> NumVecToArray(const std::shared_ptr<arrow::DataType>& type, - const std::vector<double>& vec) { - std::shared_ptr<arrow::Array> out; - SwitchType(type->id(), [&](const auto& type) { - using TWrap = std::decay_t<decltype(type)>; - if constexpr (arrow::is_number_type<typename TWrap::T>::value) { - typename arrow::TypeTraits<typename TWrap::T>::BuilderType builder; - for (const auto val : vec) { - Y_VERIFY(builder.Append(static_cast<typename TWrap::T::c_type>(val)).ok()); - } - Y_VERIFY(builder.Finish(&out).ok()); - return true; - } - return false; - }); - return out; -} - -std::shared_ptr<arrow::Array> BoolVecToArray(const std::vector<bool>& vec) { - std::shared_ptr<arrow::Array> out; - arrow::BooleanBuilder builder; - for (const auto val : vec) { - Y_VERIFY(builder.Append(val).ok()); - } - Y_VERIFY(builder.Finish(&out).ok()); - return out; -} - - -bool ArrayScalarsEqual(const std::shared_ptr<arrow::Array>& lhs, const std::shared_ptr<arrow::Array>& rhs) { - bool res = lhs->length() == rhs->length(); - for (int64_t i = 0; i < lhs->length() && res; ++i) { - res &= arrow::ScalarEquals(*lhs->GetScalar(i).ValueOrDie(), *rhs->GetScalar(i).ValueOrDie()); - } - return res; -} - +std::shared_ptr<arrow::Array> NumVecToArray(const std::shared_ptr<arrow::DataType>& type, + const std::vector<double>& vec) { + std::shared_ptr<arrow::Array> out; + SwitchType(type->id(), [&](const auto& type) { + using TWrap = std::decay_t<decltype(type)>; + if constexpr (arrow::is_number_type<typename TWrap::T>::value) { + typename arrow::TypeTraits<typename TWrap::T>::BuilderType builder; + for (const auto val : vec) { + Y_VERIFY(builder.Append(static_cast<typename TWrap::T::c_type>(val)).ok()); + } + Y_VERIFY(builder.Finish(&out).ok()); + return true; + } + return false; + }); + return out; +} + +std::shared_ptr<arrow::Array> BoolVecToArray(const std::vector<bool>& vec) { + std::shared_ptr<arrow::Array> out; + arrow::BooleanBuilder builder; + for (const auto val : vec) { + Y_VERIFY(builder.Append(val).ok()); + } + Y_VERIFY(builder.Finish(&out).ok()); + return out; +} + + +bool ArrayScalarsEqual(const std::shared_ptr<arrow::Array>& lhs, const std::shared_ptr<arrow::Array>& rhs) { + bool res = lhs->length() == rhs->length(); + for (int64_t i = 0; i < lhs->length() && res; ++i) { + res &= arrow::ScalarEquals(*lhs->GetScalar(i).ValueOrDie(), *rhs->GetScalar(i).ValueOrDie()); + } + return res; +} + } diff --git a/ydb/core/formats/arrow_helpers.h b/ydb/core/formats/arrow_helpers.h index cd3ec9f865..534da6a646 100644 --- a/ydb/core/formats/arrow_helpers.h +++ b/ydb/core/formats/arrow_helpers.h @@ -202,9 +202,9 @@ inline bool HasNulls(const std::shared_ptr<arrow::Array>& column) { return column->null_bitmap_data(); } -bool ArrayScalarsEqual(const std::shared_ptr<arrow::Array>& lhs, const std::shared_ptr<arrow::Array>& rhs); -std::shared_ptr<arrow::Array> NumVecToArray(const std::shared_ptr<arrow::DataType>& type, - const std::vector<double>& vec); -std::shared_ptr<arrow::Array> BoolVecToArray(const std::vector<bool>& vec); +bool ArrayScalarsEqual(const std::shared_ptr<arrow::Array>& lhs, const std::shared_ptr<arrow::Array>& rhs); +std::shared_ptr<arrow::Array> NumVecToArray(const std::shared_ptr<arrow::DataType>& type, + const std::vector<double>& vec); +std::shared_ptr<arrow::Array> BoolVecToArray(const std::vector<bool>& vec); } diff --git a/ydb/core/formats/bit_cast.h b/ydb/core/formats/bit_cast.h index 486645877d..b4d66de9b8 100644 --- a/ydb/core/formats/bit_cast.h +++ b/ydb/core/formats/bit_cast.h @@ -1,27 +1,27 @@ -#pragma once - -#include <string.h> -#include <algorithm> -#include <type_traits> - - -/** \brief Returns value `from` converted to type `To` while retaining bit representation. - * `To` and `From` must satisfy `CopyConstructible`. - */ -template <typename To, typename From> -std::decay_t<To> bit_cast(const From & from) -{ - To res {}; - memcpy(static_cast<void*>(&res), &from, std::min(sizeof(res), sizeof(from))); - return res; -} - -/** \brief Returns value `from` converted to type `To` while retaining bit representation. - * `To` and `From` must satisfy `CopyConstructible`. - */ -template <typename To, typename From> -std::decay_t<To> safe_bit_cast(const From & from) -{ - static_assert(sizeof(To) == sizeof(From), "bit cast on types of different width"); - return bit_cast<To, From>(from); -}
\ No newline at end of file +#pragma once + +#include <string.h> +#include <algorithm> +#include <type_traits> + + +/** \brief Returns value `from` converted to type `To` while retaining bit representation. + * `To` and `From` must satisfy `CopyConstructible`. + */ +template <typename To, typename From> +std::decay_t<To> bit_cast(const From & from) +{ + To res {}; + memcpy(static_cast<void*>(&res), &from, std::min(sizeof(res), sizeof(from))); + return res; +} + +/** \brief Returns value `from` converted to type `To` while retaining bit representation. + * `To` and `From` must satisfy `CopyConstructible`. + */ +template <typename To, typename From> +std::decay_t<To> safe_bit_cast(const From & from) +{ + static_assert(sizeof(To) == sizeof(From), "bit cast on types of different width"); + return bit_cast<To, From>(from); +}
\ No newline at end of file diff --git a/ydb/core/formats/clickhouse_type_traits.h b/ydb/core/formats/clickhouse_type_traits.h index b694d2b91f..3b4a079114 100644 --- a/ydb/core/formats/clickhouse_type_traits.h +++ b/ydb/core/formats/clickhouse_type_traits.h @@ -1,101 +1,101 @@ -#pragma once -#include <type_traits> +#pragma once +#include <type_traits> #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> - + namespace NKikimr::NArrow { - -constexpr size_t NextSize(size_t size) { - if (size < 8) { - return size * 2; - } - return size; -} - -struct TError {}; - -template <bool is_signed, bool is_floating, size_t size> -struct TConstruct { - using Type = TError; -}; - -template <> struct TConstruct<false, false, 1> { using Type = arrow::UInt8Type; }; -template <> struct TConstruct<false, false, 2> { using Type = arrow::UInt16Type; }; -template <> struct TConstruct<false, false, 4> { using Type = arrow::UInt32Type; }; -template <> struct TConstruct<false, false, 8> { using Type = arrow::UInt64Type; }; -template <> struct TConstruct<false, true, 1> { using Type = arrow::FloatType; }; -template <> struct TConstruct<false, true, 2> { using Type = arrow::FloatType; }; -template <> struct TConstruct<false, true, 4> { using Type = arrow::FloatType; }; -template <> struct TConstruct<false, true, 8> { using Type = arrow::DoubleType; }; -template <> struct TConstruct<true, false, 1> { using Type = arrow::Int8Type; }; -template <> struct TConstruct<true, false, 2> { using Type = arrow::Int16Type; }; -template <> struct TConstruct<true, false, 4> { using Type = arrow::Int32Type; }; -template <> struct TConstruct<true, false, 8> { using Type = arrow::Int64Type; }; -template <> struct TConstruct<true, true, 1> { using Type = arrow::FloatType; }; -template <> struct TConstruct<true, true, 2> { using Type = arrow::FloatType; }; -template <> struct TConstruct<true, true, 4> { using Type = arrow::FloatType; }; -template <> struct TConstruct<true, true, 8> { using Type = arrow::DoubleType; }; - -template <typename A, typename B> -struct TResultOfAdditionMultiplication { - using Type = typename TConstruct< - std::is_signed_v<A> || std::is_signed_v<B>, - std::is_floating_point_v<A> || std::is_floating_point_v<B>, - NextSize(std::max(sizeof(A), sizeof(B)))>::Type; -}; - -template <typename A, typename B> -struct TResultOfSubtraction { - using Type = typename TConstruct< - true, - std::is_floating_point_v<A> || std::is_floating_point_v<B>, - NextSize(std::max(sizeof(A), sizeof(B)))>::Type; -}; - -template <typename A, typename B> -struct TResultOfFloatingPointDivision { - using Type = arrow::DoubleType; -}; - -template <typename A, typename B> -struct TResultOfIntegerDivision { - using Type = typename TConstruct< - std::is_signed_v<A> || std::is_signed_v<B>, - false, - sizeof(A)>::Type; -}; - -template <typename A, typename B> -struct TResultOfModulo { - static constexpr bool result_is_signed = std::is_signed_v<A>; - /// If modulo of division can yield negative number, we need larger type to accommodate it. - /// Example: toInt32(-199) % toUInt8(200) will return -199 that does not fit in Int8, only in Int16. - static constexpr size_t size_of_result = result_is_signed ? NextSize(sizeof(B)) : sizeof(B); - using Type0 = typename TConstruct<result_is_signed, false, size_of_result>::Type; - using Type = std::conditional_t<std::is_floating_point_v<A> || std::is_floating_point_v<B>, arrow::DoubleType, Type0>; -}; - -template <typename A> -struct TResultOfNegate { - using Type = typename TConstruct< - true, - std::is_floating_point_v<A>, - std::is_signed_v<A> ? sizeof(A) : NextSize(sizeof(A))>::Type; -}; - -template <typename A> -struct TResultOfAbs { - using Type = typename TConstruct< - false, - std::is_floating_point_v<A>, - sizeof(A)>::Type; -}; - -template <typename A> struct TToInteger { - using Type = typename TConstruct< - std::is_signed_v<A>, - false, - std::is_floating_point_v<A> ? 8 : sizeof(A)>::Type; -}; - + +constexpr size_t NextSize(size_t size) { + if (size < 8) { + return size * 2; + } + return size; +} + +struct TError {}; + +template <bool is_signed, bool is_floating, size_t size> +struct TConstruct { + using Type = TError; +}; + +template <> struct TConstruct<false, false, 1> { using Type = arrow::UInt8Type; }; +template <> struct TConstruct<false, false, 2> { using Type = arrow::UInt16Type; }; +template <> struct TConstruct<false, false, 4> { using Type = arrow::UInt32Type; }; +template <> struct TConstruct<false, false, 8> { using Type = arrow::UInt64Type; }; +template <> struct TConstruct<false, true, 1> { using Type = arrow::FloatType; }; +template <> struct TConstruct<false, true, 2> { using Type = arrow::FloatType; }; +template <> struct TConstruct<false, true, 4> { using Type = arrow::FloatType; }; +template <> struct TConstruct<false, true, 8> { using Type = arrow::DoubleType; }; +template <> struct TConstruct<true, false, 1> { using Type = arrow::Int8Type; }; +template <> struct TConstruct<true, false, 2> { using Type = arrow::Int16Type; }; +template <> struct TConstruct<true, false, 4> { using Type = arrow::Int32Type; }; +template <> struct TConstruct<true, false, 8> { using Type = arrow::Int64Type; }; +template <> struct TConstruct<true, true, 1> { using Type = arrow::FloatType; }; +template <> struct TConstruct<true, true, 2> { using Type = arrow::FloatType; }; +template <> struct TConstruct<true, true, 4> { using Type = arrow::FloatType; }; +template <> struct TConstruct<true, true, 8> { using Type = arrow::DoubleType; }; + +template <typename A, typename B> +struct TResultOfAdditionMultiplication { + using Type = typename TConstruct< + std::is_signed_v<A> || std::is_signed_v<B>, + std::is_floating_point_v<A> || std::is_floating_point_v<B>, + NextSize(std::max(sizeof(A), sizeof(B)))>::Type; +}; + +template <typename A, typename B> +struct TResultOfSubtraction { + using Type = typename TConstruct< + true, + std::is_floating_point_v<A> || std::is_floating_point_v<B>, + NextSize(std::max(sizeof(A), sizeof(B)))>::Type; +}; + +template <typename A, typename B> +struct TResultOfFloatingPointDivision { + using Type = arrow::DoubleType; +}; + +template <typename A, typename B> +struct TResultOfIntegerDivision { + using Type = typename TConstruct< + std::is_signed_v<A> || std::is_signed_v<B>, + false, + sizeof(A)>::Type; +}; + +template <typename A, typename B> +struct TResultOfModulo { + static constexpr bool result_is_signed = std::is_signed_v<A>; + /// If modulo of division can yield negative number, we need larger type to accommodate it. + /// Example: toInt32(-199) % toUInt8(200) will return -199 that does not fit in Int8, only in Int16. + static constexpr size_t size_of_result = result_is_signed ? NextSize(sizeof(B)) : sizeof(B); + using Type0 = typename TConstruct<result_is_signed, false, size_of_result>::Type; + using Type = std::conditional_t<std::is_floating_point_v<A> || std::is_floating_point_v<B>, arrow::DoubleType, Type0>; +}; + +template <typename A> +struct TResultOfNegate { + using Type = typename TConstruct< + true, + std::is_floating_point_v<A>, + std::is_signed_v<A> ? sizeof(A) : NextSize(sizeof(A))>::Type; +}; + +template <typename A> +struct TResultOfAbs { + using Type = typename TConstruct< + false, + std::is_floating_point_v<A>, + sizeof(A)>::Type; +}; + +template <typename A> struct TToInteger { + using Type = typename TConstruct< + std::is_signed_v<A>, + false, + std::is_floating_point_v<A> ? 8 : sizeof(A)>::Type; +}; + } diff --git a/ydb/core/formats/custom_registry.cpp b/ydb/core/formats/custom_registry.cpp index 404d01d6a9..ea9433e52d 100644 --- a/ydb/core/formats/custom_registry.cpp +++ b/ydb/core/formats/custom_registry.cpp @@ -1,63 +1,63 @@ -#include "functions.h" -#include "func_common.h" +#include "functions.h" +#include "func_common.h" #include <util/system/yassert.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.cc> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> - -namespace cp = ::arrow::compute; - + +namespace cp = ::arrow::compute; + namespace NKikimr::NArrow { - -static void RegisterMath(cp::FunctionRegistry* registry) { - Y_VERIFY(registry->AddFunction(MakeMathUnary<TAcosh>(TAcosh::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TAtanh>(TAtanh::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TCbrt>(TCbrt::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TCosh>(TCosh::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeConstNullary<TE>(TE::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TErf>(TErf::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TErfc>(TErfc::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp>(TExp::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp2>(TExp2::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp10>(TExp10::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathBinary<THypot>(THypot::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TLgamma>(TLgamma::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeConstNullary<TPi>(TPi::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TSinh>(TSinh::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TSqrt>(TSqrt::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeMathUnary<TTgamma>(TTgamma::Name)).ok()); -} - -static void RegisterRound(cp::FunctionRegistry* registry) { - Y_VERIFY(registry->AddFunction(MakeArithmeticUnary<TRound>(TRound::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeArithmeticUnary<TRoundBankers>(TRoundBankers::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeArithmeticUnary<TRoundToExp2>(TRoundToExp2::Name)).ok()); -} - -static void RegisterArithmetic(cp::FunctionRegistry* registry) { - Y_VERIFY(registry->AddFunction(MakeArithmeticIntBinary<TGreatestCommonDivisor>(TGreatestCommonDivisor::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeArithmeticIntBinary<TLeastCommonMultiple>(TLeastCommonMultiple::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeArithmeticBinary<TModulo>(TModulo::Name)).ok()); - Y_VERIFY(registry->AddFunction(MakeArithmeticBinary<TModuloOrZero>(TModuloOrZero::Name)).ok()); -} - - -static std::unique_ptr<cp::FunctionRegistry> CreateCustomRegistry() { - auto registry = cp::FunctionRegistry::Make(); - RegisterMath(registry.get()); - RegisterRound(registry.get()); - RegisterArithmetic(registry.get()); - cp::internal::RegisterScalarCast(registry.get()); - return registry; -} - -cp::FunctionRegistry* GetCustomFunctionRegistry() { - static auto g_registry = CreateCustomRegistry(); - return g_registry.get(); -} - -cp::ExecContext* GetCustomExecContext() { - static auto context = std::make_unique<cp::ExecContext>(arrow::default_memory_pool(), NULLPTR, GetCustomFunctionRegistry()); - return context.get(); -} - -} + +static void RegisterMath(cp::FunctionRegistry* registry) { + Y_VERIFY(registry->AddFunction(MakeMathUnary<TAcosh>(TAcosh::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TAtanh>(TAtanh::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TCbrt>(TCbrt::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TCosh>(TCosh::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeConstNullary<TE>(TE::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TErf>(TErf::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TErfc>(TErfc::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp>(TExp::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp2>(TExp2::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TExp10>(TExp10::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathBinary<THypot>(THypot::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TLgamma>(TLgamma::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeConstNullary<TPi>(TPi::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TSinh>(TSinh::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TSqrt>(TSqrt::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeMathUnary<TTgamma>(TTgamma::Name)).ok()); +} + +static void RegisterRound(cp::FunctionRegistry* registry) { + Y_VERIFY(registry->AddFunction(MakeArithmeticUnary<TRound>(TRound::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeArithmeticUnary<TRoundBankers>(TRoundBankers::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeArithmeticUnary<TRoundToExp2>(TRoundToExp2::Name)).ok()); +} + +static void RegisterArithmetic(cp::FunctionRegistry* registry) { + Y_VERIFY(registry->AddFunction(MakeArithmeticIntBinary<TGreatestCommonDivisor>(TGreatestCommonDivisor::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeArithmeticIntBinary<TLeastCommonMultiple>(TLeastCommonMultiple::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeArithmeticBinary<TModulo>(TModulo::Name)).ok()); + Y_VERIFY(registry->AddFunction(MakeArithmeticBinary<TModuloOrZero>(TModuloOrZero::Name)).ok()); +} + + +static std::unique_ptr<cp::FunctionRegistry> CreateCustomRegistry() { + auto registry = cp::FunctionRegistry::Make(); + RegisterMath(registry.get()); + RegisterRound(registry.get()); + RegisterArithmetic(registry.get()); + cp::internal::RegisterScalarCast(registry.get()); + return registry; +} + +cp::FunctionRegistry* GetCustomFunctionRegistry() { + static auto g_registry = CreateCustomRegistry(); + return g_registry.get(); +} + +cp::ExecContext* GetCustomExecContext() { + static auto context = std::make_unique<cp::ExecContext>(arrow::default_memory_pool(), NULLPTR, GetCustomFunctionRegistry()); + return context.get(); +} + +} diff --git a/ydb/core/formats/custom_registry.h b/ydb/core/formats/custom_registry.h index 8442e44eee..ff25ae2a3c 100644 --- a/ydb/core/formats/custom_registry.h +++ b/ydb/core/formats/custom_registry.h @@ -1,9 +1,9 @@ -#pragma once +#pragma once #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> - -namespace cp = ::arrow::compute; - + +namespace cp = ::arrow::compute; + namespace NKikimr::NArrow { - cp::FunctionRegistry* GetCustomFunctionRegistry(); - cp::ExecContext* GetCustomExecContext(); -} + cp::FunctionRegistry* GetCustomFunctionRegistry(); + cp::ExecContext* GetCustomExecContext(); +} diff --git a/ydb/core/formats/execs.h b/ydb/core/formats/execs.h index bf98b70927..3600e9ea2c 100644 --- a/ydb/core/formats/execs.h +++ b/ydb/core/formats/execs.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include <contrib/libs/apache/arrow/cpp/src/arrow/scalar.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/status.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> @@ -10,174 +10,174 @@ #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h> #include <util/datetime/base.h> #include <util/system/yassert.h> -#include <cstdint> -#include <memory> -#include <type_traits> -#include <vector> -#include "switch_type.h" - -namespace cp = arrow::compute; -using cp::internal::applicator::ScalarBinary; -using cp::internal::applicator::ScalarUnary; - +#include <cstdint> +#include <memory> +#include <type_traits> +#include <vector> +#include "switch_type.h" + +namespace cp = arrow::compute; +using cp::internal::applicator::ScalarBinary; +using cp::internal::applicator::ScalarUnary; + namespace NKikimr::NArrow { - -template <template <typename... Args> class KernelGenerator, typename Op> -cp::ArrayKernelExec ArithmeticBinaryExec(cp::internal::detail::GetTypeId getId) { - switch (getId.id) { - case arrow::Type::INT8: - return KernelGenerator<arrow::Int8Type, arrow::Int8Type, arrow::Int8Type, Op>::Exec; - case arrow::Type::UINT8: - return KernelGenerator<arrow::UInt8Type, arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; - case arrow::Type::INT16: - return KernelGenerator<arrow::Int16Type, arrow::Int16Type, arrow::Int16Type, Op>::Exec; - case arrow::Type::UINT16: - return KernelGenerator<arrow::UInt16Type, arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; - case arrow::Type::INT32: - return KernelGenerator<arrow::Int32Type, arrow::Int32Type, arrow::Int32Type, Op>::Exec; - case arrow::Type::UINT32: - return KernelGenerator<arrow::UInt32Type, arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; - case arrow::Type::INT64: - return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::TIMESTAMP: - return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::UINT64: - return KernelGenerator<arrow::UInt64Type, arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; - case arrow::Type::FLOAT: - return KernelGenerator<arrow::FloatType, arrow::FloatType, arrow::FloatType, Op>::Exec; - case arrow::Type::DOUBLE: - return KernelGenerator<arrow::DoubleType, arrow::DoubleType, arrow::DoubleType, Op>::Exec; - default: - Y_VERIFY(false); - return cp::internal::ExecFail; - } -} - -template <template <typename... Args> class KernelGenerator, typename Op> -cp::ArrayKernelExec ArithmeticBinaryIntExec(cp::internal::detail::GetTypeId getId) { - switch (getId.id) { - case arrow::Type::INT8: - return KernelGenerator<arrow::Int8Type, arrow::Int8Type, arrow::Int8Type, Op>::Exec; - case arrow::Type::UINT8: - return KernelGenerator<arrow::UInt8Type, arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; - case arrow::Type::INT16: - return KernelGenerator<arrow::Int16Type, arrow::Int16Type, arrow::Int16Type, Op>::Exec; - case arrow::Type::UINT16: - return KernelGenerator<arrow::UInt16Type, arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; - case arrow::Type::INT32: - return KernelGenerator<arrow::Int32Type, arrow::Int32Type, arrow::Int32Type, Op>::Exec; - case arrow::Type::UINT32: - return KernelGenerator<arrow::UInt32Type, arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; - case arrow::Type::INT64: - return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::TIMESTAMP: - return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::UINT64: - return KernelGenerator<arrow::UInt64Type, arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; - default: - Y_VERIFY(false); - return cp::internal::ExecFail; - } -} - -template <template <typename... Args> class KernelGenerator, typename Op> -cp::ArrayKernelExec ArithmeticUnaryExec(cp::internal::detail::GetTypeId getId) { - switch (getId.id) { - case arrow::Type::INT8: - return KernelGenerator<arrow::Int8Type, arrow::Int8Type, Op>::Exec; - case arrow::Type::UINT8: - return KernelGenerator<arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; - case arrow::Type::INT16: - return KernelGenerator<arrow::Int16Type, arrow::Int16Type, Op>::Exec; - case arrow::Type::UINT16: - return KernelGenerator<arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; - case arrow::Type::INT32: - return KernelGenerator<arrow::Int32Type, arrow::Int32Type, Op>::Exec; - case arrow::Type::UINT32: - return KernelGenerator<arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; - case arrow::Type::INT64: - return KernelGenerator<arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::TIMESTAMP: - return KernelGenerator<arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::UINT64: - return KernelGenerator<arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; - case arrow::Type::FLOAT: - return KernelGenerator<arrow::FloatType, arrow::FloatType, Op>::Exec; - case arrow::Type::DOUBLE: - return KernelGenerator<arrow::DoubleType, arrow::DoubleType, Op>::Exec; - default: - Y_VERIFY(false); - return cp::internal::ExecFail; - } -} - -template <template <typename... Args> class KernelGenerator, typename Op> -cp::ArrayKernelExec MathUnaryExec(cp::internal::detail::GetTypeId getId) { - switch (getId.id) { - case arrow::Type::INT8: - return KernelGenerator<arrow::DoubleType, arrow::Int8Type, Op>::Exec; - case arrow::Type::UINT8: - return KernelGenerator<arrow::DoubleType, arrow::UInt8Type, Op>::Exec; - case arrow::Type::INT16: - return KernelGenerator<arrow::DoubleType, arrow::Int16Type, Op>::Exec; - case arrow::Type::UINT16: - return KernelGenerator<arrow::DoubleType, arrow::UInt16Type, Op>::Exec; - case arrow::Type::INT32: - return KernelGenerator<arrow::DoubleType, arrow::Int32Type, Op>::Exec; - case arrow::Type::UINT32: - return KernelGenerator<arrow::DoubleType, arrow::UInt32Type, Op>::Exec; - case arrow::Type::INT64: - return KernelGenerator<arrow::DoubleType, arrow::Int64Type, Op>::Exec; - case arrow::Type::TIMESTAMP: - return KernelGenerator<arrow::DoubleType, arrow::Int64Type, Op>::Exec; - case arrow::Type::UINT64: - return KernelGenerator<arrow::DoubleType, arrow::UInt64Type, Op>::Exec; - case arrow::Type::FLOAT: - return KernelGenerator<arrow::DoubleType, arrow::FloatType, Op>::Exec; - case arrow::Type::DOUBLE: - return KernelGenerator<arrow::DoubleType, arrow::DoubleType, Op>::Exec; - default: - Y_VERIFY(false); - return cp::internal::ExecFail; - } -} - -template <template <typename... Args> class KernelGenerator, typename Op> -cp::ArrayKernelExec MathBinaryExec(cp::internal::detail::GetTypeId getId) { - switch (getId.id) { - case arrow::Type::INT8: - return KernelGenerator<arrow::DoubleType, arrow::Int8Type, arrow::Int8Type, Op>::Exec; - case arrow::Type::UINT8: - return KernelGenerator<arrow::DoubleType, arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; - case arrow::Type::INT16: - return KernelGenerator<arrow::DoubleType, arrow::Int16Type, arrow::Int16Type, Op>::Exec; - case arrow::Type::UINT16: - return KernelGenerator<arrow::DoubleType, arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; - case arrow::Type::INT32: - return KernelGenerator<arrow::DoubleType, arrow::Int32Type, arrow::Int32Type, Op>::Exec; - case arrow::Type::UINT32: - return KernelGenerator<arrow::DoubleType, arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; - case arrow::Type::INT64: - return KernelGenerator<arrow::DoubleType, arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::TIMESTAMP: - return KernelGenerator<arrow::DoubleType, arrow::Int64Type, arrow::Int64Type, Op>::Exec; - case arrow::Type::UINT64: - return KernelGenerator<arrow::DoubleType, arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; - case arrow::Type::FLOAT: - return KernelGenerator<arrow::DoubleType, arrow::FloatType, arrow::FloatType, Op>::Exec; - case arrow::Type::DOUBLE: - return KernelGenerator<arrow::DoubleType, arrow::DoubleType, arrow::DoubleType, Op>::Exec; - default: - Y_VERIFY(false); - return cp::internal::ExecFail; - } -} - - -template <typename TOperator, typename TRes> -static arrow::Status SimpleNullaryExec(cp::KernelContext* ctx, const cp::ExecBatch&, arrow::Datum* out) { - *out = arrow::MakeScalar(TOperator:: template Call<typename cp::internal::GetViewType<TRes>::T>(ctx)); - return arrow::Status::OK(); -} - + +template <template <typename... Args> class KernelGenerator, typename Op> +cp::ArrayKernelExec ArithmeticBinaryExec(cp::internal::detail::GetTypeId getId) { + switch (getId.id) { + case arrow::Type::INT8: + return KernelGenerator<arrow::Int8Type, arrow::Int8Type, arrow::Int8Type, Op>::Exec; + case arrow::Type::UINT8: + return KernelGenerator<arrow::UInt8Type, arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; + case arrow::Type::INT16: + return KernelGenerator<arrow::Int16Type, arrow::Int16Type, arrow::Int16Type, Op>::Exec; + case arrow::Type::UINT16: + return KernelGenerator<arrow::UInt16Type, arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; + case arrow::Type::INT32: + return KernelGenerator<arrow::Int32Type, arrow::Int32Type, arrow::Int32Type, Op>::Exec; + case arrow::Type::UINT32: + return KernelGenerator<arrow::UInt32Type, arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; + case arrow::Type::INT64: + return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::TIMESTAMP: + return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::UINT64: + return KernelGenerator<arrow::UInt64Type, arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; + case arrow::Type::FLOAT: + return KernelGenerator<arrow::FloatType, arrow::FloatType, arrow::FloatType, Op>::Exec; + case arrow::Type::DOUBLE: + return KernelGenerator<arrow::DoubleType, arrow::DoubleType, arrow::DoubleType, Op>::Exec; + default: + Y_VERIFY(false); + return cp::internal::ExecFail; + } +} + +template <template <typename... Args> class KernelGenerator, typename Op> +cp::ArrayKernelExec ArithmeticBinaryIntExec(cp::internal::detail::GetTypeId getId) { + switch (getId.id) { + case arrow::Type::INT8: + return KernelGenerator<arrow::Int8Type, arrow::Int8Type, arrow::Int8Type, Op>::Exec; + case arrow::Type::UINT8: + return KernelGenerator<arrow::UInt8Type, arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; + case arrow::Type::INT16: + return KernelGenerator<arrow::Int16Type, arrow::Int16Type, arrow::Int16Type, Op>::Exec; + case arrow::Type::UINT16: + return KernelGenerator<arrow::UInt16Type, arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; + case arrow::Type::INT32: + return KernelGenerator<arrow::Int32Type, arrow::Int32Type, arrow::Int32Type, Op>::Exec; + case arrow::Type::UINT32: + return KernelGenerator<arrow::UInt32Type, arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; + case arrow::Type::INT64: + return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::TIMESTAMP: + return KernelGenerator<arrow::Int64Type, arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::UINT64: + return KernelGenerator<arrow::UInt64Type, arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; + default: + Y_VERIFY(false); + return cp::internal::ExecFail; + } +} + +template <template <typename... Args> class KernelGenerator, typename Op> +cp::ArrayKernelExec ArithmeticUnaryExec(cp::internal::detail::GetTypeId getId) { + switch (getId.id) { + case arrow::Type::INT8: + return KernelGenerator<arrow::Int8Type, arrow::Int8Type, Op>::Exec; + case arrow::Type::UINT8: + return KernelGenerator<arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; + case arrow::Type::INT16: + return KernelGenerator<arrow::Int16Type, arrow::Int16Type, Op>::Exec; + case arrow::Type::UINT16: + return KernelGenerator<arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; + case arrow::Type::INT32: + return KernelGenerator<arrow::Int32Type, arrow::Int32Type, Op>::Exec; + case arrow::Type::UINT32: + return KernelGenerator<arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; + case arrow::Type::INT64: + return KernelGenerator<arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::TIMESTAMP: + return KernelGenerator<arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::UINT64: + return KernelGenerator<arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; + case arrow::Type::FLOAT: + return KernelGenerator<arrow::FloatType, arrow::FloatType, Op>::Exec; + case arrow::Type::DOUBLE: + return KernelGenerator<arrow::DoubleType, arrow::DoubleType, Op>::Exec; + default: + Y_VERIFY(false); + return cp::internal::ExecFail; + } +} + +template <template <typename... Args> class KernelGenerator, typename Op> +cp::ArrayKernelExec MathUnaryExec(cp::internal::detail::GetTypeId getId) { + switch (getId.id) { + case arrow::Type::INT8: + return KernelGenerator<arrow::DoubleType, arrow::Int8Type, Op>::Exec; + case arrow::Type::UINT8: + return KernelGenerator<arrow::DoubleType, arrow::UInt8Type, Op>::Exec; + case arrow::Type::INT16: + return KernelGenerator<arrow::DoubleType, arrow::Int16Type, Op>::Exec; + case arrow::Type::UINT16: + return KernelGenerator<arrow::DoubleType, arrow::UInt16Type, Op>::Exec; + case arrow::Type::INT32: + return KernelGenerator<arrow::DoubleType, arrow::Int32Type, Op>::Exec; + case arrow::Type::UINT32: + return KernelGenerator<arrow::DoubleType, arrow::UInt32Type, Op>::Exec; + case arrow::Type::INT64: + return KernelGenerator<arrow::DoubleType, arrow::Int64Type, Op>::Exec; + case arrow::Type::TIMESTAMP: + return KernelGenerator<arrow::DoubleType, arrow::Int64Type, Op>::Exec; + case arrow::Type::UINT64: + return KernelGenerator<arrow::DoubleType, arrow::UInt64Type, Op>::Exec; + case arrow::Type::FLOAT: + return KernelGenerator<arrow::DoubleType, arrow::FloatType, Op>::Exec; + case arrow::Type::DOUBLE: + return KernelGenerator<arrow::DoubleType, arrow::DoubleType, Op>::Exec; + default: + Y_VERIFY(false); + return cp::internal::ExecFail; + } +} + +template <template <typename... Args> class KernelGenerator, typename Op> +cp::ArrayKernelExec MathBinaryExec(cp::internal::detail::GetTypeId getId) { + switch (getId.id) { + case arrow::Type::INT8: + return KernelGenerator<arrow::DoubleType, arrow::Int8Type, arrow::Int8Type, Op>::Exec; + case arrow::Type::UINT8: + return KernelGenerator<arrow::DoubleType, arrow::UInt8Type, arrow::UInt8Type, Op>::Exec; + case arrow::Type::INT16: + return KernelGenerator<arrow::DoubleType, arrow::Int16Type, arrow::Int16Type, Op>::Exec; + case arrow::Type::UINT16: + return KernelGenerator<arrow::DoubleType, arrow::UInt16Type, arrow::UInt16Type, Op>::Exec; + case arrow::Type::INT32: + return KernelGenerator<arrow::DoubleType, arrow::Int32Type, arrow::Int32Type, Op>::Exec; + case arrow::Type::UINT32: + return KernelGenerator<arrow::DoubleType, arrow::UInt32Type, arrow::UInt32Type, Op>::Exec; + case arrow::Type::INT64: + return KernelGenerator<arrow::DoubleType, arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::TIMESTAMP: + return KernelGenerator<arrow::DoubleType, arrow::Int64Type, arrow::Int64Type, Op>::Exec; + case arrow::Type::UINT64: + return KernelGenerator<arrow::DoubleType, arrow::UInt64Type, arrow::UInt64Type, Op>::Exec; + case arrow::Type::FLOAT: + return KernelGenerator<arrow::DoubleType, arrow::FloatType, arrow::FloatType, Op>::Exec; + case arrow::Type::DOUBLE: + return KernelGenerator<arrow::DoubleType, arrow::DoubleType, arrow::DoubleType, Op>::Exec; + default: + Y_VERIFY(false); + return cp::internal::ExecFail; + } +} + + +template <typename TOperator, typename TRes> +static arrow::Status SimpleNullaryExec(cp::KernelContext* ctx, const cp::ExecBatch&, arrow::Datum* out) { + *out = arrow::MakeScalar(TOperator:: template Call<typename cp::internal::GetViewType<TRes>::T>(ctx)); + return arrow::Status::OK(); +} + } diff --git a/ydb/core/formats/func_common.h b/ydb/core/formats/func_common.h index c697fcf03e..b661f88756 100644 --- a/ydb/core/formats/func_common.h +++ b/ydb/core/formats/func_common.h @@ -1,177 +1,177 @@ -#pragma once +#pragma once #include <contrib/libs/apache/arrow/cpp/src/arrow/scalar.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/status.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h> - + #include <util/system/yassert.h> - -#include <type_traits> - -#include "switch_type.h" -#include "execs.h" - -namespace cp = arrow::compute; -using cp::internal::applicator::ScalarBinary; -using cp::internal::applicator::ScalarUnary; - + +#include <type_traits> + +#include "switch_type.h" +#include "execs.h" + +namespace cp = arrow::compute; +using cp::internal::applicator::ScalarBinary; +using cp::internal::applicator::ScalarUnary; + namespace NKikimr::NArrow { - -template <typename T> -using IsUnsignedInteger = - std::integral_constant<bool, std::is_integral<T>::value && - std::is_unsigned<T>::value>; - -template <typename T> -using IsSignedInteger = - std::integral_constant<bool, std::is_integral<T>::value && - std::is_signed<T>::value>; - -template<typename T> -using IsNumeric = std::integral_constant<bool, IsSignedInteger<T>::value || - IsUnsignedInteger<T>::value || - std::is_floating_point<T>::value>; - -template<typename TArr> -using IsArrayNumeric = std::integral_constant<bool, arrow::is_number_type<typename TArr::TypeClass>::value>; - - -template <typename T, typename R = T> -using EnableIfSigned = - std::enable_if_t<IsSignedInteger<T>::value, R>; - -template <typename T, typename R = T> -using EnableIfUnsigned = - std::enable_if_t<IsUnsignedInteger<T>::value, R>; - -template <typename T, typename R = T> -using EnableIfInteger = std::enable_if_t<IsSignedInteger<T>::value || - IsUnsignedInteger<T>::value, R>; - -template <typename T, typename R = T> -using EnableIfFloatingPoint = - std::enable_if_t<std::is_floating_point<T>::value, R>; - -template <typename T, typename R = T> -using EnableIfFloat64 = - std::enable_if_t<std::is_same<T, arrow::TypeTraits<arrow::DoubleType>::CType>::value, R>; - -template <typename T, typename R = T> -using EnableIfFloat32 = - std::enable_if_t<std::is_same<T, arrow::TypeTraits<arrow::FloatType>::CType>::value, R>; - -template <typename T, typename R = T> -using EnableIfNumeric = - std::enable_if_t<IsNumeric<T>::value, R>; - - -template <typename TType> -using TArray = typename arrow::TypeTraits<TType>::ArrayType; - -template <typename TType> -using TBuilder = typename arrow::TypeTraits<TType>::BuilderType; - -template <typename TSignedInt> -TSignedInt SafeSignedNegate(TSignedInt u) { - using TUnsignedInt = typename std::make_unsigned<TSignedInt>::type; - return static_cast<TSignedInt>(~static_cast<TUnsignedInt>(u) + 1); -} - -struct TArithmeticFunction : cp::ScalarFunction { - using ScalarFunction::ScalarFunction; - - arrow::Result<const arrow::compute::Kernel*> DispatchBest(std::vector<arrow::ValueDescr>* values) const override { - RETURN_NOT_OK(CheckArity(*values)); - - using arrow::compute::detail::DispatchExactImpl; - if (auto* kernel = DispatchExactImpl(this, *values)) { - return kernel; - } - - arrow::compute::internal::EnsureDictionaryDecoded(values); - - // Only promote types for binary functions - if (values->size() == 2) { - arrow::compute::internal::ReplaceNullWithOtherType(values); - if (auto type = arrow::compute::internal::CommonNumeric(*values)) { - arrow::compute::internal::ReplaceTypes(type, values); - } - #if 0 // TODO: dates + ints - else if (auto type = arrow::compute::internal::CommonTimestamp(*values)) { - arrow::compute::internal::ReplaceTypes(type, values); - } - #endif - } - - if (auto* kernel = DispatchExactImpl(this, *values)) { - return kernel; - } - return arrow::compute::detail::NoMatchingKernel(this, *values); - } -}; - - - -template <typename Op> -std::shared_ptr<cp::ScalarFunction> MakeConstNullary(const std::string& name) { - auto func = std::make_shared<arrow::compute::ScalarFunction>(name, cp::Arity::Nullary(), nullptr); - cp::ArrayKernelExec exec = SimpleNullaryExec<Op, arrow::DoubleType>; - Y_VERIFY(func->AddKernel({}, arrow::float64(), exec).ok()); - return func; -} - - -template <typename Op> -std::shared_ptr<cp::ScalarFunction> MakeArithmeticBinary(const std::string& name) { - auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Binary(), nullptr); - for (const auto& ty : cp::internal::NumericTypes()) { - auto exec = ArithmeticBinaryExec<ScalarBinary, Op>(ty); - Y_VERIFY(func->AddKernel({ty, ty}, ty, exec).ok()); - } - return func; -} - -template <typename Op> -std::shared_ptr<cp::ScalarFunction> MakeArithmeticIntBinary(const std::string& name) { - auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Binary(), nullptr); - for (const auto& ty : cp::internal::IntTypes()) { - auto exec = ArithmeticBinaryIntExec<ScalarBinary, Op>(ty); - Y_VERIFY(func->AddKernel({ty, ty}, ty, exec).ok()); - } - return func; -} - - -template <typename Op> -std::shared_ptr<cp::ScalarFunction> MakeArithmeticUnary(const std::string& name) { - auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Unary(), nullptr); - for (const auto& ty : cp::internal::NumericTypes()) { - auto exec = ArithmeticUnaryExec<ScalarUnary, Op>(ty); - Y_VERIFY(func->AddKernel({ty}, ty, exec).ok()); - } - return func; -} - -template <typename Op> -std::shared_ptr<cp::ScalarFunction> MakeMathUnary(const std::string& name) { - auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Unary(), nullptr); - for (const auto& ty : cp::internal::NumericTypes()) { - auto exec = MathUnaryExec<ScalarUnary, Op>(ty); - Y_VERIFY(func->AddKernel({ty}, arrow::float64(), exec).ok()); - } - return func; -} - -template <typename Op> -std::shared_ptr<cp::ScalarFunction> MakeMathBinary(const std::string& name) { - auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Binary(), nullptr); - for (const auto& ty : cp::internal::NumericTypes()) { - auto exec = MathBinaryExec<ScalarBinary, Op>(ty); - Y_VERIFY(func->AddKernel({ty, ty}, arrow::float64(), exec).ok()); - } - return func; -} - + +template <typename T> +using IsUnsignedInteger = + std::integral_constant<bool, std::is_integral<T>::value && + std::is_unsigned<T>::value>; + +template <typename T> +using IsSignedInteger = + std::integral_constant<bool, std::is_integral<T>::value && + std::is_signed<T>::value>; + +template<typename T> +using IsNumeric = std::integral_constant<bool, IsSignedInteger<T>::value || + IsUnsignedInteger<T>::value || + std::is_floating_point<T>::value>; + +template<typename TArr> +using IsArrayNumeric = std::integral_constant<bool, arrow::is_number_type<typename TArr::TypeClass>::value>; + + +template <typename T, typename R = T> +using EnableIfSigned = + std::enable_if_t<IsSignedInteger<T>::value, R>; + +template <typename T, typename R = T> +using EnableIfUnsigned = + std::enable_if_t<IsUnsignedInteger<T>::value, R>; + +template <typename T, typename R = T> +using EnableIfInteger = std::enable_if_t<IsSignedInteger<T>::value || + IsUnsignedInteger<T>::value, R>; + +template <typename T, typename R = T> +using EnableIfFloatingPoint = + std::enable_if_t<std::is_floating_point<T>::value, R>; + +template <typename T, typename R = T> +using EnableIfFloat64 = + std::enable_if_t<std::is_same<T, arrow::TypeTraits<arrow::DoubleType>::CType>::value, R>; + +template <typename T, typename R = T> +using EnableIfFloat32 = + std::enable_if_t<std::is_same<T, arrow::TypeTraits<arrow::FloatType>::CType>::value, R>; + +template <typename T, typename R = T> +using EnableIfNumeric = + std::enable_if_t<IsNumeric<T>::value, R>; + + +template <typename TType> +using TArray = typename arrow::TypeTraits<TType>::ArrayType; + +template <typename TType> +using TBuilder = typename arrow::TypeTraits<TType>::BuilderType; + +template <typename TSignedInt> +TSignedInt SafeSignedNegate(TSignedInt u) { + using TUnsignedInt = typename std::make_unsigned<TSignedInt>::type; + return static_cast<TSignedInt>(~static_cast<TUnsignedInt>(u) + 1); +} + +struct TArithmeticFunction : cp::ScalarFunction { + using ScalarFunction::ScalarFunction; + + arrow::Result<const arrow::compute::Kernel*> DispatchBest(std::vector<arrow::ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + + using arrow::compute::detail::DispatchExactImpl; + if (auto* kernel = DispatchExactImpl(this, *values)) { + return kernel; + } + + arrow::compute::internal::EnsureDictionaryDecoded(values); + + // Only promote types for binary functions + if (values->size() == 2) { + arrow::compute::internal::ReplaceNullWithOtherType(values); + if (auto type = arrow::compute::internal::CommonNumeric(*values)) { + arrow::compute::internal::ReplaceTypes(type, values); + } + #if 0 // TODO: dates + ints + else if (auto type = arrow::compute::internal::CommonTimestamp(*values)) { + arrow::compute::internal::ReplaceTypes(type, values); + } + #endif + } + + if (auto* kernel = DispatchExactImpl(this, *values)) { + return kernel; + } + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + + + +template <typename Op> +std::shared_ptr<cp::ScalarFunction> MakeConstNullary(const std::string& name) { + auto func = std::make_shared<arrow::compute::ScalarFunction>(name, cp::Arity::Nullary(), nullptr); + cp::ArrayKernelExec exec = SimpleNullaryExec<Op, arrow::DoubleType>; + Y_VERIFY(func->AddKernel({}, arrow::float64(), exec).ok()); + return func; +} + + +template <typename Op> +std::shared_ptr<cp::ScalarFunction> MakeArithmeticBinary(const std::string& name) { + auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Binary(), nullptr); + for (const auto& ty : cp::internal::NumericTypes()) { + auto exec = ArithmeticBinaryExec<ScalarBinary, Op>(ty); + Y_VERIFY(func->AddKernel({ty, ty}, ty, exec).ok()); + } + return func; +} + +template <typename Op> +std::shared_ptr<cp::ScalarFunction> MakeArithmeticIntBinary(const std::string& name) { + auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Binary(), nullptr); + for (const auto& ty : cp::internal::IntTypes()) { + auto exec = ArithmeticBinaryIntExec<ScalarBinary, Op>(ty); + Y_VERIFY(func->AddKernel({ty, ty}, ty, exec).ok()); + } + return func; +} + + +template <typename Op> +std::shared_ptr<cp::ScalarFunction> MakeArithmeticUnary(const std::string& name) { + auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Unary(), nullptr); + for (const auto& ty : cp::internal::NumericTypes()) { + auto exec = ArithmeticUnaryExec<ScalarUnary, Op>(ty); + Y_VERIFY(func->AddKernel({ty}, ty, exec).ok()); + } + return func; +} + +template <typename Op> +std::shared_ptr<cp::ScalarFunction> MakeMathUnary(const std::string& name) { + auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Unary(), nullptr); + for (const auto& ty : cp::internal::NumericTypes()) { + auto exec = MathUnaryExec<ScalarUnary, Op>(ty); + Y_VERIFY(func->AddKernel({ty}, arrow::float64(), exec).ok()); + } + return func; +} + +template <typename Op> +std::shared_ptr<cp::ScalarFunction> MakeMathBinary(const std::string& name) { + auto func = std::make_shared<TArithmeticFunction>(name, cp::Arity::Binary(), nullptr); + for (const auto& ty : cp::internal::NumericTypes()) { + auto exec = MathBinaryExec<ScalarBinary, Op>(ty); + Y_VERIFY(func->AddKernel({ty, ty}, arrow::float64(), exec).ok()); + } + return func; +} + } diff --git a/ydb/core/formats/func_gcd.h b/ydb/core/formats/func_gcd.h index 9d4b72e4b0..d69ec595fd 100644 --- a/ydb/core/formats/func_gcd.h +++ b/ydb/core/formats/func_gcd.h @@ -1,37 +1,37 @@ -#pragma once +#pragma once #include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernel.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> -#include "func_common.h" -#include <cstdlib> -#include <type_traits> - +#include "func_common.h" +#include <cstdlib> +#include <type_traits> + namespace NKikimr::NArrow { - -template<typename T> -__attribute__((always_inline)) void FastIntSwap(T& lhs, T& rhs) { - lhs ^= rhs; - rhs ^= lhs; - lhs ^= rhs; - } - -struct TGreatestCommonDivisor { - - static constexpr const char * Name = "gcd"; - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr TRes Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { - static_assert(std::is_integral_v<TRes>, ""); - static_assert(std::is_integral_v<TArg0>, ""); - static_assert(std::is_integral_v<TArg1>, ""); - while (rhs != 0) { - lhs %= rhs; - FastIntSwap(lhs, rhs); - } - return lhs; - } -}; - + +template<typename T> +__attribute__((always_inline)) void FastIntSwap(T& lhs, T& rhs) { + lhs ^= rhs; + rhs ^= lhs; + lhs ^= rhs; + } + +struct TGreatestCommonDivisor { + + static constexpr const char * Name = "gcd"; + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr TRes Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { + static_assert(std::is_integral_v<TRes>, ""); + static_assert(std::is_integral_v<TArg0>, ""); + static_assert(std::is_integral_v<TArg1>, ""); + while (rhs != 0) { + lhs %= rhs; + FastIntSwap(lhs, rhs); + } + return lhs; + } +}; + } diff --git a/ydb/core/formats/func_lcm.h b/ydb/core/formats/func_lcm.h index 1e31b0e516..410389187f 100644 --- a/ydb/core/formats/func_lcm.h +++ b/ydb/core/formats/func_lcm.h @@ -1,27 +1,27 @@ -#pragma once -#include "func_common.h" -#include "func_mul.h" -#include "func_gcd.h" -#include "clickhouse_type_traits.h" - +#pragma once +#include "func_common.h" +#include "func_mul.h" +#include "func_gcd.h" +#include "clickhouse_type_traits.h" + namespace NKikimr::NArrow { - -struct TLeastCommonMultiple { - - static constexpr const char * Name = "lcm"; - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr TRes Call(arrow::compute::KernelContext* ctx, TArg0 lhs, TArg1 rhs, arrow::Status* st) { - static_assert(std::is_integral_v<TRes>, ""); - static_assert(std::is_integral_v<TArg0>, ""); - static_assert(std::is_integral_v<TArg1>, ""); - auto gcd = TGreatestCommonDivisor::Call<TRes, TArg0, TArg1>(ctx, lhs, rhs, st); - if (ARROW_PREDICT_FALSE(gcd == 0)) { - *st = arrow::Status::Invalid("divide by zero"); - return 0; - } - return TMultiply::Call<TRes, TArg0, TArg1>(ctx, lhs, rhs, st) / gcd; - } -}; - + +struct TLeastCommonMultiple { + + static constexpr const char * Name = "lcm"; + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr TRes Call(arrow::compute::KernelContext* ctx, TArg0 lhs, TArg1 rhs, arrow::Status* st) { + static_assert(std::is_integral_v<TRes>, ""); + static_assert(std::is_integral_v<TArg0>, ""); + static_assert(std::is_integral_v<TArg1>, ""); + auto gcd = TGreatestCommonDivisor::Call<TRes, TArg0, TArg1>(ctx, lhs, rhs, st); + if (ARROW_PREDICT_FALSE(gcd == 0)) { + *st = arrow::Status::Invalid("divide by zero"); + return 0; + } + return TMultiply::Call<TRes, TArg0, TArg1>(ctx, lhs, rhs, st) / gcd; + } +}; + } diff --git a/ydb/core/formats/func_math.h b/ydb/core/formats/func_math.h index 74cf8694ff..24d919e817 100644 --- a/ydb/core/formats/func_math.h +++ b/ydb/core/formats/func_math.h @@ -1,169 +1,169 @@ -#pragma once -#include "func_common.h" -#include <cmath> - +#pragma once +#include "func_common.h" +#include <cmath> + namespace NKikimr::NArrow { - -struct TAcosh { - - static constexpr const char * Name = "acosh"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::acosh(arg); - } -}; - -struct TAtanh { - - static constexpr const char * Name = "atanh"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::atanh(arg); - } -}; - -struct TCbrt { - - static constexpr const char * Name = "cbrt"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::cbrt(arg); - } -}; - -struct TCosh { - - static constexpr const char * Name = "cosh"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::cosh(arg); - } -}; - -struct TE { - - static constexpr const char * Name = "e"; - static constexpr double value = 2.7182818284590452353602874713526624977572470; - - template <typename TRes> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*) { - return value; - } -}; - -struct TErf { - - static constexpr const char * Name = "erf"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::erf(arg); - } -}; - -struct TErfc { - - static constexpr const char * Name = "erfc"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::erfc(arg); - } -}; - -struct TExp { - - static constexpr const char * Name = "exp"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::exp(arg); - } -}; - -struct TExp2 { - - static constexpr const char * Name = "exp2"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::exp2(arg); - } -}; - -struct TExp10 { - - static constexpr const char * Name = "exp10"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return exp10(arg); - } -}; - -struct THypot { - - static constexpr const char * Name = "hypot"; - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { - return std::hypot(lhs, rhs); - } -}; - -struct TLgamma { - - static constexpr const char * Name = "lgamma"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::lgamma(arg); - } -}; - -struct TPi { - - static constexpr const char * Name = "pi"; - static constexpr double value = 3.1415926535897932384626433832795028841971693; - - template <typename TRes> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*) { - return value; - } -}; - -struct TSinh { - - static constexpr const char * Name = "sinh"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::sinh(arg); - } -}; - -struct TSqrt { - - static constexpr const char * Name = "sqrt"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::sqrt(arg); - } -}; - -struct TTgamma { - - static constexpr const char * Name = "tgamma"; - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::tgamma(arg); - } -}; - + +struct TAcosh { + + static constexpr const char * Name = "acosh"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::acosh(arg); + } +}; + +struct TAtanh { + + static constexpr const char * Name = "atanh"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::atanh(arg); + } +}; + +struct TCbrt { + + static constexpr const char * Name = "cbrt"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::cbrt(arg); + } +}; + +struct TCosh { + + static constexpr const char * Name = "cosh"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::cosh(arg); + } +}; + +struct TE { + + static constexpr const char * Name = "e"; + static constexpr double value = 2.7182818284590452353602874713526624977572470; + + template <typename TRes> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*) { + return value; + } +}; + +struct TErf { + + static constexpr const char * Name = "erf"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::erf(arg); + } +}; + +struct TErfc { + + static constexpr const char * Name = "erfc"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::erfc(arg); + } +}; + +struct TExp { + + static constexpr const char * Name = "exp"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::exp(arg); + } +}; + +struct TExp2 { + + static constexpr const char * Name = "exp2"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::exp2(arg); + } +}; + +struct TExp10 { + + static constexpr const char * Name = "exp10"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return exp10(arg); + } +}; + +struct THypot { + + static constexpr const char * Name = "hypot"; + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { + return std::hypot(lhs, rhs); + } +}; + +struct TLgamma { + + static constexpr const char * Name = "lgamma"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::lgamma(arg); + } +}; + +struct TPi { + + static constexpr const char * Name = "pi"; + static constexpr double value = 3.1415926535897932384626433832795028841971693; + + template <typename TRes> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*) { + return value; + } +}; + +struct TSinh { + + static constexpr const char * Name = "sinh"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::sinh(arg); + } +}; + +struct TSqrt { + + static constexpr const char * Name = "sqrt"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::sqrt(arg); + } +}; + +struct TTgamma { + + static constexpr const char * Name = "tgamma"; + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::tgamma(arg); + } +}; + } diff --git a/ydb/core/formats/func_modulo.h b/ydb/core/formats/func_modulo.h index 85bcc7b9a1..cf810f9d79 100644 --- a/ydb/core/formats/func_modulo.h +++ b/ydb/core/formats/func_modulo.h @@ -1,32 +1,32 @@ -#pragma once -#include "func_common.h" -#include "clickhouse_type_traits.h" - +#pragma once +#include "func_common.h" +#include "clickhouse_type_traits.h" + namespace NKikimr::NArrow { - -struct TModulo { - - static constexpr const char * Name = "mod"; - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr EnableIfInteger<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status* st) { - static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); - if (ARROW_PREDICT_FALSE(rhs == 0)) { - *st = arrow::Status::Invalid("divide by zero"); - return 0; - } - return static_cast<TRes>(lhs) % static_cast<TRes>(rhs); - } - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr EnableIfFloatingPoint<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status* st) { - static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); - if (static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs) == 0) { - *st = arrow::Status::Invalid("divide by zero"); - return 0; - } - return static_cast<typename TToInteger<TArg0>::Type::c_type>(lhs) % static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs); - } -}; - -} + +struct TModulo { + + static constexpr const char * Name = "mod"; + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr EnableIfInteger<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status* st) { + static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); + if (ARROW_PREDICT_FALSE(rhs == 0)) { + *st = arrow::Status::Invalid("divide by zero"); + return 0; + } + return static_cast<TRes>(lhs) % static_cast<TRes>(rhs); + } + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr EnableIfFloatingPoint<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status* st) { + static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); + if (static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs) == 0) { + *st = arrow::Status::Invalid("divide by zero"); + return 0; + } + return static_cast<typename TToInteger<TArg0>::Type::c_type>(lhs) % static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs); + } +}; + +} diff --git a/ydb/core/formats/func_modulo_or_zero.h b/ydb/core/formats/func_modulo_or_zero.h index a2a4a494a3..be05e263f7 100644 --- a/ydb/core/formats/func_modulo_or_zero.h +++ b/ydb/core/formats/func_modulo_or_zero.h @@ -1,30 +1,30 @@ -#pragma once -#include "func_common.h" -#include "clickhouse_type_traits.h" - +#pragma once +#include "func_common.h" +#include "clickhouse_type_traits.h" + namespace NKikimr::NArrow { - -struct TModuloOrZero { - - static constexpr const char * Name = "modOrZero"; - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr EnableIfInteger<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { - static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); - if (ARROW_PREDICT_FALSE(rhs == 0)) { - return 0; - } - return static_cast<TRes>(lhs) % static_cast<TRes>(rhs); - } - - template <typename TRes, typename TArg0, typename TArg1> - static constexpr EnableIfFloatingPoint<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { - static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); - if (static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs) == 0) { - return 0; - } - return static_cast<typename TToInteger<TArg0>::Type::c_type>(lhs) % static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs); - } -}; - -} + +struct TModuloOrZero { + + static constexpr const char * Name = "modOrZero"; + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr EnableIfInteger<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { + static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); + if (ARROW_PREDICT_FALSE(rhs == 0)) { + return 0; + } + return static_cast<TRes>(lhs) % static_cast<TRes>(rhs); + } + + template <typename TRes, typename TArg0, typename TArg1> + static constexpr EnableIfFloatingPoint<TRes> Call(arrow::compute::KernelContext*, TArg0 lhs, TArg1 rhs, arrow::Status*) { + static_assert(std::is_same<TRes, TArg0>::value && std::is_same<TRes, TArg1>::value, ""); + if (static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs) == 0) { + return 0; + } + return static_cast<typename TToInteger<TArg0>::Type::c_type>(lhs) % static_cast<typename TToInteger<TArg1>::Type::c_type>(rhs); + } +}; + +} diff --git a/ydb/core/formats/func_mul.h b/ydb/core/formats/func_mul.h index d9be3198e9..6294353988 100644 --- a/ydb/core/formats/func_mul.h +++ b/ydb/core/formats/func_mul.h @@ -1,56 +1,56 @@ -#include "func_common.h" - -namespace cp = arrow::compute; - +#include "func_common.h" + +namespace cp = arrow::compute; + namespace NKikimr::NArrow { - -template <typename T, typename TUnsigned = typename std::make_unsigned<T>::type> -constexpr TUnsigned ToUnsigned(T sgnd) { - return static_cast<TUnsigned>(sgnd); -} - -struct TMultiply { - static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(uint16_t() * uint16_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value, ""); - static_assert(std::is_same<decltype(uint32_t() * uint32_t()), uint32_t>::value, ""); - static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value, ""); - static_assert(std::is_same<decltype(uint64_t() * uint64_t()), uint64_t>::value, ""); - - template <typename T, typename TArg0, typename TArg1> - static constexpr arrow::enable_if_floating_point<T> Call(cp::KernelContext*, T left, T right, - arrow::Status*) { - return left * right; - } - - template <typename T, typename TArg0, typename TArg1> - static constexpr std::enable_if_t<IsUnsignedInteger<T>::value && !std::is_same<T, uint16_t>::value, T> - Call(cp::KernelContext*, T left, T right, arrow::Status*) { - return left * right; - } - - template <typename T, typename TArg0, typename TArg1> - static constexpr std::enable_if_t<IsSignedInteger<T>::value && !std::is_same<T, int16_t>::value, T> - Call(cp::KernelContext*, T left, T right, arrow::Status*) { - return ToUnsigned(left) * ToUnsigned(right); - } - - // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit - // integer. However, some inputs may nevertheless overflow (which triggers undefined - // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is - // well defined. - template <typename T, typename TArg0, typename TArg1> - static constexpr arrow::enable_if_same<T, int16_t, T> Call(cp::KernelContext*, int16_t left, - int16_t right, arrow::Status*) { - return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); - } - template <typename T, typename TArg0, typename TArg1> - static constexpr arrow::enable_if_same<T, uint16_t, T> Call(cp::KernelContext*, uint16_t left, - uint16_t right, arrow::Status*) { - return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); - } -}; - -} + +template <typename T, typename TUnsigned = typename std::make_unsigned<T>::type> +constexpr TUnsigned ToUnsigned(T sgnd) { + return static_cast<TUnsigned>(sgnd); +} + +struct TMultiply { + static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(uint16_t() * uint16_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value, ""); + static_assert(std::is_same<decltype(uint32_t() * uint32_t()), uint32_t>::value, ""); + static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value, ""); + static_assert(std::is_same<decltype(uint64_t() * uint64_t()), uint64_t>::value, ""); + + template <typename T, typename TArg0, typename TArg1> + static constexpr arrow::enable_if_floating_point<T> Call(cp::KernelContext*, T left, T right, + arrow::Status*) { + return left * right; + } + + template <typename T, typename TArg0, typename TArg1> + static constexpr std::enable_if_t<IsUnsignedInteger<T>::value && !std::is_same<T, uint16_t>::value, T> + Call(cp::KernelContext*, T left, T right, arrow::Status*) { + return left * right; + } + + template <typename T, typename TArg0, typename TArg1> + static constexpr std::enable_if_t<IsSignedInteger<T>::value && !std::is_same<T, int16_t>::value, T> + Call(cp::KernelContext*, T left, T right, arrow::Status*) { + return ToUnsigned(left) * ToUnsigned(right); + } + + // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit + // integer. However, some inputs may nevertheless overflow (which triggers undefined + // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is + // well defined. + template <typename T, typename TArg0, typename TArg1> + static constexpr arrow::enable_if_same<T, int16_t, T> Call(cp::KernelContext*, int16_t left, + int16_t right, arrow::Status*) { + return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); + } + template <typename T, typename TArg0, typename TArg1> + static constexpr arrow::enable_if_same<T, uint16_t, T> Call(cp::KernelContext*, uint16_t left, + uint16_t right, arrow::Status*) { + return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); + } +}; + +} diff --git a/ydb/core/formats/func_round.h b/ydb/core/formats/func_round.h index a65e893622..978b355303 100644 --- a/ydb/core/formats/func_round.h +++ b/ydb/core/formats/func_round.h @@ -1,69 +1,69 @@ -#pragma once +#pragma once #include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/type_traits.h> -#include "func_common.h" -#include "bit_cast.h" - -#include <cmath> -#include <cstdint> -#include <fenv.h> -#include <type_traits> - +#include "func_common.h" +#include "bit_cast.h" + +#include <cmath> +#include <cstdint> +#include <fenv.h> +#include <type_traits> + namespace NKikimr::NArrow { - -struct TRound { - - static constexpr const char * Name = "round"; - - template <typename TRes, typename TArg> - static constexpr TRes Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - return std::round(arg); - } - -}; - - -struct TRoundBankers { - - static constexpr const char * Name = "roundBankers"; - - template <typename TRes, typename TArg> - static constexpr TRes Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - fesetround(FE_TONEAREST); - return std::rint(arg); - } -}; - -struct TRoundToExp2 { - static constexpr const char * Name = "roundToExp2"; - - template <typename TRes, typename TArg> - static constexpr std::enable_if_t<std::is_integral_v<TRes> && - (sizeof(TRes) <= sizeof(uint32_t)), TRes> - Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - static_assert(std::is_same_v<TRes, TArg>, ""); - return arg <= 0 ? 0 : (TRes(1) << (31 - __builtin_clz(arg))); - } - - template <typename TRes, typename TArg> - static constexpr std::enable_if_t<std::is_integral_v<TRes> && - (sizeof(TRes) == sizeof(uint64_t)), TRes> - Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - static_assert(std::is_same_v<TRes, TArg>, ""); - return arg <= 0 ? 0 : (TRes(1) << (63 - __builtin_clzll(arg))); - } - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat32<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - static_assert(std::is_same_v<TRes, TArg>, ""); - return bit_cast<TRes>(bit_cast<uint32_t>(arg) & ~((1ULL << 23) - 1)); - } - - template <typename TRes, typename TArg> - static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { - static_assert(std::is_same_v<TRes, TArg>, ""); - return bit_cast<TRes>(bit_cast<uint64_t>(arg) & ~((1ULL << 52) - 1)); - } -}; - + +struct TRound { + + static constexpr const char * Name = "round"; + + template <typename TRes, typename TArg> + static constexpr TRes Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + return std::round(arg); + } + +}; + + +struct TRoundBankers { + + static constexpr const char * Name = "roundBankers"; + + template <typename TRes, typename TArg> + static constexpr TRes Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + fesetround(FE_TONEAREST); + return std::rint(arg); + } +}; + +struct TRoundToExp2 { + static constexpr const char * Name = "roundToExp2"; + + template <typename TRes, typename TArg> + static constexpr std::enable_if_t<std::is_integral_v<TRes> && + (sizeof(TRes) <= sizeof(uint32_t)), TRes> + Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + static_assert(std::is_same_v<TRes, TArg>, ""); + return arg <= 0 ? 0 : (TRes(1) << (31 - __builtin_clz(arg))); + } + + template <typename TRes, typename TArg> + static constexpr std::enable_if_t<std::is_integral_v<TRes> && + (sizeof(TRes) == sizeof(uint64_t)), TRes> + Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + static_assert(std::is_same_v<TRes, TArg>, ""); + return arg <= 0 ? 0 : (TRes(1) << (63 - __builtin_clzll(arg))); + } + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat32<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + static_assert(std::is_same_v<TRes, TArg>, ""); + return bit_cast<TRes>(bit_cast<uint32_t>(arg) & ~((1ULL << 23) - 1)); + } + + template <typename TRes, typename TArg> + static constexpr EnableIfFloat64<TRes> Call(arrow::compute::KernelContext*, TArg arg, arrow::Status*) { + static_assert(std::is_same_v<TRes, TArg>, ""); + return bit_cast<TRes>(bit_cast<uint64_t>(arg) & ~((1ULL << 52) - 1)); + } +}; + } diff --git a/ydb/core/formats/function_factory.h b/ydb/core/formats/function_factory.h index 4c61a3387c..803a7c2e5a 100644 --- a/ydb/core/formats/function_factory.h +++ b/ydb/core/formats/function_factory.h @@ -1,75 +1,75 @@ #pragma once - + #include <ydb/core/formats/arrow/compute/registry.h> #include <ydb/core/formats/arrow/result.h> #include <ydb/core/formats/arrow/status.h> -#include <algorithm> -#include <memory> -#include <shared_mutex> -#include <string> -#include <unordered_map> -#include <utility> - +#include <algorithm> +#include <memory> +#include <shared_mutex> +#include <string> +#include <unordered_map> +#include <utility> + #include <ydb/core/formats/arrow/compute/function.h> - + namespace NKikimr::NArrow { - -template <bool mt = false> -class TFunctionFactory { -public: - - TFunctionFactory() = default; - - arrow::Status AddFunction(std::shared_ptr<arrow::compute::Function> function, bool allowOverwrite) { - auto it = nameToFunc.find(function->name()); - if (it != nameToFunc.end() && !allowOverwrite) { - return arrow::Status::KeyError("Already have a function registered with name: ", function->name()); - } - nameToFunc[function->name()] = std::move(function); - return arrow::Status::OK(); - } - - - arrow::Result<std::shared_ptr<arrow::compute::Function>> GetFunction(const std::string& name) const { - auto it = nameToFunc.find(name); - if (it == nameToFunc.end()) { - return arrow::compute::GetFunctionRegistry()->GetFunction(name); - } - return it->second; - } -private: - std::unordered_map<std::string, std::shared_ptr<arrow::compute::Function>> nameToFunc; -}; - -template <> -class TFunctionFactory<true> { -public: - - TFunctionFactory() = default; - - arrow::Status AddFunction(std::shared_ptr<arrow::compute::Function> function, bool allowOverwrite) { - std::unique_lock guard(lock); - auto it = nameToFunc.find(function->name()); - if (it != nameToFunc.end() && !allowOverwrite) { - return arrow::Status::KeyError("Already have a function registered with name: ", function->name()); - } - nameToFunc[function->name()] = std::move(function); - return arrow::Status::OK(); - } - - - arrow::Result<std::shared_ptr<arrow::compute::Function>> GetFunction(const std::string& name) const { - std::shared_lock guard(lock); - auto it = nameToFunc.find(name); - if (it == nameToFunc.end()) { - return arrow::compute::GetFunctionRegistry()->GetFunction(name); - } - return it->second; - } -private: - mutable std::shared_mutex lock; - std::unordered_map<std::string, std::shared_ptr<arrow::compute::Function>> nameToFunc; -}; - + +template <bool mt = false> +class TFunctionFactory { +public: + + TFunctionFactory() = default; + + arrow::Status AddFunction(std::shared_ptr<arrow::compute::Function> function, bool allowOverwrite) { + auto it = nameToFunc.find(function->name()); + if (it != nameToFunc.end() && !allowOverwrite) { + return arrow::Status::KeyError("Already have a function registered with name: ", function->name()); + } + nameToFunc[function->name()] = std::move(function); + return arrow::Status::OK(); + } + + + arrow::Result<std::shared_ptr<arrow::compute::Function>> GetFunction(const std::string& name) const { + auto it = nameToFunc.find(name); + if (it == nameToFunc.end()) { + return arrow::compute::GetFunctionRegistry()->GetFunction(name); + } + return it->second; + } +private: + std::unordered_map<std::string, std::shared_ptr<arrow::compute::Function>> nameToFunc; +}; + +template <> +class TFunctionFactory<true> { +public: + + TFunctionFactory() = default; + + arrow::Status AddFunction(std::shared_ptr<arrow::compute::Function> function, bool allowOverwrite) { + std::unique_lock guard(lock); + auto it = nameToFunc.find(function->name()); + if (it != nameToFunc.end() && !allowOverwrite) { + return arrow::Status::KeyError("Already have a function registered with name: ", function->name()); + } + nameToFunc[function->name()] = std::move(function); + return arrow::Status::OK(); + } + + + arrow::Result<std::shared_ptr<arrow::compute::Function>> GetFunction(const std::string& name) const { + std::shared_lock guard(lock); + auto it = nameToFunc.find(name); + if (it == nameToFunc.end()) { + return arrow::compute::GetFunctionRegistry()->GetFunction(name); + } + return it->second; + } +private: + mutable std::shared_mutex lock; + std::unordered_map<std::string, std::shared_ptr<arrow::compute::Function>> nameToFunc; +}; + } diff --git a/ydb/core/formats/functions.h b/ydb/core/formats/functions.h index 4c0ecda6be..4b8ad3a056 100644 --- a/ydb/core/formats/functions.h +++ b/ydb/core/formats/functions.h @@ -1,8 +1,8 @@ -#pragma once - -#include "func_gcd.h" -#include "func_lcm.h" -#include "func_modulo.h" -#include "func_modulo_or_zero.h" -#include "func_math.h" -#include "func_round.h" +#pragma once + +#include "func_gcd.h" +#include "func_lcm.h" +#include "func_modulo.h" +#include "func_modulo_or_zero.h" +#include "func_math.h" +#include "func_round.h" diff --git a/ydb/core/formats/program.cpp b/ydb/core/formats/program.cpp index 3bc39be6fe..402bc1900f 100644 --- a/ydb/core/formats/program.cpp +++ b/ydb/core/formats/program.cpp @@ -1,13 +1,13 @@ -#include <memory> -#include <unordered_map> -#include <vector> -#include <cstdint> -#include <algorithm> - +#include <memory> +#include <unordered_map> +#include <vector> +#include <cstdint> +#include <algorithm> + #include "program.h" #include "arrow_helpers.h" -#include <util/system/yassert.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <util/system/yassert.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h> @@ -16,7 +16,7 @@ #include <ydb/core/util/yverify_stream.h> namespace NKikimr::NArrow { - + const char * GetFunctionName(EOperation op) { switch (op) { case EOperation::CastBoolean: @@ -83,18 +83,18 @@ const char * GetFunctionName(EOperation op) { return "multiply"; case EOperation::Divide: return "divide"; - case EOperation::Abs: - return "abs"; - case EOperation::Negate: - return "negate"; - case EOperation::Gcd: - return "gcd"; - case EOperation::Lcm: - return "lcm"; - case EOperation::Modulo: - return "mod"; - case EOperation::ModuloOrZero: - return "modOrZero"; + case EOperation::Abs: + return "abs"; + case EOperation::Negate: + return "negate"; + case EOperation::Gcd: + return "gcd"; + case EOperation::Lcm: + return "lcm"; + case EOperation::Modulo: + return "mod"; + case EOperation::ModuloOrZero: + return "modOrZero"; case EOperation::AddNotNull: return "add_checked"; case EOperation::SubtractNotNull: @@ -109,52 +109,52 @@ const char * GetFunctionName(EOperation op) { case EOperation::MatchSubstring: return "match_substring"; - case EOperation::Acosh: - return "acosh"; - case EOperation::Atanh: - return "atanh"; - case EOperation::Cbrt: - return "cbrt"; - case EOperation::Cosh: - return "cosh"; - case EOperation::E: - return "e"; - case EOperation::Erf: - return "erf"; - case EOperation::Erfc: - return "erfc"; - case EOperation::Exp: - return "exp"; - case EOperation::Exp2: - return "exp2"; - case EOperation::Exp10: - return "exp10"; - case EOperation::Hypot: - return "hypot"; - case EOperation::Lgamma: - return "lgamma"; - case EOperation::Pi: - return "pi"; - case EOperation::Sinh: - return "sinh"; - case EOperation::Sqrt: - return "sqrt"; - case EOperation::Tgamma: - return "tgamma"; - - case EOperation::Floor: - return "floor"; - case EOperation::Ceil: - return "ceil"; - case EOperation::Trunc: - return "trunc"; - case EOperation::Round: - return "round"; - case EOperation::RoundBankers: - return "roundBankers"; - case EOperation::RoundToExp2: - return "roundToExp2"; - + case EOperation::Acosh: + return "acosh"; + case EOperation::Atanh: + return "atanh"; + case EOperation::Cbrt: + return "cbrt"; + case EOperation::Cosh: + return "cosh"; + case EOperation::E: + return "e"; + case EOperation::Erf: + return "erf"; + case EOperation::Erfc: + return "erfc"; + case EOperation::Exp: + return "exp"; + case EOperation::Exp2: + return "exp2"; + case EOperation::Exp10: + return "exp10"; + case EOperation::Hypot: + return "hypot"; + case EOperation::Lgamma: + return "lgamma"; + case EOperation::Pi: + return "pi"; + case EOperation::Sinh: + return "sinh"; + case EOperation::Sqrt: + return "sqrt"; + case EOperation::Tgamma: + return "tgamma"; + + case EOperation::Floor: + return "floor"; + case EOperation::Ceil: + return "ceil"; + case EOperation::Trunc: + return "trunc"; + case EOperation::Round: + return "round"; + case EOperation::RoundBankers: + return "roundBankers"; + case EOperation::RoundToExp2: + return "roundToExp2"; + // TODO: "is_in", "index_in" default: @@ -163,160 +163,160 @@ const char * GetFunctionName(EOperation op) { return ""; } -void AddColumn(std::shared_ptr<TProgramStep::TDatumBatch>& batch, - std::string field_name, - const arrow::Datum& column) { - auto field = ::arrow::field(std::move(field_name), column.type()); - Y_VERIFY(field != nullptr); - Y_VERIFY(field->type()->Equals(column.type())); - Y_VERIFY(column.is_scalar() || column.length() == batch->rows); - auto new_schema = *batch->fields->AddField(batch->fields->num_fields(), field); - batch->datums.push_back(column); - batch->fields = new_schema; -} - -arrow::Result<arrow::Datum> GetColumnByName(const std::shared_ptr<TProgramStep::TDatumBatch>& batch, const std::string& name) { - int i = batch->fields->GetFieldIndex(name); - if (i == -1) { - return arrow::Status::Invalid("Not found or duplicate"); - } - else { - return batch->datums[i]; - } -} - -std::shared_ptr<TProgramStep::TDatumBatch> ToTDatumBatch(std::shared_ptr<arrow::RecordBatch>& batch) { - std::vector<arrow::Datum> datums; - datums.reserve(batch->num_columns()); - for (int64_t i = 0; i < batch->num_columns(); ++i) { - datums.push_back(arrow::Datum(batch->column(i))); - } - return std::make_shared<TProgramStep::TDatumBatch>(TProgramStep::TDatumBatch{std::make_shared<arrow::Schema>(*batch->schema()), batch->num_rows(), std::move(datums)}); -} - -std::shared_ptr<arrow::RecordBatch> ToRecordBatch(std::shared_ptr<TProgramStep::TDatumBatch>& batch) { - std::vector<std::shared_ptr<arrow::Array>> columns; - columns.reserve(batch->datums.size()); - for (auto col : batch->datums) { - if (col.is_scalar()) { - columns.push_back(*arrow::MakeArrayFromScalar(*col.scalar(), batch->rows)); - } - else if (col.is_array()){ - Y_VERIFY(col.length() != -1); - columns.push_back(col.make_array()); - } - } - return arrow::RecordBatch::Make(batch->fields, batch->rows, columns); -} - - +void AddColumn(std::shared_ptr<TProgramStep::TDatumBatch>& batch, + std::string field_name, + const arrow::Datum& column) { + auto field = ::arrow::field(std::move(field_name), column.type()); + Y_VERIFY(field != nullptr); + Y_VERIFY(field->type()->Equals(column.type())); + Y_VERIFY(column.is_scalar() || column.length() == batch->rows); + auto new_schema = *batch->fields->AddField(batch->fields->num_fields(), field); + batch->datums.push_back(column); + batch->fields = new_schema; +} + +arrow::Result<arrow::Datum> GetColumnByName(const std::shared_ptr<TProgramStep::TDatumBatch>& batch, const std::string& name) { + int i = batch->fields->GetFieldIndex(name); + if (i == -1) { + return arrow::Status::Invalid("Not found or duplicate"); + } + else { + return batch->datums[i]; + } +} + +std::shared_ptr<TProgramStep::TDatumBatch> ToTDatumBatch(std::shared_ptr<arrow::RecordBatch>& batch) { + std::vector<arrow::Datum> datums; + datums.reserve(batch->num_columns()); + for (int64_t i = 0; i < batch->num_columns(); ++i) { + datums.push_back(arrow::Datum(batch->column(i))); + } + return std::make_shared<TProgramStep::TDatumBatch>(TProgramStep::TDatumBatch{std::make_shared<arrow::Schema>(*batch->schema()), batch->num_rows(), std::move(datums)}); +} + +std::shared_ptr<arrow::RecordBatch> ToRecordBatch(std::shared_ptr<TProgramStep::TDatumBatch>& batch) { + std::vector<std::shared_ptr<arrow::Array>> columns; + columns.reserve(batch->datums.size()); + for (auto col : batch->datums) { + if (col.is_scalar()) { + columns.push_back(*arrow::MakeArrayFromScalar(*col.scalar(), batch->rows)); + } + else if (col.is_array()){ + Y_VERIFY(col.length() != -1); + columns.push_back(col.make_array()); + } + } + return arrow::RecordBatch::Make(batch->fields, batch->rows, columns); +} + + std::shared_ptr<arrow::Array> MakeConstantColumn(const arrow::Scalar& value, int64_t size) { auto res = arrow::MakeArrayFromScalar(value, size); Y_VERIFY(res.ok()); return *res; } -//firstly try to call function from custom registry, if fails call from default -arrow::Result<arrow::Datum> CallFromCustomOrDefaultRegistry(EOperation funcId, const std::vector<arrow::Datum>& arguments, arrow::compute::ExecContext* ctx) { - std::string funcName = GetFunctionName(funcId); - if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) { - return arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx); - } else { - return arrow::compute::CallFunction(GetFunctionName(funcId), arguments); - } -} - -std::shared_ptr<arrow::Array> CallArrayFunction(EOperation funcId, const std::vector<std::string>& args, - std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) { +//firstly try to call function from custom registry, if fails call from default +arrow::Result<arrow::Datum> CallFromCustomOrDefaultRegistry(EOperation funcId, const std::vector<arrow::Datum>& arguments, arrow::compute::ExecContext* ctx) { + std::string funcName = GetFunctionName(funcId); + if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) { + return arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx); + } else { + return arrow::compute::CallFunction(GetFunctionName(funcId), arguments); + } +} + +std::shared_ptr<arrow::Array> CallArrayFunction(EOperation funcId, const std::vector<std::string>& args, + std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) { std::vector<arrow::Datum> arguments; arguments.reserve(args.size()); for (auto& colName : args) { auto column = batch->GetColumnByName(colName); Y_VERIFY(column); - arguments.push_back(arrow::Datum(*column)); + arguments.push_back(arrow::Datum(*column)); } - std::string funcName = GetFunctionName(funcId); - arrow::Result<arrow::Datum> result; - result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx); + std::string funcName = GetFunctionName(funcId); + arrow::Result<arrow::Datum> result; + result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx); + Y_VERIFY(result.ok()); + Y_VERIFY(result->is_array()); + return result->make_array(); +} + + +std::shared_ptr<arrow::Scalar> CallScalarFunction(EOperation funcId, const std::vector<std::string>& args, + std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) { + std::vector<arrow::Datum> arguments; + arguments.reserve(args.size()); + + for (auto& colName : args) { + auto column = batch->GetColumnByName(colName); + Y_VERIFY(column); + arguments.push_back(arrow::Datum{column}); + } + std::string funcName = GetFunctionName(funcId); + arrow::Result<arrow::Datum> result; + result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx); Y_VERIFY(result.ok()); - Y_VERIFY(result->is_array()); - return result->make_array(); + Y_VERIFY(result->is_scalar()); + return result->scalar(); } - -std::shared_ptr<arrow::Scalar> CallScalarFunction(EOperation funcId, const std::vector<std::string>& args, - std::shared_ptr<arrow::RecordBatch> batch, arrow::compute::ExecContext* ctx) { - std::vector<arrow::Datum> arguments; - arguments.reserve(args.size()); - - for (auto& colName : args) { - auto column = batch->GetColumnByName(colName); - Y_VERIFY(column); - arguments.push_back(arrow::Datum{column}); - } - std::string funcName = GetFunctionName(funcId); - arrow::Result<arrow::Datum> result; - result = CallFromCustomOrDefaultRegistry(funcId, arguments, ctx); - Y_VERIFY(result.ok()); - Y_VERIFY(result->is_scalar()); - return result->scalar(); -} - -arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, - std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) { - std::vector<arrow::Datum> arguments; - arguments.reserve(args.size()); - - for (auto& colName : args) { - auto column = GetColumnByName(batch, colName); - Y_VERIFY(column.ok()); - arguments.push_back(*column); - } - std::string funcName = GetFunctionName(funcId); - arrow::Result<arrow::Datum> result; - if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) { - result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx); - } else { - result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments); - } - Y_VERIFY(result.ok()); - return result.ValueOrDie(); +arrow::Datum CallFunctionById(EOperation funcId, const std::vector<std::string>& args, + std::shared_ptr<TProgramStep::TDatumBatch> batch, arrow::compute::ExecContext* ctx) { + std::vector<arrow::Datum> arguments; + arguments.reserve(args.size()); + + for (auto& colName : args) { + auto column = GetColumnByName(batch, colName); + Y_VERIFY(column.ok()); + arguments.push_back(*column); + } + std::string funcName = GetFunctionName(funcId); + arrow::Result<arrow::Datum> result; + if (ctx != nullptr && ctx->func_registry()->GetFunction(funcName).ok()) { + result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments, ctx); + } else { + result = arrow::compute::CallFunction(GetFunctionName(funcId), arguments); + } + Y_VERIFY(result.ok()); + return result.ValueOrDie(); } - - -void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const { + + +void TProgramStep::ApplyAssignes(std::shared_ptr<TProgramStep::TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const { if (Assignes.empty()) { return; } - batch->datums.reserve(batch->datums.size() + Assignes.size()); + batch->datums.reserve(batch->datums.size() + Assignes.size()); for (auto& assign : Assignes) { - Y_VERIFY(!GetColumnByName(batch, assign.GetName()).ok()); + Y_VERIFY(!GetColumnByName(batch, assign.GetName()).ok()); - arrow::Datum column; + arrow::Datum column; if (assign.IsConstant()) { - column = assign.GetConstant(); + column = assign.GetConstant(); } else { - column = CallFunctionById(assign.GetOperation(), assign.GetArguments(), batch, ctx); + column = CallFunctionById(assign.GetOperation(), assign.GetArguments(), batch, ctx); } - AddColumn(batch, assign.GetName(), column); + AddColumn(batch, assign.GetName(), column); } - //Y_VERIFY(batch->Validate().ok()); + //Y_VERIFY(batch->Validate().ok()); } -void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const { +void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const { if (Filters.empty()) { return; } std::vector<std::vector<bool>> filters; filters.reserve(Filters.size()); for (auto& colName : Filters) { - auto column = GetColumnByName(batch, colName); - Y_VERIFY(column.ok()); - Y_VERIFY(column->is_array()); - Y_VERIFY(column->type() == arrow::boolean()); - auto boolColumn = std::static_pointer_cast<arrow::BooleanArray>(column->make_array()); + auto column = GetColumnByName(batch, colName); + Y_VERIFY(column.ok()); + Y_VERIFY(column->is_array()); + Y_VERIFY(column->type() == arrow::boolean()); + auto boolColumn = std::static_pointer_cast<arrow::BooleanArray>(column->make_array()); filters.push_back(std::vector<bool>(boolColumn->length())); auto& bits = filters.back(); for (size_t i = 0; i < bits.size(); ++i) { @@ -330,52 +330,52 @@ void TProgramStep::ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const { } if (bits.size()) { - auto filter = NArrow::MakeFilter(bits); - std::unordered_set<std::string_view> projSet; - for (auto& str: Projection) { - projSet.insert(str); - } - for (int64_t i = 0; i < batch->fields->num_fields(); ++i) { - //only array filtering, scalar cannot be filtered - auto& cur_field_name = batch->fields->field(i)->name(); - bool is_proj = (Projection.empty() || projSet.contains(cur_field_name)); - if (batch->datums[i].is_array() && is_proj) { - auto res = arrow::compute::Filter(batch->datums[i].make_array(), filter); - Y_VERIFY_S(res.ok(), res.status().message()); - Y_VERIFY((*res).kind() == batch->datums[i].kind()); - batch->datums[i] = *res; - } - } - int newRows = 0; - for (int64_t i = 0; i < filter->length(); ++i) { - newRows += filter->Value(i); - } - batch->rows = newRows; + auto filter = NArrow::MakeFilter(bits); + std::unordered_set<std::string_view> projSet; + for (auto& str: Projection) { + projSet.insert(str); + } + for (int64_t i = 0; i < batch->fields->num_fields(); ++i) { + //only array filtering, scalar cannot be filtered + auto& cur_field_name = batch->fields->field(i)->name(); + bool is_proj = (Projection.empty() || projSet.contains(cur_field_name)); + if (batch->datums[i].is_array() && is_proj) { + auto res = arrow::compute::Filter(batch->datums[i].make_array(), filter); + Y_VERIFY_S(res.ok(), res.status().message()); + Y_VERIFY((*res).kind() == batch->datums[i].kind()); + batch->datums[i] = *res; + } + } + int newRows = 0; + for (int64_t i = 0; i < filter->length(); ++i) { + newRows += filter->Value(i); + } + batch->rows = newRows; } } -void TProgramStep::ApplyProjection(std::shared_ptr<TDatumBatch>& batch) const { - if (Projection.empty()) { - return; - } - std::unordered_set<std::string_view> projSet; - for (auto& str: Projection) { - projSet.insert(str); - } - std::vector<std::shared_ptr<arrow::Field>> newFields; - std::vector<arrow::Datum> newDatums; - for (int64_t i = 0; i < batch->fields->num_fields(); ++i) { - auto& cur_field_name = batch->fields->field(i)->name(); - if (projSet.contains(cur_field_name)) { - newFields.push_back(batch->fields->field(i)); - Y_VERIFY(newFields.back()); - newDatums.push_back(batch->datums[i]); - } - } - batch->fields = std::make_shared<arrow::Schema>(newFields); - batch->datums = std::move(newDatums); -} - +void TProgramStep::ApplyProjection(std::shared_ptr<TDatumBatch>& batch) const { + if (Projection.empty()) { + return; + } + std::unordered_set<std::string_view> projSet; + for (auto& str: Projection) { + projSet.insert(str); + } + std::vector<std::shared_ptr<arrow::Field>> newFields; + std::vector<arrow::Datum> newDatums; + for (int64_t i = 0; i < batch->fields->num_fields(); ++i) { + auto& cur_field_name = batch->fields->field(i)->name(); + if (projSet.contains(cur_field_name)) { + newFields.push_back(batch->fields->field(i)); + Y_VERIFY(newFields.back()); + newDatums.push_back(batch->datums[i]); + } + } + batch->fields = std::make_shared<arrow::Schema>(newFields); + batch->datums = std::move(newDatums); +} + void TProgramStep::ApplyProjection(std::shared_ptr<arrow::RecordBatch>& batch) const { if (Projection.empty()) { return; @@ -389,12 +389,12 @@ void TProgramStep::ApplyProjection(std::shared_ptr<arrow::RecordBatch>& batch) c batch = NArrow::ExtractColumns(batch, std::make_shared<arrow::Schema>(fields)); } -void TProgramStep::Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::compute::ExecContext* ctx) const { - auto rb = ToTDatumBatch(batch); - ApplyAssignes(rb, ctx); - ApplyFilters(rb); - ApplyProjection(rb); - batch = ToRecordBatch(rb); -} - +void TProgramStep::Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::compute::ExecContext* ctx) const { + auto rb = ToTDatumBatch(batch); + ApplyAssignes(rb, ctx); + ApplyFilters(rb); + ApplyProjection(rb); + batch = ToRecordBatch(rb); +} + } diff --git a/ydb/core/formats/program.h b/ydb/core/formats/program.h index ff15d30ebf..9d8a2c027d 100644 --- a/ydb/core/formats/program.h +++ b/ydb/core/formats/program.h @@ -1,5 +1,5 @@ #pragma once -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> #include <util/system/types.h> @@ -38,17 +38,17 @@ enum class EOperation { And, Or, Xor, - // + // Add, Subtract, Multiply, Divide, - Abs, - Negate, - Gcd, - Lcm, - Modulo, - ModuloOrZero, + Abs, + Negate, + Gcd, + Lcm, + Modulo, + ModuloOrZero, AddNotNull, SubtractNotNull, MultiplyNotNull, @@ -56,30 +56,30 @@ enum class EOperation { // BinaryLength, MatchSubstring, - // math - Acosh, - Atanh, - Cbrt, - Cosh, - E, - Erf, - Erfc, - Exp, - Exp2, - Exp10, - Hypot, - Lgamma, - Pi, - Sinh, - Sqrt, - Tgamma, - // round - Floor, - Ceil, - Trunc, - Round, - RoundBankers, - RoundToExp2 + // math + Acosh, + Atanh, + Cbrt, + Cosh, + E, + Erf, + Erfc, + Exp, + Exp2, + Exp10, + Hypot, + Lgamma, + Pi, + Sinh, + Sqrt, + Tgamma, + // round + Floor, + Ceil, + Trunc, + Round, + RoundBankers, + RoundToExp2 }; const char * GetFunctionName(EOperation op); @@ -174,29 +174,29 @@ struct TProgramStep { std::vector<std::string> Projection; // Step's result columns (remove others) // TODO: group by - struct TDatumBatch { - std::shared_ptr<arrow::Schema> fields; - int64_t rows; - std::vector<arrow::Datum> datums; - }; - + struct TDatumBatch { + std::shared_ptr<arrow::Schema> fields; + int64_t rows; + std::vector<arrow::Datum> datums; + }; + bool Empty() const { return Assignes.empty() && Filters.empty() && Projection.empty(); } - void Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::compute::ExecContext* ctx) const; + void Apply(std::shared_ptr<arrow::RecordBatch>& batch, arrow::compute::ExecContext* ctx) const; - void ApplyAssignes(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const; - void ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const; + void ApplyAssignes(std::shared_ptr<TDatumBatch>& batch, arrow::compute::ExecContext* ctx) const; + void ApplyFilters(std::shared_ptr<TDatumBatch>& batch) const; void ApplyProjection(std::shared_ptr<arrow::RecordBatch>& batch) const; - void ApplyProjection(std::shared_ptr<TDatumBatch>& batch) const; + void ApplyProjection(std::shared_ptr<TDatumBatch>& batch) const; }; inline void ApplyProgram(std::shared_ptr<arrow::RecordBatch>& batch, - const std::vector<std::shared_ptr<TProgramStep>>& program, - arrow::compute::ExecContext* ctx = nullptr) { + const std::vector<std::shared_ptr<TProgramStep>>& program, + arrow::compute::ExecContext* ctx = nullptr) { for (auto& step : program) { - step->Apply(batch, ctx); + step->Apply(batch, ctx); } } diff --git a/ydb/core/formats/ut/ya.make b/ydb/core/formats/ut/ya.make index d58f1f4f4f..878516a600 100644 --- a/ydb/core/formats/ut/ya.make +++ b/ydb/core/formats/ut/ya.make @@ -30,11 +30,11 @@ CFLAGS( SRCS( ut_arrow.cpp - ut_arithmetic.cpp - ut_math.cpp - ut_round.cpp - ut_program_step.cpp - custom_registry.cpp + ut_arithmetic.cpp + ut_math.cpp + ut_round.cpp + ut_program_step.cpp + custom_registry.cpp ) END() diff --git a/ydb/core/formats/ut_arithmetic.cpp b/ydb/core/formats/ut_arithmetic.cpp index ffffad99cc..1eff1bd338 100644 --- a/ydb/core/formats/ut_arithmetic.cpp +++ b/ydb/core/formats/ut_arithmetic.cpp @@ -1,372 +1,372 @@ -#include <cmath> -#include <cstdint> -#include <iterator> -#include <library/cpp/testing/unittest/registar.h> -#include <ctime> -#include <vector> -#include <algorithm> - -#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> - -#include "func_common.h" -#include "functions.h" -#include "custom_registry.h" -#include "arrow_helpers.h" - +#include <cmath> +#include <cstdint> +#include <iterator> +#include <library/cpp/testing/unittest/registar.h> +#include <ctime> +#include <vector> +#include <algorithm> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> + +#include "func_common.h" +#include "functions.h" +#include "custom_registry.h" +#include "arrow_helpers.h" + namespace NKikimr::NArrow { - -namespace cp = ::arrow::compute; -using cp::internal::applicator::ScalarBinary; -using cp::internal::applicator::ScalarUnary; - -const std::vector<std::shared_ptr<arrow::DataType>> nonPrimitiveTypes = { - arrow::list(arrow::utf8()), - arrow::list(arrow::int64()), - arrow::large_list(arrow::large_utf8()), - arrow::fixed_size_list(arrow::utf8(), 3), - arrow::fixed_size_list(arrow::int64(), 4), - arrow::dictionary(arrow::int32(), arrow::utf8()) -}; - -const std::vector<std::shared_ptr<arrow::DataType>> decimalTypes = { - arrow::decimal128(12, 2), - arrow::decimal256(12, 2) -}; - - -template<typename TType> -EnableIfSigned<TType> MakeRandomNum() { - double lowerLimit = -120; - double upperLimit = 120; - return static_cast<TType>(double(rand()) * (upperLimit - lowerLimit) / RAND_MAX + lowerLimit); -} - -template<typename TType> -EnableIfUnsigned<TType> MakeRandomNum() { - double lowerLimit = 0; - double upperLimit = 120; - return static_cast<TType>(double(rand()) * (upperLimit - lowerLimit) / RAND_MAX + lowerLimit); -} - -template<typename TType> -std::shared_ptr<TArray<TType>> GenerateTestArray(int64_t sz) { - std::srand(std::time(nullptr)); - TBuilder<TType> builder; - std::shared_ptr<TArray<TType>> res; - arrow::Status st; - UNIT_ASSERT(builder.Reserve(sz).ok()); - for (int64_t i = 0; i < sz; ++i) { - UNIT_ASSERT(builder.Append(MakeRandomNum<typename TType::c_type>()).ok()); - } - UNIT_ASSERT(builder.Finish(&res).ok()); - return res; -} - -void TestWrongTypeUnary(const std::string& func_name, std::shared_ptr<arrow::DataType> type, cp::ExecContext* exc = nullptr) { - auto res = arrow::compute::CallFunction(func_name, {arrow::MakeArrayOfNull(type, 1).ValueOrDie()}, exc); - UNIT_ASSERT_EQUAL(res.ok(), false); -} - -void TestWrongTypeBinary(const std::string& func_name, std::shared_ptr<arrow::DataType> type0, - std::shared_ptr<arrow::DataType> type1, - cp::ExecContext* exc = nullptr) { - auto res = arrow::compute::CallFunction(func_name, {arrow::MakeArrayOfNull(type0, 1).ValueOrDie(), - arrow::MakeArrayOfNull(type1, 1).ValueOrDie()}, - exc); - UNIT_ASSERT_EQUAL(res.ok(), false); -} - -std::shared_ptr<arrow::Array> MakeBooleanArray(const std::vector<bool>& arr) { - arrow::BooleanBuilder builder; - std::shared_ptr<arrow::BooleanArray> res; - UNIT_ASSERT(builder.Reserve(arr.size()).ok()); - for (size_t i = 0; i < arr.size(); ++i) { - UNIT_ASSERT(builder.Append(arr[i]).ok()); - } - UNIT_ASSERT(builder.Finish(&res).ok()); - return res; -} - -Y_UNIT_TEST_SUITE(ArrowAbsTest) { - Y_UNIT_TEST(AbsSignedInts) { - for (auto type : cp::internal::SignedIntTypes()) { - auto arg0 = NumVecToArray(type, {-32, -12, 54}); - auto expected = NumVecToArray(type, {32, 12, 54}); - auto res = arrow::compute::CallFunction("abs", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(AbsUnsignedInts) { - for (auto type : cp::internal::UnsignedIntTypes()) { - auto arg0 = NumVecToArray(type, {32, 12, 54}); - auto expected = NumVecToArray(type, {32, 12, 54}); - auto res = arrow::compute::CallFunction("abs", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(AbsFloating) { - for (auto type : cp::internal::FloatingPointTypes()) { - auto arg0 = NumVecToArray(type, {-32.4, -12.3, 54.7}); - auto expected = NumVecToArray(type, {32.4, 12.3, 54.7}); - auto res = arrow::compute::CallFunction("abs", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(AbsNull) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); - auto expected = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); - auto res = arrow::compute::CallFunction("abs", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } - - /*Y_UNIT_TEST(AbsWrongTypes) { - for (auto type : cp::internal::TemporalTypes()) { - TestWrongTypeUnary(TAbsoluteValue::Name, type); - } - for (auto type : cp::internal::StringTypes()) { - TestWrongTypeUnary(TAbsoluteValue::Name, type); - } - for (auto type : cp::internal::BaseBinaryTypes()) { - TestWrongTypeUnary(TAbsoluteValue::Name, type); - } - for (auto type : nonPrimitiveTypes) { - TestWrongTypeUnary(TAbsoluteValue::Name, type); - } - }*/ -} - -Y_UNIT_TEST_SUITE(ArrowNegateTest) { - Y_UNIT_TEST(NegateSignedInts) { - for (auto type : cp::internal::SignedIntTypes()) { - auto arg0 = NumVecToArray(type, {-32, -12, 54}); - auto expected = NumVecToArray(type, {32, 12, -54}); - auto res = arrow::compute::CallFunction("negate", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(NegateFloating) { - for (auto type : cp::internal::FloatingPointTypes()) { - auto arg0 = NumVecToArray(type, {-32.4, -12.3, 54.7}); - auto expected = NumVecToArray(type, {32.4, 12.3, -54.7}); - auto res = arrow::compute::CallFunction("negate", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(NegateWrongTypes) { - for (auto type : cp::internal::TemporalTypes()) { - TestWrongTypeUnary("negate", type); - } - for (auto type : cp::internal::StringTypes()) { - TestWrongTypeUnary("negate", type); - } - for (auto type : cp::internal::BaseBinaryTypes()) { - TestWrongTypeUnary("negate", type); - } - for (auto type : nonPrimitiveTypes) { - TestWrongTypeUnary("negate", type); - } - } - - Y_UNIT_TEST(NegateNull) { - for (auto type : cp::internal::SignedIntTypes()) { - auto arg0 = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); - auto expected = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); - auto res = arrow::compute::CallFunction("negate", {arg0}); - UNIT_ASSERT(res->Equals(expected)); - } - } -} - - - -Y_UNIT_TEST_SUITE(GcdTest) { - Y_UNIT_TEST(GcdInts) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type, {32, 12}); - auto arg1 = NumVecToArray(type, {24, 10}); - auto expected = NumVecToArray(type, {8, 2}); - auto res = arrow::compute::CallFunction(TGreatestCommonDivisor::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(GcdNull) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); - auto arg1 = NumVecToArray(type, {24, 10}); - auto expected = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); - auto res = arrow::compute::CallFunction(TGreatestCommonDivisor::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(GcdDiffTypes) { - auto registry = cp::GetFunctionRegistry(); - auto func = MakeArithmeticIntBinary<TGreatestCommonDivisor>(TGreatestCommonDivisor::Name); - UNIT_ASSERT(registry->AddFunction(func, true).ok()); - for (auto type0 : cp::internal::IntTypes()) { - for (auto type1 : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type0, {32, 12}); - auto arg1 = NumVecToArray(type1, {16, 10}); - auto res = arrow::compute::CallFunction(TGreatestCommonDivisor::Name, {arg0, arg1}, GetCustomExecContext()); - auto expected = NumVecToArray(res->type(), {16, 2}); - UNIT_ASSERT(res->Equals(expected)); - } - } - } - - Y_UNIT_TEST(GcdWrongTypes) { - for (auto type0 : cp::internal::FloatingPointTypes()) { - for (auto type1 : cp::internal::NumericTypes()) { - TestWrongTypeBinary(TGreatestCommonDivisor::Name, type0, type1, GetCustomExecContext()); - } - } - for (auto type : cp::internal::FloatingPointTypes()) { - TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type, GetCustomExecContext()); - } - for (auto type : cp::internal::TemporalTypes()) { - TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type, GetCustomExecContext()); - } - for (auto type : cp::internal::StringTypes()) { - TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type, GetCustomExecContext()); - } - for (auto type : cp::internal::BaseBinaryTypes()) { - TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type); - } - for (auto type : nonPrimitiveTypes) { - TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type); - } - } -} - - -Y_UNIT_TEST_SUITE(LcmTest) { - Y_UNIT_TEST(LcmInts) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type, {6, 12}); - auto arg1 = NumVecToArray(type, {3, 10}); - auto expected = NumVecToArray(type, {6, 60}); - auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(LcmNull) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); - auto arg1 = NumVecToArray(type, {24, 10}); - auto expected = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); - auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(LcmDiffTypes) { - for (auto type0 : cp::internal::IntTypes()) { - for (auto type1 : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type0, {6, 12}); - auto arg1 = NumVecToArray(type1, {3, 10}); - auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); - auto expected = NumVecToArray(res->type(), {6, 60}); - UNIT_ASSERT(res->Equals(expected)); - } - } - } - - Y_UNIT_TEST(LcmWrongTypes) { - for (auto type0 : cp::internal::FloatingPointTypes()) { - for (auto type1 : cp::internal::NumericTypes()) { - TestWrongTypeBinary(TLeastCommonMultiple::Name, type0, type1, GetCustomExecContext()); - } - } - for (auto type : cp::internal::FloatingPointTypes()) { - TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); - } - for (auto type : cp::internal::TemporalTypes()) { - TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); - } - for (auto type : cp::internal::StringTypes()) { - TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); - } - for (auto type : cp::internal::BaseBinaryTypes()) { - TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); - } - for (auto type : nonPrimitiveTypes) { - TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); - } - } - - Y_UNIT_TEST(LcmNullDivide) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type, {0, 16}); - auto arg1 = NumVecToArray(type, {0, 5}); - auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(!res.ok()); - } - } -} - -Y_UNIT_TEST_SUITE(ModuloTest) { - Y_UNIT_TEST(ModuloInts) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type, {10, 16}); - auto arg1 = NumVecToArray(type, {3, 5}); - auto expected = NumVecToArray(type, {1, 1}); - auto res = arrow::compute::CallFunction(TModulo::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(ModuloFloats) { - for (auto type : cp::internal::FloatingPointTypes()) { - auto arg0 = NumVecToArray(type, {10.2234, 16.2347}); - auto arg1 = NumVecToArray(type, {3.22343, 5.4234}); - auto expected = NumVecToArray(type, {1, 1}); - auto res = arrow::compute::CallFunction(TModulo::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(ModuloNullDivide) { - for (auto type : cp::internal::NumericTypes()) { - auto arg0 = NumVecToArray(type, {10.2234, 16.2347}); - auto arg1 = NumVecToArray(type, {0, 5.4234}); - auto res = arrow::compute::CallFunction(TModulo::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(!res.ok()); - } - } - -} - -Y_UNIT_TEST_SUITE(ModuloOrZeroTest) { - Y_UNIT_TEST(ModuloOrZeroInts) { - for (auto type : cp::internal::IntTypes()) { - auto arg0 = NumVecToArray(type, {10, 16}); - auto arg1 = NumVecToArray(type, {6, 0}); - auto expected = NumVecToArray(type, {4, 0}); - auto res = arrow::compute::CallFunction(TModuloOrZero::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - - Y_UNIT_TEST(ModuloOrZeroFloats) { - for (auto type : cp::internal::FloatingPointTypes()) { - auto arg0 = NumVecToArray(type, {10.2234, 16.2347}); - auto arg1 = NumVecToArray(type, {0.23, 5.4234}); - auto expected = NumVecToArray(type, {0, 1}); - auto res = arrow::compute::CallFunction(TModuloOrZero::Name, {arg0, arg1}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expected)); - } - } - -} - + +namespace cp = ::arrow::compute; +using cp::internal::applicator::ScalarBinary; +using cp::internal::applicator::ScalarUnary; + +const std::vector<std::shared_ptr<arrow::DataType>> nonPrimitiveTypes = { + arrow::list(arrow::utf8()), + arrow::list(arrow::int64()), + arrow::large_list(arrow::large_utf8()), + arrow::fixed_size_list(arrow::utf8(), 3), + arrow::fixed_size_list(arrow::int64(), 4), + arrow::dictionary(arrow::int32(), arrow::utf8()) +}; + +const std::vector<std::shared_ptr<arrow::DataType>> decimalTypes = { + arrow::decimal128(12, 2), + arrow::decimal256(12, 2) +}; + + +template<typename TType> +EnableIfSigned<TType> MakeRandomNum() { + double lowerLimit = -120; + double upperLimit = 120; + return static_cast<TType>(double(rand()) * (upperLimit - lowerLimit) / RAND_MAX + lowerLimit); +} + +template<typename TType> +EnableIfUnsigned<TType> MakeRandomNum() { + double lowerLimit = 0; + double upperLimit = 120; + return static_cast<TType>(double(rand()) * (upperLimit - lowerLimit) / RAND_MAX + lowerLimit); +} + +template<typename TType> +std::shared_ptr<TArray<TType>> GenerateTestArray(int64_t sz) { + std::srand(std::time(nullptr)); + TBuilder<TType> builder; + std::shared_ptr<TArray<TType>> res; + arrow::Status st; + UNIT_ASSERT(builder.Reserve(sz).ok()); + for (int64_t i = 0; i < sz; ++i) { + UNIT_ASSERT(builder.Append(MakeRandomNum<typename TType::c_type>()).ok()); + } + UNIT_ASSERT(builder.Finish(&res).ok()); + return res; +} + +void TestWrongTypeUnary(const std::string& func_name, std::shared_ptr<arrow::DataType> type, cp::ExecContext* exc = nullptr) { + auto res = arrow::compute::CallFunction(func_name, {arrow::MakeArrayOfNull(type, 1).ValueOrDie()}, exc); + UNIT_ASSERT_EQUAL(res.ok(), false); +} + +void TestWrongTypeBinary(const std::string& func_name, std::shared_ptr<arrow::DataType> type0, + std::shared_ptr<arrow::DataType> type1, + cp::ExecContext* exc = nullptr) { + auto res = arrow::compute::CallFunction(func_name, {arrow::MakeArrayOfNull(type0, 1).ValueOrDie(), + arrow::MakeArrayOfNull(type1, 1).ValueOrDie()}, + exc); + UNIT_ASSERT_EQUAL(res.ok(), false); +} + +std::shared_ptr<arrow::Array> MakeBooleanArray(const std::vector<bool>& arr) { + arrow::BooleanBuilder builder; + std::shared_ptr<arrow::BooleanArray> res; + UNIT_ASSERT(builder.Reserve(arr.size()).ok()); + for (size_t i = 0; i < arr.size(); ++i) { + UNIT_ASSERT(builder.Append(arr[i]).ok()); + } + UNIT_ASSERT(builder.Finish(&res).ok()); + return res; +} + +Y_UNIT_TEST_SUITE(ArrowAbsTest) { + Y_UNIT_TEST(AbsSignedInts) { + for (auto type : cp::internal::SignedIntTypes()) { + auto arg0 = NumVecToArray(type, {-32, -12, 54}); + auto expected = NumVecToArray(type, {32, 12, 54}); + auto res = arrow::compute::CallFunction("abs", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(AbsUnsignedInts) { + for (auto type : cp::internal::UnsignedIntTypes()) { + auto arg0 = NumVecToArray(type, {32, 12, 54}); + auto expected = NumVecToArray(type, {32, 12, 54}); + auto res = arrow::compute::CallFunction("abs", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(AbsFloating) { + for (auto type : cp::internal::FloatingPointTypes()) { + auto arg0 = NumVecToArray(type, {-32.4, -12.3, 54.7}); + auto expected = NumVecToArray(type, {32.4, 12.3, 54.7}); + auto res = arrow::compute::CallFunction("abs", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(AbsNull) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); + auto expected = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); + auto res = arrow::compute::CallFunction("abs", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } + + /*Y_UNIT_TEST(AbsWrongTypes) { + for (auto type : cp::internal::TemporalTypes()) { + TestWrongTypeUnary(TAbsoluteValue::Name, type); + } + for (auto type : cp::internal::StringTypes()) { + TestWrongTypeUnary(TAbsoluteValue::Name, type); + } + for (auto type : cp::internal::BaseBinaryTypes()) { + TestWrongTypeUnary(TAbsoluteValue::Name, type); + } + for (auto type : nonPrimitiveTypes) { + TestWrongTypeUnary(TAbsoluteValue::Name, type); + } + }*/ +} + +Y_UNIT_TEST_SUITE(ArrowNegateTest) { + Y_UNIT_TEST(NegateSignedInts) { + for (auto type : cp::internal::SignedIntTypes()) { + auto arg0 = NumVecToArray(type, {-32, -12, 54}); + auto expected = NumVecToArray(type, {32, 12, -54}); + auto res = arrow::compute::CallFunction("negate", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(NegateFloating) { + for (auto type : cp::internal::FloatingPointTypes()) { + auto arg0 = NumVecToArray(type, {-32.4, -12.3, 54.7}); + auto expected = NumVecToArray(type, {32.4, 12.3, -54.7}); + auto res = arrow::compute::CallFunction("negate", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(NegateWrongTypes) { + for (auto type : cp::internal::TemporalTypes()) { + TestWrongTypeUnary("negate", type); + } + for (auto type : cp::internal::StringTypes()) { + TestWrongTypeUnary("negate", type); + } + for (auto type : cp::internal::BaseBinaryTypes()) { + TestWrongTypeUnary("negate", type); + } + for (auto type : nonPrimitiveTypes) { + TestWrongTypeUnary("negate", type); + } + } + + Y_UNIT_TEST(NegateNull) { + for (auto type : cp::internal::SignedIntTypes()) { + auto arg0 = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); + auto expected = arrow::MakeArrayOfNull(type, 3).ValueOrDie(); + auto res = arrow::compute::CallFunction("negate", {arg0}); + UNIT_ASSERT(res->Equals(expected)); + } + } +} + + + +Y_UNIT_TEST_SUITE(GcdTest) { + Y_UNIT_TEST(GcdInts) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type, {32, 12}); + auto arg1 = NumVecToArray(type, {24, 10}); + auto expected = NumVecToArray(type, {8, 2}); + auto res = arrow::compute::CallFunction(TGreatestCommonDivisor::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(GcdNull) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); + auto arg1 = NumVecToArray(type, {24, 10}); + auto expected = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); + auto res = arrow::compute::CallFunction(TGreatestCommonDivisor::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(GcdDiffTypes) { + auto registry = cp::GetFunctionRegistry(); + auto func = MakeArithmeticIntBinary<TGreatestCommonDivisor>(TGreatestCommonDivisor::Name); + UNIT_ASSERT(registry->AddFunction(func, true).ok()); + for (auto type0 : cp::internal::IntTypes()) { + for (auto type1 : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type0, {32, 12}); + auto arg1 = NumVecToArray(type1, {16, 10}); + auto res = arrow::compute::CallFunction(TGreatestCommonDivisor::Name, {arg0, arg1}, GetCustomExecContext()); + auto expected = NumVecToArray(res->type(), {16, 2}); + UNIT_ASSERT(res->Equals(expected)); + } + } + } + + Y_UNIT_TEST(GcdWrongTypes) { + for (auto type0 : cp::internal::FloatingPointTypes()) { + for (auto type1 : cp::internal::NumericTypes()) { + TestWrongTypeBinary(TGreatestCommonDivisor::Name, type0, type1, GetCustomExecContext()); + } + } + for (auto type : cp::internal::FloatingPointTypes()) { + TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type, GetCustomExecContext()); + } + for (auto type : cp::internal::TemporalTypes()) { + TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type, GetCustomExecContext()); + } + for (auto type : cp::internal::StringTypes()) { + TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type, GetCustomExecContext()); + } + for (auto type : cp::internal::BaseBinaryTypes()) { + TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type); + } + for (auto type : nonPrimitiveTypes) { + TestWrongTypeBinary(TGreatestCommonDivisor::Name, type, type); + } + } +} + + +Y_UNIT_TEST_SUITE(LcmTest) { + Y_UNIT_TEST(LcmInts) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type, {6, 12}); + auto arg1 = NumVecToArray(type, {3, 10}); + auto expected = NumVecToArray(type, {6, 60}); + auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(LcmNull) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); + auto arg1 = NumVecToArray(type, {24, 10}); + auto expected = arrow::MakeArrayOfNull(type, 2).ValueOrDie(); + auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(LcmDiffTypes) { + for (auto type0 : cp::internal::IntTypes()) { + for (auto type1 : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type0, {6, 12}); + auto arg1 = NumVecToArray(type1, {3, 10}); + auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); + auto expected = NumVecToArray(res->type(), {6, 60}); + UNIT_ASSERT(res->Equals(expected)); + } + } + } + + Y_UNIT_TEST(LcmWrongTypes) { + for (auto type0 : cp::internal::FloatingPointTypes()) { + for (auto type1 : cp::internal::NumericTypes()) { + TestWrongTypeBinary(TLeastCommonMultiple::Name, type0, type1, GetCustomExecContext()); + } + } + for (auto type : cp::internal::FloatingPointTypes()) { + TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); + } + for (auto type : cp::internal::TemporalTypes()) { + TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); + } + for (auto type : cp::internal::StringTypes()) { + TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); + } + for (auto type : cp::internal::BaseBinaryTypes()) { + TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); + } + for (auto type : nonPrimitiveTypes) { + TestWrongTypeBinary(TLeastCommonMultiple::Name, type, type, GetCustomExecContext()); + } + } + + Y_UNIT_TEST(LcmNullDivide) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type, {0, 16}); + auto arg1 = NumVecToArray(type, {0, 5}); + auto res = arrow::compute::CallFunction(TLeastCommonMultiple::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(!res.ok()); + } + } +} + +Y_UNIT_TEST_SUITE(ModuloTest) { + Y_UNIT_TEST(ModuloInts) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type, {10, 16}); + auto arg1 = NumVecToArray(type, {3, 5}); + auto expected = NumVecToArray(type, {1, 1}); + auto res = arrow::compute::CallFunction(TModulo::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(ModuloFloats) { + for (auto type : cp::internal::FloatingPointTypes()) { + auto arg0 = NumVecToArray(type, {10.2234, 16.2347}); + auto arg1 = NumVecToArray(type, {3.22343, 5.4234}); + auto expected = NumVecToArray(type, {1, 1}); + auto res = arrow::compute::CallFunction(TModulo::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(ModuloNullDivide) { + for (auto type : cp::internal::NumericTypes()) { + auto arg0 = NumVecToArray(type, {10.2234, 16.2347}); + auto arg1 = NumVecToArray(type, {0, 5.4234}); + auto res = arrow::compute::CallFunction(TModulo::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(!res.ok()); + } + } + +} + +Y_UNIT_TEST_SUITE(ModuloOrZeroTest) { + Y_UNIT_TEST(ModuloOrZeroInts) { + for (auto type : cp::internal::IntTypes()) { + auto arg0 = NumVecToArray(type, {10, 16}); + auto arg1 = NumVecToArray(type, {6, 0}); + auto expected = NumVecToArray(type, {4, 0}); + auto res = arrow::compute::CallFunction(TModuloOrZero::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + + Y_UNIT_TEST(ModuloOrZeroFloats) { + for (auto type : cp::internal::FloatingPointTypes()) { + auto arg0 = NumVecToArray(type, {10.2234, 16.2347}); + auto arg1 = NumVecToArray(type, {0.23, 5.4234}); + auto expected = NumVecToArray(type, {0, 1}); + auto res = arrow::compute::CallFunction(TModuloOrZero::Name, {arg0, arg1}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expected)); + } + } + +} + } diff --git a/ydb/core/formats/ut_math.cpp b/ydb/core/formats/ut_math.cpp index c9945c4e09..93fbb00dea 100644 --- a/ydb/core/formats/ut_math.cpp +++ b/ydb/core/formats/ut_math.cpp @@ -1,72 +1,72 @@ -#include <cmath> -#include <cstdint> -#include <iterator> -#include <library/cpp/testing/unittest/registar.h> -#include <ctime> -#include <vector> -#include <algorithm> - -#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> - - -#include "func_common.h" -#include "functions.h" -#include "custom_registry.h" -#include "arrow_helpers.h" - - +#include <cmath> +#include <cstdint> +#include <iterator> +#include <library/cpp/testing/unittest/registar.h> +#include <ctime> +#include <vector> +#include <algorithm> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> + + +#include "func_common.h" +#include "functions.h" +#include "custom_registry.h" +#include "arrow_helpers.h" + + namespace NKikimr::NArrow { - -namespace cp = ::arrow::compute; - - -Y_UNIT_TEST_SUITE(MathTest) { - Y_UNIT_TEST(E) { - auto res = arrow::compute::CallFunction(TE::Name, {}, GetCustomExecContext()); - UNIT_ASSERT(res->scalar()->Equals(arrow::MakeScalar(std::exp(1.0)))); - } - - Y_UNIT_TEST(Pi) { - auto res = arrow::compute::CallFunction(TPi::Name, {}, GetCustomExecContext()); - UNIT_ASSERT(res->scalar()->Equals(arrow::MakeScalar(std::atan2(0, -1)))); - } - - Y_UNIT_TEST(AcoshFloat32) { - std::vector<double> argVec = {2.324, 1.34234, 41.14324, 123}; - std::vector<double> expVec; - for (auto val : argVec) { - expVec.push_back(std::acosh(static_cast<float>(val))); - } - auto expRes = NumVecToArray(arrow::float64(), expVec); - auto res = arrow::compute::CallFunction(TAcosh::Name, {NumVecToArray(arrow::float32(), argVec)}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expRes)); - } - - Y_UNIT_TEST(AcoshFloat64) { - std::vector<double> argVec = {2.324, 1.34234, 41.14324, 123}; - std::vector<double> expVec; - for (auto val : argVec) { - expVec.push_back(std::acosh(val)); - } - auto expRes = NumVecToArray(arrow::float64(), expVec); - auto res = arrow::compute::CallFunction(TAcosh::Name, {NumVecToArray(arrow::float64(), argVec)}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expRes)); - } - - Y_UNIT_TEST(AcoshInts) { - std::vector<double> argVec = {2.324, 1.34234, 41.14324, 123}; - std::vector<double> expVec; - for (auto val : argVec) { - expVec.push_back(std::acosh(static_cast<int64_t>(val))); - } - auto expRes = NumVecToArray(arrow::float64(), expVec); - for (auto type : cp::internal::IntTypes()) { - auto res = arrow::compute::CallFunction(TAcosh::Name, {NumVecToArray(type, argVec)}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expRes)); - } - } - -} - -} + +namespace cp = ::arrow::compute; + + +Y_UNIT_TEST_SUITE(MathTest) { + Y_UNIT_TEST(E) { + auto res = arrow::compute::CallFunction(TE::Name, {}, GetCustomExecContext()); + UNIT_ASSERT(res->scalar()->Equals(arrow::MakeScalar(std::exp(1.0)))); + } + + Y_UNIT_TEST(Pi) { + auto res = arrow::compute::CallFunction(TPi::Name, {}, GetCustomExecContext()); + UNIT_ASSERT(res->scalar()->Equals(arrow::MakeScalar(std::atan2(0, -1)))); + } + + Y_UNIT_TEST(AcoshFloat32) { + std::vector<double> argVec = {2.324, 1.34234, 41.14324, 123}; + std::vector<double> expVec; + for (auto val : argVec) { + expVec.push_back(std::acosh(static_cast<float>(val))); + } + auto expRes = NumVecToArray(arrow::float64(), expVec); + auto res = arrow::compute::CallFunction(TAcosh::Name, {NumVecToArray(arrow::float32(), argVec)}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expRes)); + } + + Y_UNIT_TEST(AcoshFloat64) { + std::vector<double> argVec = {2.324, 1.34234, 41.14324, 123}; + std::vector<double> expVec; + for (auto val : argVec) { + expVec.push_back(std::acosh(val)); + } + auto expRes = NumVecToArray(arrow::float64(), expVec); + auto res = arrow::compute::CallFunction(TAcosh::Name, {NumVecToArray(arrow::float64(), argVec)}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expRes)); + } + + Y_UNIT_TEST(AcoshInts) { + std::vector<double> argVec = {2.324, 1.34234, 41.14324, 123}; + std::vector<double> expVec; + for (auto val : argVec) { + expVec.push_back(std::acosh(static_cast<int64_t>(val))); + } + auto expRes = NumVecToArray(arrow::float64(), expVec); + for (auto type : cp::internal::IntTypes()) { + auto res = arrow::compute::CallFunction(TAcosh::Name, {NumVecToArray(type, argVec)}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expRes)); + } + } + +} + +} diff --git a/ydb/core/formats/ut_program_step.cpp b/ydb/core/formats/ut_program_step.cpp index fa368af562..c908a2f742 100644 --- a/ydb/core/formats/ut_program_step.cpp +++ b/ydb/core/formats/ut_program_step.cpp @@ -1,199 +1,199 @@ -#include <array> -#include <memory> -#include <vector> - -#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> -#include <library/cpp/testing/unittest/registar.h> -#include "custom_registry.h" -#include "program.h" -#include "arrow_helpers.h" - +#include <array> +#include <memory> +#include <vector> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/exec.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> +#include <library/cpp/testing/unittest/registar.h> +#include "custom_registry.h" +#include "program.h" +#include "arrow_helpers.h" + namespace NKikimr::NArrow { - -size_t FilterTest(std::vector<std::shared_ptr<arrow::Array>> args, EOperation frst, EOperation scnd) { - auto schema = std::make_shared<arrow::Schema>(std::vector{ - std::make_shared<arrow::Field>("x", args.at(0)->type()), - std::make_shared<arrow::Field>("y", args.at(1)->type()), - std::make_shared<arrow::Field>("z", args.at(2)->type())}); - auto rBatch = arrow::RecordBatch::Make(schema, 3, std::vector{args.at(0), args.at(1), args.at(2)}); - - auto ps = std::make_shared<TProgramStep>(); - ps->Assignes = {TAssign("res1", frst, {"x", "y"}), TAssign("res2", scnd, {"res1", "z"})}; - ps->Filters = {"res2"}; - ps->Projection = {"res1", "res2"}; - ApplyProgram(rBatch, {ps}, GetCustomExecContext()); - UNIT_ASSERT(rBatch->ValidateFull().ok()); - UNIT_ASSERT(rBatch->num_columns() == 2); - return rBatch->num_rows(); -} - -size_t FilterTestUnary(std::vector<std::shared_ptr<arrow::Array>> args, EOperation frst, EOperation scnd) { - auto schema = std::make_shared<arrow::Schema>(std::vector{ - std::make_shared<arrow::Field>("x", args.at(0)->type()), - std::make_shared<arrow::Field>("z", args.at(1)->type())}); - auto rBatch = arrow::RecordBatch::Make(schema, 3, std::vector{args.at(0), args.at(1)}); - - auto ps = std::make_shared<TProgramStep>(); - ps->Assignes = {TAssign("res1", frst, {"x"}), TAssign("res2", scnd, {"res1", "z"})}; - ps->Filters = {"res2"}; - ps->Projection = {"res1", "res2"}; - ApplyProgram(rBatch, {ps}, GetCustomExecContext()); - UNIT_ASSERT(rBatch->ValidateFull().ok()); - UNIT_ASSERT(rBatch->num_columns() == 2); - return rBatch->num_rows(); -} - - -Y_UNIT_TEST_SUITE(ProgramStepTest) { - Y_UNIT_TEST(ProgramStepRound0) { - for (auto eop : {EOperation::Round, EOperation::RoundBankers, EOperation::RoundToExp2}) { - auto x = NumVecToArray(arrow::float64(), {32.3, 12.5, 34.7}); - auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x}, GetCustomExecContext()); - UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, eop, EOperation::Equal) == 3); - } - } - - Y_UNIT_TEST(ProgramStepRound1) { - for (auto eop : {EOperation::Ceil, EOperation::Floor, EOperation::Trunc}) { - auto x = NumVecToArray(arrow::float64(), {32.3, 12.5, 34.7}); - auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x}); - UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, eop, EOperation::Equal) == 3); - } - } - - Y_UNIT_TEST(Filter) { - auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); - auto y = NumVecToArray(arrow::uint32(), {10, 34, 8}); - auto z = NumVecToArray(arrow::int64(), {33, 70, 12}); - UNIT_ASSERT(FilterTest({x, y, z}, EOperation::Add, EOperation::Less) == 2); - } - - Y_UNIT_TEST(ProgramStepAdd) { - auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); - auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); - auto z = arrow::compute::CallFunction("add", {x, y}); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Add, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepSubstract) { - auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); - auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); - auto z = arrow::compute::CallFunction("subtract", {x, y}); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Subtract, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepMultiply) { - auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); - auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); - auto z = arrow::compute::CallFunction("multiply", {x, y}); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Multiply, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepDivide) { - auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); - auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); - auto z = arrow::compute::CallFunction("divide", {x, y}); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Divide, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepGcd) { - auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); - auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); - auto z = arrow::compute::CallFunction("gcd", {x, y}, GetCustomExecContext()); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Gcd, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepLcm) { - auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); - auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); - auto z = arrow::compute::CallFunction("lcm", {x, y}, GetCustomExecContext()); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Lcm, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepMod) { - auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); - auto y = NumVecToArray(arrow::int32(), {3, 5, 2}); - auto z = arrow::compute::CallFunction("mod", {x, y}, GetCustomExecContext()); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Modulo, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepModOrZero) { - auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); - auto y = NumVecToArray(arrow::int32(), {3, 5, 0}); - auto z = arrow::compute::CallFunction("modOrZero", {x, y}, GetCustomExecContext()); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::ModuloOrZero, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepAbs) { - auto x = NumVecToArray(arrow::int32(), {-64, -16, 8}); - auto z = arrow::compute::CallFunction("abs", {x}); - UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, EOperation::Abs, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepNegate) { - auto x = NumVecToArray(arrow::int32(), {-64, -16, 8}); - auto z = arrow::compute::CallFunction("negate", {x}); - UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, EOperation::Negate, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepCompares) { - for (auto eop : {EOperation::Equal, EOperation::Less, EOperation::Greater, EOperation::GreaterEqual, - EOperation::LessEqual, EOperation::NotEqual}) { - auto x = NumVecToArray(arrow::int32(), {64, 5, 1}); - auto y = NumVecToArray(arrow::int32(), {64, 1, 5}); - auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x, y}); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, eop, EOperation::Equal) == 3); - } - } - - Y_UNIT_TEST(ProgramStepLogic0) { - for (auto eop : {EOperation::And, EOperation::Or, EOperation::Xor}) { - auto x = BoolVecToArray({true, false, false}); - auto y = BoolVecToArray({true, true, false}); - auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x, y}); - UNIT_ASSERT(FilterTest({x, y, z->make_array()}, eop, EOperation::Equal) == 3); - } - } - - Y_UNIT_TEST(ProgramStepLogic1) { - auto x = BoolVecToArray({true, false, false}); - auto z = arrow::compute::CallFunction("invert", {x}); - UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, EOperation::Invert, EOperation::Equal) == 3); - } - - Y_UNIT_TEST(ProgramStepScalarTest) { - auto schema = std::make_shared<arrow::Schema>(std::vector{ - std::make_shared<arrow::Field>("x", arrow::int64()), - std::make_shared<arrow::Field>("filter", arrow::boolean())}); - auto rBatch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int64(), {64, 5, 1, 43}), - BoolVecToArray({true, false, false, true})}); - auto ps = std::make_shared<TProgramStep>(); - ps->Assignes = {TAssign("y", 56), TAssign("res", EOperation::Add, {"x", "y"})}; - ps->Filters = {"filter"}; - ps->Projection = {"res", "filter"}; - ApplyProgram(rBatch, {ps}, GetCustomExecContext()); - UNIT_ASSERT(rBatch->ValidateFull().ok()); - UNIT_ASSERT(rBatch->num_columns() == 2); - UNIT_ASSERT(rBatch->num_rows() == 2); - } - - Y_UNIT_TEST(ProgramStepEmptyFilter) { - auto schema = std::make_shared<arrow::Schema>(std::vector{ - std::make_shared<arrow::Field>("x", arrow::int64()), - std::make_shared<arrow::Field>("filter", arrow::boolean())}); - auto rBatch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int64(), {64, 5, 1, 43}), - BoolVecToArray({true, false, false, true})}); - auto ps = std::make_shared<TProgramStep>(); - ps->Assignes = {TAssign("y", 56), TAssign("res", EOperation::Add, {"x", "y"})}; - ps->Filters = {}; - ps->Projection = {"res", "filter"}; - ApplyProgram(rBatch, {ps}, GetCustomExecContext()); - UNIT_ASSERT(rBatch->ValidateFull().ok()); - UNIT_ASSERT(rBatch->num_columns() == 2); - UNIT_ASSERT(rBatch->num_rows() == 4); - } -} - -} + +size_t FilterTest(std::vector<std::shared_ptr<arrow::Array>> args, EOperation frst, EOperation scnd) { + auto schema = std::make_shared<arrow::Schema>(std::vector{ + std::make_shared<arrow::Field>("x", args.at(0)->type()), + std::make_shared<arrow::Field>("y", args.at(1)->type()), + std::make_shared<arrow::Field>("z", args.at(2)->type())}); + auto rBatch = arrow::RecordBatch::Make(schema, 3, std::vector{args.at(0), args.at(1), args.at(2)}); + + auto ps = std::make_shared<TProgramStep>(); + ps->Assignes = {TAssign("res1", frst, {"x", "y"}), TAssign("res2", scnd, {"res1", "z"})}; + ps->Filters = {"res2"}; + ps->Projection = {"res1", "res2"}; + ApplyProgram(rBatch, {ps}, GetCustomExecContext()); + UNIT_ASSERT(rBatch->ValidateFull().ok()); + UNIT_ASSERT(rBatch->num_columns() == 2); + return rBatch->num_rows(); +} + +size_t FilterTestUnary(std::vector<std::shared_ptr<arrow::Array>> args, EOperation frst, EOperation scnd) { + auto schema = std::make_shared<arrow::Schema>(std::vector{ + std::make_shared<arrow::Field>("x", args.at(0)->type()), + std::make_shared<arrow::Field>("z", args.at(1)->type())}); + auto rBatch = arrow::RecordBatch::Make(schema, 3, std::vector{args.at(0), args.at(1)}); + + auto ps = std::make_shared<TProgramStep>(); + ps->Assignes = {TAssign("res1", frst, {"x"}), TAssign("res2", scnd, {"res1", "z"})}; + ps->Filters = {"res2"}; + ps->Projection = {"res1", "res2"}; + ApplyProgram(rBatch, {ps}, GetCustomExecContext()); + UNIT_ASSERT(rBatch->ValidateFull().ok()); + UNIT_ASSERT(rBatch->num_columns() == 2); + return rBatch->num_rows(); +} + + +Y_UNIT_TEST_SUITE(ProgramStepTest) { + Y_UNIT_TEST(ProgramStepRound0) { + for (auto eop : {EOperation::Round, EOperation::RoundBankers, EOperation::RoundToExp2}) { + auto x = NumVecToArray(arrow::float64(), {32.3, 12.5, 34.7}); + auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x}, GetCustomExecContext()); + UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, eop, EOperation::Equal) == 3); + } + } + + Y_UNIT_TEST(ProgramStepRound1) { + for (auto eop : {EOperation::Ceil, EOperation::Floor, EOperation::Trunc}) { + auto x = NumVecToArray(arrow::float64(), {32.3, 12.5, 34.7}); + auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x}); + UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, eop, EOperation::Equal) == 3); + } + } + + Y_UNIT_TEST(Filter) { + auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); + auto y = NumVecToArray(arrow::uint32(), {10, 34, 8}); + auto z = NumVecToArray(arrow::int64(), {33, 70, 12}); + UNIT_ASSERT(FilterTest({x, y, z}, EOperation::Add, EOperation::Less) == 2); + } + + Y_UNIT_TEST(ProgramStepAdd) { + auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); + auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); + auto z = arrow::compute::CallFunction("add", {x, y}); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Add, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepSubstract) { + auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); + auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); + auto z = arrow::compute::CallFunction("subtract", {x, y}); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Subtract, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepMultiply) { + auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); + auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); + auto z = arrow::compute::CallFunction("multiply", {x, y}); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Multiply, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepDivide) { + auto x = NumVecToArray(arrow::int32(), {10, 34, 8}); + auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); + auto z = arrow::compute::CallFunction("divide", {x, y}); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Divide, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepGcd) { + auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); + auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); + auto z = arrow::compute::CallFunction("gcd", {x, y}, GetCustomExecContext()); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Gcd, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepLcm) { + auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); + auto y = NumVecToArray(arrow::int32(), {32, 12, 4}); + auto z = arrow::compute::CallFunction("lcm", {x, y}, GetCustomExecContext()); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Lcm, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepMod) { + auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); + auto y = NumVecToArray(arrow::int32(), {3, 5, 2}); + auto z = arrow::compute::CallFunction("mod", {x, y}, GetCustomExecContext()); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::Modulo, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepModOrZero) { + auto x = NumVecToArray(arrow::int32(), {64, 16, 8}); + auto y = NumVecToArray(arrow::int32(), {3, 5, 0}); + auto z = arrow::compute::CallFunction("modOrZero", {x, y}, GetCustomExecContext()); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, EOperation::ModuloOrZero, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepAbs) { + auto x = NumVecToArray(arrow::int32(), {-64, -16, 8}); + auto z = arrow::compute::CallFunction("abs", {x}); + UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, EOperation::Abs, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepNegate) { + auto x = NumVecToArray(arrow::int32(), {-64, -16, 8}); + auto z = arrow::compute::CallFunction("negate", {x}); + UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, EOperation::Negate, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepCompares) { + for (auto eop : {EOperation::Equal, EOperation::Less, EOperation::Greater, EOperation::GreaterEqual, + EOperation::LessEqual, EOperation::NotEqual}) { + auto x = NumVecToArray(arrow::int32(), {64, 5, 1}); + auto y = NumVecToArray(arrow::int32(), {64, 1, 5}); + auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x, y}); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, eop, EOperation::Equal) == 3); + } + } + + Y_UNIT_TEST(ProgramStepLogic0) { + for (auto eop : {EOperation::And, EOperation::Or, EOperation::Xor}) { + auto x = BoolVecToArray({true, false, false}); + auto y = BoolVecToArray({true, true, false}); + auto z = arrow::compute::CallFunction(GetFunctionName(eop), {x, y}); + UNIT_ASSERT(FilterTest({x, y, z->make_array()}, eop, EOperation::Equal) == 3); + } + } + + Y_UNIT_TEST(ProgramStepLogic1) { + auto x = BoolVecToArray({true, false, false}); + auto z = arrow::compute::CallFunction("invert", {x}); + UNIT_ASSERT(FilterTestUnary({x, z->make_array()}, EOperation::Invert, EOperation::Equal) == 3); + } + + Y_UNIT_TEST(ProgramStepScalarTest) { + auto schema = std::make_shared<arrow::Schema>(std::vector{ + std::make_shared<arrow::Field>("x", arrow::int64()), + std::make_shared<arrow::Field>("filter", arrow::boolean())}); + auto rBatch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int64(), {64, 5, 1, 43}), + BoolVecToArray({true, false, false, true})}); + auto ps = std::make_shared<TProgramStep>(); + ps->Assignes = {TAssign("y", 56), TAssign("res", EOperation::Add, {"x", "y"})}; + ps->Filters = {"filter"}; + ps->Projection = {"res", "filter"}; + ApplyProgram(rBatch, {ps}, GetCustomExecContext()); + UNIT_ASSERT(rBatch->ValidateFull().ok()); + UNIT_ASSERT(rBatch->num_columns() == 2); + UNIT_ASSERT(rBatch->num_rows() == 2); + } + + Y_UNIT_TEST(ProgramStepEmptyFilter) { + auto schema = std::make_shared<arrow::Schema>(std::vector{ + std::make_shared<arrow::Field>("x", arrow::int64()), + std::make_shared<arrow::Field>("filter", arrow::boolean())}); + auto rBatch = arrow::RecordBatch::Make(schema, 4, std::vector{NumVecToArray(arrow::int64(), {64, 5, 1, 43}), + BoolVecToArray({true, false, false, true})}); + auto ps = std::make_shared<TProgramStep>(); + ps->Assignes = {TAssign("y", 56), TAssign("res", EOperation::Add, {"x", "y"})}; + ps->Filters = {}; + ps->Projection = {"res", "filter"}; + ApplyProgram(rBatch, {ps}, GetCustomExecContext()); + UNIT_ASSERT(rBatch->ValidateFull().ok()); + UNIT_ASSERT(rBatch->num_columns() == 2); + UNIT_ASSERT(rBatch->num_rows() == 4); + } +} + +} diff --git a/ydb/core/formats/ut_round.cpp b/ydb/core/formats/ut_round.cpp index 7b639b0888..05fc6f3879 100644 --- a/ydb/core/formats/ut_round.cpp +++ b/ydb/core/formats/ut_round.cpp @@ -1,53 +1,53 @@ -#include <cmath> -#include <cstdint> -#include <iterator> -#include <library/cpp/testing/unittest/registar.h> -#include <ctime> -#include <vector> -#include <algorithm> - -#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> - -#include "func_common.h" -#include "functions.h" -#include "custom_registry.h" -#include "arrow_helpers.h" - - +#include <cmath> +#include <cstdint> +#include <iterator> +#include <library/cpp/testing/unittest/registar.h> +#include <ctime> +#include <vector> +#include <algorithm> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> + +#include "func_common.h" +#include "functions.h" +#include "custom_registry.h" +#include "arrow_helpers.h" + + namespace NKikimr::NArrow { - -namespace cp = ::arrow::compute; - - -Y_UNIT_TEST_SUITE(RoundsTest) { - Y_UNIT_TEST(RoundTest) { - for (auto ty : cp::internal::FloatingPointTypes()) { - auto arg = NumVecToArray(ty, {2.34, 5.65, 10.01, 100.0}); - auto expRes = NumVecToArray(ty, {2, 6, 10, 100}); - auto res = arrow::compute::CallFunction(TRound::Name, {arg}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expRes)); - } - } - - Y_UNIT_TEST(RoundBankersTest) { - for (auto ty : cp::internal::FloatingPointTypes()) { - auto arg = NumVecToArray(ty, {2.34, 5.5, 6.5, 100.7}); - auto expRes = NumVecToArray(ty, {2, 6, 6, 101}); - auto res = arrow::compute::CallFunction(TRoundBankers::Name, {arg}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expRes)); - } - } - - Y_UNIT_TEST(RoundToExp2Test) { - for (auto ty : cp::internal::NumericTypes()) { - auto arg = NumVecToArray(ty, {2.34, 5.5, 6.5, 100.7, 54}); - auto expRes = NumVecToArray(ty, {2, 4, 4, 64, 32}); - auto res = arrow::compute::CallFunction(TRoundToExp2::Name, {arg}, GetCustomExecContext()); - UNIT_ASSERT(res->Equals(expRes)); - } - } - -} - -} + +namespace cp = ::arrow::compute; + + +Y_UNIT_TEST_SUITE(RoundsTest) { + Y_UNIT_TEST(RoundTest) { + for (auto ty : cp::internal::FloatingPointTypes()) { + auto arg = NumVecToArray(ty, {2.34, 5.65, 10.01, 100.0}); + auto expRes = NumVecToArray(ty, {2, 6, 10, 100}); + auto res = arrow::compute::CallFunction(TRound::Name, {arg}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expRes)); + } + } + + Y_UNIT_TEST(RoundBankersTest) { + for (auto ty : cp::internal::FloatingPointTypes()) { + auto arg = NumVecToArray(ty, {2.34, 5.5, 6.5, 100.7}); + auto expRes = NumVecToArray(ty, {2, 6, 6, 101}); + auto res = arrow::compute::CallFunction(TRoundBankers::Name, {arg}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expRes)); + } + } + + Y_UNIT_TEST(RoundToExp2Test) { + for (auto ty : cp::internal::NumericTypes()) { + auto arg = NumVecToArray(ty, {2.34, 5.5, 6.5, 100.7, 54}); + auto expRes = NumVecToArray(ty, {2, 4, 4, 64, 32}); + auto res = arrow::compute::CallFunction(TRoundToExp2::Name, {arg}, GetCustomExecContext()); + UNIT_ASSERT(res->Equals(expRes)); + } + } + +} + +} |