Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leveraged dpctl.tensor.stack() implementation #1509

Merged
merged 2 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ env:
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'
TEST_SCOPE: >-
test_arraycreation.py
test_arraymanipulation.py
test_dot.py
test_dparray.py
test_fft.py
Expand All @@ -23,6 +24,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/manipulation_tests/test_join.py
third_party/cupy/math_tests/test_explog.py
third_party/cupy/math_tests/test_misc.py
third_party/cupy/math_tests/test_trigonometric.py
Expand Down
71 changes: 66 additions & 5 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def broadcast_to(x, /, shape, subok=False):
return call_origin(numpy.broadcast_to, x, shape=shape, subok=subok)


def concatenate(arrays, *, axis=0, out=None, dtype=None, **kwargs):
def concatenate(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
"""
Join a sequence of arrays along an existing axis.

Expand All @@ -253,8 +253,7 @@ def concatenate(arrays, *, axis=0, out=None, dtype=None, **kwargs):
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exeption
will be raised.
Parameter `out` is supported with default value.
Parameter `dtype` is supported with default value.
Parameters `out` and `dtype are supported with default value.
Keyword argument ``kwargs`` is currently unsupported.
Otherwise the function will be executed sequentially on CPU.

Expand Down Expand Up @@ -834,15 +833,77 @@ def squeeze(x, /, axis=None):
return call_origin(numpy.squeeze, x, axis)


def stack(arrays, axis=0, out=None):
def stack(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
"""
Join a sequence of arrays along a new axis.

For full documentation refer to :obj:`numpy.stack`.

Returns
-------
out : dpnp.ndarray
The stacked array which has one more dimension than the input arrays.

Limitations
-----------
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exeption
will be raised.
Parameters `out` and `dtype are supported with default value.
Keyword argument ``kwargs`` is currently unsupported.
Otherwise the function will be executed sequentially on CPU.

See Also
--------
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
:obj:`dpnp.block` : Assemble an nd-array from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal size.

Examples
--------
>>> import dpnp as np
>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
>>> np.stack(arrays, axis=0).shape
(10, 3, 4)

>>> np.stack(arrays, axis=1).shape
(3, 10, 4)

>>> np.stack(arrays, axis=2).shape
(3, 4, 10)

>>> a = np.array([1, 2, 3])
>>> b = np.array([4, 5, 6])
>>> np.stack((a, b))
array([[1, 2, 3],
[4, 5, 6]])

>>> np.stack((a, b), axis=-1)
array([[1, 4],
[2, 5],
[3, 6]])

"""

return call_origin(numpy.stack, arrays, axis, out)
if kwargs:
pass
elif out is not None:
pass
elif dtype is not None:
pass
else:
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
usm_res = dpt.stack(usm_arrays, axis=axis)
return dpnp_array._create_from_usm_ndarray(usm_res)

return call_origin(
numpy.stack,
arrays,
axis=axis,
out=out,
dtype=dtype,
**kwargs,
)


def swapaxes(x1, axis1, axis2):
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def allow_fall_back_on_numpy(monkeypatch):
)


@pytest.fixture
def suppress_complex_warning():
sup = numpy.testing.suppress_warnings("always")
sup.filter(numpy.ComplexWarning)
with sup:
yield


@pytest.fixture
def suppress_divide_numpy_warnings():
# divide: treatment for division by zero (infinite result obtained from finite numbers)
Expand Down
Loading