diff --git a/mne/parallel.py b/mne/parallel.py index 225f49e313a..95798a52e7f 100644 --- a/mne/parallel.py +++ b/mne/parallel.py @@ -51,14 +51,33 @@ def parallel_func(func, n_jobs, verbose=None): parallel_verbose = 5 if logger.level <= logging.INFO else 0 parallel = Parallel(n_jobs, verbose=parallel_verbose) my_func = delayed(func) + n_jobs = check_n_jobs(n_jobs) + return parallel, my_func, n_jobs - if n_jobs == -1: - try: - import multiprocessing - n_jobs = multiprocessing.cpu_count() - except ImportError: - logger.warn('multiprocessing not installed. Cannot run in ' - 'parallel.') - n_jobs = 1 - return parallel, my_func, n_jobs +def check_n_jobs(n_jobs): + """Check n_jobs in particular for negative values + + Parameters + ---------- + n_jobs : int + The number of jobs + + Returns + ------- + n_jobs : int + The checked number of jobs. Always positive. + """ + try: + import multiprocessing + n_cores = multiprocessing.cpu_count() + if n_cores + n_jobs <= 0: + raise ValueError('If n_jobs has a negative value it must not be less ' + 'than the number of CPUs present. You\'ve got ' + '%s CPUs' % n_cores) + n_jobs = n_cores + n_jobs + except ImportError: + logger.warn('multiprocessing not installed. Cannot run in ' + 'parallel.') + n_jobs = 1 + return n_jobs diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index 8ad485dbc85..58bd0c7a5d5 100755 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -9,14 +9,13 @@ # License: Simplified BSD import numpy as np -from multiprocessing import cpu_count from scipy import stats, sparse, ndimage import logging logger = logging.getLogger('mne') from .parametric import f_oneway -from ..parallel import parallel_func +from ..parallel import parallel_func, check_n_jobs from ..utils import split_list from ..fixes import in1d, unravel_index from .. import verbose @@ -768,16 +767,7 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1024, Journal of Neuroscience Methods, Vol. 164, No. 1., pp. 177-190. doi:10.1016/j.jneumeth.2007.03.024 """ - - # infer number of jobs and replace if user passes negative integer - if n_jobs <= 0: - n_cores = cpu_count() - if n_cores + n_jobs <= 0: - raise ValueError('If n_jobs has a negative value it must not be less ' - 'than the number of CPUs present. You\'ve got ' - '%s CPUs' % n_cores) - else: - n_jobs = n_cores + n_jobs + n_jobs = check_n_jobs(n_jobs) if not out_type in ['mask', 'indices']: raise ValueError('out_type must be either \'mask\' or \'indices\'') diff --git a/mne/stats/tests/test_cluster_level.py b/mne/stats/tests/test_cluster_level.py index 7055f2b9cf3..7503d132d67 100644 --- a/mne/stats/tests/test_cluster_level.py +++ b/mne/stats/tests/test_cluster_level.py @@ -8,8 +8,6 @@ permutation_cluster_1samp_test, \ spatio_temporal_cluster_1samp_test -from multiprocessing import cpu_count - noiselevel = 20 normfactor = np.hanning(20).sum() @@ -170,21 +168,6 @@ def test_cluster_permutation_t_test_with_connectivity(): connectivity=connectivity, max_step=1, threshold=1.67, n_jobs=-1000) - # test negative n_jobs arg - n_jobs = -(cpu_count() - 1) - n_jobs2 = 1 - out_1 = spatio_temporal_cluster_1samp_test( - condition1_3, n_permutations=1, - connectivity=connectivity, max_step=1, - n_jobs=n_jobs, - threshold=1.67) - out_2 = spatio_temporal_cluster_1samp_test( - condition1_3, n_permutations=1, - connectivity=connectivity, max_step=1, - n_jobs=n_jobs2, - threshold=1.67) - assert_array_almost_equal(out_1[2], out_2[2]) - def ttest_1samp(X): """Returns T-values