Skip to content

Commit

Permalink
Check type of input in dpnp.repeat to raise a proper validation exc…
Browse files Browse the repository at this point in the history
…eption if any (#1894)

* Check type of input to raise a proper validation exception if any

* Update dpnp/dpnp_iface_manipulation.py

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>

---------

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
  • Loading branch information
antonwolfy and vtavana authored Jun 28, 2024
1 parent 067a784 commit 437f046
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 141 deletions.
27 changes: 16 additions & 11 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,12 +1248,16 @@ def repeat(a, repeats, axis=None):
----------
x : {dpnp.ndarray, usm_ndarray}
Input array.
repeat : int or array of int
repeats : {int, tuple, list, range, dpnp.ndarray, usm_ndarray}
The number of repetitions for each element. `repeats` is broadcasted to
fit the shape of the given axis.
axis : int, optional
If `repeats` is an array, it must have an integer data type.
Otherwise, `repeats` must be a Python integer or sequence of Python
integers (i.e., a tuple, list, or range).
axis : {None, int}, optional
The axis along which to repeat values. By default, use the flattened
input array, and return a flat output array.
Default: ``None``.
Returns
-------
Expand All @@ -1263,8 +1267,8 @@ def repeat(a, repeats, axis=None):
See Also
--------
:obj:`dpnp.tile` : Construct an array by repeating A the number of times
given by reps.
:obj:`dpnp.tile` : Tile an array.
:obj:`dpnp.unique` : Find the unique elements of an array.
Examples
--------
Expand All @@ -1286,14 +1290,15 @@ def repeat(a, repeats, axis=None):
"""

rep = repeats
if isinstance(repeats, dpnp_array):
rep = dpnp.get_usm_ndarray(repeats)
dpnp.check_supported_arrays_type(a)
if not isinstance(repeats, (int, tuple, list, range)):
repeats = dpnp.get_usm_ndarray(repeats)

if axis is None and a.ndim > 1:
usm_arr = dpnp.get_usm_ndarray(a.flatten())
else:
usm_arr = dpnp.get_usm_ndarray(a)
usm_arr = dpt.repeat(usm_arr, rep, axis=axis)
a = dpnp.ravel(a)

usm_arr = dpnp.get_usm_ndarray(a)
usm_arr = dpt.repeat(usm_arr, repeats, axis=axis)
return dpnp_array._create_from_usm_ndarray(usm_arr)


Expand Down
111 changes: 0 additions & 111 deletions tests/test_arraymanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,114 +1016,3 @@ def test_can_cast():
assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32")
assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32)
assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64)


def test_repeat_scalar_sequence_agreement():
x = dpnp.arange(5, dtype="i4")
expected_res = dpnp.empty(10, dtype="i4")
expected_res[1::2], expected_res[::2] = x, x

# scalar case
reps = 2
res = dpnp.repeat(x, reps)
assert dpnp.all(res == expected_res)

# tuple
reps = (2, 2, 2, 2, 2)
res = dpnp.repeat(x, reps)
assert dpnp.all(res == expected_res)


def test_repeat_as_broadcasting():
reps = 5
x = dpnp.arange(reps, dtype="i4")
x1 = x[:, dpnp.newaxis]
expected_res = dpnp.broadcast_to(x1, (reps, reps))

res = dpnp.repeat(x1, reps, axis=1)
assert dpnp.all(res == expected_res)

x2 = x[dpnp.newaxis, :]
expected_res = dpnp.broadcast_to(x2, (reps, reps))

res = dpnp.repeat(x2, reps, axis=0)
assert dpnp.all(res == expected_res)


def test_repeat_axes():
reps = 2
x = dpnp.reshape(dpnp.arange(5 * 10, dtype="i4"), (5, 10))
expected_res = dpnp.empty((x.shape[0] * 2, x.shape[1]), dtype=x.dtype)
expected_res[::2, :], expected_res[1::2] = x, x
res = dpnp.repeat(x, reps, axis=0)
assert dpnp.all(res == expected_res)

expected_res = dpnp.empty((x.shape[0], x.shape[1] * 2), dtype=x.dtype)
expected_res[:, ::2], expected_res[:, 1::2] = x, x
res = dpnp.repeat(x, reps, axis=1)
assert dpnp.all(res == expected_res)


def test_repeat_size_0_outputs():
x = dpnp.ones((3, 0, 5), dtype="i4")
reps = 10
res = dpnp.repeat(x, reps, axis=0)
assert res.size == 0
assert res.shape == (30, 0, 5)

res = dpnp.repeat(x, reps, axis=1)
assert res.size == 0
assert res.shape == (3, 0, 5)

res = dpnp.repeat(x, (2, 2, 2), axis=0)
assert res.size == 0
assert res.shape == (6, 0, 5)

x = dpnp.ones((3, 2, 5))
res = dpnp.repeat(x, 0, axis=1)
assert res.size == 0
assert res.shape == (3, 0, 5)

x = dpnp.ones((3, 2, 5))
res = dpnp.repeat(x, (0, 0), axis=1)
assert res.size == 0
assert res.shape == (3, 0, 5)


def test_repeat_strides():
reps = 2
x = dpnp.reshape(dpnp.arange(10 * 10, dtype="i4"), (10, 10))
x1 = x[:, ::-2]
expected_res = dpnp.empty((10, 10), dtype="i4")
expected_res[:, ::2], expected_res[:, 1::2] = x1, x1
res = dpnp.repeat(x1, reps, axis=1)
assert dpnp.all(res == expected_res)
res = dpnp.repeat(x1, (reps,) * x1.shape[1], axis=1)
assert dpnp.all(res == expected_res)

x1 = x[::-2, :]
expected_res = dpnp.empty((10, 10), dtype="i4")
expected_res[::2, :], expected_res[1::2, :] = x1, x1
res = dpnp.repeat(x1, reps, axis=0)
assert dpnp.all(res == expected_res)
res = dpnp.repeat(x1, (reps,) * x1.shape[0], axis=0)
assert dpnp.all(res == expected_res)


def test_repeat_casting():
x = dpnp.arange(5, dtype="i4")
# i4 is cast to i8
reps = dpnp.ones(5, dtype="i4")
res = dpnp.repeat(x, reps)
assert res.shape == x.shape
assert dpnp.all(res == x)


def test_repeat_strided_repeats():
x = dpnp.arange(5, dtype="i4")
reps = dpnp.ones(10, dtype="i8")
reps[::2] = 0
reps = reps[::-2]
res = dpnp.repeat(x, reps)
assert res.shape == x.shape
assert dpnp.all(res == x)
Loading

0 comments on commit 437f046

Please sign in to comment.