diff options
| author | azevaykin <[email protected]> | 2024-03-26 18:35:08 +0300 | 
|---|---|---|
| committer | azevaykin <[email protected]> | 2024-03-26 18:49:55 +0300 | 
| commit | ce034cd07e7ebbff42723e51067bac58d1943f47 (patch) | |
| tree | 32ff4320767eecab50c7840d533224f9b710acd3 /library/cpp | |
| parent | f4ad51b5e6cbbe7d4970e6777f2ab415a25edba7 (diff) | |
Publish library/cpp/dot_product
Publish pod_product to  https://github.com/ydb-platform/ydb
It has already been published to github: https://github.com/catboost/catboost/tree/master/library/cpp/dot_product
44150e7508881f4239c960f90320799b1b090072
Diffstat (limited to 'library/cpp')
| -rw-r--r-- | library/cpp/dot_product/README.md | 15 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product.cpp | 274 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product.h | 96 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product_avx2.cpp | 344 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product_avx2.h | 19 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product_simple.cpp | 44 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product_simple.h | 40 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product_sse.cpp | 219 | ||||
| -rw-r--r-- | library/cpp/dot_product/dot_product_sse.h | 19 | ||||
| -rw-r--r-- | library/cpp/dot_product/ya.make | 20 | 
10 files changed, 1090 insertions, 0 deletions
diff --git a/library/cpp/dot_product/README.md b/library/cpp/dot_product/README.md new file mode 100644 index 00000000000..516dcf31de3 --- /dev/null +++ b/library/cpp/dot_product/README.md @@ -0,0 +1,15 @@ +Библиотека для вычисления скалярного произведения векторов. +===================================================== + +Данная библиотека содержит функцию DotProduct, вычисляющую скалярное произведение векторов различных типов. +В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для сравнения +можно посмотреть результаты бенчмарка. + +Типичное использование - замена кусков кода вроде: +``` +for (int i = 0; i < len; i++) +    dot_product += a[i] * b[i]); +``` +на существенно более эффективный вызов ```DotProduct(a, b, len)```. + +Работает для типов i8, i32, float, double. diff --git a/library/cpp/dot_product/dot_product.cpp b/library/cpp/dot_product/dot_product.cpp new file mode 100644 index 00000000000..6be4d0a78f3 --- /dev/null +++ b/library/cpp/dot_product/dot_product.cpp @@ -0,0 +1,274 @@ +#include "dot_product.h" +#include "dot_product_sse.h" +#include "dot_product_avx2.h" +#include "dot_product_simple.h" + +#include <library/cpp/sse/sse.h> +#include <library/cpp/testing/common/env.h> +#include <util/system/compiler.h> +#include <util/generic/utility.h> +#include <util/system/cpu_id.h> +#include <util/system/env.h> + +namespace NDotProductImpl { +    i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple; +    ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple; +    i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple; +    float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple; +    double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple; + +    namespace { +        [[maybe_unused]] const int _ = [] { +            if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) { +                DotProductI8Impl = &DotProductAvx2; +                DotProductUi8Impl = &DotProductAvx2; +                DotProductI32Impl = &DotProductAvx2; +                DotProductFloatImpl = &DotProductAvx2; +                DotProductDoubleImpl = &DotProductAvx2; +            } else { +#ifdef ARCADIA_SSE +                DotProductI8Impl = &DotProductSse; +                DotProductUi8Impl = &DotProductSse; +                DotProductI32Impl = &DotProductSse; +                DotProductFloatImpl = &DotProductSse; +                DotProductDoubleImpl = &DotProductSse; +#endif +            } +            return 0; +        }(); +    } +} + +#ifdef ARCADIA_SSE +float L2NormSquared(const float* v, size_t length) noexcept { +    __m128 sum1 = _mm_setzero_ps(); +    __m128 sum2 = _mm_setzero_ps(); +    __m128 a1, a2, m1, m2; + +    while (length >= 8) { +        a1 = _mm_loadu_ps(v); +        m1 = _mm_mul_ps(a1, a1); + +        a2 = _mm_loadu_ps(v + 4); +        sum1 = _mm_add_ps(sum1, m1); + +        m2 = _mm_mul_ps(a2, a2); +        sum2 = _mm_add_ps(sum2, m2); + +        length -= 8; +        v += 8; +    } + +    if (length >= 4) { +        a1 = _mm_loadu_ps(v); +        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1)); + +        length -= 4; +        v += 4; +    } + +    sum1 = _mm_add_ps(sum1, sum2); + +    if (length) { +        switch (length) { +            case 3: +                a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]); +                break; + +            case 2: +                a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]); +                break; + +            case 1: +                a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]); +                break; + +            default: +                Y_UNREACHABLE(); +        } + +        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1)); +    } + +    alignas(16) float res[4]; +    _mm_store_ps(res, sum1); + +    return res[0] + res[1] + res[2] + res[3]; +} + +template <bool computeLL, bool computeLR, bool computeRR> +Y_FORCE_INLINE +static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) { +    if constexpr (computeLL) { +        sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a)); +    } +    if constexpr (computeLR) { +        sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b)); +    } +    if constexpr (computeRR) { +        sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b)); +    } +} + + +template <bool computeLL, bool computeLR, bool computeRR> +static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept { +    __m128 sumLL1 = _mm_setzero_ps(); +    __m128 sumLR1 = _mm_setzero_ps(); +    __m128 sumRR1 = _mm_setzero_ps(); +    __m128 sumLL2 = _mm_setzero_ps(); +    __m128 sumLR2 = _mm_setzero_ps(); +    __m128 sumRR2 = _mm_setzero_ps(); + +    while (length >= 8) { +        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); +        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4)); +        length -= 8; +        lhs += 8; +        rhs += 8; +    } + +    if (length >= 4) { +        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); +        length -= 4; +        lhs += 4; +        rhs += 4; +    } + +    if constexpr (computeLL) { +        sumLL1 = _mm_add_ps(sumLL1, sumLL2); +    } +    if constexpr (computeLR) { +        sumLR1 = _mm_add_ps(sumLR1, sumLR2); +    } +    if constexpr (computeRR) { +        sumRR1 = _mm_add_ps(sumRR1, sumRR2); +    } + +    if (length) { +        __m128 a, b; +        switch (length) { +            case 3: +                a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]); +                b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]); +                break; +            case 2: +                a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]); +                b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]); +                break; +            case 1: +                a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]); +                b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]); +                break; +            default: +                Y_UNREACHABLE(); +        } +        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b); +    } + +    __m128 t0 = sumLL1; +    __m128 t1 = sumLR1; +    __m128 t2 = sumRR1; +    __m128 t3 = _mm_setzero_ps(); +    _MM_TRANSPOSE4_PS(t0, t1, t2, t3); +    t0 = _mm_add_ps(t0, t1); +    t0 = _mm_add_ps(t0, t2); +    t0 = _mm_add_ps(t0, t3); + +    alignas(16) float res[4]; +    _mm_store_ps(res, t0); +    TTriWayDotProduct<float> result{res[0], res[1], res[2]}; +    static constexpr const TTriWayDotProduct<float> def; +    // fill skipped fields with default values +    if constexpr (!computeLL) { +        result.LL = def.LL; +    } +    if constexpr (!computeLR) { +        result.LR = def.LR; +    } +    if constexpr (!computeRR) { +        result.RR = def.RR; +    } +    return result; +} + + +TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept { +    mask &= 0b111; +    if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors +        return TriWayDotProductImpl<true, true, true>(lhs, rhs, length); +    } else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector +        const bool computeLL = (mask == 0b110); +        if (!computeLL) { +            DoSwap(lhs, rhs); +        } +        auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length); +        if (!computeLL) { +            DoSwap(result.LL, result.RR); +        } +        return result; +    } else { +        // dispatch unlikely & sparse cases +        TTriWayDotProduct<float> result{}; +        switch(mask) { +            case 0b000: +                break; +            case 0b100: +                result.LL = L2NormSquared(lhs, length); +                break; +            case 0b010: +                result.LR = DotProduct(lhs, rhs, length); +                break; +            case 0b001: +                result.RR = L2NormSquared(rhs, length); +                break; +            case 0b101: +                result.LL = L2NormSquared(lhs, length); +                result.RR = L2NormSquared(rhs, length); +                break; +            default: +                Y_UNREACHABLE(); +        } +        return result; +    } +} + +#else + +float L2NormSquared(const float* v, size_t length) noexcept { +    return DotProduct(v, v, length); +} + +TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept { +    TTriWayDotProduct<float> result; +    if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) { +        result.LL = L2NormSquared(lhs, length); +    } +    if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) { +        result.LR = DotProduct(lhs, rhs, length); +    } +    if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) { +        result.RR = L2NormSquared(rhs, length); +    } +    return result; +} + +#endif // ARCADIA_SSE + +namespace NDotProduct { +    void DisableAvx2() { +#ifdef ARCADIA_SSE +        NDotProductImpl::DotProductI8Impl = &DotProductSse; +        NDotProductImpl::DotProductUi8Impl = &DotProductSse; +        NDotProductImpl::DotProductI32Impl = &DotProductSse; +        NDotProductImpl::DotProductFloatImpl = &DotProductSse; +        NDotProductImpl::DotProductDoubleImpl = &DotProductSse; +#else +        NDotProductImpl::DotProductI8Impl = &DotProductSimple; +        NDotProductImpl::DotProductUi8Impl = &DotProductSimple; +        NDotProductImpl::DotProductI32Impl = &DotProductSimple; +        NDotProductImpl::DotProductFloatImpl = &DotProductSimple; +        NDotProductImpl::DotProductDoubleImpl = &DotProductSimple; +#endif +    } +} diff --git a/library/cpp/dot_product/dot_product.h b/library/cpp/dot_product/dot_product.h new file mode 100644 index 00000000000..0765633abdb --- /dev/null +++ b/library/cpp/dot_product/dot_product.h @@ -0,0 +1,96 @@ +#pragma once + +#include <util/system/types.h> +#include <util/system/compiler.h> + +#include <numeric> + +/** + * Dot product (Inner product or scalar product) implementation using SSE when possible. + */ +namespace NDotProductImpl { +    extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept; +    extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept; +    extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept; +    extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept; +    extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept; +} + +Y_PURE_FUNCTION +inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept { +    return NDotProductImpl::DotProductI8Impl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept { +    return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept { +    return NDotProductImpl::DotProductI32Impl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept { +    return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length); +} + +Y_PURE_FUNCTION +inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept { +    return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length); +} + +/** + * Dot product to itself + */ +Y_PURE_FUNCTION +float L2NormSquared(const float* v, size_t length) noexcept; + +// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct` +// where `lhs == rhs` because it will save N load instructions. + +template <typename T> +struct TTriWayDotProduct { +    T LL = 1; +    T LR = 0; +    T RR = 1; +}; + +enum class ETriWayDotProductComputeMask: unsigned { +    // basic +    LL = 0b100, +    LR = 0b010, +    RR = 0b001, + +    // useful combinations +    All = 0b111, +    Left = 0b110, // skip computation of R·R +    Right = 0b011, // skip computation of L·L +}; + +Y_PURE_FUNCTION +TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept; + +/** + * For two vectors L and R computes 3 dot-products: L·L, L·R, R·R + */ +Y_PURE_FUNCTION +static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept { +    return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask)); +} + +namespace NDotProduct { +    // Simpler wrapper allowing to use this functions as template argument. +    template <typename T> +    struct TDotProduct { +        using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0)); +        Y_PURE_FUNCTION +        inline TResult operator()(const T* l, const T* r, size_t length) const { +            return DotProduct(l, r, length); +        } +    }; + +    void DisableAvx2(); +} + diff --git a/library/cpp/dot_product/dot_product_avx2.cpp b/library/cpp/dot_product/dot_product_avx2.cpp new file mode 100644 index 00000000000..a0f7c169ee7 --- /dev/null +++ b/library/cpp/dot_product/dot_product_avx2.cpp @@ -0,0 +1,344 @@ +#include "dot_product_avx2.h" +#include "dot_product_simple.h" +#include "dot_product_sse.h" + +#if defined(_avx2_) && defined(_fma_) + +#include <util/system/platform.h> +#include <util/system/compiler.h> +#include <util/generic/utility.h> + +#include <immintrin.h> + +namespace { +    constexpr i64 Bits(int n) { +        return i64(-1) ^ ((i64(1) << (64 - n)) - 1); +    } + +    constexpr __m256 BlendMask64[8] = { +        __m256i{Bits(64), Bits(64), Bits(64), Bits(64)}, +        __m256i{0, Bits(64), Bits(64), Bits(64)}, +        __m256i{0, 0, Bits(64), Bits(64)}, +        __m256i{0, 0, 0, Bits(64)}, +    }; + +    constexpr __m256 BlendMask32[8] = { +        __m256i{Bits(64), Bits(64), Bits(64), Bits(64)}, +        __m256i{Bits(32), Bits(64), Bits(64), Bits(64)}, +        __m256i{0, Bits(64), Bits(64), Bits(64)}, +        __m256i{0, Bits(32), Bits(64), Bits(64)}, +        __m256i{0, 0, Bits(64), Bits(64)}, +        __m256i{0, 0, Bits(32), Bits(64)}, +        __m256i{0, 0, 0, Bits(64)}, +        __m256i{0, 0, 0, Bits(32)}, +    }; + +    constexpr __m128 BlendMask8[16] = { +        __m128i{Bits(64), Bits(64)}, +        __m128i{Bits(56), Bits(64)}, +        __m128i{Bits(48), Bits(64)}, +        __m128i{Bits(40), Bits(64)}, +        __m128i{Bits(32), Bits(64)}, +        __m128i{Bits(24), Bits(64)}, +        __m128i{Bits(16), Bits(64)}, +        __m128i{Bits(8), Bits(64)}, +        __m128i{0, Bits(64)}, +        __m128i{0, Bits(56)}, +        __m128i{0, Bits(48)}, +        __m128i{0, Bits(40)}, +        __m128i{0, Bits(32)}, +        __m128i{0, Bits(24)}, +        __m128i{0, Bits(16)}, +        __m128i{0, Bits(8)}, +    }; + +    // See https://stackoverflow.com/a/60109639 +    // Horizontal sum of eight i32 values in an avx register +    i32 HsumI32(__m256i v) { +        __m128i x = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); +        __m128i hi64  = _mm_unpackhi_epi64(x, x); +        __m128i sum64 = _mm_add_epi32(hi64, x); +        __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); +        __m128i sum32 = _mm_add_epi32(sum64, hi32); +        return _mm_cvtsi128_si32(sum32); +    } + +    // Horizontal sum of four i64 values in an avx register +    i64 HsumI64(__m256i v) { +        __m128i x = _mm_add_epi64(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); +        return _mm_cvtsi128_si64(x) + _mm_extract_epi64(x, 1); +    } + +    // Horizontal sum of eight float values in an avx register +    float HsumFloat(__m256 v) { +        __m256 y = _mm256_permute2f128_ps(v, v, 1); +        v = _mm256_add_ps(v, y); +        v = _mm256_hadd_ps(v, v); +        return _mm256_cvtss_f32(_mm256_hadd_ps(v, v)); +    } + +    // Horizontal sum of four double values in an avx register +    double HsumDouble(__m256 v) { +        __m128d x = _mm_add_pd(_mm256_castpd256_pd128(v), _mm256_extractf128_pd(v, 1)); +        x = _mm_add_pd(x, _mm_shuffle_pd(x, x, 1)); +        return _mm_cvtsd_f64(x); +    } + +    __m128i Load128i(const void* ptr) { +        return _mm_loadu_si128((const __m128i*)ptr); +    } + +    __m256i Load256i(const void* ptr) { +        return _mm256_loadu_si256((const __m256i*)ptr); +    } + +    // Unrolled dot product for relatively small sizes +    // The loop with known upper bound is unrolled by the compiler, no need to do anything special about it +    template <size_t size, class TInput, class TExtend> +    i32 DotProductInt8Avx2_Unroll(const TInput* lhs, const TInput* rhs, TExtend extend) noexcept { +        static_assert(size % 16 == 0); +        auto sum = _mm256_setzero_ps(); +        for (size_t i = 0; i != size; i += 16) { +            sum = _mm256_add_epi32(sum, _mm256_madd_epi16(extend(Load128i(lhs + i)), extend(Load128i(rhs + i)))); +        } + +        return HsumI32(sum); +    } + +    template <class TInput, class TExtend> +    i32 DotProductInt8Avx2(const TInput* lhs, const TInput* rhs, size_t length, TExtend extend) noexcept { +        // Fully unrolled versions for small multiples for 16 +        switch (length) { +            case 16: return DotProductInt8Avx2_Unroll<16>(lhs, rhs, extend); +            case 32: return DotProductInt8Avx2_Unroll<32>(lhs, rhs, extend); +            case 48: return DotProductInt8Avx2_Unroll<48>(lhs, rhs, extend); +            case 64: return DotProductInt8Avx2_Unroll<64>(lhs, rhs, extend); +        } + +        __m256i sum = _mm256_setzero_ps(); + +        if (const auto leftover = length % 16; leftover != 0) { +            auto a = _mm_blendv_epi8( +                    Load128i(lhs), _mm_setzero_ps(), BlendMask8[leftover]); +            auto b = _mm_blendv_epi8( +                    Load128i(rhs), _mm_setzero_ps(), BlendMask8[leftover]); + +            sum = _mm256_madd_epi16(extend(a), extend(b)); + +            lhs += leftover; +            rhs += leftover; +            length -= leftover; +        } + +        while (length >= 32) { +            const auto l0 = extend(Load128i(lhs)); +            const auto r0 = extend(Load128i(rhs)); +            const auto l1 = extend(Load128i(lhs + 16)); +            const auto r1 = extend(Load128i(rhs + 16)); + +            const auto s0 = _mm256_madd_epi16(l0, r0); +            const auto s1 = _mm256_madd_epi16(l1, r1); + +            sum = _mm256_add_epi32(sum, _mm256_add_epi32(s0, s1)); + +            lhs += 32; +            rhs += 32; +            length -= 32; +        } + +        if (length > 0) { +            auto l = extend(Load128i(lhs)); +            auto r = extend(Load128i(rhs)); + +            sum = _mm256_add_epi32(sum, _mm256_madd_epi16(l, r)); +        } + +        return HsumI32(sum); +    } +} + +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept { +    if (length < 16) { +        return DotProductSse(lhs, rhs, length); +    } +    return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) { +        return _mm256_cvtepi8_epi16(x); +    }); +} + +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept { +    if (length < 16) { +        return DotProductSse(lhs, rhs, length); +    } +    return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) { +        return _mm256_cvtepu8_epi16(x); +    }); +} + +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept { +    if (length < 16) { +        return DotProductSse(lhs, rhs, length); +    } +    __m256i res = _mm256_setzero_ps(); + +    if (const auto leftover = length % 8; leftover != 0) { +        // Use floating-point blendv. Who cares as long as the size is right. +        __m256i a = _mm256_blendv_ps( +                Load256i(lhs), _mm256_setzero_ps(), BlendMask32[leftover]); +        __m256i b = _mm256_blendv_ps( +                Load256i(rhs), _mm256_setzero_ps(), BlendMask32[leftover]); + +        res = _mm256_mul_epi32(a, b); +        a = _mm256_alignr_epi8(a, a, 4); +        b = _mm256_alignr_epi8(b, b, 4); +        res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res); + +        lhs += leftover; +        rhs += leftover; +        length -= leftover; +    } + +    while (length >= 8) { +        __m256i a = Load256i(lhs); +        __m256i b = Load256i(rhs); +        res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res);    // This is lower parts multiplication +        a = _mm256_alignr_epi8(a, a, 4); +        b = _mm256_alignr_epi8(b, b, 4); +        res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res); +        rhs += 8; +        lhs += 8; +        length -= 8; +    } + +    return HsumI64(res); +} + +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept { +    if (length < 16) { +        return DotProductSse(lhs, rhs, length); +    } +    __m256 sum1 = _mm256_setzero_ps(); +    __m256 sum2 = _mm256_setzero_ps(); +    __m256 a1, b1, a2, b2; + +    if (const auto leftover = length % 8; leftover != 0) { +        a1 = _mm256_blendv_ps( +                _mm256_loadu_ps(lhs), _mm256_setzero_ps(), BlendMask32[leftover]); +        b1 = _mm256_blendv_ps( +                _mm256_loadu_ps(rhs), _mm256_setzero_ps(), BlendMask32[leftover]); +        sum1 = _mm256_mul_ps(a1, b1); +        lhs += leftover; +        rhs += leftover; +        length -= leftover; +    } + +    while (length >= 16) { +        a1 = _mm256_loadu_ps(lhs); +        b1 = _mm256_loadu_ps(rhs); +        a2 = _mm256_loadu_ps(lhs + 8); +        b2 = _mm256_loadu_ps(rhs + 8); + +        sum1 = _mm256_fmadd_ps(a1, b1, sum1); +        sum2 = _mm256_fmadd_ps(a2, b2, sum2); + +        length -= 16; +        lhs += 16; +        rhs += 16; +    } + +    if (length > 0) { +        a1 = _mm256_loadu_ps(lhs); +        b1 = _mm256_loadu_ps(rhs); +        sum1 = _mm256_fmadd_ps(a1, b1, sum1); +    } + +    return HsumFloat(_mm256_add_ps(sum1, sum2)); +} + +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept { +    if (length < 16) { +        return DotProductSse(lhs, rhs, length); +    } +    __m256d sum1 = _mm256_setzero_pd(); +    __m256d sum2 = _mm256_setzero_pd(); +    __m256d a1, b1, a2, b2; + +    if (const auto leftover = length % 4; leftover != 0) { +        a1 = _mm256_blendv_pd( +                _mm256_loadu_pd(lhs), _mm256_setzero_ps(), BlendMask64[leftover]); +        b1 = _mm256_blendv_pd( +                _mm256_loadu_pd(rhs), _mm256_setzero_ps(), BlendMask64[leftover]); +        sum1 = _mm256_mul_pd(a1, b1); +        lhs += leftover; +        rhs += leftover; +        length -= leftover; +    } + +    while (length >= 8) { +        a1 = _mm256_loadu_pd(lhs); +        b1 = _mm256_loadu_pd(rhs); +        a2 = _mm256_loadu_pd(lhs + 4); +        b2 = _mm256_loadu_pd(rhs + 4); + +        sum1 = _mm256_fmadd_pd(a1, b1, sum1); +        sum2 = _mm256_fmadd_pd(a2, b2, sum2); + +        length -= 8; +        lhs += 8; +        rhs += 8; +    } + +    if (length > 0) { +        a1 = _mm256_loadu_pd(lhs); +        b1 = _mm256_loadu_pd(rhs); +        sum1 = _mm256_fmadd_pd(a1, b1, sum1); +    } + +    return HsumDouble(_mm256_add_pd(sum1, sum2)); +} + +#elif defined(ARCADIA_SSE) + +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept { +    return DotProductSse(lhs, rhs, length); +} + +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept { +    return DotProductSse(lhs, rhs, length); +} + +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept { +    return DotProductSse(lhs, rhs, length); +} + +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept { +    return DotProductSse(lhs, rhs, length); +} + +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept { +    return DotProductSse(lhs, rhs, length); +} + +#else + +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept { +    return DotProductSimple(lhs, rhs, length); +} + +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept { +    return DotProductSimple(lhs, rhs, length); +} + +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept { +    return DotProductSimple(lhs, rhs, length); +} + +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept { +    return DotProductSimple(lhs, rhs, length); +} + +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept { +    return DotProductSimple(lhs, rhs, length); +} + +#endif diff --git a/library/cpp/dot_product/dot_product_avx2.h b/library/cpp/dot_product/dot_product_avx2.h new file mode 100644 index 00000000000..715f151f448 --- /dev/null +++ b/library/cpp/dot_product/dot_product_avx2.h @@ -0,0 +1,19 @@ +#pragma once + +#include <util/system/types.h> +#include <util/system/compiler.h> + +Y_PURE_FUNCTION +i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept; diff --git a/library/cpp/dot_product/dot_product_simple.cpp b/library/cpp/dot_product/dot_product_simple.cpp new file mode 100644 index 00000000000..02891c8a228 --- /dev/null +++ b/library/cpp/dot_product/dot_product_simple.cpp @@ -0,0 +1,44 @@ +#include "dot_product_simple.h" + +namespace { +    template <typename Res, typename Number> +    static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept { +        Res s0 = 0; +        Res s1 = 0; +        Res s2 = 0; +        Res s3 = 0; + +        while (length >= 4) { +            s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]); +            s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]); +            s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]); +            s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]); +            lhs += 4; +            rhs += 4; +            length -= 4; +        } + +        while (length--) { +            s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++); +        } + +        return s0 + s1 + s2 + s3; +    } +} + +float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept { +    return DotProductSimpleImpl<float, float>(lhs, rhs, length); +} + +double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept { +    return DotProductSimpleImpl<double, double>(lhs, rhs, length); +} + +ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept { +    ui32 res = 0; +    for (size_t i = 0; i < lengtInBytes; ++i) { +        res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f); +        res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8; +    } +    return res; +} diff --git a/library/cpp/dot_product/dot_product_simple.h b/library/cpp/dot_product/dot_product_simple.h new file mode 100644 index 00000000000..dd13dd7592c --- /dev/null +++ b/library/cpp/dot_product/dot_product_simple.h @@ -0,0 +1,40 @@ +#pragma once + +#include <util/system/compiler.h> +#include <util/system/types.h> + +#include <numeric> + +/** + * Dot product implementation without SSE optimizations. + */ +Y_PURE_FUNCTION +inline ui32 DotProductSimple(const ui8* lhs, const ui8* rhs, size_t length) noexcept { +    return std::inner_product(lhs, lhs + length, rhs, static_cast<ui32>(0u), +                              [](ui32 x1, ui16 x2) {return x1 + x2;}, +                              [](ui16 x1, ui8 x2) {return x1 * x2;}); +} + +Y_PURE_FUNCTION +inline i32 DotProductSimple(const i8* lhs, const i8* rhs, size_t length) noexcept { +    return std::inner_product(lhs, lhs + length, rhs, static_cast<i32>(0), +                              [](i32 x1, i16 x2) {return x1 + x2;}, +                              [](i16 x1, i8 x2) {return x1 * x2;}); +} + +Y_PURE_FUNCTION +inline i64 DotProductSimple(const i32* lhs, const i32* rhs, size_t length) noexcept { +    return std::inner_product(lhs, lhs + length, rhs, static_cast<i64>(0), +                              [](i64 x1, i64 x2) {return x1 + x2;}, +                              [](i64 x1, i32 x2) {return x1 * x2;}); +} + +Y_PURE_FUNCTION +float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept; + diff --git a/library/cpp/dot_product/dot_product_sse.cpp b/library/cpp/dot_product/dot_product_sse.cpp new file mode 100644 index 00000000000..5256cfe98af --- /dev/null +++ b/library/cpp/dot_product/dot_product_sse.cpp @@ -0,0 +1,219 @@ +#include "dot_product_sse.h" + +#include <library/cpp/sse/sse.h> +#include <util/system/platform.h> +#include <util/system/compiler.h> + +#ifdef ARCADIA_SSE +i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept { +    const __m128i zero = _mm_setzero_si128(); +    __m128i resVec = zero; +    while (length >= 16) { +        __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); +        __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + +#ifdef _sse4_1_ +        __m128i lLo = _mm_cvtepi8_epi16(lVec); +        __m128i rLo = _mm_cvtepi8_epi16(rVec); +        __m128i lHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(lVec, lVec, 8)); +        __m128i rHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(rVec, rVec, 8)); +#else +        __m128i lLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, lVec), 8); +        __m128i rLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, rVec), 8); +        __m128i lHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, lVec), 8); +        __m128i rHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, rVec), 8); +#endif +        resVec = _mm_add_epi32(resVec, +                               _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi))); + +        lhs += 16; +        rhs += 16; +        length -= 16; +    } + +    alignas(16) i32 res[4]; +    _mm_store_si128((__m128i*)res, resVec); +    i32 sum = res[0] + res[1] + res[2] + res[3]; +    for (size_t i = 0; i < length; ++i) { +        sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]); +    } + +    return sum; +} + +ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept { +    const __m128i zero = _mm_setzero_si128(); +    __m128i resVec = zero; +    while (length >= 16) { +        __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); +        __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + +        __m128i lLo = _mm_unpacklo_epi8(lVec, zero); +        __m128i rLo = _mm_unpacklo_epi8(rVec, zero); +        __m128i lHi = _mm_unpackhi_epi8(lVec, zero); +        __m128i rHi = _mm_unpackhi_epi8(rVec, zero); + +        resVec = _mm_add_epi32(resVec, +                               _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi))); + +        lhs += 16; +        rhs += 16; +        length -= 16; +    } + +    alignas(16) i32 res[4]; +    _mm_store_si128((__m128i*)res, resVec); +    i32 sum = res[0] + res[1] + res[2] + res[3]; +    for (size_t i = 0; i < length; ++i) { +        sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]); +    } + +    return static_cast<ui32>(sum); +} +#ifdef _sse4_1_ + +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept { +    __m128i zero = _mm_setzero_si128(); +    __m128i res = zero; + +    while (length >= 4) { +        __m128i a = _mm_loadu_si128((const __m128i*)lhs); +        __m128i b = _mm_loadu_si128((const __m128i*)rhs); +        res = _mm_add_epi64(_mm_mul_epi32(a, b), res);    // This is lower parts multiplication +        a = _mm_alignr_epi8(a, a, 4); +        b = _mm_alignr_epi8(b, b, 4); +        res = _mm_add_epi64(_mm_mul_epi32(a, b), res); +        rhs += 4; +        lhs += 4; +        length -= 4; +    } + +    alignas(16) i64 r[2]; +    _mm_store_si128((__m128i*)r, res); +    i64 sum = r[0] + r[1]; + +    for (size_t i = 0; i < length; ++i) { +        sum += static_cast<i64>(lhs[i]) * static_cast<i64>(rhs[i]); +    } + +    return sum; +} + +#else +#include "dot_product_simple.h" + +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept { +    return DotProductSimple(lhs, rhs, length); +} + +#endif + +float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept { +    __m128 sum1 = _mm_setzero_ps(); +    __m128 sum2 = _mm_setzero_ps(); +    __m128 a1, b1, a2, b2, m1, m2; + +    while (length >= 8) { +        a1 = _mm_loadu_ps(lhs); +        b1 = _mm_loadu_ps(rhs); +        m1 = _mm_mul_ps(a1, b1); + +        a2 = _mm_loadu_ps(lhs + 4); +        sum1 = _mm_add_ps(sum1, m1); + +        b2 = _mm_loadu_ps(rhs + 4); +        m2 = _mm_mul_ps(a2, b2); + +        sum2 = _mm_add_ps(sum2, m2); + +        length -= 8; +        lhs += 8; +        rhs += 8; +    } + +    if (length >= 4) { +        a1 = _mm_loadu_ps(lhs); +        b1 = _mm_loadu_ps(rhs); +        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1)); + +        length -= 4; +        lhs += 4; +        rhs += 4; +    } + +    sum1 = _mm_add_ps(sum1, sum2); + +    if (length) { +        switch (length) { +            case 3: +                a1 = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]); +                b1 = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]); +                break; + +            case 2: +                a1 = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]); +                b1 = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]); +                break; + +            case 1: +                a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]); +                b1 = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]); +                break; + +            default: +                Y_UNREACHABLE(); +        } + +        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1)); +    } + +    alignas(16) float res[4]; +    _mm_store_ps(res, sum1); + +    return res[0] + res[1] + res[2] + res[3]; +} + +double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept { +    __m128d sum1 = _mm_setzero_pd(); +    __m128d sum2 = _mm_setzero_pd(); +    __m128d a1, b1, a2, b2; + +    while (length >= 4) { +        a1 = _mm_loadu_pd(lhs); +        b1 = _mm_loadu_pd(rhs); +        sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + +        a2 = _mm_loadu_pd(lhs + 2); +        b2 = _mm_loadu_pd(rhs + 2); +        sum2 = _mm_add_pd(sum2, _mm_mul_pd(a2, b2)); + +        length -= 4; +        lhs += 4; +        rhs += 4; +    } + +    if (length >= 2) { +        a1 = _mm_loadu_pd(lhs); +        b1 = _mm_loadu_pd(rhs); +        sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); + +        length -= 2; +        lhs += 2; +        rhs += 2; +    } + +    sum1 = _mm_add_pd(sum1, sum2); + +    if (length > 0) { +        a1 = _mm_set_pd(lhs[0], 0.0); +        b1 = _mm_set_pd(rhs[0], 0.0); +        sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1)); +    } + +    alignas(16) double res[2]; +    _mm_store_pd(res, sum1); + +    return res[0] + res[1]; +} + +#endif // ARCADIA_SSE diff --git a/library/cpp/dot_product/dot_product_sse.h b/library/cpp/dot_product/dot_product_sse.h new file mode 100644 index 00000000000..814736007d0 --- /dev/null +++ b/library/cpp/dot_product/dot_product_sse.h @@ -0,0 +1,19 @@ +#pragma once + +#include <util/system/types.h> +#include <util/system/compiler.h> + +Y_PURE_FUNCTION +i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept; + +Y_PURE_FUNCTION +double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept; diff --git a/library/cpp/dot_product/ya.make b/library/cpp/dot_product/ya.make new file mode 100644 index 00000000000..b308967b4be --- /dev/null +++ b/library/cpp/dot_product/ya.make @@ -0,0 +1,20 @@ +LIBRARY() + +SRCS( +    dot_product.cpp +    dot_product_sse.cpp +    dot_product_simple.cpp +) + +IF (USE_SSE4 == "yes" AND OS_LINUX == "yes") +    SRC_C_AVX2(dot_product_avx2.cpp -mfma) +ELSE() +    SRC(dot_product_avx2.cpp) +ENDIF() + +PEERDIR( +    library/cpp/sse +    library/cpp/testing/common +) + +END()  | 
