Skip to content

Commit

Permalink
Add CLI command for combining batch runs (#1542)
Browse files Browse the repository at this point in the history
* Add CLI command for combining batch runs

* Error message typo fix

* Exclude non-directories when iterating
  • Loading branch information
matt-graham authored Dec 17, 2024
1 parent a4fc9b2 commit d60c3b6
Showing 1 changed file with 67 additions and 1 deletion.
68 changes: 67 additions & 1 deletion src/tlo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile
from collections import defaultdict
from pathlib import Path
from shutil import copytree
from typing import Dict

import click
Expand Down Expand Up @@ -40,6 +41,7 @@ def cli(ctx, config_file, verbose):
* submit scenarios to batch system
* query batch system about job and tasks
* download output results for completed job
* combine runs from multiple batch jobs with same draws
"""
ctx.ensure_object(dict)
ctx.obj["config_file"] = config_file
Expand Down Expand Up @@ -844,5 +846,69 @@ def add_tasks(batch_service_client, user_identity, job_id,
batch_service_client.task.add_collection(job_id, tasks)


if __name__ == '__main__':
@cli.command()
@click.argument(
"output_results_directory",
type=click.Path(exists=True, file_okay=False, writable=True, path_type=Path),
)
@click.argument(
"additional_result_directories",
nargs=-1,
type=click.Path(exists=True, file_okay=False, path_type=Path),
)
def combine_runs(output_results_directory: Path, additional_result_directories: tuple[Path]) -> None:
"""Combine runs from multiple batch jobs locally.
Merges runs from each draw in one or more additional results directories in
to corresponding draws in output results directory.
All results directories must contain same draw numbers and the draw numbers
must be consecutive integers starting from 0. All run numbers in the output
result directory draw directories must be consecutive integers starting
from 0.
"""
if len(additional_result_directories) == 0:
msg = "One or more additional results directories to merge must be specified"
raise click.UsageError(msg)
results_directories = (output_results_directory,) + additional_result_directories
draws_per_directory = [
sorted(
int(draw_directory.name)
for draw_directory in results_directory.iterdir()
if draw_directory.is_dir()
)
for results_directory in results_directories
]
for draws in draws_per_directory:
if not draws == list(range(len(draws_per_directory[0]))):
msg = (
"All results directories must contain same draws, "
"consecutively numbered from 0."
)
raise click.UsageError(msg)
draws = draws_per_directory[0]
runs_per_draw = [
sorted(
int(run_directory.name)
for run_directory in (output_results_directory / str(draw)).iterdir()
if run_directory.is_dir()
)
for draw in draws
]
for runs in runs_per_draw:
if not runs == list(range(len(runs))):
msg = "All runs in output directory must be consecutively numbered from 0."
raise click.UsageError(msg)
for results_directory in additional_result_directories:
for draw in draws:
run_counter = len(runs_per_draw[draw])
for source_path in sorted((results_directory / str(draw)).iterdir()):
if not source_path.is_dir():
continue
destination_path = output_results_directory / str(draw) / str(run_counter)
run_counter = run_counter + 1
copytree(source_path, destination_path)


if __name__ == "__main__":
cli(obj={})

0 comments on commit d60c3b6

Please sign in to comment.