-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch vmap limitation #1031
Comments
That may mean we cannot implement something with a batch number of steps. That's fine if we mention it as a limitation and raise NotImplementedError condintionally in the dispatch. We can know whether we need batch steps or not by checking if the type of the nsteps parameter has any non-broadcastable dimensions |
There is the alternative of using `.unbind` and then just looping, but its
not really vectorized at that point.
…On Mon, Oct 14, 2024, 2:56 AM Ricardo Vieira ***@***.***> wrote:
If you call .item() (or inadvertently call .item()) in torch, vmap will
fail. This means supporting things like scalarloop are difficult when
having anything vmap
That may mean we cannot implement something with a batch number of steps.
That's fine we mention it as a limitation and raise NotImplementedError in
the dispatch. We can know whether we need batch steps or not by checking if
the type of the nsteps parameter has any non-broadcastable dimensions
—
Reply to this email directly, view it on GitHub
<#1031 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAQRAH646HTINU7G26DTWWDZ3OINHAVCNFSM6AAAAABP2ZQEO6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMJQGY2TMOJQHE>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
The ScalarLoop Op can't really be vectorized in that sense so that may be fine. What does unbind do? |
It'll just give you an iterator over the dimension you specify; almost like it's "breaking" the tensor at that dim. https://pytorch.org/docs/stable/generated/torch.unbind.html x = torch.concat(tuple(torch.tril(torch.ones(3, 3) * i) for i in range(3))).reshape(-1, 3, 3)
print(x)
tensor([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]],
[[2., 0., 0.],
[2., 2., 0.],
[2., 2., 2.]]])
for t in x.unbind(0):
print(t.sum())
tensor(0.)
tensor(6.)
tensor(12.) You can also stack it how we would vmap; it's just a little different looking. |
I didn't give an example of it stacking, here it is >>> x = torch.concat([torch.ones((1, 3, 3)) * i for i in range(1, 5)]).repeat((5, 1, 1, 1))
>>> x.shape
torch.Size([5, 4, 3, 3])
>>> counter = 0
>>> for t1 in x.unbind(0):
... for t2 in t1.unbind(0):
... print(counter, t2.shape)
... counter += 1
...
0 torch.Size([3, 3])
1 torch.Size([3, 3])
2 torch.Size([3, 3])
3 torch.Size([3, 3])
4 torch.Size([3, 3])
5 torch.Size([3, 3])
6 torch.Size([3, 3])
7 torch.Size([3, 3])
8 torch.Size([3, 3])
9 torch.Size([3, 3])
10 torch.Size([3, 3])
11 torch.Size([3, 3])
12 torch.Size([3, 3])
13 torch.Size([3, 3])
14 torch.Size([3, 3])
15 torch.Size([3, 3])
16 torch.Size([3, 3])
17 torch.Size([3, 3])
18 torch.Size([3, 3])
19 torch.Size([3, 3]) |
Description
If you call
.item()
(or inadvertently call.item()
) in torch, vmap will fail. This means supporting things like scalarloop are difficult when having anything vmapThe error in question
pytorch documentation: https://pytorch.org/functorch/stable/ux_limitations.html#data-dependent-operations-item
The text was updated successfully, but these errors were encountered: