Skip to content

Commit

Permalink
Merge pull request #10 from fractal-analytics-platform/remote_zarrs
Browse files Browse the repository at this point in the history
Support remote zarrs streaming
  • Loading branch information
lorenzocerrone authored Nov 6, 2024
2 parents fa7a405 + c64cad2 commit 40f2440
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 36 deletions.
5 changes: 4 additions & 1 deletion src/ngio/core/image_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ngio.core.image_like_handler import ImageLike
from ngio.core.roi import WorldCooROI
from ngio.io import StoreOrGroup
from ngio.io import AccessModeLiteral, StoreOrGroup
from ngio.ngff_meta import PixelSize
from ngio.ngff_meta.fractal_image_meta import ImageMeta
from ngio.utils._common_types import ArrayLike
Expand All @@ -27,6 +27,7 @@ def __init__(
highest_resolution: bool = False,
strict: bool = True,
cache: bool = True,
mode: AccessModeLiteral = "r+",
label_group: Any = None,
) -> None:
"""Initialize the the Image Object.
Expand All @@ -42,6 +43,7 @@ def __init__(
strict (bool): Whether to raise an error where a pixel size is not found
to match the requested "pixel_size".
cache (bool): Whether to cache the metadata.
mode (AccessModeLiteral): The mode to open the group in.
label_group: The group containing the labels.
"""
super().__init__(
Expand All @@ -53,6 +55,7 @@ def __init__(
strict=strict,
meta_mode="image",
cache=cache,
mode=mode,
_label_group=label_group,
)

Expand Down
10 changes: 6 additions & 4 deletions src/ngio/core/image_like_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ngio.core.dimensions import Dimensions
from ngio.core.roi import WorldCooROI
from ngio.core.utils import Lock
from ngio.io import StoreOrGroup, open_group_wrapper
from ngio.io import AccessModeLiteral, StoreOrGroup, open_group_wrapper
from ngio.ngff_meta import (
Dataset,
ImageLabelMeta,
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
strict: bool = True,
meta_mode: Literal["image", "label"] = "image",
cache: bool = True,
mode: AccessModeLiteral = "r+",
_label_group: Any = None,
) -> None:
"""Initialize the MultiscaleHandler in read mode.
Expand All @@ -63,8 +64,9 @@ def __init__(
if not strict:
warn("Strict mode is not fully supported yet.", UserWarning, stacklevel=2)

self._mode = mode
if not isinstance(store, zarr.Group):
store = open_group_wrapper(store=store, mode="r+")
store = open_group_wrapper(store=store, mode=self._mode)

self._group = store

Expand Down Expand Up @@ -99,11 +101,11 @@ def _init_dataset(self, dataset: Dataset) -> None:
This method is for internal use only.
"""
self._dataset = dataset
self._array = self._group.get(self._dataset.path, None)

if self._dataset.path not in self._group.array_keys():
if self._array is None:
raise ValueError(f"Dataset {self._dataset.path} not found in the group.")

self._array = self._group[self.dataset.path]
self._diminesions = Dimensions(
on_disk_shape=self._array.shape,
axes_names=self._dataset.axes_names,
Expand Down
34 changes: 23 additions & 11 deletions src/ngio/core/label_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
highest_resolution: bool = False,
strict: bool = True,
cache: bool = True,
mode: AccessModeLiteral = "r+",
label_group: Any = None,
) -> None:
"""Initialize the the Label Object.
Expand All @@ -46,6 +47,7 @@ def __init__(
strict (bool): Whether to raise an error where a pixel size is not found
to match the requested "pixel_size".
cache (bool): Whether to cache the metadata.
mode (AccessModeLiteral): The mode to open the group in.
label_group: The group containing the labels.
"""
super().__init__(
Expand All @@ -57,6 +59,7 @@ def __init__(
strict=strict,
meta_mode="label",
cache=cache,
mode=mode,
_label_group=label_group,
)

Expand Down Expand Up @@ -216,19 +219,22 @@ def __init__(
if not isinstance(group, zarr.Group):
group = zarr.open_group(group, mode=self._mode)

if "labels" not in group:
self._group = group.create_group("labels")
self._group.attrs["labels"] = [] # initialize the labels attribute
else:
self._group = group["labels"]
assert isinstance(self._group, zarr.Group)
label_group = group.get("labels", None)
if label_group is None and not group.read_only:
label_group = group.create_group("labels")
label_group.attrs["labels"] = [] # initialize the labels attribute

assert isinstance(label_group, zarr.Group) or label_group is None
self._label_group = label_group

self._image_ref = image_ref
self._metadata_cache = cache

def list(self) -> list[str]:
"""List all labels in the group."""
_labels = self._group.attrs.get("labels", [])
if self._label_group is None:
return []
_labels = self._label_group.attrs.get("labels", [])
assert isinstance(_labels, list)
return _labels

Expand Down Expand Up @@ -259,14 +265,17 @@ def get_label(
highest_resolution (bool, optional): Whether to get the highest
resolution level
"""
if self._label_group is None:
raise ValueError("No labels found in the group.")

if name not in self.list():
raise ValueError(f"Label {name} not found in the group.")

if path is not None or pixel_size is not None:
highest_resolution = False

return Label(
store=self._group[name],
store=self._label_group[name],
path=path,
pixel_size=pixel_size,
highest_resolution=highest_resolution,
Expand All @@ -287,17 +296,20 @@ def derive(
Default is False.
**kwargs: Additional keyword arguments to pass to the new label.
"""
if self._label_group is None:
raise ValueError("Cannot derive a new label. Group is empty or read-only.")

list_of_labels = self.list()

if overwrite and name in list_of_labels:
self._group.attrs["label"] = [
self._label_group.attrs["label"] = [
label for label in list_of_labels if label != name
]
elif not overwrite and name in list_of_labels:
raise ValueError(f"Label {name} already exists in the group.")

# create the new label
new_label_group = self._group.create_group(name, overwrite=overwrite)
new_label_group = self._label_group.create_group(name, overwrite=overwrite)

if self._image_ref is None:
label_0 = self.get_label(list_of_labels[0])
Expand Down Expand Up @@ -349,5 +361,5 @@ def derive(
)

if name not in self.list():
self._group.attrs["labels"] = [*list_of_labels, name]
self._label_group.attrs["labels"] = [*list_of_labels, name]
return self.get_label(name)
8 changes: 7 additions & 1 deletion src/ngio/core/ngff_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ def __init__(
self.store = store
self._mode = mode
self.group = open_group_wrapper(store=store, mode=self._mode)

if self.group.read_only:
self._mode = "r"

self._image_meta = get_ngff_image_meta_handler(
self.group, meta_mode="image", cache=cache
)
self._metadata_cache = cache
self.table = TableGroup(self.group, mode=self._mode)
self.label = LabelGroup(self.group, image_ref=self.get_image(), mode=self._mode)

ngio_logger.info(f"Opened image located in store: {store}")
ngio_logger.info(f"- Image number of levels: {self.num_levels}")

Expand Down Expand Up @@ -76,8 +81,9 @@ def get_image(
path=path,
pixel_size=pixel_size,
highest_resolution=highest_resolution,
label_group=LabelGroup(self.group, image_ref=None),
label_group=LabelGroup(self.group, image_ref=None, mode=self._mode),
cache=self._metadata_cache,
mode=self._mode,
)
ngio_logger.info(f"Opened image at path: {image.path}")
ngio_logger.info(f"- {image.dimensions}")
Expand Down
9 changes: 8 additions & 1 deletion src/ngio/io/_zarr_group_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import fsspec
import zarr

from ngio.io._zarr import (
Expand All @@ -22,8 +23,14 @@ def _check_store(store: StoreLike) -> StoreLike:
if isinstance(store, str) or isinstance(store, Path):
return store

if isinstance(store, fsspec.mapping.FSMap) or isinstance(
store, zarr.storage.FSStore
):
return store

raise NotImplementedError(
"RemoteStore is not yet supported. Please use LocalStore."
f"Store type {type(store)} is not supported. supported types are: "
"str, Path, fsspec.mapping.FSMap, zarr.storage.FSStore"
)


Expand Down
52 changes: 39 additions & 13 deletions src/ngio/tables/tables_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ngio.io import AccessModeLiteral, StoreLike
from ngio.tables.v1 import FeatureTableV1, MaskingROITableV1, ROITableV1
from ngio.utils import ngio_logger
from ngio.utils._pydantic_utils import BaseWithExtraFields

ROITable = ROITableV1
Expand Down Expand Up @@ -95,22 +96,38 @@ def __init__(
if not isinstance(group, zarr.Group):
group = zarr.open_group(group, mode=self._mode)

if "tables" not in group:
self._group = group.create_group("tables")
else:
self._group: zarr.Group = group["tables"]
table_group = group.get("tables", None)

if table_group is None and not group.read_only:
table_group = group.create_group("tables")
table_group.attrs["tables"] = []

assert isinstance(table_group, zarr.Group) or table_group is None
self._table_group = table_group

def _validate_list_of_tables(self, list_of_tables: list[str]) -> None:
"""Validate the list of tables."""
list_of_groups = list(self._group.group_keys())
"""Validate the list of tables.
Args:
list_of_tables (list[str]): The list of tables to validate.
"""
if self._table_group is None:
return None

for table_name in list_of_tables:
if table_name not in list_of_groups:
raise ValueError(f"Table {table_name} not found in the group.")
table = self._table_group.get(table_name, None)
if table is None:
ngio_logger.warning(
f"Table {table_name} not found in the group. "
"Consider removing it from the list of tables."
)

def _get_list_of_tables(self) -> list[str]:
"""Return the list of tables."""
list_of_tables = self._group.attrs.get("tables", [])
if self._table_group is None:
return []

list_of_tables = self._table_group.attrs.get("tables", [])
self._validate_list_of_tables(list_of_tables)
assert isinstance(list_of_tables, list)
assert all(isinstance(table_name, str) for table_name in list_of_tables)
Expand All @@ -127,6 +144,9 @@ def list(
If None, all tables are listed.
Allowed values are: 'roi_table', 'feature_table', 'masking_roi_table'.
"""
if self._table_group is None:
return []

list_of_tables = self._get_list_of_tables()
self._validate_list_of_tables(list_of_tables=list_of_tables)
if table_type is None:
Expand All @@ -140,7 +160,7 @@ def list(
)
list_of_typed_tables = []
for table_name in list_of_tables:
table = self._group[table_name]
table = self._table_group[table_name]
try:
common_meta = CommonMeta(**table.attrs)
if common_meta.type == table_type:
Expand Down Expand Up @@ -173,12 +193,15 @@ def get_table(
This is usually defined in the metadata of the table, if given here,
it will overwrite the metadata.
"""
if self._table_group is None:
raise ValueError("No tables group found in the group.")

list_of_tables = self._get_list_of_tables()
if name not in list_of_tables:
raise ValueError(f"Table {name} not found in the group.")

return _get_table_impl(
group=self._group[name],
group=self._table_group[name],
validate_metadata=validate_metadata,
table_type=table_type,
validate_table=validate_table,
Expand All @@ -194,6 +217,9 @@ def new(
**type_specific_kwargs: dict,
) -> Table:
"""Add a new table to the group."""
if self._table_group is None:
raise ValueError("No tables group found in the group.")

list_of_tables = self._get_list_of_tables()
if not overwrite and name in list_of_tables:
raise ValueError(f"Table {name} already exists in the group.")
Expand All @@ -203,13 +229,13 @@ def new(

table_impl = _find_table_impl(table_type=table_type, version=version)
new_table = table_impl._new(
parent_group=self._group,
parent_group=self._table_group,
name=name,
overwrite=overwrite,
**type_specific_kwargs,
)

self._group.attrs["tables"] = [*list_of_tables, name]
self._table_group.attrs["tables"] = [*list_of_tables, name]

assert isinstance(new_table, ROITable | FeatureTable | MaskingROITable)
return new_table
13 changes: 12 additions & 1 deletion tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from importlib.metadata import version
from pathlib import Path

import fsspec
import fsspec.implementations.http
import zarr
from packaging.version import Version
from pytest import fixture
Expand All @@ -11,7 +13,7 @@


@fixture
def ome_zarr_image_v04_path(tmpdir):
def ome_zarr_image_v04_path(tmpdir: str) -> Path:
zarr_path = Path(tmpdir) / "test_ome_ngff_v04.zarr"

if ZARR_PYTHON_V == 3:
Expand All @@ -37,3 +39,12 @@ def ome_zarr_image_v04_path(tmpdir):
group.zeros(name=path, shape=shape)

return zarr_path


@fixture
def ome_zarr_image_v04_fs() -> fsspec.mapping.FSMap:
fs = fsspec.implementations.http.HTTPFileSystem(client_kwargs={})
store = fs.get_mapper(
"https://raw.githubusercontent.com/fractal-analytics-platform/fractal-tasks-core/refs/heads/main/tests/data/plate_ones.zarr/B/03/0/"
)
return store
Loading

0 comments on commit 40f2440

Please sign in to comment.