Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch vmap limitation #1031

Open
Ch0ronomato opened this issue Oct 12, 2024 · 5 comments
Open

Pytorch vmap limitation #1031

Ch0ronomato opened this issue Oct 12, 2024 · 5 comments

Comments

@Ch0ronomato
Copy link
Contributor

Ch0ronomato commented Oct 12, 2024

Description

If you call .item() (or inadvertently call .item()) in torch, vmap will fail. This means supporting things like scalarloop are difficult when having anything vmap

for i in steps: <--- this calls .item()

The error in question

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

pytorch documentation: https://pytorch.org/functorch/stable/ux_limitations.html#data-dependent-operations-item

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 14, 2024

If you call .item() (or inadvertently call .item()) in torch, vmap will fail. This means supporting things like scalarloop are difficult when having anything vmap

That may mean we cannot implement something with a batch number of steps. That's fine if we mention it as a limitation and raise NotImplementedError condintionally in the dispatch. We can know whether we need batch steps or not by checking if the type of the nsteps parameter has any non-broadcastable dimensions

@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Oct 14, 2024 via email

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 14, 2024

The ScalarLoop Op can't really be vectorized in that sense so that may be fine. What does unbind do?

@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Oct 14, 2024

It'll just give you an iterator over the dimension you specify; almost like it's "breaking" the tensor at that dim.

https://pytorch.org/docs/stable/generated/torch.unbind.html

x = torch.concat(tuple(torch.tril(torch.ones(3, 3) * i) for i in range(3))).reshape(-1, 3, 3)
print(x)
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[1., 0., 0.],
         [1., 1., 0.],
         [1., 1., 1.]],

        [[2., 0., 0.],
         [2., 2., 0.],
         [2., 2., 2.]]])

for t in x.unbind(0):
  print(t.sum())

tensor(0.)
tensor(6.)
tensor(12.)

You can also stack it how we would vmap; it's just a little different looking.

@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Oct 15, 2024

I didn't give an example of it stacking, here it is

>>> x = torch.concat([torch.ones((1, 3, 3)) * i for i in range(1, 5)]).repeat((5, 1, 1, 1))
>>> x.shape
torch.Size([5, 4, 3, 3])
>>> counter = 0
>>> for t1 in x.unbind(0):
...     for t2 in t1.unbind(0):
...             print(counter, t2.shape)
...             counter += 1
... 
0 torch.Size([3, 3])
1 torch.Size([3, 3])
2 torch.Size([3, 3])
3 torch.Size([3, 3])
4 torch.Size([3, 3])
5 torch.Size([3, 3])
6 torch.Size([3, 3])
7 torch.Size([3, 3])
8 torch.Size([3, 3])
9 torch.Size([3, 3])
10 torch.Size([3, 3])
11 torch.Size([3, 3])
12 torch.Size([3, 3])
13 torch.Size([3, 3])
14 torch.Size([3, 3])
15 torch.Size([3, 3])
16 torch.Size([3, 3])
17 torch.Size([3, 3])
18 torch.Size([3, 3])
19 torch.Size([3, 3])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants