Skip to content

Error unstacking array API compliant class #8666

Closed
@TomNicholas

Description

@TomNicholas

What happened?

Unstacking fails for array types that strictly follow the array API standard.

What did you expect to happen?

This obviously works fine with a normal numpy array.

Minimal Complete Verifiable Example

import numpy.array_api as nxp

arr = nxp.asarray([[1, 2, 3], [4, 5, 6]], dtype=np.dtype('float32'))

da = xr.DataArray(
    arr,
    coords=[("x", ["a", "b"]), ("y", [0, 1, 2])],
)
da
stacked = da.stack(z=("x", "y"))
stacked.indexes["z"]
stacked.unstack()

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[65], line 8
      6 stacked = da.stack(z=("x", "y"))
      7 stacked.indexes["z"]
----> 8 roundtripped = stacked.unstack()
      9 arr.identical(roundtripped)

File ~/Documents/Work/Code/xarray/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
    111     kwargs.update({name: arg for name, arg in zip_args})
    113     return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)

File ~/Documents/Work/Code/xarray/xarray/core/dataarray.py:2913, in DataArray.unstack(self, dim, fill_value, sparse)
   2851 @_deprecate_positional_args("v2023.10.0")
   2852 def unstack(
   2853     self,
   (...)
   2857     sparse: bool = False,
   2858 ) -> Self:
   2859     """
   2860     Unstack existing dimensions corresponding to MultiIndexes into
   2861     multiple new dimensions.
   (...)
   2911     DataArray.stack
   2912     """
-> 2913     ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse)
   2914     return self._from_temp_dataset(ds)

File ~/Documents/Work/Code/xarray/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
    111     kwargs.update({name: arg for name, arg in zip_args})
    113     return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)

File ~/Documents/Work/Code/xarray/xarray/core/dataset.py:5581, in Dataset.unstack(self, dim, fill_value, sparse)
   5579 for d in dims:
   5580     if needs_full_reindex:
-> 5581         result = result._unstack_full_reindex(
   5582             d, stacked_indexes[d], fill_value, sparse
   5583         )
   5584     else:
   5585         result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)

File ~/Documents/Work/Code/xarray/xarray/core/dataset.py:5474, in Dataset._unstack_full_reindex(self, dim, index_and_vars, fill_value, sparse)
   5472 if name not in index_vars:
   5473     if dim in var.dims:
-> 5474         variables[name] = var.unstack({dim: new_dim_sizes})
   5475     else:
   5476         variables[name] = var

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1684, in Variable.unstack(self, dimensions, **dimensions_kwargs)
   1682 result = self
   1683 for old_dim, dims in dimensions.items():
-> 1684     result = result._unstack_once_full(dims, old_dim)
   1685 return result

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1574, in Variable._unstack_once_full(self, dim, old_dim)
   1571 reordered = self.transpose(*dim_order)
   1573 new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes
-> 1574 new_data = reordered.data.reshape(new_shape)
   1575 new_dims = reordered.dims[: len(other_dims)] + new_dim_names
   1577 return type(self)(
   1578     new_dims, new_data, self._attrs, self._encoding, fastpath=True
   1579 )

AttributeError: 'Array' object has no attribute 'reshape'

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

It fails on the arr.reshape call, because the array API standard has reshape be a function, not a method.

We do in fact have an array API-compatible version of reshape defined in duck_array_ops.py, it just apparently isn't yet used everywhere we call reshape.

def reshape(array, shape):

Environment

main branch of xarray, numpy 1.26.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions