Skip to content

Commit

Permalink
move masking to base image
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Oct 14, 2024
1 parent f1aa2ec commit 7cd4407
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 52 deletions.
54 changes: 2 additions & 52 deletions src/ngio/core/image_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
"""A module to handle OME-NGFF images stored in Zarr format."""

from typing import Literal

import dask.array as da
import numpy as np

from ngio._common_types import ArrayLike
from ngio.core.image_like_handler import ImageLike
from ngio.core.roi import WorldCooROI
from ngio.io import StoreOrGroup
from ngio.ngff_meta.fractal_image_meta import ImageMeta, PixelSize

Expand Down Expand Up @@ -44,6 +37,7 @@ def __init__(
strict (bool): Whether to raise an error where a pixel size is not found
to match the requested "pixel_size".
cache (bool): Whether to cache the metadata.
label_group: The group containing the labels.
"""
super().__init__(
store=store,
Expand All @@ -54,8 +48,8 @@ def __init__(
strict=strict,
meta_mode="image",
cache=cache,
_label_group=label_group,
)
self._label_group = label_group

@property
def metadata(self) -> ImageMeta:
Expand All @@ -74,47 +68,3 @@ def get_channel_idx(
) -> int:
"""Return the index of the channel."""
return self.metadata.get_channel_idx(label=label, wavelength_id=wavelength_id)

def get_array_masked(
self,
roi: WorldCooROI,
t: int | slice | None = None,
c: int | slice | None = None,
mask_mode: Literal["bbox", "mask"] = "bbox",
mode: Literal["numpy"] = "numpy",
preserve_dimensions: bool = False,
) -> ArrayLike:
"""Return the image data from a region of interest (ROI).
Args:
roi (WorldCooROI): The region of interest.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
mask_mode (str): Masking mode
mode (str): The mode to return the data.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
label_name = roi.infos.get("label_name", None)
if label_name is None:
raise ValueError("The label name must be provided in the ROI infos.")

data_pipe = self._build_roi_pipe(
roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions
)

if mask_mode == "bbox":
return self._get_pipe(data_pipe=data_pipe, mode=mode)

label = self._label_group.get_label(label_name, pixel_size=self.pixel_size)

mask = label.mask(
roi,
t=t,
mode=mode,
)
array = self._get_pipe(data_pipe=data_pipe, mode=mode)
if mode == "numpy":
return_array = np.where(mask, array, 0)
else:
return_array = da.where(mask, array, 0)
return return_array
45 changes: 45 additions & 0 deletions src/ngio/core/image_like_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
strict: bool = True,
meta_mode: Literal["image", "label"] = "image",
cache: bool = True,
_label_group=None,
) -> None:
"""Initialize the MultiscaleHandler in read mode.
Expand All @@ -57,6 +58,7 @@ def __init__(
to match the requested "pixel_size".
meta_mode (str): The mode of the metadata handler.
cache (bool): Whether to cache the metadata.
_label_group (LabelGroup): The group containing the label data (internal use).
"""
if not strict:
warn("Strict mode is not fully supported yet.", UserWarning, stacklevel=2)
Expand Down Expand Up @@ -89,6 +91,8 @@ def __init__(
self._init_dataset(dataset)
self._dask_lock = None

self._label_group = _label_group

def _init_dataset(self, dataset: Dataset):
"""Set the dataset of the image.
Expand Down Expand Up @@ -415,6 +419,47 @@ def set_array(
)
self._set_pipe(data_pipe=data_pipe, patch=patch)

def get_array_masked(
self,
roi: WorldCooROI,
t: int | slice | None = None,
c: int | slice | None = None,
mask_mode: Literal["bbox", "mask"] = "bbox",
mode: Literal["numpy"] = "numpy",
preserve_dimensions: bool = False,
) -> ArrayLike:
"""Return the image data from a region of interest (ROI).
Args:
roi (WorldCooROI): The region of interest.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
mask_mode (str): Masking mode
mode (str): The mode to return the data.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
label_name = roi.infos.get("label_name", None)
if label_name is None:
raise ValueError("The label name must be provided in the ROI infos.")

data_pipe = self._build_roi_pipe(
roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions
)

if mask_mode == "bbox":
return self._get_pipe(data_pipe=data_pipe, mode=mode)

label = self._label_group.get_label(label_name, pixel_size=self.pixel_size)

mask = label.mask(
roi,
t=t,
mode=mode,
)
array = self._get_pipe(data_pipe=data_pipe, mode=mode)
where_func = np.where if mode == "numpy" else da.where
return where_func(mask, array, 0)

def consolidate(self, order: int = 1) -> None:
"""Consolidate the Zarr array."""
processed_paths = [self]
Expand Down
3 changes: 3 additions & 0 deletions src/ngio/core/label_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
highest_resolution: bool = False,
strict: bool = True,
cache: bool = True,
label_group=None,
) -> None:
"""Initialize the the Label Object.
Expand All @@ -44,6 +45,7 @@ def __init__(
strict (bool): Whether to raise an error where a pixel size is not found
to match the requested "pixel_size".
cache (bool): Whether to cache the metadata.
label_group: The group containing the labels.
"""
super().__init__(
store,
Expand All @@ -54,6 +56,7 @@ def __init__(
strict=strict,
meta_mode="label",
cache=cache,
label_group=label_group,
)

@property
Expand Down

0 comments on commit 7cd4407

Please sign in to comment.