From 7cd4407186a3be2d1707ff5ffb1101f96ecb7efb Mon Sep 17 00:00:00 2001 From: lorenzo Date: Mon, 14 Oct 2024 14:26:16 +0200 Subject: [PATCH] move masking to base image --- src/ngio/core/image_handler.py | 54 ++--------------------------- src/ngio/core/image_like_handler.py | 45 ++++++++++++++++++++++++ src/ngio/core/label_handler.py | 3 ++ 3 files changed, 50 insertions(+), 52 deletions(-) diff --git a/src/ngio/core/image_handler.py b/src/ngio/core/image_handler.py index 9128e43..da28734 100644 --- a/src/ngio/core/image_handler.py +++ b/src/ngio/core/image_handler.py @@ -1,13 +1,6 @@ """A module to handle OME-NGFF images stored in Zarr format.""" -from typing import Literal - -import dask.array as da -import numpy as np - -from ngio._common_types import ArrayLike from ngio.core.image_like_handler import ImageLike -from ngio.core.roi import WorldCooROI from ngio.io import StoreOrGroup from ngio.ngff_meta.fractal_image_meta import ImageMeta, PixelSize @@ -44,6 +37,7 @@ def __init__( strict (bool): Whether to raise an error where a pixel size is not found to match the requested "pixel_size". cache (bool): Whether to cache the metadata. + label_group: The group containing the labels. """ super().__init__( store=store, @@ -54,8 +48,8 @@ def __init__( strict=strict, meta_mode="image", cache=cache, + _label_group=label_group, ) - self._label_group = label_group @property def metadata(self) -> ImageMeta: @@ -74,47 +68,3 @@ def get_channel_idx( ) -> int: """Return the index of the channel.""" return self.metadata.get_channel_idx(label=label, wavelength_id=wavelength_id) - - def get_array_masked( - self, - roi: WorldCooROI, - t: int | slice | None = None, - c: int | slice | None = None, - mask_mode: Literal["bbox", "mask"] = "bbox", - mode: Literal["numpy"] = "numpy", - preserve_dimensions: bool = False, - ) -> ArrayLike: - """Return the image data from a region of interest (ROI). - - Args: - roi (WorldCooROI): The region of interest. - t (int | slice | None): The time index or slice. - c (int | slice | None): The channel index or slice. - mask_mode (str): Masking mode - mode (str): The mode to return the data. - preserve_dimensions (bool): Whether to preserve the dimensions of the data. - """ - label_name = roi.infos.get("label_name", None) - if label_name is None: - raise ValueError("The label name must be provided in the ROI infos.") - - data_pipe = self._build_roi_pipe( - roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions - ) - - if mask_mode == "bbox": - return self._get_pipe(data_pipe=data_pipe, mode=mode) - - label = self._label_group.get_label(label_name, pixel_size=self.pixel_size) - - mask = label.mask( - roi, - t=t, - mode=mode, - ) - array = self._get_pipe(data_pipe=data_pipe, mode=mode) - if mode == "numpy": - return_array = np.where(mask, array, 0) - else: - return_array = da.where(mask, array, 0) - return return_array diff --git a/src/ngio/core/image_like_handler.py b/src/ngio/core/image_like_handler.py index c1ad1a9..6dae1c7 100644 --- a/src/ngio/core/image_like_handler.py +++ b/src/ngio/core/image_like_handler.py @@ -42,6 +42,7 @@ def __init__( strict: bool = True, meta_mode: Literal["image", "label"] = "image", cache: bool = True, + _label_group=None, ) -> None: """Initialize the MultiscaleHandler in read mode. @@ -57,6 +58,7 @@ def __init__( to match the requested "pixel_size". meta_mode (str): The mode of the metadata handler. cache (bool): Whether to cache the metadata. + _label_group (LabelGroup): The group containing the label data (internal use). """ if not strict: warn("Strict mode is not fully supported yet.", UserWarning, stacklevel=2) @@ -89,6 +91,8 @@ def __init__( self._init_dataset(dataset) self._dask_lock = None + self._label_group = _label_group + def _init_dataset(self, dataset: Dataset): """Set the dataset of the image. @@ -415,6 +419,47 @@ def set_array( ) self._set_pipe(data_pipe=data_pipe, patch=patch) + def get_array_masked( + self, + roi: WorldCooROI, + t: int | slice | None = None, + c: int | slice | None = None, + mask_mode: Literal["bbox", "mask"] = "bbox", + mode: Literal["numpy"] = "numpy", + preserve_dimensions: bool = False, + ) -> ArrayLike: + """Return the image data from a region of interest (ROI). + + Args: + roi (WorldCooROI): The region of interest. + t (int | slice | None): The time index or slice. + c (int | slice | None): The channel index or slice. + mask_mode (str): Masking mode + mode (str): The mode to return the data. + preserve_dimensions (bool): Whether to preserve the dimensions of the data. + """ + label_name = roi.infos.get("label_name", None) + if label_name is None: + raise ValueError("The label name must be provided in the ROI infos.") + + data_pipe = self._build_roi_pipe( + roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions + ) + + if mask_mode == "bbox": + return self._get_pipe(data_pipe=data_pipe, mode=mode) + + label = self._label_group.get_label(label_name, pixel_size=self.pixel_size) + + mask = label.mask( + roi, + t=t, + mode=mode, + ) + array = self._get_pipe(data_pipe=data_pipe, mode=mode) + where_func = np.where if mode == "numpy" else da.where + return where_func(mask, array, 0) + def consolidate(self, order: int = 1) -> None: """Consolidate the Zarr array.""" processed_paths = [self] diff --git a/src/ngio/core/label_handler.py b/src/ngio/core/label_handler.py index 2b929ab..22e1031 100644 --- a/src/ngio/core/label_handler.py +++ b/src/ngio/core/label_handler.py @@ -30,6 +30,7 @@ def __init__( highest_resolution: bool = False, strict: bool = True, cache: bool = True, + label_group=None, ) -> None: """Initialize the the Label Object. @@ -44,6 +45,7 @@ def __init__( strict (bool): Whether to raise an error where a pixel size is not found to match the requested "pixel_size". cache (bool): Whether to cache the metadata. + label_group: The group containing the labels. """ super().__init__( store, @@ -54,6 +56,7 @@ def __init__( strict=strict, meta_mode="label", cache=cache, + label_group=label_group, ) @property