Skip to content

Commit

Permalink
Improved concatenate() for non-unique element names (#720)
Browse files Browse the repository at this point in the history
* accept Iterable in concatenate

* concatenate: automatic non-unique names resolution

* docs, changelog

* add test for len 1 iterable (grst code review)
  • Loading branch information
LucaMarconato authored Oct 4, 2024
1 parent 3988452 commit c09f35a
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning][].

- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`
- Added `get_pyramid_levels()` utils API
- Improved ergonomics of `concatenate()` when element names are non-unique #720

## [0.2.3] - 2024-09-25

Expand Down
15 changes: 10 additions & 5 deletions src/spatialdata/_core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from collections.abc import Iterable

from spatialdata._core.spatialdata import SpatialData


def _find_common_table_keys(sdatas: list[SpatialData]) -> set[str]:
def _find_common_table_keys(sdatas: Iterable[SpatialData]) -> set[str]:
"""
Find table keys present in more than one SpatialData object.
Parameters
----------
sdatas
A list of SpatialData objects.
An `Iterable` of SpatialData objects.
Returns
-------
A set of common keys that are present in the tables of more than one SpatialData object.
"""
common_keys = set(sdatas[0].tables.keys())
common_keys: set[str] = set()

for sdata in sdatas[1:]:
common_keys.intersection_update(sdata.tables.keys())
for sdata in sdatas:
if len(common_keys) == 0:
common_keys = set(sdata.tables.keys())
else:
common_keys.intersection_update(sdata.tables.keys())

return common_keys
109 changes: 95 additions & 14 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from copy import copy # Should probably go up at the top
from itertools import chain
from typing import Any
Expand All @@ -11,7 +12,7 @@

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import TableModel
from spatialdata.models import SpatialElement, TableModel, get_table_keys

__all__ = [
"concatenate",
Expand Down Expand Up @@ -73,10 +74,12 @@ def _concatenate_tables(


def concatenate(
sdatas: list[SpatialData],
sdatas: Iterable[SpatialData] | dict[str, SpatialData],
region_key: str | None = None,
instance_key: str | None = None,
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
**kwargs: Any,
) -> SpatialData:
"""
Expand All @@ -85,36 +88,74 @@ def concatenate(
Parameters
----------
sdatas
The spatial data objects to concatenate.
The spatial data objects to concatenate. The names of the elements across the `SpatialData` objects must be
unique. If they are not unique, you can pass a dictionary with the suffixes as keys and the spatial data objects
as values. This will rename the names of each `SpatialElement` to ensure uniqueness of names across
`SpatialData` objects. See more on the notes.
region_key
The key to use for the region column in the concatenated object.
If all region_keys are the same, the `region_key` is used.
If `None` and all region_keys are the same, the `region_key` is used.
instance_key
The key to use for the instance column in the concatenated object.
If `None` and all instance_keys are the same, the `instance_key` is used.
concatenate_tables
Whether to merge the tables in case of having the same element name.
obs_names_make_unique
Whether to make the `obs_names` unique by calling `AnnData.obs_names_make_unique()` on each table of the
concatenated object. If you passed a dictionary with the suffixes as keys and the `SpatialData` objects as
values and if `concatenate_tables` is `True`, the `obs_names` will be made unique by adding the corresponding
suffix instead.
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
kwargs
See :func:`anndata.concat` for more details.
Returns
-------
The concatenated :class:`spatialdata.SpatialData` object.
Notes
-----
If you pass a dictionary with the suffixes as keys and the `SpatialData` objects as values, the names of each
`SpatialElement` will be renamed to ensure uniqueness of names across `SpatialData` objects by adding the
corresponding suffix. To ensure the matching between existing table annotations, the `region` metadata of each
table, and the values of the `region_key` column in each table, will be altered by adding the suffix. In addition,
the `obs_names` of each table will be altered (a suffix will be added). Finally, a suffix will be added to the name
of each table iff `rename_tables` is `False`.
If you need more control in the renaming, please give us feedback, as we are still trying to find the right balance
between ergonomics and control. Also, you are welcome to copy and adjust the code of
`_fix_ensure_unique_element_names()` directly.
"""
if not isinstance(sdatas, Iterable):
raise TypeError("`sdatas` must be a `Iterable`")

if isinstance(sdatas, dict):
sdatas = _fix_ensure_unique_element_names(
sdatas,
rename_tables=not concatenate_tables,
rename_obs_names=obs_names_make_unique and concatenate_tables,
modify_tables_inplace=modify_tables_inplace,
)

ERROR_STR = (
" must have unique names across the SpatialData objects to concatenate. Please pass a `dict[str, SpatialData]`"
" to `concatenate()` to address this (see docstring)."
)

merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}}
if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]):
raise KeyError("Images must have unique names across the SpatialData objects to concatenate")
raise KeyError("Images" + ERROR_STR)
merged_labels = {**{k: v for sdata in sdatas for k, v in sdata.labels.items()}}
if len(merged_labels) != np.sum([len(sdata.labels) for sdata in sdatas]):
raise KeyError("Labels must have unique names across the SpatialData objects to concatenate")
raise KeyError("Labels" + ERROR_STR)
merged_points = {**{k: v for sdata in sdatas for k, v in sdata.points.items()}}
if len(merged_points) != np.sum([len(sdata.points) for sdata in sdatas]):
raise KeyError("Points must have unique names across the SpatialData objects to concatenate")
raise KeyError("Points" + ERROR_STR)
merged_shapes = {**{k: v for sdata in sdatas for k, v in sdata.shapes.items()}}
if len(merged_shapes) != np.sum([len(sdata.shapes) for sdata in sdatas]):
raise KeyError("Shapes must have unique names across the SpatialData objects to concatenate")

assert isinstance(sdatas, list), "sdatas must be a list"
assert len(sdatas) > 0, "sdatas must be a non-empty list"
raise KeyError("Shapes" + ERROR_STR)

if not concatenate_tables:
key_counts: dict[str, int] = defaultdict(int)
Expand All @@ -124,8 +165,8 @@ def concatenate(

if any(value > 1 for value in key_counts.values()):
warn(
"Duplicate table names found. Tables will be added with integer suffix. Set concatenate_tables to True"
"if concatenation is wished for instead.",
"Duplicate table names found. Tables will be added with integer suffix. Set `concatenate_tables` to "
"`True` if concatenation is wished for instead.",
UserWarning,
stacklevel=2,
)
Expand All @@ -147,13 +188,17 @@ def concatenate(
else:
merged_tables[k] = v

return SpatialData(
sdata = SpatialData(
images=merged_images,
labels=merged_labels,
points=merged_points,
shapes=merged_shapes,
tables=merged_tables,
)
if obs_names_make_unique:
for table in sdata.tables.values():
table.obs_names_make_unique()
return sdata


def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list[str]) -> AnnData:
Expand All @@ -162,3 +207,39 @@ def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list
new_table = table[table.obs[region_key].isin(coordinate_systems)].copy()
new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_table.obs[region_key].unique().tolist()
return new_table


def _fix_ensure_unique_element_names(
sdatas: dict[str, SpatialData],
rename_tables: bool,
rename_obs_names: bool,
modify_tables_inplace: bool,
) -> list[SpatialData]:
elements_by_sdata: list[dict[str, SpatialElement]] = []
tables_by_sdata: list[dict[str, AnnData]] = []
for suffix, sdata in sdatas.items():
elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()}
elements_by_sdata.append(elements)
tables = {}
for name, table in sdata.tables.items():
if not modify_tables_inplace:
table = table.copy()

# fix the region_key column
region, region_key, _ = get_table_keys(table)
table.obs[region_key] = (table.obs[region_key].astype("str") + f"-{suffix}").astype("category")
table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = f"{region}-{suffix}"

# fix the obs names
if rename_obs_names:
table.obs.index = table.obs.index.to_series().apply(lambda x, suffix=suffix: f"{x}-{suffix}")

# fix the table name
new_name = f"{name}-{suffix}" if rename_tables else name
tables[new_name] = table
tables_by_sdata.append(tables)
sdatas_fixed = []
for elements, tables in zip(elements_by_sdata, tables_by_sdata):
sdata = SpatialData.init_from_elements(elements, tables=tables)
sdatas_fixed.append(sdata)
return sdatas_fixed
49 changes: 48 additions & 1 deletion tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from spatialdata._core.operations._utils import transform_to_data_extent
from spatialdata._core.spatialdata import SpatialData
from spatialdata.datasets import blobs
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys
from spatialdata.models import (
Image2DModel,
Labels2DModel,
PointsModel,
ShapesModel,
TableModel,
get_table_keys,
)
from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical
from spatialdata.transformations.operations import get_transformation, set_transformation
from spatialdata.transformations.transformations import (
Expand Down Expand Up @@ -284,6 +291,46 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None:
assert len(list(concatenated.gen_elements())) == 3


@pytest.mark.parametrize("concatenate_tables", [True, False])
@pytest.mark.parametrize("obs_names_make_unique", [True, False])
def test_concatenate_sdatas_from_iterable(concatenate_tables: bool, obs_names_make_unique: bool) -> None:
sdata0 = blobs()
sdata1 = blobs()

sdatas = {"sample0": sdata0, "sample1": sdata1}
with pytest.raises(KeyError, match="Images must have unique names across the SpatialData objects"):
_ = concatenate(
sdatas.values(), concatenate_tables=concatenate_tables, obs_names_make_unique=obs_names_make_unique
)
merged = concatenate(sdatas, obs_names_make_unique=obs_names_make_unique, concatenate_tables=concatenate_tables)

if concatenate_tables:
assert len(merged.tables) == 1
table = merged["table"]
if obs_names_make_unique:
assert table.obs_names[0] == "1-sample0"
assert table.obs_names[-1] == "30-sample1"
else:
assert table.obs_names[0] == "1"
else:
assert merged["table-sample0"].obs_names[0] == "1"
assert sdata0["table"].obs_names[0] == "1"


def test_concatenate_sdatas_single_item() -> None:
sdata = blobs()

def _n_elements(sdata: SpatialData) -> int:
return len([0 for _, _, _ in sdata.gen_elements()])

n = _n_elements(sdata)
assert n == _n_elements(concatenate([sdata]))
assert n == _n_elements(concatenate({"sample": sdata}.values()))
c = concatenate({"sample": sdata})
assert n == _n_elements(c)
assert "blobs_image-sample" in c.images


def test_locate_spatial_element(full_sdata: SpatialData) -> None:
assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d"
im = full_sdata.images["image2d"]
Expand Down

0 comments on commit c09f35a

Please sign in to comment.