Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of dpnp.unstack() #2106

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"transpose",
"trim_zeros",
"unique",
"unstack",
"vsplit",
"vstack",
]
Expand Down Expand Up @@ -1723,6 +1724,8 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
an axis.

Examples
--------
Expand Down Expand Up @@ -2913,6 +2916,8 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"):
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
an axis.

Examples
--------
Expand Down Expand Up @@ -3413,6 +3418,84 @@ def unique(
return _unpack_tuple(result)


def unstack(x, /, *, axis=0):
"""
Split an array into a sequence of arrays along the given axis.

The `axis` parameter specifies the dimension along which the array will
be split. For example, if ``axis=0`` (the default) it will be the first
dimension and if ``axis=-1`` it will be the last dimension.

The result is a tuple of arrays split along `axis`.

For full documentation refer to :obj:`numpy.unstack`.

Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
The array to be unstacked.
axis : int, optional
Axis along which the array will be split.
Default: ``0``.

Returns
-------
unstacked : tuple of dpnp.ndarray
The unstacked arrays.

See Also
--------
:obj:`dpnp.stack` : Join a sequence of arrays along a new axis.
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.

Notes
-----
:obj:`dpnp.unstack` serves as the reverse operation of :obj:`dpnp.stack`,
i.e., ``dpnp.stack(dpnp.unstack(x, axis=axis), axis=axis) == x``.

This function is equivalent to ``tuple(dpnp.moveaxis(x, axis, 0))``, since
iterating on an array iterates along the first axis.

Examples
--------
>>> import dpnp as np
>>> arr = np.arange(24).reshape((2, 3, 4))
>>> np.unstack(arr)
(array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]),
array([[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]))

>>> np.unstack(arr, axis=1)
(array([[ 0, 1, 2, 3],
[12, 13, 14, 15]]),
array([[ 4, 5, 6, 7],
[16, 17, 18, 19]]),
array([[ 8, 9, 10, 11],
[20, 21, 22, 23]]))

>>> arr2 = np.stack(np.unstack(arr, axis=1), axis=1)
>>> arr2.shape
(2, 3, 4)
>>> np.all(arr == arr2)
array(True)

"""

usm_x = dpnp.get_usm_ndarray(x)

if usm_x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")

res = dpt.unstack(usm_x, axis=axis)
return tuple(dpnp_array._create_from_usm_ndarray(a) for a in res)


def vsplit(ary, indices_or_sections):
"""
Split an array into multiple sub-arrays vertically (row-wise).
Expand Down Expand Up @@ -3521,6 +3604,8 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
an axis.

Examples
--------
Expand Down
79 changes: 79 additions & 0 deletions tests/test_arraymanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,85 @@ def test_generator(self):
dpnp.stack(map(lambda x: x, dpnp.ones((3, 2))))


# numpy.unstack() is available since numpy >= 2.1
@testing.with_requires("numpy>=2.1")
class TestUnstack:
def test_non_array_input(self):
with pytest.raises(TypeError):
dpnp.unstack(1)

@pytest.mark.parametrize(
"input", [([1, 2, 3],), [dpnp.int32(1), dpnp.int32(2), dpnp.int32(3)]]
)
def test_scalar_input(self, input):
with pytest.raises(TypeError):
dpnp.unstack(input)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_0d_array_input(self, dtype):
np_a = numpy.array(1, dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

with pytest.raises(ValueError):
numpy.unstack(np_a)
with pytest.raises(ValueError):
dpnp.unstack(dp_a)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_1d_array(self, dtype):
np_a = numpy.array([1, 2, 3], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a)
dp_res = dpnp.unstack(dp_a)
assert len(dp_res) == len(np_res)
for dp_arr, np_arr in zip(dp_res, np_res):
assert_array_equal(dp_arr.asnumpy(), np_arr)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_2d_array(self, dtype):
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a, axis=0)
dp_res = dpnp.unstack(dp_a, axis=0)
assert len(dp_res) == len(np_res)
for dp_arr, np_arr in zip(dp_res, np_res):
assert_array_equal(dp_arr.asnumpy(), np_arr)

@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_2d_array_axis(self, axis, dtype):
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a, axis=axis)
dp_res = dpnp.unstack(dp_a, axis=axis)
assert len(dp_res) == len(np_res)
for dp_arr, np_arr in zip(dp_res, np_res):
assert_array_equal(dp_arr.asnumpy(), np_arr)

@pytest.mark.parametrize("axis", [2, -3])
@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_invalid_axis(self, axis, dtype):
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

with pytest.raises(AxisError):
numpy.unstack(np_a, axis=axis)
with pytest.raises(AxisError):
dpnp.unstack(dp_a, axis=axis)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_empty_array(self, dtype):
np_a = numpy.array([], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a)
dp_res = dpnp.unstack(dp_a)
assert len(dp_res) == len(np_res)


class TestVstack:
def test_non_iterable(self):
assert_raises(TypeError, dpnp.vstack, 1)
Expand Down
Loading