From 3e74d8680288e4e8536f52f7de35cc7ead8e9d91 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Fri, 1 Sep 2023 14:31:36 -0600 Subject: [PATCH] added protection against bad batch inputs with test --- gaps/batch.py | 23 +++++++++++++++++++++-- gaps/cli/documentation.py | 16 +++++++++++++++- tests/test_batch.py | 19 +++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/gaps/batch.py b/gaps/batch.py index 7b9f67e7..445c60ea 100644 --- a/gaps/batch.py +++ b/gaps/batch.py @@ -48,6 +48,7 @@ def __init__(self, config): config : str File path to config json or csv (str). """ + self._job_tags = None self._base_dir, config = _load_batch_config(config) self._pipeline_fp = Path(config["pipeline_config"]) @@ -338,7 +339,7 @@ def _check_sets(config, base_dir): def _enumerated_product(args): """An enumerated product function.""" - yield from zip(product(*(range(len(x)) for x in args)), product(*args)) + return list(zip(product(*(range(len(x)) for x in args)), product(*args))) def _parse_config(config): @@ -350,12 +351,30 @@ def _parse_config(config): for batch_set in config["sets"]: set_tag = batch_set.get("set_tag", "") args = batch_set["args"] + if set_tag in sets: msg = f"Found multiple sets with the same set_tag: {set_tag!r}" raise gapsValueError(msg) + + for key, value in args.items(): + if isinstance(value, str): + msg = ('Batch arguments should be lists but found ' + f'"{key}": "{value}"') + raise gapsValueError(msg) + sets.add(set_tag) - for inds, comb in _enumerated_product(args.values()): + products = _enumerated_product(args.values()) + set_str = f' in set "{set_tag}"' if set_tag else '' + logger.info(f'Found {len(products)} batch projects{set_str}. ' + 'Creating jobs...') + if len(products) > 1e3: + msg = (f'Large number of batch jobs found: {len(products)}! ' + 'Proceeding, but consider double checking your config.') + warn(msg) + logger.warning(msg) + + for inds, comb in products: arg_combo = dict(zip(args, comb)) arg_inds = dict(zip(args, inds)) tag_arg_comb = { diff --git a/gaps/cli/documentation.py b/gaps/cli/documentation.py index fd13557b..7487af0c 100644 --- a/gaps/cli/documentation.py +++ b/gaps/cli/documentation.py @@ -132,6 +132,16 @@ Parameters ---------- +logging : dict, optional + Dictionary containing keyword-argument pairs to pass to + `init_logger `_. This + initializes logging for the submission portion of the + pipeline. Note, however, that each step (command) will + **also** record the submission step log output to a + common "project" log file, so it's only ever necessary + to use this input if you want a different (lower) level + of verbosity than the `log_level` specified in the + config for the step of the pipeline being executed. pipeline_config : str Path to the pipeline configuration defining the commands to run for every parametric set. @@ -163,8 +173,8 @@ include an underscore, as that is provided during concatenation. - """ + _BATCH_ARGS_DICT = """.. tabs:: .. group-tab:: JSON/JSON5 @@ -602,6 +612,10 @@ def _batch_command_help(): # pragma: no cover format_inputs = {} template_config = { + "logging": { + "log_file": None, + "log_level": "INFO" + }, "pipeline_config": CommandDocumentation.REQUIRED_TAG, "sets": [ { diff --git a/tests/test_batch.py b/tests/test_batch.py index a66703f8..171aacb4 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -541,5 +541,24 @@ def test_batch_csv_setup(csv_batch_config): assert count_2 == count_0, "Batch did not clear all job files!" +def test_bad_str_arg(typical_batch_config): + """Test that a string in a batch argument will raise an error (argument + parameterizations should be lists)""" + + batch_dir = typical_batch_config.parent + + config = ConfigType.JSON.load(typical_batch_config) + config['sets'][0]['args']['project_points'] = 'bad_str' + with open(typical_batch_config, 'w') as f: + ConfigType.JSON.dump(config, f) + + count_0 = len(list(batch_dir.glob("*"))) + assert count_0 == 8, "Unknown starting files detected!" + + with pytest.raises(gapsValueError) as exc_info: + BatchJob(typical_batch_config).run(dry_run=True) + assert 'Batch arguments should be lists' in str(exc_info) + + if __name__ == "__main__": pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"])