Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse dpctl.tensor.squeeze for dpnp.squeeze #1381

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions dpnp/dpnp_algo/dpnp_algo_manipulation.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ __all__ += [
"dpnp_repeat",
"dpnp_reshape",
"dpnp_transpose",
"dpnp_squeeze",
]


Expand Down Expand Up @@ -294,23 +293,3 @@ cpdef utils.dpnp_descriptor dpnp_transpose(utils.dpnp_descriptor array1, axes=No
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_squeeze(utils.dpnp_descriptor in_array, axis):
cdef shape_type_c shape_list
if axis is None:
for i in range(in_array.ndim):
if in_array.shape[i] != 1:
shape_list.push_back(in_array.shape[i])
else:
axis_norm = utils._object_to_tuple(utils.normalize_axis(utils._object_to_tuple(axis), in_array.ndim))
for i in range(in_array.ndim):
if i in axis_norm:
if in_array.shape[i] != 1:
utils.checker_throw_value_error("dpnp_squeeze", "axis", axis, "axis has size not equal to one")
else:
shape_list.push_back(in_array.shape[i])

in_array_obj = in_array.get_array()

return dpnp_reshape(dpnp_copy(in_array), shape_list)
37 changes: 22 additions & 15 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,28 @@ def rollaxis(x1, axis, start=0):
return call_origin(numpy.rollaxis, x1, axis, start)


def squeeze(x1, axis=None):
def squeeze(x, axis=None):
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
"""
Remove single-dimensional entries from the shape of an array.
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved

For full documentation refer to :obj:`numpy.squeeze`.

Returns
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
-------
out : dpnp.ndarray
Output array is a view, if possible,
and a copy otherwise, but with all or a subset of the
dimensions of length 1 removed. Output has the same data
type as the input, is allocated on the same device as the
input and has the same USM allocation type as the input
array `x`.

Limitations
-----------
Parameters `x` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Otherwise the function will be executed sequentially on CPU.

Examples
--------
>>> import dpnp as np
Expand All @@ -602,26 +618,17 @@ def squeeze(x1, axis=None):
>>> np.squeeze(x, axis=1).shape
Traceback (most recent call last):
...
ValueError: cannot select an axis to squeeze out which has size not equal to one
ValueError: Cannot select an axis to squeeze out which has size not equal to one
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
>>> np.squeeze(x, axis=2).shape
(1, 3)
>>> x = np.array([[1234]])
>>> x.shape
(1, 1)
>>> np.squeeze(x)
array(1234) # 0d array
>>> np.squeeze(x).shape
()
>>> np.squeeze(x)[()]
1234

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
if x1_desc:
return dpnp_squeeze(x1_desc, axis).get_pyobj()
if isinstance(x, dpnp_array) or isinstance(x, dpt.usm_ndarray):
dpt_array = x.get_array() if isinstance(x, dpnp_array) else x
return dpnp_array._create_from_usm_ndarray(dpt.squeeze(dpt_array, axis))

return call_origin(numpy.squeeze, x1, axis)
return call_origin(numpy.squeeze, x, axis)


def stack(arrays, axis=0, out=None):
Expand Down
9 changes: 0 additions & 9 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -667,15 +667,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestBroadcast_param_8_{s
tests/third_party/cupy/manipulation_tests/test_dims.py::TestBroadcast_param_9_{shapes=[(0, 1, 1, 3), (2, 1, 0, 0, 3)]}::test_broadcast
tests/third_party/cupy/manipulation_tests/test_dims.py::TestBroadcast_param_9_{shapes=[(0, 1, 1, 3), (2, 1, 0, 0, 3)]}::test_broadcast_arrays

tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_int_axis_failure1
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_int_axis_failure2
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure1
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure2
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure3
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure4
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_tuple_axis_failure1
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_tuple_axis_failure2
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_tuple_axis_failure3
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_0_{shapes=[(3,), (2,)]}::test_invalid_broadcast
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_0_{shapes=[(3,), (2,)]}::test_invalid_broadcast_arrays
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_1_{shapes=[(3, 2), (2, 3)]}::test_invalid_broadcast
Expand Down
9 changes: 0 additions & 9 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -828,15 +828,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestBroadcast_param_8_{s
tests/third_party/cupy/manipulation_tests/test_dims.py::TestBroadcast_param_9_{shapes=[(0, 1, 1, 3), (2, 1, 0, 0, 3)]}::test_broadcast
tests/third_party/cupy/manipulation_tests/test_dims.py::TestBroadcast_param_9_{shapes=[(0, 1, 1, 3), (2, 1, 0, 0, 3)]}::test_broadcast_arrays

tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_int_axis_failure1
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_int_axis_failure2
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure1
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure2
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure3
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_scalar_failure4
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_tuple_axis_failure1
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_tuple_axis_failure2
tests/third_party/cupy/manipulation_tests/test_dims.py::TestDims::test_squeeze_tuple_axis_failure3
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_0_{shapes=[(3,), (2,)]}::test_invalid_broadcast
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_0_{shapes=[(3,), (2,)]}::test_invalid_broadcast_arrays
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_1_{shapes=[(3, 2), (2, 3)]}::test_invalid_broadcast
Expand Down