Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 11, 2023
1 parent 06d77ad commit 96ac4ec
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from xarray.namedarray.utils import T_DuckArray

T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray")
DimsInput = typing.Union[str, Iterable[Hashable]]
Dims = tuple[Hashable, ...]


# TODO: Add tests!
Expand All @@ -43,10 +45,10 @@ def as_compatible_data(
# TODO: better that is_duck_array(ExplicitlyIndexed) -> True
return typing.cast(T_DuckArray, data)

if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return typing.cast(T_DuckArray, data)
# if not isinstance(data, np.ndarray) and (
# hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
# ):
# return typing.cast(T_DuckArray, data)
if isinstance(data, tuple):
data = to_0d_object_array(data)

Expand All @@ -59,12 +61,12 @@ class NamedArray:

def __init__(
self,
dims: str | Iterable[Hashable],
dims: DimsInput,
data: T_DuckArray | np.typing.ArrayLike,
attrs: dict | None = None,
):
self._data: T_DuckArray = as_compatible_data(data)
self._dims: tuple[Hashable, ...] = self._parse_dimensions(dims)
self._dims: Dims = self._parse_dimensions(dims)
self._attrs: dict | None = dict(attrs) if attrs else None

@property
Expand Down Expand Up @@ -134,17 +136,15 @@ def nbytes(self: T_NamedArray) -> int:
return self.size * self.dtype.itemsize

@property
def dims(self: T_NamedArray) -> tuple[Hashable, ...]:
def dims(self: T_NamedArray) -> Dims:
"""Tuple of dimension names with which this variable is associated."""
return self._dims

@dims.setter
def dims(self: T_NamedArray, value: str | Iterable[Hashable]) -> None:
def dims(self: T_NamedArray, value: DimsInput) -> None:
self._dims = self._parse_dimensions(value)

def _parse_dimensions(
self: T_NamedArray, dims: str | Iterable[Hashable]
) -> tuple[Hashable, ...]:
def _parse_dimensions(self: T_NamedArray, dims: DimsInput) -> Dims:
dims = (dims,) if isinstance(dims, str) else tuple(dims)
if len(dims) != self.ndim:
raise ValueError(
Expand Down Expand Up @@ -399,7 +399,7 @@ def _nonzero(self: T_NamedArray) -> tuple[T_NamedArray, ...]:
# TODO we should replace dask's native nonzero
# after https://github.com/dask/dask/issues/1076 is implemented.
nonzeros = np.nonzero(self.data)
return tuple(type(self)((dim), nz) for nz, dim in zip(nonzeros, self.dims))
return tuple(type(self)((dim,), nz) for nz, dim in zip(nonzeros, self.dims))

def _as_sparse(
self: T_NamedArray,
Expand Down

0 comments on commit 96ac4ec

Please sign in to comment.