From 9450cabb5c1f9d6d951d438c654343d6503be921 Mon Sep 17 00:00:00 2001 From: rvandewater Date: Mon, 20 Feb 2023 17:23:42 +0100 Subject: [PATCH] Small fixes --- icu_benchmarks/models/train.py | 4 +--- icu_benchmarks/run.py | 9 ++++----- icu_benchmarks/run_utils.py | 2 ++ scripts/evaluate_results/aggregrate_experiment.py | 1 + 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index ae283ec2..a0337f61 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -35,7 +35,7 @@ def train_common( test_on: str = Split.test, use_wandb: bool = True, cpu: bool = False, - num_workers: int = min(len(os.sched_getaffinity(0)), torch.cuda.device_count() * 4 * int(torch.cuda.is_available()), 16), + num_workers: int = min(torch.cuda.device_count() * 4 * int(torch.cuda.is_available()), 16), ): """Common wrapper to train all benchmarked models. @@ -44,13 +44,11 @@ def train_common( log_dir: Path to directory where model output should be saved. load_weights: If set to true, skip training and load weights from source_dir instead. source_dir: If set to load weights, path to directory containing trained weights. - seed: Common seed used for any random operation. reproducible: If set to true, set torch to run reproducibly. mode: Mode of the model. Can be one of the values of RunMode. model: Model to be trained. weight: Weight to be used for the loss function. optimizer: Optimizer to be used for training. - do_test: If set to true, evaluate the model on the test set. batch_size: Batch size to be used for training. epochs: Number of epochs to train for. patience: Number of epochs to wait before early stopping. diff --git a/icu_benchmarks/run.py b/icu_benchmarks/run.py index 76c9c01d..9952cd4f 100644 --- a/icu_benchmarks/run.py +++ b/icu_benchmarks/run.py @@ -4,7 +4,6 @@ import gin import logging import sys -import torch from pathlib import Path import importlib.util @@ -37,7 +36,7 @@ def main(my_args=tuple(sys.argv[1:])): log_fmt = "%(asctime)s - %(levelname)s: %(message)s" logging.basicConfig(format=log_fmt) - logging.getLogger().setLevel(logging.INFO) + logging.getLogger().setLevel(logging.DEBUG) load_weights = args.command == "evaluate" args.data_dir = Path(args.data_dir) @@ -141,6 +140,6 @@ def main(my_args=tuple(sys.argv[1:])): plot_aggregated_results(run_dir, "aggregated_test_metrics.json") -"""Main module.""" -if __name__ == "__main__": - main() +# """Main module.""" +# if __name__ == "__main__": +# main() diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index 1e5b00d5..78f04413 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -169,6 +169,8 @@ def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newl "{0:{char}^{width}}{1}".format(msg, "\n" * num_newlines, char=char, width=terminal_size.columns - reserved_chars), ) +# aggregate_results(Path(r"C:\Users\Robin\Downloads\2023-02-12T15-46-31")) + def load_pretrained_imputation_model(use_pretrained_imputation): if use_pretrained_imputation is not None and not Path(use_pretrained_imputation).exists(): logging.warning("the specified pretrained imputation model does not exist") diff --git a/scripts/evaluate_results/aggregrate_experiment.py b/scripts/evaluate_results/aggregrate_experiment.py index e21cc00b..c83e54bd 100644 --- a/scripts/evaluate_results/aggregrate_experiment.py +++ b/scripts/evaluate_results/aggregrate_experiment.py @@ -83,3 +83,4 @@ def aggregate_results( # Exclude datasets results = results[results["Model"].isin(models)] return results +print(aggregate_results(Path(r"C:\Users\Robin\Downloads\2023-02-12T15-46-31"), decimals=1)) \ No newline at end of file