aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
authorarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-06-07 05:49:38 +0300
committerarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-06-07 05:49:38 +0300
commitdd24834944d3b49f6b7be1199a349c7d9ea3b7b2 (patch)
treedc70048d5f9846d23e3f92e9d1bbe806553b716a /util
parent811d88e6a929f3a50645eeb76bdea2c47dbd034e (diff)
downloadydb-dd24834944d3b49f6b7be1199a349c7d9ea3b7b2.tar.gz
Pull request "ShufflePart function" by @igormarkov00 from https://github.com/catboost/catboost/pull/2087
MERGED FROM https://github.com/catboost/catboost/pull/2087 ref:d27dfbe948e17ef1feb8ad2b13b409915afc86c8
Diffstat (limited to 'util')
-rw-r--r--util/random/shuffle.h34
-rw-r--r--util/random/shuffle_ut.cpp39
2 files changed, 73 insertions, 0 deletions
diff --git a/util/random/shuffle.h b/util/random/shuffle.h
index 274ac147c9..218695696b 100644
--- a/util/random/shuffle.h
+++ b/util/random/shuffle.h
@@ -4,11 +4,13 @@
#include "entropy.h"
#include <util/generic/utility.h>
+#include <util/system/yassert.h>
// some kind of https://en.wikipedia.org/wiki/Fisher–Yates_shuffle#The_modern_algorithm
template <typename TRandIter, typename TRandIterEnd>
inline void Shuffle(TRandIter begin, TRandIterEnd end) {
+ Y_ASSERT(begin <= end);
static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme");
static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme");
@@ -21,6 +23,7 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end) {
template <typename TRandIter, typename TRandIterEnd, typename TRandGen>
inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) {
+ Y_ASSERT(begin <= end);
const size_t sz = end - begin;
for (size_t i = 1; i < sz; ++i) {
@@ -28,6 +31,37 @@ inline void Shuffle(TRandIter begin, TRandIterEnd end, TRandGen&& gen) {
}
}
+// Fills first size elements of array with equiprobably randomly
+// chosen elements of array with no replacement
+template <typename TRandIter, typename TRandIterEnd>
+inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size) {
+ Y_ASSERT(begin <= end);
+ static_assert(sizeof(end - begin) <= sizeof(size_t), "fixme");
+ static_assert(sizeof(TReallyFastRng32::RandMax()) <= sizeof(size_t), "fixme");
+
+ if ((size_t)(end - begin) < (size_t)TReallyFastRng32::RandMax()) {
+ PartialShuffle(begin, end, size, TReallyFastRng32(Seed()));
+ } else {
+ PartialShuffle(begin, end, size, TFastRng64(Seed()));
+ }
+}
+
+template <typename TRandIter, typename TRandIterEnd, typename TRandGen>
+inline void PartialShuffle(TRandIter begin, TRandIterEnd end, size_t size, TRandGen&& gen) {
+ Y_ASSERT(begin <= end);
+
+ const size_t totalSize = end - begin;
+ Y_ASSERT(size <= totalSize); // Size of shuffled part should be less than or equal to the size of container
+ if (totalSize == 0) {
+ return;
+ }
+ size = Min(size, totalSize - 1);
+
+ for (size_t i = 0; i < size; ++i) {
+ DoSwap(*(begin + i), *(begin + gen.Uniform(i, totalSize)));
+ }
+}
+
template <typename TRange>
inline void ShuffleRange(TRange& range) {
auto b = range.begin();
diff --git a/util/random/shuffle_ut.cpp b/util/random/shuffle_ut.cpp
index 87cbae94c0..8cab95d8b2 100644
--- a/util/random/shuffle_ut.cpp
+++ b/util/random/shuffle_ut.cpp
@@ -45,31 +45,70 @@ Y_UNIT_TEST_SUITE(TRandUtilsTest) {
UNIT_ASSERT(s0 != s1); // if shuffle does work, chances it will fail are 1 to 64!.
}
+ template <typename... A>
+ static void TestIterPartial(A&&... args) {
+ TString s0, s1;
+
+ auto f = [&](int shuffledSize) {
+ auto b = s1.begin();
+ auto e = s1.end();
+
+ PartialShuffle(b, e, shuffledSize, args...);
+ };
+
+ s1 = "";
+ f(0);
+
+ s1 = "01";
+ f(1);
+
+ s1 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,.;?!";
+ s0 = s1.copy();
+ f(2);
+ size_t matchesCount = 0;
+ for (size_t i = 0; i < s0.size(); ++i) {
+ matchesCount += (s0[i] == s1[i]);
+ }
+
+ UNIT_ASSERT(matchesCount >= s1.size() - 4); // partial shuffle doesn't make non-necessary swaps.
+
+ s1 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,.;?!";
+ s0 = s1.copy();
+ f(64);
+
+ UNIT_ASSERT(s0 != s1); // if partial shuffle does work, chances it will fail are 1 to 64!.
+ }
+
Y_UNIT_TEST(TestShuffle) {
TestRange();
+ TestIterPartial();
}
Y_UNIT_TEST(TestShuffleMersenne64) {
TMersenne<ui64> prng(42);
TestRange(prng);
+ TestIterPartial(prng);
}
Y_UNIT_TEST(TestShuffleMersenne32) {
TMersenne<ui32> prng(24);
TestIter(prng);
+ TestIterPartial(prng);
}
Y_UNIT_TEST(TestShuffleFast32) {
TFastRng32 prng(24, 0);
TestIter(prng);
+ TestIterPartial(prng);
}
Y_UNIT_TEST(TestShuffleFast64) {
TFastRng64 prng(24, 0, 25, 1);
TestIter(prng);
+ TestIterPartial(prng);
}
}