aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp
diff options
context:
space:
mode:
authorazevaykin <azevaykin@yandex-team.com>2024-03-26 18:35:08 +0300
committerazevaykin <azevaykin@yandex-team.com>2024-03-26 18:49:55 +0300
commitce034cd07e7ebbff42723e51067bac58d1943f47 (patch)
tree32ff4320767eecab50c7840d533224f9b710acd3 /library/cpp
parentf4ad51b5e6cbbe7d4970e6777f2ab415a25edba7 (diff)
downloadydb-ce034cd07e7ebbff42723e51067bac58d1943f47.tar.gz
Publish library/cpp/dot_product
Publish pod_product to https://github.com/ydb-platform/ydb It has already been published to github: https://github.com/catboost/catboost/tree/master/library/cpp/dot_product 44150e7508881f4239c960f90320799b1b090072
Diffstat (limited to 'library/cpp')
-rw-r--r--library/cpp/dot_product/README.md15
-rw-r--r--library/cpp/dot_product/dot_product.cpp274
-rw-r--r--library/cpp/dot_product/dot_product.h96
-rw-r--r--library/cpp/dot_product/dot_product_avx2.cpp344
-rw-r--r--library/cpp/dot_product/dot_product_avx2.h19
-rw-r--r--library/cpp/dot_product/dot_product_simple.cpp44
-rw-r--r--library/cpp/dot_product/dot_product_simple.h40
-rw-r--r--library/cpp/dot_product/dot_product_sse.cpp219
-rw-r--r--library/cpp/dot_product/dot_product_sse.h19
-rw-r--r--library/cpp/dot_product/ya.make20
10 files changed, 1090 insertions, 0 deletions
diff --git a/library/cpp/dot_product/README.md b/library/cpp/dot_product/README.md
new file mode 100644
index 0000000000..516dcf31de
--- /dev/null
+++ b/library/cpp/dot_product/README.md
@@ -0,0 +1,15 @@
+Библиотека для вычисления скалярного произведения векторов.
+=====================================================
+
+Данная библиотека содержит функцию DotProduct, вычисляющую скалярное произведение векторов различных типов.
+В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для сравнения
+можно посмотреть результаты бенчмарка.
+
+Типичное использование - замена кусков кода вроде:
+```
+for (int i = 0; i < len; i++)
+ dot_product += a[i] * b[i]);
+```
+на существенно более эффективный вызов ```DotProduct(a, b, len)```.
+
+Работает для типов i8, i32, float, double.
diff --git a/library/cpp/dot_product/dot_product.cpp b/library/cpp/dot_product/dot_product.cpp
new file mode 100644
index 0000000000..6be4d0a78f
--- /dev/null
+++ b/library/cpp/dot_product/dot_product.cpp
@@ -0,0 +1,274 @@
+#include "dot_product.h"
+#include "dot_product_sse.h"
+#include "dot_product_avx2.h"
+#include "dot_product_simple.h"
+
+#include <library/cpp/sse/sse.h>
+#include <library/cpp/testing/common/env.h>
+#include <util/system/compiler.h>
+#include <util/generic/utility.h>
+#include <util/system/cpu_id.h>
+#include <util/system/env.h>
+
+namespace NDotProductImpl {
+ i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple;
+ ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple;
+ i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple;
+ float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple;
+ double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple;
+
+ namespace {
+ [[maybe_unused]] const int _ = [] {
+ if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) {
+ DotProductI8Impl = &DotProductAvx2;
+ DotProductUi8Impl = &DotProductAvx2;
+ DotProductI32Impl = &DotProductAvx2;
+ DotProductFloatImpl = &DotProductAvx2;
+ DotProductDoubleImpl = &DotProductAvx2;
+ } else {
+#ifdef ARCADIA_SSE
+ DotProductI8Impl = &DotProductSse;
+ DotProductUi8Impl = &DotProductSse;
+ DotProductI32Impl = &DotProductSse;
+ DotProductFloatImpl = &DotProductSse;
+ DotProductDoubleImpl = &DotProductSse;
+#endif
+ }
+ return 0;
+ }();
+ }
+}
+
+#ifdef ARCADIA_SSE
+float L2NormSquared(const float* v, size_t length) noexcept {
+ __m128 sum1 = _mm_setzero_ps();
+ __m128 sum2 = _mm_setzero_ps();
+ __m128 a1, a2, m1, m2;
+
+ while (length >= 8) {
+ a1 = _mm_loadu_ps(v);
+ m1 = _mm_mul_ps(a1, a1);
+
+ a2 = _mm_loadu_ps(v + 4);
+ sum1 = _mm_add_ps(sum1, m1);
+
+ m2 = _mm_mul_ps(a2, a2);
+ sum2 = _mm_add_ps(sum2, m2);
+
+ length -= 8;
+ v += 8;
+ }
+
+ if (length >= 4) {
+ a1 = _mm_loadu_ps(v);
+ sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
+
+ length -= 4;
+ v += 4;
+ }
+
+ sum1 = _mm_add_ps(sum1, sum2);
+
+ if (length) {
+ switch (length) {
+ case 3:
+ a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]);
+ break;
+
+ case 2:
+ a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]);
+ break;
+
+ case 1:
+ a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]);
+ break;
+
+ default:
+ Y_UNREACHABLE();
+ }
+
+ sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
+ }
+
+ alignas(16) float res[4];
+ _mm_store_ps(res, sum1);
+
+ return res[0] + res[1] + res[2] + res[3];
+}
+
+template <bool computeLL, bool computeLR, bool computeRR>
+Y_FORCE_INLINE
+static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) {
+ if constexpr (computeLL) {
+ sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a));
+ }
+ if constexpr (computeLR) {
+ sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b));
+ }
+ if constexpr (computeRR) {
+ sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b));
+ }
+}
+
+
+template <bool computeLL, bool computeLR, bool computeRR>
+static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept {
+ __m128 sumLL1 = _mm_setzero_ps();
+ __m128 sumLR1 = _mm_setzero_ps();
+ __m128 sumRR1 = _mm_setzero_ps();
+ __m128 sumLL2 = _mm_setzero_ps();
+ __m128 sumLR2 = _mm_setzero_ps();
+ __m128 sumRR2 = _mm_setzero_ps();
+
+ while (length >= 8) {
+ TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
+ TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4));
+ length -= 8;
+ lhs += 8;
+ rhs += 8;
+ }
+
+ if (length >= 4) {
+ TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
+ length -= 4;
+ lhs += 4;
+ rhs += 4;
+ }
+
+ if constexpr (computeLL) {
+ sumLL1 = _mm_add_ps(sumLL1, sumLL2);
+ }
+ if constexpr (computeLR) {
+ sumLR1 = _mm_add_ps(sumLR1, sumLR2);
+ }
+ if constexpr (computeRR) {
+ sumRR1 = _mm_add_ps(sumRR1, sumRR2);
+ }
+
+ if (length) {
+ __m128 a, b;
+ switch (length) {
+ case 3:
+ a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
+ b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
+ break;
+ case 2:
+ a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
+ b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
+ break;
+ case 1:
+ a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
+ b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
+ break;
+ default:
+ Y_UNREACHABLE();
+ }
+ TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b);
+ }
+
+ __m128 t0 = sumLL1;
+ __m128 t1 = sumLR1;
+ __m128 t2 = sumRR1;
+ __m128 t3 = _mm_setzero_ps();
+ _MM_TRANSPOSE4_PS(t0, t1, t2, t3);
+ t0 = _mm_add_ps(t0, t1);
+ t0 = _mm_add_ps(t0, t2);
+ t0 = _mm_add_ps(t0, t3);
+
+ alignas(16) float res[4];
+ _mm_store_ps(res, t0);
+ TTriWayDotProduct<float> result{res[0], res[1], res[2]};
+ static constexpr const TTriWayDotProduct<float> def;
+ // fill skipped fields with default values
+ if constexpr (!computeLL) {
+ result.LL = def.LL;
+ }
+ if constexpr (!computeLR) {
+ result.LR = def.LR;
+ }
+ if constexpr (!computeRR) {
+ result.RR = def.RR;
+ }
+ return result;
+}
+
+
+TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
+ mask &= 0b111;
+ if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors
+ return TriWayDotProductImpl<true, true, true>(lhs, rhs, length);
+ } else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector
+ const bool computeLL = (mask == 0b110);
+ if (!computeLL) {
+ DoSwap(lhs, rhs);
+ }
+ auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length);
+ if (!computeLL) {
+ DoSwap(result.LL, result.RR);
+ }
+ return result;
+ } else {
+ // dispatch unlikely & sparse cases
+ TTriWayDotProduct<float> result{};
+ switch(mask) {
+ case 0b000:
+ break;
+ case 0b100:
+ result.LL = L2NormSquared(lhs, length);
+ break;
+ case 0b010:
+ result.LR = DotProduct(lhs, rhs, length);
+ break;
+ case 0b001:
+ result.RR = L2NormSquared(rhs, length);
+ break;
+ case 0b101:
+ result.LL = L2NormSquared(lhs, length);
+ result.RR = L2NormSquared(rhs, length);
+ break;
+ default:
+ Y_UNREACHABLE();
+ }
+ return result;
+ }
+}
+
+#else
+
+float L2NormSquared(const float* v, size_t length) noexcept {
+ return DotProduct(v, v, length);
+}
+
+TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
+ TTriWayDotProduct<float> result;
+ if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) {
+ result.LL = L2NormSquared(lhs, length);
+ }
+ if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) {
+ result.LR = DotProduct(lhs, rhs, length);
+ }
+ if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) {
+ result.RR = L2NormSquared(rhs, length);
+ }
+ return result;
+}
+
+#endif // ARCADIA_SSE
+
+namespace NDotProduct {
+ void DisableAvx2() {
+#ifdef ARCADIA_SSE
+ NDotProductImpl::DotProductI8Impl = &DotProductSse;
+ NDotProductImpl::DotProductUi8Impl = &DotProductSse;
+ NDotProductImpl::DotProductI32Impl = &DotProductSse;
+ NDotProductImpl::DotProductFloatImpl = &DotProductSse;
+ NDotProductImpl::DotProductDoubleImpl = &DotProductSse;
+#else
+ NDotProductImpl::DotProductI8Impl = &DotProductSimple;
+ NDotProductImpl::DotProductUi8Impl = &DotProductSimple;
+ NDotProductImpl::DotProductI32Impl = &DotProductSimple;
+ NDotProductImpl::DotProductFloatImpl = &DotProductSimple;
+ NDotProductImpl::DotProductDoubleImpl = &DotProductSimple;
+#endif
+ }
+}
diff --git a/library/cpp/dot_product/dot_product.h b/library/cpp/dot_product/dot_product.h
new file mode 100644
index 0000000000..0765633abd
--- /dev/null
+++ b/library/cpp/dot_product/dot_product.h
@@ -0,0 +1,96 @@
+#pragma once
+
+#include <util/system/types.h>
+#include <util/system/compiler.h>
+
+#include <numeric>
+
+/**
+ * Dot product (Inner product or scalar product) implementation using SSE when possible.
+ */
+namespace NDotProductImpl {
+ extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept;
+ extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
+ extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept;
+ extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept;
+ extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept;
+}
+
+Y_PURE_FUNCTION
+inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept {
+ return NDotProductImpl::DotProductI8Impl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+ return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ return NDotProductImpl::DotProductI32Impl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept {
+ return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept {
+ return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length);
+}
+
+/**
+ * Dot product to itself
+ */
+Y_PURE_FUNCTION
+float L2NormSquared(const float* v, size_t length) noexcept;
+
+// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct`
+// where `lhs == rhs` because it will save N load instructions.
+
+template <typename T>
+struct TTriWayDotProduct {
+ T LL = 1;
+ T LR = 0;
+ T RR = 1;
+};
+
+enum class ETriWayDotProductComputeMask: unsigned {
+ // basic
+ LL = 0b100,
+ LR = 0b010,
+ RR = 0b001,
+
+ // useful combinations
+ All = 0b111,
+ Left = 0b110, // skip computation of R·R
+ Right = 0b011, // skip computation of L·L
+};
+
+Y_PURE_FUNCTION
+TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept;
+
+/**
+ * For two vectors L and R computes 3 dot-products: L·L, L·R, R·R
+ */
+Y_PURE_FUNCTION
+static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept {
+ return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask));
+}
+
+namespace NDotProduct {
+ // Simpler wrapper allowing to use this functions as template argument.
+ template <typename T>
+ struct TDotProduct {
+ using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
+ Y_PURE_FUNCTION
+ inline TResult operator()(const T* l, const T* r, size_t length) const {
+ return DotProduct(l, r, length);
+ }
+ };
+
+ void DisableAvx2();
+}
+
diff --git a/library/cpp/dot_product/dot_product_avx2.cpp b/library/cpp/dot_product/dot_product_avx2.cpp
new file mode 100644
index 0000000000..a0f7c169ee
--- /dev/null
+++ b/library/cpp/dot_product/dot_product_avx2.cpp
@@ -0,0 +1,344 @@
+#include "dot_product_avx2.h"
+#include "dot_product_simple.h"
+#include "dot_product_sse.h"
+
+#if defined(_avx2_) && defined(_fma_)
+
+#include <util/system/platform.h>
+#include <util/system/compiler.h>
+#include <util/generic/utility.h>
+
+#include <immintrin.h>
+
+namespace {
+ constexpr i64 Bits(int n) {
+ return i64(-1) ^ ((i64(1) << (64 - n)) - 1);
+ }
+
+ constexpr __m256 BlendMask64[8] = {
+ __m256i{Bits(64), Bits(64), Bits(64), Bits(64)},
+ __m256i{0, Bits(64), Bits(64), Bits(64)},
+ __m256i{0, 0, Bits(64), Bits(64)},
+ __m256i{0, 0, 0, Bits(64)},
+ };
+
+ constexpr __m256 BlendMask32[8] = {
+ __m256i{Bits(64), Bits(64), Bits(64), Bits(64)},
+ __m256i{Bits(32), Bits(64), Bits(64), Bits(64)},
+ __m256i{0, Bits(64), Bits(64), Bits(64)},
+ __m256i{0, Bits(32), Bits(64), Bits(64)},
+ __m256i{0, 0, Bits(64), Bits(64)},
+ __m256i{0, 0, Bits(32), Bits(64)},
+ __m256i{0, 0, 0, Bits(64)},
+ __m256i{0, 0, 0, Bits(32)},
+ };
+
+ constexpr __m128 BlendMask8[16] = {
+ __m128i{Bits(64), Bits(64)},
+ __m128i{Bits(56), Bits(64)},
+ __m128i{Bits(48), Bits(64)},
+ __m128i{Bits(40), Bits(64)},
+ __m128i{Bits(32), Bits(64)},
+ __m128i{Bits(24), Bits(64)},
+ __m128i{Bits(16), Bits(64)},
+ __m128i{Bits(8), Bits(64)},
+ __m128i{0, Bits(64)},
+ __m128i{0, Bits(56)},
+ __m128i{0, Bits(48)},
+ __m128i{0, Bits(40)},
+ __m128i{0, Bits(32)},
+ __m128i{0, Bits(24)},
+ __m128i{0, Bits(16)},
+ __m128i{0, Bits(8)},
+ };
+
+ // See https://stackoverflow.com/a/60109639
+ // Horizontal sum of eight i32 values in an avx register
+ i32 HsumI32(__m256i v) {
+ __m128i x = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
+ __m128i hi64 = _mm_unpackhi_epi64(x, x);
+ __m128i sum64 = _mm_add_epi32(hi64, x);
+ __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ __m128i sum32 = _mm_add_epi32(sum64, hi32);
+ return _mm_cvtsi128_si32(sum32);
+ }
+
+ // Horizontal sum of four i64 values in an avx register
+ i64 HsumI64(__m256i v) {
+ __m128i x = _mm_add_epi64(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
+ return _mm_cvtsi128_si64(x) + _mm_extract_epi64(x, 1);
+ }
+
+ // Horizontal sum of eight float values in an avx register
+ float HsumFloat(__m256 v) {
+ __m256 y = _mm256_permute2f128_ps(v, v, 1);
+ v = _mm256_add_ps(v, y);
+ v = _mm256_hadd_ps(v, v);
+ return _mm256_cvtss_f32(_mm256_hadd_ps(v, v));
+ }
+
+ // Horizontal sum of four double values in an avx register
+ double HsumDouble(__m256 v) {
+ __m128d x = _mm_add_pd(_mm256_castpd256_pd128(v), _mm256_extractf128_pd(v, 1));
+ x = _mm_add_pd(x, _mm_shuffle_pd(x, x, 1));
+ return _mm_cvtsd_f64(x);
+ }
+
+ __m128i Load128i(const void* ptr) {
+ return _mm_loadu_si128((const __m128i*)ptr);
+ }
+
+ __m256i Load256i(const void* ptr) {
+ return _mm256_loadu_si256((const __m256i*)ptr);
+ }
+
+ // Unrolled dot product for relatively small sizes
+ // The loop with known upper bound is unrolled by the compiler, no need to do anything special about it
+ template <size_t size, class TInput, class TExtend>
+ i32 DotProductInt8Avx2_Unroll(const TInput* lhs, const TInput* rhs, TExtend extend) noexcept {
+ static_assert(size % 16 == 0);
+ auto sum = _mm256_setzero_ps();
+ for (size_t i = 0; i != size; i += 16) {
+ sum = _mm256_add_epi32(sum, _mm256_madd_epi16(extend(Load128i(lhs + i)), extend(Load128i(rhs + i))));
+ }
+
+ return HsumI32(sum);
+ }
+
+ template <class TInput, class TExtend>
+ i32 DotProductInt8Avx2(const TInput* lhs, const TInput* rhs, size_t length, TExtend extend) noexcept {
+ // Fully unrolled versions for small multiples for 16
+ switch (length) {
+ case 16: return DotProductInt8Avx2_Unroll<16>(lhs, rhs, extend);
+ case 32: return DotProductInt8Avx2_Unroll<32>(lhs, rhs, extend);
+ case 48: return DotProductInt8Avx2_Unroll<48>(lhs, rhs, extend);
+ case 64: return DotProductInt8Avx2_Unroll<64>(lhs, rhs, extend);
+ }
+
+ __m256i sum = _mm256_setzero_ps();
+
+ if (const auto leftover = length % 16; leftover != 0) {
+ auto a = _mm_blendv_epi8(
+ Load128i(lhs), _mm_setzero_ps(), BlendMask8[leftover]);
+ auto b = _mm_blendv_epi8(
+ Load128i(rhs), _mm_setzero_ps(), BlendMask8[leftover]);
+
+ sum = _mm256_madd_epi16(extend(a), extend(b));
+
+ lhs += leftover;
+ rhs += leftover;
+ length -= leftover;
+ }
+
+ while (length >= 32) {
+ const auto l0 = extend(Load128i(lhs));
+ const auto r0 = extend(Load128i(rhs));
+ const auto l1 = extend(Load128i(lhs + 16));
+ const auto r1 = extend(Load128i(rhs + 16));
+
+ const auto s0 = _mm256_madd_epi16(l0, r0);
+ const auto s1 = _mm256_madd_epi16(l1, r1);
+
+ sum = _mm256_add_epi32(sum, _mm256_add_epi32(s0, s1));
+
+ lhs += 32;
+ rhs += 32;
+ length -= 32;
+ }
+
+ if (length > 0) {
+ auto l = extend(Load128i(lhs));
+ auto r = extend(Load128i(rhs));
+
+ sum = _mm256_add_epi32(sum, _mm256_madd_epi16(l, r));
+ }
+
+ return HsumI32(sum);
+ }
+}
+
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept {
+ if (length < 16) {
+ return DotProductSse(lhs, rhs, length);
+ }
+ return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) {
+ return _mm256_cvtepi8_epi16(x);
+ });
+}
+
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+ if (length < 16) {
+ return DotProductSse(lhs, rhs, length);
+ }
+ return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) {
+ return _mm256_cvtepu8_epi16(x);
+ });
+}
+
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ if (length < 16) {
+ return DotProductSse(lhs, rhs, length);
+ }
+ __m256i res = _mm256_setzero_ps();
+
+ if (const auto leftover = length % 8; leftover != 0) {
+ // Use floating-point blendv. Who cares as long as the size is right.
+ __m256i a = _mm256_blendv_ps(
+ Load256i(lhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+ __m256i b = _mm256_blendv_ps(
+ Load256i(rhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+
+ res = _mm256_mul_epi32(a, b);
+ a = _mm256_alignr_epi8(a, a, 4);
+ b = _mm256_alignr_epi8(b, b, 4);
+ res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res);
+
+ lhs += leftover;
+ rhs += leftover;
+ length -= leftover;
+ }
+
+ while (length >= 8) {
+ __m256i a = Load256i(lhs);
+ __m256i b = Load256i(rhs);
+ res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res); // This is lower parts multiplication
+ a = _mm256_alignr_epi8(a, a, 4);
+ b = _mm256_alignr_epi8(b, b, 4);
+ res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res);
+ rhs += 8;
+ lhs += 8;
+ length -= 8;
+ }
+
+ return HsumI64(res);
+}
+
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept {
+ if (length < 16) {
+ return DotProductSse(lhs, rhs, length);
+ }
+ __m256 sum1 = _mm256_setzero_ps();
+ __m256 sum2 = _mm256_setzero_ps();
+ __m256 a1, b1, a2, b2;
+
+ if (const auto leftover = length % 8; leftover != 0) {
+ a1 = _mm256_blendv_ps(
+ _mm256_loadu_ps(lhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+ b1 = _mm256_blendv_ps(
+ _mm256_loadu_ps(rhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+ sum1 = _mm256_mul_ps(a1, b1);
+ lhs += leftover;
+ rhs += leftover;
+ length -= leftover;
+ }
+
+ while (length >= 16) {
+ a1 = _mm256_loadu_ps(lhs);
+ b1 = _mm256_loadu_ps(rhs);
+ a2 = _mm256_loadu_ps(lhs + 8);
+ b2 = _mm256_loadu_ps(rhs + 8);
+
+ sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+ sum2 = _mm256_fmadd_ps(a2, b2, sum2);
+
+ length -= 16;
+ lhs += 16;
+ rhs += 16;
+ }
+
+ if (length > 0) {
+ a1 = _mm256_loadu_ps(lhs);
+ b1 = _mm256_loadu_ps(rhs);
+ sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+ }
+
+ return HsumFloat(_mm256_add_ps(sum1, sum2));
+}
+
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept {
+ if (length < 16) {
+ return DotProductSse(lhs, rhs, length);
+ }
+ __m256d sum1 = _mm256_setzero_pd();
+ __m256d sum2 = _mm256_setzero_pd();
+ __m256d a1, b1, a2, b2;
+
+ if (const auto leftover = length % 4; leftover != 0) {
+ a1 = _mm256_blendv_pd(
+ _mm256_loadu_pd(lhs), _mm256_setzero_ps(), BlendMask64[leftover]);
+ b1 = _mm256_blendv_pd(
+ _mm256_loadu_pd(rhs), _mm256_setzero_ps(), BlendMask64[leftover]);
+ sum1 = _mm256_mul_pd(a1, b1);
+ lhs += leftover;
+ rhs += leftover;
+ length -= leftover;
+ }
+
+ while (length >= 8) {
+ a1 = _mm256_loadu_pd(lhs);
+ b1 = _mm256_loadu_pd(rhs);
+ a2 = _mm256_loadu_pd(lhs + 4);
+ b2 = _mm256_loadu_pd(rhs + 4);
+
+ sum1 = _mm256_fmadd_pd(a1, b1, sum1);
+ sum2 = _mm256_fmadd_pd(a2, b2, sum2);
+
+ length -= 8;
+ lhs += 8;
+ rhs += 8;
+ }
+
+ if (length > 0) {
+ a1 = _mm256_loadu_pd(lhs);
+ b1 = _mm256_loadu_pd(rhs);
+ sum1 = _mm256_fmadd_pd(a1, b1, sum1);
+ }
+
+ return HsumDouble(_mm256_add_pd(sum1, sum2));
+}
+
+#elif defined(ARCADIA_SSE)
+
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept {
+ return DotProductSse(lhs, rhs, length);
+}
+
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+ return DotProductSse(lhs, rhs, length);
+}
+
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ return DotProductSse(lhs, rhs, length);
+}
+
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept {
+ return DotProductSse(lhs, rhs, length);
+}
+
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept {
+ return DotProductSse(lhs, rhs, length);
+}
+
+#else
+
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept {
+ return DotProductSimple(lhs, rhs, length);
+}
+
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+ return DotProductSimple(lhs, rhs, length);
+}
+
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ return DotProductSimple(lhs, rhs, length);
+}
+
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept {
+ return DotProductSimple(lhs, rhs, length);
+}
+
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept {
+ return DotProductSimple(lhs, rhs, length);
+}
+
+#endif
diff --git a/library/cpp/dot_product/dot_product_avx2.h b/library/cpp/dot_product/dot_product_avx2.h
new file mode 100644
index 0000000000..715f151f44
--- /dev/null
+++ b/library/cpp/dot_product/dot_product_avx2.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <util/system/types.h>
+#include <util/system/compiler.h>
+
+Y_PURE_FUNCTION
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept;
diff --git a/library/cpp/dot_product/dot_product_simple.cpp b/library/cpp/dot_product/dot_product_simple.cpp
new file mode 100644
index 0000000000..02891c8a22
--- /dev/null
+++ b/library/cpp/dot_product/dot_product_simple.cpp
@@ -0,0 +1,44 @@
+#include "dot_product_simple.h"
+
+namespace {
+ template <typename Res, typename Number>
+ static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept {
+ Res s0 = 0;
+ Res s1 = 0;
+ Res s2 = 0;
+ Res s3 = 0;
+
+ while (length >= 4) {
+ s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]);
+ s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]);
+ s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]);
+ s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]);
+ lhs += 4;
+ rhs += 4;
+ length -= 4;
+ }
+
+ while (length--) {
+ s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++);
+ }
+
+ return s0 + s1 + s2 + s3;
+ }
+}
+
+float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept {
+ return DotProductSimpleImpl<float, float>(lhs, rhs, length);
+}
+
+double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept {
+ return DotProductSimpleImpl<double, double>(lhs, rhs, length);
+}
+
+ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept {
+ ui32 res = 0;
+ for (size_t i = 0; i < lengtInBytes; ++i) {
+ res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f);
+ res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8;
+ }
+ return res;
+}
diff --git a/library/cpp/dot_product/dot_product_simple.h b/library/cpp/dot_product/dot_product_simple.h
new file mode 100644
index 0000000000..dd13dd7592
--- /dev/null
+++ b/library/cpp/dot_product/dot_product_simple.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include <util/system/compiler.h>
+#include <util/system/types.h>
+
+#include <numeric>
+
+/**
+ * Dot product implementation without SSE optimizations.
+ */
+Y_PURE_FUNCTION
+inline ui32 DotProductSimple(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+ return std::inner_product(lhs, lhs + length, rhs, static_cast<ui32>(0u),
+ [](ui32 x1, ui16 x2) {return x1 + x2;},
+ [](ui16 x1, ui8 x2) {return x1 * x2;});
+}
+
+Y_PURE_FUNCTION
+inline i32 DotProductSimple(const i8* lhs, const i8* rhs, size_t length) noexcept {
+ return std::inner_product(lhs, lhs + length, rhs, static_cast<i32>(0),
+ [](i32 x1, i16 x2) {return x1 + x2;},
+ [](i16 x1, i8 x2) {return x1 * x2;});
+}
+
+Y_PURE_FUNCTION
+inline i64 DotProductSimple(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ return std::inner_product(lhs, lhs + length, rhs, static_cast<i64>(0),
+ [](i64 x1, i64 x2) {return x1 + x2;},
+ [](i64 x1, i32 x2) {return x1 * x2;});
+}
+
+Y_PURE_FUNCTION
+float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept;
+
diff --git a/library/cpp/dot_product/dot_product_sse.cpp b/library/cpp/dot_product/dot_product_sse.cpp
new file mode 100644
index 0000000000..5256cfe98a
--- /dev/null
+++ b/library/cpp/dot_product/dot_product_sse.cpp
@@ -0,0 +1,219 @@
+#include "dot_product_sse.h"
+
+#include <library/cpp/sse/sse.h>
+#include <util/system/platform.h>
+#include <util/system/compiler.h>
+
+#ifdef ARCADIA_SSE
+i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept {
+ const __m128i zero = _mm_setzero_si128();
+ __m128i resVec = zero;
+ while (length >= 16) {
+ __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
+ __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
+
+#ifdef _sse4_1_
+ __m128i lLo = _mm_cvtepi8_epi16(lVec);
+ __m128i rLo = _mm_cvtepi8_epi16(rVec);
+ __m128i lHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(lVec, lVec, 8));
+ __m128i rHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(rVec, rVec, 8));
+#else
+ __m128i lLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, lVec), 8);
+ __m128i rLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, rVec), 8);
+ __m128i lHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, lVec), 8);
+ __m128i rHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, rVec), 8);
+#endif
+ resVec = _mm_add_epi32(resVec,
+ _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi)));
+
+ lhs += 16;
+ rhs += 16;
+ length -= 16;
+ }
+
+ alignas(16) i32 res[4];
+ _mm_store_si128((__m128i*)res, resVec);
+ i32 sum = res[0] + res[1] + res[2] + res[3];
+ for (size_t i = 0; i < length; ++i) {
+ sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]);
+ }
+
+ return sum;
+}
+
+ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+ const __m128i zero = _mm_setzero_si128();
+ __m128i resVec = zero;
+ while (length >= 16) {
+ __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
+ __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
+
+ __m128i lLo = _mm_unpacklo_epi8(lVec, zero);
+ __m128i rLo = _mm_unpacklo_epi8(rVec, zero);
+ __m128i lHi = _mm_unpackhi_epi8(lVec, zero);
+ __m128i rHi = _mm_unpackhi_epi8(rVec, zero);
+
+ resVec = _mm_add_epi32(resVec,
+ _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi)));
+
+ lhs += 16;
+ rhs += 16;
+ length -= 16;
+ }
+
+ alignas(16) i32 res[4];
+ _mm_store_si128((__m128i*)res, resVec);
+ i32 sum = res[0] + res[1] + res[2] + res[3];
+ for (size_t i = 0; i < length; ++i) {
+ sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]);
+ }
+
+ return static_cast<ui32>(sum);
+}
+#ifdef _sse4_1_
+
+i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ __m128i zero = _mm_setzero_si128();
+ __m128i res = zero;
+
+ while (length >= 4) {
+ __m128i a = _mm_loadu_si128((const __m128i*)lhs);
+ __m128i b = _mm_loadu_si128((const __m128i*)rhs);
+ res = _mm_add_epi64(_mm_mul_epi32(a, b), res); // This is lower parts multiplication
+ a = _mm_alignr_epi8(a, a, 4);
+ b = _mm_alignr_epi8(b, b, 4);
+ res = _mm_add_epi64(_mm_mul_epi32(a, b), res);
+ rhs += 4;
+ lhs += 4;
+ length -= 4;
+ }
+
+ alignas(16) i64 r[2];
+ _mm_store_si128((__m128i*)r, res);
+ i64 sum = r[0] + r[1];
+
+ for (size_t i = 0; i < length; ++i) {
+ sum += static_cast<i64>(lhs[i]) * static_cast<i64>(rhs[i]);
+ }
+
+ return sum;
+}
+
+#else
+#include "dot_product_simple.h"
+
+i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept {
+ return DotProductSimple(lhs, rhs, length);
+}
+
+#endif
+
+float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept {
+ __m128 sum1 = _mm_setzero_ps();
+ __m128 sum2 = _mm_setzero_ps();
+ __m128 a1, b1, a2, b2, m1, m2;
+
+ while (length >= 8) {
+ a1 = _mm_loadu_ps(lhs);
+ b1 = _mm_loadu_ps(rhs);
+ m1 = _mm_mul_ps(a1, b1);
+
+ a2 = _mm_loadu_ps(lhs + 4);
+ sum1 = _mm_add_ps(sum1, m1);
+
+ b2 = _mm_loadu_ps(rhs + 4);
+ m2 = _mm_mul_ps(a2, b2);
+
+ sum2 = _mm_add_ps(sum2, m2);
+
+ length -= 8;
+ lhs += 8;
+ rhs += 8;
+ }
+
+ if (length >= 4) {
+ a1 = _mm_loadu_ps(lhs);
+ b1 = _mm_loadu_ps(rhs);
+ sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1));
+
+ length -= 4;
+ lhs += 4;
+ rhs += 4;
+ }
+
+ sum1 = _mm_add_ps(sum1, sum2);
+
+ if (length) {
+ switch (length) {
+ case 3:
+ a1 = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
+ b1 = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
+ break;
+
+ case 2:
+ a1 = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
+ b1 = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
+ break;
+
+ case 1:
+ a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
+ b1 = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
+ break;
+
+ default:
+ Y_UNREACHABLE();
+ }
+
+ sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1));
+ }
+
+ alignas(16) float res[4];
+ _mm_store_ps(res, sum1);
+
+ return res[0] + res[1] + res[2] + res[3];
+}
+
+double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept {
+ __m128d sum1 = _mm_setzero_pd();
+ __m128d sum2 = _mm_setzero_pd();
+ __m128d a1, b1, a2, b2;
+
+ while (length >= 4) {
+ a1 = _mm_loadu_pd(lhs);
+ b1 = _mm_loadu_pd(rhs);
+ sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
+
+ a2 = _mm_loadu_pd(lhs + 2);
+ b2 = _mm_loadu_pd(rhs + 2);
+ sum2 = _mm_add_pd(sum2, _mm_mul_pd(a2, b2));
+
+ length -= 4;
+ lhs += 4;
+ rhs += 4;
+ }
+
+ if (length >= 2) {
+ a1 = _mm_loadu_pd(lhs);
+ b1 = _mm_loadu_pd(rhs);
+ sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
+
+ length -= 2;
+ lhs += 2;
+ rhs += 2;
+ }
+
+ sum1 = _mm_add_pd(sum1, sum2);
+
+ if (length > 0) {
+ a1 = _mm_set_pd(lhs[0], 0.0);
+ b1 = _mm_set_pd(rhs[0], 0.0);
+ sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
+ }
+
+ alignas(16) double res[2];
+ _mm_store_pd(res, sum1);
+
+ return res[0] + res[1];
+}
+
+#endif // ARCADIA_SSE
diff --git a/library/cpp/dot_product/dot_product_sse.h b/library/cpp/dot_product/dot_product_sse.h
new file mode 100644
index 0000000000..814736007d
--- /dev/null
+++ b/library/cpp/dot_product/dot_product_sse.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <util/system/types.h>
+#include <util/system/compiler.h>
+
+Y_PURE_FUNCTION
+i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept;
diff --git a/library/cpp/dot_product/ya.make b/library/cpp/dot_product/ya.make
new file mode 100644
index 0000000000..b308967b4b
--- /dev/null
+++ b/library/cpp/dot_product/ya.make
@@ -0,0 +1,20 @@
+LIBRARY()
+
+SRCS(
+ dot_product.cpp
+ dot_product_sse.cpp
+ dot_product_simple.cpp
+)
+
+IF (USE_SSE4 == "yes" AND OS_LINUX == "yes")
+ SRC_C_AVX2(dot_product_avx2.cpp -mfma)
+ELSE()
+ SRC(dot_product_avx2.cpp)
+ENDIF()
+
+PEERDIR(
+ library/cpp/sse
+ library/cpp/testing/common
+)
+
+END()