aboutsummaryrefslogtreecommitdiffstats
path: root/util/random/shuffle.h
diff options
context:
space:
mode:
authorarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-06-07 05:49:38 +0300
committerarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-06-07 05:49:38 +0300
commitdd24834944d3b49f6b7be1199a349c7d9ea3b7b2 (patch)
treedc70048d5f9846d23e3f92e9d1bbe806553b716a /util/random/shuffle.h
parent811d88e6a929f3a50645eeb76bdea2c47dbd034e (diff)
downloadydb-dd24834944d3b49f6b7be1199a349c7d9ea3b7b2.tar.gz
Pull request "ShufflePart function" by @igormarkov00 from https://github.com/catboost/catboost/pull/2087
MERGED FROM https://github.com/catboost/catboost/pull/2087 ref:d27dfbe948e17ef1feb8ad2b13b409915afc86c8
Diffstat (limited to 'util/random/shuffle.h')
-rw-r--r--util/random/shuffle.h34
1 files changed, 34 insertions, 0 deletions
diff --git a/util/random/shuffle.h b/util/random/shuffle.h
index 274ac147c9..218695696b 100644
--- a/util/random/shuffle.h
+++ b/util/random/shuffle.h
@@ -4,11 +4,13 @@
#include "entropy.h"
#include <util/generic/utility.h>
+#include <util/system/yassert.h>
// some kind of https://en.wikipedia.org/wiki/Fisher–Yates_shuffle#The_modern_algorithm
template <typename TRandIter, typename TRandIterEnd>
inline void Shuffle(TRandIter begin, TRandIterEnd end) {
+ Y_ASSERT(begin <= end);
static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme");
static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme");
@@ -21,6 +23,7 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end) {
template <typename TRandIter, typename TRandIterEnd, typename TRandGen>
inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) {
+ Y_ASSERT(begin <= end);
const size_t sz = end - begin;
for (size_t i = 1; i < sz; ++i) {
@@ -28,6 +31,37 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) {
}
}
+// Fills first size elements of array with equiprobably randomly
+// chosen elements of array with no replacement
+template <typename TRandIter, typename TRandIterEnd>
+inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size) {
+ Y_ASSERT(begin <= end);
+ static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme");
+ static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme");
+
+ if ((size_t)(end - begin) < (size_t)TReallyFastRng32::RandMax()) {
+ PartialShuffle(begin, end, size, TReallyFastRng32(Seed()));
+ } else {
+ PartialShuffle(begin, end, size, TFastRng64(Seed()));
+ }
+}
+
+template <typename TRandIter, typename TRandIterEnd, typename TRandGen>
+inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size, TRandGen&& gen) {
+ Y_ASSERT(begin <= end);
+
+ const size_t totalSize = end - begin;
+ Y_ASSERT(size <= totalSize); // Size of shuffled part should be less than or equal to the size of container
+ if (totalSize == 0) {
+ return;
+ }
+ size = Min(size, totalSize - 1);
+
+ for (size_t i = 0; i < size; ++i) {
+ DoSwap(*(begin + i), *(begin + gen.Uniform(i, totalSize)));
+ }
+}
+
template <typename TRange>
inline void ShuffleRange(TRange& range) {
auto b = range.begin();