Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/aggregation #271

Merged
merged 5 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ShapesModel,
get_model,
)
from spatialdata.models._utils import get_axes_names
from spatialdata.transformations import BaseTransformation, Identity, get_transformation

__all__ = ["aggregate"]
Expand Down Expand Up @@ -112,8 +111,8 @@ def _aggregate_points_by_shapes(
value_key: str | None = None,
agg_func: str | list[str] = "count",
) -> ad.AnnData:
# Have to get dims on dask dataframe, can't get from pandas
dims = get_axes_names(points)
from spatialdata.models import points_dask_dataframe_to_geopandas

# Default value for id_key
if id_key is None:
id_key = points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY]
Expand All @@ -123,9 +122,8 @@ def _aggregate_points_by_shapes(
"`FEATURE_KEY` for the points."
)

if isinstance(points, ddf.DataFrame):
points = points.compute()
points = gpd.GeoDataFrame(points, geometry=gpd.points_from_xy(*[points[dim] for dim in dims]))
points = points_dask_dataframe_to_geopandas(points, suppress_z_warning=True)
shapes = circles_to_polygons(shapes)

return _aggregate_shapes(points, shapes, id_key, value_key, agg_func)

Expand Down Expand Up @@ -253,8 +251,12 @@ def _aggregate_shapes(
value_key: point_values,
}
)
##
aggregated = to_agg.groupby([by_id_key, id_key]).agg(agg_func).reset_index()
obs_id_categorical = pd.Categorical(aggregated[by_id_key])

# this is for only shapes in "by" that intersect with something in "value"
obs_id_categorical_categories = by.index.tolist()
obs_id_categorical = pd.Categorical(aggregated[by_id_key], categories=obs_id_categorical_categories)

X = sparse.coo_matrix(
(
Expand All @@ -265,7 +267,8 @@ def _aggregate_shapes(
).tocsr()
return ad.AnnData(
X,
obs=pd.DataFrame(index=obs_id_categorical.categories),
# obs=pd.DataFrame(index=obs_id_categorical.categories),
obs=pd.DataFrame(index=pd.Categorical(by.index.tolist()).categories),
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
var=pd.DataFrame(index=joined[id_key].cat.categories),
dtype=X.dtype,
)
18 changes: 12 additions & 6 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,18 @@ def _bounding_box_mask_points(
def _dict_query_dispatcher(
elements: dict[str, SpatialElement], query_function: Callable[[SpatialElement], SpatialElement], **kwargs: Any
) -> dict[str, SpatialElement]:
from spatialdata.transformations import get_transformation

queried_elements = {}
for key, element in elements.items():
result = query_function(element, **kwargs)
if result is not None:
# query returns None if it is empty
queried_elements[key] = result
target_coordinate_system = kwargs["target_coordinate_system"]
d = get_transformation(element, get_all=True)
assert isinstance(d, dict)
if target_coordinate_system in d:
result = query_function(element, **kwargs)
if result is not None:
# query returns None if it is empty
queried_elements[key] = result
return queried_elements


Expand Down Expand Up @@ -649,12 +655,12 @@ def _polygon_query(
new_points = {}
if points:
for points_name, p in sdata.points.items():
points_gdf = points_dask_dataframe_to_geopandas(p)
points_gdf = points_dask_dataframe_to_geopandas(p, suppress_z_warning=True)
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
indices = points_gdf.geometry.intersects(polygon)
if np.sum(indices) == 0:
raise ValueError("we expect at least one point")
queried_points = points_gdf[indices]
ddf = points_geopandas_to_dask_dataframe(queried_points)
ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True)
transformation = get_transformation(p, target_coordinate_system)
if "z" in ddf.columns:
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
Expand Down
12 changes: 12 additions & 0 deletions src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def write_points(

points_groups = group.require_group(name)
path = Path(points_groups._store.path) / points_groups.path / "points.parquet"

# The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is
# 'category', it checks whether the categories of this column are known. If not, it explicitly converts the
# categories to known categories using 'c.cat.as_known()' and assigns the transformed Series back to the original
# DataFrame. This step is crucial when the number of categories exceeds 127, as pyarrow defaults to int8 for
# unknown categories which can only hold values from -128 to 127.
for column_name in points.columns:
c = points[column_name]
if c.dtype == "category" and not c.cat.known:
c = c.cat.as_known()
points[column_name] = c

points.to_parquet(path)

attrs = fmt.attrs_to_dict(points.attrs)
Expand Down
34 changes: 24 additions & 10 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _validate_dims(dims: tuple[str, ...]) -> None:
raise ValueError(f"Invalid dimensions: {dims}")


def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame:
def points_dask_dataframe_to_geopandas(points: DaskDataFrame, suppress_z_warning: bool = False) -> GeoDataFrame:
"""
Convert a Dask DataFrame to a GeoDataFrame.

Expand All @@ -212,16 +212,25 @@ def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame:
points need to be saved as a Dask DataFrame. We will be restructuring the models to allow for GeoDataFrames soon.

"""
if "z" in points.columns:
from spatialdata.transformations import get_transformation, set_transformation

if "z" in points.columns and not suppress_z_warning:
logger.warning("Constructing the GeoDataFrame without considering the z coordinate in the geometry.")

points_gdf = GeoDataFrame(geometry=geopandas.points_from_xy(points["x"], points["y"]))
for c in points.columns:
points_gdf[c] = points[c]
transformations = get_transformation(points, get_all=True)
assert isinstance(transformations, dict)
assert len(transformations) > 0
points = points.compute()
points_gdf = GeoDataFrame(points, geometry=geopandas.points_from_xy(points["x"], points["y"]))
points_gdf.reset_index(drop=True, inplace=True)
# keep the x and y either in the geometry either as columns: we don't duplicate because having this redundancy could
# lead to subtle bugs when coverting back to dask dataframes
points_gdf.drop(columns=["x", "y"], inplace=True)
set_transformation(points_gdf, transformations, set_all=True)
return points_gdf


def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame:
def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bool = False) -> DaskDataFrame:
"""
Convert a GeoDataFrame which represents 2D or 3D points to a Dask DataFrame that passes the schema validation.

Expand All @@ -241,15 +250,20 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame:
"""
from spatialdata.models import PointsModel

# transformations are transferred automatically
ddf = dd.from_pandas(gdf[gdf.columns.drop("geometry")], npartitions=1)
# we don't want redundancy in the columns since this could lead to subtle bugs when converting back to geopandas
assert "x" not in ddf.columns
assert "y" not in ddf.columns
ddf["x"] = gdf.geometry.x
ddf["y"] = gdf.geometry.y
# parse
if "z" in ddf.columns:
logger.warning(
"Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an "
"additional column."
)
if not suppress_z_warning:
logger.warning(
"Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an "
"additional column."
)
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
else:
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"})
Expand Down
123 changes: 121 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
os.environ["USE_PYGEOS"] = "0"
# isort:on

from shapely import linearrings, polygons
from pathlib import Path
from typing import Union
from spatialdata._types import ArrayLike
Expand All @@ -29,6 +30,7 @@
)
from xarray import DataArray
from spatialdata.datasets import BlobsDataset
import geopandas as gpd

RNG = default_rng()

Expand Down Expand Up @@ -249,10 +251,16 @@ def _get_points() -> dict[str, DaskDataFrame]:
out = {}
for i in range(2):
name = f"{name}_{i}"
arr = RNG.normal(size=(100, 2))
arr = RNG.normal(size=(300, 2))
# randomly assign some values from v to the points
points_assignment0 = RNG.integers(0, 10, size=arr.shape[0]).astype(np.int_)
genes = RNG.choice(["a", "b"], size=arr.shape[0])
if i == 0:
genes = RNG.choice(["a", "b"], size=arr.shape[0])
else:
# we need to test the case in which we have a categorical column with more than 127 categories, see full
# explanation in write_points() (the parser will convert this column to a categorical since
# feature_key="genes")
genes = np.tile(np.array(list(map(str, range(280)))), 2)[:300]
annotation = pd.DataFrame(
{
"genes": genes,
Expand Down Expand Up @@ -299,3 +307,114 @@ def sdata_blobs() -> SpatialData:
sdata.labels["blobs_multiscale_labels"]
)
return sdata


def _make_points(coordinates: np.ndarray) -> DaskDataFrame:
"""Helper function to make a Points element."""
k0 = int(len(coordinates) / 3)
k1 = len(coordinates) - k0
genes = np.hstack((np.repeat("a", k0), np.repeat("b", k1)))
return PointsModel.parse(coordinates, annotation=pd.DataFrame({"genes": genes}), feature_key="genes")


def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons:
linear_rings = []
for centroid, half_width in zip(centroid_coordinates, half_widths):
min_coords = centroid - half_width
max_coords = centroid + half_width

linear_rings.append(
linearrings(
[
[min_coords[0], min_coords[1]],
[min_coords[0], max_coords[1]],
[max_coords[0], max_coords[1]],
[max_coords[0], min_coords[1]],
]
)
)
s = polygons(linear_rings)
polygon_series = gpd.GeoSeries(s)
cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series)
return ShapesModel.parse(cell_polygon_table)


def _make_circles(centroid_coordinates: np.ndarray, radius: list[float]) -> GeoDataFrame:
return ShapesModel.parse(centroid_coordinates, geometry=0, radius=radius)


def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData:
"""
Creates a SpatialData object with many edge cases for testing querying and aggregation.

Returns
-------
The SpatialData object.

Notes
-----
Description of what is tested (for a quick visualization, plot the returned SpatialData object):
- values to query/aggregate: polygons, points, circles
- values to query by: polygons, circles
- the shapes are completely inside, outside, or intersecting the query region (with the centroid inside or outside
the query region)

Additional cases:
- concave shape intersecting multiple times the same shape; used both as query and as value
- shape intersecting multiple shapes; used both as query and as value
"""
values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]])
values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]])
by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90], [210, 15]])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New addition: added one extra by square and one extra by circle, the rest is identical as before, just moved between files.

by_centroids_circles = np.array([[24, 15], [290, 15]])
values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles)))
values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15])
values_circles = _make_circles(values_centroids_circles, radius=[6] * 9)
by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15, 30])
by_circles = _make_circles(by_centroids_circles, radius=[30, 30])

from shapely.geometry import Polygon

polygon = Polygon([(100, 90 - 10), (100 + 30, 90), (100, 90 + 10), (150, 90)])
values_squares.loc[len(values_squares)] = [polygon]
ShapesModel.validate(values_squares)

polygon = Polygon([(0, 90 - 10), (0 + 30, 90), (0, 90 + 10), (50, 90)])
by_squares.loc[len(by_squares)] = [polygon]
ShapesModel.validate(by_squares)

sdata = SpatialData(
points={"points": values_points},
shapes={
"values_polygons": values_squares,
"values_circles": values_circles,
"by_polygons": by_squares,
"by_circles": by_circles,
},
)
# to visualize the cases considered in the test, much more immediate than reading them as text as done above
PLOT = False
if PLOT:
##
import matplotlib.pyplot as plt

ax = plt.gca()
sdata.pl.render_shapes(element="values_polygons", na_color=(0.5, 0.2, 0.5, 0.5)).pl.render_points().pl.show(
ax=ax
)
sdata.pl.render_shapes(element="values_circles", na_color=(0.5, 0.2, 0.5, 0.5)).pl.show(ax=ax)
sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax)
sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax)
plt.show()
##

# generate table
x = np.ones((21, 2)) * np.array([1, 2])
region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12)
instance_id = np.array(list(range(9)) + list(range(12)))
table = AnnData(x, obs=pd.DataFrame({"region": region, "instance_id": instance_id}))
table = TableModel.parse(
table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id"
)
sdata.table = table
return sdata
Loading