diff --git a/test/test_vmap.py b/test/test_vmap.py index 6dbec6a7e..6250c7ebc 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3542,7 +3542,7 @@ def f3(x, idx): def test_nested_advanced_indexing(self, device): e = torch.rand(7, 4, device=device) - idx = torch.LongTensor([0, 1], device=device).view(2, 1) + idx = torch.tensor([0, 1], device=device).view(2, 1) # simple reference implementation for comparison def _fake_vmap(f, in_dims=0, out_dims=0):