You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Ensembled modules should also support optimization with an API consistent with ordinary modules. However, passing EnsembledModule.parameters() to an optimizer, as usual, does not yield correct behavior.
importtorchimporttorch.nnasnnfromtensordictimportTensorDictfromtensordict.nnimportEnsembleModule, TensorDictModulem=TensorDictModule(nn.Linear(128, 1), ["a"], ["a_out"])
m=EnsembleModule(m, 3, expand_input=True)
x=TensorDict({"a": torch.randn(32, 128)}, [32])
# this does not workparams=m.parameters()
# this worksparams=list(m.params_td.values(True, True))
forparaminparams:
param.retain_grad() # cannot optimize non-leaf tensorsopt=torch.optim.Adam(params)
foriinrange(10):
y=m(x)
loss=y["a_out"].sum()
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
Solution
A direct solution would be to override EnsembledModule.named_parameters(), which however may lead to other inconsistencies. It seems the issue is with nn.Parameter, which shares data with the tensordict from which it is created but does not receive gradients.
[ x ] I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
Thanks for reporting!
We need to make the params apparent within tensordict modules in some way, even if they're contained in the tensordict.
One option is to put them in parameter list, but I'm not fan of that solution...
I'm currently facing a similar problem with the loss modules in torchrl where we hack our way to register the params contained in a tensordict. I'll try to come up with an elegant solution for both and ping you once i'm done!
Motivation
Ensembled modules should also support optimization with an API consistent with ordinary modules. However, passing
EnsembledModule.parameters()
to an optimizer, as usual, does not yield correct behavior.Solution
A direct solution would be to override
EnsembledModule.named_parameters()
, which however may lead to other inconsistencies. It seems the issue is withnn.Parameter
, which shares data with thetensordict
from which it is created but does not receive gradients.The text was updated successfully, but these errors were encountered: