Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 12, 2024
1 parent 8502cfa commit 33da5ff
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def take(x, indices, /, *, axis=None):
if axis is None:
x = x.flatten()
return x[indices]

axis = normalize_axis(axis, x.ndim)
full_index = (slice(None),) * axis + (indices, ...)
return x[full_index]
Expand Down Expand Up @@ -1293,7 +1293,7 @@ def _sort_coo(
indices[pos:] += fill_value_count
break

Check warning on line 1294 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1290-L1294

Added lines #L1290 - L1294 were not covered by tests

result_indices[offset:offset+len(indices)] = indices
result_indices[offset : offset + len(indices)] = indices
offset += len(indices)

Check warning on line 1297 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1296-L1297

Added lines #L1296 - L1297 were not covered by tests

sort_coords[:] = result_indices

Check warning on line 1299 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1299

Added line #L1299 was not covered by tests
Expand Down
26 changes: 19 additions & 7 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,15 +1779,18 @@ def test_input_validation(self, func):
func(self.arr)


@pytest.mark.parametrize("arr", [
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64)
])
@pytest.mark.parametrize(
"arr",
[
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64),
],
)
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("descending", [False, True])
def test_sort(arr, fill_value, axis, descending):
s_arr = sparse.COO.from_numpy(arr, fill_value)
s_arr = sparse.COO.from_numpy(arr, fill_value)

result = sparse.sort(s_arr, axis=axis, descending=descending)
expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis)
Expand All @@ -1798,10 +1801,19 @@ def test_sort(arr, fill_value, axis, descending):
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
@pytest.mark.parametrize(
"indices,axis",
[([1], 0,), ([2, 1], 1), ([1, 2, 3], 2), ([2, 3], -1), ([5, 3, 7, 8], None)]
[
(
[1],
0,
),
([2, 1], 1),
([1, 2, 3], 2),
([2, 3], -1),
([5, 3, 7, 8], None),
],
)
def test_take(fill_value, indices, axis):
arr = np.arange(24).reshape((2,3,4))
arr = np.arange(24).reshape((2, 3, 4))

s_arr = sparse.COO.from_numpy(arr, fill_value)

Expand Down

0 comments on commit 33da5ff

Please sign in to comment.