Skip to content

Commit

Permalink
added protection against bad batch inputs with test
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Sep 1, 2023
1 parent 6e9092d commit 3e74d86
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
23 changes: 21 additions & 2 deletions gaps/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config):
config : str
File path to config json or csv (str).
"""

self._job_tags = None
self._base_dir, config = _load_batch_config(config)
self._pipeline_fp = Path(config["pipeline_config"])
Expand Down Expand Up @@ -338,7 +339,7 @@ def _check_sets(config, base_dir):

def _enumerated_product(args):
"""An enumerated product function."""
yield from zip(product(*(range(len(x)) for x in args)), product(*args))
return list(zip(product(*(range(len(x)) for x in args)), product(*args)))


def _parse_config(config):
Expand All @@ -350,12 +351,30 @@ def _parse_config(config):
for batch_set in config["sets"]:
set_tag = batch_set.get("set_tag", "")
args = batch_set["args"]

if set_tag in sets:
msg = f"Found multiple sets with the same set_tag: {set_tag!r}"
raise gapsValueError(msg)

for key, value in args.items():
if isinstance(value, str):
msg = ('Batch arguments should be lists but found '
f'"{key}": "{value}"')
raise gapsValueError(msg)

sets.add(set_tag)

for inds, comb in _enumerated_product(args.values()):
products = _enumerated_product(args.values())
set_str = f' in set "{set_tag}"' if set_tag else ''
logger.info(f'Found {len(products)} batch projects{set_str}. '
'Creating jobs...')
if len(products) > 1e3:
msg = (f'Large number of batch jobs found: {len(products)}! '
'Proceeding, but consider double checking your config.')
warn(msg)
logger.warning(msg)

for inds, comb in products:
arg_combo = dict(zip(args, comb))
arg_inds = dict(zip(args, inds))
tag_arg_comb = {
Expand Down
16 changes: 15 additions & 1 deletion gaps/cli/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@
Parameters
----------
logging : dict, optional
Dictionary containing keyword-argument pairs to pass to
`init_logger <https://tinyurl.com/47hakp7f/>`_. This
initializes logging for the submission portion of the
pipeline. Note, however, that each step (command) will
**also** record the submission step log output to a
common "project" log file, so it's only ever necessary
to use this input if you want a different (lower) level
of verbosity than the `log_level` specified in the
config for the step of the pipeline being executed.
pipeline_config : str
Path to the pipeline configuration defining the commands to
run for every parametric set.
Expand Down Expand Up @@ -163,8 +173,8 @@
include an underscore, as that is provided during
concatenation.
"""

_BATCH_ARGS_DICT = """.. tabs::
.. group-tab:: JSON/JSON5
Expand Down Expand Up @@ -602,6 +612,10 @@ def _batch_command_help(): # pragma: no cover

format_inputs = {}
template_config = {
"logging": {
"log_file": None,
"log_level": "INFO"
},
"pipeline_config": CommandDocumentation.REQUIRED_TAG,
"sets": [
{
Expand Down
19 changes: 19 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,5 +541,24 @@ def test_batch_csv_setup(csv_batch_config):
assert count_2 == count_0, "Batch did not clear all job files!"


def test_bad_str_arg(typical_batch_config):
"""Test that a string in a batch argument will raise an error (argument
parameterizations should be lists)"""

batch_dir = typical_batch_config.parent

config = ConfigType.JSON.load(typical_batch_config)
config['sets'][0]['args']['project_points'] = 'bad_str'
with open(typical_batch_config, 'w') as f:
ConfigType.JSON.dump(config, f)

count_0 = len(list(batch_dir.glob("*")))
assert count_0 == 8, "Unknown starting files detected!"

with pytest.raises(gapsValueError) as exc_info:
BatchJob(typical_batch_config).run(dry_run=True)
assert 'Batch arguments should be lists' in str(exc_info)


if __name__ == "__main__":
pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"])

0 comments on commit 3e74d86

Please sign in to comment.