Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr support with tensorstore backend #83

Merged
merged 6 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/bfio/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,20 @@ def read_only(self):
raise AttributeError(self._READ_ONLY_MESSAGE.format("read_only"))

def __getattribute__(self, name):
# delay metadata parsing as long as posible for tensorstore backend
if name.lower() == "metadata" and self._metadata is None:
self._metadata = self._backend.read_metadata()
# Get image dimensions using num_x, x, or X
if len(name) == 1 and name.lower() in "xyzct":
if self._metadata is None:
self._metadata = self._backend.read_metadata()
return getattr(
self._metadata.images[0].pixels, "size_{}".format(name.lower())
)
# for tensorstore, we do not need to parse metadata to get shape
if type(self._backend).__name__ == "TensorstoreReader":
return getattr(self._backend, name.upper())
else:
if self._metadata is None:
self._metadata = self._backend.read_metadata()
return getattr(
self._metadata.images[0].pixels, "size_{}".format(name.lower())
)
else:
return object.__getattribute__(self, name)

Expand Down
4 changes: 2 additions & 2 deletions src/bfio/bfio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from bfio import backends
from bfio.base_classes import BioBase
from bfio.ts_backends import TsOmeTiffReader
from bfio.ts_backends import TensorstoreReader


class BioReader(BioBase):
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
if self._backend_name == "python":
self._backend = backends.PythonReader(self)
elif self._backend_name == "tensorstore":
self._backend = TsOmeTiffReader(self)
self._backend = TensorstoreReader(self)
elif self._backend_name == "bioformats":
try:
self._backend = backends.JavaReader(self)
Expand Down
232 changes: 208 additions & 24 deletions src/bfio/ts_backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# import core packages
import logging
from pathlib import Path
from typing import Dict


Expand All @@ -9,13 +10,14 @@
from xml.etree import ElementTree as ET


from bfiocpp import TSTiffReader, Seq
from bfiocpp import TSReader, Seq, FileType, get_ome_xml
import bfio.base_classes
from bfio.utils import clean_ome_xml_for_known_issues
import zarr


class TsOmeTiffReader(bfio.base_classes.TSAbstractReader):
logger = logging.getLogger("bfio.backends.TsOmeTiffReader")
class TensorstoreReader(bfio.base_classes.TSAbstractReader):
logger = logging.getLogger("bfio.backends.TensorstoreReader")

_rdr = None
_offsets_bytes = None
Expand All @@ -24,19 +26,116 @@ class TsOmeTiffReader(bfio.base_classes.TSAbstractReader):
def __init__(self, frontend):
super().__init__(frontend)

self.logger.debug("__init__(): Initializing _rdr (TSTiffReader)...")
self._rdr = TSTiffReader(str(self.frontend._file_path))
self.logger.debug("__init__(): Initializing _rdr (TSReader)...")
extension = "".join(self.frontend._file_path.suffixes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to access frontend._file_path through a public method or is this the only way?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no getters for this object at the moment. Since everything is public in Python and this is a low-level library implementation code (no user API has direct access to frontend), I opted not to add another layer of indirection.

if extension.endswith(".ome.tif") or extension.endswith(".ome.tiff"):
# # check if it satisfies all the condition for python backend
self._file_type = FileType.OmeTiff
self._rdr = TSReader(str(self.frontend._file_path), FileType.OmeTiff, "")
elif extension.endswith(".zarr"):
# if path exists, make sure it is a directory
if not Path.is_dir(self.frontend._file_path):
raise ValueError(
"this filetype is not supported by tensorstore backend"
)
else:
zarr_path, axes_list = self.get_zarr_array_info()
self._file_type = FileType.OmeZarr
self._rdr = TSReader(zarr_path, FileType.OmeZarr, axes_list)

self.X = self._rdr._X
self.Y = self._rdr._Y
self.Z = self._rdr._Z
self.C = self._rdr._C
self.T = self._rdr._T
self.data_type = self._rdr._datatype

def get_zarr_array_info(self):
self.logger.debug(f"Level is {self.frontend.level}")

# do test for strip mages
root = None
root_path = self.frontend._file_path
try:
root = zarr.open(str(root_path.resolve()), mode="r")
except zarr.errors.PathNotFoundError:
# a workaround for pre-compute slide output directory structure
root_path = self.frontend._file_path / "data.zarr"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this log a warning or is it expected behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expected behavior. Pyramid zarr outputs from precompute slide plugin puts everything inside this data.zarr and then puts this directory in a top level directory. This is something how Viv wants the folder to be organized.

root = zarr.open(root_path.resolve(), mode="r")

# do test for interleaved images
axes_list = ""
if self.frontend.level is None:
if isinstance(root, zarr.core.Array):
return str(root_path.resolve()), axes_list
elif isinstance(root, zarr.hierarchy.Group):
# the top level is a group, check if this has any arrays
num_arrays = len(sorted(root.array_keys()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this the len of the sorted function instead of just taking the length of root.array_keys()?

Copy link
Member Author

@sameeul sameeul May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array_keys() is actually a generator function. Calling sorted on it actually generate the list and then we can get the length of all the elements produced by the generator.
Since the list contains small number of elements (~20 max), I don't expect this to be a huge overhead. I wish there is any other trick to do that.

if num_arrays > 0:
array_key = next(root.array_keys())
root_path = root_path / str(array_key)
try:
axes_metadata = root.attrs["multiscales"][0]["axes"]
axes_list = "".join(
axes["name"].upper() for axes in axes_metadata
)
except KeyError:
self.logger.warning(
"Unable to find multiscales metadata. Z, C and T "
+ "dimensions might be incorrect."
)

# do test for dimension order
return str(root_path.resolve()), axes_list
else:
# need to go one more level
group_key = next(root.group_keys())
root = root[group_key]
try:
axes_metadata = root.attrs["multiscales"][0]["axes"]
axes_list = "".join(
axes["name"].upper() for axes in axes_metadata
)
except KeyError:
self.logger.warning(
"Unable to find multiscales metadata. Z, C and T "
+ "dimensions might be incorrect."
)

array_key = next(root.array_keys())
root_path = root_path / str(group_key) / str(array_key)
return str(root_path.resolve()), axes_list
else:
return str(root_path.resolve()), axes_list
else:
if isinstance(root, zarr.core.Array):
self.close()
raise ValueError(
"Level is specified but the zarr file does not contain "
+ "multiple resoulutions."
)
elif isinstance(root, zarr.hierarchy.Group):
if len(sorted(root.array_keys())) > self.frontend.level:
root_path = root_path / str(self.frontend.level)
try:
axes_metadata = root.attrs["multiscales"][0]["axes"]
axes_list = "".join(
axes["name"].upper() for axes in axes_metadata
)
except KeyError:
self.logger.warning(
"Unable to find multiscales metadata. Z, C and T "
+ "dimensions might be incorrect."
)
return str(root_path.resolve()), axes_list
else:
raise ValueError(
"The zarr file does not contain resolution "
+ "level {}.".format(self.frontend.level)
)
else:
raise ValueError(
"The zarr file does not contain resolution level {}.".format(
self.frontend.level
)
)

def __getstate__(self) -> Dict:
state_dict = {n: getattr(self, n) for n in self._STATE_DICT}
Expand All @@ -54,22 +153,10 @@ def __setstate__(self, state) -> None:
def read_metadata(self):

self.logger.debug("read_metadata(): Reading metadata...")
if self._metadata is None:
try:
self._metadata = ome_types.from_xml(
self._rdr.ome_metadata(), validate=False
)
except (ET.ParseError, ValueError):
if self.frontend.clean_metadata:
cleaned = clean_ome_xml_for_known_issues(self._rdr.ome_metadata())
self._metadata = ome_types.from_xml(cleaned, validate=False)
self.logger.warning(
"read_metadata(): OME XML required reformatting."
)
else:
raise

return self._metadata
if self._file_type == FileType.OmeTiff:
return self.read_tiff_metadata()
if self._file_type == FileType.OmeZarr:
return self.read_zarr_metadata()

def read_image(self, X, Y, Z, C, T):

Expand All @@ -92,3 +179,100 @@ def close(self):

def __del__(self):
self.close()

def read_tiff_metadata(self):
self.logger.debug("read_tiff_metadata(): Reading metadata...")
if self._metadata is None:
try:
self._metadata = ome_types.from_xml(
get_ome_xml(str(self.frontend._file_path)), validate=False
)
except (ET.ParseError, ValueError):
if self.frontend.clean_metadata:
cleaned = clean_ome_xml_for_known_issues(
get_ome_xml(str(self.frontend._file_path))
)
self._metadata = ome_types.from_xml(cleaned, validate=False)
self.logger.warning(
"read_metadata(): OME XML required reformatting."
)
else:
raise

return self._metadata

def read_zarr_metadata(self):
self.logger.debug("read_zarr_metadata(): Reading metadata...")
if self._metadata is None:

metadata_path = self.frontend._file_path.joinpath("METADATA.ome.xml")

if not metadata_path.exists():
# try to look for OME directory
metadata_path = self.frontend._file_path.joinpath("OME").joinpath(
"METADATA.ome.xml"
)
if metadata_path.exists():
if self._metadata is None:
with open(metadata_path) as fr:
metadata = fr.read()

try:
self._metadata = ome_types.from_xml(metadata, validate=False)
except ET.ParseError:
if self.frontend.clean_metadata:
cleaned = clean_ome_xml_for_known_issues(metadata)
self._metadata = ome_types.from_xml(cleaned, validate=False)
self.logger.warning(
"read_metadata(): OME XML required reformatting."
)
else:
raise

if self.frontend.level is not None:
self._metadata.images[0].pixels.size_x = self._rdr._X
self._metadata.images[0].pixels.size_y = self._rdr._Y

return self._metadata
else:
# Couldn't find OMEXML metadata, scrape metadata from file
omexml = ome_types.model.OME.model_construct()
ome_dtype = self._rdr._datatype
if ome_dtype == "float64":
ome_dtype = "double"
elif ome_dtype == "float32":
ome_dtype = "float"
else:
pass
# this is speculation, since each array in a group, in theory,
# can have distinct properties
ome_dim_order = ome_types.model.Pixels_DimensionOrder.XYZCT
size_x = self._rdr._X
size_y = self._rdr._Y
size_z = self._rdr._Z
size_c = self._rdr._C
size_t = self._rdr._T

ome_pixel = ome_types.model.Pixels(
dimension_order=ome_dim_order,
big_endian=False,
size_x=size_x,
size_y=size_y,
size_z=size_z,
size_c=size_c,
size_t=size_t,
channels=[],
type=ome_dtype,
)

for i in range(ome_pixel.size_c):
ome_pixel.channels.append(ome_types.model.Channel())

omexml.images.append(
ome_types.model.Image(
name=Path(self.frontend._file_path).name, pixels=ome_pixel
)
)

self._metadata = omexml
return self._metadata
65 changes: 65 additions & 0 deletions tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,50 @@ def test_sub_resolution_read(self):
self.assertEqual(br.shape, (1350, 1351, 1, 27))


class TestZarrTSReader(unittest.TestCase):
def test_get_dims(self):
"""Testing metadata dimension attributes"""
with bfio.BioReader(
TEST_DIR.joinpath("4d_array.zarr"), backend="tensorstore"
) as br:
get_dims(br)
self.assertEqual(br.shape, (512, 672, 21, 3))

def test_get_pixel_size(self):
"""Testing metadata pixel sizes"""
with bfio.BioReader(
TEST_DIR.joinpath("4d_array.zarr"), backend="tensorstore"
) as br:
get_pixel_size(br)

def test_get_pixel_info(self):
"""Testing metadata pixel information"""
with bfio.BioReader(
TEST_DIR.joinpath("4d_array.zarr"), backend="tensorstore"
) as br:
get_pixel_info(br)

def test_get_channel_names(self):
"""Testing metadata channel names"""
with bfio.BioReader(
TEST_DIR.joinpath("4d_array.zarr"), backend="tensorstore"
) as br:
get_channel_names(br)

def test_sub_resolution_read(self):
"""Testing multi-resolution read"""
with bfio.BioReader(
TEST_DIR.joinpath("5025551.zarr"), backend="tensorstore"
) as br:
get_dims(br)
self.assertEqual(br.shape, (2700, 2702, 1, 27))
with bfio.BioReader(
TEST_DIR.joinpath("5025551.zarr"), backend="tensorstore", level=1
) as br:
get_dims(br)
self.assertEqual(br.shape, (1350, 1351, 1, 27))


class TestZarrMetadata(unittest.TestCase):
def test_set_metadata(self):
"""Testing metadata dimension attributes"""
Expand All @@ -355,3 +399,24 @@ def test_set_metadata(self):
logger.info(br.cnames)
logger.info(br.ps_x)
self.assertEqual(br.cnames[0], cname[0])


class TestZarrTesnsorstoreMetadata(unittest.TestCase):
def test_set_metadata(self):
"""Testing metadata dimension attributes"""
cname = ["test"]

image = np.load(TEST_DIR.joinpath("4d_array.npy"))

with bfio.BioWriter(TEST_DIR.joinpath("test_cname.ome.zarr")) as bw:
bw.cnames = cname
bw.ps_x = (100, "nm")
bw.shape = image.shape
bw[:] = image

with bfio.BioReader(
TEST_DIR.joinpath("test_cname.ome.zarr"), backend="tensorstore"
) as br:
logger.info(br.cnames)
logger.info(br.ps_x)
self.assertEqual(br.cnames[0], cname[0])
Loading