Skip to content

Commit 41d2b2a

Browse files
authored
Merge pull request #56 from swift-nav/random_without_replacement_bug
Bug: randint_without_replacement
2 parents 4fd5a10 + 1e590a7 commit 41d2b2a

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

albatross/random_utils.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,39 @@
1717
#include <random>
1818
#include <set>
1919

20+
namespace albatross {
2021
/*
2122
* Samples integers between low and high (inclusive) with replacement.
2223
*/
2324
inline std::vector<std::size_t>
2425
randint_without_replacement(std::size_t n, std::size_t low, std::size_t high,
2526
std::default_random_engine &gen) {
2627
assert(n >= 0);
27-
assert(n <= (high - low));
28+
29+
std::size_t n_choices = high - low + 1;
30+
assert(n <= n_choices);
31+
32+
if (n == (high - low + 1)) {
33+
std::vector<std::size_t> all_inds(n);
34+
std::iota(all_inds.begin(), all_inds.end(), 0);
35+
return all_inds;
36+
}
37+
38+
if (n > n_choices / 2 + 1) {
39+
// Since we're trying to randomly sample more than half of the
40+
// points it'll be faster to randomly sample which points we
41+
// should throw out than which ones we should keep.
42+
const auto to_throw_out =
43+
randint_without_replacement(n_choices - n, low, high, gen);
44+
auto to_keep = indices_complement(to_throw_out, high - low);
45+
46+
if (low != 0) {
47+
for (auto &el : to_keep) {
48+
el += low;
49+
}
50+
}
51+
return to_keep;
52+
}
2853

2954
std::uniform_int_distribution<std::size_t> dist(low, high);
3055
std::set<int> samples;
@@ -33,5 +58,6 @@ randint_without_replacement(std::size_t n, std::size_t low, std::size_t high,
3358
}
3459
return std::vector<std::size_t>(samples.begin(), samples.end());
3560
}
61+
}
3662

3763
#endif

tests/test_random_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,9 @@ TEST(test_random_utils, randint_without_replacement) {
3333
}
3434
}
3535

36+
TEST(test_random_utils, randint_without_replacement_full_set) {
37+
std::default_random_engine gen;
38+
const auto inds = randint_without_replacement(10, 0, 9, gen);
39+
}
40+
3641
} // namespace albatross

0 commit comments

Comments
 (0)