aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/dot_product/dot_product.h
blob: 0765633abdbfd6fb719b857123eb482aac6df6bb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#pragma once

#include <util/system/types.h>
#include <util/system/compiler.h>

#include <numeric>

/**
 * Dot product (Inner product or scalar product) implementation using SSE when possible.
 */
namespace NDotProductImpl {
    extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept;
    extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
    extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept;
    extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept;
    extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept;
}

Y_PURE_FUNCTION
inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept {
    return NDotProductImpl::DotProductI8Impl(lhs, rhs, length);
}

Y_PURE_FUNCTION
inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
    return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length);
}

Y_PURE_FUNCTION
inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept {
    return NDotProductImpl::DotProductI32Impl(lhs, rhs, length);
}

Y_PURE_FUNCTION
inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept {
    return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length);
}

Y_PURE_FUNCTION
inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept {
    return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length);
}

/**
 * Dot product to itself
 */
Y_PURE_FUNCTION
float L2NormSquared(const float* v, size_t length) noexcept;

// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct`
// where `lhs == rhs` because it will save N load instructions.

template <typename T>
struct TTriWayDotProduct {
    T LL = 1;
    T LR = 0;
    T RR = 1;
};

enum class ETriWayDotProductComputeMask: unsigned {
    // basic
    LL = 0b100,
    LR = 0b010,
    RR = 0b001,

    // useful combinations
    All = 0b111,
    Left = 0b110, // skip computation of R·R
    Right = 0b011, // skip computation of L·L
};

Y_PURE_FUNCTION
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept;

/**
 * For two vectors L and R computes 3 dot-products: L·L, L·R, R·R
 */
Y_PURE_FUNCTION
static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept {
    return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask));
}

namespace NDotProduct {
    // Simpler wrapper allowing to use this functions as template argument.
    template <typename T>
    struct TDotProduct {
        using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
        Y_PURE_FUNCTION
        inline TResult operator()(const T* l, const T* r, size_t length) const {
            return DotProduct(l, r, length);
        }
    };

    void DisableAvx2();
}