From d60c3b66cd123dd754b25c131d750a39d473657c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 17 Dec 2024 11:09:49 +0000 Subject: [PATCH] Add CLI command for combining batch runs (#1542) * Add CLI command for combining batch runs * Error message typo fix * Exclude non-directories when iterating --- src/tlo/cli.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/src/tlo/cli.py b/src/tlo/cli.py index 6824dd1045..1404088d76 100644 --- a/src/tlo/cli.py +++ b/src/tlo/cli.py @@ -8,6 +8,7 @@ import tempfile from collections import defaultdict from pathlib import Path +from shutil import copytree from typing import Dict import click @@ -40,6 +41,7 @@ def cli(ctx, config_file, verbose): * submit scenarios to batch system * query batch system about job and tasks * download output results for completed job + * combine runs from multiple batch jobs with same draws """ ctx.ensure_object(dict) ctx.obj["config_file"] = config_file @@ -844,5 +846,69 @@ def add_tasks(batch_service_client, user_identity, job_id, batch_service_client.task.add_collection(job_id, tasks) -if __name__ == '__main__': +@cli.command() +@click.argument( + "output_results_directory", + type=click.Path(exists=True, file_okay=False, writable=True, path_type=Path), +) +@click.argument( + "additional_result_directories", + nargs=-1, + type=click.Path(exists=True, file_okay=False, path_type=Path), +) +def combine_runs(output_results_directory: Path, additional_result_directories: tuple[Path]) -> None: + """Combine runs from multiple batch jobs locally. + + Merges runs from each draw in one or more additional results directories in + to corresponding draws in output results directory. + + All results directories must contain same draw numbers and the draw numbers + must be consecutive integers starting from 0. All run numbers in the output + result directory draw directories must be consecutive integers starting + from 0. + """ + if len(additional_result_directories) == 0: + msg = "One or more additional results directories to merge must be specified" + raise click.UsageError(msg) + results_directories = (output_results_directory,) + additional_result_directories + draws_per_directory = [ + sorted( + int(draw_directory.name) + for draw_directory in results_directory.iterdir() + if draw_directory.is_dir() + ) + for results_directory in results_directories + ] + for draws in draws_per_directory: + if not draws == list(range(len(draws_per_directory[0]))): + msg = ( + "All results directories must contain same draws, " + "consecutively numbered from 0." + ) + raise click.UsageError(msg) + draws = draws_per_directory[0] + runs_per_draw = [ + sorted( + int(run_directory.name) + for run_directory in (output_results_directory / str(draw)).iterdir() + if run_directory.is_dir() + ) + for draw in draws + ] + for runs in runs_per_draw: + if not runs == list(range(len(runs))): + msg = "All runs in output directory must be consecutively numbered from 0." + raise click.UsageError(msg) + for results_directory in additional_result_directories: + for draw in draws: + run_counter = len(runs_per_draw[draw]) + for source_path in sorted((results_directory / str(draw)).iterdir()): + if not source_path.is_dir(): + continue + destination_path = output_results_directory / str(draw) / str(run_counter) + run_counter = run_counter + 1 + copytree(source_path, destination_path) + + +if __name__ == "__main__": cli(obj={})