From 31647efe300f273c05b3703709e5aa3ab1028e5b Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Sat, 18 Dec 2021 19:26:25 -0500 Subject: [PATCH 1/2] Add tests for mixing thread counts --- parallel_numpy_rng.py | 8 ++++-- test_parallel_numpy_rng.py | 53 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/parallel_numpy_rng.py b/parallel_numpy_rng.py index 2d99ab9..cf731d2 100644 --- a/parallel_numpy_rng.py +++ b/parallel_numpy_rng.py @@ -69,6 +69,8 @@ def random(self, size=None, nthread=None, out=None, verify_rng=True, dtype=np.fl size = 1 if nthread == None: nthread = self.nthread + if nthread < 1: + raise ValueError("nthread must be >= 1") if nthread > size: nthread = size @@ -128,6 +130,8 @@ def standard_normal(self, size=None, nthread=None, out=None, verify_rng=True, dt nthread = self.nthread if nthread > max(size//2,1): nthread = max(size//2,1) + if nthread < 1: + raise ValueError("nthread must be >= 1") if out is None: out = np.empty(size, dtype=dtype) if size == 0: @@ -193,7 +197,7 @@ def _copy_bitgen(self): @njit(fastmath=True, parallel=True) def _random(states, starts, out, next_double): nthread = len(states) - numba.set_num_threads(nthread) + numba.set_num_threads(max(1,nthread)) for t in numba.prange(nthread): a = starts[t] @@ -206,7 +210,7 @@ def _random(states, starts, out, next_double): @njit(fastmath=True,parallel=True) def _boxmuller(states, starts, out, next_double): nthread = len(states) - numba.set_num_threads(nthread) + numba.set_num_threads(max(1,nthread)) dtype = out.dtype.type cache = np.full(1, np.nan, dtype=dtype) diff --git a/test_parallel_numpy_rng.py b/test_parallel_numpy_rng.py index 1c8408b..5ebb970 100644 --- a/test_parallel_numpy_rng.py +++ b/test_parallel_numpy_rng.py @@ -58,6 +58,10 @@ def test_threads(allN, seed, nthread, dtype, funcname): from parallel_numpy_rng import MTGenerator for N in allN: + if N < maxthreads-1: + # don't repeatedly test N < nthread + continue + pcg = np.random.PCG64(seed) mtg = MTGenerator(pcg) func = getattr(mtg,funcname) @@ -88,6 +92,10 @@ def test_resume(someN, seed, nthread, dtype, funcname): rng = np.random.default_rng(seed) for N in someN: + if N < maxthreads-1: + # don't repeatedly test N < nthread + continue + pcg = np.random.PCG64(seed) mtg = MTGenerator(pcg) func = getattr(mtg,funcname) @@ -108,6 +116,47 @@ def test_resume(someN, seed, nthread, dtype, funcname): assert np.allclose(a, res, atol=1e-7, rtol=0.) elif dtype == np.float64: assert np.allclose(a, res, atol=1e-15, rtol=0.) + + +def test_mixing_threads(someN, seed, nthread, dtype): + '''Test that changing the number of threads mid-stream + doesn't matter. Only standard normal holds any interesting + external state. + ''' + funcname = 'standard_normal' + from parallel_numpy_rng import MTGenerator + + rng = np.random.default_rng(seed) + maxthreads = nthread + del nthread + + for N in someN: + if N < maxthreads-1: + # don't repeatedly test N < nthread + continue + + pcg = np.random.PCG64(seed) + mtg = MTGenerator(pcg) + func = getattr(mtg,funcname) + a = func(size=N, nthread=maxthreads, dtype=dtype) + + pcg = np.random.PCG64(seed) + mtg = MTGenerator(pcg) + func = getattr(mtg,funcname) + + res = np.empty(N, dtype=dtype) + i = 0 + tstart = np.linspace(0, N, maxthreads+1, endpoint=True, dtype=int) + # sweep from 1 to maxthreads + for t in range(maxthreads): + n = tstart[t+1]-tstart[t] + res[i:i+n] = func(size=n, nthread=t+1, dtype=dtype) + i += n + + if dtype == np.float32: + assert np.allclose(a, res, atol=1e-7, rtol=0.) + elif dtype == np.float64: + assert np.allclose(a, res, atol=1e-15, rtol=0.) def test_uniform_matches_numpy(someN, seed, nthread, dtype): @@ -122,6 +171,10 @@ def test_uniform_matches_numpy(someN, seed, nthread, dtype): from parallel_numpy_rng import MTGenerator for N in someN: + if N < maxthreads-1: + # don't repeatedly test N < nthread + continue + pcg = np.random.PCG64(seed) mtg = MTGenerator(pcg) a = mtg.random(size=N, nthread=nthread, dtype=dtype) From c91ef305b5f44d0efc889bcba9be602d0f5a7947 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Sat, 18 Dec 2021 19:43:26 -0500 Subject: [PATCH 2/2] Test mixing normals and uniforms --- test_parallel_numpy_rng.py | 53 +++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/test_parallel_numpy_rng.py b/test_parallel_numpy_rng.py index 5ebb970..2cb652a 100644 --- a/test_parallel_numpy_rng.py +++ b/test_parallel_numpy_rng.py @@ -58,7 +58,7 @@ def test_threads(allN, seed, nthread, dtype, funcname): from parallel_numpy_rng import MTGenerator for N in allN: - if N < maxthreads-1: + if N < nthread-1: # don't repeatedly test N < nthread continue @@ -92,7 +92,7 @@ def test_resume(someN, seed, nthread, dtype, funcname): rng = np.random.default_rng(seed) for N in someN: - if N < maxthreads-1: + if N < nthread-1: # don't repeatedly test N < nthread continue @@ -157,6 +157,54 @@ def test_mixing_threads(someN, seed, nthread, dtype): assert np.allclose(a, res, atol=1e-7, rtol=0.) elif dtype == np.float64: assert np.allclose(a, res, atol=1e-15, rtol=0.) + + +def test_mixing_func(someN, seed, nthread, dtype): + '''Test interleaving random and standard_normal works for different N/thread + ''' + + from parallel_numpy_rng import MTGenerator + + rng = np.random.default_rng(seed) + + for N in someN: + if N < nthread-1: + # don't repeatedly test N < nthread + continue + + pcg = np.random.PCG64(seed) + mtg = MTGenerator(pcg) + coin_seed = rng.integers(2**16) + coin_rng = np.random.default_rng(coin_seed) + + nchunk = max(2,N//100) + serial = np.empty(N, dtype=dtype) + i = 0 + tstart = np.linspace(0, N, nchunk+1, endpoint=True, dtype=int) + for t in range(nchunk): + n = tstart[t+1]-tstart[t] + # in each chunk, flip a coin to decide the function + func = mtg.random if coin_rng.integers(2) else mtg.standard_normal + serial[i:i+n] = func(size=n, nthread=1, dtype=dtype) + i += n + + + pcg = np.random.PCG64(seed) + mtg = MTGenerator(pcg) + coin_rng = np.random.default_rng(coin_seed) + + parallel = np.empty(N, dtype=dtype) + i = 0 + for t in range(nchunk): + n = tstart[t+1]-tstart[t] + func = mtg.random if coin_rng.integers(2) else mtg.standard_normal + parallel[i:i+n] = func(size=n, nthread=nthread, dtype=dtype) + i += n + + if dtype == np.float32: + assert np.allclose(serial, parallel, atol=1e-7, rtol=0.) + elif dtype == np.float64: + assert np.allclose(serial, parallel, atol=1e-15, rtol=0.) def test_uniform_matches_numpy(someN, seed, nthread, dtype): @@ -205,4 +253,3 @@ def test_finite_normals_float32(): mtg = MTGenerator(pcg) a = mtg.standard_normal(size=20000, nthread=maxthreads, dtype=np.float32) assert np.all(np.isfinite(a)) - \ No newline at end of file