Skip to content

Commit

Permalink
Update CEM loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 23, 2023
1 parent 1656748 commit 12c29a3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import imageio.v3 as imageio

from torch_em.data.datasets.cem import get_benchmark_loader, BENCHMARK_SHAPES
from torch_em.data.datasets import cem
from torch_em.util.debug import check_loader


Expand All @@ -17,17 +17,25 @@ def get_all_shapes():
i += 1


def check_all_loaders():
def check_benchmark_loaders():
for dataset_id in range(1, 7):
full_shape = BENCHMARK_SHAPES[dataset_id]
full_shape = cem.BENCHMARK_SHAPES[dataset_id]
patch_shape = (1,) + full_shape[1:]
loader = get_benchmark_loader("./data", dataset_id=dataset_id, batch_size=1, patch_shape=patch_shape, ndim=2)
loader = cem.get_benchmark_loader(
"./data", dataset_id=dataset_id, batch_size=1, patch_shape=patch_shape, ndim=2
)
check_loader(loader, 4, instance_labels=True)


def check_mitolab_loader():
loader = cem.get_mitolab_loader("./data", split=None, batch_size=1, shuffle=True)
check_loader(loader, 8, instance_labels=True)


def main():
# get_all_shapes()
check_all_loaders()
# get_benchmark_shapes()
# check_benchmark_loaders()
check_mitolab_loader()


if __name__ == "__main__":
Expand Down
30 changes: 17 additions & 13 deletions torch_em/data/datasets/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,27 @@
}


# TODO
def _download_cem_mitolab(path):
# os.makedirs(path, exist_ok=True)
raise NotImplementedError("Data download is not implemented yet for CEM data.")
def _get_mitolab_data(path, download):
access_id = "11037"
data_path = util.download_source_empiar(path, access_id, download)

zip_path = os.path.join(data_path, "data/cem_mitolab.zip")
if os.path.exists(zip_path):
util.unzip(zip_path, data_path, remove=True)

def _get_cem_mitolab_paths(path, split, val_fraction, download):
folders = glob(os.path.join(path, "*"))
assert all(os.path.isdir(folder) for folder in folders)
data_root = os.path.join(data_path, "cem_mitolab")
assert os.path.exists(data_root)

if len(folders) == 0 and download:
_download_cem_mitolab(path)
elif len(folders) == 0:
raise RuntimeError(f"The CEM Mitolab data is not available at {path}, but download was set to False.")
return data_root

raw_paths, label_paths = [], []

def _get_mitolab_paths(path, split, val_fraction, download):
data_path = _get_mitolab_data(path, download)

folders = glob(os.path.join(data_path, "*"))
assert all(os.path.isdir(folder) for folder in folders)

raw_paths, label_paths = [], []
for folder in folders:
images = glob(os.path.join(folder, "images", "*.tiff"))
images.sort()
Expand Down Expand Up @@ -100,7 +104,7 @@ def get_mitolab_dataset(
):
assert split in ("train", "val", None)
assert os.path.exists(path)
raw_paths, label_paths = _get_cem_mitolab_paths(path, split, val_fraction, download)
raw_paths, label_paths = _get_mitolab_paths(path, split, val_fraction, download)
return torch_em.default_segmentation_dataset(
raw_paths=raw_paths, raw_key=None,
label_paths=label_paths, label_key=None,
Expand Down

0 comments on commit 12c29a3

Please sign in to comment.