From 65793463a3398488a50a92ec502c45188ca76497 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Fri, 19 May 2023 16:07:20 +0200 Subject: [PATCH 1/5] adjusted tests hierarchy --- tests/core/query/test_spatial_query.py | 3 ++ tests/dataloader/test_datasets.py | 62 ++++++++++++++++++++++++++ tests/test_dataloader/test_datasets.py | 62 -------------------------- 3 files changed, 65 insertions(+), 62 deletions(-) delete mode 100644 tests/test_dataloader/test_datasets.py diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 9180059a..8c77c11b 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -553,18 +553,21 @@ def test_polygon_query_image2d(): pass +@pytest.mark.skip def test_polygon_query_image3d(): # single image case # multiscale case pass +@pytest.mark.skip def test_polygon_query_labels2d(): # single image case # multiscale case pass +@pytest.mark.skip def test_polygon_query_labels3d(): # single image case # multiscale case diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index e69de29b..b5772531 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -0,0 +1,62 @@ +import contextlib + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from spatialdata.dataloader import ImageTilesDataset +from spatialdata.models import TableModel + + +@pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) +@pytest.mark.parametrize( + "regions_element", + ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], +) +def test_tiles_dataset(sdata_blobs, image_element, regions_element): + if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: + cm = pytest.raises(NotImplementedError) + else: + cm = contextlib.nullcontext() + with cm: + ds = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={regions_element: image_element}, + tile_dim_in_units=10, + tile_dim_in_pixels=32, + target_coordinate_system="global", + ) + tile = ds[0].images.values().__iter__().__next__() + assert tile.shape == (3, 32, 32) + + +def test_tiles_table(sdata_blobs): + new_table = AnnData( + X=np.random.default_rng().random((3, 10)), + obs=pd.DataFrame({"region": "blobs_circles", "instance_id": np.array([0, 1, 2])}), + ) + new_table = TableModel.parse(new_table, region="blobs_circles", region_key="region", instance_key="instance_id") + del sdata_blobs.table + sdata_blobs.table = new_table + ds = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={"blobs_circles": "blobs_image"}, + tile_dim_in_units=10, + tile_dim_in_pixels=32, + target_coordinate_system="global", + ) + assert len(ds) == 3 + assert len(ds[0].table) == 1 + assert np.all(ds[0].table.X == new_table[0].X) + + +def test_tiles_multiple_elements(sdata_blobs): + ds = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={"blobs_circles": "blobs_image", "blobs_polygons": "blobs_multiscale_image"}, + tile_dim_in_units=10, + tile_dim_in_pixels=32, + target_coordinate_system="global", + ) + assert len(ds) == 6 + _ = ds[0] diff --git a/tests/test_dataloader/test_datasets.py b/tests/test_dataloader/test_datasets.py deleted file mode 100644 index b5772531..00000000 --- a/tests/test_dataloader/test_datasets.py +++ /dev/null @@ -1,62 +0,0 @@ -import contextlib - -import numpy as np -import pandas as pd -import pytest -from anndata import AnnData -from spatialdata.dataloader import ImageTilesDataset -from spatialdata.models import TableModel - - -@pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) -@pytest.mark.parametrize( - "regions_element", - ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], -) -def test_tiles_dataset(sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: - cm = pytest.raises(NotImplementedError) - else: - cm = contextlib.nullcontext() - with cm: - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={regions_element: image_element}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) - - -def test_tiles_table(sdata_blobs): - new_table = AnnData( - X=np.random.default_rng().random((3, 10)), - obs=pd.DataFrame({"region": "blobs_circles", "instance_id": np.array([0, 1, 2])}), - ) - new_table = TableModel.parse(new_table, region="blobs_circles", region_key="region", instance_key="instance_id") - del sdata_blobs.table - sdata_blobs.table = new_table - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 3 - assert len(ds[0].table) == 1 - assert np.all(ds[0].table.X == new_table[0].X) - - -def test_tiles_multiple_elements(sdata_blobs): - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image", "blobs_polygons": "blobs_multiscale_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 6 - _ = ds[0] From c6ba2109eb79e7719c94d5d23f9fb154c3c0034a Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Fri, 19 May 2023 16:14:18 +0200 Subject: [PATCH 2/5] moved data for tests in conftest --- tests/conftest.py | 111 +++++++++++++++++++++++ tests/core/query/test_spatial_query.py | 116 +------------------------ 2 files changed, 112 insertions(+), 115 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4ff8975e..f5456ab6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ os.environ["USE_PYGEOS"] = "0" # isort:on +from shapely import linearrings, polygons from pathlib import Path from typing import Union from spatialdata._types import ArrayLike @@ -29,6 +30,7 @@ ) from xarray import DataArray from spatialdata.datasets import BlobsDataset +import geopandas as gpd RNG = default_rng() @@ -299,3 +301,112 @@ def sdata_blobs() -> SpatialData: sdata.labels["blobs_multiscale_labels"] ) return sdata + + +def _make_points(coordinates: np.ndarray) -> DaskDataFrame: + """Helper function to make a Points element.""" + k0 = int(len(coordinates) / 3) + k1 = len(coordinates) - k0 + genes = np.hstack((np.repeat("a", k0), np.repeat("b", k1))) + return PointsModel.parse(coordinates, annotation=pd.DataFrame({"genes": genes}), feature_key="genes") + + +def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons: + linear_rings = [] + for centroid, half_width in zip(centroid_coordinates, half_widths): + min_coords = centroid - half_width + max_coords = centroid + half_width + + linear_rings.append( + linearrings( + [ + [min_coords[0], min_coords[1]], + [min_coords[0], max_coords[1]], + [max_coords[0], max_coords[1]], + [max_coords[0], min_coords[1]], + ] + ) + ) + s = polygons(linear_rings) + polygon_series = gpd.GeoSeries(s) + cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series) + return ShapesModel.parse(cell_polygon_table) + + +def _make_circles(centroid_coordinates: np.ndarray, radius: list[float]) -> GeoDataFrame: + return ShapesModel.parse(centroid_coordinates, geometry=0, radius=radius) + + +def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: + """ + Creates a SpatialData object with many edge cases for testing querying and aggregation. + + Returns + ------- + The SpatialData object. + + Notes + ----- + Description of what is tested (for a quick visualization, plot the returned SpatialData object): + - values to query/aggregate: polygons, points, circles + - values to query by: polygons, circles + - the shapes are completely inside, outside, or intersecting the query region (with the centroid inside or outside + the query region) + + Additional cases: + - concave shape intersecting multiple times the same shape; used both as query and as value + - shape intersecting multiple shapes; used both as query and as value + """ + values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]]) + values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]]) + by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90]]) + by_centroids_circles = np.array([[24, 15]]) + values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles))) + values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15]) + values_circles = _make_circles(values_centroids_circles, radius=[6] * 9) + by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15]) + by_circles = _make_circles(by_centroids_circles, radius=[30]) + + from shapely.geometry import Polygon + + polygon = Polygon([(100, 90 - 10), (100 + 30, 90), (100, 90 + 10), (150, 90)]) + values_squares.loc[len(values_squares)] = [polygon] + ShapesModel.validate(values_squares) + + polygon = Polygon([(0, 90 - 10), (0 + 30, 90), (0, 90 + 10), (50, 90)]) + by_squares.loc[len(by_squares)] = [polygon] + ShapesModel.validate(by_squares) + + sdata = SpatialData( + points={"points": values_points}, + shapes={ + "values_polygons": values_squares, + "values_circles": values_circles, + "by_polygons": by_squares, + "by_circles": by_circles, + }, + ) + # to visualize the cases considered in the test, much more immediate than reading them as text as done above + PLOT = False + if PLOT: + import matplotlib.pyplot as plt + + ax = plt.gca() + sdata.pl.render_shapes(element="values_polygons", na_color=(0.5, 0.2, 0.5, 0.5)).pl.render_points().pl.show( + ax=ax + ) + sdata.pl.render_shapes(element="values_circles", na_color=(0.5, 0.2, 0.5, 0.5)).pl.show(ax=ax) + sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) + sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) + plt.show() + + # generate table + x = np.ones((21, 2)) * np.array([1, 2]) + region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12) + instance_id = np.array(list(range(9)) + list(range(12))) + table = AnnData(x, obs=pd.DataFrame({"region": region, "instance_id": instance_id})) + table = TableModel.parse( + table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" + ) + sdata.table = table + return sdata diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 8c77c11b..1304614d 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -1,14 +1,9 @@ from dataclasses import FrozenInstanceError -import geopandas as gpd import numpy as np -import pandas as pd import pytest from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame -from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage -from shapely import linearrings, polygons from spatial_image import SpatialImage from spatialdata import SpatialData from spatialdata._core.query.spatial_query import ( @@ -31,44 +26,10 @@ set_transformation, ) - -def _make_points(coordinates: np.ndarray) -> DaskDataFrame: - """Helper function to make a Points element.""" - k0 = int(len(coordinates) / 3) - k1 = len(coordinates) - k0 - genes = np.hstack((np.repeat("a", k0), np.repeat("b", k1))) - return PointsModel.parse(coordinates, annotation=pd.DataFrame({"genes": genes}), feature_key="genes") - - -def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons: - linear_rings = [] - for centroid, half_width in zip(centroid_coordinates, half_widths): - min_coords = centroid - half_width - max_coords = centroid + half_width - - linear_rings.append( - linearrings( - [ - [min_coords[0], min_coords[1]], - [min_coords[0], max_coords[1]], - [max_coords[0], max_coords[1]], - [max_coords[0], min_coords[1]], - ] - ) - ) - s = polygons(linear_rings) - polygon_series = gpd.GeoSeries(s) - cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series) - return ShapesModel.parse(cell_polygon_table) - - -def _make_circles(centroid_coordinates: np.ndarray, radius: list[float]) -> GeoDataFrame: - return ShapesModel.parse(centroid_coordinates, geometry=0, radius=radius) +from tests.conftest import _make_points, _make_sdata_for_testing_querying_and_aggretation, _make_squares # ---------------- test bounding box queries ---------------[ - - def test_bounding_box_request_immutable(): """Test that the bounding box request is immutable.""" request = BoundingBoxRequest( @@ -396,81 +357,6 @@ def test_bounding_box_filter_table(): # ----------------- test polygon query ----------------- -def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: - """ - Creates a SpatialData object with many edge cases for testing querying and aggregation. - - Returns - ------- - The SpatialData object. - - Notes - ----- - Description of what is tested (for a quick visualization, plot the returned SpatialData object): - - values to query/aggregate: polygons, points, circles - - values to query by: polygons, circles - - the shapes are completely inside, outside, or intersecting the query region (with the centroid inside or outside - the query region) - - Additional cases: - - concave shape intersecting multiple times the same shape; used both as query and as value - - shape intersecting multiple shapes; used both as query and as value - """ - values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]]) - values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]]) - by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90]]) - by_centroids_circles = np.array([[24, 15]]) - values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles))) - values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15]) - values_circles = _make_circles(values_centroids_circles, radius=[6] * 9) - by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15]) - by_circles = _make_circles(by_centroids_circles, radius=[30]) - - from shapely.geometry import Polygon - - polygon = Polygon([(100, 90 - 10), (100 + 30, 90), (100, 90 + 10), (150, 90)]) - values_squares.loc[len(values_squares)] = [polygon] - ShapesModel.validate(values_squares) - - polygon = Polygon([(0, 90 - 10), (0 + 30, 90), (0, 90 + 10), (50, 90)]) - by_squares.loc[len(by_squares)] = [polygon] - ShapesModel.validate(by_squares) - - sdata = SpatialData( - points={"points": values_points}, - shapes={ - "values_polygons": values_squares, - "values_circles": values_circles, - "by_polygons": by_squares, - "by_circles": by_circles, - }, - ) - # to visualize the cases considered in the test, much more immediate than reading them as text as done above - PLOT = False - if PLOT: - import matplotlib.pyplot as plt - - ax = plt.gca() - sdata.pl.render_shapes(element="values_polygons", na_color=(0.5, 0.2, 0.5, 0.5)).pl.render_points().pl.show( - ax=ax - ) - sdata.pl.render_shapes(element="values_circles", na_color=(0.5, 0.2, 0.5, 0.5)).pl.show(ax=ax) - sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) - sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) - plt.show() - - # generate table - x = np.ones((21, 2)) * np.array([1, 2]) - region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12) - instance_id = np.array(list(range(9)) + list(range(12))) - table = AnnData(x, obs=pd.DataFrame({"region": region, "instance_id": instance_id})) - table = TableModel.parse( - table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" - ) - sdata.table = table - return sdata - - def test_polygon_query_points(): sdata = _make_sdata_for_testing_querying_and_aggretation() polygon = sdata["by_polygons"].geometry.iloc[0] From 5c13e8a12a97009c78b4e8be45e375586fde2ee0 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Fri, 19 May 2023 16:18:57 +0200 Subject: [PATCH 3/5] adding "by" shapes without values inside to the datasets for the tests --- tests/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f5456ab6..6a136ecb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -359,13 +359,13 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: """ values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]]) values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]]) - by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90]]) - by_centroids_circles = np.array([[24, 15]]) + by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90], [210, 15]]) + by_centroids_circles = np.array([[24, 15], [290, 15]]) values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles))) values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15]) values_circles = _make_circles(values_centroids_circles, radius=[6] * 9) - by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15]) - by_circles = _make_circles(by_centroids_circles, radius=[30]) + by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15, 30]) + by_circles = _make_circles(by_centroids_circles, radius=[30, 30]) from shapely.geometry import Polygon From c0c104138d2fa652830d752dea3df389bf9ef25a Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Sat, 20 May 2023 14:23:52 +0200 Subject: [PATCH 4/5] fix aggregate points into circles when circles have no points inside; fix bug with categories and to_parquet(); add tests --- src/spatialdata/_core/operations/aggregate.py | 19 +++++---- src/spatialdata/_core/query/spatial_query.py | 18 +++++--- src/spatialdata/_io/io_points.py | 12 ++++++ src/spatialdata/models/_utils.py | 34 ++++++++++----- tests/conftest.py | 12 +++++- tests/core/operations/test_aggregations.py | 41 ++++++++++++++++--- tests/models/test_models.py | 30 ++++++++++++++ 7 files changed, 135 insertions(+), 31 deletions(-) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 51895f36..8fc69379 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -24,7 +24,6 @@ ShapesModel, get_model, ) -from spatialdata.models._utils import get_axes_names from spatialdata.transformations import BaseTransformation, Identity, get_transformation __all__ = ["aggregate"] @@ -112,8 +111,8 @@ def _aggregate_points_by_shapes( value_key: str | None = None, agg_func: str | list[str] = "count", ) -> ad.AnnData: - # Have to get dims on dask dataframe, can't get from pandas - dims = get_axes_names(points) + from spatialdata.models import points_dask_dataframe_to_geopandas + # Default value for id_key if id_key is None: id_key = points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY] @@ -123,9 +122,8 @@ def _aggregate_points_by_shapes( "`FEATURE_KEY` for the points." ) - if isinstance(points, ddf.DataFrame): - points = points.compute() - points = gpd.GeoDataFrame(points, geometry=gpd.points_from_xy(*[points[dim] for dim in dims])) + points = points_dask_dataframe_to_geopandas(points, suppress_z_warning=True) + shapes = circles_to_polygons(shapes) return _aggregate_shapes(points, shapes, id_key, value_key, agg_func) @@ -253,8 +251,12 @@ def _aggregate_shapes( value_key: point_values, } ) + ## aggregated = to_agg.groupby([by_id_key, id_key]).agg(agg_func).reset_index() - obs_id_categorical = pd.Categorical(aggregated[by_id_key]) + + # this is for only shapes in "by" that intersect with something in "value" + obs_id_categorical_categories = by.index.tolist() + obs_id_categorical = pd.Categorical(aggregated[by_id_key], categories=obs_id_categorical_categories) X = sparse.coo_matrix( ( @@ -265,7 +267,8 @@ def _aggregate_shapes( ).tocsr() return ad.AnnData( X, - obs=pd.DataFrame(index=obs_id_categorical.categories), + # obs=pd.DataFrame(index=obs_id_categorical.categories), + obs=pd.DataFrame(index=pd.Categorical(by.index.tolist()).categories), var=pd.DataFrame(index=joined[id_key].cat.categories), dtype=X.dtype, ) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index f94af552..0ff5c571 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -247,12 +247,18 @@ def _bounding_box_mask_points( def _dict_query_dispatcher( elements: dict[str, SpatialElement], query_function: Callable[[SpatialElement], SpatialElement], **kwargs: Any ) -> dict[str, SpatialElement]: + from spatialdata.transformations import get_transformation + queried_elements = {} for key, element in elements.items(): - result = query_function(element, **kwargs) - if result is not None: - # query returns None if it is empty - queried_elements[key] = result + target_coordinate_system = kwargs["target_coordinate_system"] + d = get_transformation(element, get_all=True) + assert isinstance(d, dict) + if target_coordinate_system in d: + result = query_function(element, **kwargs) + if result is not None: + # query returns None if it is empty + queried_elements[key] = result return queried_elements @@ -649,12 +655,12 @@ def _polygon_query( new_points = {} if points: for points_name, p in sdata.points.items(): - points_gdf = points_dask_dataframe_to_geopandas(p) + points_gdf = points_dask_dataframe_to_geopandas(p, suppress_z_warning=True) indices = points_gdf.geometry.intersects(polygon) if np.sum(indices) == 0: raise ValueError("we expect at least one point") queried_points = points_gdf[indices] - ddf = points_geopandas_to_dask_dataframe(queried_points) + ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True) transformation = get_transformation(p, target_coordinate_system) if "z" in ddf.columns: ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"}) diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 21f09534..c951fa23 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -54,6 +54,18 @@ def write_points( points_groups = group.require_group(name) path = Path(points_groups._store.path) / points_groups.path / "points.parquet" + + # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is + # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the + # categories to known categories using 'c.cat.as_known()' and assigns the transformed Series back to the original + # DataFrame. This step is crucial when the number of categories exceeds 127, as pyarrow defaults to int8 for + # unknown categories which can only hold values from -128 to 127. + for column_name in points.columns: + c = points[column_name] + if c.dtype == "category" and not c.cat.known: + c = c.cat.as_known() + points[column_name] = c + points.to_parquet(path) attrs = fmt.attrs_to_dict(points.attrs) diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index ddfe0945..9ab0f4dc 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -190,7 +190,7 @@ def _validate_dims(dims: tuple[str, ...]) -> None: raise ValueError(f"Invalid dimensions: {dims}") -def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame: +def points_dask_dataframe_to_geopandas(points: DaskDataFrame, suppress_z_warning: bool = False) -> GeoDataFrame: """ Convert a Dask DataFrame to a GeoDataFrame. @@ -212,16 +212,25 @@ def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame: points need to be saved as a Dask DataFrame. We will be restructuring the models to allow for GeoDataFrames soon. """ - if "z" in points.columns: + from spatialdata.transformations import get_transformation, set_transformation + + if "z" in points.columns and not suppress_z_warning: logger.warning("Constructing the GeoDataFrame without considering the z coordinate in the geometry.") - points_gdf = GeoDataFrame(geometry=geopandas.points_from_xy(points["x"], points["y"])) - for c in points.columns: - points_gdf[c] = points[c] + transformations = get_transformation(points, get_all=True) + assert isinstance(transformations, dict) + assert len(transformations) > 0 + points = points.compute() + points_gdf = GeoDataFrame(points, geometry=geopandas.points_from_xy(points["x"], points["y"])) + points_gdf.reset_index(drop=True, inplace=True) + # keep the x and y either in the geometry either as columns: we don't duplicate because having this redundancy could + # lead to subtle bugs when coverting back to dask dataframes + points_gdf.drop(columns=["x", "y"], inplace=True) + set_transformation(points_gdf, transformations, set_all=True) return points_gdf -def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame: +def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bool = False) -> DaskDataFrame: """ Convert a GeoDataFrame which represents 2D or 3D points to a Dask DataFrame that passes the schema validation. @@ -241,15 +250,20 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame: """ from spatialdata.models import PointsModel + # transformations are transferred automatically ddf = dd.from_pandas(gdf[gdf.columns.drop("geometry")], npartitions=1) + # we don't want redundancy in the columns since this could lead to subtle bugs when converting back to geopandas + assert "x" not in ddf.columns + assert "y" not in ddf.columns ddf["x"] = gdf.geometry.x ddf["y"] = gdf.geometry.y # parse if "z" in ddf.columns: - logger.warning( - "Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an " - "additional column." - ) + if not suppress_z_warning: + logger.warning( + "Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an " + "additional column." + ) ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"}) else: ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"}) diff --git a/tests/conftest.py b/tests/conftest.py index 6a136ecb..a595416e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -251,10 +251,16 @@ def _get_points() -> dict[str, DaskDataFrame]: out = {} for i in range(2): name = f"{name}_{i}" - arr = RNG.normal(size=(100, 2)) + arr = RNG.normal(size=(300, 2)) # randomly assign some values from v to the points points_assignment0 = RNG.integers(0, 10, size=arr.shape[0]).astype(np.int_) - genes = RNG.choice(["a", "b"], size=arr.shape[0]) + if i == 0: + genes = RNG.choice(["a", "b"], size=arr.shape[0]) + else: + # we need to test the case in which we have a categorical column with more than 127 categories, see full + # explanation in write_points() (the parser will convert this column to a categorical since + # feature_key="genes") + genes = np.tile(np.array(list(map(str, range(280)))), 2)[:300] annotation = pd.DataFrame( { "genes": genes, @@ -389,6 +395,7 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: # to visualize the cases considered in the test, much more immediate than reading them as text as done above PLOT = False if PLOT: + ## import matplotlib.pyplot as plt ax = plt.gca() @@ -399,6 +406,7 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) plt.show() + ## # generate table x = np.ones((21, 2)) * np.array([1, 2]) diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 8c3a8d7c..1641e27e 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -10,6 +10,8 @@ from spatialdata._core.operations.aggregate import aggregate from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel +from tests.conftest import _make_sdata_for_testing_querying_and_aggretation + RNG = default_rng(42) @@ -25,26 +27,55 @@ def test_aggregate_points_by_polygons() -> None: coordinates={"x": "x", "y": "y"}, feature_key="gene", ) + # shape_0 doesn't contain points, the other two shapes do shapes = ShapesModel.parse( gpd.GeoDataFrame( geometry=[ + shapely.Polygon([(0.0, 10.0), (2.0, 10.0), (0.0, 9.0)]), shapely.Polygon([(0.5, 7.0), (4.0, 2.0), (5.0, 8.0)]), shapely.Polygon([(3.0, 8.0), (7.0, 2.0), (10.0, 6.0), (7.0, 10.0)]), ], - index=["shape_0", "shape_1"], + index=["shape_0", "shape_1", "shape_2"], ) ) result_adata = aggregate(points, shapes, "gene", agg_func="sum") - assert result_adata.obs_names.to_list() == ["shape_0", "shape_1"] + assert result_adata.obs_names.to_list() == ["shape_0", "shape_1", "shape_2"] assert result_adata.var_names.to_list() == ["a", "b"] - np.testing.assert_equal(result_adata.X.A, np.array([[2, 0], [1, 3]])) + np.testing.assert_equal(result_adata.X.A, np.array([[0, 0], [2, 0], [1, 3]])) # id_key can be implicit for points result_adata_implicit = aggregate(points, shapes, agg_func="sum") assert_equal(result_adata, result_adata_implicit) +def test_aggregate_points_by_circles(): + sdata = _make_sdata_for_testing_querying_and_aggretation() + # checks also that cound and sum behave the same for categorical variables + adata0 = aggregate( + values=sdata["points"], + by=sdata["by_circles"], + id_key="genes", + agg_func="count", + target_coordinate_system="global", + ) + adata1 = aggregate( + values=sdata["points"], + by=sdata["by_circles"], + id_key="genes", + agg_func="sum", + target_coordinate_system="global", + ) + + assert adata0.var_names.tolist() == ["a", "b"] + assert adata1.var_names.tolist() == ["a", "b"] + X0 = adata0.X.todense() + X1 = adata1.X.todense() + + assert np.all(np.matrix([[3, 3], [0, 0]]) == X0) + assert np.all(np.matrix([[3, 3], [0, 0]]) == X1) + + def test_aggregate_polygons_by_polygons() -> None: cellular = ShapesModel.parse( gpd.GeoDataFrame( @@ -123,6 +154,6 @@ def test_aggregate_image_by_labels(labels_blobs, image_schema, labels_schema) -> def test_aggregate_spatialdata(sdata_blobs: SpatialData) -> None: sdata = sdata_blobs.aggregate(sdata_blobs.points["blobs_points"], by="blobs_polygons") assert isinstance(sdata, SpatialData) - assert len(sdata.shapes["blobs_polygons"]) == 2 - assert sdata.table.shape == (2, 2) + assert len(sdata.shapes["blobs_polygons"]) == 3 + assert sdata.table.shape == (3, 2) assert len(sdata.points["points"].compute()) == 300 diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 423b35ea..285c4584 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -31,6 +31,8 @@ ShapesModel, TableModel, get_model, + points_dask_dataframe_to_geopandas, + points_geopandas_to_dask_dataframe, ) from spatialdata.models._utils import validate_axis_name from spatialdata.models.models import RasterSchema @@ -333,3 +335,31 @@ def test_get_schema(): assert schema == ShapesModel schema = get_model(table) assert schema == TableModel + + +def test_points_and_shapes_conversions(shapes, points): + from spatialdata.transformations import get_transformation + + circles0 = shapes["circles"] + circles1 = points_geopandas_to_dask_dataframe(circles0) + circles2 = points_dask_dataframe_to_geopandas(circles1) + circles0 = circles0[circles2.columns] + assert np.all(circles0.values == circles2.values) + + t0 = get_transformation(circles0, get_all=True) + t1 = get_transformation(circles1, get_all=True) + t2 = get_transformation(circles2, get_all=True) + assert t0 == t1 + assert t0 == t2 + + points0 = points["points_0"] + points1 = points_dask_dataframe_to_geopandas(points0) + points2 = points_geopandas_to_dask_dataframe(points1) + points0 = points0[points2.columns] + assert np.all(points0.values == points2.values) + + t0 = get_transformation(points0, get_all=True) + t1 = get_transformation(points1, get_all=True) + t2 = get_transformation(points2, get_all=True) + assert t0 == t1 + assert t0 == t2 From 1ab3361c7d639822ce5bbdbb01bd3e4c349c309b Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Mon, 22 May 2023 11:23:39 +0200 Subject: [PATCH 5/5] code review from Giovanni --- src/spatialdata/_core/operations/aggregate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 8fc69379..faef5134 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -267,8 +267,7 @@ def _aggregate_shapes( ).tocsr() return ad.AnnData( X, - # obs=pd.DataFrame(index=obs_id_categorical.categories), - obs=pd.DataFrame(index=pd.Categorical(by.index.tolist()).categories), + obs=pd.DataFrame(index=pd.Categorical(obs_id_categorical_categories).categories), var=pd.DataFrame(index=joined[id_key].cat.categories), dtype=X.dtype, )