Skip to content

Commit

Permalink
Small changes to environment.yml file, loader, constants, and cross v…
Browse files Browse the repository at this point in the history
…alidation.
  • Loading branch information
rvandewater committed Feb 20, 2023
1 parent e4f88f3 commit c52dc42
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name: yaib_impute
name: yaib_impute2
channels:
- pytorch
- conda-forge
- anaconda
dependencies:
- python=3.10
- black=22.10.0
- coverage=6.5.0
- flake8=5.0.4
Expand Down
7 changes: 5 additions & 2 deletions icu_benchmarks/contants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
class RunMode:
from enum import Enum


class RunMode(Enum):
classification = "Classification"
imputation = "Imputation"
imputation = "Imputation"
1 change: 1 addition & 0 deletions icu_benchmarks/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def execute_repeated_cv(
load_cache: Whether to load previously cached data.
test_on: Dataset to test on. Can be "test" or "val" (e.g. for hyperparameter tuning).
mode: Run mode. Can be one of the values of RunMode
pretrained_imputation_model: Use a pretrained imputation model.
cpu: Whether to run on CPU.
Returns:
The average loss of all folds.
Expand Down
17 changes: 8 additions & 9 deletions icu_benchmarks/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from pandas import DataFrame
import gin
import numpy as np
import torch
from torch import Tensor
from torch import Tensor, cat, from_numpy, float32
from torch.utils.data import Dataset
import logging
from typing import Dict, Tuple
from torch.utils.data import Dataset

from icu_benchmarks.imputation.amputations import ampute_data
from .constants import DataSegment as Segment
Expand Down Expand Up @@ -58,7 +57,7 @@ def to_tensor(self):
if len(values) <= i:
values.append([])
values[i].append(value.unsqueeze(0))
return [torch.cat(value, dim=0) for value in values]
return [cat(value, dim=0) for value in values]

@gin.configurable("ClassificationDataset")
class ClassificationDataset(SICUDataset):
Expand Down Expand Up @@ -118,7 +117,7 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
labels = labels.astype(np.float32)
data = window.astype(np.float32)

return torch.from_numpy(data), torch.from_numpy(labels), torch.from_numpy(pad_mask)
return from_numpy(data), from_numpy(labels), from_numpy(pad_mask)

def get_balance(self) -> list:
"""Return the weight balance for the split of interest.
Expand Down Expand Up @@ -201,9 +200,9 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
amputation_mask = self.amputation_mask.loc[stay_id:stay_id, self.vars[Segment.dynamic]]

return (
torch.from_numpy(amputated_window.values).to(torch.float32),
torch.from_numpy(amputation_mask.values).to(torch.float32),
torch.from_numpy(window.values).to(torch.float32),
from_numpy(amputated_window.values).to(float32),
from_numpy(amputation_mask.values).to(float32),
from_numpy(window.values).to(float32),
)


Expand Down Expand Up @@ -271,4 +270,4 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
# slice to make sure to always return a DF
window = self.dyn_df.loc[stay_id:stay_id, :]

return torch.from_numpy(window.values).to(torch.float32)
return from_numpy(window.values).to(float32)

0 comments on commit c52dc42

Please sign in to comment.