Skip to content

Commit

Permalink
Fix validation and scale accessor multiscale image (#719)
Browse files Browse the repository at this point in the history
* better comment validation datatree

* added API get_pyramid_levels()
  • Loading branch information
LucaMarconato authored Oct 1, 2024
1 parent f492f55 commit 3988452
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning][].
### Minor

- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`
- Added `get_pyramid_levels()` utils API

## [0.2.3] - 2024-09-25

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Operations on `SpatialData` objects.
unpad_raster
are_extents_equal
deepcopy
get_pyramid_levels
```

## Models
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"get_centroids",
"read_zarr",
"unpad_raster",
"get_pyramid_levels",
"save_transformations",
"get_dask_backing_files",
"are_extents_equal",
Expand Down Expand Up @@ -75,4 +76,4 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import get_dask_backing_files, save_transformations
from spatialdata._io.io_zarr import read_zarr
from spatialdata._utils import unpad_raster
from spatialdata._utils import get_pyramid_levels, unpad_raster
22 changes: 3 additions & 19 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from xarray import DataArray

from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import iterate_pyramid_levels
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import (
MappingToCoordinateSystem_t,
SpatialElement,
Expand Down Expand Up @@ -122,22 +122,6 @@ def _write_metadata(
group.attrs["spatialdata_attrs"] = attrs


def _iter_multiscale(
data: DataTree,
attr: str | None,
) -> list[Any]:
# TODO: put this check also in the validator for raster multiscales
for i in data:
variables = set(data[i].variables.keys())
names: set[str] = variables.difference({"c", "z", "y", "x"})
if len(names) != 1:
raise ValueError(f"Invalid variable name: `{names}`.")
name: str = next(iter(names))
if attr is not None:
return [getattr(data[i][name], attr) for i in data]
return [data[i][name] for i in data]


class dircmp(filecmp.dircmp): # type: ignore[type-arg]
"""
Compare the content of dir1 and dir2.
Expand Down Expand Up @@ -241,8 +225,8 @@ def _(element: DataArray) -> list[str]:

@get_dask_backing_files.register(DataTree)
def _(element: DataTree) -> list[str]:
xdata0 = next(iter(iterate_pyramid_levels(element)))
return _get_backing_files(xdata0.data)
dask_data_scale0 = get_pyramid_levels(element, attr="data", n=0)
return _get_backing_files(dask_data_scale0)


@get_dask_backing_files.register(DaskDataFrame)
Expand Down
10 changes: 5 additions & 5 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from spatialdata._io._utils import (
_get_transformations_from_ngff_dict,
_iter_multiscale,
overwrite_coordinate_transformations_raster,
)
from spatialdata._io.format import (
Expand All @@ -26,6 +25,7 @@
RasterFormatV01,
_parse_version,
)
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import get_channels
from spatialdata.models.models import ATTRS_KEY
from spatialdata.transformations._utils import (
Expand Down Expand Up @@ -180,8 +180,8 @@ def _get_group_for_writing_transformations() -> zarr.Group:
group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes
)
elif isinstance(raster_data, DataTree):
data = _iter_multiscale(raster_data, "data")
list_of_input_axes: list[Any] = _iter_multiscale(raster_data, "dims")
data = get_pyramid_levels(raster_data, attr="data")
list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims")
assert len(set(list_of_input_axes)) == 1
input_axes = list_of_input_axes[0]
# saving only the transformations of the first scale
Expand All @@ -191,8 +191,8 @@ def _get_group_for_writing_transformations() -> zarr.Group:
transformations = _get_transformations_xarray(xdata)
assert transformations is not None
assert len(transformations) > 0
chunks = _iter_multiscale(raster_data, "chunks")
# coords = _iter_multiscale(raster_data, "coords")
chunks = get_pyramid_levels(raster_data, "chunks")
# coords = iterate_pyramid_levels(raster_data, "coords")
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format)
storage_options = [{"chunks": chunk} for chunk in chunks]
write_multi_scale_ngff(
Expand Down
63 changes: 36 additions & 27 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import warnings
from collections.abc import Generator
from itertools import islice
from typing import Any, Callable, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -150,45 +151,53 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]:
return compute_coordinates(unpadded)


# TODO: probably we want this method to live in multiscale_spatial_image
def multiscale_spatial_image_from_data_tree(data_tree: DataTree) -> DataTree:
warnings.warn(
f"{multiscale_spatial_image_from_data_tree} is deprecated and will be removed in version 0.2.0.",
DeprecationWarning,
stacklevel=2,
)
d = {}
for k, dt in data_tree.items():
v = dt.values()
assert len(v) == 1
xdata = v.__iter__().__next__()
d[k] = xdata
def get_pyramid_levels(image: DataTree, attr: str | None = None, n: int | None = None) -> list[Any] | Any:
"""
Access the data/attribute of the pyramid levels of a multiscale spatial image.
Parameters
----------
image
The multiscale spatial image.
attr
If `None`, return the data of the pyramid level as a `DataArray`, if not None, return the specified attribute
within the `DataArray` data.
n
If not None, return only the `n` pyramid level.
return DataTree.from_dict(d)
Returns
-------
The pyramid levels data (or an attribute of it) as a list or a generator.
"""
generator = iterate_pyramid_levels(image, attr)
if n is not None:
return next(iter(islice(generator, n, None)))
return list(generator)


# TODO: this functions is similar to _iter_multiscale(), the latter is more powerful but not exposed to the user.
# Use only one and expose it to the user in this file
def iterate_pyramid_levels(image: DataTree) -> Generator[DataArray, None, None]:
def iterate_pyramid_levels(
data: DataTree,
attr: str | None,
) -> Generator[Any, None, None]:
"""
Iterate over the pyramid levels of a multiscale spatial image.
Parameters
----------
image
The multiscale spatial image.
data
The multiscale spatial image
attr
If `None`, return the data of the pyramid level as a `DataArray`, if not None, return the specified attribute
within the `DataArray` data.
Returns
-------
A generator that yields the pyramid levels.
A generator to iterate over the pyramid levels.
"""
for k in range(len(image)):
scale_name = f"scale{k}"
dt = image[scale_name]
v = dt.values()
assert len(v) == 1
xdata = next(iter(v))
yield xdata
names = data["scale0"].ds.keys()
name: str = next(iter(names))
for scale in data:
yield data[scale][name] if attr is None else getattr(data[scale][name], attr)


def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: AnnData) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def _(self, data: DataTree) -> None:
if j != k:
raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.")
name = {list(data[i].data_vars.keys())[0] for i in data}
if len(name) > 1:
raise ValueError(f"Wrong name for datatree: `{name}`.")
if len(name) != 1:
raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.")
name = list(name)[0]
for d in data:
super().validate(data[d][name])
Expand Down
4 changes: 2 additions & 2 deletions tests/core/operations/test_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from spatialdata import SpatialData, get_extent
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.query.relational_query import get_element_instances
from spatialdata._io._utils import _iter_multiscale
from spatialdata._utils import get_pyramid_levels
from spatialdata.models import PointsModel, ShapesModel, TableModel, get_axes_names
from spatialdata.models._utils import get_spatial_axes
from spatialdata.transformations import MapAxis
Expand Down Expand Up @@ -57,7 +57,7 @@ def _get_data_of_largest_scale(raster):
if isinstance(raster, DataArray):
return raster.data.compute()

xdata = next(iter(_iter_multiscale(raster, None)))
xdata = get_pyramid_levels(raster, n=0)
return xdata.data.compute()

for element_name, raster in rasters.items():
Expand Down

0 comments on commit 3988452

Please sign in to comment.