From b25a8ff83f47a2cc9277925715a5f56ceef091a2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:54:41 +0200 Subject: [PATCH] All input data can be arraylike --- xarray/namedarray/core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 19a5cfd66d5..91691d8ace3 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -22,12 +22,11 @@ from xarray.namedarray.utils import T_DuckArray T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") + T_InputData = typing.Union[T_DuckArray, np.typing.ArrayLike] # TODO: Add tests! -def as_compatible_data( - data: T_DuckArray | np.typing.ArrayLike, fastpath: bool = False -) -> T_DuckArray: +def as_compatible_data(data: T_InputData, fastpath: bool = False) -> T_DuckArray: if fastpath and getattr(data, "ndim", 0) > 0: # can't use fastpath (yet) for scalars return typing.cast(T_DuckArray, data) @@ -40,7 +39,8 @@ def as_compatible_data( if isinstance(data, np.ma.MaskedArray): raise ValueError if isinstance(data, ExplicitlyIndexed): - return data + # TODO: better that is_duck_array(ExplicitlyIndexed) -> True + return typing.cast(T_DuckArray, data) if not isinstance(data, np.ndarray) and ( hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") @@ -59,7 +59,7 @@ class NamedArray: def __init__( self, dims: str | Iterable[Hashable], - data: T_DuckArray | np.typing.ArrayLike, + data: T_InputData, attrs: dict | None = None, ): self._data: T_DuckArray = as_compatible_data(data) @@ -186,7 +186,7 @@ def data(self: T_NamedArray): return self._data @data.setter - def data(self: T_NamedArray, data: T_DuckArray) -> None: + def data(self: T_NamedArray, data: T_InputData) -> None: data = as_compatible_data(data) self._check_shape(data) self._data = data @@ -306,7 +306,7 @@ def _replace( def _copy( self: T_NamedArray, deep: bool = True, - data: T_DuckArray | None = None, + data: T_InputData | None = None, memo: dict[int, typing.Any] | None = None, ) -> T_NamedArray: if data is None: @@ -332,7 +332,7 @@ def __deepcopy__( return self._copy(deep=True, memo=memo) def copy( - self: T_NamedArray, deep: bool = True, data: T_DuckArray | None = None + self: T_NamedArray, deep: bool = True, data: T_InputData | None = None ) -> T_NamedArray: """Returns a copy of this object.