Skip to content

Commit

Permalink
Allow ellipsis (...) in transpose (#3421)
Browse files Browse the repository at this point in the history
* infix_dims function

* implement transpose with ellipsis

* also infix in dataarray

* check errors centrally, remove boilerplate from transpose methods

* whatsnew

* docs

* remove old comments

* generator->iterator

* test for differently ordered dimensions
  • Loading branch information
max-sixty authored and dcherian committed Oct 28, 2019
1 parent fb0cf7b commit 02288b4
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 12 deletions.
4 changes: 3 additions & 1 deletion doc/reshaping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ Reordering dimensions
---------------------

To reorder dimensions on a :py:class:`~xarray.DataArray` or across all variables
on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`:
on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`. An
ellipsis (`...`) can be use to represent all other dimensions:

.. ipython:: python
ds = xr.Dataset({'foo': (('x', 'y', 'z'), [[[42]]]), 'bar': (('y', 'z'), [[24]])})
ds.transpose('y', 'z', 'x')
ds.transpose(..., 'x') # equivalent
ds.transpose() # reverses all dimensions
Expand and squeeze dimensions
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`Dataset.transpose` and :py:meth:`DataArray.transpose` now support an ellipsis (`...`)
to represent all 'other' dimensions. For example, to move one dimension to the front,
use `.transpose('x', ...)`. (:pull:`3421`)
By `Maximilian Roos <https://github.com/max-sixty>`_
- Changed `xr.ALL_DIMS` to equal python's `Ellipsis` (`...`), and changed internal usages to use
`...` directly. As before, you can use this to instruct a `groupby` operation
to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,7 @@ tag_prefix = v
parentdir_prefix = xarray-

[aliases]
test = pytest
test = pytest

[pytest-watch]
nobeep = True
7 changes: 1 addition & 6 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,12 +1863,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArra
Dataset.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
raise ValueError(
"arguments to transpose (%s) must be "
"permuted array dimensions (%s)" % (dims, tuple(self.dims))
)

dims = tuple(utils.infix_dims(dims, self.dims))
variable = self.variable.transpose(*dims)
if transpose_coords:
coords: Dict[Hashable, Variable] = {}
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3712,14 +3712,14 @@ def transpose(self, *dims: Hashable) -> "Dataset":
DataArray.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
if set(dims) ^ set(self.dims) and ... not in dims:
raise ValueError(
"arguments to transpose (%s) must be "
"permuted dataset dimensions (%s)" % (dims, tuple(self.dims))
)
ds = self.copy()
for name, var in self._variables.items():
var_dims = tuple(dim for dim in dims if dim in var.dims)
var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,)))
ds._variables[name] = var.transpose(*var_dims)
return ds

Expand Down
25 changes: 25 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AbstractSet,
Any,
Callable,
Collection,
Container,
Dict,
Hashable,
Expand Down Expand Up @@ -660,6 +661,30 @@ def __len__(self) -> int:
return len(self._data) - num_hidden


def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator:
"""
Resolves a supplied list containing an ellispsis representing other items, to
a generator with the 'realized' list of all items
"""
if ... in dims_supplied:
if len(set(dims_all)) != len(dims_all):
raise ValueError("Cannot use ellipsis with repeated dims")
if len([d for d in dims_supplied if d == ...]) > 1:
raise ValueError("More than one ellipsis supplied")
other_dims = [d for d in dims_all if d not in dims_supplied]
for d in dims_supplied:
if d == ...:
yield from other_dims
else:
yield d
else:
if set(dims_supplied) ^ set(dims_all):
raise ValueError(
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
)
yield from dims_supplied


def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
""" Get an new dimension name based on new_dim, that is not used in dims.
If the same name exists, we add an underscore(s) in the head.
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
OrderedSet,
decode_numpy_dict_values,
either_dict_or_kwargs,
infix_dims,
ensure_us_time_resolution,
)

Expand Down Expand Up @@ -1228,6 +1229,7 @@ def transpose(self, *dims) -> "Variable":
"""
if len(dims) == 0:
dims = self.dims[::-1]
dims = tuple(infix_dims(dims, self.dims))
axes = self.get_axis_num(dims)
if len(dims) < 2: # no need to transpose if only one dimension
return self.copy(deep=False)
Expand Down
3 changes: 3 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,21 @@ def source_ndarray(array):


def assert_equal(a, b):
__tracebackhide__ = True
xarray.testing.assert_equal(a, b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_identical(a, b):
__tracebackhide__ = True
xarray.testing.assert_identical(a, b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_allclose(a, b, **kwargs):
__tracebackhide__ = True
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
4 changes: 4 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,10 @@ def test_transpose(self):
)
assert_equal(expected, actual)

# same as previous but with ellipsis
actual = da.transpose("z", ..., "x", transpose_coords=True)
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose("x", "y")

Expand Down
27 changes: 25 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4675,6 +4675,10 @@ def test_dataset_transpose(self):
)
assert_identical(expected, actual)

actual = ds.transpose(...)
expected = ds
assert_identical(expected, actual)

actual = ds.transpose("x", "y")
expected = ds.apply(lambda x: x.transpose("x", "y", transpose_coords=True))
assert_identical(expected, actual)
Expand All @@ -4690,13 +4694,32 @@ def test_dataset_transpose(self):
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
assert actual[k].dims == expected_dims

with raises_regex(ValueError, "arguments to transpose"):
# same as above but with ellipsis
new_order = ("dim2", "dim3", "dim1", "time")
actual = ds.transpose("dim2", "dim3", ...)
for k in ds.variables:
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
assert actual[k].dims == expected_dims

with raises_regex(ValueError, "permuted"):
ds.transpose("dim1", "dim2", "dim3")
with raises_regex(ValueError, "arguments to transpose"):
with raises_regex(ValueError, "permuted"):
ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim")

assert "T" not in dir(ds)

def test_dataset_ellipsis_transpose_different_ordered_vars(self):
# https://github.com/pydata/xarray/issues/1081#issuecomment-544350457
ds = Dataset(
dict(
a=(("w", "x", "y", "z"), np.ones((2, 3, 4, 5))),
b=(("x", "w", "y", "z"), np.zeros((3, 2, 4, 5))),
)
)
result = ds.transpose(..., "z", "y")
assert list(result["a"].dims) == list("wxzy")
assert list(result["b"].dims) == list("xwzy")

def test_dataset_retains_period_index_on_transpose(self):

ds = create_test_data()
Expand Down
24 changes: 24 additions & 0 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,27 @@ def test_either_dict_or_kwargs():

with pytest.raises(ValueError, match=r"foo"):
result = either_dict_or_kwargs(dict(a=1), dict(a=1), "foo")


@pytest.mark.parametrize(
["supplied", "all_", "expected"],
[
(list("abc"), list("abc"), list("abc")),
(["a", ..., "c"], list("abc"), list("abc")),
(["a", ...], list("abc"), list("abc")),
(["c", ...], list("abc"), list("cab")),
([..., "b"], list("abc"), list("acb")),
([...], list("abc"), list("abc")),
],
)
def test_infix_dims(supplied, all_, expected):
result = list(utils.infix_dims(supplied, all_))
assert result == expected


@pytest.mark.parametrize(
["supplied", "all_"], [([..., ...], list("abc")), ([...], list("aac"))]
)
def test_infix_dims_errors(supplied, all_):
with pytest.raises(ValueError):
list(utils.infix_dims(supplied, all_))
3 changes: 3 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,9 @@ def test_transpose(self):
w2 = Variable(["d", "b", "c", "a"], np.einsum("abcd->dbca", x))
assert w2.shape == (5, 3, 4, 2)
assert_identical(w2, w.transpose("d", "b", "c", "a"))
assert_identical(w2, w.transpose("d", ..., "a"))
assert_identical(w2, w.transpose("d", "b", "c", ...))
assert_identical(w2, w.transpose(..., "b", "c", "a"))
assert_identical(w, w2.transpose("a", "b", "c", "d"))
w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x))
assert_identical(w, w3.transpose("a", "b", "c", "d"))
Expand Down

0 comments on commit 02288b4

Please sign in to comment.