Skip to content

Commit 8934c5d

Browse files
authored
Merge pull request #337 from rust-ndarray/random-using-rng
`random_*_using` API for using given RNG
2 parents 816f1a7 + 0d2fc06 commit 8934c5d

File tree

1 file changed

+92
-8
lines changed

1 file changed

+92
-8
lines changed

ndarray-linalg/src/generate.rs

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ where
2222
a
2323
}
2424

25-
/// Generate random array
25+
/// Generate random array with given shape
26+
///
27+
/// - This function uses [rand::thread_rng].
28+
/// See [random_using] for using another RNG
2629
pub fn random<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D>
2730
where
2831
A: Scalar,
@@ -31,29 +34,77 @@ where
3134
Sh: ShapeBuilder<Dim = D>,
3235
{
3336
let mut rng = thread_rng();
34-
ArrayBase::from_shape_fn(sh, |_| A::rand(&mut rng))
37+
random_using(sh, &mut rng)
38+
}
39+
40+
/// Generate random array with given RNG
41+
///
42+
/// - See [random] for using default RNG
43+
pub fn random_using<A, S, Sh, D, R>(sh: Sh, rng: &mut R) -> ArrayBase<S, D>
44+
where
45+
A: Scalar,
46+
S: DataOwned<Elem = A>,
47+
D: Dimension,
48+
Sh: ShapeBuilder<Dim = D>,
49+
R: Rng,
50+
{
51+
ArrayBase::from_shape_fn(sh, |_| A::rand(rng))
3552
}
3653

3754
/// Generate random unitary matrix using QR decomposition
3855
///
39-
/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.
56+
/// - Be sure that this it **NOT** a uniform distribution.
57+
/// Use it only for test purpose.
58+
/// - This function uses [rand::thread_rng].
59+
/// See [random_unitary_using] for using another RNG.
4060
pub fn random_unitary<A>(n: usize) -> Array2<A>
4161
where
4262
A: Scalar + Lapack,
4363
{
44-
let a: Array2<A> = random((n, n));
64+
let mut rng = thread_rng();
65+
random_unitary_using(n, &mut rng)
66+
}
67+
68+
/// Generate random unitary matrix using QR decomposition with given RNG
69+
///
70+
/// - Be sure that this it **NOT** a uniform distribution.
71+
/// Use it only for test purpose.
72+
/// - See [random_unitary] for using default RNG.
73+
pub fn random_unitary_using<A, R>(n: usize, rng: &mut R) -> Array2<A>
74+
where
75+
A: Scalar + Lapack,
76+
R: Rng,
77+
{
78+
let a: Array2<A> = random_using((n, n), rng);
4579
let (q, _r) = a.qr_into().unwrap();
4680
q
4781
}
4882

4983
/// Generate random regular matrix
5084
///
51-
/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.
85+
/// - Be sure that this it **NOT** a uniform distribution.
86+
/// Use it only for test purpose.
87+
/// - This function uses [rand::thread_rng].
88+
/// See [random_regular_using] for using another RNG.
5289
pub fn random_regular<A>(n: usize) -> Array2<A>
5390
where
5491
A: Scalar + Lapack,
5592
{
56-
let a: Array2<A> = random((n, n));
93+
let mut rng = rand::thread_rng();
94+
random_regular_using(n, &mut rng)
95+
}
96+
97+
/// Generate random regular matrix with given RNG
98+
///
99+
/// - Be sure that this it **NOT** a uniform distribution.
100+
/// Use it only for test purpose.
101+
/// - See [random_regular] for using default RNG.
102+
pub fn random_regular_using<A, R>(n: usize, rng: &mut R) -> Array2<A>
103+
where
104+
A: Scalar + Lapack,
105+
R: Rng,
106+
{
107+
let a: Array2<A> = random_using((n, n), rng);
57108
let (q, mut r) = a.qr_into().unwrap();
58109
for i in 0..n {
59110
r[(i, i)] = A::one() + A::from_real(r[(i, i)].abs());
@@ -62,12 +113,28 @@ where
62113
}
63114

64115
/// Random Hermite matrix
116+
///
117+
/// - This function uses [rand::thread_rng].
118+
/// See [random_hermite_using] for using another RNG.
65119
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
66120
where
67121
A: Scalar,
68122
S: DataOwned<Elem = A> + DataMut,
69123
{
70-
let mut a: ArrayBase<S, Ix2> = random((n, n));
124+
let mut rng = rand::thread_rng();
125+
random_hermite_using(n, &mut rng)
126+
}
127+
128+
/// Random Hermite matrix with given RNG
129+
///
130+
/// - See [random_hermite] for using default RNG.
131+
pub fn random_hermite_using<A, S, R>(n: usize, rng: &mut R) -> ArrayBase<S, Ix2>
132+
where
133+
A: Scalar,
134+
S: DataOwned<Elem = A> + DataMut,
135+
R: Rng,
136+
{
137+
let mut a: ArrayBase<S, Ix2> = random_using((n, n), rng);
71138
for i in 0..n {
72139
a[(i, i)] = a[(i, i)] + a[(i, i)].conj();
73140
for j in (i + 1)..n {
@@ -80,13 +147,30 @@ where
80147
/// Random Hermite Positive-definite matrix
81148
///
82149
/// - Eigenvalue of matrix must be larger than 1 (thus non-singular)
150+
/// - This function uses [rand::thread_rng].
151+
/// See [random_hpd_using] for using another RNG.
83152
///
84153
pub fn random_hpd<A, S>(n: usize) -> ArrayBase<S, Ix2>
85154
where
86155
A: Scalar,
87156
S: DataOwned<Elem = A> + DataMut,
88157
{
89-
let a: Array2<A> = random((n, n));
158+
let mut rng = rand::thread_rng();
159+
random_hpd_using(n, &mut rng)
160+
}
161+
162+
/// Random Hermite Positive-definite matrix with given RNG
163+
///
164+
/// - Eigenvalue of matrix must be larger than 1 (thus non-singular)
165+
/// - See [random_hpd] for using default RNG.
166+
///
167+
pub fn random_hpd_using<A, S, R>(n: usize, rng: &mut R) -> ArrayBase<S, Ix2>
168+
where
169+
A: Scalar,
170+
S: DataOwned<Elem = A> + DataMut,
171+
R: Rng,
172+
{
173+
let a: Array2<A> = random_using((n, n), rng);
90174
let ah: Array2<A> = conjugate(&a);
91175
ArrayBase::eye(n) + &ah.dot(&a)
92176
}

0 commit comments

Comments
 (0)