Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for axes as list in dpnp.ndarray.transpose #1770

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,21 +1317,31 @@ def transpose(self, *axes):

For full documentation refer to :obj:`numpy.ndarray.transpose`.

Parameters
----------
axes : None, tuple or list of ints, n ints, optional
* ``None`` or no argument: reverses the order of the axes.
* tuple or list of ints: `i` in the `j`-th place in the tuple/list
means that the array’s `i`-th axis becomes the transposed
array’s `j`-th axis.
* n ints: same as an n-tuple/n-list of the same ints (this form is
vtavana marked this conversation as resolved.
Show resolved Hide resolved
intended simply as a “convenience” alternative to the tuple form).

Returns
-------
y : dpnp.ndarray
out : dpnp.ndarray
View of the array with its axes suitably permuted.

See Also
--------
:obj:`dpnp.transpose` : Equivalent function.
:obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed.
:obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data.
:obj:`dpnp.transpose` : Equivalent function.
:obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed.
:obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data.

Examples
--------
>>> import dpnp as dp
>>> a = dp.array([[1, 2], [3, 4]])
>>> import dpnp as np
>>> a = np.array([[1, 2], [3, 4]])
>>> a
array([[1, 2],
[3, 4]])
Expand All @@ -1342,7 +1352,7 @@ def transpose(self, *axes):
array([[1, 3],
[2, 4]])

>>> a = dp.array([1, 2, 3, 4])
>>> a = np.array([1, 2, 3, 4])
>>> a
array([1, 2, 3, 4])
>>> a.transpose()
Expand All @@ -1355,7 +1365,7 @@ def transpose(self, *axes):
return self

axes_len = len(axes)
if axes_len == 1 and isinstance(axes[0], tuple):
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
vtavana marked this conversation as resolved.
Show resolved Hide resolved
axes = axes[0]

res = self.__new__(dpnp_array)
Expand Down
27 changes: 14 additions & 13 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,12 +1861,13 @@ def transpose(a, axes=None):
----------
a : {dpnp.ndarray, usm_ndarray}
Input array.
axes : tuple or list of ints, optional
axes : None, tuple or list of ints, optional
If specified, it must be a tuple or list which contains a permutation
of [0, 1, ..., N-1] where N is the number of axes of `a`.
The `i`'th axis of the returned array will correspond to the axis
numbered ``axes[i]`` of the input. If not specified, defaults to
``range(a.ndim)[::-1]``, which reverses the order of the axes.
numbered ``axes[i]`` of the input. If not specified or ``None``,
defaults to ``range(a.ndim)[::-1]``, which reverses the order of
the axes.

Returns
-------
Expand All @@ -1881,35 +1882,35 @@ def transpose(a, axes=None):

Examples
--------
>>> import dpnp as dp
>>> a = dp.array([[1, 2], [3, 4]])
>>> import dpnp as np
>>> a = np.array([[1, 2], [3, 4]])
>>> a
array([[1, 2],
[3, 4]])
>>> dp.transpose(a)
>>> np.transpose(a)
array([[1, 3],
[2, 4]])

>>> a = dp.array([1, 2, 3, 4])
>>> a = np.array([1, 2, 3, 4])
>>> a
array([1, 2, 3, 4])
>>> dp.transpose(a)
>>> np.transpose(a)
array([1, 2, 3, 4])

>>> a = dp.ones((1, 2, 3))
>>> dp.transpose(a, (1, 0, 2)).shape
>>> a = np.ones((1, 2, 3))
>>> np.transpose(a, (1, 0, 2)).shape
(2, 1, 3)

>>> a = dp.ones((2, 3, 4, 5))
>>> dp.transpose(a).shape
>>> a = np.ones((2, 3, 4, 5))
>>> np.transpose(a).shape
(5, 4, 3, 2)

"""

if isinstance(a, dpnp_array):
array = a
elif isinstance(a, dpt.usm_ndarray):
array = dpnp_array._create_from_usm_ndarray(a.get_array())
array = dpnp_array._create_from_usm_ndarray(a)
else:
raise TypeError(
f"An array must be any of supported type, but got {type(a)}"
Expand Down
36 changes: 34 additions & 2 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_unique(array):


class TestTranspose:
@pytest.mark.parametrize("axes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]])
def test_2d_with_axes(self, axes):
na = numpy.array([[1, 2], [3, 4]])
da = dpnp.array(na)
Expand All @@ -124,7 +124,22 @@ def test_2d_with_axes(self, axes):
result = dpnp.transpose(da, axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize("axes", [(1, 0, 2), ((1, 0, 2),)])
# ndarray
expected = na.transpose(axes)
result = da.transpose(axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize(
"axes",
[
(1, 0, 2),
[1, 0, 2],
((1, 0, 2),),
([1, 0, 2],),
[(1, 0, 2)],
[[1, 0, 2]],
],
)
def test_3d_with_packed_axes(self, axes):
na = numpy.ones((1, 2, 3))
da = dpnp.array(na)
Expand All @@ -133,10 +148,27 @@ def test_3d_with_packed_axes(self, axes):
result = da.transpose(*axes)
assert_array_equal(expected, result)

# ndarray
expected = na.transpose(*axes)
result = da.transpose(*axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize("shape", [(10,), (2, 4), (5, 3, 7), (3, 8, 4, 1)])
def test_none_axes(self, shape):
na = numpy.ones(shape)
da = dpnp.ones(shape)

assert_array_equal(numpy.transpose(na), dpnp.transpose(da))
assert_array_equal(numpy.transpose(na, None), dpnp.transpose(da, None))

# ndarray
assert_array_equal(na.transpose(), da.transpose())
assert_array_equal(na.transpose(None), da.transpose(None))

def test_ndarray_axes_n_int(self):
na = numpy.ones((1, 2, 3))
da = dpnp.array(na)

expected = na.transpose(1, 0, 2)
result = da.transpose(1, 0, 2)
assert_array_equal(expected, result)
22 changes: 20 additions & 2 deletions tests/third_party/cupy/manipulation_tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,38 @@ def test_moveaxis_invalid2_2(self):
with pytest.raises(numpy.AxisError):
xp.moveaxis(a, [0, -4], [1, 2])

def test_moveaxis_invalid2_3(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(numpy.AxisError):
xp.moveaxis(a, -4, 0)

# len(source) != len(destination)
def test_moveaxis_invalid3(self):
def test_moveaxis_invalid3_1(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1, 2], [1, 2])

def test_moveaxis_invalid3_2(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, 0, [1, 2])

# len(source) != len(destination)
def test_moveaxis_invalid4(self):
def test_moveaxis_invalid4_1(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1], [1, 2, 0])

def test_moveaxis_invalid4_2(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1], 1)

# Use the same axis twice
def test_moveaxis_invalid5_1(self):
for xp in (numpy, cupy):
Expand Down
Loading