Skip to content

Commit

Permalink
Remove use of deprecated NumPy list indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 29, 2022
1 parent 1e36f60 commit 8086efe
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aesara/tensor/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,18 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
idx[axis] = slice(-k, None) if k > 0 else slice(-k)

if not op.return_indices:
zv = np.partition(x, -k, axis=axis)[idx]
zv = np.partition(x, -k, axis=axis)[tuple(idx)]
return zv
elif op.return_values:
zi = np.argpartition(x, -k, axis=axis)[idx]
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
idx2 = tuple(
np.arange(s).reshape((s,) + (1,) * (ndim - i - 1)) if i != axis else zi
for i, s in enumerate(x.shape)
)
zv = x[idx2]
return zv, zi.astype(idx_dtype)
else:
zi = np.argpartition(x, -k, axis=axis)[idx]
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
return zi.astype(idx_dtype)


Expand Down

0 comments on commit 8086efe

Please sign in to comment.