aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Functions/array/arrayDotProduct.cpp
blob: 47e865785d427a1cb4e0dafb055c64c3eef7678d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Core/Types_fwd.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Functions/castTypeToEither.h>
#include <Functions/array/arrayScalarProduct.h>
#include <base/types.h>
#include <Functions/FunctionBinaryArithmetic.h>


namespace DB
{

namespace ErrorCodes
{
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

struct NameArrayDotProduct
{
    static constexpr auto name = "arrayDotProduct";
};

class ArrayDotProductImpl
{
public:
    static DataTypePtr getReturnType(const DataTypePtr & left, const DataTypePtr & right)
    {
        using Types = TypeList<DataTypeFloat32, DataTypeFloat64,
                               DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64,
                               DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64>;

        DataTypePtr result_type;
        bool valid = castTypeToEither(Types{}, left.get(), [&](const auto & left_)
        {
            return castTypeToEither(Types{}, right.get(), [&](const auto & right_)
            {
                using LeftDataType = typename std::decay_t<decltype(left_)>::FieldType;
                using RightDataType = typename std::decay_t<decltype(right_)>::FieldType;
                using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<LeftDataType, RightDataType>::Type;
                if (std::is_same_v<LeftDataType, Float32> && std::is_same_v<RightDataType, Float32>)
                    result_type = std::make_shared<DataTypeFloat32>();
                else
                    result_type = std::make_shared<DataTypeFromFieldType<ResultType>>();
                return true;
            });
        });

        if (!valid)
            throw Exception(
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "Arguments of function {} "
                "only support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.",
                std::string(NameArrayDotProduct::name));
        return result_type;
    }

    template <typename ResultType, typename T, typename U>
    static inline NO_SANITIZE_UNDEFINED ResultType apply(
        const T * left,
        const U * right,
        size_t size)
    {
        ResultType result = 0;
        for (size_t i = 0; i < size; ++i)
            result += static_cast<ResultType>(left[i]) * static_cast<ResultType>(right[i]);
        return result;
    }
};

using FunctionArrayDotProduct = FunctionArrayScalarProduct<ArrayDotProductImpl, NameArrayDotProduct>;

REGISTER_FUNCTION(ArrayDotProduct)
{
    factory.registerFunction<FunctionArrayDotProduct>();
}

// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp
FunctionPtr createFunctionArrayDotProduct(ContextPtr context_) { return FunctionArrayDotProduct::create(context_); }
}