Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dpnp.trim_zeros implementation #1941

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference/manipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Adding and removing elements
dpnp.resize
dpnp.unique
dpnp.trim_zeros
dpnp.pad


Rearranging elements
Expand Down
66 changes: 66 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"swapaxes",
"tile",
"transpose",
"trim_zeros",
"unique",
"vstack",
]
Expand Down Expand Up @@ -1927,6 +1928,71 @@ def transpose(a, axes=None):
return array.transpose(*axes)


def trim_zeros(filt, trim="fb"):
"""
Trim the leading and/or trailing zeros from a 1-D array.

For full documentation refer to :obj:`numpy.trim_zeros`.

Parameters
----------
filt : {dpnp.ndarray, usm_ndarray}
Input 1-D array.
trim : str, optional
A string with 'f' representing trim from front and 'b' to trim from
back. By defaults, trim zeros from both front and back of the array.
Default: ``"fb"``.

Returns
-------
out : dpnp.ndarray
The result of trimming the input.

Examples
--------
>>> import dpnp as np
>>> a = np.array((0, 0, 0, 1, 2, 3, 0, 2, 1, 0))
>>> np.trim_zeros(a)
array([1, 2, 3, 0, 2, 1])

>>> np.trim_zeros(a, 'b')
array([0, 0, 0, 1, 2, 3, 0, 2, 1])

"""

dpnp.check_supported_arrays_type(filt)
if filt.ndim == 0:
raise TypeError("0-d array cannot be trimmed")
if filt.ndim > 1:
raise ValueError("Multi-dimensional trim is not supported")

if not isinstance(trim, str):
raise TypeError("only string trim is supported")

trim = trim.upper()
if not any(x in trim for x in "FB"):
return filt # no trim rule is specified

if filt.size == 0:
return filt # no trailing zeros in empty array

a = dpnp.nonzero(filt)[0]
a_size = a.size
if a_size == 0:
# 'filt' is array of zeros
return dpnp.empty_like(filt, shape=(0,))

first = 0
if "F" in trim:
first = a[0]

last = filt.size
if "B" in trim:
last = a[-1] + 1

return filt[first:last]


def unique(ar, **kwargs):
"""
Find the unique elements of an array.
Expand Down
71 changes: 71 additions & 0 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,74 @@ def test_ndarray_axes_n_int(self):
expected = na.transpose(1, 0, 2)
result = da.transpose(1, 0, 2)
assert_array_equal(expected, result)


class TestTrimZeros:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_basic(self, dtype):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0], dtype=dtype)
ia = dpnp.array(a)

result = dpnp.trim_zeros(ia)
expected = numpy.trim_zeros(a)
assert_array_equal(expected, result)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B"])
def test_trim(self, dtype, trim):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0], dtype=dtype)
ia = dpnp.array(a)

result = dpnp.trim_zeros(ia, trim)
expected = numpy.trim_zeros(a, trim)
assert_array_equal(expected, result)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("trim", ["F", "B"])
def test_all_zero(self, dtype, trim):
a = numpy.zeros((8,), dtype=dtype)
ia = dpnp.array(a)

result = dpnp.trim_zeros(ia, trim)
expected = numpy.trim_zeros(a, trim)
assert_array_equal(expected, result)

def test_size_zero(self):
a = numpy.zeros(0)
ia = dpnp.array(a)

result = dpnp.trim_zeros(ia)
expected = numpy.trim_zeros(a)
assert_array_equal(expected, result)

@pytest.mark.parametrize(
"a", [numpy.array([0, 2**62, 0]), numpy.array([0, 2**63, 0])]
)
def test_overflow(self, a):
ia = dpnp.array(a)

result = dpnp.trim_zeros(ia)
expected = numpy.trim_zeros(a)
assert_array_equal(expected, result)

def test_trim_no_rule(self):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0])
ia = dpnp.array(a)
trim = "ADE" # no "F" or "B" in trim string

result = dpnp.trim_zeros(ia, trim)
expected = numpy.trim_zeros(a, trim)
assert_array_equal(expected, result)

def test_list_array(self):
assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0])

@pytest.mark.parametrize(
"trim", [1, ["F"], numpy.array("B")], ids=["int", "list", "array"]
)
def test_unsupported_trim(self, trim):
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0])
ia = dpnp.array(a)

assert_raises(TypeError, dpnp.trim_zeros, ia, trim)
assert_raises(AttributeError, numpy.trim_zeros, a, trim)
1 change: 1 addition & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def test_meshgrid(device):
"trace", [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
),
pytest.param("trapz", [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
pytest.param("trim_zeros", [0, 0, 0, 1, 2, 3, 0, 2, 1, 0]),
pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
pytest.param("var", [1.0, 2.0, 4.0, 7.0]),
],
Expand Down
1 change: 1 addition & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def test_norm(usm_type, ord, axis):
pytest.param(
"trace", [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
),
pytest.param("trim_zeros", [0, 0, 0, 1, 2, 3, 0, 2, 1, 0]),
pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
pytest.param("var", [1.0, 2.0, 4.0, 7.0]),
],
Expand Down
Loading
Loading