diff --git a/environment.yml b/environment.yml index 386b6c97..da295d69 100644 --- a/environment.yml +++ b/environment.yml @@ -1,9 +1,10 @@ -name: yaib_impute +name: yaib_impute2 channels: - pytorch - conda-forge - anaconda dependencies: + - python=3.10 - black=22.10.0 - coverage=6.5.0 - flake8=5.0.4 diff --git a/icu_benchmarks/contants.py b/icu_benchmarks/contants.py index 2875146a..7bda61da 100644 --- a/icu_benchmarks/contants.py +++ b/icu_benchmarks/contants.py @@ -1,3 +1,6 @@ -class RunMode: +from enum import Enum + + +class RunMode(Enum): classification = "Classification" - imputation = "Imputation" \ No newline at end of file + imputation = "Imputation" diff --git a/icu_benchmarks/cross_validation.py b/icu_benchmarks/cross_validation.py index e1f1a445..a31d9e65 100644 --- a/icu_benchmarks/cross_validation.py +++ b/icu_benchmarks/cross_validation.py @@ -48,6 +48,7 @@ def execute_repeated_cv( load_cache: Whether to load previously cached data. test_on: Dataset to test on. Can be "test" or "val" (e.g. for hyperparameter tuning). mode: Run mode. Can be one of the values of RunMode + pretrained_imputation_model: Use a pretrained imputation model. cpu: Whether to run on CPU. Returns: The average loss of all folds. diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index 19c96fb5..89fbd3e2 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -2,11 +2,10 @@ from pandas import DataFrame import gin import numpy as np -import torch -from torch import Tensor +from torch import Tensor, cat, from_numpy, float32 +from torch.utils.data import Dataset import logging from typing import Dict, Tuple -from torch.utils.data import Dataset from icu_benchmarks.imputation.amputations import ampute_data from .constants import DataSegment as Segment @@ -58,7 +57,7 @@ def to_tensor(self): if len(values) <= i: values.append([]) values[i].append(value.unsqueeze(0)) - return [torch.cat(value, dim=0) for value in values] + return [cat(value, dim=0) for value in values] @gin.configurable("ClassificationDataset") class ClassificationDataset(SICUDataset): @@ -118,7 +117,7 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: labels = labels.astype(np.float32) data = window.astype(np.float32) - return torch.from_numpy(data), torch.from_numpy(labels), torch.from_numpy(pad_mask) + return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) def get_balance(self) -> list: """Return the weight balance for the split of interest. @@ -201,9 +200,9 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: amputation_mask = self.amputation_mask.loc[stay_id:stay_id, self.vars[Segment.dynamic]] return ( - torch.from_numpy(amputated_window.values).to(torch.float32), - torch.from_numpy(amputation_mask.values).to(torch.float32), - torch.from_numpy(window.values).to(torch.float32), + from_numpy(amputated_window.values).to(float32), + from_numpy(amputation_mask.values).to(float32), + from_numpy(window.values).to(float32), ) @@ -271,4 +270,4 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: # slice to make sure to always return a DF window = self.dyn_df.loc[stay_id:stay_id, :] - return torch.from_numpy(window.values).to(torch.float32) + return from_numpy(window.values).to(float32)