Skip to content

Commit

Permalink
fix test device loc, add nested version
Browse files Browse the repository at this point in the history
  • Loading branch information
samdow committed May 9, 2022
1 parent f313ca4 commit 5edf834
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3529,15 +3529,43 @@ def f2(x, idx):
def f3(x, idx):
return x[:, :, idx]

inps = (torch.randn(5, 5, 5), torch.randn(5, 5, 5, 5), torch.randn(5, 5, 5, 5, 5))
idxes = (torch.tensor([0, 1, 2]),
torch.tensor([0, 1, 2]).reshape(3, 1),
torch.tensor([0, 1, 2]).reshape(3, 1, 1))
inps = (torch.randn(5, 5, 5, device=device),
torch.randn(5, 5, 5, 5, device=device),
torch.randn(5, 5, 5, 5, 5, device=device))
idxes = (torch.tensor([0, 1, 2], device=device),
torch.tensor([0, 1, 2], device=device).reshape(3, 1),
torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1))
for (inp, idx) in itertools.product(inps, idxes):
test(f, (inp, idx))
test(f2, (inp, idx))
test(f3, (inp, idx))

def test_nested_advanced_indexing(self, device):
e = torch.rand(7, 4, device=device)
idx = torch.LongTensor([0, 1], device=device).view(2, 1)

# simple reference implementation for comparison
def _fake_vmap(f, in_dims=0, out_dims=0):
def w(input):
r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))]
return torch.stack(r, out_dims)

return w

def with_vmap(_vmap):
def g(idx_):
def f(e_):
return e_[idx_]

return _vmap(f, in_dims=1)(e)

r = _vmap(g)(idx)
return r

a = with_vmap(vmap)
b = with_vmap(_fake_vmap)
self.assertEqual(a, b)


class TestRandomness(TestCase):
def _reset_random(self, generator, orig_state, use_generator, seed):
Expand Down

0 comments on commit 5edf834

Please sign in to comment.