From dfd53f0413609ee6125915d84c762c85d55dfb21 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Mon, 23 Sep 2024 11:36:22 +0200 Subject: [PATCH 01/10] read/write attrs to/from disk --- CHANGELOG.md | 4 ++++ src/spatialdata/_core/spatialdata.py | 19 +++++++++++++++++-- src/spatialdata/_io/io_zarr.py | 1 + 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cdb72db..ea6cffcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][]. ## [0.x.x] - 2024-xx-xx +### Major + +- Added attributes at the SpatialData object level (`.attrs`) + ### Minor - Added `clip: bool = False` parameter to `polygon_query()` #670 diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 1c3affa2..7c63d86f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -3,7 +3,7 @@ import hashlib import os import warnings -from collections.abc import Generator +from collections.abc import Generator, Mapping from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -115,6 +115,7 @@ def __init__( points: dict[str, DaskDataFrame] | None = None, shapes: dict[str, GeoDataFrame] | None = None, tables: dict[str, AnnData] | Tables | None = None, + attrs: Mapping[Any, Any] | None = None, ) -> None: self._path: Path | None = None @@ -124,6 +125,7 @@ def __init__( self._points: Points = Points(shared_keys=self._shared_keys) self._shapes: Shapes = Shapes(shared_keys=self._shared_keys) self._tables: Tables = Tables(shared_keys=self._shared_keys) + self._attrs: dict[Any, Any] = dict(attrs) if attrs else {} # Workaround to allow for backward compatibility if isinstance(tables, AnnData): @@ -1152,7 +1154,11 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) store = parse_url(file_path, mode="w").store - _ = zarr.group(store=store, overwrite=overwrite) + zarr_group = zarr.group(store=store, overwrite=overwrite) + try: + zarr_group.attrs.put(self.attrs) + except TypeError as e: + raise TypeError("Invalid attribute in SpatialData.attrs") from e store.close() for element_type, element_name, element in self.gen_elements(): @@ -2188,6 +2194,15 @@ def __delitem__(self, key: str) -> None: element_type, _, _ = self._find_element(key) getattr(self, element_type).__delitem__(key) + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary of global attributes on this SpatialData object.""" + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + class QueryManager: """Perform queries on SpatialData objects.""" diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index bb14fbc8..bb4d2540 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -145,6 +145,7 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str points=points, shapes=shapes, tables=tables, + attrs=f.attrs.asdict(), ) sdata.path = Path(store) return sdata From bb49102d6a23f5e4f4bc5dad1a9ceb5bbe8a3043 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Thu, 31 Oct 2024 17:21:05 +0100 Subject: [PATCH 02/10] pass attrs to queries --- src/spatialdata/_core/operations/rasterize.py | 2 +- src/spatialdata/_core/operations/transform.py | 2 +- src/spatialdata/_core/query/spatial_query.py | 4 ++-- src/spatialdata/_core/spatialdata.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index bbe2a400..e7172bf1 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -305,7 +305,7 @@ def rasterize( new_labels[new_name] = rasterized else: raise RuntimeError(f"Unsupported model {model} detected as return type of rasterize().") - return SpatialData(images=new_images, labels=new_labels, tables=data.tables) + return SpatialData(images=new_images, labels=new_labels, tables=data.tables, attrs=data.attrs) parsed_data = _parse_element(element=data, sdata=sdata, element_var_name="data", sdata_var_name="sdata") model = get_model(parsed_data) diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index 15872cf1..3282809c 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -300,7 +300,7 @@ def _( new_elements[element_type][k] = transform( v, transformation, to_coordinate_system=to_coordinate_system, maintain_positioning=maintain_positioning ) - return SpatialData(**new_elements) + return SpatialData(**new_elements, attrs=data.attrs) @transform.register(DataArray) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index a8d0b475..32d93f1a 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -529,7 +529,7 @@ def _( tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) - return SpatialData(**new_elements, tables=tables) + return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs) @bounding_box_query.register(DataArray) @@ -881,7 +881,7 @@ def _( tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) - return SpatialData(**new_elements, tables=tables) + return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs) @polygon_query.register(DataArray) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 7c63d86f..ce77c39c 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -687,7 +687,7 @@ def filter_by_coordinate_system( set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system ) - return SpatialData(**elements, tables=tables) + return SpatialData(**elements, tables=tables, attrs=self.attrs) # TODO: move to relational query with refactor def _filter_tables( @@ -929,7 +929,7 @@ def transform_to_coordinate_system( if element_type not in elements: elements[element_type] = {} elements[element_type][element_name] = transformed - return SpatialData(**elements, tables=sdata.tables) + return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs) def elements_are_self_contained(self) -> dict[str, bool]: """ @@ -2112,7 +2112,7 @@ def subset( include_orphan_tables, elements_dict=elements_dict, ) - return SpatialData(**elements_dict, tables=tables) + return SpatialData(**elements_dict, tables=tables, attrs=self.attrs) def __getitem__(self, item: str) -> SpatialElement: """ From 85a6efeb01cb03f0d7b4657410227f7a810fc573 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Tue, 12 Nov 2024 14:20:38 +0100 Subject: [PATCH 03/10] use same strategy as anndata to concat attrs --- src/spatialdata/_core/concatenate.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index b8548f74..c0398cc5 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -4,11 +4,12 @@ from collections.abc import Iterable from copy import copy # Should probably go up at the top from itertools import chain -from typing import Any +from typing import Any, Callable from warnings import warn import numpy as np from anndata import AnnData +from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData @@ -80,6 +81,7 @@ def concatenate( concatenate_tables: bool = False, obs_names_make_unique: bool = True, modify_tables_inplace: bool = False, + attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None, **kwargs: Any, ) -> SpatialData: """ @@ -108,6 +110,8 @@ def concatenate( modify_tables_inplace Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables will be copied before modification. Copying is enabled by default but can be disabled for performance reasons. + attrs_merge + How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html) kwargs See :func:`anndata.concat` for more details. @@ -188,12 +192,16 @@ def concatenate( else: merged_tables[k] = v + attrs_merge = resolve_merge_strategy(attrs_merge) + attrs = attrs_merge([sdata.attrs for sdata in sdatas]) + sdata = SpatialData( images=merged_images, labels=merged_labels, points=merged_points, shapes=merged_shapes, tables=merged_tables, + attrs=attrs, ) if obs_names_make_unique: for table in sdata.tables.values(): From 427c3fe3c6004b7fded3dc14fb798270d1d10d07 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Tue, 12 Nov 2024 14:23:23 +0100 Subject: [PATCH 04/10] import callable from collections --- src/spatialdata/_core/concatenate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index c0398cc5..6be34850 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -1,10 +1,10 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import copy # Should probably go up at the top from itertools import chain -from typing import Any, Callable +from typing import Any from warnings import warn import numpy as np From 37ebf37af25d5c4177f84d74cc0c22143ee55b0a Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Tue, 12 Nov 2024 17:06:47 +0100 Subject: [PATCH 05/10] add SpatialData.write_attrs method --- src/spatialdata/_core/spatialdata.py | 31 +++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index ce77c39c..77bd6df9 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1155,10 +1155,7 @@ def write( store = parse_url(file_path, mode="w").store zarr_group = zarr.group(store=store, overwrite=overwrite) - try: - zarr_group.attrs.put(self.attrs) - except TypeError as e: - raise TypeError("Invalid attribute in SpatialData.attrs") from e + self.write_attrs(zarr_group=zarr_group) store.close() for element_type, element_name, element in self.gen_elements(): @@ -1522,7 +1519,28 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name - def write_metadata(self, element_name: str | None = None, consolidate_metadata: bool | None = None) -> None: + def write_attrs(self, overwrite: bool = True, zarr_group: zarr.Group | None = None) -> None: + store = None + + if zarr_group is None: + assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." + store = parse_url(self.path, mode="w").store + zarr_group = zarr.group(store=store, overwrite=overwrite) + + try: + zarr_group.attrs.put(self.attrs) + except TypeError as e: + raise TypeError("Invalid attribute in SpatialData.attrs") from e + + if store is not None: + store.close() + + def write_metadata( + self, + element_name: str | None = None, + consolidate_metadata: bool | None = None, + write_attrs: bool = True, + ) -> None: """ Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data. @@ -1547,6 +1565,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata: ----- When using the methods `write()` and `write_element()`, the metadata is written automatically. """ + if write_attrs: + self.write_attrs() + from spatialdata._core._elements import Elements if element_name is not None: From ffbac849b7902650fb6e1f42c3d9b412e62bca06 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 13 Nov 2024 09:09:12 +0100 Subject: [PATCH 06/10] fixed tests --- docs/tutorials/notebooks | 2 +- src/spatialdata/_core/spatialdata.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 45c4d0ed..a08575fa 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 45c4d0edd826dcf472725991ad688f80f1d1dd5a +Subproject commit a08575fae319d8de0734a9b99f5acaa429040c9e diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c0caf8aa..0b50fb2d 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1564,9 +1564,6 @@ def write_metadata( ----- When using the methods `write()` and `write_element()`, the metadata is written automatically. """ - if write_attrs: - self.write_attrs() - from spatialdata._core._elements import Elements if element_name is not None: @@ -1577,6 +1574,9 @@ def write_metadata( # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. # TODO: write omero metadata for the channel name of images. + if write_attrs: + self.write_attrs() + if consolidate_metadata is None and self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: From a4b91143bff57623e76d713bb80be671faa57525 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 29 Nov 2024 14:37:35 +0100 Subject: [PATCH 07/10] fix initializers, deprecated from_elements_dict() --- docs/tutorials/notebooks | 2 +- src/spatialdata/_core/spatialdata.py | 72 +++++++++++-------- tests/core/operations/test_rasterize.py | 4 +- .../operations/test_spatialdata_operations.py | 10 ++- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index a08575fa..45c4d0ed 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit a08575fae319d8de0734a9b99f5acaa429040c9e +Subproject commit 45c4d0edd826dcf472725991ad688f80f1d1dd5a diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 75fbf82f..ea7b34f5 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -218,7 +218,9 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: ) @staticmethod - def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData: + def from_elements_dict( + elements_dict: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None + ) -> SpatialData: """ Create a SpatialData object from a dict of elements. @@ -227,38 +229,20 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp elements_dict Dict of elements. The keys are the names of the elements and the values are the elements. A table can be present in the dict, but only at most one; its name is not used and can be anything. + attrs + Additional attributes to store in the SpatialData object. Returns ------- The SpatialData object. """ - d: dict[str, dict[str, SpatialElement] | AnnData | None] = { - "images": {}, - "labels": {}, - "points": {}, - "shapes": {}, - "tables": {}, - } - for k, e in elements_dict.items(): - schema = get_model(e) - if schema in (Image2DModel, Image3DModel): - assert isinstance(d["images"], dict) - d["images"][k] = e - elif schema in (Labels2DModel, Labels3DModel): - assert isinstance(d["labels"], dict) - d["labels"][k] = e - elif schema == PointsModel: - assert isinstance(d["points"], dict) - d["points"][k] = e - elif schema == ShapesModel: - assert isinstance(d["shapes"], dict) - d["shapes"][k] = e - elif schema == TableModel: - assert isinstance(d["tables"], dict) - d["tables"][k] = e - else: - raise ValueError(f"Unknown schema {schema}") - return SpatialData(**d) # type: ignore[arg-type] + warnings.warn( + 'This method is deprecated and will be removed in a future release. Use "SpatialData.init_from_elements(' + ')" instead. For the momment, such methods will be automatically called.', + DeprecationWarning, + stacklevel=2, + ) + return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs) @staticmethod def get_annotated_regions(table: AnnData) -> str | list[str]: @@ -2130,9 +2114,11 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A return found[0] @classmethod - @_deprecation_alias(table="tables", version="0.1.0") def init_from_elements( - cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None + cls, + elements: dict[str, SpatialElement], + tables: AnnData | dict[str, AnnData] | None = None, + attrs: Mapping[Any, Any] | None = None, ) -> SpatialData: """ Create a SpatialData object from a dict of named elements and an optional table. @@ -2143,6 +2129,8 @@ def init_from_elements( A dict of named elements. tables An optional table or dictionary of tables + attrs + Additional attributes to store in the SpatialData object. Returns ------- @@ -2157,11 +2145,33 @@ def init_from_elements( element_type = "labels" elif model == PointsModel: element_type = "points" + elif model == TableModel: + element_type = "tables" else: assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - return cls(**elements_dict, tables=tables) + # when the "tables" argument is removed, we can remove all this if block + if tables is not None: + warnings.warn( + 'The "tables" argument is deprecated and will be removed in a future version. Please ' + "specifies the tables in the `elements` argument. Until the removal occurs, the `elements` " + "variable will be automatically populated with the tables if the `tables` argument is not None.", + DeprecationWarning, + stacklevel=2, + ) + if "tables" in elements_dict: + raise ValueError( + "The tables key is already present in the elements dictionary. Please do not specify " + "the `tables` argument." + ) + elements_dict["tables"] = {} + if isinstance(tables, AnnData): + elements_dict["tables"]["table"] = tables + else: + for name, table in tables.items(): + elements_dict["tables"][name] = table + return cls(**elements_dict) def subset( self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 9ce5618b..ea8a9d63 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -205,7 +205,7 @@ def test_rasterize_shapes(): ) adata.obs["cat_values"] = adata.obs["cat_values"].astype("category") adata = TableModel.parse(adata, region=element_name, region_key="region", instance_key="instance_id") - sdata = SpatialData.init_from_elements({element_name: gdf[["geometry"]]}, table=adata) + sdata = SpatialData.init_from_elements({element_name: gdf[["geometry"]], "table": adata}) def _rasterize(element: GeoDataFrame, **kwargs) -> SpatialImage: return _rasterize_test_alternative_calls(element=element, sdata=sdata, element_name=element_name, **kwargs) @@ -320,7 +320,7 @@ def test_rasterize_points(): ) adata.obs["gene"] = adata.obs["gene"].astype("category") adata = TableModel.parse(adata, region=element_name, region_key="region", instance_key="instance_id") - sdata = SpatialData.init_from_elements({element_name: ddf[["x", "y"]]}, table=adata) + sdata = SpatialData.init_from_elements({element_name: ddf[["x", "y"]], "table": adata}) def _rasterize(element: DaskDataFrame, **kwargs) -> SpatialImage: return _rasterize_test_alternative_calls(element=element, sdata=sdata, element_name=element_name, **kwargs) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index a0d7ea2c..bf63fc4b 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -389,9 +389,15 @@ def test_no_shared_transformations() -> None: def test_init_from_elements(full_sdata: SpatialData) -> None: + # this first code block needs to be removed when the tables argument is removed from init_from_elements() all_elements = {name: el for _, name, el in full_sdata._gen_elements()} - sdata = SpatialData.init_from_elements(all_elements, table=full_sdata.table) - for element_type in ["images", "labels", "points", "shapes"]: + sdata = SpatialData.init_from_elements(all_elements, tables=full_sdata["table"]) + for element_type in ["images", "labels", "points", "shapes", "tables"]: + assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) + + all_elements = {name: el for _, name, el in full_sdata._gen_elements(include_table=True)} + sdata = SpatialData.init_from_elements(all_elements) + for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) From f417bd4bbc5e581545fd3898e7ad95fb654b8728 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sat, 30 Nov 2024 17:17:32 +0100 Subject: [PATCH 08/10] fixes around attrs behavior; tests --- src/spatialdata/_core/_deepcopy.py | 3 ++- src/spatialdata/_core/operations/_utils.py | 2 +- src/spatialdata/_core/spatialdata.py | 19 +++++++++++------ tests/core/test_deepcopy.py | 16 +++++++++++++++ tests/io/test_readwrite.py | 24 ++++++++++++++++++++++ 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index 1ce04409..cb38d46d 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -47,7 +47,8 @@ def _(sdata: SpatialData) -> SpatialData: elements_dict = {} for _, element_name, element in sdata.gen_elements(): elements_dict[element_name] = deepcopy(element) - return SpatialData.from_elements_dict(elements_dict) + deepcopied_attrs = _deepcopy(sdata.attrs) + return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs) @deepcopy.register(DataArray) diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index e78fccd3..00b65500 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -134,7 +134,7 @@ def transform_to_data_extent( set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True) for k, v in sdata.tables.items(): sdata_to_return_elements[k] = v.copy() - return SpatialData.from_elements_dict(sdata_to_return_elements) + return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs) def _parse_element( diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index ea7b34f5..85e8e0d1 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -132,7 +132,7 @@ def __init__( self._points: Points = Points(shared_keys=self._shared_keys) self._shapes: Shapes = Shapes(shared_keys=self._shared_keys) self._tables: Tables = Tables(shared_keys=self._shared_keys) - self._attrs: dict[Any, Any] = dict(attrs) if attrs else {} + self.attrs = attrs if attrs else {} # type: ignore[assignment] # Workaround to allow for backward compatibility if isinstance(tables, AnnData): @@ -1570,13 +1570,13 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name - def write_attrs(self, overwrite: bool = True, zarr_group: zarr.Group | None = None) -> None: + def write_attrs(self, zarr_group: zarr.Group | None = None) -> None: store = None if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = parse_url(self.path, mode="w").store - zarr_group = zarr.group(store=store, overwrite=overwrite) + store = parse_url(self.path, mode="r+").store + zarr_group = zarr.group(store=store, overwrite=False) try: zarr_group.attrs.put(self.attrs) @@ -2171,7 +2171,7 @@ def init_from_elements( else: for name, table in tables.items(): elements_dict["tables"][name] = table - return cls(**elements_dict) + return cls(**elements_dict, attrs=attrs) def subset( self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False @@ -2299,7 +2299,14 @@ def attrs(self) -> dict[Any, Any]: @attrs.setter def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) + if isinstance(value, dict): + # even if we call dict(value), we still get a shallow copy. For example, dict({'a': {'b': 1}}) will return + # a new dict, {'b': 1} is passed by reference. For this reason, we just pass .attrs by reference, which is + # more performant. The user can always use copy.deepcopy(sdata.attrs), or spatialdata.deepcopy(sdata), to + # get the attrs deepcopied. + self._attrs = value + else: + self._attrs = dict(value) class QueryManager: diff --git a/tests/core/test_deepcopy.py b/tests/core/test_deepcopy.py index 8e1c427a..b21cc925 100644 --- a/tests/core/test_deepcopy.py +++ b/tests/core/test_deepcopy.py @@ -1,5 +1,6 @@ from pandas.testing import assert_frame_equal +from spatialdata import SpatialData from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata.testing import assert_spatial_data_objects_are_identical @@ -45,3 +46,18 @@ def test_deepcopy(full_sdata): assert_spatial_data_objects_are_identical(full_sdata, copied) assert_spatial_data_objects_are_identical(full_sdata, copied_again) + + +def test_deepcopy_attrs(points: SpatialData) -> None: + some_attrs = {"a": {"b": 0}} + points.attrs = some_attrs + + # before deepcopy + sub_points = points.subset(["points_0"]) + assert sub_points.attrs is some_attrs + assert sub_points.attrs["a"] is some_attrs["a"] + + # after deepcopy + sub_points_deepcopy = _deepcopy(sub_points) + assert sub_points_deepcopy.attrs is not some_attrs + assert sub_points_deepcopy.attrs["a"] is not some_attrs["a"] diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index cc54fe04..abfe4eaa 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -581,6 +581,30 @@ def test_incremental_io_valid_name(points: SpatialData) -> None: _check_valid_name(points.delete_element_from_disk) +def test_incremental_io_attrs(points: SpatialData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + my_attrs = {"a": "b", "c": 1} + points.attrs = my_attrs + points.write(f) + + # test that the attributes are written to disk + sdata = SpatialData.read(f) + assert sdata.attrs == my_attrs + + # test incremental io attrs (write_attrs()) + sdata.attrs["c"] = 2 + sdata.write_attrs() + sdata2 = SpatialData.read(f) + assert sdata2.attrs["c"] == 2 + + # test incremental io attrs (write_metadata()) + sdata.attrs["c"] = 3 + sdata.write_metadata() + sdata2 = SpatialData.read(f) + assert sdata2.attrs["c"] == 3 + + cached_sdata_blobs = blobs() From f6475078340c025c55e1536a2fe835d31536f0d2 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 1 Dec 2024 13:32:16 +0100 Subject: [PATCH 09/10] added root-level SpatialData versioning --- src/spatialdata/_core/spatialdata.py | 38 +++++++++++++++++++++++++--- src/spatialdata/_io/format.py | 19 ++++++++++++++ src/spatialdata/_io/io_zarr.py | 11 +++++++- 3 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 85e8e0d1..0f56ca5b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1570,7 +1570,11 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name - def write_attrs(self, zarr_group: zarr.Group | None = None) -> None: + def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None: + from spatialdata._io.format import _parse_formats + + parsed = _parse_formats(formats=format) + store = None if zarr_group is None: @@ -1578,8 +1582,13 @@ def write_attrs(self, zarr_group: zarr.Group | None = None) -> None: store = parse_url(self.path, mode="r+").store zarr_group = zarr.group(store=store, overwrite=False) + version = parsed["SpatialData"].spatialdata_format_version + # we currently do not save any specific root level metadata, so we don't need to call + # parsed['SpatialData'].dict_to_attrs() + attrs_to_write = {"spatialdata_attrs": {"version": version}} | self.attrs + try: - zarr_group.attrs.put(self.attrs) + zarr_group.attrs.put(attrs_to_write) except TypeError as e: raise TypeError("Invalid attribute in SpatialData.attrs") from e @@ -2294,11 +2303,34 @@ def __delitem__(self, key: str) -> None: @property def attrs(self) -> dict[Any, Any]: - """Dictionary of global attributes on this SpatialData object.""" + """ + Dictionary of global attributes on this SpatialData object. + + Notes + ----- + Operations on SpatialData objects such as `subset()`, `query()`, ..., will pass the `.attrs` by + reference. If you want to modify the `.attrs` without affecting the original object, you should + either use `copy.deepcopy(sdata.attrs)` or eventually copy the SpatialData object using + `spatialdata.deepcopy()`. + """ return self._attrs @attrs.setter def attrs(self, value: Mapping[Any, Any]) -> None: + """ + Set the global attributes on this SpatialData object. + + Parameters + ---------- + value + The new attributes to set. + + Notes + ----- + If a dict is passed, the attrs will be passed by reference, else if a mapping is passed, + the mapping will be casted to a dict (shallow copy), i.e. if the mapping contains a dict inside, + that dict will be passed by reference. + """ if isinstance(value, dict): # even if we call dict(value), we still get a shallow copy. For example, dict({'a': {'b': 1}}) will return # a new dict, {'b': 1} is passed by reference. For this reason, we just pass .attrs by reference, which is diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index b992c287..ea3120f1 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -48,6 +48,16 @@ class SpatialDataFormat(CurrentFormat): pass +class SpatialDataContainerFormatV01(SpatialDataFormat): + @property + def spatialdata_format_version(self) -> str: + return "0.1" + + # no need for attrs_from_dict as we are not saving specific metadata at the root level + def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]]: + return {} + + class RasterFormatV01(SpatialDataFormat): """Formatter for raster data.""" @@ -201,6 +211,7 @@ def validate_table( CurrentShapesFormat = ShapesFormatV02 CurrentPointsFormat = PointsFormatV01 CurrentTablesFormat = TablesFormatV01 +CurrentSpatialDataContainerFormats = SpatialDataContainerFormatV01 ShapesFormats = { "0.1": ShapesFormatV01(), @@ -215,6 +226,9 @@ def validate_table( RasterFormats = { "0.1": RasterFormatV01(), } +SpatialDataContainerFormats = { + "0.1": SpatialDataContainerFormatV01(), +} def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) -> dict[str, SpatialDataFormat]: @@ -223,6 +237,7 @@ def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) "shapes": CurrentShapesFormat(), "points": CurrentPointsFormat(), "tables": CurrentTablesFormat(), + "SpatialData": CurrentSpatialDataContainerFormats(), } if formats is None: return parsed @@ -236,6 +251,7 @@ def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) "shapes": False, "points": False, "tables": False, + "SpatialData": False, } def _check_modified(element_type: str) -> None: @@ -256,6 +272,9 @@ def _check_modified(element_type: str) -> None: elif any(isinstance(fmt, type(v)) for v in RasterFormats.values()): _check_modified("raster") parsed["raster"] = fmt + elif any(isinstance(fmt, type(v)) for v in SpatialDataContainerFormats.values()): + _check_modified("SpatialData") + parsed["SpatialData"] = fmt else: raise ValueError(f"Unsupported format {fmt}") return parsed diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 2737d24b..8e6806dc 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -138,13 +138,22 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non logger.debug(f"Found {count} elements in {group}") + # read attrs metadata + attrs = f.attrs.asdict() + if "spatialdata_attrs" in attrs: + # no need to call for SpatialDataContainerFormatV01.attrs_to_dict since currently we do not save any root-level + # metadata + attrs.pop("spatialdata_attrs") + else: + attrs = None + sdata = SpatialData( images=images, labels=labels, points=points, shapes=shapes, tables=tables, - attrs=f.attrs.asdict(), + attrs=attrs, ) sdata.path = Path(store) return sdata From 4546beab9f17eb5cbf67c3837807e8b4957a7e9f Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 1 Dec 2024 14:14:24 +0100 Subject: [PATCH 10/10] writing spatialdata software version in .zattrs root-level metadata --- src/spatialdata/_core/spatialdata.py | 5 ++--- src/spatialdata/_io/format.py | 8 ++++++-- src/spatialdata/_io/io_zarr.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 0f56ca5b..db9b91ab 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1583,9 +1583,8 @@ def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr. zarr_group = zarr.group(store=store, overwrite=False) version = parsed["SpatialData"].spatialdata_format_version - # we currently do not save any specific root level metadata, so we don't need to call - # parsed['SpatialData'].dict_to_attrs() - attrs_to_write = {"spatialdata_attrs": {"version": version}} | self.attrs + version_specific_attrs = parsed["SpatialData"].attrs_to_dict() + attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs try: zarr_group.attrs.put(attrs_to_write) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index ea3120f1..abd1700e 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -53,10 +53,14 @@ class SpatialDataContainerFormatV01(SpatialDataFormat): def spatialdata_format_version(self) -> str: return "0.1" - # no need for attrs_from_dict as we are not saving specific metadata at the root level - def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]]: + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: return {} + def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: + from spatialdata import __version__ + + return {"spatialdata_software_version": __version__} + class RasterFormatV01(SpatialDataFormat): """Formatter for raster data.""" diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 8e6806dc..0be7c8f4 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -141,8 +141,8 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non # read attrs metadata attrs = f.attrs.asdict() if "spatialdata_attrs" in attrs: - # no need to call for SpatialDataContainerFormatV01.attrs_to_dict since currently we do not save any root-level - # metadata + # when refactoring the read_zarr function into reading componenets separately (and according to the version), + # we can move the code below (.pop()) into attrs_from_dict() attrs.pop("spatialdata_attrs") else: attrs = None