Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.ediff1d function #1983

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 34 additions & 32 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,30 @@ def _gradient_num_diff_edges(
)


def _process_ediff1d_args(arg, arg_name, ary_dtype, ary_sycl_queue, usm_type):
"""Process the argument for ediff1d."""
if not dpnp.is_supported_array_type(arg):
arg = dpnp.asarray(arg, usm_type=usm_type, sycl_queue=ary_sycl_queue)
else:
usm_type = dpu.get_coerced_usm_type([usm_type, arg.usm_type])
# check that arrays have the same allocation queue
if dpu.get_execution_queue([ary_sycl_queue, arg.sycl_queue]) is None:
raise dpu.ExecutionPlacementError(
f"ary and {arg_name} must be allocated on the same SYCL queue"
)

if not dpnp.can_cast(arg, ary_dtype, casting="same_kind"):
raise TypeError(
f"dtype of {arg_name} must be compatible "
"with input ary under the `same_kind` rule."
)

if arg.ndim > 1:
arg = dpnp.ravel(arg)

return arg, usm_type


_ABS_DOCSTRING = """
Calculates the absolute value for each element `x_i` of input array `x`.

Expand Down Expand Up @@ -1332,52 +1356,30 @@ def ediff1d(ary, to_end=None, to_begin=None):
return ary[1:] - ary[:-1]

ary_dtype = ary.dtype
ary_usm_type = ary.usm_type
ary_sycl_queue = ary.sycl_queue
usm_type = ary.usm_type

if to_begin is None:
l_begin = 0
else:
if not dpnp.is_supported_array_type(to_begin):
to_begin = dpnp.asarray(
to_begin, usm_type=ary_usm_type, sycl_queue=ary_sycl_queue
)
if not dpnp.can_cast(to_begin, ary_dtype, casting="same_kind"):
raise TypeError(
"dtype of `to_begin` must be compatible "
"with input `ary` under the `same_kind` rule."
)

to_begin_ndim = to_begin.ndim

if to_begin_ndim > 1:
to_begin = dpnp.ravel(to_begin)

to_begin, usm_type = _process_ediff1d_args(
to_begin, "to_begin", ary_dtype, ary_sycl_queue, usm_type
)
l_begin = to_begin.size

if to_end is None:
l_end = 0
else:
if not dpnp.is_supported_array_type(to_end):
to_end = dpnp.asarray(
to_end, usm_type=ary_usm_type, sycl_queue=ary_sycl_queue
)
if not dpnp.can_cast(to_end, ary_dtype, casting="same_kind"):
raise TypeError(
"dtype of `to_end` must be compatible "
"with input `ary` under the `same_kind` rule."
)

to_end_ndim = to_end.ndim

if to_end_ndim > 1:
to_end = dpnp.ravel(to_end)

to_end, usm_type = _process_ediff1d_args(
to_end, "to_end", ary_dtype, ary_sycl_queue, usm_type
)
l_end = to_end.size

# calculating using in place operation
l_diff = max(len(ary) - 1, 0)
result = dpnp.empty_like(ary, shape=l_diff + l_begin + l_end)
result = dpnp.empty_like(
ary, shape=l_diff + l_begin + l_end, usm_type=usm_type
)

if l_begin > 0:
result[:l_begin] = to_begin
Expand Down
12 changes: 12 additions & 0 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,18 @@ def test_ediff1d_errors(self):
to_end = dpnp.array([5], dtype="f4")
assert_raises(TypeError, dpnp.ediff1d, a_dp, to_end=to_end)

# another `to_begin` sycl queue
to_begin = dpnp.array([-20, -15], sycl_queue=dpctl.SyclQueue())
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
assert_raises(
ExecutionPlacementError, dpnp.ediff1d, a_dp, to_begin=to_begin
)

# another `to_end` sycl queue
to_end = dpnp.array([15, 20], sycl_queue=dpctl.SyclQueue())
assert_raises(
ExecutionPlacementError, dpnp.ediff1d, a_dp, to_end=to_end
)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
class TestTrapz:
Expand Down
15 changes: 5 additions & 10 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,12 +2407,7 @@ def test_nan_to_num(copy, device):


@pytest.mark.parametrize(
"device_x",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
@pytest.mark.parametrize(
"device_args",
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
Expand All @@ -2424,15 +2419,15 @@ def test_nan_to_num(copy, device):
(10, -10),
],
)
def test_ediff1d(device_x, device_args, to_end, to_begin):
def test_ediff1d(device, to_end, to_begin):
data = [1, 3, 5, 7]

x = dpnp.array(data, device=device_x)
x = dpnp.array(data, device=device)
if to_end:
to_end = dpnp.array(to_end, device=device_args)
to_end = dpnp.array(to_end, device=device)

if to_begin:
to_begin = dpnp.array(to_begin, device=device_args)
to_begin = dpnp.array(to_begin, device=device)

res = dpnp.ediff1d(x, to_end=to_end, to_begin=to_begin)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,4 +1431,4 @@ def test_ediff1d(usm_type_x, usm_type_args, to_end, to_begin):

res = dp.ediff1d(x, to_end=to_end, to_begin=to_begin)

assert res.usm_type == x.usm_type
assert res.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_args])
Loading