Skip to content

Commit

Permalink
Merge pull request #271 from scverse/fix/aggregation
Browse files Browse the repository at this point in the history
Fix/aggregation
  • Loading branch information
LucaMarconato authored May 22, 2023
2 parents ae61666 + 1ab3361 commit a0063cc
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 208 deletions.
18 changes: 10 additions & 8 deletions src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ShapesModel,
get_model,
)
from spatialdata.models._utils import get_axes_names
from spatialdata.transformations import BaseTransformation, Identity, get_transformation

__all__ = ["aggregate"]
Expand Down Expand Up @@ -112,8 +111,8 @@ def _aggregate_points_by_shapes(
value_key: str | None = None,
agg_func: str | list[str] = "count",
) -> ad.AnnData:
# Have to get dims on dask dataframe, can't get from pandas
dims = get_axes_names(points)
from spatialdata.models import points_dask_dataframe_to_geopandas

# Default value for id_key
if id_key is None:
id_key = points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY]
Expand All @@ -123,9 +122,8 @@ def _aggregate_points_by_shapes(
"`FEATURE_KEY` for the points."
)

if isinstance(points, ddf.DataFrame):
points = points.compute()
points = gpd.GeoDataFrame(points, geometry=gpd.points_from_xy(*[points[dim] for dim in dims]))
points = points_dask_dataframe_to_geopandas(points, suppress_z_warning=True)
shapes = circles_to_polygons(shapes)

return _aggregate_shapes(points, shapes, id_key, value_key, agg_func)

Expand Down Expand Up @@ -253,8 +251,12 @@ def _aggregate_shapes(
value_key: point_values,
}
)
##
aggregated = to_agg.groupby([by_id_key, id_key]).agg(agg_func).reset_index()
obs_id_categorical = pd.Categorical(aggregated[by_id_key])

# this is for only shapes in "by" that intersect with something in "value"
obs_id_categorical_categories = by.index.tolist()
obs_id_categorical = pd.Categorical(aggregated[by_id_key], categories=obs_id_categorical_categories)

X = sparse.coo_matrix(
(
Expand All @@ -265,7 +267,7 @@ def _aggregate_shapes(
).tocsr()
return ad.AnnData(
X,
obs=pd.DataFrame(index=obs_id_categorical.categories),
obs=pd.DataFrame(index=pd.Categorical(obs_id_categorical_categories).categories),
var=pd.DataFrame(index=joined[id_key].cat.categories),
dtype=X.dtype,
)
18 changes: 12 additions & 6 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,18 @@ def _bounding_box_mask_points(
def _dict_query_dispatcher(
elements: dict[str, SpatialElement], query_function: Callable[[SpatialElement], SpatialElement], **kwargs: Any
) -> dict[str, SpatialElement]:
from spatialdata.transformations import get_transformation

queried_elements = {}
for key, element in elements.items():
result = query_function(element, **kwargs)
if result is not None:
# query returns None if it is empty
queried_elements[key] = result
target_coordinate_system = kwargs["target_coordinate_system"]
d = get_transformation(element, get_all=True)
assert isinstance(d, dict)
if target_coordinate_system in d:
result = query_function(element, **kwargs)
if result is not None:
# query returns None if it is empty
queried_elements[key] = result
return queried_elements


Expand Down Expand Up @@ -649,12 +655,12 @@ def _polygon_query(
new_points = {}
if points:
for points_name, p in sdata.points.items():
points_gdf = points_dask_dataframe_to_geopandas(p)
points_gdf = points_dask_dataframe_to_geopandas(p, suppress_z_warning=True)
indices = points_gdf.geometry.intersects(polygon)
if np.sum(indices) == 0:
raise ValueError("we expect at least one point")
queried_points = points_gdf[indices]
ddf = points_geopandas_to_dask_dataframe(queried_points)
ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True)
transformation = get_transformation(p, target_coordinate_system)
if "z" in ddf.columns:
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
Expand Down
12 changes: 12 additions & 0 deletions src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def write_points(

points_groups = group.require_group(name)
path = Path(points_groups._store.path) / points_groups.path / "points.parquet"

# The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is
# 'category', it checks whether the categories of this column are known. If not, it explicitly converts the
# categories to known categories using 'c.cat.as_known()' and assigns the transformed Series back to the original
# DataFrame. This step is crucial when the number of categories exceeds 127, as pyarrow defaults to int8 for
# unknown categories which can only hold values from -128 to 127.
for column_name in points.columns:
c = points[column_name]
if c.dtype == "category" and not c.cat.known:
c = c.cat.as_known()
points[column_name] = c

points.to_parquet(path)

attrs = fmt.attrs_to_dict(points.attrs)
Expand Down
34 changes: 24 additions & 10 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _validate_dims(dims: tuple[str, ...]) -> None:
raise ValueError(f"Invalid dimensions: {dims}")


def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame:
def points_dask_dataframe_to_geopandas(points: DaskDataFrame, suppress_z_warning: bool = False) -> GeoDataFrame:
"""
Convert a Dask DataFrame to a GeoDataFrame.
Expand All @@ -212,16 +212,25 @@ def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame:
points need to be saved as a Dask DataFrame. We will be restructuring the models to allow for GeoDataFrames soon.
"""
if "z" in points.columns:
from spatialdata.transformations import get_transformation, set_transformation

if "z" in points.columns and not suppress_z_warning:
logger.warning("Constructing the GeoDataFrame without considering the z coordinate in the geometry.")

points_gdf = GeoDataFrame(geometry=geopandas.points_from_xy(points["x"], points["y"]))
for c in points.columns:
points_gdf[c] = points[c]
transformations = get_transformation(points, get_all=True)
assert isinstance(transformations, dict)
assert len(transformations) > 0
points = points.compute()
points_gdf = GeoDataFrame(points, geometry=geopandas.points_from_xy(points["x"], points["y"]))
points_gdf.reset_index(drop=True, inplace=True)
# keep the x and y either in the geometry either as columns: we don't duplicate because having this redundancy could
# lead to subtle bugs when coverting back to dask dataframes
points_gdf.drop(columns=["x", "y"], inplace=True)
set_transformation(points_gdf, transformations, set_all=True)
return points_gdf


def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame:
def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bool = False) -> DaskDataFrame:
"""
Convert a GeoDataFrame which represents 2D or 3D points to a Dask DataFrame that passes the schema validation.
Expand All @@ -241,15 +250,20 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame:
"""
from spatialdata.models import PointsModel

# transformations are transferred automatically
ddf = dd.from_pandas(gdf[gdf.columns.drop("geometry")], npartitions=1)
# we don't want redundancy in the columns since this could lead to subtle bugs when converting back to geopandas
assert "x" not in ddf.columns
assert "y" not in ddf.columns
ddf["x"] = gdf.geometry.x
ddf["y"] = gdf.geometry.y
# parse
if "z" in ddf.columns:
logger.warning(
"Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an "
"additional column."
)
if not suppress_z_warning:
logger.warning(
"Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an "
"additional column."
)
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
else:
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"})
Expand Down
123 changes: 121 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
os.environ["USE_PYGEOS"] = "0"
# isort:on

from shapely import linearrings, polygons
from pathlib import Path
from typing import Union
from spatialdata._types import ArrayLike
Expand All @@ -29,6 +30,7 @@
)
from xarray import DataArray
from spatialdata.datasets import BlobsDataset
import geopandas as gpd

RNG = default_rng()

Expand Down Expand Up @@ -249,10 +251,16 @@ def _get_points() -> dict[str, DaskDataFrame]:
out = {}
for i in range(2):
name = f"{name}_{i}"
arr = RNG.normal(size=(100, 2))
arr = RNG.normal(size=(300, 2))
# randomly assign some values from v to the points
points_assignment0 = RNG.integers(0, 10, size=arr.shape[0]).astype(np.int_)
genes = RNG.choice(["a", "b"], size=arr.shape[0])
if i == 0:
genes = RNG.choice(["a", "b"], size=arr.shape[0])
else:
# we need to test the case in which we have a categorical column with more than 127 categories, see full
# explanation in write_points() (the parser will convert this column to a categorical since
# feature_key="genes")
genes = np.tile(np.array(list(map(str, range(280)))), 2)[:300]
annotation = pd.DataFrame(
{
"genes": genes,
Expand Down Expand Up @@ -299,3 +307,114 @@ def sdata_blobs() -> SpatialData:
sdata.labels["blobs_multiscale_labels"]
)
return sdata


def _make_points(coordinates: np.ndarray) -> DaskDataFrame:
"""Helper function to make a Points element."""
k0 = int(len(coordinates) / 3)
k1 = len(coordinates) - k0
genes = np.hstack((np.repeat("a", k0), np.repeat("b", k1)))
return PointsModel.parse(coordinates, annotation=pd.DataFrame({"genes": genes}), feature_key="genes")


def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons:
linear_rings = []
for centroid, half_width in zip(centroid_coordinates, half_widths):
min_coords = centroid - half_width
max_coords = centroid + half_width

linear_rings.append(
linearrings(
[
[min_coords[0], min_coords[1]],
[min_coords[0], max_coords[1]],
[max_coords[0], max_coords[1]],
[max_coords[0], min_coords[1]],
]
)
)
s = polygons(linear_rings)
polygon_series = gpd.GeoSeries(s)
cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series)
return ShapesModel.parse(cell_polygon_table)


def _make_circles(centroid_coordinates: np.ndarray, radius: list[float]) -> GeoDataFrame:
return ShapesModel.parse(centroid_coordinates, geometry=0, radius=radius)


def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData:
"""
Creates a SpatialData object with many edge cases for testing querying and aggregation.
Returns
-------
The SpatialData object.
Notes
-----
Description of what is tested (for a quick visualization, plot the returned SpatialData object):
- values to query/aggregate: polygons, points, circles
- values to query by: polygons, circles
- the shapes are completely inside, outside, or intersecting the query region (with the centroid inside or outside
the query region)
Additional cases:
- concave shape intersecting multiple times the same shape; used both as query and as value
- shape intersecting multiple shapes; used both as query and as value
"""
values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]])
values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]])
by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90], [210, 15]])
by_centroids_circles = np.array([[24, 15], [290, 15]])
values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles)))
values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15])
values_circles = _make_circles(values_centroids_circles, radius=[6] * 9)
by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15, 30])
by_circles = _make_circles(by_centroids_circles, radius=[30, 30])

from shapely.geometry import Polygon

polygon = Polygon([(100, 90 - 10), (100 + 30, 90), (100, 90 + 10), (150, 90)])
values_squares.loc[len(values_squares)] = [polygon]
ShapesModel.validate(values_squares)

polygon = Polygon([(0, 90 - 10), (0 + 30, 90), (0, 90 + 10), (50, 90)])
by_squares.loc[len(by_squares)] = [polygon]
ShapesModel.validate(by_squares)

sdata = SpatialData(
points={"points": values_points},
shapes={
"values_polygons": values_squares,
"values_circles": values_circles,
"by_polygons": by_squares,
"by_circles": by_circles,
},
)
# to visualize the cases considered in the test, much more immediate than reading them as text as done above
PLOT = False
if PLOT:
##
import matplotlib.pyplot as plt

ax = plt.gca()
sdata.pl.render_shapes(element="values_polygons", na_color=(0.5, 0.2, 0.5, 0.5)).pl.render_points().pl.show(
ax=ax
)
sdata.pl.render_shapes(element="values_circles", na_color=(0.5, 0.2, 0.5, 0.5)).pl.show(ax=ax)
sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax)
sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax)
plt.show()
##

# generate table
x = np.ones((21, 2)) * np.array([1, 2])
region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12)
instance_id = np.array(list(range(9)) + list(range(12)))
table = AnnData(x, obs=pd.DataFrame({"region": region, "instance_id": instance_id}))
table = TableModel.parse(
table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id"
)
sdata.table = table
return sdata
Loading

0 comments on commit a0063cc

Please sign in to comment.