Skip to content

Commit

Permalink
Merge pull request #20 from NREL/pp/batch_background_fix
Browse files Browse the repository at this point in the history
Batch + monitor + background fix
  • Loading branch information
ppinchuk authored Sep 7, 2023
2 parents a135c89 + fb1b09f commit c040ad8
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 17 deletions.
31 changes: 22 additions & 9 deletions gaps/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _run_pipelines(self, monitor_background=False):
"""Run the pipeline modules for each batch job."""

for sub_directory in self.sub_dirs:
os.chdir(sub_directory)
pipeline_config = sub_directory / self._pipeline_fp.name
if not pipeline_config.is_file():
raise gapsConfigError(
Expand Down Expand Up @@ -216,7 +217,11 @@ def run(self, dry_run=False, monitor_background=False):
if dry_run:
return

self._run_pipelines(monitor_background=monitor_background)
cwd = os.getcwd()
try:
self._run_pipelines(monitor_background=monitor_background)
finally:
os.chdir(cwd)


def _load_batch_config(config_fp):
Expand Down Expand Up @@ -358,19 +363,27 @@ def _parse_config(config):

for key, value in args.items():
if isinstance(value, str):
msg = ('Batch arguments should be lists but found '
f'"{key}": "{value}"')
msg = (
"Batch arguments should be lists but found "
f"{key!r}: {value!r}"
)
raise gapsValueError(msg)

sets.add(set_tag)

products = _enumerated_product(args.values())
set_str = f' in set "{set_tag}"' if set_tag else ''
logger.info(f'Found {len(products)} batch projects{set_str}. '
'Creating jobs...')
if len(products) > 1e3:
msg = (f'Large number of batch jobs found: {len(products)}! '
'Proceeding, but consider double checking your config.')
num_batch_jobs = len(products)
set_str = f" in set {set_tag!r}" if set_tag else ""
logger.info(
"Found %d batch projects%s. Creating jobs...",
num_batch_jobs,
set_str,
)
if num_batch_jobs > 1e3:
msg = (
f"Large number of batch jobs found: {num_batch_jobs:,}! "
"Proceeding, but consider double checking your config."
)
warn(msg)
logger.warning(msg)

Expand Down
3 changes: 1 addition & 2 deletions gaps/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
GAPs pipeline CLI entry points.
"""
import os
import sys
import logging
from pathlib import Path
from warnings import warn
Expand Down Expand Up @@ -77,6 +76,7 @@ def pipeline(ctx, config_file, cancel, monitor, background=False):
raise gapsExecutionError(msg)
ctx.obj["LOG_STREAM"] = False
_kickoff_background(config_file)
return

project_dir = str(Path(config_file).parent.expanduser().resolve())
status = Status(project_dir).update_from_all_job_files(purge=False)
Expand Down Expand Up @@ -107,7 +107,6 @@ def _kickoff_background(config_file): # pragma: no cover
click.echo(
f"Kicking off pipeline job in the background. Monitor PID: {pid}"
)
sys.exit()


def pipeline_command(template_config):
Expand Down
10 changes: 8 additions & 2 deletions gaps/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@ def _cancel_all_jobs(self):
"""Cancel all jobs in this pipeline."""
status = self.status
for job_id, hardware in zip(status.job_ids, status.job_hardware):
if job_id is None:
continue

manager = HardwareOption(hardware).manager
if manager is not None:
manager.cancel(job_id)
if manager is None:
continue

manager.cancel(job_id)

logger.info("Pipeline job %r cancelled.", self.name)

def _main(self):
Expand Down
2 changes: 1 addition & 1 deletion gaps/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""GAPs Version Number. """

__version__ = "0.4.3"
__version__ = "0.4.4"
2 changes: 2 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def preprocess_run_config(config, project_dir, out_dir):
main, ["pipeline", "-c", pipe_config_fp.as_posix(), "--background"]
)

time.sleep(10) # give job enough time to run a little

status = Status(tmp_cwd).update_from_all_job_files(purge=False)
assert "monitor_pid" in status

Expand Down
11 changes: 8 additions & 3 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,19 +390,24 @@ def test_batch_job_run(typical_batch_config, monkeypatch):
assert count_0 == 8, "Unknown starting files detected!"

config_cache = []
working_dirs = []

def _test_call(config, monitor, *__, **___):
assert not monitor
config_cache.append(config)
working_dirs.append(os.getcwd())

monkeypatch.setattr(
gaps.pipeline.Pipeline, "run", _test_call, raising=True
)

cwd = os.getcwd()
BatchJob(typical_batch_config).run()
assert len(config_cache) == 9
assert set(fp.name for fp in config_cache) == {"config_pipeline.json"}
assert len(set(fp.parent for fp in config_cache)) == 9
assert len(set(working_dirs)) == 9
assert cwd == os.getcwd()

BatchJob(typical_batch_config).delete()
count_2 = len(list(batch_dir.glob("*")))
Expand Down Expand Up @@ -548,16 +553,16 @@ def test_bad_str_arg(typical_batch_config):
batch_dir = typical_batch_config.parent

config = ConfigType.JSON.load(typical_batch_config)
config['sets'][0]['args']['project_points'] = 'bad_str'
with open(typical_batch_config, 'w') as f:
config["sets"][0]["args"]["project_points"] = "bad_str"
with open(typical_batch_config, "w") as f:
ConfigType.JSON.dump(config, f)

count_0 = len(list(batch_dir.glob("*")))
assert count_0 == 8, "Unknown starting files detected!"

with pytest.raises(gapsValueError) as exc_info:
BatchJob(typical_batch_config).run(dry_run=True)
assert 'Batch arguments should be lists' in str(exc_info)
assert "Batch arguments should be lists" in str(exc_info)


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,15 @@ def cache_cancel_calls(__, job_id):
StatusField.HARDWARE: HardwareOption.EAGLE,
},
)
Status.mark_job_as_submitted(
sample_pipeline_config.parent,
"run",
"test4",
job_attrs={
StatusField.JOB_ID: None,
StatusField.HARDWARE: HardwareOption.EAGLE,
},
)
Pipeline.cancel_all(sample_pipeline_config)
assert set(cancelled_jobs) == {1, 12}
assert_message_was_logged("Pipeline job", "INFO")
Expand Down

0 comments on commit c040ad8

Please sign in to comment.