Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Feb 20, 2023
1 parent e752ecd commit 9450cab
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 1 addition & 3 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train_common(
test_on: str = Split.test,
use_wandb: bool = True,
cpu: bool = False,
num_workers: int = min(len(os.sched_getaffinity(0)), torch.cuda.device_count() * 4 * int(torch.cuda.is_available()), 16),
num_workers: int = min(torch.cuda.device_count() * 4 * int(torch.cuda.is_available()), 16),
):
"""Common wrapper to train all benchmarked models.
Expand All @@ -44,13 +44,11 @@ def train_common(
log_dir: Path to directory where model output should be saved.
load_weights: If set to true, skip training and load weights from source_dir instead.
source_dir: If set to load weights, path to directory containing trained weights.
seed: Common seed used for any random operation.
reproducible: If set to true, set torch to run reproducibly.
mode: Mode of the model. Can be one of the values of RunMode.
model: Model to be trained.
weight: Weight to be used for the loss function.
optimizer: Optimizer to be used for training.
do_test: If set to true, evaluate the model on the test set.
batch_size: Batch size to be used for training.
epochs: Number of epochs to train for.
patience: Number of epochs to wait before early stopping.
Expand Down
9 changes: 4 additions & 5 deletions icu_benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import gin
import logging
import sys
import torch
from pathlib import Path

import importlib.util
Expand Down Expand Up @@ -37,7 +36,7 @@ def main(my_args=tuple(sys.argv[1:])):

log_fmt = "%(asctime)s - %(levelname)s: %(message)s"
logging.basicConfig(format=log_fmt)
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().setLevel(logging.DEBUG)

load_weights = args.command == "evaluate"
args.data_dir = Path(args.data_dir)
Expand Down Expand Up @@ -141,6 +140,6 @@ def main(my_args=tuple(sys.argv[1:])):
plot_aggregated_results(run_dir, "aggregated_test_metrics.json")


"""Main module."""
if __name__ == "__main__":
main()
# """Main module."""
# if __name__ == "__main__":
# main()
2 changes: 2 additions & 0 deletions icu_benchmarks/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newl
"{0:{char}^{width}}{1}".format(msg, "\n" * num_newlines, char=char, width=terminal_size.columns - reserved_chars),
)

# aggregate_results(Path(r"C:\Users\Robin\Downloads\2023-02-12T15-46-31"))

def load_pretrained_imputation_model(use_pretrained_imputation):
if use_pretrained_imputation is not None and not Path(use_pretrained_imputation).exists():
logging.warning("the specified pretrained imputation model does not exist")
Expand Down
1 change: 1 addition & 0 deletions scripts/evaluate_results/aggregrate_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ def aggregate_results(
# Exclude datasets
results = results[results["Model"].isin(models)]
return results
print(aggregate_results(Path(r"C:\Users\Robin\Downloads\2023-02-12T15-46-31"), decimals=1))

0 comments on commit 9450cab

Please sign in to comment.