Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Update check_n_jobs #14

Merged
merged 5 commits into from
Feb 14, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions sktime/utils/validation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,28 @@ def all_inputs_are_time_like(args: list) -> bool:
def check_n_jobs(n_jobs: int) -> int:
"""Check `n_jobs` parameter according to the scikit-learn convention.

https://scikit-learn.org/stable/glossary.html#term-n_jobs

Parameters
----------
n_jobs : int, positive or -1
n_jobs : int or None
The number of jobs for parallelization.
If None or 0, 1 is used.
If negative, (n_cpus + 1 + n_jobs) is used. In such a case, -1 would use all
available CPUs and -2 would use all but one. If the number of CPUs used would
fall under 1, 1 is returned instead.

Returns
-------
n_jobs : int
Checked number of jobs.
The number of threads to be used.
"""
# scikit-learn convention
# https://scikit-learn.org/stable/glossary.html#term-n-jobs
if n_jobs is None:
if n_jobs is None or n_jobs == 0:
return 1
elif not is_int(n_jobs):
raise ValueError(f"`n_jobs` must be None or an integer, but found: {n_jobs}")
elif n_jobs < 0:
return os.cpu_count() - n_jobs + 1
return max(1, os.cpu_count() + 1 + n_jobs)
else:
return n_jobs

Expand Down