Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leverage dpctl.tensor.expand_dims()/swapaxes() implementation #1532

Merged
merged 12 commits into from
Aug 23, 2023
45 changes: 0 additions & 45 deletions dpnp/dpnp_algo/dpnp_algo_manipulation.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ and the rest of the library
__all__ += [
"dpnp_atleast_2d",
"dpnp_atleast_3d",
"dpnp_expand_dims",
"dpnp_repeat",
"dpnp_reshape",
]


Expand Down Expand Up @@ -104,35 +102,6 @@ cpdef utils.dpnp_descriptor dpnp_atleast_3d(utils.dpnp_descriptor arr):
return arr


cpdef utils.dpnp_descriptor dpnp_expand_dims(utils.dpnp_descriptor in_array, axis):
axis_tuple = utils._object_to_tuple(axis)
result_ndim = len(axis_tuple) + in_array.ndim

if len(axis_tuple) == 0:
axis_ndim = 0
else:
axis_ndim = max(-min(0, min(axis_tuple)), max(0, max(axis_tuple))) + 1

axis_norm = utils._object_to_tuple(utils.normalize_axis(axis_tuple, result_ndim))

if axis_ndim - len(axis_norm) > in_array.ndim:
utils.checker_throw_axis_error("dpnp_expand_dims", "axis", axis, axis_ndim)

if len(axis_norm) > len(set(axis_norm)):
utils.checker_throw_value_error("dpnp_expand_dims", "axis", axis, "no repeated axis")

cdef shape_type_c shape_list
axis_idx = 0
for i in range(result_ndim):
if i in axis_norm:
shape_list.push_back(1)
else:
shape_list.push_back(in_array.shape[axis_idx])
axis_idx = axis_idx + 1

return dpnp_reshape(in_array, shape_list)
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved


cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, axes=None):
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)

Expand Down Expand Up @@ -165,17 +134,3 @@ cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, a
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_reshape(utils.dpnp_descriptor array1, newshape, order="C"):
# return dpnp.get_dpnp_descriptor(dpctl.tensor.usm_ndarray(newshape, dtype=numpy.dtype(array1.dtype).name, buffer=array1.get_pyobj()))
# return dpnp.get_dpnp_descriptor(dpctl.tensor.reshape(array1.get_pyobj(), newshape))
array1_obj = array1.get_array()
array_obj = dpctl.tensor.reshape(array1_obj, newshape, order=order)
return dpnp.get_dpnp_descriptor(dpnp_array(array_obj.shape,
buffer=array_obj,
order=order,
device=array1_obj.sycl_device,
usm_type=array1_obj.usm_type,
sycl_queue=array1_obj.sycl_queue),
copy_when_nondefault_queue=False)
9 changes: 8 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,14 @@ def sum(
where=where,
)

# 'swapaxes',
def swapaxes(self, axis1, axis2):
"""
Interchange two axes of an array.

For full documentation refer to :obj:`numpy.swapaxes`.
"""

return dpnp.swapaxes(self, axis1=axis1, axis2=axis2)

def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
"""
Expand Down
Loading