diff --git a/xarray/conventions.py b/xarray/conventions.py index c3a05e42f82..8bd25776762 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -1,5 +1,6 @@ import warnings from collections import defaultdict +from copy import copy import numpy as np import pandas as pd @@ -9,6 +10,7 @@ from .core import duck_array_ops, indexing from .core.common import contains_cftime_datetimes from .core.pycompat import is_duck_dask_array +from .core.utils import maybe_coerce_to_dict from .core.variable import IndexVariable, Variable, as_variable CF_RELATED_DATA = ( @@ -95,7 +97,7 @@ def __getitem__(self, key): def _var_as_tuple(var): - return var.dims, var.data, var.attrs.copy(), var.encoding.copy() + return var.dims, var.data, copy(var.attrs), var.encoding.copy() def maybe_encode_nonstring_dtype(var, name=None): @@ -562,7 +564,7 @@ def stackable(dim): del var_attrs[attr_name] if decode_coords and "coordinates" in attributes: - attributes = dict(attributes) + attributes = maybe_coerce_to_dict(attributes) coord_names.update(attributes.pop("coordinates").split()) return new_vars, attributes, coord_names @@ -786,7 +788,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): # http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/007571.html global_coordinates.difference_update(written_coords) if global_coordinates: - attributes = dict(attributes) + attributes = copy(attributes) if "coordinates" in attributes: warnings.warn( f"cannot serialize global coordinates {global_coordinates!r} because the global " diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 89f916db7f4..10ffe67aaa4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2,6 +2,7 @@ import datetime import warnings +from copy import copy from typing import ( TYPE_CHECKING, Any, @@ -11,6 +12,7 @@ Iterable, List, Mapping, + MutableMapping, Optional, Sequence, Tuple, @@ -575,7 +577,7 @@ def to_dataset( result = self._to_dataset_whole(name) if promote_attrs: - result.attrs = dict(self.attrs) + result.attrs = copy(self.attrs) return result @@ -788,9 +790,9 @@ def loc(self) -> _LocIndexer: """Attribute for location based indexing like pandas.""" return _LocIndexer(self) - @property # Key type needs to be `Any` because of mypy#4167 - def attrs(self) -> Dict[Any, Any]: + @property + def attrs(self) -> MutableMapping[Any, Any]: """Dictionary storing arbitrary metadata with this array.""" return self.variable.attrs diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e882495dce5..43e01d99bfb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -94,6 +94,7 @@ infix_dims, is_dict_like, is_scalar, + maybe_coerce_to_dict, maybe_wrap_array, ) from .variable import ( @@ -697,7 +698,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): description: Weather related data. """ - _attrs: Optional[Dict[Hashable, Any]] + _attrs: Optional[MutableMapping[Any, Any]] _cache: Dict[str, Any] _coord_names: Set[Hashable] _dims: Dict[Hashable, int] @@ -752,7 +753,9 @@ def __init__( data_vars, coords, compat="broadcast_equals" ) - self._attrs = dict(attrs) if attrs is not None else None + self._attrs = None + if attrs is not None: + self.attrs = attrs # type: ignore[assignment] # https://github.com/python/mypy/issues/3004 self._close = None self._encoding = None self._variables = variables @@ -784,7 +787,7 @@ def variables(self) -> Mapping[Hashable, Variable]: return Frozen(self._variables) @property - def attrs(self) -> Dict[Hashable, Any]: + def attrs(self) -> MutableMapping[Any, Any]: """Dictionary of global attributes on this dataset""" if self._attrs is None: self._attrs = {} @@ -792,7 +795,7 @@ def attrs(self) -> Dict[Hashable, Any]: @attrs.setter def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) + self._attrs = maybe_coerce_to_dict(value) @property def encoding(self) -> Dict: @@ -1096,8 +1099,8 @@ def _replace( variables: Dict[Hashable, Variable] = None, coord_names: Set[Hashable] = None, dims: Dict[Any, int] = None, - attrs: Union[Dict[Hashable, Any], None, Default] = _default, - indexes: Union[Dict[Hashable, Index], None, Default] = _default, + attrs: Union[MutableMapping[Any, Any], None, Default] = _default, + indexes: Union[Dict[Any, Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, ) -> "Dataset": @@ -1145,7 +1148,7 @@ def _replace_with_new_dims( self, variables: Dict[Hashable, Variable], coord_names: set = None, - attrs: Union[Dict[Hashable, Any], None, Default] = _default, + attrs: Union[MutableMapping[Any, Any], None, Default] = _default, indexes: Union[Dict[Hashable, Index], None, Default] = _default, inplace: bool = False, ) -> "Dataset": @@ -1160,7 +1163,7 @@ def _replace_vars_and_dims( variables: Dict[Hashable, Variable], coord_names: set = None, dims: Dict[Hashable, int] = None, - attrs: Union[Dict[Hashable, Any], None, Default] = _default, + attrs: Union[MutableMapping[Any, Any], None, Default] = _default, inplace: bool = False, ) -> "Dataset": """Deprecated version of _replace_with_new_dims(). @@ -6996,7 +6999,7 @@ def polyfit( covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j")) * fac variables[name + "polyfit_covariance"] = covariance - return Dataset(data_vars=variables, attrs=self.attrs.copy()) + return Dataset(data_vars=variables, attrs=copy.copy(self.attrs)) def pad( self, @@ -7726,6 +7729,6 @@ def _wrapper(Y, *coords_, **kwargs): result = result.assign_coords( {"param": params, "cov_i": params, "cov_j": params} ) - result.attrs = self.attrs.copy() + result.attrs = copy.copy(self.attrs) return result diff --git a/xarray/core/merge.py b/xarray/core/merge.py index a89e767826d..d2f3724ef3c 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import copy from typing import ( TYPE_CHECKING, AbstractSet, @@ -524,9 +525,9 @@ def merge_attrs(variable_attrs, combine_attrs, context=None): elif combine_attrs == "drop": return {} elif combine_attrs == "override": - return dict(variable_attrs[0]) + return copy(variable_attrs[0]) elif combine_attrs == "no_conflicts": - result = dict(variable_attrs[0]) + result = copy(variable_attrs[0]) for attrs in variable_attrs[1:]: try: result = compat_dict_union(result, attrs) @@ -555,7 +556,7 @@ def merge_attrs(variable_attrs, combine_attrs, context=None): dropped_keys |= {key for key in attrs if key not in result} return result elif combine_attrs == "identical": - result = dict(variable_attrs[0]) + result = copy(variable_attrs[0]) for attrs in variable_attrs[1:]: if not dict_equiv(result, attrs): raise MergeError( diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ebf6d7e28ed..1f66a4c6928 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -7,6 +7,7 @@ import re import sys import warnings +from copy import copy from enum import Enum from typing import ( TYPE_CHECKING, @@ -95,6 +96,45 @@ def maybe_coerce_to_str(index, original_coords): return index +# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without +# requiring typing_extensions as a required dependency to _run_ the code (it is required +# to type-check). +try: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard +except (ImportError, NameError) as e: + if TYPE_CHECKING: + raise e + else: + + def _is_MutableMapping(obj: Mapping[Any, Any]) -> bool: + """Check if the object is a mutable mapping.""" + return hasattr(obj, "__setitem__") + + +else: + + def _is_MutableMapping( + obj: Mapping[Any, Any] + ) -> TypeGuard[MutableMapping[Any, Any]]: + """Check if the object is a mutable mapping.""" + return hasattr(obj, "__setitem__") + + +def maybe_coerce_to_dict(obj: Mapping[Any, Any]) -> MutableMapping[Any, Any]: + """Convert to dict if the object is not a valid dict-like.""" + # if isinstance(obj, MutableMapping): + if _is_MutableMapping(obj): + # if hasattr(obj, "update"): + # return obj.copy() + return copy(obj) + # return obj + else: + return dict(obj) + + def safe_cast_to_index(array: Any) -> pd.Index: """Given an array, safely cast it to a pandas.Index. @@ -417,7 +457,7 @@ def compat_dict_intersection( def compat_dict_union( - first_dict: Mapping[K, V], + first_dict: MutableMapping[K, V], second_dict: Mapping[K, V], compat: Callable[[V, V], bool] = equivalent, ) -> MutableMapping[K, V]: @@ -439,7 +479,7 @@ def compat_dict_union( union : dict union of the contents. """ - new_dict = dict(first_dict) + new_dict = copy(first_dict) update_safety_check(first_dict, second_dict, compat) new_dict.update(second_dict) return new_dict diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 52125ec4113..8de1ac17e9c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,6 +13,7 @@ Hashable, List, Mapping, + MutableMapping, Optional, Sequence, Tuple, @@ -55,6 +56,7 @@ ensure_us_time_resolution, infix_dims, is_duck_array, + maybe_coerce_to_dict, maybe_coerce_to_str, ) @@ -286,9 +288,18 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): they can use more complete metadata in context of coordinate labels. """ + _attrs: Optional[MutableMapping[Any, Any]] + __slots__ = ("_dims", "_data", "_attrs", "_encoding") - def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + def __init__( + self, + dims, + data, + attrs: Optional[Mapping[Any, Any]] = None, + encoding=None, + fastpath=False, + ): """ Parameters ---------- @@ -313,7 +324,7 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): self._attrs = None self._encoding = None if attrs is not None: - self.attrs = attrs + self.attrs = attrs # type: ignore[assignment] # https://github.com/python/mypy/issues/3004 if encoding is not None: self.encoding = encoding @@ -863,7 +874,7 @@ def __setitem__(self, key, value): indexable[index_tuple] = value @property - def attrs(self) -> Dict[Hashable, Any]: + def attrs(self) -> MutableMapping[Any, Any]: """Dictionary of local attributes on this variable.""" if self._attrs is None: self._attrs = {} @@ -871,7 +882,7 @@ def attrs(self) -> Dict[Hashable, Any]: @attrs.setter def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) + self._attrs = maybe_coerce_to_dict(value) @property def encoding(self): @@ -2602,9 +2613,18 @@ class IndexVariable(Variable): unless another name is given. """ + _attrs: Optional[MutableMapping[Any, Any]] + __slots__ = () - def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + def __init__( + self, + dims, + data, + attrs: Optional[Mapping[Any, Any]] = None, + encoding=None, + fastpath=False, + ): super().__init__(dims, data, attrs, encoding, fastpath) if self.ndim != 1: raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")