diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f7f92dc..620b5621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,12 @@ ### Changed - Use `DataArray` and `DataTree` typing instead of (Multiscale)SpatialImage (as in `spatialdata>=0.2.0`) +### Fix +- Fix Xenium reader issue for recent machine versions (#80) + +### Changed +- Fully depends on `spatialdata-io` for the MERSCOPE and the Xenium reader + ## [1.1.0] - 2024-06-11 First post-publication release diff --git a/sopa/io/reader/merscope.py b/sopa/io/reader/merscope.py index 48b7c38a..b1848d8c 100644 --- a/sopa/io/reader/merscope.py +++ b/sopa/io/reader/merscope.py @@ -1,43 +1,27 @@ -# Updated from spatialdata-io: https://spatialdata.scverse.org/projects/io/en/latest/ -# In the future, we will completely rely on spatialdata-io - from __future__ import annotations import logging -import re -import warnings from pathlib import Path -from typing import Callable +from typing import Literal -import dask.array as da -import dask.dataframe as dd -import numpy as np -import xarray -from dask_image.imread import imread from spatialdata import SpatialData -from spatialdata._logging import logger -from spatialdata.models import Image2DModel, PointsModel -from spatialdata.transformations import Affine, Identity -from spatialdata_io._constants._constants import MerscopeKeys +from spatialdata_io.readers.merscope import merscope as merscope_spatialdata_io from .utils import _default_image_kwargs log = logging.getLogger(__name__) -SUPPORTED_BACKENDS = ["dask_image", "rioxarray"] - - def merscope( path: str | Path, - backend: str = None, + backend: Literal["dask_image", "rioxarray"] | None = None, z_layers: int | list[int] | None = 3, region_name: str | None = None, slide_name: str | None = None, image_models_kwargs: dict | None = None, imread_kwargs: dict | None = None, ) -> SpatialData: - """Read MERSCOPE data as a `SpatialData` object. + """Read MERSCOPE data as a `SpatialData` object. For more information, refer to [spatialdata-io](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.merscope.html). This function reads the following files: - `detected_transcripts.csv`: transcripts locations and names @@ -46,7 +30,7 @@ def merscope( Args: path: Path to the MERSCOPE directory containing all the experiment files - backend: Either `"dask_image"` or `"rioxarray"` (the latter uses less RAM). By default, uses `"rioxarray"` if and only if the `rioxarray` library is installed. + backend: Either `"dask_image"` or `"rioxarray"` (the latter uses less RAM, but requires `rioxarray` to be installed). By default, uses `"rioxarray"` if and only if the `rioxarray` library is installed. z_layers: Indices of the z-layers to consider. Either one `int` index, or a list of `int` indices. If `None`, then no image is loaded. By default, only the middle layer is considered (that is, layer 3). region_name: Name of the region of interest, e.g., `'region_0'`. If `None` then the name of the `path` directory is used. slide_name: Name of the slide/run. If `None` then the name of the parent directory of `path` is used (whose name starts with a date). @@ -56,145 +40,16 @@ def merscope( Returns: A `SpatialData` object representing the MERSCOPE experiment """ - assert ( - backend is None or backend in SUPPORTED_BACKENDS - ), f"Backend '{backend} not supported. Should be one of: {', '.join(SUPPORTED_BACKENDS)}" - - path = Path(path).absolute() image_models_kwargs, imread_kwargs = _default_image_kwargs(image_models_kwargs, imread_kwargs) - images_dir = path / MerscopeKeys.IMAGES_DIR - - microns_to_pixels = Affine( - np.genfromtxt(images_dir / MerscopeKeys.TRANSFORMATION_FILE), - input_axes=("x", "y"), - output_axes=("x", "y"), - ) - - vizgen_region = path.name if region_name is None else region_name - slide_name = path.parent.name if slide_name is None else slide_name - dataset_id = f"{slide_name}_{vizgen_region}" - - # Images - images = {} - - z_layers = [z_layers] if isinstance(z_layers, int) else z_layers or [] - - stainings = _get_channel_names(images_dir) - image_transformations = {"microns": microns_to_pixels.inverse()} - - reader = _get_reader(backend) - - if stainings: - for z_layer in z_layers: - images[f"{dataset_id}_z{z_layer}"] = reader( - images_dir, - stainings, - z_layer, - image_models_kwargs, - image_transformations, - **imread_kwargs, - ) - - # Transcripts - points = {} - transcript_path = path / MerscopeKeys.TRANSCRIPTS_FILE - if transcript_path.exists(): - points[f"{dataset_id}_transcripts"] = _get_points(transcript_path) - else: - logger.warning( - f"Transcript file {transcript_path} does not exist. Transcripts are not loaded." - ) - - return SpatialData(points=points, images=images) - - -def _get_reader(backend: str | None) -> Callable: - if backend is not None: - return _rioxarray_load_merscope if backend == "rioxarray" else _dask_image_load_merscope - try: - import rioxarray # noqa: F401 - - return _rioxarray_load_merscope - except: - return _dask_image_load_merscope - - -def _get_channel_names(images_dir: Path) -> list[str]: - exp = r"mosaic_(?P[\w|-]+[0-9]?)_z(?P[0-9]+).tif" - matches = [re.search(exp, file.name) for file in images_dir.iterdir()] - - stainings = {match.group("stain") for match in matches if match} - - return list(stainings) - - -def _rioxarray_load_merscope( - images_dir: Path, - stainings: list[str], - z_layer: int, - image_models_kwargs: dict, - transformations: dict, - **kwargs, -): - log.info("Using rioxarray backend.") - - import rioxarray - from rasterio.errors import NotGeoreferencedWarning - - warnings.simplefilter("ignore", category=NotGeoreferencedWarning) - - im = xarray.concat( - [ - rioxarray.open_rasterio( - images_dir / f"mosaic_{stain}_z{z_layer}.tif", - chunks=image_models_kwargs["chunks"], - **kwargs, - ) - .rename({"band": "c"}) - .reset_coords("spatial_ref", drop=True) - for stain in stainings - ], - dim="c", - ) - - return Image2DModel.parse( - im, transformations=transformations, c_coords=stainings, **image_models_kwargs - ) - - -def _dask_image_load_merscope( - images_dir: Path, - stainings: list[str], - z_layer: int, - image_models_kwargs: dict, - transformations: dict, - **kwargs, -): - im = da.stack( - [ - imread(images_dir / f"mosaic_{stain}_z{z_layer}.tif", **kwargs).squeeze() - for stain in stainings - ], - axis=0, - ) - - return Image2DModel.parse( - im, - dims=("c", "y", "x"), - transformations=transformations, - c_coords=stainings, - **image_models_kwargs, - ) - - -def _get_points(transcript_path: Path): - transcript_df = dd.read_csv(transcript_path) - transcripts = PointsModel.parse( - transcript_df, - coordinates={"x": MerscopeKeys.GLOBAL_X, "y": MerscopeKeys.GLOBAL_Y}, - feature_key="gene", - transformations={"microns": Identity()}, + return merscope_spatialdata_io( + path, + backend=backend, + z_layers=z_layers, + region_name=region_name, + slide_name=slide_name, + image_models_kwargs=image_models_kwargs, + imread_kwargs=imread_kwargs, + cells_boundaries=False, + cells_table=False, ) - transcripts["gene"] = transcripts["gene"].astype("category") - return transcripts diff --git a/sopa/io/reader/xenium.py b/sopa/io/reader/xenium.py index 32120ade..37cde294 100644 --- a/sopa/io/reader/xenium.py +++ b/sopa/io/reader/xenium.py @@ -1,22 +1,10 @@ -# Updated from spatialdata-io: https://spatialdata.scverse.org/projects/io/en/latest/ -# In the future, we will completely rely on spatialdata-io - from __future__ import annotations -import json import logging from pathlib import Path -from typing import Any -import dask.array as da -import packaging.version -from dask.dataframe import read_parquet -from dask_image.imread import imread from spatialdata import SpatialData -from spatialdata.models import Image2DModel, PointsModel -from spatialdata.transformations import Identity, Scale -from spatialdata_io._constants._constants import XeniumKeys -from spatialdata_io.readers.xenium import _parse_version_of_xenium_analyzer +from spatialdata_io.readers.xenium import xenium as xenium_spatialdata_io from .utils import _default_image_kwargs @@ -44,94 +32,18 @@ def xenium( Returns: A `SpatialData` object representing the Xenium experiment """ - path = Path(path) image_models_kwargs, imread_kwargs = _default_image_kwargs(image_models_kwargs, imread_kwargs) - image_models_kwargs["c_coords"] = [XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0.value] - - with open(path / XeniumKeys.XENIUM_SPECS) as f: - specs = json.load(f) - - points = {"transcripts": _get_points(path, specs)} - - with open(path / XeniumKeys.XENIUM_SPECS) as f: - specs = json.load(f) - # to trigger the warning if the version cannot be parsed - version = _parse_version_of_xenium_analyzer(specs, hide_warning=False) - - images = {} - if version is None or version < packaging.version.parse("2.0.0"): - images["morphology_focus"] = _get_images( - path, - XeniumKeys.MORPHOLOGY_FOCUS_FILE, - imread_kwargs, - image_models_kwargs, - ) - else: - morphology_focus_dir = path / XeniumKeys.MORPHOLOGY_FOCUS_DIR - files = [ - morphology_focus_dir / XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(i) - for i in range(4) - ] - files = [f for f in files if f.exists()] - if len(files) not in [1, 4]: - raise ValueError( - "Expected 1 (no segmentation kit) or 4 (segmentation kit) files in the morphology focus directory, " - f"found {len(files)}: {files}" - ) - - if len(files) == 4: - image_models_kwargs["c_coords"] = [ - XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0, - XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_1, - XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_2, - XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_3, - ] - - images["morphology_focus"] = _get_images( - morphology_focus_dir, - files, - imread_kwargs, - image_models_kwargs, - ) - - return SpatialData(images=images, points=points) - - -def _get_points(path: Path, specs: dict[str, Any]): - table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE) - table["feature_name"] = table["feature_name"].apply( - lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x), - meta=("feature_name", "object"), - ) - - transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y")) - points = PointsModel.parse( - table, - coordinates={ - "x": XeniumKeys.TRANSCRIPTS_X, - "y": XeniumKeys.TRANSCRIPTS_Y, - "z": XeniumKeys.TRANSCRIPTS_Z, - }, - feature_key=XeniumKeys.FEATURE_NAME, - instance_key=XeniumKeys.CELL_ID, - transformations={"global": transform}, - ) - return points - -def _get_images( - path: Path, - file: str | list[str], - imread_kwargs: dict, - image_models_kwargs: dict, -): - if isinstance(file, list): - image = da.concatenate([imread(f, **imread_kwargs) for f in file], axis=0) - else: - image = imread(path / file, **imread_kwargs) - return Image2DModel.parse( - image, - transformations={"global": Identity()}, - dims=("c", "y", "x"), - **image_models_kwargs, + return xenium_spatialdata_io( + path, + cells_table=False, + aligned_images=False, + morphology_mip=False, + nucleus_labels=False, + cells_labels=False, + cells_as_circles=False, + nucleus_boundaries=False, + cells_boundaries=False, + image_models_kwargs=image_models_kwargs, + imread_kwargs=imread_kwargs, )