Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nd-support to dpnp.trim_zeros #2241

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 66 additions & 27 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3900,25 +3900,40 @@ def transpose(a, axes=None):
permute_dims = transpose # permute_dims is an alias for transpose


def trim_zeros(filt, trim="fb"):
def trim_zeros(filt, trim="fb", axis=None):
"""
Trim the leading and/or trailing zeros from a 1-D array.
Remove values along a dimension which are zero along all other.

For full documentation refer to :obj:`numpy.trim_zeros`.

Parameters
----------
filt : {dpnp.ndarray, usm_ndarray}
Input 1-D array.
trim : str, optional
A string with 'f' representing trim from front and 'b' to trim from
back. By defaults, trim zeros from both front and back of the array.
Input array.
trim : {"fb", "f", "b"}, optional
A string with `"f"` representing trim from front and `"b"` to trim from
back. By default, zeros are trimmed on both sides. Front and back refer
to the edges of a dimension, with "front" referring to the side with
the lowest index 0, and "back" referring to the highest index
(or index -1).
Default: ``"fb"``.
axis : {None, int}, optional
If ``None``, `filt` is cropped such, that the smallest bounding box is
returned that still contains all values which are not zero.
If an `axis` is specified, `filt` will be sliced in that dimension only
on the sides specified by `trim`. The remaining area will be the
smallest that still contains all values which are not zero.
Default: ``None``.

Returns
-------
out : dpnp.ndarray
The result of trimming the input.
The result of trimming the input. The number of dimensions and the
input data type are preserved.

Notes
-----
For all-zero arrays, the first axis is trimmed first.

Examples
--------
Expand All @@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"):
>>> np.trim_zeros(a)
array([1, 2, 3, 0, 2, 1])

>>> np.trim_zeros(a, 'b')
>>> np.trim_zeros(a, trim='b')
array([0, 0, 0, 1, 2, 3, 0, 2, 1])

Multiple dimensions are supported:

>>> b = np.array([[0, 0, 2, 3, 0, 0],
... [0, 1, 0, 3, 0, 0],
... [0, 0, 0, 0, 0, 0]])
>>> np.trim_zeros(b)
array([[0, 2, 3],
[1, 0, 3]])

>>> np.trim_zeros(b, axis=-1)
array([[0, 2, 3],
[1, 0, 3],
[0, 0, 0]])

"""

dpnp.check_supported_arrays_type(filt)
if filt.ndim == 0:
raise TypeError("0-d array cannot be trimmed")
if filt.ndim > 1:
raise ValueError("Multi-dimensional trim is not supported")

if not isinstance(trim, str):
raise TypeError("only string trim is supported")

trim = trim.upper()
if not any(x in trim for x in "FB"):
return filt # no trim rule is specified
trim = trim.lower()
if trim not in ["fb", "bf", "f", "b"]:
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")

nd = filt.ndim
if axis is not None:
axis = normalize_axis_index(axis, nd)

if filt.size == 0:
return filt # no trailing zeros in empty array

a = dpnp.nonzero(filt)[0]
a_size = a.size
if a_size == 0:
# 'filt' is array of zeros
return dpnp.empty_like(filt, shape=(0,))
non_zero = dpnp.argwhere(filt)
if non_zero.size == 0:
# `filt` has all zeros, so assign `start` and `stop` to the same value,
# then the resulting slice will be empty
start = stop = dpnp.zeros_like(filt, shape=nd, dtype=dpnp.intp)
else:
if "f" in trim:
start = non_zero.min(axis=0)
else:
start = (None,) * nd

first = 0
if "F" in trim:
first = a[0]
if "b" in trim:
stop = non_zero.max(axis=0)
stop += 1 # Adjust for slicing
else:
stop = (None,) * nd

last = filt.size
if "B" in trim:
last = a[-1] + 1
if axis is None:
# trim all axes
sl = tuple(slice(*x) for x in zip(start, stop))
else:
# only trim single axis
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)

return filt[first:last]
return filt[sl]


def unique(
Expand Down
43 changes: 32 additions & 11 deletions dpnp/tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,20 @@ def test_basic(self, dtype):
expected = numpy.trim_zeros(a)
assert_array_equal(result, expected)

@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
def test_basic_nd(self, dtype, trim, ndim):
a = numpy.ones((2,) * ndim, dtype=dtype)
a = numpy.pad(a, (2, 1), mode="constant", constant_values=0)
ia = dpnp.array(a)

for axis in list(range(ndim)) + [None]:
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
assert_array_equal(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B"])
def test_trim(self, dtype, trim):
Expand All @@ -1398,6 +1412,19 @@ def test_all_zero(self, dtype, trim):
expected = numpy.trim_zeros(a, trim)
assert_array_equal(result, expected)

@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
def test_all_zero_nd(self, dtype, trim, ndim):
a = numpy.zeros((3,) * ndim, dtype=dtype)
ia = dpnp.array(a)

for axis in list(range(ndim)) + [None]:
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
assert_array_equal(result, expected)

def test_size_zero(self):
a = numpy.zeros(0)
ia = dpnp.array(a)
Expand All @@ -1416,17 +1443,11 @@ def test_overflow(self, a):
expected = numpy.trim_zeros(a)
assert_array_equal(result, expected)

# TODO: modify once SAT-7616
# numpy 2.2 validates trim rules
@testing.with_requires("numpy<2.2")
def test_trim_no_rule(self):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0])
ia = dpnp.array(a)
trim = "ADE" # no "F" or "B" in trim string

result = dpnp.trim_zeros(ia, trim)
expected = numpy.trim_zeros(a, trim)
assert_array_equal(result, expected)
@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_trim_no_fb_in_rule(self, xp):
a = xp.array([0, 0, 1, 0, 2, 3, 4, 0])
assert_raises(ValueError, xp.trim_zeros, a, "ADE")

def test_list_array(self):
assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,15 @@ def test_trim_back_zeros(self, xp, dtype):
a = xp.array([1, 0, 2, 3, 0, 5, 0, 0, 0], dtype=dtype)
return xp.trim_zeros(a, trim=self.trim)

# TODO: remove once SAT-7616
@testing.with_requires("numpy<2.2")
@pytest.mark.skip("0-d array is supported")
@testing.for_all_dtypes()
def test_trim_zero_dim(self, dtype):
for xp in (numpy, cupy):
a = testing.shaped_arange((), xp, dtype)
with pytest.raises(TypeError):
xp.trim_zeros(a, trim=self.trim)

# TODO: remove once SAT-7616
@testing.with_requires("numpy<2.2")
@pytest.mark.skip("nd array is supported")
@testing.for_all_dtypes()
def test_trim_ndim(self, dtype):
for xp in (numpy, cupy):
Expand Down
Loading