Skip to content

Commit

Permalink
Update script to check CEM loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 29, 2023
1 parent 7ebb1f5 commit 7111f72
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions scripts/datasets/check_cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from torch_em.data.datasets import cem
from torch_em.util.debug import check_loader

# ROOT = "./data"
ROOT = "/scratch-grete/projects/nim00007/data/mitolab"


def get_all_shapes():
# Get the shape for the 3d datasets (id: 1-6)
data_root = "./data/10982/data/mito_benchmarks"
data_root = os.path.join(ROOT, "10982/data/mito_benchmarks")
i = 1
for root, dirs, files in os.walk(data_root):
dirs.sort()
Expand All @@ -21,7 +24,7 @@ def get_all_shapes():
i += 1

# Get the shape for the 2d dataset (id: 7)
data_root = "./data/10982/data/tem_benchmark/images"
data_root = os.path.join(ROOT, "10982/data/tem_benchmark/images")

shapes_2d = []
for image in glob(os.path.join(data_root, "*.tiff")):
Expand All @@ -39,15 +42,15 @@ def check_benchmark_loaders():
else:
patch_shape = (1,) + full_shape[1:]
loader = cem.get_benchmark_loader(
"./data", dataset_id=dataset_id, batch_size=1, patch_shape=patch_shape, ndim=2
ROOT, dataset_id=dataset_id, batch_size=1, patch_shape=patch_shape, ndim=2
)
check_loader(loader, 4, instance_labels=True)


def check_mitolab_loader():
val_fraction = 0.1
train_loader = cem.get_mitolab_loader(
"./data", split="train", batch_size=1, shuffle=True,
ROOT, split="train", batch_size=1, shuffle=True,
sampler=torch_em.data.sampler.MinInstanceSampler(),
val_fraction=val_fraction,
)
Expand All @@ -56,7 +59,7 @@ def check_mitolab_loader():
print("... done")

val_loader = cem.get_mitolab_loader(
"./data", split="val", batch_size=1, shuffle=True,
ROOT, split="val", batch_size=1, shuffle=True,
sampler=torch_em.data.sampler.MinInstanceSampler(),
val_fraction=val_fraction,
)
Expand All @@ -66,7 +69,7 @@ def check_mitolab_loader():


def analyse_mitolab():
data_root = "data/11037/cem_mitolab"
data_root = os.path.join(ROOT, "11037/cem_mitolab")
folders = glob(os.path.join(data_root, "*"))

n_datasets = len(folders)
Expand Down Expand Up @@ -94,8 +97,8 @@ def analyse_mitolab():

def main():
# get_all_shapes()
# check_benchmark_loaders()
check_mitolab_loader()
check_benchmark_loaders()
# check_mitolab_loader()
# analyse_mitolab()


Expand Down

0 comments on commit 7111f72

Please sign in to comment.