diff options
author | azevaykin <azevaykin@yandex-team.com> | 2024-05-13 16:49:46 +0300 |
---|---|---|
committer | azevaykin <azevaykin@yandex-team.com> | 2024-05-13 17:00:18 +0300 |
commit | b7deb7f0b71db7419781d1b0357dfa443ccc3ff1 (patch) | |
tree | bb0628671666543d7cd991debdeb4892f444e67b /library/cpp | |
parent | d2ca018b203eba118a0af8da0775b9c51650f29b (diff) | |
download | ydb-b7deb7f0b71db7419781d1b0357dfa443ccc3ff1.tar.gz |
Publish l1_distance & l2_distance
Publish l1_distance & l2_distance to https://github.com/ydb-platform/ydb
It has already been published to github: https://github.com/catboost/catboost/tree/master/library/cpp/
a6fd3da173e50ff5a518af0fd5b354f56ca72fdf
Diffstat (limited to 'library/cpp')
-rw-r--r-- | library/cpp/l1_distance/README.md | 15 | ||||
-rw-r--r-- | library/cpp/l1_distance/l1_distance.h | 477 | ||||
-rw-r--r-- | library/cpp/l1_distance/ya.make | 11 | ||||
-rw-r--r-- | library/cpp/l2_distance/README.md | 15 | ||||
-rw-r--r-- | library/cpp/l2_distance/l2_distance.cpp | 376 | ||||
-rw-r--r-- | library/cpp/l2_distance/l2_distance.h | 140 | ||||
-rw-r--r-- | library/cpp/l2_distance/ya.make | 12 |
7 files changed, 1046 insertions, 0 deletions
diff --git a/library/cpp/l1_distance/README.md b/library/cpp/l1_distance/README.md new file mode 100644 index 0000000000..a25f09f12a --- /dev/null +++ b/library/cpp/l1_distance/README.md @@ -0,0 +1,15 @@ +Библиотека для вычисления расстояния между векторами. +===================================================== + +Данная библиотека содержит функцию L1Distance, вычисляющую расстояние между векторами разных типов. +В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для +сравнения можно посмотреть результаты бенчмарка. + +Типичное использование - замена кусков кода вроде: +``` +for (int i = 0; i < len; i++) + dist += std::abs(a[i] - b[i]); +``` +на существенно более эффективный вызов ```L1Distance(a, b, len)```. + +Работает для типов i8, ui8, i32, ui32, float, double. diff --git a/library/cpp/l1_distance/l1_distance.h b/library/cpp/l1_distance/l1_distance.h new file mode 100644 index 0000000000..71545cbe33 --- /dev/null +++ b/library/cpp/l1_distance/l1_distance.h @@ -0,0 +1,477 @@ +#pragma once + +#include <library/cpp/sse/sse.h> + +#include <util/system/types.h> +#include <util/generic/ymath.h> +#include <util/system/align.h> +#include <util/system/platform.h> + +namespace NL1Distance { + namespace NPrivate { + template <typename T> + inline T AbsDelta(T a, T b) { + if (a < b) + return b - a; + return a - b; + } + + template <typename Result, typename Number> + inline Result L1DistanceImpl(const Number* lhs, const Number* rhs, int length) { + Result sum = 0; + + for (int i = 0; i < length; i++) + sum += AbsDelta(lhs[i], rhs[i]); + + return sum; + } + + template <typename Result, typename Number> + inline Result L1DistanceImpl2(const Number* lhs, const Number* rhs, int length) { + Result s0 = 0; + Result s1 = 0; + + while (length >= 2) { + s0 += AbsDelta(lhs[0], rhs[0]); + s1 += AbsDelta(lhs[1], rhs[1]); + lhs += 2; + rhs += 2; + length -= 2; + } + + while (length--) + s0 += AbsDelta(*lhs++, *rhs++); + + return s0 + s1; + } + + template <typename Result, typename Number> + inline Result L1DistanceImpl4(const Number* lhs, const Number* rhs, int length) { + Result s0 = 0; + Result s1 = 0; + Result s2 = 0; + Result s3 = 0; + + while (length >= 4) { + s0 += AbsDelta(lhs[0], rhs[0]); + s1 += AbsDelta(lhs[1], rhs[1]); + s2 += AbsDelta(lhs[2], rhs[2]); + s3 += AbsDelta(lhs[3], rhs[3]); + lhs += 4; + rhs += 4; + length -= 4; + } + + while (length--) + s0 += AbsDelta(*lhs++, *rhs++); + + return s0 + s1 + s2 + s3; + } + + template <typename Result> + inline Result L1DistanceImplUI4(const ui8* lhs, const ui8* rhs, int lengtInBytes) { + Result sum = 0; + + for (int i = 0; i < lengtInBytes; ++i) { + sum += AbsDelta(lhs[i] & 0x0f, rhs[i] & 0x0f); + sum += AbsDelta(lhs[i] & 0xf0, rhs[i] & 0xf0) >> 4; + } + + return sum; + } + +#ifdef ARCADIA_SSE + static const __m128i MASK_UI4_1 = _mm_set_epi8(0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f); + static const __m128i MASK_UI4_2 = _mm_set_epi8(0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0); + + + Y_FORCE_INLINE ui32 L1Distance96Ui8(const ui8* lhs, const ui8* rhs) { + __m128i x1 = _mm_loadu_si128((const __m128i*)&lhs[0]); + __m128i y1 = _mm_loadu_si128((const __m128i*)&rhs[0]); + + __m128i sum = _mm_sad_epu8(x1, y1); + + __m128i x2 = _mm_loadu_si128((const __m128i*)&lhs[16]); + __m128i y2 = _mm_loadu_si128((const __m128i*)&rhs[16]); + + sum = _mm_add_epi64(sum, _mm_sad_epu8(x2, y2)); + + __m128i x3 = _mm_loadu_si128((const __m128i*)&lhs[32]); + __m128i y3 = _mm_loadu_si128((const __m128i*)&rhs[32]); + + sum = _mm_add_epi64(sum, _mm_sad_epu8(x3, y3)); + + __m128i x4 = _mm_loadu_si128((const __m128i*)&lhs[48]); + __m128i y4 = _mm_loadu_si128((const __m128i*)&rhs[48]); + + sum = _mm_add_epi64(sum, _mm_sad_epu8(x4, y4)); + + __m128i x5 = _mm_loadu_si128((const __m128i*)&lhs[64]); + __m128i y5 = _mm_loadu_si128((const __m128i*)&rhs[64]); + + sum = _mm_add_epi64(sum, _mm_sad_epu8(x5, y5)); + + __m128i x6 = _mm_loadu_si128((const __m128i*)&lhs[80]); + __m128i y6 = _mm_loadu_si128((const __m128i*)&rhs[80]); + + sum = _mm_add_epi64(sum, _mm_sad_epu8(x6, y6)); + return _mm_cvtsi128_si32(sum) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 2, 2, 2))); + } + + Y_FORCE_INLINE ui32 L1Distance96Ui4(const ui8* lhs, const ui8* rhs) { + __m128i x1 = _mm_loadu_si128((const __m128i*)&lhs[0]); + __m128i y1 = _mm_loadu_si128((const __m128i*)&rhs[0]); + __m128i sum1 = _mm_sad_epu8(_mm_and_si128(x1, MASK_UI4_1), _mm_and_si128(y1, MASK_UI4_1)); + __m128i sum2 = _mm_sad_epu8(_mm_and_si128(x1, MASK_UI4_2), _mm_and_si128(y1, MASK_UI4_2)); + + __m128i x2 = _mm_loadu_si128((const __m128i*)&lhs[16]); + __m128i y2 = _mm_loadu_si128((const __m128i*)&rhs[16]); + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(x2, MASK_UI4_1), _mm_and_si128(y2, MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(x2, MASK_UI4_2), _mm_and_si128(y2, MASK_UI4_2))); + + __m128i x3 = _mm_loadu_si128((const __m128i*)&lhs[32]); + __m128i y3 = _mm_loadu_si128((const __m128i*)&rhs[32]); + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(x3, MASK_UI4_1), _mm_and_si128(y3, MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(x3, MASK_UI4_2), _mm_and_si128(y3, MASK_UI4_2))); + + __m128i x4 = _mm_loadu_si128((const __m128i*)&lhs[48]); + __m128i y4 = _mm_loadu_si128((const __m128i*)&rhs[48]); + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(x4, MASK_UI4_1), _mm_and_si128(y4, MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(x4, MASK_UI4_2), _mm_and_si128(y4, MASK_UI4_2))); + + __m128i x5 = _mm_loadu_si128((const __m128i*)&lhs[64]); + __m128i y5 = _mm_loadu_si128((const __m128i*)&rhs[64]); + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(x5, MASK_UI4_1), _mm_and_si128(y5, MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(x5, MASK_UI4_2), _mm_and_si128(y5, MASK_UI4_2))); + + __m128i x6 = _mm_loadu_si128((const __m128i*)&lhs[80]); + __m128i y6 = _mm_loadu_si128((const __m128i*)&rhs[80]); + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(x6, MASK_UI4_1), _mm_and_si128(y6, MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(x6, MASK_UI4_2), _mm_and_si128(y6, MASK_UI4_2))); + + return _mm_cvtsi128_si32(sum1) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum1, _MM_SHUFFLE(2, 2, 2, 2))) + + ((_mm_cvtsi128_si32(sum2) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum2, _MM_SHUFFLE(2, 2, 2, 2)))) >> 4); + } +#endif // ARCADIA_SSE + } // namespace NPrivate +} + +/** + * L1Distance (sum(abs(l[i] - r[i]))) implementation using SSE when possible. + */ +#ifdef ARCADIA_SSE + +Y_FORCE_INLINE ui32 L1Distance(const i8* lhs, const i8* rhs, int length) { + static const __m128i unsignedToSignedDiff = _mm_set_epi8( + -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128); + __m128i resVec = _mm_setzero_si128(); + + while (length >= 16) { + __m128i lVec = _mm_sub_epi8(_mm_loadu_si128((const __m128i*)lhs), unsignedToSignedDiff); + __m128i rVec = _mm_sub_epi8(_mm_loadu_si128((const __m128i*)rhs), unsignedToSignedDiff); + + resVec = _mm_add_epi64(_mm_sad_epu8(lVec, rVec), resVec); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) i64 res[2]; + _mm_store_si128((__m128i*)res, resVec); + ui32 sum = res[0] + res[1]; + for (int i = 0; i < length; ++i) { + const i32 diff = static_cast<i32>(lhs[i]) - static_cast<i32>(rhs[i]); + sum += (diff >= 0) ? diff : -diff; + } + + return sum; +} + +Y_FORCE_INLINE ui32 L1Distance(const ui8* lhs, const ui8* rhs, int length) { + if (length == 96) + return NL1Distance::NPrivate::L1Distance96Ui8(lhs, rhs); + + int l16 = length & (~15); + __m128i sum = _mm_setzero_si128(); + + if ((reinterpret_cast<uintptr_t>(lhs) & 0x0f) || (reinterpret_cast<uintptr_t>(rhs) & 0x0f)) { + for (int i = 0; i < l16; i += 16) { + __m128i a = _mm_loadu_si128((const __m128i*)(&lhs[i])); + __m128i b = _mm_loadu_si128((const __m128i*)(&rhs[i])); + + sum = _mm_add_epi64(sum, _mm_sad_epu8(a, b)); + } + } else { + for (int i = 0; i < l16; i += 16) { + __m128i sum_ab = _mm_sad_epu8(*(const __m128i*)(&lhs[i]), *(const __m128i*)(&rhs[i])); + sum = _mm_add_epi64(sum, sum_ab); + } + } + + if (l16 == length) + return _mm_cvtsi128_si32(sum) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 2, 2, 2))); + + int l4 = length & (~3); + for (int i = l16; i < l4; i += 4) { + __m128i a = _mm_set_epi32(*((const ui32*)&lhs[i]), 0, 0, 0); + __m128i b = _mm_set_epi32(*((const ui32*)&rhs[i]), 0, 0, 0); + sum = _mm_add_epi64(sum, _mm_sad_epu8(a, b)); + } + + ui32 res = _mm_cvtsi128_si32(sum) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 2, 2, 2))); + + for (int i = l4; i < length; i++) + res += lhs[i] < rhs[i] ? rhs[i] - lhs[i] : lhs[i] - rhs[i]; + + return res; +} + +Y_FORCE_INLINE ui32 L1DistanceUI4(const ui8* lhs, const ui8* rhs, int lengtInBytes) { + + if (lengtInBytes == 96) + return NL1Distance::NPrivate::L1Distance96Ui4(lhs, rhs); + + int l16 = lengtInBytes & (~15); + __m128i sum1 = _mm_setzero_si128(); + __m128i sum2 = _mm_setzero_si128(); + + for (int i = 0; i < l16; i += 16) { + __m128i a = _mm_loadu_si128((const __m128i*)(&lhs[i])); + __m128i b = _mm_loadu_si128((const __m128i*)(&rhs[i])); + + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(a, NL1Distance::NPrivate::MASK_UI4_1), _mm_and_si128(b, NL1Distance::NPrivate::MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(a, NL1Distance::NPrivate::MASK_UI4_2), _mm_and_si128(b, NL1Distance::NPrivate::MASK_UI4_2))); + } + + if (l16 == lengtInBytes) + return _mm_cvtsi128_si32(sum1) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum1, _MM_SHUFFLE(2, 2, 2, 2))) + + ((_mm_cvtsi128_si32(sum2) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum2, _MM_SHUFFLE(2, 2, 2, 2)))) >> 4); + + int l4 = lengtInBytes & (~3); + for (int i = l16; i < l4; i += 4) { + __m128i a = _mm_set_epi32(*((const ui32*)&lhs[i]), 0, 0, 0); + __m128i b = _mm_set_epi32(*((const ui32*)&rhs[i]), 0, 0, 0); + sum1 = _mm_add_epi64(sum1, _mm_sad_epu8(_mm_and_si128(a, NL1Distance::NPrivate::MASK_UI4_1), _mm_and_si128(b, NL1Distance::NPrivate::MASK_UI4_1))); + sum2 = _mm_add_epi64(sum2, _mm_sad_epu8(_mm_and_si128(a, NL1Distance::NPrivate::MASK_UI4_2), _mm_and_si128(b, NL1Distance::NPrivate::MASK_UI4_2))); + } + + ui32 res = _mm_cvtsi128_si32(sum1) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum1, _MM_SHUFFLE(2, 2, 2, 2))) + + ((_mm_cvtsi128_si32(sum2) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sum2, _MM_SHUFFLE(2, 2, 2, 2)))) >> 4); + + for (int i = l4; i < lengtInBytes; ++i) { + ui8 a1 = lhs[i] & 0x0f; + ui8 a2 = (lhs[i] & 0xf0) >> 4; + ui8 b1 = rhs[i] & 0x0f; + ui8 b2 = (rhs[i] & 0xf0) >> 4; + res += a1 < b1 ? b1 - a1 : a1 - b1; + res += a2 < b2 ? b2 - a2 : a2 - b2; + } + + return res; +} + +Y_FORCE_INLINE ui64 L1Distance(const i32* lhs, const i32* rhs, int length) { + __m128i zero = _mm_setzero_si128(); + __m128i res = zero; + + while (length >= 4) { + __m128i a = _mm_loadu_si128((const __m128i*)lhs); + __m128i b = _mm_loadu_si128((const __m128i*)rhs); + __m128i mask = _mm_cmpgt_epi32(a, b); + __m128i a2 = _mm_and_si128(mask, _mm_sub_epi32(a, b)); + b = _mm_andnot_si128(mask, _mm_sub_epi32(b, a)); + a = _mm_or_si128(a2, b); + res = _mm_add_epi64(_mm_unpackhi_epi32(a, zero), res); + res = _mm_add_epi64(_mm_unpacklo_epi32(a, zero), res); + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) ui64 r[2]; + _mm_store_si128((__m128i*)r, res); + ui64 sum = r[0] + r[1]; + + while (length) { + sum += lhs[0] < rhs[0] ? rhs[0] - lhs[0] : lhs[0] - rhs[0]; + ++lhs; + ++rhs; + --length; + } + + return sum; +} + +Y_FORCE_INLINE ui64 L1Distance(const ui32* lhs, const ui32* rhs, int length) { + __m128i zero = _mm_setzero_si128(); + __m128i shift = _mm_set1_epi32(0x80000000); + __m128i res = zero; + + while (length >= 4) { + __m128i a = _mm_add_epi32(_mm_loadu_si128((const __m128i*)lhs), shift); + __m128i b = _mm_add_epi32(_mm_loadu_si128((const __m128i*)rhs), shift); + __m128i mask = _mm_cmpgt_epi32(a, b); + __m128i a2 = _mm_and_si128(mask, _mm_sub_epi32(a, b)); + b = _mm_andnot_si128(mask, _mm_sub_epi32(b, a)); + a = _mm_or_si128(a2, b); + res = _mm_add_epi64(_mm_unpackhi_epi32(a, zero), res); + res = _mm_add_epi64(_mm_unpacklo_epi32(a, zero), res); + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) ui64 r[2]; + _mm_store_si128((__m128i*)r, res); + ui64 sum = r[0] + r[1]; + + while (length) { + sum += lhs[0] < rhs[0] ? rhs[0] - lhs[0] : lhs[0] - rhs[0]; + ++lhs; + ++rhs; + --length; + } + + return sum; +} + +Y_FORCE_INLINE float L1Distance(const float* lhs, const float* rhs, int length) { + __m128 res = _mm_setzero_ps(); + __m128 absMask = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)); + + while (length >= 4) { + __m128 a = _mm_loadu_ps(lhs); + __m128 b = _mm_loadu_ps(rhs); + __m128 d = _mm_sub_ps(a, b); + res = _mm_add_ps(_mm_and_ps(d, absMask), res); + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) float r[4]; + _mm_store_ps(r, res); + float sum = r[0] + r[1] + r[2] + r[3]; + + while (length) { + sum += std::abs(*lhs - *rhs); + ++lhs; + ++rhs; + --length; + } + + return sum; +} + +Y_FORCE_INLINE double L1Distance(const double* lhs, const double* rhs, int length) { + __m128d res = _mm_setzero_pd(); + __m128d absMask = _mm_castsi128_pd(_mm_set_epi32(0x7fffffff, 0xffffffff, 0x7fffffff, 0xffffffff)); + + while (length >= 2) { + __m128d a = _mm_loadu_pd(lhs); + __m128d b = _mm_loadu_pd(rhs); + __m128d d = _mm_sub_pd(a, b); + res = _mm_add_pd(_mm_and_pd(d, absMask), res); + rhs += 2; + lhs += 2; + length -= 2; + } + + alignas(16) double r[2]; + _mm_store_pd(r, res); + double sum = r[0] + r[1]; + + while (length) { + sum += std::abs(*lhs - *rhs); + ++lhs; + ++rhs; + --length; + } + + return sum; +} + +#else // ARCADIA_SSE + +inline ui32 L1Distance(const i8* lhs, const i8* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl<ui32, i8>(lhs, rhs, length); +} + +inline ui32 L1Distance(const ui8* lhs, const ui8* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl<ui32, ui8>(lhs, rhs, length); +} + +inline ui32 L1DistanceUI4(const ui8* lhs, const ui8* rhs, int lengtInBytes) { + return NL1Distance::NPrivate::L1DistanceImplUI4<ui32>(lhs, rhs, lengtInBytes); +} + +inline ui64 L1Distance(const ui32* lhs, const ui32* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl2<ui64, ui32>(lhs, rhs, length); +} + +inline ui64 L1Distance(const i32* lhs, const i32* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl2<ui64, i32>(lhs, rhs, length); +} + +inline float L1Distance(const float* lhs, const float* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl4<float, float>(lhs, rhs, length); +} + +inline double L1Distance(const double* lhs, const double* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl4<double, double>(lhs, rhs, length); +} + +#endif // _sse_ + +/** + * L1Distance (sum(abs(l[i] - r[i]))) implementation without SSE. + */ +inline ui32 L1DistanceSlow(const i8* lhs, const i8* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl<ui32, i8>(lhs, rhs, length); +} + +inline ui32 L1DistanceSlow(const ui8* lhs, const ui8* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl<ui32, ui8>(lhs, rhs, length); +} + +inline ui32 L1DistanceUI4Slow(const ui8* lhs, const ui8* rhs, int lengtInBytes) { + return NL1Distance::NPrivate::L1DistanceImplUI4<ui32>(lhs, rhs, lengtInBytes); +} + +inline ui64 L1DistanceSlow(const ui32* lhs, const ui32* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl2<ui64, ui32>(lhs, rhs, length); +} + +inline ui64 L1DistanceSlow(const i32* lhs, const i32* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl2<ui64, i32>(lhs, rhs, length); +} + +inline float L1DistanceSlow(const float* lhs, const float* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl4<float, float>(lhs, rhs, length); +} + +inline double L1DistanceSlow(const double* lhs, const double* rhs, int length) { + return NL1Distance::NPrivate::L1DistanceImpl4<double, double>(lhs, rhs, length); +} + +namespace NL1Distance { + // Simpler wrapper allowing to use this functions as template argument. + template <typename T> + struct TL1Distance { + using TResult = decltype(L1Distance(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0)); + + inline TResult operator()(const T* a, const T* b, int length) const { + return L1Distance(a, b, length); + } + }; + + struct TL1DistanceUI4 { + using TResult = ui32; + + inline TResult operator()(const ui8* a, const ui8* b, int lengtInBytes) const { + return L1DistanceUI4(a, b, lengtInBytes); + } + }; +} diff --git a/library/cpp/l1_distance/ya.make b/library/cpp/l1_distance/ya.make new file mode 100644 index 0000000000..9345cb99af --- /dev/null +++ b/library/cpp/l1_distance/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +SRCS( + l1_distance.h +) + +PEERDIR( + library/cpp/sse +) + +END() diff --git a/library/cpp/l2_distance/README.md b/library/cpp/l2_distance/README.md new file mode 100644 index 0000000000..e6d3b7ad41 --- /dev/null +++ b/library/cpp/l2_distance/README.md @@ -0,0 +1,15 @@ +Библиотека для вычисления расстояния между векторами. +===================================================== + +Данная библиотека содержит две функции L2Distance и L2SqrDistance. Первая вычисляет L2 расстояние между векторами +разных типов, а вторая его квадрат. В отличии от наивной реализации, библиотека использует SSE и работает существенно +быстрее. Для сравнения можно посмотреть результаты бенчмарка. + +Типичное использование - замена кусков кода вроде: +``` +for (int i = 0; i < len; i++) + dist += (a[i] - b[i]) * (a[i] - b[i]); +``` +на существенно более эффективный вызов ```L2SqrDistance(a, b, len)```. + +Работает для типов i8, ui8, i32, ui32, float, double. diff --git a/library/cpp/l2_distance/l2_distance.cpp b/library/cpp/l2_distance/l2_distance.cpp new file mode 100644 index 0000000000..2266e10aaa --- /dev/null +++ b/library/cpp/l2_distance/l2_distance.cpp @@ -0,0 +1,376 @@ +#include "l2_distance.h" + +#include <library/cpp/sse/sse.h> + +#include <contrib/libs/cblas/include/cblas.h> + +#include <util/system/platform.h> + +template <typename Result, typename Number> +inline Result SqrDelta(Number a, Number b) { + Result diff = a < b ? b - a : a - b; + return diff * diff; +} + +template <typename Result, typename Number> +inline Result L2SqrDistanceImpl(const Number* a, const Number* b, int length) { + Result res = 0; + + for (int i = 0; i < length; i++) { + res += SqrDelta<Result, Number>(a[i], b[i]); + } + + return res; +} + +template <typename Result, typename Number> +inline Result L2SqrDistanceImpl2(const Number* a, const Number* b, int length) { + Result s0 = 0; + Result s1 = 0; + + while (length >= 2) { + s0 += SqrDelta<Result, Number>(a[0], b[0]); + s1 += SqrDelta<Result, Number>(a[1], b[1]); + a += 2; + b += 2; + length -= 2; + } + + while (length--) + s0 += SqrDelta<Result, Number>(*a++, *b++); + + return s0 + s1; +} + +template <typename Result, typename Number> +inline Result L2SqrDistanceImpl4(const Number* a, const Number* b, int length) { + Result s0 = 0; + Result s1 = 0; + Result s2 = 0; + Result s3 = 0; + + while (length >= 4) { + s0 += SqrDelta<Result, Number>(a[0], b[0]); + s1 += SqrDelta<Result, Number>(a[1], b[1]); + s2 += SqrDelta<Result, Number>(a[2], b[2]); + s3 += SqrDelta<Result, Number>(a[3], b[3]); + a += 4; + b += 4; + length -= 4; + } + + while (length--) + s0 += SqrDelta<Result, Number>(*a++, *b++); + + return s0 + s1 + s2 + s3; +} + +inline ui32 L2SqrDistanceImplUI4(const ui8* a, const ui8* b, int length) { + ui32 res = 0; + for (int i = 0; i < length; i++) { + res += SqrDelta<ui32, ui8>(a[i] & 0x0f, b[i] & 0x0f); + res += SqrDelta<ui32, ui8>(a[i] & 0xf0, b[i] & 0xf0) >> 8; + } + return res; +} + + +#ifdef ARCADIA_SSE +namespace NL2Distance { + static const __m128i MASK_UI4_1 = _mm_set_epi8(0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, + 0x0f, 0x0f, 0x0f, 0x0f, 0x0f); + static const __m128i MASK_UI4_2 = _mm_set_epi8(0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, + 0xf0, 0xf0, 0xf0, 0xf0, 0xf0); +} +ui32 L2SqrDistance(const i8* lhs, const i8* rhs, int length) { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec = zero; + + while (length >= 16) { + __m128i vec = _mm_subs_epi8(_mm_loadu_si128((const __m128i*)lhs), _mm_loadu_si128((const __m128i*)rhs)); + +#ifdef _sse4_1_ + __m128i lo = _mm_cvtepi8_epi16(vec); + __m128i hi = _mm_cvtepi8_epi16(_mm_alignr_epi8(vec, vec, 8)); +#else + __m128i lo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, vec), 8); + __m128i hi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, vec), 8); +#endif + + resVec = _mm_add_epi32(resVec, + _mm_add_epi32(_mm_madd_epi16(lo, lo), _mm_madd_epi16(hi, hi))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) ui32 res[4]; + _mm_store_si128((__m128i*)res, resVec); + ui32 sum = res[0] + res[1] + res[2] + res[3]; + for (int i = 0; i < length; ++i) { + sum += Sqr(static_cast<i32>(lhs[i]) - static_cast<i32>(rhs[i])); + } + + return sum; +} + +ui32 L2SqrDistance(const ui8* lhs, const ui8* rhs, int length) { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec = zero; + + while (length >= 16) { + __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); + __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + + // We will think about this vectors as about i16. + __m128i lo = _mm_sub_epi16(_mm_unpacklo_epi8(lVec, zero), _mm_unpacklo_epi8(rVec, zero)); + __m128i hi = _mm_sub_epi16(_mm_unpackhi_epi8(lVec, zero), _mm_unpackhi_epi8(rVec, zero)); + + resVec = _mm_add_epi32(resVec, + _mm_add_epi32(_mm_madd_epi16(lo, lo), _mm_madd_epi16(hi, hi))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) ui32 res[4]; + _mm_store_si128((__m128i*)res, resVec); + ui32 sum = res[0] + res[1] + res[2] + res[3]; + for (int i = 0; i < length; ++i) { + sum += Sqr(static_cast<i32>(lhs[i]) - static_cast<i32>(rhs[i])); + } + + return sum; +} + +float L2SqrDistance(const float* lhs, const float* rhs, int length) { + __m128 sum = _mm_setzero_ps(); + + while (length >= 4) { + __m128 a = _mm_loadu_ps(lhs); + __m128 b = _mm_loadu_ps(rhs); + __m128 delta = _mm_sub_ps(a, b); + sum = _mm_add_ps(sum, _mm_mul_ps(delta, delta)); + length -= 4; + rhs += 4; + lhs += 4; + } + + alignas(16) float res[4]; + _mm_store_ps(res, sum); + + while (length--) + res[0] += Sqr(*rhs++ - *lhs++); + + return res[0] + res[1] + res[2] + res[3]; +} + +double L2SqrDistance(const double* lhs, const double* rhs, int length) { + __m128d sum = _mm_setzero_pd(); + + while (length >= 2) { + __m128d a = _mm_loadu_pd(lhs); + __m128d b = _mm_loadu_pd(rhs); + __m128d delta = _mm_sub_pd(a, b); + sum = _mm_add_pd(sum, _mm_mul_pd(delta, delta)); + length -= 2; + rhs += 2; + lhs += 2; + } + + alignas(16) double res[2]; + _mm_store_pd(res, sum); + + while (length--) + res[0] += Sqr(*rhs++ - *lhs++); + + return res[0] + res[1]; +} + +ui64 L2SqrDistance(const i32* lhs, const i32* rhs, int length) { + __m128i zero = _mm_setzero_si128(); + __m128i res = zero; + + while (length >= 4) { + __m128i a = _mm_loadu_si128((const __m128i*)lhs); + __m128i b = _mm_loadu_si128((const __m128i*)rhs); + +#ifdef _sse4_1_ + // In SSE4.1 si32*si32->si64 is available, so we may do just (a-b)*(a-b) not caring about (a-b) sign + a = _mm_sub_epi32(a, b); + res = _mm_add_epi64(_mm_mul_epi32(a, a), res); + a = _mm_alignr_epi8(a, a, 4); + res = _mm_add_epi64(_mm_mul_epi32(a, a), res); +#else + __m128i mask = _mm_cmpgt_epi32(a, b); // mask = a > b? 0xffffffff: 0; + __m128i a2 = _mm_sub_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, b)); // a2 = (a & mask) - (b & mask) (for a > b) + b = _mm_sub_epi32(_mm_andnot_si128(mask, b), _mm_andnot_si128(mask, a)); // b = (b & ~mask) - (a & ~mask) (for b > a) + a = _mm_or_si128(a2, b); // a = abs(a - b) + a2 = _mm_unpackhi_epi32(a, zero); + res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); + a2 = _mm_unpacklo_epi32(a, zero); + res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); +#endif + + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) ui64 r[2]; + _mm_store_si128((__m128i*)r, res); + ui64 sum = r[0] + r[1]; + + while (length) { + sum += SqrDelta<ui64, i32>(lhs[0], rhs[0]); + ++lhs; + ++rhs; + --length; + } + + return sum; +} + +ui64 L2SqrDistance(const ui32* lhs, const ui32* rhs, int length) { + __m128i zero = _mm_setzero_si128(); + __m128i shift = _mm_set1_epi32(0x80000000); + __m128i res = zero; + + while (length >= 4) { + __m128i a = _mm_add_epi32(_mm_loadu_si128((const __m128i*)lhs), shift); + __m128i b = _mm_add_epi32(_mm_loadu_si128((const __m128i*)rhs), shift); + __m128i mask = _mm_cmpgt_epi32(a, b); // mask = a > b? 0xffffffff: 0; + __m128i a2 = _mm_sub_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, b)); // a2 = (a & mask) - (b & mask) (for a > b) + b = _mm_sub_epi32(_mm_andnot_si128(mask, b), _mm_andnot_si128(mask, a)); // b = (b & ~mask) - (a & ~mask) (for b > a) + a = _mm_or_si128(a2, b); // a = abs(a - b) + +#ifdef _sse4_1_ + res = _mm_add_epi64(_mm_mul_epu32(a, a), res); + a = _mm_alignr_epi8(a, a, 4); + res = _mm_add_epi64(_mm_mul_epu32(a, a), res); +#else + a2 = _mm_unpackhi_epi32(a, zero); + res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); + a2 = _mm_unpacklo_epi32(a, zero); + res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); +#endif + + rhs += 4; + lhs += 4; + length -= 4; + } + + alignas(16) ui64 r[2]; + _mm_store_si128((__m128i*)r, res); + ui64 sum = r[0] + r[1]; + + while (length) { + sum += SqrDelta<ui64, ui32>(lhs[0], rhs[0]); + ++lhs; + ++rhs; + --length; + } + + return sum; +} + +ui32 L2SqrDistanceUI4(const ui8* lhs, const ui8* rhs, int length) { + const __m128i zero = _mm_setzero_si128(); + __m128i resVec1 = zero; + __m128i resVec2 = zero; + + while (length >= 16) { + __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); + __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); + + __m128i lVec1 = _mm_and_si128(lVec, NL2Distance::MASK_UI4_1); + __m128i lVec2 = _mm_and_si128(lVec, NL2Distance::MASK_UI4_2); + __m128i rVec1 = _mm_and_si128(rVec, NL2Distance::MASK_UI4_1); + __m128i rVec2 = _mm_and_si128(rVec, NL2Distance::MASK_UI4_2); + // We will think about this vectors as about i16. + __m128i lo1 = _mm_sub_epi16(_mm_unpacklo_epi8(lVec1, zero), _mm_unpacklo_epi8(rVec1, zero)); + __m128i hi1 = _mm_sub_epi16(_mm_unpackhi_epi8(lVec1, zero), _mm_unpackhi_epi8(rVec1, zero)); + __m128i lo2 = _mm_sub_epi16(_mm_unpacklo_epi8(lVec2, zero), _mm_unpacklo_epi8(rVec2, zero)); + __m128i hi2 = _mm_sub_epi16(_mm_unpackhi_epi8(lVec2, zero), _mm_unpackhi_epi8(rVec2, zero)); + + resVec1 = _mm_add_epi32(resVec1, _mm_add_epi32(_mm_madd_epi16(lo1, lo1), _mm_madd_epi16(hi1, hi1))); + resVec2 = _mm_add_epi32(resVec2, _mm_add_epi32(_mm_madd_epi16(lo2, lo2), _mm_madd_epi16(hi2, hi2))); + + lhs += 16; + rhs += 16; + length -= 16; + } + + alignas(16) ui32 res[4]; + _mm_store_si128((__m128i*)res, resVec1); + ui32 sum = res[0] + res[1] + res[2] + res[3]; + _mm_store_si128((__m128i*)res, resVec2); + sum += (res[0] + res[1] + res[2] + res[3]) >> 8; + for (int i = 0; i < length; ++i) { + sum += Sqr(static_cast<i32>(lhs[i] & 0x0f) - static_cast<i32>(rhs[i] & 0x0f)); + sum += Sqr(static_cast<i32>(lhs[i] & 0xf0) - static_cast<i32>(rhs[i] & 0xf0)) >> 8; + } + return sum; +} + +#else /* !ARCADIA_SSE */ + +ui32 L2SqrDistance(const i8* lhs, const i8* rhs, int length) { + return L2SqrDistanceImpl<ui32, i8>(lhs, rhs, length); +} + +ui32 L2SqrDistance(const ui8* lhs, const ui8* rhs, int length) { + return L2SqrDistanceImpl<ui32, ui8>(lhs, rhs, length); +} + +ui64 L2SqrDistance(const i32* a, const i32* b, int length) { + return L2SqrDistanceImpl2<ui64, i32>(a, b, length); +} + +ui64 L2SqrDistance(const ui32* a, const ui32* b, int length) { + return L2SqrDistanceImpl2<ui64, ui32>(a, b, length); +} + +float L2SqrDistance(const float* a, const float* b, int length) { + return L2SqrDistanceImpl4<float, float>(a, b, length); +} + +double L2SqrDistance(const double* a, const double* b, int length) { + return L2SqrDistanceImpl2<double, double>(a, b, length); +} + +ui32 L2SqrDistanceUI4(const ui8* lhs, const ui8* rhs, int length) { + return L2SqrDistanceImplUI4(lhs, rhs, length); +} + +#endif /* ARCADIA_SSE */ + +ui32 L2SqrDistanceSlow(const i8* lhs, const i8* rhs, int length) { + return L2SqrDistanceImpl<ui32, i8>(lhs, rhs, length); +} + +ui32 L2SqrDistanceSlow(const ui8* lhs, const ui8* rhs, int length) { + return L2SqrDistanceImpl<ui32, ui8>(lhs, rhs, length); +} + +ui64 L2SqrDistanceSlow(const i32* a, const i32* b, int length) { + return L2SqrDistanceImpl2<ui64, i32>(a, b, length); +} + +ui64 L2SqrDistanceSlow(const ui32* a, const ui32* b, int length) { + return L2SqrDistanceImpl2<ui64, ui32>(a, b, length); +} + +float L2SqrDistanceSlow(const float* a, const float* b, int length) { + return L2SqrDistanceImpl4<float, float>(a, b, length); +} + +double L2SqrDistanceSlow(const double* a, const double* b, int length) { + return L2SqrDistanceImpl2<double, double>(a, b, length); +} + +ui32 L2SqrDistanceUI4Slow(const ui8* lhs, const ui8* rhs, int length) { + return L2SqrDistanceImplUI4(lhs, rhs, length); +} diff --git a/library/cpp/l2_distance/l2_distance.h b/library/cpp/l2_distance/l2_distance.h new file mode 100644 index 0000000000..0106be70a5 --- /dev/null +++ b/library/cpp/l2_distance/l2_distance.h @@ -0,0 +1,140 @@ +#pragma once + +#include <util/system/types.h> +#include <util/generic/ymath.h> +#include <cmath> + +namespace NPrivate { + namespace NL2Distance { + template <typename Number> + inline Number L2DistanceSqrt(Number a) { + return std::sqrt(a); + } + + template <> + inline ui64 L2DistanceSqrt(ui64 a) { + // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_.28base_2.29 + ui64 res = 0; + ui64 bit = static_cast<ui64>(1) << (sizeof(ui64) * 8 - 2); + + while (bit > a) + bit >>= 2; + + while (bit != 0) { + if (a >= res + bit) { + a -= (res + bit); + res = (res >> 1) + bit; + } else { + res >>= 1; + } + bit >>= 2; + } + + return res; + } + + template <> + inline ui32 L2DistanceSqrt(ui32 a) { + return L2DistanceSqrt<ui64>(a); + } + + // Special class to match argument type and result type. + template <typename Arg> + class TMatchArgumentResult { + public: + using TResult = Arg; + }; + + template <> + class TMatchArgumentResult<i8> { + public: + using TResult = ui32; + }; + + template <> + class TMatchArgumentResult<ui8> { + public: + using TResult = ui32; + }; + + template <> + class TMatchArgumentResult<i32> { + public: + using TResult = ui64; + }; + + template <> + class TMatchArgumentResult<ui32> { + public: + using TResult = ui64; + }; + + } + +} + +/** + * sqr(l2_distance) = sum((a[i]-b[i])^2) + * If target system does not support SSE2 Slow functions are used automatically. + */ +ui32 L2SqrDistance(const i8* a, const i8* b, int cnt); +ui32 L2SqrDistance(const ui8* a, const ui8* b, int cnt); +ui64 L2SqrDistance(const i32* a, const i32* b, int length); +ui64 L2SqrDistance(const ui32* a, const ui32* b, int length); +float L2SqrDistance(const float* a, const float* b, int length); +double L2SqrDistance(const double* a, const double* b, int length); +ui32 L2SqrDistanceUI4(const ui8* a, const ui8* b, int cnt); + +ui32 L2SqrDistanceSlow(const i8* a, const i8* b, int cnt); +ui32 L2SqrDistanceSlow(const ui8* a, const ui8* b, int cnt); +ui64 L2SqrDistanceSlow(const i32* a, const i32* b, int length); +ui64 L2SqrDistanceSlow(const ui32* a, const ui32* b, int length); +float L2SqrDistanceSlow(const float* a, const float* b, int length); +double L2SqrDistanceSlow(const double* a, const double* b, int length); +ui32 L2SqrDistanceUI4Slow(const ui8* a, const ui8* b, int cnt); + +/** + * L2 distance = sqrt(sum((a[i]-b[i])^2)) + */ +template <typename Number, typename Result = typename NPrivate::NL2Distance::TMatchArgumentResult<Number>::TResult> +inline Result L2Distance(const Number* a, const Number* b, int cnt) { + return NPrivate::NL2Distance::L2DistanceSqrt(L2SqrDistance(a, b, cnt)); +} + +template <typename Number, typename Result = typename NPrivate::NL2Distance::TMatchArgumentResult<Number>::TResult> +inline Result L2DistanceSlow(const Number* a, const Number* b, int cnt) { + return NPrivate::NL2Distance::L2DistanceSqrt(L2SqrDistanceSlow(a, b, cnt)); +} + +namespace NL2Distance { + // You can use this structures as template function arguments. + template <typename T> + struct TL2Distance { + using TResult = decltype(L2Distance(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0)); + inline TResult operator()(const T* a, const T* b, int length) const { + return L2Distance(a, b, length); + } + }; + + struct TL2DistanceUI4 { + using TResult = ui32; + inline TResult operator()(const ui8* a, const ui8* b, int lengtInBytes) const { + return NPrivate::NL2Distance::L2DistanceSqrt(L2SqrDistanceUI4(a, b, lengtInBytes)); + } + }; + + template <typename T> + struct TL2SqrDistance { + using TResult = decltype(L2SqrDistance(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0)); + inline TResult operator()(const T* a, const T* b, int length) const { + return L2SqrDistance(a, b, length); + } + }; + + struct TL2SqrDistanceUI4 { + using TResult = ui32; + inline TResult operator()(const ui8* a, const ui8* b, int lengtInBytes) const { + return L2SqrDistanceUI4(a, b, lengtInBytes); + } + }; +} diff --git a/library/cpp/l2_distance/ya.make b/library/cpp/l2_distance/ya.make new file mode 100644 index 0000000000..919e77ae4a --- /dev/null +++ b/library/cpp/l2_distance/ya.make @@ -0,0 +1,12 @@ +LIBRARY() + +SRCS( + l2_distance.h + l2_distance.cpp +) + +PEERDIR( + library/cpp/sse +) + +END() |