Skip to content

Commit 5f5b580

Browse files
authored
Allow fn choose_multiple_weighted to return fewer than amount elts (#1623)
1 parent 7808f4e commit 5f5b580

File tree

4 files changed

+42
-45
lines changed

4 files changed

+42
-45
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
1212
- Fix feature `simd_support` for recent nightly rust (#1586)
1313
- Add `Alphabetic` distribution. (#1587)
1414
- Re-export `rand_core` (#1602)
15+
- Allow `fn rand::seq::index::sample_weighted` and `fn IndexedRandom::choose_multiple_weighted` to return fewer than `amount` results (#1623), reverting an undocumented change (#1382) to the previous release.
1516

1617
## [0.9.0] - 2025-01-27
1718
### Security and unsafe

src/seq/index.rs

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,11 @@ where
282282
}
283283
}
284284

285-
/// Randomly sample exactly `amount` distinct indices from `0..length`
285+
/// Randomly sample `amount` distinct indices from `0..length`
286286
///
287-
/// Results are in arbitrary order (there is no guarantee of shuffling or
288-
/// ordering).
287+
/// The result may contain less than `amount` indices if insufficient non-zero
288+
/// weights are available. Results are returned in an arbitrary order (there is
289+
/// no guarantee of shuffling or ordering).
289290
///
290291
/// Function `weight` is called once for each index to provide weights.
291292
///
@@ -295,7 +296,6 @@ where
295296
///
296297
/// Error cases:
297298
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
298-
/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive.
299299
///
300300
/// This implementation uses `O(length + amount)` space and `O(length)` time.
301301
#[cfg(feature = "std")]
@@ -328,18 +328,20 @@ where
328328
}
329329
}
330330

331-
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
332-
/// return them in an arbitrary order (there is no guarantee of shuffling or
333-
/// ordering). The weights are to be provided by the input function `weights`,
334-
/// which will be called once for each index.
331+
/// Randomly sample `amount` distinct indices from `0..length`
332+
///
333+
/// The result may contain less than `amount` indices if insufficient non-zero
334+
/// weights are available. Results are returned in an arbitrary order (there is
335+
/// no guarantee of shuffling or ordering).
336+
///
337+
/// Function `weight` is called once for each index to provide weights.
335338
///
336339
/// This implementation is based on the algorithm A-ExpJ as found in
337340
/// [Efraimidis and Spirakis, 2005](https://doi.org/10.1016/j.ipl.2005.11.003).
338341
/// It uses `O(length + amount)` space and `O(length)` time.
339342
///
340343
/// Error cases:
341344
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
342-
/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive.
343345
#[cfg(feature = "std")]
344346
fn sample_efraimidis_spirakis<R, F, X, N>(
345347
rng: &mut R,
@@ -403,28 +405,26 @@ where
403405
index += N::one();
404406
}
405407

406-
if candidates.len() < amount.as_usize() {
407-
return Err(WeightError::InsufficientNonZero);
408-
}
408+
if index < length {
409+
let mut x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
410+
while index < length {
411+
let weight = weight(index.as_usize()).into();
412+
if weight > 0.0 {
413+
x -= weight;
414+
if x <= 0.0 {
415+
let min_candidate = candidates.pop().unwrap();
416+
let t = (min_candidate.key * weight).exp();
417+
let key = rng.random_range(t..1.0).ln() / weight;
418+
candidates.push(Element { index, key });
409419

410-
let mut x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
411-
while index < length {
412-
let weight = weight(index.as_usize()).into();
413-
if weight > 0.0 {
414-
x -= weight;
415-
if x <= 0.0 {
416-
let min_candidate = candidates.pop().unwrap();
417-
let t = (min_candidate.key * weight).exp();
418-
let key = rng.random_range(t..1.0).ln() / weight;
419-
candidates.push(Element { index, key });
420-
421-
x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
420+
x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
421+
}
422+
} else if !(weight >= 0.0) {
423+
return Err(WeightError::InvalidWeight);
422424
}
423-
} else if !(weight >= 0.0) {
424-
return Err(WeightError::InvalidWeight);
425-
}
426425

427-
index += N::one();
426+
index += N::one();
427+
}
428428
}
429429

430430
Ok(IndexVec::from(
@@ -653,7 +653,7 @@ mod test {
653653
}
654654

655655
let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10);
656-
assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero);
656+
assert_eq!(r.unwrap().len(), 9);
657657
}
658658

659659
#[test]

src/seq/iterator.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ pub trait IteratorRandom: Iterator + Sized {
134134
/// force every element to be created regardless call `.inspect(|e| ())`.
135135
///
136136
/// [`choose`]: IteratorRandom::choose
137+
//
138+
// Clippy is wrong here: we need to iterate over all entries with the RNG to
139+
// ensure that choosing is *stable*.
140+
#[allow(clippy::double_ended_iterator_last)]
137141
fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
138142
where
139143
R: Rng + ?Sized,

src/seq/slice.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -173,26 +173,18 @@ pub trait IndexedRandom: Index<usize> {
173173

174174
/// Biased sampling of `amount` distinct elements
175175
///
176-
/// Similar to [`choose_multiple`], but where the likelihood of each element's
177-
/// inclusion in the output may be specified. The elements are returned in an
178-
/// arbitrary, unspecified order.
176+
/// Similar to [`choose_multiple`], but where the likelihood of each
177+
/// element's inclusion in the output may be specified. Zero-weighted
178+
/// elements are never returned; the result may therefore contain fewer
179+
/// elements than `amount` even when `self.len() >= amount`. The elements
180+
/// are returned in an arbitrary, unspecified order.
179181
///
180182
/// The specified function `weight` maps each item `x` to a relative
181183
/// likelihood `weight(x)`. The probability of each item being selected is
182184
/// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
183185
///
184-
/// If all of the weights are equal, even if they are all zero, each element has
185-
/// an equal likelihood of being selected.
186-
///
187-
/// This implementation uses `O(length + amount)` space and `O(length)` time
188-
/// if the "nightly" feature is enabled, or `O(length)` space and
189-
/// `O(length + amount * log length)` time otherwise.
190-
///
191-
/// # Known issues
192-
///
193-
/// The algorithm currently used to implement this method loses accuracy
194-
/// when small values are used for weights.
195-
/// See [#1476](https://github.com/rust-random/rand/issues/1476).
186+
/// This implementation uses `O(length + amount)` space and `O(length)` time.
187+
/// See [`index::sample_weighted`] for details.
196188
///
197189
/// # Example
198190
///
@@ -687,7 +679,7 @@ mod test {
687679
// Case 2: All of the weights are 0
688680
let choices = [('a', 0), ('b', 0), ('c', 0)];
689681
let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
690-
assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero);
682+
assert_eq!(r.unwrap().len(), 0);
691683

692684
// Case 3: Negative weights
693685
let choices = [('a', -1), ('b', 1), ('c', 1)];

0 commit comments

Comments
 (0)