-
Notifications
You must be signed in to change notification settings - Fork 108
[Feature] Support for torch.autograd.grad
#1417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Great feature! Why would lazy TDs not work? I think they should (tensors in the lazy stack can be leaves of the graph). |
|
Okay thanks, I will work on that next week! And for lazy I just don't know how they are handled but will dig. |
|
hello |
|
Also took some time off, but I'll work on that on Friday, and thx for the proposition I might need help ^^ |
|
@vmoens I have an issue with import torch
from tensordict import TensorDict
from tensordict import LazyStackedTensorDict
td = LazyStackedTensorDict(TensorDict(), TensorDict(), stack_dim=0)
td["a"] = [torch.ones(1), torch.zeros(1)]
td.requires_grad_()
out = td + 1
torch.autograd.grad(out["a"], td["a"], torch.ones_like(out["a"]))
|
|
I also put a small bit of docs in |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good progress!
Can you investigate these two errors?
FAILED test/test_tensordict.py::TestTensorDicts::test_autograd_grad[nested_stacked_td-device3] - RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
FAILED test/test_tensordict.py::TestTensorDicts::test_autograd_grad[stacked_td-device4] - RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
Also a similar test with tensorclass would be nice!
This is due to the mixed computation graph introduced by td = LazyStackedTensorDict(TensorDict(), TensorDict(), stack_dim=0)
td["a"] = [torch.ones(1), torch.zeros(1)]
td.requires_grad_()
out = td + 1
torch.autograd.grad(out["a"], td["a"], torch.ones_like(out["a"]))(td+1)["a"]: td[0]["a"] -> td[0]["a"]+1 ; td[1]["a"] -> td[2]["a"]+1 ; (td[0]["a"]+1, td[2]["a"]+1) ->stack(...) But: td["a"]: (td[0]["a"], td[2]["a"]) ->stack(...) never appears in the previous computation graph. |
You could do |
|
I see thx, I'll split all Lazy then and merge them back after. |
|
@vmoens I made it work but it feels hacky, lmk what you think |
|
I used a simpler approach, can you check that it makes sense? |
|
It definitely makes sense and is much simpler! But I am not following:
Do you want to add support for passing multiple tensordicts as inputs/outputs? Or do you mean when the fields are of different type? I think the former could be supported using nested TensorDict (but wanted to get the single grad to work first), for the other, I don't see the problem. Could you share a quick snippet? |
|
I mean that if you have |
|
Gotcha, yes, it is an edge case, I think, I might have an example where it might be useful though, and I'd say often outputs and grad_outputs do have the same shape. But I am not getting why converting to lazy would do the trick, will it split the stack dimension of the full tensordict? But even two stacked tensordicts could have the same shape but not stacked similarily right? I think we should keep tuples since the issue was only with the inputs. Like keeping this: tup_grad_outputs = tuple(grad_outputs[k] for k in outputs.keys(True, True))
else:
tup_grad_outputs = None
tup_outputs = tuple(outputs[k] for k in outputs.keys(True, True)) |
|
It also ensure same order of outputs and grad_outputs |
Description
Implement a
torch.autograd.gradfor TensorDict.It is still in progress, but comments are welcome. Especially on:
Motivation and Context
Closes #1416
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!