Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #11

Merged
merged 12 commits into from
Aug 26, 2023
5 changes: 4 additions & 1 deletion bio_vae/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@

from functools import lru_cache

from albumentations import Compose
from typing import Callable


class DatasetGlob(Dataset):
def __init__(
self,
path_glob,
over_sampling=1,
transform=None,
transform: Callable = Compose([]),
samples=-1,
shuffle=True,
**kwargs,
Expand Down
40 changes: 25 additions & 15 deletions bio_vae/lightning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,35 @@ def __init__(
def get_dataset(self):
return self.dataset

def splitting(self, dataset, split=0.8, seed=42):
if len(dataset) < 4:
return dataset, dataset, dataset, dataset
spliting_shares = [
len(dataset) * split * split, # train
len(dataset) * split * (1 - split), # test
len(dataset) * split * (1 - split), # predict
len(dataset) * (1 - split) * (1 - split), # val
]

train, test, predict, val = random_split(

def splitting(self, dataset, split_train=0.8, split_val=0.1, seed=42):
if len(dataset) < 3:
return dataset, dataset, dataset

train_share = int(len(dataset) * split_train)
val_share = int(len(dataset) * split_val)
test_share = len(dataset) - train_share - val_share

# Ensure that the splits add up correctly
if train_share + val_share + test_share != len(dataset):
raise ValueError("The splitting ratios do not add up to the length of the dataset")

torch.manual_seed(seed) # for reproducibility

train, val, test = random_split(
dataset,
list(map(int, saferound(spliting_shares, places=0))),
[train_share, val_share, test_share]
)

return test, train, predict, val
return train, val, test

def setup(self, stage=None):
self.test, self.train, self.predict, self.val = self.splitting(self.dataset)
self.train, self.val, self.test = self.splitting(self.dataset)

# self.test = self.get_dataloader(test)
# self.predict = self.get_dataloader(predict)
# self.train = self.get_dataloader(train)
# self.val = self.get_dataloader(val)

def test_dataloader(self):
return DataLoader(self.test, **self.data_loader_settings)
Expand All @@ -72,7 +82,7 @@ def val_dataloader(self):
return DataLoader(self.val, **self.data_loader_settings)

def predict_dataloader(self):
return DataLoader(self.predict, **self.data_loader_settings)
return DataLoader(self.dataset, **self.data_loader_settings)

# def teardown(self, stage: Optional[str] = None):
# # Used to clean-up when the run is finished
Expand Down
2 changes: 0 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
channels:
- bioconda
- pytorch
- idr
- ome
- conda-forge
# - defaults
- torch
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ umap-learn = {extras = ["plot"], version = "^0.5.3"}
colorcet = "^3.0.1"
holoviews = "^1.15.2"
# idr-py = "^0.4.2"
llvmlite = "^0.39.1"
#llvmlite = "^0.39.1"
torchmetrics = "^0.11.0"
tensorboard = "^2.11.2"
albumentations = "^1.3.0"
Expand Down
Binary file added scripts/.train_ivy_gap_legacy.py.swp
Binary file not shown.
Loading
Loading