1717#include < random>
1818#include < set>
1919
20+ namespace albatross {
2021/*
2122 * Samples integers between low and high (inclusive) with replacement.
2223 */
2324inline std::vector<std::size_t >
2425randint_without_replacement (std::size_t n, std::size_t low, std::size_t high,
2526 std::default_random_engine &gen) {
2627 assert (n >= 0 );
27- assert (n <= (high - low));
28+
29+ std::size_t n_choices = high - low + 1 ;
30+ assert (n <= n_choices);
31+
32+ if (n == (high - low + 1 )) {
33+ std::vector<std::size_t > all_inds (n);
34+ std::iota (all_inds.begin (), all_inds.end (), 0 );
35+ return all_inds;
36+ }
37+
38+ if (n > n_choices / 2 + 1 ) {
39+ // Since we're trying to randomly sample more than half of the
40+ // points it'll be faster to randomly sample which points we
41+ // should throw out than which ones we should keep.
42+ const auto to_throw_out =
43+ randint_without_replacement (n_choices - n, low, high, gen);
44+ auto to_keep = indices_complement (to_throw_out, high - low);
45+
46+ if (low != 0 ) {
47+ for (auto &el : to_keep) {
48+ el += low;
49+ }
50+ }
51+ return to_keep;
52+ }
2853
2954 std::uniform_int_distribution<std::size_t > dist (low, high);
3055 std::set<int > samples;
@@ -33,5 +58,6 @@ randint_without_replacement(std::size_t n, std::size_t low, std::size_t high,
3358 }
3459 return std::vector<std::size_t >(samples.begin (), samples.end ());
3560}
61+ }
3662
3763#endif
0 commit comments