Skip to content

Commit

Permalink
Move c_coords out of kwargs (scverse#779)
Browse files Browse the repository at this point in the history
* refactor c_coords kwargs

* add c_coords docstring specific for images

* fix docstring

* fix docstring, add tests
  • Loading branch information
melonora authored Nov 13, 2024
1 parent 42f7b6a commit 6694c91
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/spatialdata/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def blobs(
n_shapes: int = 5,
extra_coord_system: str | None = None,
n_channels: int = 3,
c_coords: ArrayLike | None = None,
c_coords: str | list[str] | None = None,
) -> SpatialData:
"""
Blobs dataset.
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
n_shapes: int = 5,
extra_coord_system: str | None = None,
n_channels: int = 3,
c_coords: ArrayLike | None = None,
c_coords: str | list[str] | None = None,
) -> None:
"""
Blobs dataset.
Expand Down Expand Up @@ -176,7 +176,7 @@ def _image_blobs(
transformations: dict[str, Any] | None = None,
length: int = 512,
n_channels: int = 3,
c_coords: ArrayLike | None = None,
c_coords: str | list[str] | None = None,
multiscale: bool = False,
) -> DataArray | DataTree:
masks = []
Expand Down
20 changes: 18 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ def parse(
cls,
data: ArrayLike | DataArray | DaskArray,
dims: Sequence[str] | None = None,
c_coords: str | list[str] | None = None,
transformations: MappingToCoordinateSystem_t | None = None,
scale_factors: ScaleFactors_t | None = None,
method: Methods | None = None,
chunks: Chunks_t | None = None,
**kwargs: Any,
) -> DataArray | DataTree:
"""
r"""
Validate (or parse) raster data.
Parameters
Expand All @@ -110,6 +111,9 @@ def parse(
Dimensions of the data (e.g. ['c', 'y', 'x'] for 2D image data). If the data is a :class:`xarray.DataArray`,
the dimensions can also be inferred from the data. If the dimensions are not in the order (c)(z)yx, the data
will be transposed to match the order.
c_coords : str | list[str] | None
Channel names of image data. Must be equal to the length of dimension 'c'. Only supported for `Image`
models.
transformations
Dictionary of transformations to apply to the data. The key is the name of the target coordinate system,
the value is the transformation to apply. By default, a single `Identity` transformation mapping to the
Expand Down Expand Up @@ -195,7 +199,15 @@ def parse(
) from e

# finally convert to spatial image
data = to_spatial_image(array_like=data, dims=cls.dims.dims, **kwargs)
if isinstance(c_coords, str):
c_coords = [c_coords]
if c_coords is not None and len(c_coords) != data.shape[cls.dims.dims.index("c")]:
raise ValueError(
f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'"
f" with length {data.shape[cls.dims.dims.index('c')]}."
)

data = to_spatial_image(array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs)
# parse transformations
_parse_transformations(data, transformations)
# convert to multiscale if needed
Expand Down Expand Up @@ -270,6 +282,8 @@ def parse( # noqa: D102
*args: Any,
**kwargs: Any,
) -> DataArray | DataTree:
if kwargs.get("c_coords") is not None:
raise ValueError("`c_coords` is not supported for labels")
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
# Override default scaling method to preserve labels
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
Expand All @@ -292,6 +306,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

@classmethod
def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102
if kwargs.get("c_coords") is not None:
raise ValueError("`c_coords` is not supported for labels")
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
# Override default scaling method to preserve labels
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,24 @@ def test_dask_points_from_parquet(points, npartitions: int, sorted_index: bool):
match=r"The index of the dataframe is not monotonic increasing\.",
):
_ = PointsModel.parse(points, npartitions=npartitions)


@pytest.mark.parametrize("scale_factors", [None, [2, 2]])
def test_c_coords_2d(scale_factors: list[int] | None):
data = np.zeros((3, 30, 30))
model = Image2DModel().parse(data, c_coords=["1st", "2nd", "3rd"], scale_factors=scale_factors)
if scale_factors is None:
assert model.coords["c"].data.tolist() == ["1st", "2nd", "3rd"]
else:
assert all(
model[group]["image"].coords["c"].data.tolist() == ["1st", "2nd", "3rd"] for group in list(model.keys())
)

with pytest.raises(ValueError, match="The number of channel names"):
Image2DModel().parse(data, c_coords=["1st", "2nd", "3rd", "too_much"], scale_factors=scale_factors)


@pytest.mark.parametrize("model", [Labels2DModel, Labels3DModel])
def test_label_no_c_coords(model: Labels2DModel | Labels3DModel):
with pytest.raises(ValueError, match="`c_coords` is not supported"):
model().parse(np.zeros((30, 30)), c_coords=["1st", "2nd", "3rd"])

0 comments on commit 6694c91

Please sign in to comment.