Skip to content

Commit 0f64cc7

Browse files
authored
Fix scatter scalar test (#122)
1 parent 29eac6e commit 0f64cc7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/test_mps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4316,8 +4316,8 @@ def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True)
43164316
helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
43174317

43184318
# Test pytorch scatter_add and scatter for scalar input
4319-
def test_scatter_add(self):
4320-
def helper(shape, idx_dtype=torch.int64, do_add=True):
4319+
def test_scatter_add_scalar(self):
4320+
def helper(idx_dtype=torch.int64, do_add=True):
43214321
cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
43224322
x = cpu_x.detach().clone().to('mps').requires_grad_()
43234323

0 commit comments

Comments
 (0)