From 96ac4ec3f8c7be27b2e8ec9cf4cf3f613d17c7c0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 11 Sep 2023 23:42:34 +0200 Subject: [PATCH] Update core.py --- xarray/namedarray/core.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 08d3588cde7..bce0e1042c8 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -22,6 +22,8 @@ from xarray.namedarray.utils import T_DuckArray T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") + DimsInput = typing.Union[str, Iterable[Hashable]] + Dims = tuple[Hashable, ...] # TODO: Add tests! @@ -43,10 +45,10 @@ def as_compatible_data( # TODO: better that is_duck_array(ExplicitlyIndexed) -> True return typing.cast(T_DuckArray, data) - if not isinstance(data, np.ndarray) and ( - hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") - ): - return typing.cast(T_DuckArray, data) + # if not isinstance(data, np.ndarray) and ( + # hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") + # ): + # return typing.cast(T_DuckArray, data) if isinstance(data, tuple): data = to_0d_object_array(data) @@ -59,12 +61,12 @@ class NamedArray: def __init__( self, - dims: str | Iterable[Hashable], + dims: DimsInput, data: T_DuckArray | np.typing.ArrayLike, attrs: dict | None = None, ): self._data: T_DuckArray = as_compatible_data(data) - self._dims: tuple[Hashable, ...] = self._parse_dimensions(dims) + self._dims: Dims = self._parse_dimensions(dims) self._attrs: dict | None = dict(attrs) if attrs else None @property @@ -134,17 +136,15 @@ def nbytes(self: T_NamedArray) -> int: return self.size * self.dtype.itemsize @property - def dims(self: T_NamedArray) -> tuple[Hashable, ...]: + def dims(self: T_NamedArray) -> Dims: """Tuple of dimension names with which this variable is associated.""" return self._dims @dims.setter - def dims(self: T_NamedArray, value: str | Iterable[Hashable]) -> None: + def dims(self: T_NamedArray, value: DimsInput) -> None: self._dims = self._parse_dimensions(value) - def _parse_dimensions( - self: T_NamedArray, dims: str | Iterable[Hashable] - ) -> tuple[Hashable, ...]: + def _parse_dimensions(self: T_NamedArray, dims: DimsInput) -> Dims: dims = (dims,) if isinstance(dims, str) else tuple(dims) if len(dims) != self.ndim: raise ValueError( @@ -399,7 +399,7 @@ def _nonzero(self: T_NamedArray) -> tuple[T_NamedArray, ...]: # TODO we should replace dask's native nonzero # after https://github.com/dask/dask/issues/1076 is implemented. nonzeros = np.nonzero(self.data) - return tuple(type(self)((dim), nz) for nz, dim in zip(nonzeros, self.dims)) + return tuple(type(self)((dim,), nz) for nz, dim in zip(nonzeros, self.dims)) def _as_sparse( self: T_NamedArray,