summaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp
diff options
context:
space:
mode:
authorvitalyisaev <[email protected]>2023-11-14 09:58:56 +0300
committervitalyisaev <[email protected]>2023-11-14 10:20:20 +0300
commitc2b2dfd9827a400a8495e172a56343462e3ceb82 (patch)
treecd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp
parentd4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff)
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp')
-rw-r--r--contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp80
1 files changed, 80 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp b/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp
new file mode 100644
index 00000000000..47e865785d4
--- /dev/null
+++ b/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp
@@ -0,0 +1,80 @@
+#include <DataTypes/DataTypesNumber.h>
+#include <Functions/FunctionFactory.h>
+#include <Core/Types_fwd.h>
+#include <DataTypes/Serializations/ISerialization.h>
+#include <Functions/castTypeToEither.h>
+#include <Functions/array/arrayScalarProduct.h>
+#include <base/types.h>
+#include <Functions/FunctionBinaryArithmetic.h>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+}
+
+struct NameArrayDotProduct
+{
+ static constexpr auto name = "arrayDotProduct";
+};
+
+class ArrayDotProductImpl
+{
+public:
+ static DataTypePtr getReturnType(const DataTypePtr & left, const DataTypePtr & right)
+ {
+ using Types = TypeList<DataTypeFloat32, DataTypeFloat64,
+ DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64,
+ DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64>;
+
+ DataTypePtr result_type;
+ bool valid = castTypeToEither(Types{}, left.get(), [&](const auto & left_)
+ {
+ return castTypeToEither(Types{}, right.get(), [&](const auto & right_)
+ {
+ using LeftDataType = typename std::decay_t<decltype(left_)>::FieldType;
+ using RightDataType = typename std::decay_t<decltype(right_)>::FieldType;
+ using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<LeftDataType, RightDataType>::Type;
+ if (std::is_same_v<LeftDataType, Float32> && std::is_same_v<RightDataType, Float32>)
+ result_type = std::make_shared<DataTypeFloat32>();
+ else
+ result_type = std::make_shared<DataTypeFromFieldType<ResultType>>();
+ return true;
+ });
+ });
+
+ if (!valid)
+ throw Exception(
+ ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Arguments of function {} "
+ "only support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.",
+ std::string(NameArrayDotProduct::name));
+ return result_type;
+ }
+
+ template <typename ResultType, typename T, typename U>
+ static inline NO_SANITIZE_UNDEFINED ResultType apply(
+ const T * left,
+ const U * right,
+ size_t size)
+ {
+ ResultType result = 0;
+ for (size_t i = 0; i < size; ++i)
+ result += static_cast<ResultType>(left[i]) * static_cast<ResultType>(right[i]);
+ return result;
+ }
+};
+
+using FunctionArrayDotProduct = FunctionArrayScalarProduct<ArrayDotProductImpl, NameArrayDotProduct>;
+
+REGISTER_FUNCTION(ArrayDotProduct)
+{
+ factory.registerFunction<FunctionArrayDotProduct>();
+}
+
+// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp
+FunctionPtr createFunctionArrayDotProduct(ContextPtr context_) { return FunctionArrayDotProduct::create(context_); }
+}