Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leverage dpctl.tensor.expand_dims()/swapaxes() implementation #1532

Merged
merged 12 commits into from
Aug 23, 2023
30 changes: 0 additions & 30 deletions dpnp/dpnp_algo/dpnp_algo_manipulation.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ and the rest of the library
__all__ += [
"dpnp_atleast_2d",
"dpnp_atleast_3d",
"dpnp_expand_dims",
"dpnp_repeat",
"dpnp_reshape",
]
Expand Down Expand Up @@ -104,35 +103,6 @@ cpdef utils.dpnp_descriptor dpnp_atleast_3d(utils.dpnp_descriptor arr):
return arr


cpdef utils.dpnp_descriptor dpnp_expand_dims(utils.dpnp_descriptor in_array, axis):
axis_tuple = utils._object_to_tuple(axis)
result_ndim = len(axis_tuple) + in_array.ndim

if len(axis_tuple) == 0:
axis_ndim = 0
else:
axis_ndim = max(-min(0, min(axis_tuple)), max(0, max(axis_tuple))) + 1

axis_norm = utils._object_to_tuple(utils.normalize_axis(axis_tuple, result_ndim))

if axis_ndim - len(axis_norm) > in_array.ndim:
utils.checker_throw_axis_error("dpnp_expand_dims", "axis", axis, axis_ndim)

if len(axis_norm) > len(set(axis_norm)):
utils.checker_throw_value_error("dpnp_expand_dims", "axis", axis, "no repeated axis")

cdef shape_type_c shape_list
axis_idx = 0
for i in range(result_ndim):
if i in axis_norm:
shape_list.push_back(1)
else:
shape_list.push_back(in_array.shape[axis_idx])
axis_idx = axis_idx + 1

return dpnp_reshape(in_array, shape_list)
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved


cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, axes=None):
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)

Expand Down
9 changes: 8 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,14 @@ def sum(
where=where,
)

# 'swapaxes',
def swapaxes(self, axis1, axis2):
"""
Interchange two axes of an array.

For full documentation refer to :obj:`numpy.swapaxes`.
"""

return dpnp.swapaxes(self, axis1=axis1, axis2=axis2)

def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
"""
Expand Down
82 changes: 59 additions & 23 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def copyto(dst, src, casting="same_kind", where=True):
dst_usm[mask_usm] = src_usm[mask_usm]


def expand_dims(x1, axis):
def expand_dims(x, axis):
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand the shape of an array.

Expand All @@ -397,6 +397,30 @@ def expand_dims(x1, axis):

For full documentation refer to :obj:`numpy.expand_dims`.

Returns
-------
dpnp.ndarray
An array with the number of dimensions increased.
A view is returned whenever possible.

Limitations
-----------
Parameters `x` is supported either as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Input array data types are limited by supported DPNP :ref:`Data types`.
Otherwise ``TypeError`` exception will be raised.

Notes
-----
If `x` has rank (i.e, number of dimensions) `N`, a valid `axis` must reside
in the closed-interval `[-N-1, N]`.
If provided a negative `axis`, the `axis` position at which to insert a
singleton dimension is computed as `N + axis + 1`.
Hence, if provided `-1`, the resolved axis position is `N` (i.e.,
a singleton dimension must be appended to the input array `x`).
If provided `-N-1`, the resolved axis position is `0` (i.e., a
singleton dimension is prepended to the input array `x`).

See Also
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
--------
:obj:`dpnp.squeeze` : The inverse operation, removing singleton dimensions
Expand Down Expand Up @@ -446,11 +470,15 @@ def expand_dims(x1, axis):

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
if x1_desc:
return dpnp_expand_dims(x1_desc, axis).get_pyobj()
if not dpnp.is_supported_array_type(x):
raise TypeError(
f"An array must be any of supported type, but got {type(x)}"
)
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

return call_origin(numpy.expand_dims, x1, axis)
dpt_array = dpnp.get_usm_ndarray(x)
return dpnp_array._create_from_usm_ndarray(
dpt.expand_dims(dpt_array, axis=axis)
)


def hstack(tup):
Expand Down Expand Up @@ -921,12 +949,17 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
)


def swapaxes(x1, axis1, axis2):
def swapaxes(x, axis1, axis2):
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
"""
Interchange two axes of an array.

For full documentation refer to :obj:`numpy.swapaxes`.

Returns
-------
dpnp.ndarray
An array with with swapped axes.
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

Limitations
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
-----------
Input array is supported as :obj:`dpnp.ndarray`.
Expand All @@ -935,6 +968,18 @@ def swapaxes(x1, axis1, axis2):
Parameter ``axis2`` is limited by ``axis2 < x1.ndim``.
Input array data types are limited by supported DPNP :ref:`Data types`.

Limitations
-----------
Parameters `x` is supported either as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Input array data types are limited by supported DPNP :ref:`Data types`.
Otherwise ``TypeError`` exception will be raised.

Notes
-----
If `x` has rank (i.e., number of dimensions) `N`,
a valid `axis` must be in the half-open interval `[-N, N)`.

Examples
--------
>>> import dpnp as np
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -947,24 +992,15 @@ def swapaxes(x1, axis1, axis2):

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
if x1_desc:
if axis1 >= x1_desc.ndim:
pass
elif axis2 >= x1_desc.ndim:
pass
else:
# 'do nothing' pattern for transpose()
input_permute = [i for i in range(x1.ndim)]
# swap axes
input_permute[axis1], input_permute[axis2] = (
input_permute[axis2],
input_permute[axis1],
)

return transpose(x1_desc.get_pyobj(), axes=input_permute)
if not dpnp.is_supported_array_type(x):
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
f"An array must be any of supported type, but got {type(x)}"
)

return call_origin(numpy.swapaxes, x1, axis1, axis2)
dpt_array = dpnp.get_usm_ndarray(x)
return dpnp_array._create_from_usm_ndarray(
dpt.swapaxes(dpt_array, axis1=axis1, axis2=axis2)
)


def transpose(a, axes=None):
Expand Down
9 changes: 2 additions & 7 deletions tests/third_party/cupy/linalg_tests/test_eigenvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@


def _get_hermitian(xp, a, UPLO):
# TODO: remove wrapping, but now there is no dpnp_array.swapaxes()
a = _wrap_as_numpy_array(xp, a)
_xp = numpy

if UPLO == "U":
_a = _xp.triu(a) + _xp.triu(a, k=1).swapaxes(-2, -1).conj()
return xp.triu(a) + xp.triu(a, k=1).swapaxes(-2, -1).conj()
else:
_a = _xp.tril(a) + _xp.tril(a, k=-1).swapaxes(-2, -1).conj()
return xp.array(_a)
return xp.tril(a) + xp.tril(a, k=-1).swapaxes(-2, -1).conj()


# TODO: remove once all required functionality is supported
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
5 changes: 5 additions & 0 deletions tests/third_party/cupy/manipulation_tests/test_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def test_expand_dims_repeated_axis(self):
with pytest.raises(ValueError):
xp.expand_dims(a, (1, 1))

def test_expand_dims_invalid_type(self):
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
a = testing.shaped_arange((2, 2), numpy)
with pytest.raises(TypeError):
cupy.expand_dims(a, 1)

@testing.numpy_cupy_array_equal()
def test_squeeze1(self, xp):
a = testing.shaped_arange((1, 2, 1, 3, 1, 1, 4, 1), xp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,17 @@ def test_swapaxes(self, xp):
a = testing.shaped_arange((2, 3, 4), xp)
return xp.swapaxes(a, 2, 0)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
def test_swapaxes_failure(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.swapaxes(a, 3, 0)

def test_swapaxes_invalid_type(self):
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
a = testing.shaped_arange((2, 3, 4), numpy)
with pytest.raises(TypeError):
cupy.swapaxes(a, 1, 0)

@testing.numpy_cupy_array_equal()
def test_transpose(self, xp):
a = testing.shaped_arange((2, 3, 4), xp)
Expand Down