Skip to content

Commit

Permalink
fix Xenium reader issue for recent machine versions (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
quentinblampey committed Jul 3, 2024
1 parent 2f7cbce commit ec4f20e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 261 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
### Changed
- Use `DataArray` and `DataTree` typing instead of (Multiscale)SpatialImage (as in `spatialdata>=0.2.0`)

### Fix
- Fix Xenium reader issue for recent machine versions (#80)

### Changed
- Fully depends on `spatialdata-io` for the MERSCOPE and the Xenium reader

## [1.1.0] - 2024-06-11

First post-publication release
Expand Down
175 changes: 15 additions & 160 deletions sopa/io/reader/merscope.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,27 @@
# Updated from spatialdata-io: https://spatialdata.scverse.org/projects/io/en/latest/
# In the future, we will completely rely on spatialdata-io

from __future__ import annotations

import logging
import re
import warnings
from pathlib import Path
from typing import Callable
from typing import Literal

import dask.array as da
import dask.dataframe as dd
import numpy as np
import xarray
from dask_image.imread import imread
from spatialdata import SpatialData
from spatialdata._logging import logger
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations import Affine, Identity
from spatialdata_io._constants._constants import MerscopeKeys
from spatialdata_io.readers.merscope import merscope as merscope_spatialdata_io

from .utils import _default_image_kwargs

log = logging.getLogger(__name__)


SUPPORTED_BACKENDS = ["dask_image", "rioxarray"]


def merscope(
path: str | Path,
backend: str = None,
backend: Literal["dask_image", "rioxarray"] | None = None,
z_layers: int | list[int] | None = 3,
region_name: str | None = None,
slide_name: str | None = None,
image_models_kwargs: dict | None = None,
imread_kwargs: dict | None = None,
) -> SpatialData:
"""Read MERSCOPE data as a `SpatialData` object.
"""Read MERSCOPE data as a `SpatialData` object. For more information, refer to [spatialdata-io](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.merscope.html).
This function reads the following files:
- `detected_transcripts.csv`: transcripts locations and names
Expand All @@ -46,7 +30,7 @@ def merscope(
Args:
path: Path to the MERSCOPE directory containing all the experiment files
backend: Either `"dask_image"` or `"rioxarray"` (the latter uses less RAM). By default, uses `"rioxarray"` if and only if the `rioxarray` library is installed.
backend: Either `"dask_image"` or `"rioxarray"` (the latter uses less RAM, but requires `rioxarray` to be installed). By default, uses `"rioxarray"` if and only if the `rioxarray` library is installed.
z_layers: Indices of the z-layers to consider. Either one `int` index, or a list of `int` indices. If `None`, then no image is loaded. By default, only the middle layer is considered (that is, layer 3).
region_name: Name of the region of interest, e.g., `'region_0'`. If `None` then the name of the `path` directory is used.
slide_name: Name of the slide/run. If `None` then the name of the parent directory of `path` is used (whose name starts with a date).
Expand All @@ -56,145 +40,16 @@ def merscope(
Returns:
A `SpatialData` object representing the MERSCOPE experiment
"""
assert (
backend is None or backend in SUPPORTED_BACKENDS
), f"Backend '{backend} not supported. Should be one of: {', '.join(SUPPORTED_BACKENDS)}"

path = Path(path).absolute()
image_models_kwargs, imread_kwargs = _default_image_kwargs(image_models_kwargs, imread_kwargs)

images_dir = path / MerscopeKeys.IMAGES_DIR

microns_to_pixels = Affine(
np.genfromtxt(images_dir / MerscopeKeys.TRANSFORMATION_FILE),
input_axes=("x", "y"),
output_axes=("x", "y"),
)

vizgen_region = path.name if region_name is None else region_name
slide_name = path.parent.name if slide_name is None else slide_name
dataset_id = f"{slide_name}_{vizgen_region}"

# Images
images = {}

z_layers = [z_layers] if isinstance(z_layers, int) else z_layers or []

stainings = _get_channel_names(images_dir)
image_transformations = {"microns": microns_to_pixels.inverse()}

reader = _get_reader(backend)

if stainings:
for z_layer in z_layers:
images[f"{dataset_id}_z{z_layer}"] = reader(
images_dir,
stainings,
z_layer,
image_models_kwargs,
image_transformations,
**imread_kwargs,
)

# Transcripts
points = {}
transcript_path = path / MerscopeKeys.TRANSCRIPTS_FILE
if transcript_path.exists():
points[f"{dataset_id}_transcripts"] = _get_points(transcript_path)
else:
logger.warning(
f"Transcript file {transcript_path} does not exist. Transcripts are not loaded."
)

return SpatialData(points=points, images=images)


def _get_reader(backend: str | None) -> Callable:
if backend is not None:
return _rioxarray_load_merscope if backend == "rioxarray" else _dask_image_load_merscope
try:
import rioxarray # noqa: F401

return _rioxarray_load_merscope
except:
return _dask_image_load_merscope


def _get_channel_names(images_dir: Path) -> list[str]:
exp = r"mosaic_(?P<stain>[\w|-]+[0-9]?)_z(?P<z>[0-9]+).tif"
matches = [re.search(exp, file.name) for file in images_dir.iterdir()]

stainings = {match.group("stain") for match in matches if match}

return list(stainings)


def _rioxarray_load_merscope(
images_dir: Path,
stainings: list[str],
z_layer: int,
image_models_kwargs: dict,
transformations: dict,
**kwargs,
):
log.info("Using rioxarray backend.")

import rioxarray
from rasterio.errors import NotGeoreferencedWarning

warnings.simplefilter("ignore", category=NotGeoreferencedWarning)

im = xarray.concat(
[
rioxarray.open_rasterio(
images_dir / f"mosaic_{stain}_z{z_layer}.tif",
chunks=image_models_kwargs["chunks"],
**kwargs,
)
.rename({"band": "c"})
.reset_coords("spatial_ref", drop=True)
for stain in stainings
],
dim="c",
)

return Image2DModel.parse(
im, transformations=transformations, c_coords=stainings, **image_models_kwargs
)


def _dask_image_load_merscope(
images_dir: Path,
stainings: list[str],
z_layer: int,
image_models_kwargs: dict,
transformations: dict,
**kwargs,
):
im = da.stack(
[
imread(images_dir / f"mosaic_{stain}_z{z_layer}.tif", **kwargs).squeeze()
for stain in stainings
],
axis=0,
)

return Image2DModel.parse(
im,
dims=("c", "y", "x"),
transformations=transformations,
c_coords=stainings,
**image_models_kwargs,
)


def _get_points(transcript_path: Path):
transcript_df = dd.read_csv(transcript_path)
transcripts = PointsModel.parse(
transcript_df,
coordinates={"x": MerscopeKeys.GLOBAL_X, "y": MerscopeKeys.GLOBAL_Y},
feature_key="gene",
transformations={"microns": Identity()},
return merscope_spatialdata_io(
path,
backend=backend,
z_layers=z_layers,
region_name=region_name,
slide_name=slide_name,
image_models_kwargs=image_models_kwargs,
imread_kwargs=imread_kwargs,
cells_boundaries=False,
cells_table=False,
)
transcripts["gene"] = transcripts["gene"].astype("category")
return transcripts
114 changes: 13 additions & 101 deletions sopa/io/reader/xenium.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
# Updated from spatialdata-io: https://spatialdata.scverse.org/projects/io/en/latest/
# In the future, we will completely rely on spatialdata-io

from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Any

import dask.array as da
import packaging.version
from dask.dataframe import read_parquet
from dask_image.imread import imread
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations import Identity, Scale
from spatialdata_io._constants._constants import XeniumKeys
from spatialdata_io.readers.xenium import _parse_version_of_xenium_analyzer
from spatialdata_io.readers.xenium import xenium as xenium_spatialdata_io

from .utils import _default_image_kwargs

Expand Down Expand Up @@ -44,94 +32,18 @@ def xenium(
Returns:
A `SpatialData` object representing the Xenium experiment
"""
path = Path(path)
image_models_kwargs, imread_kwargs = _default_image_kwargs(image_models_kwargs, imread_kwargs)
image_models_kwargs["c_coords"] = [XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0.value]

with open(path / XeniumKeys.XENIUM_SPECS) as f:
specs = json.load(f)

points = {"transcripts": _get_points(path, specs)}

with open(path / XeniumKeys.XENIUM_SPECS) as f:
specs = json.load(f)
# to trigger the warning if the version cannot be parsed
version = _parse_version_of_xenium_analyzer(specs, hide_warning=False)

images = {}
if version is None or version < packaging.version.parse("2.0.0"):
images["morphology_focus"] = _get_images(
path,
XeniumKeys.MORPHOLOGY_FOCUS_FILE,
imread_kwargs,
image_models_kwargs,
)
else:
morphology_focus_dir = path / XeniumKeys.MORPHOLOGY_FOCUS_DIR
files = [
morphology_focus_dir / XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(i)
for i in range(4)
]
files = [f for f in files if f.exists()]
if len(files) not in [1, 4]:
raise ValueError(
"Expected 1 (no segmentation kit) or 4 (segmentation kit) files in the morphology focus directory, "
f"found {len(files)}: {files}"
)

if len(files) == 4:
image_models_kwargs["c_coords"] = [
XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0,
XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_1,
XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_2,
XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_3,
]

images["morphology_focus"] = _get_images(
morphology_focus_dir,
files,
imread_kwargs,
image_models_kwargs,
)

return SpatialData(images=images, points=points)


def _get_points(path: Path, specs: dict[str, Any]):
table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE)
table["feature_name"] = table["feature_name"].apply(
lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x),
meta=("feature_name", "object"),
)

transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
points = PointsModel.parse(
table,
coordinates={
"x": XeniumKeys.TRANSCRIPTS_X,
"y": XeniumKeys.TRANSCRIPTS_Y,
"z": XeniumKeys.TRANSCRIPTS_Z,
},
feature_key=XeniumKeys.FEATURE_NAME,
instance_key=XeniumKeys.CELL_ID,
transformations={"global": transform},
)
return points


def _get_images(
path: Path,
file: str | list[str],
imread_kwargs: dict,
image_models_kwargs: dict,
):
if isinstance(file, list):
image = da.concatenate([imread(f, **imread_kwargs) for f in file], axis=0)
else:
image = imread(path / file, **imread_kwargs)
return Image2DModel.parse(
image,
transformations={"global": Identity()},
dims=("c", "y", "x"),
**image_models_kwargs,
return xenium_spatialdata_io(
path,
cells_table=False,
aligned_images=False,
morphology_mip=False,
nucleus_labels=False,
cells_labels=False,
cells_as_circles=False,
nucleus_boundaries=False,
cells_boundaries=False,
image_models_kwargs=image_models_kwargs,
imread_kwargs=imread_kwargs,
)

0 comments on commit ec4f20e

Please sign in to comment.