diff options
author | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-06-07 05:49:38 +0300 |
---|---|---|
committer | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-06-07 05:49:38 +0300 |
commit | dd24834944d3b49f6b7be1199a349c7d9ea3b7b2 (patch) | |
tree | dc70048d5f9846d23e3f92e9d1bbe806553b716a /util/random/shuffle.h | |
parent | 811d88e6a929f3a50645eeb76bdea2c47dbd034e (diff) | |
download | ydb-dd24834944d3b49f6b7be1199a349c7d9ea3b7b2.tar.gz |
Pull request "ShufflePart function" by @igormarkov00 from https://github.com/catboost/catboost/pull/2087
MERGED FROM https://github.com/catboost/catboost/pull/2087
ref:d27dfbe948e17ef1feb8ad2b13b409915afc86c8
Diffstat (limited to 'util/random/shuffle.h')
-rw-r--r-- | util/random/shuffle.h | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/util/random/shuffle.h b/util/random/shuffle.h index 274ac147c9..218695696b 100644 --- a/util/random/shuffle.h +++ b/util/random/shuffle.h @@ -4,11 +4,13 @@ #include "entropy.h" #include <util/generic/utility.h> +#include <util/system/yassert.h> // some kind of https://en.wikipedia.org/wiki/Fisher–Yates_shuffle#The_modern_algorithm template <typename TRandIter, typename TRandIterEnd> inline void Shuffle(TRandIter begin, TRandIterEnd end) { + Y_ASSERT(begin <= end); static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme"); static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme"); @@ -21,6 +23,7 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end) { template <typename TRandIter, typename TRandIterEnd, typename TRandGen> inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) { + Y_ASSERT(begin <= end); const size_t sz = end - begin; for (size_t i = 1; i < sz; ++i) { @@ -28,6 +31,37 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) { } } +// Fills first size elements of array with equiprobably randomly +// chosen elements of array with no replacement +template <typename TRandIter, typename TRandIterEnd> +inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size) { + Y_ASSERT(begin <= end); + static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme"); + static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme"); + + if ((size_t)(end - begin) < (size_t)TReallyFastRng32::RandMax()) { + PartialShuffle(begin, end, size, TReallyFastRng32(Seed())); + } else { + PartialShuffle(begin, end, size, TFastRng64(Seed())); + } +} + +template <typename TRandIter, typename TRandIterEnd, typename TRandGen> +inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size, TRandGen&& gen) { + Y_ASSERT(begin <= end); + + const size_t totalSize = end - begin; + Y_ASSERT(size <= totalSize); // Size of shuffled part should be less than or equal to the size of container + if (totalSize == 0) { + return; + } + size = Min(size, totalSize - 1); + + for (size_t i = 0; i < size; ++i) { + DoSwap(*(begin + i), *(begin + gen.Uniform(i, totalSize))); + } +} + template <typename TRange> inline void ShuffleRange(TRange& range) { auto b = range.begin(); |