Skip to content

Commit 4f80f68

Browse files
committed
add multithreading tests
1 parent 79575bb commit 4f80f68

6 files changed

+125
-3
lines changed

CHANGES.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ transform improves multi-core utilization which may offset the performance loss
7979

8080
Added :code:`scipy.fft` backend, see #42. Fixed #46.
8181

82-
```
82+
```rst
83+
.. code-block:: python
8384
Python 3.7.5 (default, Nov 23 2019, 04:02:01)
8485
Type 'copyright', 'credits' or 'license' for more information
8586
IPython 7.11.1 -- An enhanced Interactive Python. Type '?' for help.

mkl_fft/_scipy_fft.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def get_max_threads_count(self):
7171

7272
class _workers_data:
7373
def __init__(self, workers=None):
74-
if workers:
74+
if workers is not None:
75+
workers = _workers(workers)
7576
self.workers_ = workers
7677
else:
7778
self.workers_ = _cpu_max_threads_count().get_cpu_count()
@@ -86,6 +87,22 @@ def workers(self, workers_val):
8687
self.workerks_ = operator.index(workers_val)
8788

8889

90+
def _workers(workers):
91+
_cpu_count = os.cpu_count()
92+
if workers < 0:
93+
if workers >= -_cpu_count:
94+
workers += 1 + _cpu_count
95+
else:
96+
raise ValueError(
97+
f"workers value out of range; got {workers}, must not be"
98+
f" less than {-_cpu_count}"
99+
)
100+
elif workers == 0:
101+
raise ValueError("workers must not be zero")
102+
103+
return workers
104+
105+
89106
_workers_global_settings = contextvars.ContextVar(
90107
"scipy_backend_workers", default=_workers_data()
91108
)
Binary file not shown.

mkl_fft/tests/test_from_scipy.py mkl_fft/tests/from_scipy/test_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file includes tests from scipy.fft module:
22
# https://github.com/scipy/scipy/blob/main/scipy/fft/tests/test_basic.py
33

4-
# TODO: remove when hfft functions are added
4+
# TODO: remove when hfft* functions are added
55
# pylint: disable=no-member
66

77
import multiprocessing
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# This file includes tests from scipy.fft module:
2+
# https://github.com/scipy/scipy/blob/main/scipy/fft/tests/test_multithreading.py.py
3+
4+
import multiprocessing
5+
import os
6+
7+
import numpy as np
8+
import pytest
9+
from numpy.testing import assert_allclose
10+
11+
import mkl_fft.interfaces.scipy_fft as fft
12+
13+
14+
@pytest.fixture(scope="module")
15+
def x():
16+
return np.random.randn(512, 128) # Must be large enough to qualify for mt
17+
18+
19+
@pytest.mark.parametrize(
20+
"func",
21+
[
22+
fft.fft,
23+
fft.ifft,
24+
fft.fft2,
25+
fft.ifft2,
26+
fft.fftn,
27+
fft.ifftn,
28+
fft.rfft,
29+
fft.irfft,
30+
fft.rfft2,
31+
fft.irfft2,
32+
fft.rfftn,
33+
fft.irfftn,
34+
# TODO: fft.hfft, fft.ihfft, fft.hfft2, fft.ihfft2, fft.hfftn, fft.ihfftn,
35+
# TODO: fft.dct, fft.idct, fft.dctn, fft.idctn,
36+
# TODO: fft.dst, fft.idst, fft.dstn, fft.idstn,
37+
],
38+
)
39+
@pytest.mark.parametrize("workers", [2, -1])
40+
def test_threaded_same(x, func, workers):
41+
expected = func(x, workers=1)
42+
actual = func(x, workers=workers)
43+
assert_allclose(actual, expected)
44+
45+
46+
def _mt_fft(x):
47+
return fft.fft(x, workers=2)
48+
49+
50+
@pytest.mark.slow
51+
def test_mixed_threads_processes(x):
52+
# Test that the fft threadpool is safe to use before & after fork
53+
54+
expect = fft.fft(x, workers=2)
55+
56+
with multiprocessing.Pool(2) as p:
57+
res = p.map(_mt_fft, [x for _ in range(4)])
58+
59+
for r in res:
60+
assert_allclose(r, expect)
61+
62+
fft.fft(x, workers=2)
63+
64+
65+
def test_invalid_workers(x):
66+
cpus = os.cpu_count()
67+
68+
fft.ifft([1], workers=-cpus)
69+
70+
with pytest.raises(ValueError, match="workers must not be zero"):
71+
fft.fft(x, workers=0)
72+
73+
with pytest.raises(ValueError, match="workers value out of range"):
74+
fft.ifft(x, workers=-cpus - 1)
75+
76+
77+
def test_set_get_workers():
78+
cpus = os.cpu_count()
79+
# scipy default is 1 but mkl_fft default is max number of threads
80+
assert fft.get_workers() == cpus
81+
with fft.set_workers(4):
82+
assert fft.get_workers() == 4
83+
84+
with fft.set_workers(-1):
85+
assert fft.get_workers() == cpus
86+
87+
assert fft.get_workers() == 4
88+
89+
# scipy default is 1 but mkl_fft default is max number of threads
90+
assert fft.get_workers() == cpus
91+
92+
with fft.set_workers(-cpus):
93+
assert fft.get_workers() == 1
94+
95+
96+
def test_set_workers_invalid():
97+
98+
with pytest.raises(ValueError): # , match='workers must not be zero'):
99+
with fft.set_workers(0):
100+
pass
101+
102+
with pytest.raises(ValueError): # , match='workers value out of range'):
103+
with fft.set_workers(-os.cpu_count() - 1):
104+
pass

0 commit comments

Comments
 (0)