1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
|
#include "dot_product_simple.h"
namespace {
template <typename Res, typename Number>
static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept {
Res s0 = 0;
Res s1 = 0;
Res s2 = 0;
Res s3 = 0;
while (length >= 4) {
s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]);
s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]);
s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]);
s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]);
lhs += 4;
rhs += 4;
length -= 4;
}
while (length--) {
s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++);
}
return s0 + s1 + s2 + s3;
}
}
float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept {
return DotProductSimpleImpl<float, float>(lhs, rhs, length);
}
double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept {
return DotProductSimpleImpl<double, double>(lhs, rhs, length);
}
ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept {
ui32 res = 0;
for (size_t i = 0; i < lengtInBytes; ++i) {
res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f);
res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8;
}
return res;
}
TTriWayDotProduct<float> TriWayDotProductSimple(
const float* lhs,
const float* rhs,
size_t length,
bool computeRR) noexcept
{
float sumLL = 0.0f;
float sumLR = 0.0f;
float sumRR = 0.0f;
for (size_t i = 0; i < length; ++i) {
const float l = lhs[i];
const float r = rhs[i];
sumLL += l * l;
sumLR += l * r;
if (computeRR) {
sumRR += r * r;
}
}
TTriWayDotProduct<float> result;
result.LL = sumLL;
result.LR = sumLR;
if (computeRR) {
result.RR = sumRR;
} else {
static constexpr TTriWayDotProduct<float> def;
result.RR = def.RR;
}
return result;
}
|