Skip to content

Commit

Permalink
Re-org with distr::slice, distr::weighted modules (#1548)
Browse files Browse the repository at this point in the history
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty`
- Rename trait `DistString` -> `SampleString`
- Rename `DistIter` -> `Iter`, `DistMap` -> `Map`
- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight,
Error, WeightedIndex}`
- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` ->
`weighted::{..}`
- Move `weighted_tree::WeightedTreeIndex` ->
`weighted::WeightedTreeIndex`
  • Loading branch information
dhardy authored Jan 14, 2025
1 parent 16eb7de commit b4b1eb7
Show file tree
Hide file tree
Showing 23 changed files with 354 additions and 334 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/benches.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defaults:

jobs:
clippy-fmt:
name: Check Clippy and rustfmt
name: "Benches: Check Clippy and rustfmt"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -33,7 +33,7 @@ jobs:
- name: Clippy
run: cargo clippy --all-targets -- -D warnings
benches:
name: Test benchmarks
name: "Benches: Test"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/distr_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defaults:

jobs:
clippy-fmt:
name: Check Clippy and rustfmt
name: "distr_test: Check Clippy and rustfmt"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -33,7 +33,7 @@ jobs:
- name: Clippy
run: cargo clippy --all-targets -- -D warnings
ks-tests:
name: Run Komogorov Smirnov tests
name: "distr_test: Run Komogorov Smirnov tests"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
toolchain: stable
components: clippy, rustfmt
- name: Check Clippy
run: cargo clippy --all --all-targets -- -D warnings
run: cargo clippy --workspace -- -D warnings
- name: Check rustfmt
run: cargo fmt --all -- --check

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.

## [0.9.0-beta.3] - 2025-01-03
- Add feature `thread_rng` (#1547)
- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548)
- Rename trait `distr::DistString` -> `distr::SampleString` (#1548)
- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548)
- Move `distr::{Weight, WeightError, WeightedIndex}` -> `distr::weighted::{Weight, Error, WeightedIndex}` (#1548)

## [0.9.0-beta.1] - 2024-11-30
- Bump `rand_core` version
Expand Down
1 change: 1 addition & 0 deletions benches/benches/distr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use criterion_cycles_per_byte::CyclesPerByte;

use rand::prelude::*;
use rand_distr::weighted::*;
use rand_distr::*;

// At this time, distributions are optimised for 64-bit platforms.
Expand Down
2 changes: 1 addition & 1 deletion benches/benches/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// except according to those terms.

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::distr::WeightedIndex;
use rand::distr::weighted::WeightedIndex;
use rand::prelude::*;
use rand::seq::index::sample_weighted;

Expand Down
4 changes: 2 additions & 2 deletions distr_test/tests/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

mod ks;
use ks::test_discrete;
use rand::distr::{Distribution, WeightedIndex};
use rand::distr::Distribution;
use rand::seq::{IndexedRandom, IteratorRandom};
use rand_distr::{WeightedAliasIndex, WeightedTreeIndex};
use rand_distr::weighted::*;

/// Takes the unnormalized pdf and creates the cdf of a discrete distribution
fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 {
Expand Down
6 changes: 6 additions & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.5.0-beta.3] - 2025-01-03
- Bump `rand` version (#1547)
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548)
- Rename trait `DistString` -> `SampleString` (#1548)
- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548)
- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548)
- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548)
- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548)

## [0.5.0-beta.2] - 2024-11-30
- Bump `rand` version
Expand Down
24 changes: 7 additions & 17 deletions rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
//!
//! The following are re-exported:
//!
//! - The [`Distribution`] trait and [`DistIter`] helper type
//! - The [`Distribution`] trait and [`Iter`] helper type
//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`],
//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions
//! [`Open01`], [`Bernoulli`] distributions
//! - The [`weighted`] module
//!
//! ## Distributions
//!
Expand Down Expand Up @@ -76,9 +77,6 @@
//! - [`UnitBall`] distribution
//! - [`UnitCircle`] distribution
//! - [`UnitDisc`] distribution
//! - Alternative implementations for weighted index sampling
//! - [`WeightedAliasIndex`] distribution
//! - [`WeightedTreeIndex`] distribution
//! - Misc. distributions
//! - [`InverseGaussian`] distribution
//! - [`NormalInverseGaussian`] distribution
Expand All @@ -94,7 +92,7 @@ extern crate std;
use rand::Rng;

pub use rand::distr::{
uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01,
uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01,
StandardUniform, Uniform,
};

Expand Down Expand Up @@ -128,16 +126,13 @@ pub use self::unit_sphere::UnitSphere;
pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zeta::{Error as ZetaError, Zeta};
pub use self::zipf::{Error as ZipfError, Zipf};
#[cfg(feature = "alloc")]
pub use rand::distr::{WeightError, WeightedIndex};
pub use student_t::StudentT;
#[cfg(feature = "alloc")]
pub use weighted_alias::WeightedAliasIndex;
#[cfg(feature = "alloc")]
pub use weighted_tree::WeightedTreeIndex;

pub use num_traits;

#[cfg(feature = "alloc")]
pub mod weighted;

#[cfg(test)]
#[macro_use]
mod test {
Expand Down Expand Up @@ -189,11 +184,6 @@ mod test {
}
}

#[cfg(feature = "alloc")]
pub mod weighted_alias;
#[cfg(feature = "alloc")]
pub mod weighted_tree;

mod beta;
mod binomial;
mod cauchy;
Expand Down
28 changes: 28 additions & 0 deletions rand_distr/src/weighted/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Weighted (index) sampling
//!
//! This module is a superset of [`rand::distr::weighted`].
//!
//! Multiple implementations of weighted index sampling are provided:
//!
//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction
//! and `O(log N)` sampling over `N` weights.
//! It also supports updating weights with `O(N)` time.
//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high
//! construction time many samples are required to outperform [`WeightedIndex`].
//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and
//! update/insertion/removal of weights with `O(log N)` time.
mod weighted_alias;
mod weighted_tree;

pub use rand::distr::weighted::*;
pub use weighted_alias::*;
pub use weighted_tree::*;
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! This module contains an implementation of alias method for sampling random
//! indices with probabilities proportional to a collection of weights.
use super::WeightError;
use super::Error;
use crate::{uniform::SampleUniform, Distribution, Uniform};
use alloc::{boxed::Box, vec, vec::Vec};
use core::fmt;
Expand Down Expand Up @@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize};
/// # Example
///
/// ```
/// use rand_distr::WeightedAliasIndex;
/// use rand_distr::weighted::WeightedAliasIndex;
/// use rand::prelude::*;
///
/// let choices = vec!['a', 'b', 'c'];
Expand Down Expand Up @@ -85,14 +85,14 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
/// Creates a new [`WeightedAliasIndex`].
///
/// Error cases:
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
/// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
/// - [`Error::InvalidWeight`] when a weight is not-a-number,
/// negative or greater than `max = W::MAX / weights.len()`.
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
/// - [`Error::InsufficientNonZero`] when the sum of all weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, Error> {
let n = weights.len();
if n == 0 || n > u32::MAX as usize {
return Err(WeightError::InvalidInput);
return Err(Error::InvalidInput);
}
let n = n as u32;

Expand All @@ -103,7 +103,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
.iter()
.all(|&w| W::ZERO <= w && w <= max_weight_size)
{
return Err(WeightError::InvalidWeight);
return Err(Error::InvalidWeight);
}

// The sum of weights will represent 100% of no alias odds.
Expand All @@ -115,7 +115,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
weight_sum
};
if weight_sum == W::ZERO {
return Err(WeightError::InsufficientNonZero);
return Err(Error::InsufficientNonZero);
}

// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
Expand Down Expand Up @@ -384,23 +384,23 @@ mod test {
// Floating point special cases
assert_eq!(
WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
WeightError::InsufficientNonZero
Error::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand All @@ -418,11 +418,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand All @@ -440,11 +440,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand Down Expand Up @@ -491,15 +491,15 @@ mod test {

assert_eq!(
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
WeightError::InvalidInput
Error::InvalidInput
);
assert_eq!(
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
WeightError::InsufficientNonZero
Error::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand Down
Loading

0 comments on commit b4b1eb7

Please sign in to comment.