diff --git a/snakemake/__init__.py b/snakemake/__init__.py index 3655305ea..b7481d380 100644 --- a/snakemake/__init__.py +++ b/snakemake/__init__.py @@ -23,7 +23,12 @@ from snakemake.workflow import Workflow from snakemake.dag import Batch -from snakemake.exceptions import ResourceScopesException, print_exception, WorkflowError +from snakemake.exceptions import ( + CliException, + ResourceScopesException, + print_exception, + WorkflowError, +) from snakemake.logging import setup_logger, logger, SlackLogger, WMSLogger from snakemake.io import load_configfile, wait_for_files from snakemake.shell import shell @@ -1025,6 +1030,60 @@ def parse_config(args): return config +def parse_cores(cores, allow_none=False): + if cores is None: + if allow_none: + return cores + raise CliException( + "Error: you need to specify the maximum number of CPU cores to " + "be used at the same time. If you want to use N cores, say --cores N " + "or -cN. For all cores on your system (be sure that this is " + "appropriate) use --cores all. For no parallelization use --cores 1 or " + "-c1." + ) + if cores == "all": + return available_cpu_count() + try: + return int(cores) + except ValueError: + raise CliException( + "Error parsing number of cores (--cores, -c, -j): must be integer, " + "empty, or 'all'." + ) + + +def parse_jobs(jobs, allow_none=False): + if jobs is None: + if allow_none: + return jobs + raise CliException( + "Error: you need to specify the maximum number of jobs to " + "be queued or executed at the same time with --jobs or -j.", + ) + if jobs == "unlimited": + return sys.maxsize + try: + return int(jobs) + except ValueError: + raise CliException( + "Error parsing number of jobs (--jobs, -j): must be integer.", + ) + + +def parse_cores_jobs(cores, jobs, no_exec, non_local_exec, dryrun): + if no_exec or dryrun: + cores = parse_cores(cores, allow_none=True) or 1 + jobs = parse_jobs(jobs, allow_none=True) or 1 + elif non_local_exec: + cores = parse_cores(cores, allow_none=True) + jobs = parse_jobs(jobs) + else: + cores = parse_cores(cores or jobs) + jobs = None + + return cores, jobs + + def get_profile_file(profile, file, return_default=False): dirs = get_appdirs() if os.path.exists(profile): @@ -2666,65 +2725,16 @@ def adjust_path(f): or args.unlock or args.cleanup_metadata ) - local_exec = not (no_exec or non_local_exec) - - def parse_cores(cores): - if cores == "all": - return available_cpu_count() - else: - try: - return int(cores) - except ValueError: - print( - "Error parsing number of cores (--cores, -c, -j): must be integer, empty, or 'all'.", - file=sys.stderr, - ) - sys.exit(1) - if args.cores is not None: - args.cores = parse_cores(args.cores) - if local_exec: - # avoid people accidentally setting jobs as well - args.jobs = None - else: - if no_exec: - args.cores = 1 - elif local_exec: - if args.jobs is not None: - args.cores = parse_cores(args.jobs) - args.jobs = None - elif args.dryrun: - # dryrun with single core if nothing specified - args.cores = 1 - else: - print( - "Error: you need to specify the maximum number of CPU cores to " - "be used at the same time. If you want to use N cores, say --cores N or " - "-cN. For all cores on your system (be sure that this is appropriate) " - "use --cores all. For no parallelization use --cores 1 or -c1.", - file=sys.stderr, - ) - sys.exit(1) - - if non_local_exec: - if args.jobs is None: - print( - "Error: you need to specify the maximum number of jobs to " - "be queued or executed at the same time with --jobs or -j.", - file=sys.stderr, - ) - sys.exit(1) - elif args.jobs == "unlimited": - args.jobs = sys.maxsize - else: - try: - args.jobs = int(args.jobs) - except ValueError: - print( - "Error parsing number of jobs (--jobs, -j): must be integer.", - file=sys.stderr, - ) - sys.exit(1) + try: + cores, jobs = parse_cores_jobs( + args.cores, args.jobs, no_exec, non_local_exec, args.dryrun + ) + args.cores = cores + args.jobs = jobs + except CliException as err: + print(err.msg, sys.stderr) + sys.exit(1) if args.drmaa_log_dir is not None: if not os.path.isabs(args.drmaa_log_dir): diff --git a/snakemake/exceptions.py b/snakemake/exceptions.py index 5f9dc0c7e..9cd48483f 100644 --- a/snakemake/exceptions.py +++ b/snakemake/exceptions.py @@ -568,3 +568,9 @@ def __init__(self, msg, invalid_resources): super().__init__(msg, invalid_resources) self.msg = msg self.invalid_resources = invalid_resources + + +class CliException(Exception): + def __init__(self, msg): + super().__init__(msg) + self.msg = msg diff --git a/tests/tests.py b/tests/tests.py index d9273e746..59f553507 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -9,6 +9,10 @@ import subprocess as sp from pathlib import Path +from snakemake import parse_cores_jobs +from snakemake.exceptions import CliException +from snakemake.utils import available_cpu_count + sys.path.insert(0, os.path.dirname(__file__)) from .common import * @@ -1513,6 +1517,59 @@ def test_env_modules(): run(dpath("test_env_modules"), use_env_modules=True) +class TestParseCoresJobs: + def run_test(self, func, ref): + if ref is None: + with pytest.raises(CliException): + func() + return + assert func() == ref + + @pytest.mark.parametrize( + ("input", "output"), + [ + [(1, 1), (1, 1)], + [(4, 4), (4, 4)], + [(None, None), (1, 1)], + [("all", "unlimited"), (available_cpu_count(), sys.maxsize)], + ], + ) + def test_no_exec(self, input, output): + self.run_test(lambda: parse_cores_jobs(*input, True, False, False), output) + # Test dryrun seperately + self.run_test(lambda: parse_cores_jobs(*input, False, False, True), output) + + @pytest.mark.parametrize( + ("input", "output"), + [ + [(1, 1), (1, 1)], + [(4, 4), (4, 4)], + [(None, 1), (None, 1)], + [(None, None), None], + [(1, None), None], + [("all", "unlimited"), (available_cpu_count(), sys.maxsize)], + ], + ) + def test_non_local_job(self, input, output): + self.run_test(lambda: parse_cores_jobs(*input, False, True, False), output) + + @pytest.mark.parametrize( + ("input", "output"), + [ + [(1, 1), (1, None)], + [(4, 4), (4, None)], + [(None, 1), (1, None)], + [(None, None), None], + [(1, None), (1, None)], + [(None, "all"), (available_cpu_count(), None)], + [(None, "unlimited"), None], + [("all", "unlimited"), (available_cpu_count(), None)], + ], + ) + def test_local_job(self, input, output): + self.run_test(lambda: parse_cores_jobs(*input, False, False, False), output) + + @skip_on_windows @connected def test_container():