diff --git a/dpctl/tensor/_reshape.py b/dpctl/tensor/_reshape.py index 1f057c636a..575c79c115 100644 --- a/dpctl/tensor/_reshape.py +++ b/dpctl/tensor/_reshape.py @@ -18,7 +18,6 @@ import numpy as np import dpctl.tensor as dpt -from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray from dpctl.tensor._tensor_impl import ( _copy_usm_ndarray_for_reshape, _ravel_multi_index, @@ -155,32 +154,37 @@ def reshape(X, /, shape, *, order="C", copy=None): "Reshaping the array requires a copy, but no copying was " "requested by using copy=False" ) + copy_q = X.sycl_queue if copy_required or (copy is True): # must perform a copy flat_res = dpt.usm_ndarray( (X.size,), dtype=X.dtype, buffer=X.usm_type, - buffer_ctor_kwargs={"queue": X.sycl_queue}, + buffer_ctor_kwargs={"queue": copy_q}, ) if order == "C": hev, _ = _copy_usm_ndarray_for_reshape( - src=X, dst=flat_res, sycl_queue=X.sycl_queue + src=X, dst=flat_res, sycl_queue=copy_q ) - hev.wait() else: - for i in range(X.size): - _copy_from_usm_ndarray_to_usm_ndarray( - flat_res[i], X[np.unravel_index(i, X.shape, order=order)] - ) + X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1)) + hev, _ = _copy_usm_ndarray_for_reshape( + src=X_t, dst=flat_res, sycl_queue=copy_q + ) + hev.wait() return dpt.usm_ndarray( tuple(shape), dtype=X.dtype, buffer=flat_res, order=order ) # can form a view + if (len(shape) == X.ndim) and all( + s1 == s2 for s1, s2 in zip(shape, X.shape) + ): + return X return dpt.usm_ndarray( shape, dtype=X.dtype, buffer=X, strides=tuple(newsts), - offset=X.__sycl_usm_array_interface__.get("offset", 0), + offset=X._element_offset, ) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index a0f2414fce..6d10c63f03 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1454,6 +1454,33 @@ def test_reshape(): assert A4.shape == requested_shape +def test_reshape_orderF(): + try: + a = dpt.arange(6 * 3 * 4, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + b = dpt.reshape(a, (6, 2, 6)) + c = dpt.reshape(b, (9, 8), order="F") + assert c.flags.f_contiguous + assert c._pointer != b._pointer + assert b._pointer == a._pointer + + a_np = np.arange(6 * 3 * 4, dtype="i4") + b_np = np.reshape(a_np, (6, 2, 6)) + c_np = np.reshape(b_np, (9, 8), order="F") + assert np.array_equal(c_np, dpt.asnumpy(c)) + + +def test_reshape_noop(): + """Per gh-1664""" + try: + a = dpt.ones((2, 1)) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + b = dpt.reshape(a, (2, 1)) + assert b is a + + def test_reshape_zero_size(): try: a = dpt.empty((0,))