Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions parallel_numpy_rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def random(self, size=None, nthread=None, out=None, verify_rng=True, dtype=np.fl
size = 1
if nthread == None:
nthread = self.nthread
if nthread < 1:
raise ValueError("nthread must be >= 1")
if nthread > size:
nthread = size

Expand Down Expand Up @@ -128,6 +130,8 @@ def standard_normal(self, size=None, nthread=None, out=None, verify_rng=True, dt
nthread = self.nthread
if nthread > max(size//2,1):
nthread = max(size//2,1)
if nthread < 1:
raise ValueError("nthread must be >= 1")
if out is None:
out = np.empty(size, dtype=dtype)
if size == 0:
Expand Down Expand Up @@ -193,7 +197,7 @@ def _copy_bitgen(self):
@njit(fastmath=True, parallel=True)
def _random(states, starts, out, next_double):
nthread = len(states)
numba.set_num_threads(nthread)
numba.set_num_threads(max(1,nthread))

for t in numba.prange(nthread):
a = starts[t]
Expand All @@ -206,7 +210,7 @@ def _random(states, starts, out, next_double):
@njit(fastmath=True,parallel=True)
def _boxmuller(states, starts, out, next_double):
nthread = len(states)
numba.set_num_threads(nthread)
numba.set_num_threads(max(1,nthread))
dtype = out.dtype.type

cache = np.full(1, np.nan, dtype=dtype)
Expand Down
102 changes: 101 additions & 1 deletion test_parallel_numpy_rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_threads(allN, seed, nthread, dtype, funcname):
from parallel_numpy_rng import MTGenerator

for N in allN:
if N < nthread-1:
# don't repeatedly test N < nthread
continue

pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
func = getattr(mtg,funcname)
Expand Down Expand Up @@ -88,6 +92,10 @@ def test_resume(someN, seed, nthread, dtype, funcname):
rng = np.random.default_rng(seed)

for N in someN:
if N < nthread-1:
# don't repeatedly test N < nthread
continue

pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
func = getattr(mtg,funcname)
Expand All @@ -108,6 +116,95 @@ def test_resume(someN, seed, nthread, dtype, funcname):
assert np.allclose(a, res, atol=1e-7, rtol=0.)
elif dtype == np.float64:
assert np.allclose(a, res, atol=1e-15, rtol=0.)


def test_mixing_threads(someN, seed, nthread, dtype):
'''Test that changing the number of threads mid-stream
doesn't matter. Only standard normal holds any interesting
external state.
'''
funcname = 'standard_normal'
from parallel_numpy_rng import MTGenerator

rng = np.random.default_rng(seed)
maxthreads = nthread
del nthread

for N in someN:
if N < maxthreads-1:
# don't repeatedly test N < nthread
continue

pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
func = getattr(mtg,funcname)
a = func(size=N, nthread=maxthreads, dtype=dtype)

pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
func = getattr(mtg,funcname)

res = np.empty(N, dtype=dtype)
i = 0
tstart = np.linspace(0, N, maxthreads+1, endpoint=True, dtype=int)
# sweep from 1 to maxthreads
for t in range(maxthreads):
n = tstart[t+1]-tstart[t]
res[i:i+n] = func(size=n, nthread=t+1, dtype=dtype)
i += n

if dtype == np.float32:
assert np.allclose(a, res, atol=1e-7, rtol=0.)
elif dtype == np.float64:
assert np.allclose(a, res, atol=1e-15, rtol=0.)


def test_mixing_func(someN, seed, nthread, dtype):
'''Test interleaving random and standard_normal works for different N/thread
'''

from parallel_numpy_rng import MTGenerator

rng = np.random.default_rng(seed)

for N in someN:
if N < nthread-1:
# don't repeatedly test N < nthread
continue

pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
coin_seed = rng.integers(2**16)
coin_rng = np.random.default_rng(coin_seed)

nchunk = max(2,N//100)
serial = np.empty(N, dtype=dtype)
i = 0
tstart = np.linspace(0, N, nchunk+1, endpoint=True, dtype=int)
for t in range(nchunk):
n = tstart[t+1]-tstart[t]
# in each chunk, flip a coin to decide the function
func = mtg.random if coin_rng.integers(2) else mtg.standard_normal
serial[i:i+n] = func(size=n, nthread=1, dtype=dtype)
i += n


pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
coin_rng = np.random.default_rng(coin_seed)

parallel = np.empty(N, dtype=dtype)
i = 0
for t in range(nchunk):
n = tstart[t+1]-tstart[t]
func = mtg.random if coin_rng.integers(2) else mtg.standard_normal
parallel[i:i+n] = func(size=n, nthread=nthread, dtype=dtype)
i += n

if dtype == np.float32:
assert np.allclose(serial, parallel, atol=1e-7, rtol=0.)
elif dtype == np.float64:
assert np.allclose(serial, parallel, atol=1e-15, rtol=0.)


def test_uniform_matches_numpy(someN, seed, nthread, dtype):
Expand All @@ -122,6 +219,10 @@ def test_uniform_matches_numpy(someN, seed, nthread, dtype):
from parallel_numpy_rng import MTGenerator

for N in someN:
if N < maxthreads-1:
# don't repeatedly test N < nthread
continue

pcg = np.random.PCG64(seed)
mtg = MTGenerator(pcg)
a = mtg.random(size=N, nthread=nthread, dtype=dtype)
Expand Down Expand Up @@ -152,4 +253,3 @@ def test_finite_normals_float32():
mtg = MTGenerator(pcg)
a = mtg.standard_normal(size=20000, nthread=maxthreads, dtype=np.float32)
assert np.all(np.isfinite(a))