diff options
author | azevaykin <azevaykin@yandex-team.com> | 2024-03-26 18:35:08 +0300 |
---|---|---|
committer | azevaykin <azevaykin@yandex-team.com> | 2024-03-26 18:49:55 +0300 |
commit | ce034cd07e7ebbff42723e51067bac58d1943f47 (patch) | |
tree | 32ff4320767eecab50c7840d533224f9b710acd3 /library/cpp | |
parent | f4ad51b5e6cbbe7d4970e6777f2ab415a25edba7 (diff) | |
download | ydb-ce034cd07e7ebbff42723e51067bac58d1943f47.tar.gz |
Publish library/cpp/dot_product
Publish pod_product to https://github.com/ydb-platform/ydb
It has already been published to github: https://github.com/catboost/catboost/tree/master/library/cpp/dot_product
44150e7508881f4239c960f90320799b1b090072
Diffstat (limited to 'library/cpp')
-rw-r--r-- | library/cpp/dot_product/README.md | 15 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product.cpp | 274 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product.h | 96 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product_avx2.cpp | 344 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product_avx2.h | 19 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product_simple.cpp | 44 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product_simple.h | 40 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product_sse.cpp | 219 | ||||
-rw-r--r-- | library/cpp/dot_product/dot_product_sse.h | 19 | ||||
-rw-r--r-- | library/cpp/dot_product/ya.make | 20 |
10 files changed, 1090 insertions, 0 deletions
diff --git a/library/cpp/dot_product/README.md b/library/cpp/dot_product/README.md new file mode 100644 index 0000000000..516dcf31de --- /dev/null +++ b/library/cpp/dot_product/README.md @@ -0,0 +1,15 @@ +Библиотека для вычисления скалярного произведения векторов. +===================================================== + +Данная библиотека содержит функцию DotProduct, вычисляющую скалярное произведение векторов различных типов. +В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для сравнения +можно посмотреть результаты бенчмарка. + +Типичное использование - замена кусков кода вроде: +``` +for (int i = 0; i < len; i++) + dot_product += a[i] * b[i]); +``` +на существенно более эффективный вызов ```DotProduct(a, b, len)```. + +Работает для типов i8, i32, float, double. diff --git a/library/cpp/dot_product/dot_product.cpp b/library/cpp/dot_product/dot_product.cpp new file mode 100644 index 0000000000..6be4d0a78f --- /dev/null +++ b/library/cpp/dot_product/dot_product.cpp @@ -0,0 +1,274 @@ +#include "dot_product.h" +#include "dot_product_sse.h" +#include "dot_product_avx2.h" +#include "dot_product_simple.h" + +#include <library/cpp/sse/sse.h> +#include <library/cpp/testing/common/env.h> +#include <util/system/compiler.h> +#include <util/generic/utility.h> +#include <util/system/cpu_id.h> +#include <util/system/env.h> + +namespace NDotProductImpl { + i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple; + ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple; + i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple; + float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple; + double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple; + + namespace { + [[maybe_unused]] const int _ = [] { + if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) { + DotProductI8Impl = &DotProductAvx2; + DotProductUi8Impl = &DotProductAvx2; + DotProductI32Impl = &DotProductAvx2; + DotProductFloatImpl = &DotProductAvx2; + DotProductDoubleImpl = &DotProductAvx2; + } else { +#ifdef ARCADIA_SSE + DotProductI8Impl = &DotProductSse; + DotProductUi8Impl = &DotProductSse; + DotProductI32Impl = &DotProductSse; + DotProductFloatImpl = &DotProductSse; + DotProductDoubleImpl = &DotProductSse; +#endif + } + return 0; + }(); + } +} + +#ifdef ARCADIA_SSE +float L2NormSquared(const float* v, size_t length) noexcept { + __m128 sum1 = _mm_setzero_ps(); + __m128 sum2 = _mm_setzero_ps(); + __m128 a1, a2, m1, m2; + + while (length >= 8) { + a1 = _mm_loadu_ps(v); + m1 = _mm_mul_ps(a1, a1); + + a2 = _mm_loadu_ps(v + 4); + sum1 = _mm_add_ps(sum1, m1); + + m2 = _mm_mul_ps(a2, a2); + sum2 = _mm_add_ps(sum2, m2); + + length -= 8; + v += 8; + } + + if (length >= 4) { + a1 = _mm_loadu_ps(v); + sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1)); + + length -= 4; + v += 4; + } + + sum1 = _mm_add_ps(sum1, sum2); + + if (length) { + switch (length) { + case 3: + a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]); + break; + + case 2: + a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]); + break; + + case 1: + a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]); + break; + + default: + Y_UNREACHABLE(); + } + + sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1)); + } + + alignas(16) float res[4]; + _mm_store_ps(res, sum1); + + return res[0] + res[1] + res[2] + res[3]; +} + +template <bool computeLL, bool computeLR, bool computeRR> +Y_FORCE_INLINE +static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) { + if constexpr (computeLL) { + sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a)); + } + if constexpr (computeLR) { + sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b)); + } + if constexpr (computeRR) { + sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b)); + } +} + + +template <bool computeLL, bool computeLR, bool computeRR> +static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept { + __m128 sumLL1 = _mm_setzero_ps(); + __m128 sumLR1 = _mm_setzero_ps(); + __m128 sumRR1 = _mm_setzero_ps(); + __m128 sumLL2 = _mm_setzero_ps(); + __m128 sumLR2 = _mm_setzero_ps(); + __m128 sumRR2 = _mm_setzero_ps(); + + while (length >= 8) { + TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); + TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4)); + length -= 8; + lhs += 8; + rhs += 8; + } + + if (length >= 4) { + TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); + length -= 4; + lhs += 4; + rhs += 4; + } + + if constexpr (computeLL) { + sumLL1 = _mm_add_ps(sumLL1, sumLL2); + } + if constexpr (computeLR) { + sumLR1 = _mm_add_ps(sumLR1, sumLR2); + } + if constexpr (computeRR) { + sumRR1 = _mm_add_ps(sumRR1, sumRR2); + } + + if (length) { + __m128 a, b; + switch (length) { + case 3: + a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]); + b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]); + break; + case 2: + a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]); + b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]); + break; + case 1: + a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]); + b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]); + break; + default: + Y_UNREACHABLE(); + } + TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b); + } + + __m128 t0 = sumLL1; + __m128 t1 = sumLR1; + __m128 t2 = sumRR1; + __m128 t3 = _mm_setzero_ps(); + _MM_TRANSPOSE4_PS(t0, t1, t2, t3); + t0 = _mm_add_ps(t0, t1); + t0 = _mm_add_ps(t0, t2); + t0 = _mm_add_ps(t0, t3); + + alignas(16) float res[4]; + _mm_store_ps(res, t0); + TTriWayDotProduct<float> result{res[0], res[1], res[2]}; + static constexpr const TTriWayDotProduct<float> def; + // fill skipped fields with default values + if constexpr (!computeLL) { + result.LL = def.LL; + } + if constexpr (!computeLR) { + result.LR = def.LR; + } + if constexpr (!computeRR) { + result.RR = def.RR; + } + return result; +} + + +TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept { + mask &= 0b111; + if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors + return TriWayDotProductImpl<true, true, true>(lhs, rhs, length); + } else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector + const bool computeLL = (mask == 0b110); + if (!computeLL) { + DoSwap(lhs, rhs); + } + auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length); + if (!computeLL) { + DoSwap(result.LL, result.RR); + } + return result; + } else { + // dispatch unlikely & sparse cases + TTriWayDotProduct<float> result{}; + switch(mask) { + case 0b000: + break; + case 0b100: + result.LL = L2NormSquared(lhs, length); + break; + case 0b010: + result.LR = DotProduct(lhs, rhs, length); + break; + case 0b001: + result.RR = L2NormSquared(rhs, length); + break; + case 0b101: + result.LL = L2NormSquared(lhs, length); + result.RR = L2NormSquared(rhs, length); + break; + default: + Y_UNREACHABLE(); + } + return result; + } +} + +#else + +float L2NormSquared(const float* v, size_t length) noexcept { + return DotProduct(v, v, length); +} + +TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept { + TTriWayDotProduct<float> result; + if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) { + result.LL = L2NormSquared(lhs, length); + } + if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) { + result.LR = DotProduct(lhs, rhs, length); + } + if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) { + result.RR = L2NormSquared(rhs, length); + } + return result; +} + +#endif // ARCADIA_SSE + +namespace NDotProduct { + void DisableAvx2() { +#ifdef ARCADIA_SSE + NDotProductImpl::DotProductI8Impl = &DotProductSse; + NDotProductImpl::DotProductUi8Impl = &DotProductSse; + NDotProductImpl::DotProductI32Impl = &DotProductSse; + NDotProductImpl::DotProductFloatImpl = &DotProductSse; + NDotProductImpl::DotProductDoubleImpl = &DotProductSse; +#else + NDotProductImpl::DotProductI8Impl = &DotProductSimple; + NDotProductImpl::DotProductUi8Impl = &DotProductSimple; + NDotProductImpl::DotProductI32Impl = &DotProductSimple; + NDotProductImpl::DotProductFloatImpl = &DotProductSimple; + NDotProductImpl::DotProductDoubleImpl = &DotProductSimple; +#endif + } +} diff --git a/library/cpp/dot_product/dot_product.h b/library/cpp/dot_product/dot_product.h new file mode 100644 index 0000000000..0765633abd --- /dev/null +++ b/library/cpp/dot_product/dot_product.h @@ -0,0 +1,96 @@ +#pragma once + +#include <util/system/types.h> +#include <util/system/compiler.h> + +#include <numeric> + +/** + * Dot product (Inner product or scalar product) implementation using SSE when possible. + */ +namespace NDotProductImpl { + extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept; + extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept; + extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept; + extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept; + extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept; +} + +Y_PURE_FUNCTION +inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept { + return NDotProductImpl::DotProductI8Impl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept { + return NDotProductImpl::DotProductI32Impl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept { + return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept { + return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length); +} + +/** + * Dot product to itself + */ +Y_PURE_FUNCTION +float L2NormSquared(const float* v, size_t length) noexcept; + +// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct` +// where `lhs == rhs` because it will save N load instructions. + +template <typename T> +struct TTriWayDotProduct { + T LL = 1; + T LR = 0; + T RR = 1; +}; + +enum class ETriWayDotProductComputeMask: unsigned { + // basic + LL = 0b100, + LR = 0b010, + RR = 0b001, + + // useful combinations + All = 0b111, + Left = 0b110, // skip computation of R·R + Right = 0b011, // skip computation of L·L +}; + +Y_PURE_FUNCTION +TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept; + +/** + * For two vectors L and R computes 3 dot-products: L·L, L·R, R·R + */ +Y_PURE_FUNCTION +static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept { + return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask)); +} + +namespace NDotProduct { + // Simpler wrapper allowing to use this functions as template argument. + template <typename T> + struct TDotProduct { + using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0)); + Y_PURE_FUNCTION + inline TResult operator()(const T* l, const T* r, size_t length) const { + return DotProduct(l, r, length); + } + }; + + void DisableAvx2(); +} + diff --git a/library/cpp/dot_product/dot_product_avx2.cpp b/library/cpp/dot_product/dot_product_avx2.cpp new file mode 100644 index 0000000000..a0f7c169ee --- /dev/null +++ b/library/cpp/dot_product/dot_product_avx2.cpp @@ -0,0 +1,344 @@ +#include "dot_product_avx2.h" +#include "dot_product_simple.h" +#include "dot_product_sse.h" + +#if defined(_avx2_) && defined(_fma_) + +#include <util/system/platform.h> +#include <util/system/compiler.h> +#include <util/generic/utility.h> + +#include <immintrin.h> + +namespace { + constexpr i64 Bits(int n) { + return i64(-1) ^ ((i64(1) << (64 - n)) - 1); + } + + constexpr __m256 BlendMask64[8] = { + __m256i{Bits(64), Bits(64), Bits(64), Bits(64)}, + __m256i{0, Bits(64), Bits(64), Bits(64)}, + __m256i{0, 0, Bits(64), Bits(64)}, + __m256i{0, 0, 0, Bits(64)}, + }; + + constexpr __m256 BlendMask32[8] = { + __m256i{Bits(64), Bits(64), Bits(64), Bits(64)}, + __m256i{Bits(32), Bits(64), Bits(64), Bits(64)}, + __m256i{0, Bits(64), Bits(64), Bits(64)}, + __m256i{0, Bits(32), Bits(64), Bits(64)}, + __m256i{0, 0, Bits(64), Bits(64)}, + __m256i{0, 0, Bits(32), Bits(64)}, + __m256i{0, 0, 0, Bits(64)}, + __m256i{0, 0, 0, Bits(32)}, + }; + + constexpr __m128 BlendMask8[16] = { + __m128i{Bits(64), Bits(64)}, + __m128i{Bits(56), Bits(64)}, + __m128i{Bits(48), Bits(64)}, + __m128i{Bits(40), Bits(64)}, + __m128i{Bits(32), Bits(64)}, + __m128i{Bits(24), Bits(64)}, + __m128i{Bits(16), Bits(64)}, + __m128i{Bits(8), Bits(64)}, + __m128i{0, Bits(64)}, + __m128i{0, Bits(56)}, + __m128i{0, Bits(48)}, + __m128i{0, Bits(40)}, + __m128i{0, Bits(32)}, + __m128i{0, Bits(24)}, + __m128i{0, Bits(16)}, + __m128i{0, Bits(8)}, + }; + + // See https://stackoverflow.com/a/60109639 + // Horizontal sum of eight i32 values in an avx register + i32 HsumI32(__m256i v) { + __m128i x = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); + __m128i hi64 = _mm_unpackhi_epi64(x, x); + __m128i sum64 = _mm_add_epi32(hi64, x); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i sum32 = _mm_add_epi32(sum64, hi32); + return _mm_cvtsi128_si32(sum32); + } + + // Horizontal sum of four i64 values in an avx register + i64 HsumI64(__m256i v) { + __m128i x = _mm_add_epi64(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); + return _mm_cvtsi128_si64(x) + _mm_extract_epi64(x, 1); + } + + // Horizontal sum of eight float values in an avx register + float HsumFloat(__m256 v) { + __m256 y = _mm256_permute2f128_ps(v, v, 1); + v = _mm256_add_ps(v, y); + v = _mm256_hadd_ps(v, v); + return _mm256_cvtss_f32(_mm256_hadd_ps(v, v)); + } + + // Horizontal sum of four double values in an avx register + double HsumDouble(__m256 v) { + __m128d x = _mm_add_pd(_mm256_castpd256_pd128(v), _mm256_extractf128_pd(v, 1)); + x = _mm_add_pd(x, _mm_shuffle_pd(x, x, 1)); + return _mm_cvtsd_f64(x); + } + + __m128i Load128i(const void* ptr) { + return _mm_loadu_si128((const __m128i*)ptr); + } + + __m256i Load256i(const void* ptr) { + return _mm256_loadu_si256((const __m256i*)ptr); + } + + // Unrolled dot product for relatively small sizes + // The loop with known upper bound is unrolled by the compiler, no need to do anything special about it + template <size_t size, class TInput, class TExtend> + i32 DotProductInt8Avx2_Unroll(const TInput* lhs, const TInput* rhs, TExtend extend) noexcept { + static_assert(size % 16 == 0); + auto sum = _mm256_setzero_ps(); + for (size_t i = 0; i != size; i += 16) { + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(extend(Load128i(lhs + i)), extend(Load128i(rhs + i)))); + } + + return HsumI32(sum); + } + + template <class TInput, class TExtend> + i32 DotProductInt8Avx2(const TInput* lhs, const TInput* rhs, size_t length, TExtend extend) noexcept { + // Fully unrolled versions for small multiples for 16 + switch (length) { + case 16: return DotProductInt8Avx2_Unroll<16>(lhs, rhs, extend); + case 32: return DotProductInt8Avx2_Unroll<32>(lhs, rhs, extend); + case 48: return DotProductInt8Avx2_Unroll<48>(lhs, rhs, extend); + case 64: return DotProductInt8Avx2_Unroll<64>(lhs, rhs, extend); + } + + __m256i sum = _mm256_setzero_ps(); + + if (const auto leftover = length % 16; leftover != 0) { + auto a = _mm_blendv_epi8( + Load128i(lhs), _mm_setzero_ps(), BlendMask8[leftover]); + auto b = _mm_blendv_epi8( + Load128i(rhs), _mm_setzero_ps(), BlendMask8[leftover]); + + sum = _mm256_madd_epi16(extend(a), extend(b)); + + lhs += leftover; + rhs += leftover; + length -= leftover; + } + + while (length >= 32) { + const auto l0 = extend(Load128i(lhs)); + const auto r0 = extend(Load128i(rhs)); + const auto l1 = extend(Load128i(lhs + 16)); + const auto r1 = extend(Load128i(rhs + 16)); + + const auto s0 = _mm256_madd_epi16(l0, r0); + const auto s1 = _mm256_madd_epi16(l1, r1); + + sum = _mm256_add_epi32(sum, _mm256_add_epi32(s0, s1)); + + lhs += 32; + rhs += 32; + length -= 32; + } + + if (length > 0) { + auto l = extend(Load128i(lhs)); + auto r = extend(Load128i(rhs)); + + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(l, r)); + } + + return HsumI32(sum); + } +} + +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept { + if (length < 16) { + return DotProductSse(lhs, rhs, length); + } + return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) { + return _mm256_cvtepi8_epi16(x); + }); +} + +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + if (length < 16) { + return DotProductSse(lhs, rhs, length); + } + return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) { + return _mm256_cvtepu8_epi16(x); + }); +} + +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept { + if (length < 16) { + return DotProductSse(lhs, rhs, length); + } + __m256i res = _mm256_setzero_ps(); + + if (const auto leftover = length % 8; leftover != 0) { + // Use floating-point blendv. Who cares as long as the size is right. + __m256i a = _mm256_blendv_ps( + Load256i(lhs), _mm256_setzero_ps(), BlendMask32[leftover]); + __m256i b = _mm256_blendv_ps( + Load256i(rhs), _mm256_setzero_ps(), BlendMask32[leftover]); + + res = _mm256_mul_epi32(a, b); + a = _mm256_alignr_epi8(a, a, 4); + b = _mm256_alignr_epi8(b, b, 4); + res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res); + + lhs += leftover; + rhs += leftover; + length -= leftover; + } + + while (length >= 8) { + __m256i a = Load256i(lhs); + __m256i b = Load256i(rhs); + res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res); // This is lower parts multiplication + a = _mm256_alignr_epi8(a, a, 4); + b = _mm256_alignr_epi8(b, b, 4); + res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res); + rhs += 8; + lhs += 8; + length -= 8; + } + + return HsumI64(res); +} + +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept { + if (length < 16) { + return DotProductSse(lhs, rhs, length); + } + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 a1, b1, a2, b2; + + if (const auto leftover = length % 8; leftover != 0) { + a1 = _mm256_blendv_ps( + _mm256_loadu_ps(lhs), _mm256_setzero_ps(), BlendMask32[leftover]); + b1 = _mm256_blendv_ps( + _mm256_loadu_ps(rhs), _mm256_setzero_ps(), BlendMask32[leftover]); + sum1 = _mm256_mul_ps(a1, b1); + lhs += leftover; + rhs += leftover; + length -= leftover; + } + + while (length >= 16) { + a1 = _mm256_loadu_ps(lhs); + b1 = _mm256_loadu_ps(rhs); + a2 = _mm256_loadu_ps(lhs + 8); + b2 = _mm256_loadu_ps(rhs + 8); + + sum1 = _mm256_fmadd_ps(a1, b1, sum1); + sum2 = _mm256_fmadd_ps(a2, b2, sum2); + + length -= 16; + lhs += 16; + rhs += 16; + } + + if (length > 0) { + a1 = _mm256_loadu_ps(lhs); + b1 = _mm256_loadu_ps(rhs); + sum1 = _mm256_fmadd_ps(a1, b1, sum1); + } + + return HsumFloat(_mm256_add_ps(sum1, sum2)); +} + +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept { + if (length < 16) { + return DotProductSse(lhs, rhs, length); + } + __m256d sum1 = _mm256_setzero_pd(); + __m256d sum2 = _mm256_setzero_pd(); + __m256d a1, b1, a2, b2; + + if (const auto leftover = length % 4; leftover != 0) { + a1 = _mm256_blendv_pd( + _mm256_loadu_pd(lhs), _mm256_setzero_ps(), BlendMask64[leftover]); + b1 = _mm256_blendv_pd( + _mm256_loadu_pd(rhs), _mm256_setzero_ps(), BlendMask64[leftover]); + sum1 = _mm256_mul_pd(a1, b1); + lhs += leftover; + rhs += leftover; + length -= leftover; + } + + while (length >= 8) { + a1 = _mm256_loadu_pd(lhs); + b1 = _mm256_loadu_pd(rhs); + a2 = _mm256_loadu_pd(lhs + 4); + b2 = _mm256_loadu_pd(rhs + 4); + + sum1 = _mm256_fmadd_pd(a1, b1, sum1); + sum2 = _mm256_fmadd_pd(a2, b2, sum2); + + length -= 8; + lhs += 8; + rhs += 8; + } + + if (length > 0) { + a1 = _mm256_loadu_pd(lhs); + b1 = _mm256_loadu_pd(rhs); + sum1 = _mm256_fmadd_pd(a1, b1, sum1); + } + + return HsumDouble(_mm256_add_pd(sum1, sum2)); +} + +#elif defined(ARCADIA_SSE) + +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept { + return DotProductSse(lhs, rhs, length); +} + +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + return DotProductSse(lhs, rhs, length); +} + +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept { + return DotProductSse(lhs, rhs, length); +} + +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept { + return DotProductSse(lhs, rhs, length); +} + +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept { + return DotProductSse(lhs, rhs, length); +} + +#else + +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +#endif diff --git a/library/cpp/dot_product/dot_product_avx2.h b/library/cpp/dot_product/dot_product_avx2.h new file mode 100644 index 0000000000..715f151f44 --- /dev/null +++ b/library/cpp/dot_product/dot_product_avx2.h @@ -0,0 +1,19 @@ +#pragma once + +#include <util/system/types.h> +#include <util/system/compiler.h> + +Y_PURE_FUNCTION +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept; diff --git a/library/cpp/dot_product/dot_product_simple.cpp b/library/cpp/dot_product/dot_product_simple.cpp new file mode 100644 index 0000000000..02891c8a22 --- /dev/null +++ b/library/cpp/dot_product/dot_product_simple.cpp @@ -0,0 +1,44 @@ +#include "dot_product_simple.h" + +namespace { + template <typename Res, typename Number> + static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept { + Res s0 = 0; + Res s1 = 0; + Res s2 = 0; + Res s3 = 0; + + while (length >= 4) { + s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]); + s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]); + s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]); + s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]); + lhs += 4; + rhs += 4; + length -= 4; + } + + while (length--) { + s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++); + } + + return s0 + s1 + s2 + s3; + } +} + +float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept { + return DotProductSimpleImpl<float, float>(lhs, rhs, length); +} + +double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept { + return DotProductSimpleImpl<double, double>(lhs, rhs, length); +} + +ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept { + ui32 res = 0; + for (size_t i = 0; i < lengtInBytes; ++i) { + res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f); + res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8; + } + return res; +} diff --git a/library/cpp/dot_product/dot_product_simple.h b/library/cpp/dot_product/dot_product_simple.h new file mode 100644 index 0000000000..dd13dd7592 --- /dev/null +++ b/library/cpp/dot_product/dot_product_simple.h @@ -0,0 +1,40 @@ +#pragma once + +#include <util/system/compiler.h> +#include <util/system/types.h> + +#include <numeric> + +/** + * Dot product implementation without SSE optimizations. + */ +Y_PURE_FUNCTION +inline ui32 DotProductSimple(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + return std::inner_product(lhs, lhs + length, rhs, static_cast<ui32>(0u), + [](ui32 x1, ui16 x2) {return x1 + x2;}, + [](ui16 x1, ui8 x2) {return x1 * x2;}); +} + +Y_PURE_FUNCTION +inline i32 DotProductSimple(const i8* lhs, const i8* rhs, size_t length) noexcept { + return std::inner_product(lhs, lhs + length, rhs, static_cast<i32>(0), + [](i32 x1, i16 x2) {return x1 + x2;}, + [](i16 x1, i8 x2) {return x1 * x2;}); +} + +Y_PURE_FUNCTION +inline i64 DotProductSimple(const i32* lhs, const i32* rhs, size_t length) noexcept { + return std::inner_product(lhs, lhs + length, rhs, static_cast<i64>(0), + [](i64 x1, i64 x2) {return x1 + x2;}, + [](i64 x1, i32 x2) {return x1 * x2;}); +} + +Y_PURE_FUNCTION +float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept; + diff --git a/library/cpp/dot_product/dot_product_sse.cpp b/library/cpp/dot_product/dot_product_sse.cpp new file mode 100644 index 0000000000..5256cfe98a --- /dev/null +++ b/library/cpp/dot_product/dot_product_sse.cpp @@ -0,0 +1,219 @@ +#include "dot_product_sse.h" + +#include <library/cpp/sse/sse.h> +#include <util/system/platform.h> +#include <util/system/compiler.h> + +#ifdef ARCADIA_SSE +i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec = zero; + while (length >= 16) { + __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); + __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + +#ifdef _sse4_1_ + __m128i lLo = _mm_cvtepi8_epi16(lVec); + __m128i rLo = _mm_cvtepi8_epi16(rVec); + __m128i lHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(lVec, lVec, 8)); + __m128i rHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(rVec, rVec, 8)); +#else + __m128i lLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, lVec), 8); + __m128i rLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, rVec), 8); + __m128i lHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, lVec), 8); + __m128i rHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, rVec), 8); +#endif + resVec = _mm_add_epi32(resVec, + _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) i32 res[4]; + _mm_store_si128((__m128i*)res, resVec); + i32 sum = res[0] + res[1] + res[2] + res[3]; + for (size_t i = 0; i < length; ++i) { + sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]); + } + + return sum; +} + +ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec = zero; + while (length >= 16) { + __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); + __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + + __m128i lLo = _mm_unpacklo_epi8(lVec, zero); + __m128i rLo = _mm_unpacklo_epi8(rVec, zero); + __m128i lHi = _mm_unpackhi_epi8(lVec, zero); + __m128i rHi = _mm_unpackhi_epi8(rVec, zero); + + resVec = _mm_add_epi32(resVec, + _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) i32 res[4]; + _mm_store_si128((__m128i*)res, resVec); + i32 sum = res[0] + res[1] + res[2] + res[3]; + for (size_t i = 0; i < length; ++i) { + sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]); + } + + return static_cast<ui32>(sum); +} +#ifdef _sse4_1_ + +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept { + __m128i zero = _mm_setzero_si128(); + __m128i res = zero; + + while (length >= 4) { + __m128i a = _mm_loadu_si128((const __m128i*)lhs); + __m128i b = _mm_loadu_si128((const __m128i*)rhs); + res = _mm_add_epi64(_mm_mul_epi32(a, b), res); // This is lower parts multiplication + a = _mm_alignr_epi8(a, a, 4); + b = _mm_alignr_epi8(b, b, 4); + res = _mm_add_epi64(_mm_mul_epi32(a, b), res); + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) i64 r[2]; + _mm_store_si128((__m128i*)r, res); + i64 sum = r[0] + r[1]; + + for (size_t i = 0; i < length; ++i) { + sum += static_cast<i64>(lhs[i]) * static_cast<i64>(rhs[i]); + } + + return sum; +} + +#else +#include "dot_product_simple.h" + +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +#endif + +float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept { + __m128 sum1 = _mm_setzero_ps(); + __m128 sum2 = _mm_setzero_ps(); + __m128 a1, b1, a2, b2, m1, m2; + + while (length >= 8) { + a1 = _mm_loadu_ps(lhs); + b1 = _mm_loadu_ps(rhs); + m1 = _mm_mul_ps(a1, b1); + + a2 = _mm_loadu_ps(lhs + 4); + sum1 = _mm_add_ps(sum1, m1); + + b2 = _mm_loadu_ps(rhs + 4); + m2 = _mm_mul_ps(a2, b2); + + sum2 = _mm_add_ps(sum2, m2); + + length -= 8; + lhs += 8; + rhs += 8; + } + + if (length >= 4) { + a1 = _mm_loadu_ps(lhs); + b1 = _mm_loadu_ps(rhs); + sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1)); + + length -= 4; + lhs += 4; + rhs += 4; + } + + sum1 = _mm_add_ps(sum1, sum2); + + if (length) { + switch (length) { + case 3: + a1 = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]); + b1 = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]); + break; + + case 2: + a1 = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]); + b1 = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]); + break; + + case 1: + a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]); + b1 = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]); + break; + + default: + Y_UNREACHABLE(); + } + + sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1)); + } + + alignas(16) float res[4]; + _mm_store_ps(res, sum1); + + return res[0] + res[1] + res[2] + res[3]; +} + +double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept { + __m128d sum1 = _mm_setzero_pd(); + __m128d sum2 = _mm_setzero_pd(); + __m128d a1, b1, a2, b2; + + while (length >= 4) { + a1 = _mm_loadu_pd(lhs); + b1 = _mm_loadu_pd(rhs); + sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + + a2 = _mm_loadu_pd(lhs + 2); + b2 = _mm_loadu_pd(rhs + 2); + sum2 = _mm_add_pd(sum2, _mm_mul_pd(a2, b2)); + + length -= 4; + lhs += 4; + rhs += 4; + } + + if (length >= 2) { + a1 = _mm_loadu_pd(lhs); + b1 = _mm_loadu_pd(rhs); + sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + + length -= 2; + lhs += 2; + rhs += 2; + } + + sum1 = _mm_add_pd(sum1, sum2); + + if (length > 0) { + a1 = _mm_set_pd(lhs[0], 0.0); + b1 = _mm_set_pd(rhs[0], 0.0); + sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + } + + alignas(16) double res[2]; + _mm_store_pd(res, sum1); + + return res[0] + res[1]; +} + +#endif // ARCADIA_SSE diff --git a/library/cpp/dot_product/dot_product_sse.h b/library/cpp/dot_product/dot_product_sse.h new file mode 100644 index 0000000000..814736007d --- /dev/null +++ b/library/cpp/dot_product/dot_product_sse.h @@ -0,0 +1,19 @@ +#pragma once + +#include <util/system/types.h> +#include <util/system/compiler.h> + +Y_PURE_FUNCTION +i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept; diff --git a/library/cpp/dot_product/ya.make b/library/cpp/dot_product/ya.make new file mode 100644 index 0000000000..b308967b4b --- /dev/null +++ b/library/cpp/dot_product/ya.make @@ -0,0 +1,20 @@ +LIBRARY() + +SRCS( + dot_product.cpp + dot_product_sse.cpp + dot_product_simple.cpp +) + +IF (USE_SSE4 == "yes" AND OS_LINUX == "yes") + SRC_C_AVX2(dot_product_avx2.cpp -mfma) +ELSE() + SRC(dot_product_avx2.cpp) +ENDIF() + +PEERDIR( + library/cpp/sse + library/cpp/testing/common +) + +END() |