Skip to content

Commit

Permalink
Merge branch 'main' into sdata_attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
quentinblampey committed Oct 15, 2024
2 parents dfd53f0 + 8ebbde2 commit 7ee1b55
Show file tree
Hide file tree
Showing 21 changed files with 986 additions and 391 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
additional_dependencies: [numpy, types-requests]
exclude: tests/|docs/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
rev: v0.6.8
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.x.x] - 2024-xx-xx
## [0.2.4] - xxxx-xx-xx

### Minor

- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`
- Added `get_pyramid_levels()` utils API
- Improved ergonomics of `concatenate()` when element names are non-unique #720
- Improved performance of writing images with multiscales #577

## [0.2.3] - 2024-09-25

### Major

Expand All @@ -17,6 +26,11 @@ and this project adheres to [Semantic Versioning][].
### Minor

- Added `clip: bool = False` parameter to `polygon_query()` #670
- Add `sort` parameter to `PointsModel.parse()` #672

### Fixed

- Fix interpolation artifact multiscale computation for labels #697

## [0.2.2] - 2024-08-07

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Operations on `SpatialData` objects.
unpad_raster
are_extents_equal
deepcopy
get_pyramid_levels
```

## Models
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"rich",
"setuptools",
"shapely>=2.0.1",
"spatial_image>=1.0.0",
"spatial_image>=1.1.0",
"scikit-image",
"scipy",
"typing_extensions>=4.8.0",
Expand Down Expand Up @@ -68,6 +68,7 @@ docs = [
test = [
"pytest",
"pytest-cov",
"pytest-mock",
]
torch = [
"torch"
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"get_centroids",
"read_zarr",
"unpad_raster",
"get_pyramid_levels",
"save_transformations",
"get_dask_backing_files",
"are_extents_equal",
Expand Down Expand Up @@ -75,4 +76,4 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import get_dask_backing_files, save_transformations
from spatialdata._io.io_zarr import read_zarr
from spatialdata._utils import unpad_raster
from spatialdata._utils import get_pyramid_levels, unpad_raster
15 changes: 10 additions & 5 deletions src/spatialdata/_core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from collections.abc import Iterable

from spatialdata._core.spatialdata import SpatialData


def _find_common_table_keys(sdatas: list[SpatialData]) -> set[str]:
def _find_common_table_keys(sdatas: Iterable[SpatialData]) -> set[str]:
"""
Find table keys present in more than one SpatialData object.
Parameters
----------
sdatas
A list of SpatialData objects.
An `Iterable` of SpatialData objects.
Returns
-------
A set of common keys that are present in the tables of more than one SpatialData object.
"""
common_keys = set(sdatas[0].tables.keys())
common_keys: set[str] = set()

for sdata in sdatas[1:]:
common_keys.intersection_update(sdata.tables.keys())
for sdata in sdatas:
if len(common_keys) == 0:
common_keys = set(sdata.tables.keys())
else:
common_keys.intersection_update(sdata.tables.keys())

return common_keys
109 changes: 95 additions & 14 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from copy import copy # Should probably go up at the top
from itertools import chain
from typing import Any
Expand All @@ -11,7 +12,7 @@

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import TableModel
from spatialdata.models import SpatialElement, TableModel, get_table_keys

__all__ = [
"concatenate",
Expand Down Expand Up @@ -73,10 +74,12 @@ def _concatenate_tables(


def concatenate(
sdatas: list[SpatialData],
sdatas: Iterable[SpatialData] | dict[str, SpatialData],
region_key: str | None = None,
instance_key: str | None = None,
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
**kwargs: Any,
) -> SpatialData:
"""
Expand All @@ -85,36 +88,74 @@ def concatenate(
Parameters
----------
sdatas
The spatial data objects to concatenate.
The spatial data objects to concatenate. The names of the elements across the `SpatialData` objects must be
unique. If they are not unique, you can pass a dictionary with the suffixes as keys and the spatial data objects
as values. This will rename the names of each `SpatialElement` to ensure uniqueness of names across
`SpatialData` objects. See more on the notes.
region_key
The key to use for the region column in the concatenated object.
If all region_keys are the same, the `region_key` is used.
If `None` and all region_keys are the same, the `region_key` is used.
instance_key
The key to use for the instance column in the concatenated object.
If `None` and all instance_keys are the same, the `instance_key` is used.
concatenate_tables
Whether to merge the tables in case of having the same element name.
obs_names_make_unique
Whether to make the `obs_names` unique by calling `AnnData.obs_names_make_unique()` on each table of the
concatenated object. If you passed a dictionary with the suffixes as keys and the `SpatialData` objects as
values and if `concatenate_tables` is `True`, the `obs_names` will be made unique by adding the corresponding
suffix instead.
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
kwargs
See :func:`anndata.concat` for more details.
Returns
-------
The concatenated :class:`spatialdata.SpatialData` object.
Notes
-----
If you pass a dictionary with the suffixes as keys and the `SpatialData` objects as values, the names of each
`SpatialElement` will be renamed to ensure uniqueness of names across `SpatialData` objects by adding the
corresponding suffix. To ensure the matching between existing table annotations, the `region` metadata of each
table, and the values of the `region_key` column in each table, will be altered by adding the suffix. In addition,
the `obs_names` of each table will be altered (a suffix will be added). Finally, a suffix will be added to the name
of each table iff `rename_tables` is `False`.
If you need more control in the renaming, please give us feedback, as we are still trying to find the right balance
between ergonomics and control. Also, you are welcome to copy and adjust the code of
`_fix_ensure_unique_element_names()` directly.
"""
if not isinstance(sdatas, Iterable):
raise TypeError("`sdatas` must be a `Iterable`")

if isinstance(sdatas, dict):
sdatas = _fix_ensure_unique_element_names(
sdatas,
rename_tables=not concatenate_tables,
rename_obs_names=obs_names_make_unique and concatenate_tables,
modify_tables_inplace=modify_tables_inplace,
)

ERROR_STR = (
" must have unique names across the SpatialData objects to concatenate. Please pass a `dict[str, SpatialData]`"
" to `concatenate()` to address this (see docstring)."
)

merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}}
if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]):
raise KeyError("Images must have unique names across the SpatialData objects to concatenate")
raise KeyError("Images" + ERROR_STR)
merged_labels = {**{k: v for sdata in sdatas for k, v in sdata.labels.items()}}
if len(merged_labels) != np.sum([len(sdata.labels) for sdata in sdatas]):
raise KeyError("Labels must have unique names across the SpatialData objects to concatenate")
raise KeyError("Labels" + ERROR_STR)
merged_points = {**{k: v for sdata in sdatas for k, v in sdata.points.items()}}
if len(merged_points) != np.sum([len(sdata.points) for sdata in sdatas]):
raise KeyError("Points must have unique names across the SpatialData objects to concatenate")
raise KeyError("Points" + ERROR_STR)
merged_shapes = {**{k: v for sdata in sdatas for k, v in sdata.shapes.items()}}
if len(merged_shapes) != np.sum([len(sdata.shapes) for sdata in sdatas]):
raise KeyError("Shapes must have unique names across the SpatialData objects to concatenate")

assert isinstance(sdatas, list), "sdatas must be a list"
assert len(sdatas) > 0, "sdatas must be a non-empty list"
raise KeyError("Shapes" + ERROR_STR)

if not concatenate_tables:
key_counts: dict[str, int] = defaultdict(int)
Expand All @@ -124,8 +165,8 @@ def concatenate(

if any(value > 1 for value in key_counts.values()):
warn(
"Duplicate table names found. Tables will be added with integer suffix. Set concatenate_tables to True"
"if concatenation is wished for instead.",
"Duplicate table names found. Tables will be added with integer suffix. Set `concatenate_tables` to "
"`True` if concatenation is wished for instead.",
UserWarning,
stacklevel=2,
)
Expand All @@ -147,13 +188,17 @@ def concatenate(
else:
merged_tables[k] = v

return SpatialData(
sdata = SpatialData(
images=merged_images,
labels=merged_labels,
points=merged_points,
shapes=merged_shapes,
tables=merged_tables,
)
if obs_names_make_unique:
for table in sdata.tables.values():
table.obs_names_make_unique()
return sdata


def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list[str]) -> AnnData:
Expand All @@ -162,3 +207,39 @@ def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list
new_table = table[table.obs[region_key].isin(coordinate_systems)].copy()
new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_table.obs[region_key].unique().tolist()
return new_table


def _fix_ensure_unique_element_names(
sdatas: dict[str, SpatialData],
rename_tables: bool,
rename_obs_names: bool,
modify_tables_inplace: bool,
) -> list[SpatialData]:
elements_by_sdata: list[dict[str, SpatialElement]] = []
tables_by_sdata: list[dict[str, AnnData]] = []
for suffix, sdata in sdatas.items():
elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()}
elements_by_sdata.append(elements)
tables = {}
for name, table in sdata.tables.items():
if not modify_tables_inplace:
table = table.copy()

# fix the region_key column
region, region_key, _ = get_table_keys(table)
table.obs[region_key] = (table.obs[region_key].astype("str") + f"-{suffix}").astype("category")
table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = f"{region}-{suffix}"

# fix the obs names
if rename_obs_names:
table.obs.index = table.obs.index.to_series().apply(lambda x, suffix=suffix: f"{x}-{suffix}")

# fix the table name
new_name = f"{name}-{suffix}" if rename_tables else name
tables[new_name] = table
tables_by_sdata.append(tables)
sdatas_fixed = []
for elements, tables in zip(elements_by_sdata, tables_by_sdata):
sdata = SpatialData.init_from_elements(elements, tables=tables)
sdatas_fixed.append(sdata)
return sdatas_fixed
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _aggregate_image_by_labels(

X = sparse.csr_matrix(df.values)

index = kwargs.get("zone_ids", None) # `zone_ids` allows the user to select specific labels to aggregate by
index = kwargs.get("zone_ids") # `zone_ids` allows the user to select specific labels to aggregate by
if index is None:
index = np.array(da.array.unique(by.data))
assert np.array(index == np.insert(zones, 0, 0)).all(), "Index mismatch between zonal stats and labels."
Expand Down
Loading

0 comments on commit 7ee1b55

Please sign in to comment.