Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix automatic broadcasting when wrapping array api class #8669

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- Fix bug with broadcasting when wrapping array API-compliant classes. (:issue:`8665`, :pull:`8669`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant classes. (:issue:`8666`, :pull:`8668`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Expand Down
3 changes: 2 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,8 @@ def set_dims(self, dims, shape=None):
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
else:
expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)]
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
expanded_data = self.data[indexer]

expanded_var = Variable(
expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
assert_equal(a, e)


def test_broadcast_during_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
np_arr, xp_arr = arrays
np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x")
xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x")

expected = np_arr * np_arr2
actual = xp_arr * xp_arr2
assert isinstance(actual.data, Array)
assert_equal(actual, expected)

expected = np_arr2 * np_arr
actual = xp_arr2 * xp_arr
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
np_arr, xp_arr = arrays
expected = xr.concat((np_arr, np_arr), dim="x")
Expand Down
Loading