#include "dot_product.h" #include "dot_product_simple.h" #include #include #include #include #include Y_UNIT_TEST_SUITE(TDocProductTestSuite) { const double EPSILON = 0.00001; template void FillWithRandomNumbers(Num * dst, int seed, size_t length) { TReallyFastRng32 Rnd(seed); Num maxNum = ~Num(0); for (size_t i = 0; i < length; ++i) { dst[i] = Rnd.Uniform(maxNum); } } template void FillWithRandomFloats(Num * dst, int seed, size_t length) { TReallyFastRng32 Rnd(seed); for (size_t i = 0; i < length; ++i) { dst[i] = Rnd.GenRandReal1(); } } template Res SimpleDotProduct(const Int* lhs, const Int* rhs, size_t length) { Res sum = 0; for (size_t i = 0; i < length; ++i) { sum += static_cast(lhs[i]) * static_cast(rhs[i]); } return sum; } Y_UNIT_TEST(TestDotProduct8) { TVector a(100); FillWithRandomNumbers(a.data(), 179, 100); TVector b(100); FillWithRandomNumbers(b.data(), 239, 100); for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { UNIT_ASSERT_EQUAL(DotProduct(a.data() + i, b.data() + i, length), (SimpleDotProduct(a.data() + i, b.data() + i, length))); UNIT_ASSERT_EQUAL(DotProductSimple(a.data() + i, b.data() + i, length), (SimpleDotProduct(a.data() + i, b.data() + i, length))); } } } Y_UNIT_TEST(TestDotProduct8u) { TVector a(100); FillWithRandomNumbers(a.data(), 179, 100); TVector b(100); FillWithRandomNumbers(b.data(), 239, 100); for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { UNIT_ASSERT_EQUAL(DotProduct(a.data() + i, b.data() + i, length), (SimpleDotProduct(a.data() + i, b.data() + i, length))); UNIT_ASSERT_EQUAL(DotProductSimple(a.data() + i, b.data() + i, length), (SimpleDotProduct(a.data() + i, b.data() + i, length))); } } } Y_UNIT_TEST(TestDotProduct32) { TVector a(100); FillWithRandomNumbers(a.data(), 179, 100); TVector b(100); FillWithRandomNumbers(b.data(), 239, 100); for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { UNIT_ASSERT_EQUAL(DotProduct(a.data() + i, b.data() + i, length), (SimpleDotProduct(a.data() + i, b.data() + i, length))); UNIT_ASSERT_EQUAL(DotProductSimple(a.data() + i, b.data() + i, length), (SimpleDotProduct(a.data() + i, b.data() + i, length))); } } } Y_UNIT_TEST(TestDotProductf) { TVector a(100); FillWithRandomFloats(a.data(), 179, 100); TVector b(100); FillWithRandomFloats(b.data(), 239, 100); for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { UNIT_ASSERT(std::fabs(DotProduct(a.data() + i, b.data() + i, length) - (SimpleDotProduct(a.data() + i, b.data() + i, length))) < EPSILON); UNIT_ASSERT(std::fabs(DotProductSimple(a.data() + i, b.data() + i, length) - (SimpleDotProduct(a.data() + i, b.data() + i, length))) < EPSILON); } } } Y_UNIT_TEST(TestL2NormSqaredf) { TVector a(100); FillWithRandomFloats(a.data(), 179, 100); TVector b(100); FillWithRandomFloats(b.data(), 239, 100); for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { UNIT_ASSERT(std::fabs(L2NormSquared(a.data() + i, length) - DotProductSimple(a.data() + i, a.data() + i, length)) < EPSILON); UNIT_ASSERT(std::fabs(L2NormSquared(b.data() + i, length) - DotProductSimple(b.data() + i, b.data() + i, length)) < EPSILON); } } } Y_UNIT_TEST(TestDotProductd) { TVector a(100); FillWithRandomFloats(a.data(), 179, 100); TVector b(100); FillWithRandomFloats(b.data(), 239, 100); for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { UNIT_ASSERT(std::fabs(DotProduct(a.data() + i, b.data() + i, length) - (SimpleDotProduct(a.data() + i, b.data() + i, length))) < EPSILON); UNIT_ASSERT(std::fabs(DotProductSimple(a.data() + i, b.data() + i, length) - (SimpleDotProduct(a.data() + i, b.data() + i, length))) < EPSILON); } } } Y_UNIT_TEST(TestCombinedDotProductf) { TVector a(100); FillWithRandomFloats(a.data(), 179, 100); TVector b(100); FillWithRandomFloats(b.data(), 239, 100); auto simple3WayProduct = [](const float* l, const float* r, size_t length) -> TTriWayDotProduct { return { SimpleDotProduct(l, l, length), SimpleDotProduct(l, r, length), SimpleDotProduct(r, r, length) }; }; auto cosine = [](const auto p) { return p.LR / sqrt(p.LL * p.RR); }; for (size_t i = 0; i < 30; ++i) { for (size_t length = 1; length + i + 1 < a.size(); ++length) { const TString testCaseExpl = TStringBuilder() << "i = " << i << "; length = " << length; { const float c1 = cosine(TriWayDotProduct(a.data() + i, b.data() + i, length)); const float c2 = cosine(simple3WayProduct(a.data() + i, b.data() + i, length)); UNIT_ASSERT_DOUBLES_EQUAL_C(c1, c2, EPSILON, testCaseExpl); } { // Left auto cpl = TriWayDotProduct(a.data() + i, b.data() + i, length, ETriWayDotProductComputeMask::Left); auto cnl = simple3WayProduct(a.data() + i, b.data() + i, length); UNIT_ASSERT_DOUBLES_EQUAL(cpl.RR, 1.0, EPSILON); cpl.RR = 1; cnl.RR = 1; UNIT_ASSERT_DOUBLES_EQUAL_C(cosine(cpl), cosine(cnl), EPSILON, testCaseExpl); } { // Right auto cpr = TriWayDotProduct(a.data() + i, b.data() + i, length, ETriWayDotProductComputeMask::Right); auto cnr = simple3WayProduct(a.data() + i, b.data() + i, length); UNIT_ASSERT_DOUBLES_EQUAL(cpr.LL, 1.0, EPSILON); cpr.LL = 1; cnr.LL = 1; UNIT_ASSERT_DOUBLES_EQUAL_C(cosine(cpr), cosine(cnr), EPSILON, testCaseExpl); } } } } Y_UNIT_TEST(TestDotProductZeroLength) { UNIT_ASSERT_EQUAL(DotProduct(static_cast(nullptr), nullptr, 0), 0); UNIT_ASSERT_EQUAL(DotProduct(static_cast(nullptr), nullptr, 0), 0); UNIT_ASSERT_EQUAL(DotProduct(static_cast(nullptr), nullptr, 0), 0); UNIT_ASSERT(std::abs(DotProduct(static_cast(nullptr), nullptr, 0)) < EPSILON); UNIT_ASSERT(std::abs(DotProduct(static_cast(nullptr), nullptr, 0)) < EPSILON); UNIT_ASSERT_EQUAL(DotProductSimple(static_cast(nullptr), nullptr, 0), 0); UNIT_ASSERT_EQUAL(DotProductSimple(static_cast(nullptr), nullptr, 0), 0); UNIT_ASSERT_EQUAL(DotProductSimple(static_cast(nullptr), nullptr, 0), 0); UNIT_ASSERT(std::abs(DotProductSimple(static_cast(nullptr), nullptr, 0)) < EPSILON); UNIT_ASSERT(std::abs(DotProductSimple(static_cast(nullptr), nullptr, 0)) < EPSILON); } Y_UNIT_TEST(TestDotProductFloatStability) { TVector a(1003); FillWithRandomFloats(a.data(), 179, a.size()); TVector b(1003); FillWithRandomFloats(b.data(), 239, b.size()); float res = DotProduct(a.data(), b.data(), a.size()); for (size_t i = 0; i < 30; ++i) UNIT_ASSERT_VALUES_EQUAL(DotProduct(a.data(), b.data(), a.size()), res); #ifdef ARCADIA_SSE UNIT_ASSERT_VALUES_EQUAL(ToString(res), "250.502"); #endif } Y_UNIT_TEST(TestDotProductDoubleStability) { TVector a(1003); FillWithRandomFloats(a.data(), 13133, a.size()); TVector b(1003); FillWithRandomFloats(b.data(), 1121, b.size()); double res = DotProduct(a.data(), b.data(), a.size()); for (size_t i = 0; i < 30; ++i) UNIT_ASSERT_VALUES_EQUAL(DotProduct(a.data(), b.data(), a.size()), res); #ifdef ARCADIA_SSE UNIT_ASSERT_VALUES_EQUAL(ToString(res), "235.7826026"); #endif } Y_UNIT_TEST(TestDotProductCharStability) { TVector a(1003); FillWithRandomNumbers(a.data(), 1079, a.size()); TVector b(1003); FillWithRandomNumbers(b.data(), 2139, b.size()); ui32 res = DotProduct(a.data(), b.data(), a.size()); for (size_t i = 0; i < 30; ++i) { UNIT_ASSERT_VALUES_EQUAL(DotProduct(a.data(), b.data(), a.size()), res); UNIT_ASSERT_VALUES_EQUAL(DotProductSimple(a.data(), b.data(), a.size()), res); } UNIT_ASSERT_VALUES_EQUAL(res, 90928); } Y_UNIT_TEST(TestDotProductCharStabilityU) { TVector a(1003); FillWithRandomNumbers(a.data(), 1079, a.size()); TVector b(1003); FillWithRandomNumbers(b.data(), 2139, b.size()); ui32 res = DotProduct(a.data(), b.data(), a.size()); for (size_t i = 0; i < 30; ++i) { UNIT_ASSERT_VALUES_EQUAL(DotProduct(a.data(), b.data(), a.size()), res); UNIT_ASSERT_VALUES_EQUAL(DotProductSimple(a.data(), b.data(), a.size()), res); } UNIT_ASSERT_VALUES_EQUAL(res, 16420179); } Y_UNIT_TEST(TestDotProductUI4Manual) { static ui8 a[4] = {1 + (3 << 4), 15 + (8 << 4), 0 + (5 << 4), 3 + (1 << 4)}; static ui8 b[4] = {2 + (4 << 4), 1 + (8 << 4), 7 + (0 << 4), 1 + (4 << 4)}; UNIT_ASSERT_VALUES_EQUAL(DotProductUI4Simple(a, b, 4), 2 + 12 + 15 + 64 + 0 + 0 + 3 + 4); } }