Skip to content

Commit

Permalink
fix NamedArray.imag and NamedArray.real typing info (#8369)
Browse files Browse the repository at this point in the history
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
  • Loading branch information
andersy005 and Illviljan authored Oct 25, 2023
1 parent ccc8f99 commit 70c4ee7
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
22 changes: 22 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,28 @@ def notnull(self, keep_attrs: bool | None = None):
keep_attrs=keep_attrs,
)

@property
def imag(self) -> Variable:
"""
The imaginary part of the variable.
See Also
--------
numpy.ndarray.imag
"""
return self._new(data=self.data.imag)

@property
def real(self) -> Variable:
"""
The real part of the variable.
See Also
--------
numpy.ndarray.real
"""
return self._new(data=self.data.real)

def __array_wrap__(self, obj, context=None):
return Variable(self.dims, obj)

Expand Down
8 changes: 8 additions & 0 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ def __array_function__(
) -> Any:
...

@property
def imag(self) -> _arrayfunction[_ShapeType_co, Any]:
...

@property
def real(self) -> _arrayfunction[_ShapeType_co, Any]:
...


# Corresponds to np.typing.NDArray:
_ArrayFunction = _arrayfunction[Any, np.dtype[_ScalarType_co]]
Expand Down
25 changes: 21 additions & 4 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
from xarray.core import dtypes, formatting, formatting_html
from xarray.namedarray._aggregations import NamedArrayAggregations
from xarray.namedarray._typing import (
_arrayapi,
_arrayfunction_or_api,
_chunkedarray,
_dtype,
_DType_co,
_ScalarType_co,
_ShapeType_co,
_SupportsImag,
_SupportsReal,
)
from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array

Expand Down Expand Up @@ -513,26 +517,39 @@ def data(self, data: duckarray[Any, _DType_co]) -> None:
self._data = data

@property
def imag(self) -> Self:
def imag(
self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
) -> NamedArray[_ShapeType, _dtype[_ScalarType]]:
"""
The imaginary part of the array.
See Also
--------
numpy.ndarray.imag
"""
return self._replace(data=self.data.imag) # type: ignore
if isinstance(self._data, _arrayapi):
from xarray.namedarray._array_api import imag

return imag(self)

return self._new(data=self._data.imag)

@property
def real(self) -> Self:
def real(
self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
) -> NamedArray[_ShapeType, _dtype[_ScalarType]]:
"""
The real part of the array.
See Also
--------
numpy.ndarray.real
"""
return self._replace(data=self.data.real) # type: ignore
if isinstance(self._data, _arrayapi):
from xarray.namedarray._array_api import real

return real(self)
return self._new(data=self._data.real)

def __dask_tokenize__(self) -> Hashable:
# Use v.data, instead of v._data, in order to cope with the wrappers
Expand Down
26 changes: 20 additions & 6 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,25 @@ def test_data(random_inputs: np.ndarray[Any, Any]) -> None:


def test_real_and_imag() -> None:
named_array: NamedArray[Any, Any]
named_array = NamedArray(["x"], np.arange(3) - 1j * np.arange(3))
expected_real = np.arange(3)
assert np.array_equal(named_array.real.data, expected_real)
expected_real: np.ndarray[Any, np.dtype[np.float64]]
expected_real = np.arange(3, dtype=np.float64)

expected_imag: np.ndarray[Any, np.dtype[np.float64]]
expected_imag = -np.arange(3, dtype=np.float64)

arr: np.ndarray[Any, np.dtype[np.complex128]]
arr = expected_real + 1j * expected_imag

named_array: NamedArray[Any, np.dtype[np.complex128]]
named_array = NamedArray(["x"], arr)

actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data
assert np.array_equal(actual_real, expected_real)
assert actual_real.dtype == expected_real.dtype

expected_imag = -np.arange(3)
assert np.array_equal(named_array.imag.data, expected_imag)
actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data
assert np.array_equal(actual_imag, expected_imag)
assert actual_imag.dtype == expected_imag.dtype


# Additional tests as per your original class-based code
Expand Down Expand Up @@ -347,7 +359,9 @@ def _new(

def test_replace_namedarray() -> None:
dtype_float = np.dtype(np.float32)
np_val: np.ndarray[Any, np.dtype[np.float32]]
np_val = np.array([1.5, 3.2], dtype=dtype_float)
np_val2: np.ndarray[Any, np.dtype[np.float32]]
np_val2 = 2 * np_val

narr_float: NamedArray[Any, np.dtype[np.float32]]
Expand Down

0 comments on commit 70c4ee7

Please sign in to comment.