Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding attrs at the SpatialData object level #711

Merged
merged 13 commits into from
Dec 1, 2024
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ and this project adheres to [Semantic Versioning][].

## [0.2.3] - 2024-09-25

### Major

- Added attributes at the SpatialData object level (`.attrs`)

### Minor

- Added `clip: bool = False` parameter to `polygon_query()` #670
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/_core/_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def _(sdata: SpatialData) -> SpatialData:
elements_dict = {}
for _, element_name, element in sdata.gen_elements():
elements_dict[element_name] = deepcopy(element)
return SpatialData.from_elements_dict(elements_dict)
deepcopied_attrs = _deepcopy(sdata.attrs)
return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs)


@deepcopy.register(DataArray)
Expand Down
10 changes: 9 additions & 1 deletion src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from copy import copy # Should probably go up at the top
from itertools import chain
from typing import Any
from warnings import warn

import numpy as np
from anndata import AnnData
from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -80,6 +81,7 @@ def concatenate(
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None,
**kwargs: Any,
) -> SpatialData:
"""
Expand Down Expand Up @@ -108,6 +110,8 @@ def concatenate(
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
attrs_merge
How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html)
kwargs
See :func:`anndata.concat` for more details.

Expand Down Expand Up @@ -188,12 +192,16 @@ def concatenate(
else:
merged_tables[k] = v

attrs_merge = resolve_merge_strategy(attrs_merge)
attrs = attrs_merge([sdata.attrs for sdata in sdatas])

sdata = SpatialData(
images=merged_images,
labels=merged_labels,
points=merged_points,
shapes=merged_shapes,
tables=merged_tables,
attrs=attrs,
)
if obs_names_make_unique:
for table in sdata.tables.values():
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def transform_to_data_extent(
set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True)
for k, v in sdata.tables.items():
sdata_to_return_elements[k] = v.copy()
return SpatialData.from_elements_dict(sdata_to_return_elements)
return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs)


def _parse_element(
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def rasterize(
new_labels[new_name] = rasterized
else:
raise RuntimeError(f"Unsupported model {model} detected as return type of rasterize().")
return SpatialData(images=new_images, labels=new_labels, tables=data.tables)
return SpatialData(images=new_images, labels=new_labels, tables=data.tables, attrs=data.attrs)

parsed_data = _parse_element(element=data, sdata=sdata, element_var_name="data", sdata_var_name="sdata")
model = get_model(parsed_data)
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _(
new_elements[element_type][k] = transform(
v, transformation, to_coordinate_system=to_coordinate_system, maintain_positioning=maintain_positioning
)
return SpatialData(**new_elements)
return SpatialData(**new_elements, attrs=data.attrs)


@transform.register(DataArray)
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _(

tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)

return SpatialData(**new_elements, tables=tables)
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)


@bounding_box_query.register(DataArray)
Expand Down Expand Up @@ -885,7 +885,7 @@ def _(

tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)

return SpatialData(**new_elements, tables=tables)
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)


@polygon_query.register(DataArray)
Expand Down
158 changes: 121 additions & 37 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import os
import warnings
from collections.abc import Generator
from collections.abc import Generator, Mapping
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
points: dict[str, DaskDataFrame] | None = None,
shapes: dict[str, GeoDataFrame] | None = None,
tables: dict[str, AnnData] | Tables | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> None:
self._path: Path | None = None

Expand All @@ -131,6 +132,7 @@ def __init__(
self._points: Points = Points(shared_keys=self._shared_keys)
self._shapes: Shapes = Shapes(shared_keys=self._shared_keys)
self._tables: Tables = Tables(shared_keys=self._shared_keys)
self.attrs = attrs if attrs else {} # type: ignore[assignment]

# Workaround to allow for backward compatibility
if isinstance(tables, AnnData):
Expand Down Expand Up @@ -216,7 +218,9 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None:
)

@staticmethod
def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData:
def from_elements_dict(
elements_dict: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None
) -> SpatialData:
"""
Create a SpatialData object from a dict of elements.

Expand All @@ -225,38 +229,20 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp
elements_dict
Dict of elements. The keys are the names of the elements and the values are the elements.
A table can be present in the dict, but only at most one; its name is not used and can be anything.
attrs
Additional attributes to store in the SpatialData object.

Returns
-------
The SpatialData object.
"""
d: dict[str, dict[str, SpatialElement] | AnnData | None] = {
"images": {},
"labels": {},
"points": {},
"shapes": {},
"tables": {},
}
for k, e in elements_dict.items():
schema = get_model(e)
if schema in (Image2DModel, Image3DModel):
assert isinstance(d["images"], dict)
d["images"][k] = e
elif schema in (Labels2DModel, Labels3DModel):
assert isinstance(d["labels"], dict)
d["labels"][k] = e
elif schema == PointsModel:
assert isinstance(d["points"], dict)
d["points"][k] = e
elif schema == ShapesModel:
assert isinstance(d["shapes"], dict)
d["shapes"][k] = e
elif schema == TableModel:
assert isinstance(d["tables"], dict)
d["tables"][k] = e
else:
raise ValueError(f"Unknown schema {schema}")
return SpatialData(**d) # type: ignore[arg-type]
warnings.warn(
'This method is deprecated and will be removed in a future release. Use "SpatialData.init_from_elements('
')" instead. For the momment, such methods will be automatically called.',
DeprecationWarning,
stacklevel=2,
)
return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs)

@staticmethod
def get_annotated_regions(table: AnnData) -> str | list[str]:
Expand Down Expand Up @@ -712,7 +698,7 @@ def filter_by_coordinate_system(
set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system
)

return SpatialData(**elements, tables=tables)
return SpatialData(**elements, tables=tables, attrs=self.attrs)

# TODO: move to relational query with refactor
def _filter_tables(
Expand Down Expand Up @@ -954,7 +940,7 @@ def transform_to_coordinate_system(
if element_type not in elements:
elements[element_type] = {}
elements[element_type][element_name] = transformed
return SpatialData(**elements, tables=sdata.tables)
return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs)

def elements_are_self_contained(self) -> dict[str, bool]:
"""
Expand Down Expand Up @@ -1179,7 +1165,8 @@ def write(
self._validate_can_safely_write_to_path(file_path, overwrite=overwrite)

store = parse_url(file_path, mode="w").store
_ = zarr.group(store=store, overwrite=overwrite)
zarr_group = zarr.group(store=store, overwrite=overwrite)
self.write_attrs(zarr_group=zarr_group)
store.close()

for element_type, element_name, element in self.gen_elements():
Expand Down Expand Up @@ -1583,7 +1570,36 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s
element_type, element_name = element_path.split("/")
return element_type, element_name

def write_metadata(self, element_name: str | None = None, consolidate_metadata: bool | None = None) -> None:
def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None:
from spatialdata._io.format import _parse_formats

parsed = _parse_formats(formats=format)

store = None

if zarr_group is None:
assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs."
store = parse_url(self.path, mode="r+").store
zarr_group = zarr.group(store=store, overwrite=False)

version = parsed["SpatialData"].spatialdata_format_version
version_specific_attrs = parsed["SpatialData"].attrs_to_dict()
attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs

try:
zarr_group.attrs.put(attrs_to_write)
except TypeError as e:
raise TypeError("Invalid attribute in SpatialData.attrs") from e

if store is not None:
store.close()

def write_metadata(
self,
element_name: str | None = None,
consolidate_metadata: bool | None = None,
write_attrs: bool = True,
) -> None:
"""
Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data.

Expand Down Expand Up @@ -1618,6 +1634,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata:
# TODO: write .uns['spatialdata_attrs'] metadata for AnnData.
# TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame.

if write_attrs:
self.write_attrs()

if consolidate_metadata is None and self.has_consolidated_metadata():
consolidate_metadata = True
if consolidate_metadata:
Expand Down Expand Up @@ -2103,9 +2122,11 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A
return found[0]

@classmethod
@_deprecation_alias(table="tables", version="0.1.0")
def init_from_elements(
cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None
cls,
elements: dict[str, SpatialElement],
tables: AnnData | dict[str, AnnData] | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> SpatialData:
"""
Create a SpatialData object from a dict of named elements and an optional table.
Expand All @@ -2116,6 +2137,8 @@ def init_from_elements(
A dict of named elements.
tables
An optional table or dictionary of tables
attrs
Additional attributes to store in the SpatialData object.

Returns
-------
Expand All @@ -2130,11 +2153,33 @@ def init_from_elements(
element_type = "labels"
elif model == PointsModel:
element_type = "points"
elif model == TableModel:
element_type = "tables"
else:
assert model == ShapesModel
element_type = "shapes"
elements_dict.setdefault(element_type, {})[name] = element
return cls(**elements_dict, tables=tables)
# when the "tables" argument is removed, we can remove all this if block
if tables is not None:
warnings.warn(
'The "tables" argument is deprecated and will be removed in a future version. Please '
"specifies the tables in the `elements` argument. Until the removal occurs, the `elements` "
"variable will be automatically populated with the tables if the `tables` argument is not None.",
DeprecationWarning,
stacklevel=2,
)
if "tables" in elements_dict:
raise ValueError(
"The tables key is already present in the elements dictionary. Please do not specify "
"the `tables` argument."
)
elements_dict["tables"] = {}
if isinstance(tables, AnnData):
elements_dict["tables"]["table"] = tables
else:
for name, table in tables.items():
elements_dict["tables"][name] = table
return cls(**elements_dict, attrs=attrs)

def subset(
self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False
Expand Down Expand Up @@ -2173,7 +2218,7 @@ def subset(
include_orphan_tables,
elements_dict=elements_dict,
)
return SpatialData(**elements_dict, tables=tables)
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)

def __getitem__(self, item: str) -> SpatialElement:
"""
Expand Down Expand Up @@ -2255,6 +2300,45 @@ def __delitem__(self, key: str) -> None:
element_type, _, _ = self._find_element(key)
getattr(self, element_type).__delitem__(key)

@property
def attrs(self) -> dict[Any, Any]:
"""
Dictionary of global attributes on this SpatialData object.

Notes
-----
Operations on SpatialData objects such as `subset()`, `query()`, ..., will pass the `.attrs` by
reference. If you want to modify the `.attrs` without affecting the original object, you should
either use `copy.deepcopy(sdata.attrs)` or eventually copy the SpatialData object using
`spatialdata.deepcopy()`.
"""
return self._attrs

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
"""
Set the global attributes on this SpatialData object.

Parameters
----------
value
The new attributes to set.

Notes
-----
If a dict is passed, the attrs will be passed by reference, else if a mapping is passed,
the mapping will be casted to a dict (shallow copy), i.e. if the mapping contains a dict inside,
that dict will be passed by reference.
"""
if isinstance(value, dict):
# even if we call dict(value), we still get a shallow copy. For example, dict({'a': {'b': 1}}) will return
# a new dict, {'b': 1} is passed by reference. For this reason, we just pass .attrs by reference, which is
# more performant. The user can always use copy.deepcopy(sdata.attrs), or spatialdata.deepcopy(sdata), to
# get the attrs deepcopied.
self._attrs = value
else:
self._attrs = dict(value)


class QueryManager:
"""Perform queries on SpatialData objects."""
Expand Down
Loading
Loading