Skip to content

Commit

Permalink
Added subset API; fix behavior with zero-len table (#426)
Browse files Browse the repository at this point in the history
* added subset API, returning None instead of empty table for APIs with  filter_table=True

* fix 3.9
  • Loading branch information
LucaMarconato authored Jan 9, 2024
1 parent 9470a61 commit 4567d2c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning][].

### Added

- added SpatialData.subset() API

### Fixed

#### Minor
Expand Down
6 changes: 6 additions & 0 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def _(
new_elements[element_type] = queried_elements

table = _filter_table_by_elements(sdata.table, new_elements) if filter_table else sdata.table
if len(table) == 0:
table = None
return SpatialData(**new_elements, table=table)


Expand Down Expand Up @@ -641,6 +643,8 @@ def _polygon_query(

if filter_table and sdata.table is not None:
table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points})
if table is not None and len(table) == 0:
table = None
else:
table = sdata.table
return SpatialData(shapes=new_shapes, points=new_points, images=new_images, table=table)
Expand Down Expand Up @@ -749,5 +753,7 @@ def polygon_query(
geodataframes[k] = vv

table = _filter_table_by_elements(sdata.table, {"shapes": geodataframes}) if filter_table else sdata.table
if len(table) == 0:
table = None

return SpatialData(shapes=geodataframes, table=table)
27 changes: 27 additions & 0 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,33 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData
elements_dict.setdefault(element_type, {})[name] = element
return cls(**elements_dict, table=table)

def subset(self, element_names: list[str], filter_table: bool = True) -> SpatialData:
"""
Subset the SpatialData object.
Parameters
----------
element_names
The names of the element_names to subset.
filter_table
If True (default), the table is filtered to only contain rows that are annotating regions
contained within the element_names.
Returns
-------
The subsetted SpatialData object.
"""
from spatialdata._core.query.relational_query import _filter_table_by_elements

elements_dict: dict[str, SpatialElement] = {}
for element_type, element_name, element in self._gen_elements():
if element_name in element_names:
elements_dict.setdefault(element_type, {})[element_name] = element
table = _filter_table_by_elements(self.table, elements_dict=elements_dict) if filter_table else self.table
if len(table) == 0:
table = None
return SpatialData(**elements_dict, table=table)

def __getitem__(self, item: str) -> SpatialElement:
"""
Return the element with the given name.
Expand Down
30 changes: 28 additions & 2 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
import pytest
from anndata import AnnData
Expand Down Expand Up @@ -123,8 +125,8 @@ def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: S
raise TypeError(f"Unsupported type {type(element)}")


def _assert_tables_seem_identical(table0: AnnData, table1: AnnData) -> None:
assert table0.shape == table1.shape
def _assert_tables_seem_identical(table0: AnnData | None, table1: AnnData | None) -> None:
assert table0 is None and table1 is None or table0.shape == table1.shape


def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None:
Expand Down Expand Up @@ -360,3 +362,27 @@ def test_init_from_elements(full_sdata: SpatialData) -> None:
sdata = SpatialData.init_from_elements(all_elements, table=full_sdata.table)
for element_type in ["images", "labels", "points", "shapes"]:
assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys())


def test_subset(full_sdata: SpatialData) -> None:
element_names = ["image2d", "labels2d", "points_0", "circles", "poly"]
subset0 = full_sdata.subset(element_names)
unique_names = set()
for _, k, _ in subset0._gen_elements():
unique_names.add(k)
assert "image3d_xarray" in full_sdata.images
assert unique_names == set(element_names)
assert subset0.table is None

adata = AnnData(
shape=(10, 0),
obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, "a", "b", "c", "d", "e"]},
)
del full_sdata.table
full_sdata.table = TableModel.parse(
adata, region=["circles", "poly"], region_key="region", instance_key="instance_id"
)
subset1 = full_sdata.subset(["poly"])
assert subset1.table is not None
assert len(subset1.table) == 5
assert subset1.table.obs["region"].unique().tolist() == ["poly"]
2 changes: 1 addition & 1 deletion tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_polygon_query_points(sdata_query_aggregation):
queried = polygon_query(sdata, polygons=polygon, target_coordinate_system="global", shapes=False, points=True)
points = queried["points"].compute()
assert len(points) == 6
assert len(queried.table) == 0
assert queried.table is None

# TODO: the case of querying points with multiple polygons is not currently implemented

Expand Down

0 comments on commit 4567d2c

Please sign in to comment.