Skip to content

Commit

Permalink
fix tests and docs and remove
Browse files Browse the repository at this point in the history
  • Loading branch information
giovp committed Nov 27, 2023
1 parent fa539f5 commit 63e761d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ def _get_tile_coords(

# get extent, first by checking shape defaults, then by using the `tile_dim_in_units`
if tile_dim_in_units is None:
if elem.iloc[0][0].geom_type == "Point":
if elem.iloc[0, 0].geom_type == "Point":
extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale
elif elem.iloc[0][0].geom_type in ["Polygon", "MultiPolygon"]:
elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]:
extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale
else:
raise ValueError("Only point and polygon shapes are supported.")
Expand Down
8 changes: 4 additions & 4 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from multiscale_spatial_image import to_multiscale
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from multiscale_spatial_image.to_multiscale.to_multiscale import Methods
from pandas.api.types import is_categorical_dtype
from pandas import CategoricalDtype
from shapely._geometry import GeometryType
from shapely.geometry import MultiPolygon, Point, Polygon
from shapely.geometry.collection import GeometryCollection
Expand Down Expand Up @@ -470,7 +470,7 @@ def validate(cls, data: DaskDataFrame) -> None:
raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.")
if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]:
feature_key = data.attrs[cls.ATTRS_KEY][cls.FEATURE_KEY]
if not is_categorical_dtype(data[feature_key]):
if not isinstance(data[feature_key], CategoricalDtype):
logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.")

@singledispatchmethod
Expand Down Expand Up @@ -624,7 +624,7 @@ def _add_metadata_and_validate(
# Here we are explicitly importing the categories
# but it is a convenient way to ensure that the categories are known.
# It also just changes the state of the series, so it is not a big deal.
if is_categorical_dtype(data[c]) and not data[c].cat.known:
if isinstance(data[c], CategoricalDtype) and not data[c].cat.known:
try:
data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
except ValueError:
Expand Down Expand Up @@ -729,7 +729,7 @@ def parse(
region_: list[str] = region if isinstance(region, list) else [region]
if not adata.obs[region_key].isin(region_).all():
raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.")
if not is_categorical_dtype(adata.obs[region_key]):
if not isinstance(adata.obs[region_key], CategoricalDtype):
warnings.warn(
f"Converting `{cls.REGION_KEY_KEY}: {region_key}` to categorical dtype.", UserWarning, stacklevel=2
)
Expand Down
15 changes: 11 additions & 4 deletions tests/dataloader/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,18 @@ def test_default(self, sdata_blobs, regions_element, raster):
if raster:
assert tile.shape == (3, 329, 329)
else:
assert tile.shape == (3, 164, 164)
assert tile.shape == (3, 165, 164)
else:
raise ValueError(f"Unexpected regions_element: {regions_element}")

# extent has units in pixel so should be the same as tile shape
if raster:
assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1]
else:
assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1]
if regions_element != "blobs_multipolygons":
assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1]
else:
assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1]
assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns)
assert list(sdata_tile.images.keys())[0] == "blobs_image"

Expand All @@ -88,11 +92,14 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot):
elif regions_element == "blobs_polygons":
assert tile.shape == (3, 82, 82)
elif regions_element == "blobs_multipolygons":
assert tile.shape == (3, 164, 164)
assert tile.shape == (3, 165, 164)
else:
raise ValueError(f"Unexpected regions_element: {regions_element}")
# extent has units in pixel so should be the same as tile shape
assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1]
if regions_element != "blobs_multipolygons":
assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1]
else:
assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1]
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
assert annot.shape[1] == len(return_annot)

Expand Down

0 comments on commit 63e761d

Please sign in to comment.