Skip to content

Commit

Permalink
Add nd-support to dpnp.trim_zeros (#2241)
Browse files Browse the repository at this point in the history
The PR extends `dpnp.trim_zeros` implement to align with changes
introduced in NumPy 2.2.

It adds support for trimming nd-arrays while preserving the old behavior
for 1-d input. The new parameter `axis` can specify a single dimension
to be trimmed (reducing all other dimensions to the envelope of absolute
values). By default (when `None` is specified), all dimensions are
trimmed iteratively.
  • Loading branch information
antonwolfy authored Dec 20, 2024
1 parent fdde3d0 commit 94bea60
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 42 deletions.
93 changes: 66 additions & 27 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3900,25 +3900,40 @@ def transpose(a, axes=None):
permute_dims = transpose # permute_dims is an alias for transpose


def trim_zeros(filt, trim="fb"):
def trim_zeros(filt, trim="fb", axis=None):
"""
Trim the leading and/or trailing zeros from a 1-D array.
Remove values along a dimension which are zero along all other.
For full documentation refer to :obj:`numpy.trim_zeros`.
Parameters
----------
filt : {dpnp.ndarray, usm_ndarray}
Input 1-D array.
trim : str, optional
A string with 'f' representing trim from front and 'b' to trim from
back. By defaults, trim zeros from both front and back of the array.
Input array.
trim : {"fb", "f", "b"}, optional
A string with `"f"` representing trim from front and `"b"` to trim from
back. By default, zeros are trimmed on both sides. Front and back refer
to the edges of a dimension, with "front" referring to the side with
the lowest index 0, and "back" referring to the highest index
(or index -1).
Default: ``"fb"``.
axis : {None, int}, optional
If ``None``, `filt` is cropped such, that the smallest bounding box is
returned that still contains all values which are not zero.
If an `axis` is specified, `filt` will be sliced in that dimension only
on the sides specified by `trim`. The remaining area will be the
smallest that still contains all values which are not zero.
Default: ``None``.
Returns
-------
out : dpnp.ndarray
The result of trimming the input.
The result of trimming the input. The number of dimensions and the
input data type are preserved.
Notes
-----
For all-zero arrays, the first axis is trimmed first.
Examples
--------
Expand All @@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"):
>>> np.trim_zeros(a)
array([1, 2, 3, 0, 2, 1])
>>> np.trim_zeros(a, 'b')
>>> np.trim_zeros(a, trim='b')
array([0, 0, 0, 1, 2, 3, 0, 2, 1])
Multiple dimensions are supported:
>>> b = np.array([[0, 0, 2, 3, 0, 0],
... [0, 1, 0, 3, 0, 0],
... [0, 0, 0, 0, 0, 0]])
>>> np.trim_zeros(b)
array([[0, 2, 3],
[1, 0, 3]])
>>> np.trim_zeros(b, axis=-1)
array([[0, 2, 3],
[1, 0, 3],
[0, 0, 0]])
"""

dpnp.check_supported_arrays_type(filt)
if filt.ndim == 0:
raise TypeError("0-d array cannot be trimmed")
if filt.ndim > 1:
raise ValueError("Multi-dimensional trim is not supported")

if not isinstance(trim, str):
raise TypeError("only string trim is supported")

trim = trim.upper()
if not any(x in trim for x in "FB"):
return filt # no trim rule is specified
trim = trim.lower()
if trim not in ["fb", "bf", "f", "b"]:
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")

nd = filt.ndim
if axis is not None:
axis = normalize_axis_index(axis, nd)

if filt.size == 0:
return filt # no trailing zeros in empty array

a = dpnp.nonzero(filt)[0]
a_size = a.size
if a_size == 0:
# 'filt' is array of zeros
return dpnp.empty_like(filt, shape=(0,))
non_zero = dpnp.argwhere(filt)
if non_zero.size == 0:
# `filt` has all zeros, so assign `start` and `stop` to the same value,
# then the resulting slice will be empty
start = stop = dpnp.zeros_like(filt, shape=nd, dtype=dpnp.intp)
else:
if "f" in trim:
start = non_zero.min(axis=0)
else:
start = (None,) * nd

first = 0
if "F" in trim:
first = a[0]
if "b" in trim:
stop = non_zero.max(axis=0)
stop += 1 # Adjust for slicing
else:
stop = (None,) * nd

last = filt.size
if "B" in trim:
last = a[-1] + 1
if axis is None:
# trim all axes
sl = tuple(slice(*x) for x in zip(start, stop))
else:
# only trim single axis
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)

return filt[first:last]
return filt[sl]


def unique(
Expand Down
43 changes: 32 additions & 11 deletions dpnp/tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,20 @@ def test_basic(self, dtype):
expected = numpy.trim_zeros(a)
assert_array_equal(result, expected)

@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
def test_basic_nd(self, dtype, trim, ndim):
a = numpy.ones((2,) * ndim, dtype=dtype)
a = numpy.pad(a, (2, 1), mode="constant", constant_values=0)
ia = dpnp.array(a)

for axis in list(range(ndim)) + [None]:
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
assert_array_equal(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B"])
def test_trim(self, dtype, trim):
Expand All @@ -1398,6 +1412,19 @@ def test_all_zero(self, dtype, trim):
expected = numpy.trim_zeros(a, trim)
assert_array_equal(result, expected)

@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
def test_all_zero_nd(self, dtype, trim, ndim):
a = numpy.zeros((3,) * ndim, dtype=dtype)
ia = dpnp.array(a)

for axis in list(range(ndim)) + [None]:
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
assert_array_equal(result, expected)

def test_size_zero(self):
a = numpy.zeros(0)
ia = dpnp.array(a)
Expand All @@ -1416,17 +1443,11 @@ def test_overflow(self, a):
expected = numpy.trim_zeros(a)
assert_array_equal(result, expected)

# TODO: modify once SAT-7616
# numpy 2.2 validates trim rules
@testing.with_requires("numpy<2.2")
def test_trim_no_rule(self):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0])
ia = dpnp.array(a)
trim = "ADE" # no "F" or "B" in trim string

result = dpnp.trim_zeros(ia, trim)
expected = numpy.trim_zeros(a, trim)
assert_array_equal(result, expected)
@testing.with_requires("numpy>=2.2")
@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_trim_no_fb_in_rule(self, xp):
a = xp.array([0, 0, 1, 0, 2, 3, 4, 0])
assert_raises(ValueError, xp.trim_zeros, a, "ADE")

def test_list_array(self):
assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,15 @@ def test_trim_back_zeros(self, xp, dtype):
a = xp.array([1, 0, 2, 3, 0, 5, 0, 0, 0], dtype=dtype)
return xp.trim_zeros(a, trim=self.trim)

# TODO: remove once SAT-7616
@testing.with_requires("numpy<2.2")
@pytest.mark.skip("0-d array is supported")
@testing.for_all_dtypes()
def test_trim_zero_dim(self, dtype):
for xp in (numpy, cupy):
a = testing.shaped_arange((), xp, dtype)
with pytest.raises(TypeError):
xp.trim_zeros(a, trim=self.trim)

# TODO: remove once SAT-7616
@testing.with_requires("numpy<2.2")
@pytest.mark.skip("nd array is supported")
@testing.for_all_dtypes()
def test_trim_ndim(self, dtype):
for xp in (numpy, cupy):
Expand Down

0 comments on commit 94bea60

Please sign in to comment.