summaryrefslogtreecommitdiffstats
path: root/library/cpp/dot_product/dot_product_simple.cpp
blob: 50afcfe62a3b8a715236e8fd0c3da39a3c0c0840 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include "dot_product_simple.h"

namespace {
    template <typename Res, typename Number>
    static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept {
        Res s0 = 0;
        Res s1 = 0;
        Res s2 = 0;
        Res s3 = 0;

        while (length >= 4) {
            s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]);
            s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]);
            s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]);
            s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]);
            lhs += 4;
            rhs += 4;
            length -= 4;
        }

        while (length--) {
            s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++);
        }

        return s0 + s1 + s2 + s3;
    }
}

float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept {
    return DotProductSimpleImpl<float, float>(lhs, rhs, length);
}

double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept {
    return DotProductSimpleImpl<double, double>(lhs, rhs, length);
}

ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept {
    ui32 res = 0;
    for (size_t i = 0; i < lengtInBytes; ++i) {
        res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f);
        res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8;
    }
    return res;
}

TTriWayDotProduct<float> TriWayDotProductSimple(
    const float* lhs,
    const float* rhs,
    size_t length,
    bool computeRR) noexcept
{
    float sumLL = 0.0f;
    float sumLR = 0.0f;
    float sumRR = 0.0f;

    for (size_t i = 0; i < length; ++i) {
        const float l = lhs[i];
        const float r = rhs[i];
        sumLL += l * l;
        sumLR += l * r;
        if (computeRR) {
            sumRR += r * r;
        }
    }

    TTriWayDotProduct<float> result;
    result.LL = sumLL;
    result.LR = sumLR;
    if (computeRR) {
        result.RR = sumRR;
    } else {
        static constexpr TTriWayDotProduct<float> def;
        result.RR = def.RR;
    }
    return result;
}