diff options
author | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-06-07 05:49:38 +0300 |
---|---|---|
committer | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-06-07 05:49:38 +0300 |
commit | dd24834944d3b49f6b7be1199a349c7d9ea3b7b2 (patch) | |
tree | dc70048d5f9846d23e3f92e9d1bbe806553b716a /util/random | |
parent | 811d88e6a929f3a50645eeb76bdea2c47dbd034e (diff) | |
download | ydb-dd24834944d3b49f6b7be1199a349c7d9ea3b7b2.tar.gz |
Pull request "ShufflePart function" by @igormarkov00 from https://github.com/catboost/catboost/pull/2087
MERGED FROM https://github.com/catboost/catboost/pull/2087
ref:d27dfbe948e17ef1feb8ad2b13b409915afc86c8
Diffstat (limited to 'util/random')
-rw-r--r-- | util/random/shuffle.h | 34 | ||||
-rw-r--r-- | util/random/shuffle_ut.cpp | 39 |
2 files changed, 73 insertions, 0 deletions
diff --git a/util/random/shuffle.h b/util/random/shuffle.h index 274ac147c9..218695696b 100644 --- a/util/random/shuffle.h +++ b/util/random/shuffle.h @@ -4,11 +4,13 @@ #include "entropy.h" #include <util/generic/utility.h> +#include <util/system/yassert.h> // some kind of https://en.wikipedia.org/wiki/Fisher–Yates_shuffle#The_modern_algorithm template <typename TRandIter, typename TRandIterEnd> inline void Shuffle(TRandIter begin, TRandIterEnd end) { + Y_ASSERT(begin <= end); static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme"); static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme"); @@ -21,6 +23,7 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end) { template <typename TRandIter, typename TRandIterEnd, typename TRandGen> inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) { + Y_ASSERT(begin <= end); const size_t sz = end - begin; for (size_t i = 1; i < sz; ++i) { @@ -28,6 +31,37 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) { } } +// Fills first size elements of array with equiprobably randomly +// chosen elements of array with no replacement +template <typename TRandIter, typename TRandIterEnd> +inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size) { + Y_ASSERT(begin <= end); + static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme"); + static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme"); + + if ((size_t)(end - begin) < (size_t)TReallyFastRng32::RandMax()) { + PartialShuffle(begin, end, size, TReallyFastRng32(Seed())); + } else { + PartialShuffle(begin, end, size, TFastRng64(Seed())); + } +} + +template <typename TRandIter, typename TRandIterEnd, typename TRandGen> +inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size, TRandGen&& gen) { + Y_ASSERT(begin <= end); + + const size_t totalSize = end - begin; + Y_ASSERT(size <= totalSize); // Size of shuffled part should be less than or equal to the size of container + if (totalSize == 0) { + return; + } + size = Min(size, totalSize - 1); + + for (size_t i = 0; i < size; ++i) { + DoSwap(*(begin + i), *(begin + gen.Uniform(i, totalSize))); + } +} + template <typename TRange> inline void ShuffleRange(TRange& range) { auto b = range.begin(); diff --git a/util/random/shuffle_ut.cpp b/util/random/shuffle_ut.cpp index 87cbae94c0..8cab95d8b2 100644 --- a/util/random/shuffle_ut.cpp +++ b/util/random/shuffle_ut.cpp @@ -45,31 +45,70 @@ Y_UNIT_TEST_SUITE(TRandUtilsTest) { UNIT_ASSERT(s0 != s1); // if shuffle does work, chances it will fail are 1 to 64!. } + template <typename... A> + static void TestIterPartial(A&&... args) { + TString s0, s1; + + auto f = [&](int shuffledSize) { + auto b = s1.begin(); + auto e = s1.end(); + + PartialShuffle(b, e, shuffledSize, args...); + }; + + s1 = ""; + f(0); + + s1 = "01"; + f(1); + + s1 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,.;?!"; + s0 = s1.copy(); + f(2); + size_t matchesCount = 0; + for (size_t i = 0; i < s0.size(); ++i) { + matchesCount += (s0[i] == s1[i]); + } + + UNIT_ASSERT(matchesCount >= s1.size() - 4); // partial shuffle doesn't make non-necessary swaps. + + s1 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,.;?!"; + s0 = s1.copy(); + f(64); + + UNIT_ASSERT(s0 != s1); // if partial shuffle does work, chances it will fail are 1 to 64!. + } + Y_UNIT_TEST(TestShuffle) { TestRange(); + TestIterPartial(); } Y_UNIT_TEST(TestShuffleMersenne64) { TMersenne<ui64> prng(42); TestRange(prng); + TestIterPartial(prng); } Y_UNIT_TEST(TestShuffleMersenne32) { TMersenne<ui32> prng(24); TestIter(prng); + TestIterPartial(prng); } Y_UNIT_TEST(TestShuffleFast32) { TFastRng32 prng(24, 0); TestIter(prng); + TestIterPartial(prng); } Y_UNIT_TEST(TestShuffleFast64) { TFastRng64 prng(24, 0, 25, 1); TestIter(prng); + TestIterPartial(prng); } } |