Skip to content

Commit

Permalink
Refactor ctc (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored Feb 9, 2024
1 parent 507697e commit 66c30be
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
6 changes: 2 additions & 4 deletions scripts/datasets/check_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch_em.util.debug import check_loader
from torch_em.data.sampler import MinInstanceSampler

ROOT = "/home/pape/Work/data/ctc/ctc-training-data"
ROOT = "/scratch/projects/nim00007/sam/data/ctc/"


# Some of the datasets have partial sparse labels:
Expand All @@ -11,14 +11,12 @@
# Maybe depends on the split?!
def check_ctc_segmentation():
for name in CTC_URLS.keys():
if not name.startswith("DIC"):
continue
print("Checking dataset", name)
loader = get_ctc_segmentation_loader(
ROOT, name, (1, 512, 512), 1, download=True,
sampler=MinInstanceSampler()
)
check_loader(loader, 8, instance_labels=True)
check_loader(loader, 8, plt=True, save_path="ctc.png")


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions torch_em/data/datasets/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def _require_ctc_dataset(path, dataset_name, download):

data_path = os.path.join(path, dataset_name)

if not os.path.exists(data_path):
url, checksum = CTC_URLS[dataset_name], CTC_CHECKSUMS[dataset_name]
zip_path = os.path.join(path, f"{dataset_name}.zip")
util.download_source(zip_path, url, download, checksum=checksum)
util.unzip(zip_path, path, remove=True)
if os.path.exists(data_path):
return data_path

os.makedirs(data_path)
url, checksum = CTC_URLS[dataset_name], CTC_CHECKSUMS[dataset_name]
zip_path = os.path.join(path, f"{dataset_name}.zip")
util.download_source(zip_path, url, download, checksum=checksum)
util.unzip(zip_path, path, remove=True)

return data_path

Expand Down Expand Up @@ -101,6 +104,8 @@ def get_ctc_segmentation_dataset(
splits = glob(os.path.join(data_path, "*_GT"))
splits = [os.path.basename(split) for split in splits]
splits = [split.rstrip("_GT") for split in splits]
else:
splits = split

image_path, label_path = _require_gt_images(data_path, splits)

Expand Down

0 comments on commit 66c30be

Please sign in to comment.