diff options
author | azevaykin <azevaykin@yandex-team.com> | 2024-03-26 18:35:08 +0300 |
---|---|---|
committer | azevaykin <azevaykin@yandex-team.com> | 2024-03-26 18:49:55 +0300 |
commit | ce034cd07e7ebbff42723e51067bac58d1943f47 (patch) | |
tree | 32ff4320767eecab50c7840d533224f9b710acd3 /library/cpp/dot_product/dot_product_sse.cpp | |
parent | f4ad51b5e6cbbe7d4970e6777f2ab415a25edba7 (diff) | |
download | ydb-ce034cd07e7ebbff42723e51067bac58d1943f47.tar.gz |
Publish library/cpp/dot_product
Publish pod_product to https://github.com/ydb-platform/ydb
It has already been published to github: https://github.com/catboost/catboost/tree/master/library/cpp/dot_product
44150e7508881f4239c960f90320799b1b090072
Diffstat (limited to 'library/cpp/dot_product/dot_product_sse.cpp')
-rw-r--r-- | library/cpp/dot_product/dot_product_sse.cpp | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/library/cpp/dot_product/dot_product_sse.cpp b/library/cpp/dot_product/dot_product_sse.cpp new file mode 100644 index 0000000000..5256cfe98a --- /dev/null +++ b/library/cpp/dot_product/dot_product_sse.cpp @@ -0,0 +1,219 @@ +#include "dot_product_sse.h" + +#include <library/cpp/sse/sse.h> +#include <util/system/platform.h> +#include <util/system/compiler.h> + +#ifdef ARCADIA_SSE +i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec = zero; + while (length >= 16) { + __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); + __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + +#ifdef _sse4_1_ + __m128i lLo = _mm_cvtepi8_epi16(lVec); + __m128i rLo = _mm_cvtepi8_epi16(rVec); + __m128i lHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(lVec, lVec, 8)); + __m128i rHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(rVec, rVec, 8)); +#else + __m128i lLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, lVec), 8); + __m128i rLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, rVec), 8); + __m128i lHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, lVec), 8); + __m128i rHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, rVec), 8); +#endif + resVec = _mm_add_epi32(resVec, + _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) i32 res[4]; + _mm_store_si128((__m128i*)res, resVec); + i32 sum = res[0] + res[1] + res[2] + res[3]; + for (size_t i = 0; i < length; ++i) { + sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]); + } + + return sum; +} + +ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec = zero; + while (length >= 16) { + __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); + __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + + __m128i lLo = _mm_unpacklo_epi8(lVec, zero); + __m128i rLo = _mm_unpacklo_epi8(rVec, zero); + __m128i lHi = _mm_unpackhi_epi8(lVec, zero); + __m128i rHi = _mm_unpackhi_epi8(rVec, zero); + + resVec = _mm_add_epi32(resVec, + _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) i32 res[4]; + _mm_store_si128((__m128i*)res, resVec); + i32 sum = res[0] + res[1] + res[2] + res[3]; + for (size_t i = 0; i < length; ++i) { + sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]); + } + + return static_cast<ui32>(sum); +} +#ifdef _sse4_1_ + +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept { + __m128i zero = _mm_setzero_si128(); + __m128i res = zero; + + while (length >= 4) { + __m128i a = _mm_loadu_si128((const __m128i*)lhs); + __m128i b = _mm_loadu_si128((const __m128i*)rhs); + res = _mm_add_epi64(_mm_mul_epi32(a, b), res); // This is lower parts multiplication + a = _mm_alignr_epi8(a, a, 4); + b = _mm_alignr_epi8(b, b, 4); + res = _mm_add_epi64(_mm_mul_epi32(a, b), res); + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) i64 r[2]; + _mm_store_si128((__m128i*)r, res); + i64 sum = r[0] + r[1]; + + for (size_t i = 0; i < length; ++i) { + sum += static_cast<i64>(lhs[i]) * static_cast<i64>(rhs[i]); + } + + return sum; +} + +#else +#include "dot_product_simple.h" + +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept { + return DotProductSimple(lhs, rhs, length); +} + +#endif + +float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept { + __m128 sum1 = _mm_setzero_ps(); + __m128 sum2 = _mm_setzero_ps(); + __m128 a1, b1, a2, b2, m1, m2; + + while (length >= 8) { + a1 = _mm_loadu_ps(lhs); + b1 = _mm_loadu_ps(rhs); + m1 = _mm_mul_ps(a1, b1); + + a2 = _mm_loadu_ps(lhs + 4); + sum1 = _mm_add_ps(sum1, m1); + + b2 = _mm_loadu_ps(rhs + 4); + m2 = _mm_mul_ps(a2, b2); + + sum2 = _mm_add_ps(sum2, m2); + + length -= 8; + lhs += 8; + rhs += 8; + } + + if (length >= 4) { + a1 = _mm_loadu_ps(lhs); + b1 = _mm_loadu_ps(rhs); + sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1)); + + length -= 4; + lhs += 4; + rhs += 4; + } + + sum1 = _mm_add_ps(sum1, sum2); + + if (length) { + switch (length) { + case 3: + a1 = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]); + b1 = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]); + break; + + case 2: + a1 = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]); + b1 = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]); + break; + + case 1: + a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]); + b1 = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]); + break; + + default: + Y_UNREACHABLE(); + } + + sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1)); + } + + alignas(16) float res[4]; + _mm_store_ps(res, sum1); + + return res[0] + res[1] + res[2] + res[3]; +} + +double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept { + __m128d sum1 = _mm_setzero_pd(); + __m128d sum2 = _mm_setzero_pd(); + __m128d a1, b1, a2, b2; + + while (length >= 4) { + a1 = _mm_loadu_pd(lhs); + b1 = _mm_loadu_pd(rhs); + sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + + a2 = _mm_loadu_pd(lhs + 2); + b2 = _mm_loadu_pd(rhs + 2); + sum2 = _mm_add_pd(sum2, _mm_mul_pd(a2, b2)); + + length -= 4; + lhs += 4; + rhs += 4; + } + + if (length >= 2) { + a1 = _mm_loadu_pd(lhs); + b1 = _mm_loadu_pd(rhs); + sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + + length -= 2; + lhs += 2; + rhs += 2; + } + + sum1 = _mm_add_pd(sum1, sum2); + + if (length > 0) { + a1 = _mm_set_pd(lhs[0], 0.0); + b1 = _mm_set_pd(rhs[0], 0.0); + sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + } + + alignas(16) double res[2]; + _mm_store_pd(res, sum1); + + return res[0] + res[1]; +} + +#endif // ARCADIA_SSE |