-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement EnsembleModule #1359
Implement EnsembleModule #1359
Conversation
I've rebased this upon the latest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Superb I love this feature!
It is meant to be part of RL, but there could be a usage for this in tensordict, wdyt?
not to self: should be integrated more broadly in the losses to alleviate the burden of expanding params for multiple q value nets
@@ -0,0 +1,102 @@ | |||
import torch | |||
from tensordict import TensorDict, TensorDictBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing headers
from torch import nn | ||
|
||
|
||
class EnsembleModule(TensorDictModuleBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be added manually to the doc under docs/source/reference/modules.rst
assert ( | ||
TensorDictBase is not None | ||
), "Ensembles are functional and require passing a TensorDict of parameters to reset_parameters_recursive" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we want to check that parameters are not None?
The function will break before this since we don't have a default value.
Other comment: We don't use assert
in the lib, only in the tests. In this case, a TypeError or a ValueError would be appropriate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops yes, should be parameters
and this is indeed unreachable code. I could make the default parameters=None
here, then raise an exception if parameters == None
but that seems confusing.
All I wanted was a descriptive error message because it will not be clear to the user why my_ensemble.reset_parameters_recursive()
is failing (they need to explicitly pass in params).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, let's make the default None and raise the ValueError if it is the default
"""Resets the parameters of all the copies of the module. | ||
|
||
Args: | ||
stacked_params_td: A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stacked_params_td: A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict | |
parameters: A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict |
Are you suggesting to move |
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Closing in favor of pytorch/tensordict#485 |
Description
This adds support for module ensembles once pytorch/tensordict#478 lands. cc @vmoens, @matteobettini, @Acciorocketships.
Motivation and Context
This is necessary for implementing REDQ and various other Q learning algorithms that use ensembles at collection time. See #1344 for more.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist