Skip to content

Commit

Permalink
Fix argpartition cuda bug in torch (#19634)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Apr 28, 2024
1 parent 54e15eb commit 81c0047
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
7 changes: 2 additions & 5 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,19 +1617,16 @@ def slogdet(x):

def argpartition(x, kth, axis=-1):
x = convert_to_tensor(x, "int32")

x = torch.transpose(x, axis, -1)
bottom_ind = torch.topk(-x, kth + 1)[1]

def set_to_zero(a, i):
a[i] = 0
a[i] = torch.zeros(1, dtype=a.dtype, device=a.device)
return a

for _ in range(x.dim() - 1):
set_to_zero = torch.vmap(set_to_zero)
proxy = set_to_zero(ones(x.shape, dtype=torch.int32), bottom_ind)

proxy = set_to_zero(torch.ones_like(x, dtype=torch.int32), bottom_ind)
top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1]

out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1)
return cast(torch.transpose(out, -1, axis), "int32")
2 changes: 1 addition & 1 deletion keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4328,7 +4328,7 @@ def test_argpartition(self):
self.assertAllClose(knp.argpartition(x, 2), np.argpartition(x, 2))
self.assertAllClose(knp.Argpartition(2)(x), np.argpartition(x, 2))

x = np.array([[3, 4, 2], [1, 3, 1]])
x = np.array([[3, 4, 2], [1, 3, 4]])
self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1))
self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1))

Expand Down

0 comments on commit 81c0047

Please sign in to comment.