Skip to content

Commit

Permalink
Merge pull request #338 from rust-ndarray/test-fixed-seed
Browse files Browse the repository at this point in the history
Fix PRNG seed for random vectors and matrices in tests
  • Loading branch information
termoshtt authored Sep 17, 2022
2 parents f3a9f4c + 4ea4f36 commit 3a8520c
Show file tree
Hide file tree
Showing 26 changed files with 302 additions and 192 deletions.
1 change: 1 addition & 0 deletions ndarray-linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ paste = "1.0.5"
criterion = "0.3.4"
# Keep the same version as ndarray's dependency!
approx = { version = "0.4.0", features = ["num-complex"] }
rand_pcg = "0.3.1"

[[bench]]
name = "truncated_eig"
Expand Down
18 changes: 12 additions & 6 deletions ndarray-linalg/src/lobpcg/lobpcg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ mod tests {
/// Test the `sorted_eigen` function
#[test]
fn test_sorted_eigen() {
let matrix: Array2<f64> = generate::random((10, 10)) * 10.0;
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let matrix: Array2<f64> = generate::random_using((10, 10), &mut rng) * 10.0;
let matrix = matrix.t().dot(&matrix);

// return all eigenvectors with largest first
Expand All @@ -476,7 +477,8 @@ mod tests {
/// Test the masking function
#[test]
fn test_masking() {
let matrix: Array2<f64> = generate::random((10, 5)) * 10.0;
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let matrix: Array2<f64> = generate::random_using((10, 5), &mut rng) * 10.0;
let masked_matrix = ndarray_mask(matrix.view(), &[true, true, false, true, false]);
close_l2(
&masked_matrix.slice(s![.., 2]),
Expand All @@ -488,7 +490,8 @@ mod tests {
/// Test orthonormalization of a random matrix
#[test]
fn test_orthonormalize() {
let matrix: Array2<f64> = generate::random((10, 10)) * 10.0;
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let matrix: Array2<f64> = generate::random_using((10, 10), &mut rng) * 10.0;

let (n, l) = orthonormalize(matrix.clone()).unwrap();

Expand All @@ -509,7 +512,8 @@ mod tests {
assert_symmetric(a);

let n = a.len_of(Axis(0));
let x: Array2<f64> = generate::random((n, num));
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let x: Array2<f64> = generate::random_using((n, num), &mut rng);

let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-5, n * 2, order);
match result {
Expand Down Expand Up @@ -553,7 +557,8 @@ mod tests {
#[test]
fn test_eigsolver_constructed() {
let n = 50;
let tmp = generate::random((n, n));
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let tmp = generate::random_using((n, n), &mut rng);
//let (v, _) = tmp.qr_square().unwrap();
let (v, _) = orthonormalize(tmp).unwrap();

Expand All @@ -570,7 +575,8 @@ mod tests {
fn test_eigsolver_constrained() {
let diag = arr1(&[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
let a = Array2::from_diag(&diag);
let x: Array2<f64> = generate::random((10, 1));
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let x: Array2<f64> = generate::random_using((10, 1), &mut rng);
let y: Array2<f64> = arr2(&[
[1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.],
Expand Down
3 changes: 2 additions & 1 deletion ndarray-linalg/src/lobpcg/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ mod tests {

#[test]
fn test_truncated_svd_random() {
let a: Array2<f64> = generate::random((50, 10));
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = generate::random_using((50, 10), &mut rng);

let res = TruncatedSvd::new(a.clone(), Order::Largest)
.precision(1e-5)
Expand Down
20 changes: 5 additions & 15 deletions ndarray-linalg/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,13 @@
//! Solve `A * x = b`:
//!
//! ```
//! #[macro_use]
//! extern crate ndarray;
//! extern crate ndarray_linalg;
//!
//! use ndarray::prelude::*;
//! use ndarray_linalg::Solve;
//! # fn main() {
//!
//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
//! let b: Array1<f64> = array![1., -2., 0.];
//! let x = a.solve_into(b).unwrap();
//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
//!
//! # }
//! ```
//!
//! There are also special functions for solving `A^T * x = b` and
Expand All @@ -29,21 +22,18 @@
//! the beginning than solving directly using `A`:
//!
//! ```
//! # extern crate ndarray;
//! # extern crate ndarray_linalg;
//!
//! use ndarray::prelude::*;
//! use ndarray_linalg::*;
//! # fn main() {
//!
//! let a: Array2<f64> = random((3, 3));
//! /// Use fixed algorithm and seed of PRNG for reproducible test
//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
//!
//! let a: Array2<f64> = random_using((3, 3), &mut rng);
//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
//! for _ in 0..10 {
//! let b: Array1<f64> = random(3);
//! let b: Array1<f64> = random_using(3, &mut rng);
//! let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
//! }
//!
//! # }
//! ```
use ndarray::*;
Expand Down
19 changes: 5 additions & 14 deletions ndarray-linalg/src/solveh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,8 @@
//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
//!
//! ```
//! #[macro_use]
//! extern crate ndarray;
//! extern crate ndarray_linalg;
//!
//! use ndarray::prelude::*;
//! use ndarray_linalg::SolveH;
//! # fn main() {
//!
//! let a: Array2<f64> = array![
//! [3., 2., -1.],
Expand All @@ -24,29 +19,25 @@
//! let b: Array1<f64> = array![11., -12., 1.];
//! let x = a.solveh_into(b).unwrap();
//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
//!
//! # }
//! ```
//!
//! If you are solving multiple systems of linear equations with the same
//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
//! the factorization once at the beginning than solving directly using `A`:
//!
//! ```
//! # extern crate ndarray;
//! # extern crate ndarray_linalg;
//! use ndarray::prelude::*;
//! use ndarray_linalg::*;
//! # fn main() {
//!
//! let a: Array2<f64> = random((3, 3));
//! /// Use fixed algorithm and seed of PRNG for reproducible test
//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
//!
//! let a: Array2<f64> = random_using((3, 3), &mut rng);
//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
//! for _ in 0..10 {
//! let b: Array1<f64> = random(3);
//! let b: Array1<f64> = random_using(3, &mut rng);
//! let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
//! }
//!
//! # }
//! ```
use ndarray::*;
Expand Down
20 changes: 12 additions & 8 deletions ndarray-linalg/tests/arnoldi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use ndarray_linalg::{krylov::*, *};

#[test]
fn aq_qh_mgs() {
let a: Array2<f64> = random((5, 5));
let v: Array1<f64> = random(5);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = random_using((5, 5), &mut rng);
let v: Array1<f64> = random_using(5, &mut rng);
let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9);
println!("A = \n{:?}", &a);
println!("Q = \n{:?}", &q);
Expand All @@ -18,8 +19,9 @@ fn aq_qh_mgs() {

#[test]
fn aq_qh_householder() {
let a: Array2<f64> = random((5, 5));
let v: Array1<f64> = random(5);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = random_using((5, 5), &mut rng);
let v: Array1<f64> = random_using(5, &mut rng);
let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9);
println!("A = \n{:?}", &a);
println!("Q = \n{:?}", &q);
Expand All @@ -33,8 +35,9 @@ fn aq_qh_householder() {

#[test]
fn aq_qh_mgs_complex() {
let a: Array2<c64> = random((5, 5));
let v: Array1<c64> = random(5);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<c64> = random_using((5, 5), &mut rng);
let v: Array1<c64> = random_using(5, &mut rng);
let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9);
println!("A = \n{:?}", &a);
println!("Q = \n{:?}", &q);
Expand All @@ -48,8 +51,9 @@ fn aq_qh_mgs_complex() {

#[test]
fn aq_qh_householder_complex() {
let a: Array2<c64> = random((5, 5));
let v: Array1<c64> = random(5);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<c64> = random_using((5, 5), &mut rng);
let v: Array1<c64> = random_using(5, &mut rng);
let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9);
println!("A = \n{:?}", &a);
println!("Q = \n{:?}", &q);
Expand Down
17 changes: 11 additions & 6 deletions ndarray-linalg/tests/cholesky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ macro_rules! cholesky {
paste::item! {
#[test]
fn [<cholesky_ $elem>]() {
let a_orig: Array2<$elem> = random_hpd(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a_orig: Array2<$elem> = random_hpd_using(3, &mut rng);
println!("a = \n{:?}", a_orig);

let upper = a_orig.cholesky(UPLO::Upper).unwrap();
Expand Down Expand Up @@ -79,7 +80,8 @@ macro_rules! cholesky_into_lower_upper {
paste::item! {
#[test]
fn [<cholesky_into_lower_upper_ $elem>]() {
let a: Array2<$elem> = random_hpd(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<$elem> = random_hpd_using(3, &mut rng);
println!("a = \n{:?}", a);
let upper = a.cholesky(UPLO::Upper).unwrap();
let fac_upper = a.factorizec(UPLO::Upper).unwrap();
Expand All @@ -106,7 +108,8 @@ macro_rules! cholesky_into_inverse {
paste::item! {
#[test]
fn [<cholesky_inverse_ $elem>]() {
let a: Array2<$elem> = random_hpd(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<$elem> = random_hpd_using(3, &mut rng);
println!("a = \n{:?}", a);
let inv = a.invc().unwrap();
assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol);
Expand Down Expand Up @@ -134,7 +137,8 @@ macro_rules! cholesky_det {
paste::item! {
#[test]
fn [<cholesky_det_ $elem>]() {
let a: Array2<$elem> = random_hpd(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<$elem> = random_hpd_using(3, &mut rng);
println!("a = \n{:?}", a);
let ln_det = a
.eigvalsh(UPLO::Upper)
Expand Down Expand Up @@ -168,8 +172,9 @@ macro_rules! cholesky_solve {
paste::item! {
#[test]
fn [<cholesky_solve_ $elem>]() {
let a: Array2<$elem> = random_hpd(3);
let x: Array1<$elem> = random(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<$elem> = random_hpd_using(3, &mut rng);
let x: Array1<$elem> = random_using(3, &mut rng);
let b = a.dot(&x);
println!("a = \n{:?}", a);
println!("x = \n{:?}", x);
Expand Down
3 changes: 2 additions & 1 deletion ndarray-linalg/tests/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use ndarray_linalg::*;

#[test]
fn generalize() {
let a: Array3<f64> = random((3, 2, 4).f());
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array3<f64> = random_using((3, 2, 4).f(), &mut rng);
let ans = a.clone();
let a: Array3<f64> = convert::generalize(a);
assert_eq!(a, ans);
Expand Down
40 changes: 31 additions & 9 deletions ndarray-linalg/tests/det.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,45 @@ fn det() {
assert_rclose!(result.1, ln_det, rtol);
}
}
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
for rows in 1..5 {
det_impl(random_regular::<f64>(rows), 1e-9);
det_impl(random_regular::<f32>(rows), 1e-4);
det_impl(random_regular::<c64>(rows), 1e-9);
det_impl(random_regular::<c32>(rows), 1e-4);
det_impl(random_regular::<f64>(rows).t().to_owned(), 1e-9);
det_impl(random_regular::<f32>(rows).t().to_owned(), 1e-4);
det_impl(random_regular::<c64>(rows).t().to_owned(), 1e-9);
det_impl(random_regular::<c32>(rows).t().to_owned(), 1e-4);
det_impl(random_regular_using::<f64, _>(rows, &mut rng), 1e-9);
det_impl(random_regular_using::<f32, _>(rows, &mut rng), 1e-4);
det_impl(random_regular_using::<c64, _>(rows, &mut rng), 1e-9);
det_impl(random_regular_using::<c32, _>(rows, &mut rng), 1e-4);
det_impl(
random_regular_using::<f64, _>(rows, &mut rng)
.t()
.to_owned(),
1e-9,
);
det_impl(
random_regular_using::<f32, _>(rows, &mut rng)
.t()
.to_owned(),
1e-4,
);
det_impl(
random_regular_using::<c64, _>(rows, &mut rng)
.t()
.to_owned(),
1e-9,
);
det_impl(
random_regular_using::<c32, _>(rows, &mut rng)
.t()
.to_owned(),
1e-4,
);
}
}

#[test]
fn det_nonsquare() {
macro_rules! det_nonsquare {
($elem:ty, $shape:expr) => {
let a: Array2<$elem> = random($shape);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<$elem> = random_using($shape, &mut rng);
assert!(a.factorize().unwrap().det().is_err());
assert!(a.factorize().unwrap().sln_det().is_err());
assert!(a.factorize().unwrap().det_into().is_err());
Expand Down
3 changes: 2 additions & 1 deletion ndarray-linalg/tests/deth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ fn deth_zero_nonsquare() {
fn deth() {
macro_rules! deth {
($elem:ty, $rows:expr, $atol:expr) => {
let a: Array2<$elem> = random_hermite($rows);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<$elem> = random_hermite_using($rows, &mut rng);
println!("a = \n{:?}", a);

// Compute determinant from eigenvalues.
Expand Down
12 changes: 8 additions & 4 deletions ndarray-linalg/tests/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ fn fixed_t_lower() {

#[test]
fn ssqrt() {
let a: Array2<f64> = random_hpd(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = random_hpd_using(3, &mut rng);
let ans = a.clone();
let s = a.ssqrt(UPLO::Upper).unwrap();
println!("a = {:?}", &ans);
Expand All @@ -92,7 +93,8 @@ fn ssqrt() {

#[test]
fn ssqrt_t() {
let a: Array2<f64> = random_hpd(3).reversed_axes();
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = random_hpd_using(3, &mut rng).reversed_axes();
let ans = a.clone();
let s = a.ssqrt(UPLO::Upper).unwrap();
println!("a = {:?}", &ans);
Expand All @@ -105,7 +107,8 @@ fn ssqrt_t() {

#[test]
fn ssqrt_lower() {
let a: Array2<f64> = random_hpd(3);
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = random_hpd_using(3, &mut rng);
let ans = a.clone();
let s = a.ssqrt(UPLO::Lower).unwrap();
println!("a = {:?}", &ans);
Expand All @@ -118,7 +121,8 @@ fn ssqrt_lower() {

#[test]
fn ssqrt_t_lower() {
let a: Array2<f64> = random_hpd(3).reversed_axes();
let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
let a: Array2<f64> = random_hpd_using(3, &mut rng).reversed_axes();
let ans = a.clone();
let s = a.ssqrt(UPLO::Lower).unwrap();
println!("a = {:?}", &ans);
Expand Down
Loading

0 comments on commit 3a8520c

Please sign in to comment.