Skip to content

Commit

Permalink
All input data can be arraylike
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 11, 2023
1 parent d2971cc commit b25a8ff
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from xarray.namedarray.utils import T_DuckArray

T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray")
T_InputData = typing.Union[T_DuckArray, np.typing.ArrayLike]


# TODO: Add tests!
def as_compatible_data(
data: T_DuckArray | np.typing.ArrayLike, fastpath: bool = False
) -> T_DuckArray:
def as_compatible_data(data: T_InputData, fastpath: bool = False) -> T_DuckArray:
if fastpath and getattr(data, "ndim", 0) > 0:
# can't use fastpath (yet) for scalars
return typing.cast(T_DuckArray, data)
Expand All @@ -40,7 +39,8 @@ def as_compatible_data(
if isinstance(data, np.ma.MaskedArray):
raise ValueError
if isinstance(data, ExplicitlyIndexed):
return data
# TODO: better that is_duck_array(ExplicitlyIndexed) -> True
return typing.cast(T_DuckArray, data)

if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
Expand All @@ -59,7 +59,7 @@ class NamedArray:
def __init__(
self,
dims: str | Iterable[Hashable],
data: T_DuckArray | np.typing.ArrayLike,
data: T_InputData,
attrs: dict | None = None,
):
self._data: T_DuckArray = as_compatible_data(data)
Expand Down Expand Up @@ -186,7 +186,7 @@ def data(self: T_NamedArray):
return self._data

@data.setter
def data(self: T_NamedArray, data: T_DuckArray) -> None:
def data(self: T_NamedArray, data: T_InputData) -> None:
data = as_compatible_data(data)
self._check_shape(data)
self._data = data
Expand Down Expand Up @@ -306,7 +306,7 @@ def _replace(
def _copy(
self: T_NamedArray,
deep: bool = True,
data: T_DuckArray | None = None,
data: T_InputData | None = None,
memo: dict[int, typing.Any] | None = None,
) -> T_NamedArray:
if data is None:
Expand All @@ -332,7 +332,7 @@ def __deepcopy__(
return self._copy(deep=True, memo=memo)

def copy(
self: T_NamedArray, deep: bool = True, data: T_DuckArray | None = None
self: T_NamedArray, deep: bool = True, data: T_InputData | None = None
) -> T_NamedArray:
"""Returns a copy of this object.
Expand Down

0 comments on commit b25a8ff

Please sign in to comment.