diff --git a/src/spatialdata_io/__main__.py b/src/spatialdata_io/__main__.py index c65e46f4..3e01c751 100644 --- a/src/spatialdata_io/__main__.py +++ b/src/spatialdata_io/__main__.py @@ -417,6 +417,8 @@ def visium_hd_wrapper( output: str, dataset_id: str | None = None, filtered_counts_file: bool = True, + load_segmentations_only: bool = True, + load_nucleus_segmentations: bool = False, bin_size: int | list[int] | None = None, bins_as_squares: bool = True, fullres_image_file: str | Path | None = None, @@ -428,6 +430,8 @@ def visium_hd_wrapper( path=input, dataset_id=dataset_id, filtered_counts_file=filtered_counts_file, + load_segmentations_only=load_segmentations_only, + load_nucleus_segmentations=load_nucleus_segmentations, bin_size=bin_size, bins_as_squares=bins_as_squares, fullres_image_file=fullres_image_file, diff --git a/src/spatialdata_io/_constants/_constants.py b/src/spatialdata_io/_constants/_constants.py index 33795f92..0c3db3c6 100644 --- a/src/spatialdata_io/_constants/_constants.py +++ b/src/spatialdata_io/_constants/_constants.py @@ -353,11 +353,16 @@ class VisiumHDKeys(ModeEnum): BIN_PREFIX = "square_" MICROSCOPE_IMAGE = "microscope_image" BINNED_OUTPUTS = "binned_outputs" + SEGMENTATION_OUTPUTS = "segmented_outputs" # counts and locations files FILTERED_COUNTS_FILE = "filtered_feature_bc_matrix.h5" RAW_COUNTS_FILE = "raw_feature_bc_matrix.h5" TISSUE_POSITIONS_FILE = "tissue_positions.parquet" + BARCODE_MAPPINGS_FILE = "barcode_mappings.parquet" + FILTERED_CELL_COUNTS_FILE = "filtered_feature_cell_matrix.h5" + CELL_SEGMENTATION_GEOJSON_PATH = "cell_segmentations.geojson" + NUCLEUS_SEGMENTATION_GEOJSON_PATH = "nucleus_segmentations.geojson" # images IMAGE_HIRES_FILE = "tissue_hires_image.png" @@ -399,3 +404,7 @@ class VisiumHDKeys(ModeEnum): MICROSCOPE_COLROW_TO_SPOT_COLROW = ("microscope_colrow_to_spot_colrow",) SPOT_COLROW_TO_MICROSCOPE_COLROW = ("spot_colrow_to_microscope_colrow",) FILE_FORMAT = "file_format" + + # Cell Segmentation keys + CELL_SEG_KEY_HD = 'cell_segmentations' + NUCLEUS_SEG_KEY_HD = 'nucleus_segmentations' diff --git a/src/spatialdata_io/readers/visium_hd.py b/src/spatialdata_io/readers/visium_hd.py index 6a5f3bf6..1534f0db 100644 --- a/src/spatialdata_io/readers/visium_hd.py +++ b/src/spatialdata_io/readers/visium_hd.py @@ -10,11 +10,14 @@ import h5py import numpy as np import pandas as pd +import pyarrow.parquet as pq import scanpy as sc from dask_image.imread import imread from geopandas import GeoDataFrame from imageio import imread as imread2 from numpy.random import default_rng +from scipy.sparse import csc_matrix +from shapely.geometry import Polygon from skimage.transform import ProjectiveTransform, warp from spatialdata import ( SpatialData, @@ -32,18 +35,20 @@ if TYPE_CHECKING: from collections.abc import Mapping + import anndata from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage from spatialdata._types import ArrayLike RNG = default_rng(0) - @inject_docs(vx=VisiumHDKeys) def visium_hd( path: str | Path, dataset_id: str | None = None, filtered_counts_file: bool = True, + load_segmentations_only: bool = True, + load_nucleus_segmentations: bool = False, bin_size: int | list[int] | None = None, bins_as_squares: bool = True, annotate_table_by_labels: bool = False, @@ -56,10 +61,6 @@ def visium_hd( ) -> SpatialData: """Read *10x Genomics* Visium HD formatted dataset. - .. seealso:: - - - `Space Ranger output `_. - Parameters ---------- path @@ -70,6 +71,9 @@ def visium_hd( filtered_counts_file It sets the value of `counts_file` to ``{vx.FILTERED_COUNTS_FILE!r}`` (when `True`) or to ``{vx.RAW_COUNTS_FILE!r}`` (when `False`). + load_segmentations_only + If `True`, only the segmented cell boundaries and their associated counts will be loaded. All binned data + will be skipped. bin_size When specified, load the data of a specific bin size, or a list of bin sizes. By default, it loads all the available bin sizes. @@ -105,11 +109,21 @@ def visium_hd( images: dict[str, Any] = {} labels: dict[str, Any] = {} + # Check for segmentation files + SEGMENTED_OUTPUTS_PATH = path / VisiumHDKeys.SEGMENTATION_OUTPUTS + COUNT_MATRIX_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.FILTERED_CELL_COUNTS_FILE + CELL_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.CELL_SEGMENTATION_GEOJSON_PATH + NUCLEUS_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.NUCLEUS_SEGMENTATION_GEOJSON_PATH + SCALE_FACTORS_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.SPATIAL / VisiumHDKeys.SCALEFACTORS_FILE + BARCODE_MAPPINGS_PATH = next((file for file in path.rglob("*") if file.name.endswith(VisiumHDKeys.BARCODE_MAPPINGS_FILE)), None) + FILTERED_MATRIX_2U_PATH = path / VisiumHDKeys.BINNED_OUTPUTS / f"{VisiumHDKeys.BIN_PREFIX}002um" / VisiumHDKeys.FILTERED_COUNTS_FILE + cell_segmentation_files_exist = COUNT_MATRIX_PATH.exists() and CELL_GEOJSON_PATH.exists() and SCALE_FACTORS_PATH.exists() + nucleus_segmentation_files_exist = NUCLEUS_GEOJSON_PATH.exists() and (BARCODE_MAPPINGS_PATH is not None and BARCODE_MAPPINGS_PATH.exists()) and FILTERED_MATRIX_2U_PATH.exists() + if dataset_id is None: dataset_id = _infer_dataset_id(path) filename_prefix = _get_filename_prefix(path, dataset_id) - def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None) -> None: _load_image( path=path, @@ -131,154 +145,195 @@ def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None) stacklevel=2, ) - def _get_bins(path_bins: Path) -> list[str]: - return sorted( - [ - bin_size.name - for bin_size in path_bins.iterdir() - if bin_size.is_dir() and bin_size.name.startswith(VisiumHDKeys.BIN_PREFIX) - ] - ) - - all_path_bins = [path_bin for path_bin in all_files if VisiumHDKeys.BINNED_OUTPUTS in str(path_bin)] - if len(all_path_bins) != 0: - path_bins_parts = all_path_bins[ - -1 - ].parts # just choosing last one here as users might have tar file which would be first - path_bins = Path(*path_bins_parts[: path_bins_parts.index(VisiumHDKeys.BINNED_OUTPUTS) + 1]) - else: - path_bins = path - all_bin_sizes = _get_bins(path_bins) - - bin_sizes = [] - if bin_size is not None: - if not isinstance(bin_size, list): - bin_size = [bin_size] - bin_sizes = [f"square_{bs:03}um" for bs in bin_size if f"square_{bs:03}um" in all_bin_sizes] - if len(bin_sizes) < len(bin_size): - warnings.warn( - f"Requested bin size {bin_size} (available {all_bin_sizes}); ignoring the bin sizes that are not " - "found.", - UserWarning, - stacklevel=2, + # Load Binned Data (skipped if load_segmentations_only is True) + if not load_segmentations_only: + def _get_bins(path_bins: Path) -> list[str]: + return sorted( + [ + bin_size.name + for bin_size in path_bins.iterdir() + if bin_size.is_dir() and bin_size.name.startswith(VisiumHDKeys.BIN_PREFIX) + ] ) - if bin_size is None or bin_sizes == []: - bin_sizes = all_bin_sizes - - # iterate over the given bins and load the data - for bin_size_str in bin_sizes: - path_bin = path_bins / bin_size_str - counts_file = VisiumHDKeys.FILTERED_COUNTS_FILE if filtered_counts_file else VisiumHDKeys.RAW_COUNTS_FILE - adata = sc.read_10x_h5( - path_bin / counts_file, - gex_only=False, - **anndata_kwargs, - ) - path_bin_spatial = path_bin / VisiumHDKeys.SPATIAL + all_path_bins = [path_bin for path_bin in all_files if VisiumHDKeys.BINNED_OUTPUTS in str(path_bin)] + if len(all_path_bins) != 0: + path_bins_parts = all_path_bins[ + -1 + ].parts + path_bins = Path(*path_bins_parts[: path_bins_parts.index(VisiumHDKeys.BINNED_OUTPUTS) + 1]) + else: + path_bins = path + all_bin_sizes = _get_bins(path_bins) + + bin_sizes = [] + if bin_size is not None: + if not isinstance(bin_size, list): + bin_size = [bin_size] + bin_sizes = [f"square_{bs:03}um" for bs in bin_size if f"square_{bs:03}um" in all_bin_sizes] + if len(bin_sizes) < len(bin_size): + warnings.warn( + f"Requested bin size {bin_size} (available {all_bin_sizes}); ignoring the bin sizes that are not " + "found.", + UserWarning, + stacklevel=2, + ) + if bin_size is None or bin_sizes == []: + bin_sizes = all_bin_sizes - with open(path_bin_spatial / VisiumHDKeys.SCALEFACTORS_FILE) as file: - scalefactors = json.load(file) + for bin_size_str in bin_sizes: + path_bin = path_bins / bin_size_str + counts_file = VisiumHDKeys.FILTERED_COUNTS_FILE if filtered_counts_file else VisiumHDKeys.RAW_COUNTS_FILE + adata = sc.read_10x_h5( + path_bin / counts_file, + gex_only=False, + **anndata_kwargs, + ) - # consistency check - found_bin_size = re.search(r"\d{3}", bin_size_str) - assert found_bin_size is not None - assert float(found_bin_size.group()) == scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] - assert np.isclose( - scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] - / scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES], - scalefactors[VisiumHDKeys.SCALEFACTORS_MICRONS_PER_PIXEL], - ) + path_bin_spatial = path_bin / VisiumHDKeys.SPATIAL - tissue_positions_file = path_bin_spatial / VisiumHDKeys.TISSUE_POSITIONS_FILE - - # read coordinates and set up adata.obs and adata.obsm - coords = pd.read_parquet(tissue_positions_file) - assert all( - coords.columns.values - == [ - VisiumHDKeys.BARCODE, - VisiumHDKeys.IN_TISSUE, - VisiumHDKeys.ARRAY_ROW, - VisiumHDKeys.ARRAY_COL, - VisiumHDKeys.LOCATIONS_Y, - VisiumHDKeys.LOCATIONS_X, - ] - ) - coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True) - coords_filtered = coords.loc[adata.obs.index] - adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True) - # compatibility to legacy squidpy - adata.obsm["spatial"] = adata.obs[[VisiumHDKeys.LOCATIONS_X, VisiumHDKeys.LOCATIONS_Y]].values - # dropping the spatial coordinates (will be stored in shapes) - adata.obs.drop( - columns=[ - VisiumHDKeys.LOCATIONS_X, - VisiumHDKeys.LOCATIONS_Y, - ], - inplace=True, - ) - adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata)) + with open(path_bin_spatial / VisiumHDKeys.SCALEFACTORS_FILE) as file: + scalefactors = json.load(file) - # scaling - transform_original = Identity() - transform_lowres = Scale( - np.array( - [ - scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], - scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], - ] - ), - axes=("x", "y"), - ) - transform_hires = Scale( - np.array( - [ - scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], - scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], - ] - ), - axes=("x", "y"), - ) - # parse shapes - shapes_name = dataset_id + "_" + bin_size_str - radius = scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0 - transformations = { - dataset_id: transform_original, - f"{dataset_id}_downscaled_hires": transform_hires, - f"{dataset_id}_downscaled_lowres": transform_lowres, - } - circles = ShapesModel.parse( - adata.obsm["spatial"], - geometry=0, - radius=radius, - index=adata.obs[VisiumHDKeys.INSTANCE_KEY].copy(), - transformations=transformations, - ) - if not bins_as_squares: - shapes[shapes_name] = circles - else: - squares_series = circles.buffer(radius, cap_style=3) - shapes[shapes_name] = ShapesModel.parse( - GeoDataFrame(geometry=squares_series), transformations=transformations + found_bin_size = re.search(r"\d{3}", bin_size_str) + assert found_bin_size is not None + assert float(found_bin_size.group()) == scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] + assert np.isclose( + scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] + / scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES], + scalefactors[VisiumHDKeys.SCALEFACTORS_MICRONS_PER_PIXEL], ) - # parse table - adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name - adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category") - - tables[bin_size_str] = TableModel.parse( - adata, - region=shapes_name, - region_key=str(VisiumHDKeys.REGION_KEY), - instance_key=str(VisiumHDKeys.INSTANCE_KEY), + tissue_positions_file = path_bin_spatial / VisiumHDKeys.TISSUE_POSITIONS_FILE + + coords = pd.read_parquet(tissue_positions_file) + assert all( + coords.columns.values + == [ + VisiumHDKeys.BARCODE, + VisiumHDKeys.IN_TISSUE, + VisiumHDKeys.ARRAY_ROW, + VisiumHDKeys.ARRAY_COL, + VisiumHDKeys.LOCATIONS_Y, + VisiumHDKeys.LOCATIONS_X, + ] + ) + coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True) + coords_filtered = coords.loc[adata.obs.index] + adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True) + adata.obsm["spatial"] = adata.obs[[VisiumHDKeys.LOCATIONS_X, VisiumHDKeys.LOCATIONS_Y]].values + adata.obs.drop( + columns=[ + VisiumHDKeys.LOCATIONS_X, + VisiumHDKeys.LOCATIONS_Y, + ], + inplace=True, + ) + adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata)) + + transform_lowres = Scale( + np.array( + [ + scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], + scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], + ] + ), + axes=("x", "y"), + ) + transform_hires = Scale( + np.array( + [ + scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], + scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], + ] + ), + axes=("x", "y"), + ) + shapes_name = dataset_id + "_" + bin_size_str + radius = scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0 + + # Here we ensure that only the correct coordinate systems are created for the binned data + transformations = { + f"{dataset_id}_downscaled_hires": transform_hires, + f"{dataset_id}_downscaled_lowres": transform_lowres, + } + circles = ShapesModel.parse( + adata.obsm["spatial"], + geometry=0, + radius=radius, + index=adata.obs[VisiumHDKeys.INSTANCE_KEY].copy(), + transformations=transformations, + ) + if not bins_as_squares: + shapes[shapes_name] = circles + else: + squares_series = circles.buffer(radius, cap_style=3) + shapes[shapes_name] = ShapesModel.parse( + GeoDataFrame(geometry=squares_series), transformations=transformations + ) + + adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name + adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category") + + tables[bin_size_str] = TableModel.parse( + adata, + region=shapes_name, + region_key=str(VisiumHDKeys.REGION_KEY), + instance_key=str(VisiumHDKeys.INSTANCE_KEY), + ) + if var_names_make_unique: + tables[bin_size_str].var_names_make_unique() + + # Integrate the segmentation data (skipped if segmentation files are not found) + if cell_segmentation_files_exist: + print("Found segmentation data. Incorporating cell_segmentations.") + cell_adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH) + cell_adata_hd.var_names_make_unique() + + shapes_transformations_hd = _make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id) # Used for both cell and nucleus segmentations + cell_geojson_features_map = _make_geojson_features_map(CELL_GEOJSON_PATH) + cell_shapes_gdf = _extract_geometries_from_geojson(cell_adata_hd, geojson_features_map=cell_geojson_features_map) + + SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.CELL_SEG_KEY_HD}" + cell_adata_hd.obs['cell_id'] = cell_adata_hd.obs.index + cell_adata_hd.obs['region'] = SHAPES_KEY_HD + cell_adata_hd.obs['region'] = cell_adata_hd.obs['region'].astype('category') + cell_adata_hd = cell_adata_hd[cell_shapes_gdf.index].copy() + + shapes[SHAPES_KEY_HD] = ShapesModel.parse(cell_shapes_gdf, transformations=shapes_transformations_hd) + tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse( + cell_adata_hd, + region=SHAPES_KEY_HD, + region_key='region', + instance_key='cell_id' ) - if var_names_make_unique: - tables[bin_size_str].var_names_make_unique() - # read full resolution image + # load nucleus segmentations if available + if nucleus_segmentation_files_exist and load_nucleus_segmentations: + print("Found nucleus segmentation data. Incorporating nucleus_segmentations.") + + nucleus_adata_hd = _make_filtered_nucleus_adata(filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH,barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH) + geojson_features_map = _make_geojson_features_map(NUCLEUS_GEOJSON_PATH) + nucleus_shapes_gdf = _extract_geometries_from_geojson(adata=nucleus_adata_hd, geojson_features_map=geojson_features_map) + + SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}" + nucleus_adata_hd.obs['cell_id'] = nucleus_adata_hd.obs.index + nucleus_adata_hd.obs['region'] = SHAPES_KEY_HD + nucleus_adata_hd.obs['region'] = nucleus_adata_hd.obs['region'].astype('category') + nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy() + + shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=shapes_transformations_hd) + tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse( + nucleus_adata_hd, + region=SHAPES_KEY_HD, + region_key='region', + instance_key='cell_id' + ) + + # Read all images and add transformations for both binning and segmentation + fullres_image_file_paths = [] if fullres_image_file is not None: - fullres_image_file = Path(fullres_image_file) + fullres_image_file_paths.append(Path(fullres_image_file)) else: path_fullres = path / VisiumHDKeys.MICROSCOPE_IMAGE if path_fullres.exists(): @@ -305,62 +360,76 @@ def _get_bins(path_bins: Path) -> list[str]: if fullres_image_file is not None: load_image( - path=fullres_image_file, + path=fullres_image_file_paths[0], suffix="_full_image", scale_factors=[2, 2, 2, 2], ) + else: + warnings.warn( + "No full resolution image found. If incorrect, please specify the path in the " + "`fullres_image_file` parameter when calling the `visium_hd` reader function.", + UserWarning, + stacklevel=2, + ) - # hires image hires_image_path = [path for path in all_files if VisiumHDKeys.IMAGE_HIRES_FILE in str(path)] - if len(hires_image_path) == 0: + if len(hires_image_path) > 0: + load_image( + path=hires_image_path[0], + suffix="_hires_image", + ) + if not load_segmentations_only and "transform_hires" in locals(): + set_transformation( + images[dataset_id + "_hires_image"], + { + f"{dataset_id}_downscaled_hires": Identity(), + dataset_id: transform_hires.inverse(), + }, + set_all=True, + ) + if cell_segmentation_files_exist: + set_transformation( + images[dataset_id + "_hires_image"], + {f"{dataset_id}_downscaled_hires": Identity()}, + set_all=True, + ) + else: warnings.warn( f"No image path found containing the hires image: {VisiumHDKeys.IMAGE_HIRES_FILE}", UserWarning, stacklevel=2, ) - load_image( - path=hires_image_path[0], - suffix="_hires_image", - ) - set_transformation( - images[dataset_id + "_hires_image"], - { - f"{dataset_id}_downscaled_hires": Identity(), - dataset_id: transform_hires.inverse(), - }, - set_all=True, - ) - # lowres image lowres_image_path = [path for path in all_files if VisiumHDKeys.IMAGE_LOWRES_FILE in str(path)] - if len(lowres_image_path) == 0: + if len(lowres_image_path) > 0: + load_image( + path=lowres_image_path[0], + suffix="_lowres_image", + ) + if not load_segmentations_only and "transform_lowres" in locals(): + set_transformation( + images[dataset_id + "_lowres_image"], + { + f"{dataset_id}_downscaled_lowres": Identity(), + dataset_id: transform_lowres.inverse(), + }, + set_all=True, + ) + if cell_segmentation_files_exist: + set_transformation( + images[dataset_id + "_lowres_image"], + {f"{dataset_id}_downscaled_lowres": Identity()}, + set_all=True, + ) + else: warnings.warn( f"No image path found containing the lowres image: {VisiumHDKeys.IMAGE_LOWRES_FILE}", UserWarning, stacklevel=2, ) - load_image( - path=lowres_image_path[0], - suffix="_lowres_image", - ) - set_transformation( - images[dataset_id + "_lowres_image"], - { - f"{dataset_id}_downscaled_lowres": Identity(), - dataset_id: transform_lowres.inverse(), - }, - set_all=True, - ) - # cytassist image cytassist_path = [path for path in all_files if VisiumHDKeys.IMAGE_CYTASSIST in str(path)] - if len(cytassist_path) == 0: - warnings.warn( - f"No image path found containing the cytassist image: {VisiumHDKeys.IMAGE_CYTASSIST}", - UserWarning, - stacklevel=2, - ) - if load_all_images: + if load_all_images and len(cytassist_path) > 0: load_image( path=cytassist_path[0], suffix="_cytassist_image", @@ -373,21 +442,17 @@ def _get_bins(path_bins: Path) -> list[str]: projective /= projective[2, 2] if _projective_matrix_is_affine(projective): affine = Affine(projective, input_axes=("x", "y"), output_axes=("x", "y")) - set_transformation(image, affine, dataset_id) + if not load_segmentations_only: + set_transformation(image, affine, dataset_id) else: - # the projective matrix is not affine, we will separate the affine part and the projective shift, and apply - # the projective shift to the image affine_matrix, projective_shift = _decompose_projective_matrix(projective) affine = Affine(affine_matrix, input_axes=("x", "y"), output_axes=("x", "y")) - - # determine the size of the transformed image bounding_box = get_extent(image, coordinate_system=dataset_id) x0, x1 = bounding_box["x"] y0, y1 = bounding_box["y"] x1 -= 1 y1 -= 1 corners = [(x0, y0), (x1, y0), (x1, y1), (x0, y1)] - transformed_corners = [] for x, y in corners: px, py = _projective_matrix_transform_point(projective_shift, x, y) @@ -399,33 +464,29 @@ def _get_bins(path_bins: Path) -> list[str]: np.max(transformed_corners_array[:, 0]), np.max(transformed_corners_array[:, 1]), ) - # the first two components are <= 0, we just discard them since the cytassist image has a lot of padding - # and therefore we can safely discard pixels with negative coordinates transformed_shape = (np.ceil(transformed_bounds[2]), np.ceil(transformed_bounds[3])) - - # flip xy transformed_shape = (transformed_shape[1], transformed_shape[0]) - - # the cytassist image is a small, single-scale image, so we can compute it in memory numpy_data = image.transpose("y", "x", "c").data.compute() warped = warp( numpy_data, ProjectiveTransform(projective_shift).inverse, output_shape=transformed_shape, order=1 ) warped = np.round(warped * 255).astype(np.uint8) - warped = Image2DModel.parse(warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True) - - # we replace the cytassist image with the warped image - images[dataset_id + "_cytassist_image"] = warped + if not load_segmentations_only: + warped = Image2DModel.parse(warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True) + images[dataset_id + "_cytassist_image"] = warped + elif load_all_images: + warnings.warn( + f"No image path found containing the cytassist image: {VisiumHDKeys.IMAGE_CYTASSIST}", + UserWarning, + stacklevel=2, + ) sdata = SpatialData(tables=tables, images=images, shapes=shapes, labels=labels) - if annotate_table_by_labels: + if annotate_table_by_labels and not load_segmentations_only: for bin_size_str in bin_sizes: shapes_name = dataset_id + "_" + bin_size_str - - # add labels layer (rasterized bins). labels_name = f"{dataset_id}_{bin_size_str}_labels" - labels_element = rasterize_bins( sdata, bins=shapes_name, @@ -435,7 +496,6 @@ def _get_bins(path_bins: Path) -> list[str]: value_key=None, return_region_as_labels=True, ) - sdata[labels_name] = labels_element rasterize_bins_link_table_to_labels( sdata=sdata, table_name=bin_size_str, rasterized_labels_name=labels_name @@ -584,3 +644,144 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any]) transform_matrices[key] = np.array(coefficients).reshape(3, 3) return transform_matrices + +def _make_filtered_nucleus_adata( + filtered_matrix_h5_path: Path, + barcode_mappings_parquet_path: Path, + bin_col_name: str = 'square_002um', + aggregate_col_name: str = 'cell_id' +) -> anndata.AnnData: + """Generate a filtered AnnData object by aggregating 2um binned data based on nucleus segmentation. + + Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing + barcode mappings, filters the data to include only valid nucleus mappings, + and aggregates the data based on specified bin into cell IDs which only contain + the 2um square data under segmented nuclei. + + Parameters: + ----------- + filtered_matrix_h5_path : Path + Path to the 10x Genomics HDF5 matrix file. + barcode_mappings_parquet_path : Path + Path to the Parquet file containing barcode mappings. + bin_col_name : str, optional + Column name in the barcode mappings that specifies the spatial bin (default is 'square_002um'). + aggregate_col_name : str, optional + Column name in the barcode mappings that specifies the aggregate cell ID (default is 'cell_id'). + + Returns: + -------- + anndata.AnnData + An AnnData object where the observations correspond to filtered cell IDs + and the variables correspond to the original features from the input data. + """ + # Read in the necessary files + adata_2um = sc.read_10x_h5(filtered_matrix_h5_path) + barcode_mappings = pq.read_table(barcode_mappings_parquet_path) + + # Filter to only include valid cell IDs that are in both nucleus and cell + barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"]) + + # Filter the 2um adata to only include squares present in the barcode mappings + valid_squares = barcode_mappings[bin_col_name].unique() + squares_to_keep = np.intersect1d(adata_2um.obs_names, valid_squares) + adata_filtered = adata_2um[squares_to_keep, :].copy() + + # Map each square to its corresponding cell ID + square_to_cell_map = dict(zip( + barcode_mappings[bin_col_name].to_pylist(), + barcode_mappings[aggregate_col_name].to_pylist(), strict=False + + )) + ordered_cell_ids = [square_to_cell_map[square] for square in adata_filtered.obs_names] + unique_cells = list(dict.fromkeys(ordered_cell_ids).keys()) + cell_to_idx = {cell: i for i, cell in enumerate(unique_cells)} + + # Make the aggregation matrix + col_indices = [cell_to_idx[cell] for cell in ordered_cell_ids] + row_indices = np.arange(len(ordered_cell_ids)) + data = np.ones_like(row_indices) + + aggregation_matrix = csc_matrix( + (data, (row_indices, col_indices)), + shape=(adata_filtered.n_obs, len(unique_cells)) + ) + + # Make the final AnnData object where cell IDs are filtered + # to the data under the segmented nuclei + nucleus_matrix_sparse = adata_filtered.X.T.dot(aggregation_matrix) + adata_nucleus = sc.AnnData(nucleus_matrix_sparse.T) + adata_nucleus.obs_names = unique_cells + adata_nucleus.var = adata_filtered.var + + return adata_nucleus + +def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame: + """Extract geometries and create a GeoDataFrame from a GeoJSON features map. + + Parameters + ---------- + cell_adata : anndata.AnnData + AnnData object containing cell data. + geojson_features_map : dict[str, Any] + Dictionary mapping cell IDs to GeoJSON features. + + Returns + ------- + GeoDataFrame + A GeoDataFrame containing cell IDs and their corresponding geometries. + """ + geometries = [] + cell_ids_ordered = [] + + for obs_index_str in adata.obs.index: + feature = geojson_features_map.get(obs_index_str) + if feature: + polygon_coords = np.array(feature['geometry']['coordinates'][0]) + geometries.append(Polygon(polygon_coords)) + cell_ids_ordered.append(obs_index_str) + else: + geometries.append(None) + cell_ids_ordered.append(obs_index_str) + + valid_indices = [i for i, geom in enumerate(geometries) if geom is not None] + geometries = [geometries[i] for i in valid_indices] + cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices] + + return GeoDataFrame({ + 'cell_id': cell_ids_ordered, + 'geometry': geometries + }, index=cell_ids_ordered) + +def _make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> dict[str, Scale]: + """Load scale factors for lowres and hires images and create transformations. + + Parameters + ---------- + scale_factors_path : Path + Path to the scale factors JSON file. + dataset_id : str + Unique identifier of the dataset. + + Returns + ------- + dict[str, Scale] + A dictionary containing the transformations for lowres and hires images. + """ + with open(scale_factors_path) as f: + scale_data_hd = json.load(f) + lowres_scale_factor_hd = scale_data_hd['tissue_lowres_scalef'] + hires_scale_factor_hd = scale_data_hd['tissue_hires_scalef'] + + return { + f"{dataset_id}_downscaled_lowres": Scale(np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")), + f"{dataset_id}_downscaled_hires": Scale(np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y")) + } + +def _make_geojson_features_map(geojson_path: Path) -> dict[str, Any]: + with open(geojson_path) as f: + geojson_data = json.load(f) + return { + f"cellid_{feature['properties']['cell_id']:09d}-1": feature + for feature in geojson_data['features'] + } diff --git a/tests/test_visium_hd.py b/tests/test_visium_hd.py new file mode 100644 index 00000000..9bb70efe --- /dev/null +++ b/tests/test_visium_hd.py @@ -0,0 +1,177 @@ +import math +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest +from click.testing import CliRunner +from spatialdata import get_extent, read_zarr +from spatialdata.models import get_table_keys + +from spatialdata_io.__main__ import visium_hd_wrapper +from spatialdata_io._constants._constants import VisiumHDKeys +from spatialdata_io.readers.visium_hd import ( + _decompose_projective_matrix, + _projective_matrix_is_affine, + visium_hd, +) +from tests._utils import skip_if_below_python_version + +# --- UNIT TESTS FOR HELPER FUNCTIONS --- + +def test_projective_matrix_is_affine() -> None: + """Test the affine matrix check function.""" + # An affine matrix should have [0, 0, 1] as its last row + affine_matrix = np.array([[2, 0.5, 10], [0.5, 2, 20], [0, 0, 1]]) + assert _projective_matrix_is_affine(affine_matrix) + + # A projective matrix is not affine if the last row is different + projective_matrix = np.array([[2, 0.5, 10], [0.5, 2, 20], [0.01, 0.02, 1]]) + assert not _projective_matrix_is_affine(projective_matrix) + + +def test_decompose_projective_matrix() -> None: + """Test the decomposition of a projective matrix into affine and shift components.""" + projective_matrix = np.array([[1, 2, 3], [4, 5, 6], [0.1, 0.2, 1]]) + affine, shift = _decompose_projective_matrix(projective_matrix) + + expected_affine = np.array([[1, 2, 3], [4, 5, 6], [0, 0, 1]]) + + # The affine component should be correctly extracted + assert np.allclose(affine, expected_affine) + # Recomposing the affine and shift matrices should yield the original projective matrix + assert np.allclose(affine @ shift, projective_matrix) + + +# --- END-TO-END TESTS ON EXAMPLE DATA --- + +# TODO: Replace with the actual Visium HD test dataset folder name +# This dataset name is used to locate the test data in the './data/' directory. +# See https://github.com/scverse/spatialdata-io/blob/main/.github/workflows/prepare_test_data.yaml +# for instructions on how to download and place the data on disk. +DATASET_FOLDER = "Visium_HD_Mouse_Brain_Chunk" +DATASET_ID = "visium_hd_tiny" + + +@skip_if_below_python_version() +def test_visium_hd_data_extent() -> None: + """Check the spatial extent of the loaded Visium HD data.""" + f = Path("./data") / DATASET_FOLDER + if not f.is_dir(): + pytest.skip(f"Test data not found at '{f}'. Skipping extent test.") + + sdata = visium_hd(f, dataset_id=DATASET_ID) + extent = get_extent(sdata, exact=False) + extent = {ax: (math.floor(extent[ax][0]), math.ceil(extent[ax][1])) for ax in extent} + + # TODO: Replace with the actual expected extent of your test data + expected_extent = "{'x': (1000, 7000), 'y': (2000, 8000)}" + assert str(extent) == expected_extent + + +@skip_if_below_python_version() +@pytest.mark.parametrize( + "params", + [ + # Test case 1: Default binned data loading (squares) + {"load_segmentations_only": False, "load_nucleus_segmentations": False, "bins_as_squares": True, "annotate_table_by_labels": False, "load_all_images": False}, + # Test case 2: Binned data as circles + {"load_segmentations_only": False, "load_nucleus_segmentations": False, "bins_as_squares": False, "annotate_table_by_labels": False, "load_all_images": False}, + # Test case 3: Binned data with tables annotating labels instead of shapes + {"load_segmentations_only": False, "load_nucleus_segmentations": False, "bins_as_squares": True, "annotate_table_by_labels": True, "load_all_images": False}, + # Test case 4: Load binned data AND all segmentations (cell + nucleus) + {"load_segmentations_only": False, "load_nucleus_segmentations": True, "bins_as_squares": True, "annotate_table_by_labels": False, "load_all_images": False}, + # Test case 5: Load cell segmentations only + {"load_segmentations_only": True, "load_nucleus_segmentations": False, "bins_as_squares": True, "annotate_table_by_labels": False, "load_all_images": False}, + # Test case 6: Load all segmentations (cell + nucleus) only + {"load_segmentations_only": True, "load_nucleus_segmentations": True, "bins_as_squares": True, "annotate_table_by_labels": False, "load_all_images": False}, + # Test case 7: Load everything, including auxiliary images like CytAssist + {"load_segmentations_only": False, "load_nucleus_segmentations": True, "bins_as_squares": True, "annotate_table_by_labels": False, "load_all_images": True}, + ], +) +def test_visium_hd_data_integrity(params: dict[str, bool]) -> None: + """Check the integrity of various components of the loaded SpatialData object.""" + f = Path("./data") / DATASET_FOLDER + if not f.is_dir(): + pytest.skip(f"Test data not found at '{f}'. Skipping integrity test.") + + sdata = visium_hd(f, dataset_id=DATASET_ID, **params) + + # --- IMAGE CHECKS --- + assert f"{DATASET_ID}_full_image" in sdata.images + assert f"{DATASET_ID}_hires_image" in sdata.images + assert f"{DATASET_ID}_lowres_image" in sdata.images + if params.get("load_all_images", False): + assert f"{DATASET_ID}_cytassist_image" in sdata.images + + # --- SEGMENTATION CHECKS (loaded in all modes if present) --- + # TODO: Update placeholder values with actual data from your test dataset + assert VisiumHDKeys.CELL_SEG_KEY_HD in sdata.tables + assert f"{DATASET_ID}_{VisiumHDKeys.CELL_SEG_KEY_HD}" in sdata.shapes + cell_table = sdata.tables[VisiumHDKeys.CELL_SEG_KEY_HD] + assert cell_table.shape == (2485, 36738) # Example shape (n_obs, n_vars) + assert "cellid_000000001-1" in cell_table.obs_names # Example cell ID + + if params["load_nucleus_segmentations"]: + assert VisiumHDKeys.NUCLEUS_SEG_KEY_HD in sdata.tables + assert f"{DATASET_ID}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}" in sdata.shapes + nuc_table = sdata.tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] + assert nuc_table.shape == (2485, 36738) # Example shape + else: + assert VisiumHDKeys.NUCLEUS_SEG_KEY_HD not in sdata.tables + + # --- BINNED DATA CHECKS --- + if params["load_segmentations_only"]: + assert "square_002um" not in sdata.tables + else: + assert "square_008um" in sdata.tables + table = sdata.tables["square_008um"] + assert table.shape == (39000, 36738) # Example shape + assert "AAACCGGGTTTA-1" in table.obs_names # Example barcode + assert np.array_equal(table.X.indices[:3], [10, 20, 30]) # Example indices + + shape_name = f"{DATASET_ID}_square_008um" + labels_name = f"{shape_name}_labels" + if params["annotate_table_by_labels"]: + assert labels_name in sdata.labels + region, _, _ = get_table_keys(table) + assert region == labels_name + else: + assert shape_name in sdata.shapes + region, _, _ = get_table_keys(table) + assert region == shape_name + # Check for circles vs. squares + if params["bins_as_squares"]: + assert "radius" not in sdata.shapes[shape_name] + else: + assert "radius" in sdata.shapes[shape_name] + +# --- CLI WRAPPER TEST --- + +@skip_if_below_python_version() +def test_cli_visium_hd(runner: CliRunner) -> None: + """Test the command-line interface for the Visium HD reader.""" + f = Path("./data") / DATASET_FOLDER + if not f.is_dir(): + pytest.skip(f"Test data not found at '{f}'. Skipping CLI test.") + + with TemporaryDirectory() as tmpdir: + output_zarr = Path(tmpdir) / "data.zarr" + result = runner.invoke( + visium_hd_wrapper, + [ + "--path", + str(f), + "--output", + str(output_zarr), + ], + ) + assert result.exit_code == 0, result.output + # Verify the output can be read + sdata = read_zarr(output_zarr) + + # A simple check to confirm data was loaded + # The default dataset_id is inferred from the feature slice file name. + # This assert may need adjustment based on your test data's file names. + inferred_dataset_id = DATASET_FOLDER.replace("_outs", "") # Example inference + assert f"{inferred_dataset_id}_full_image" in sdata.images