diff options
| author | vitalyisaev <[email protected]> | 2023-11-14 09:58:56 +0300 |
|---|---|---|
| committer | vitalyisaev <[email protected]> | 2023-11-14 10:20:20 +0300 |
| commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
| tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp | |
| parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp')
| -rw-r--r-- | contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp b/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp new file mode 100644 index 00000000000..47e865785d4 --- /dev/null +++ b/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp @@ -0,0 +1,80 @@ +#include <DataTypes/DataTypesNumber.h> +#include <Functions/FunctionFactory.h> +#include <Core/Types_fwd.h> +#include <DataTypes/Serializations/ISerialization.h> +#include <Functions/castTypeToEither.h> +#include <Functions/array/arrayScalarProduct.h> +#include <base/types.h> +#include <Functions/FunctionBinaryArithmetic.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + +struct NameArrayDotProduct +{ + static constexpr auto name = "arrayDotProduct"; +}; + +class ArrayDotProductImpl +{ +public: + static DataTypePtr getReturnType(const DataTypePtr & left, const DataTypePtr & right) + { + using Types = TypeList<DataTypeFloat32, DataTypeFloat64, + DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64, + DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64>; + + DataTypePtr result_type; + bool valid = castTypeToEither(Types{}, left.get(), [&](const auto & left_) + { + return castTypeToEither(Types{}, right.get(), [&](const auto & right_) + { + using LeftDataType = typename std::decay_t<decltype(left_)>::FieldType; + using RightDataType = typename std::decay_t<decltype(right_)>::FieldType; + using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<LeftDataType, RightDataType>::Type; + if (std::is_same_v<LeftDataType, Float32> && std::is_same_v<RightDataType, Float32>) + result_type = std::make_shared<DataTypeFloat32>(); + else + result_type = std::make_shared<DataTypeFromFieldType<ResultType>>(); + return true; + }); + }); + + if (!valid) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Arguments of function {} " + "only support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + std::string(NameArrayDotProduct::name)); + return result_type; + } + + template <typename ResultType, typename T, typename U> + static inline NO_SANITIZE_UNDEFINED ResultType apply( + const T * left, + const U * right, + size_t size) + { + ResultType result = 0; + for (size_t i = 0; i < size; ++i) + result += static_cast<ResultType>(left[i]) * static_cast<ResultType>(right[i]); + return result; + } +}; + +using FunctionArrayDotProduct = FunctionArrayScalarProduct<ArrayDotProductImpl, NameArrayDotProduct>; + +REGISTER_FUNCTION(ArrayDotProduct) +{ + factory.registerFunction<FunctionArrayDotProduct>(); +} + +// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp +FunctionPtr createFunctionArrayDotProduct(ContextPtr context_) { return FunctionArrayDotProduct::create(context_); } +} |
