Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EnsembleModule #1359

Closed
wants to merge 10 commits into from
Closed

Implement EnsembleModule #1359

wants to merge 10 commits into from

Conversation

smorad
Copy link
Contributor

@smorad smorad commented Jul 5, 2023

Description

This adds support for module ensembles once pytorch/tensordict#478 lands. cc @vmoens, @matteobettini, @Acciorocketships.

Motivation and Context

This is necessary for implementing REDQ and various other Q learning algorithms that use ensembles at collection time. See #1344 for more.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 5, 2023
@smorad
Copy link
Contributor Author

smorad commented Jul 10, 2023

I've rebased this upon the latest reset_parameters_recursive commit. The parameter tensordict approach to reset_parameters_recursive really simplifies this module, thanks for the tips!

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superb I love this feature!
It is meant to be part of RL, but there could be a usage for this in tensordict, wdyt?

not to self: should be integrated more broadly in the losses to alleviate the burden of expanding params for multiple q value nets

@@ -0,0 +1,102 @@
import torch
from tensordict import TensorDict, TensorDictBase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing headers

from torch import nn


class EnsembleModule(TensorDictModuleBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be added manually to the doc under docs/source/reference/modules.rst

torchrl/modules/tensordict_module/ensemble.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/ensemble.py Show resolved Hide resolved
Comment on lines 92 to 94
assert (
TensorDictBase is not None
), "Ensembles are functional and require passing a TensorDict of parameters to reset_parameters_recursive"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we want to check that parameters are not None?
The function will break before this since we don't have a default value.

Other comment: We don't use assert in the lib, only in the tests. In this case, a TypeError or a ValueError would be appropriate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yes, should be parameters and this is indeed unreachable code. I could make the default parameters=None here, then raise an exception if parameters == None but that seems confusing.

All I wanted was a descriptive error message because it will not be clear to the user why my_ensemble.reset_parameters_recursive() is failing (they need to explicitly pass in params).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, let's make the default None and raise the ValueError if it is the default

"""Resets the parameters of all the copies of the module.

Args:
stacked_params_td: A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stacked_params_td: A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict
parameters: A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict

@smorad
Copy link
Contributor Author

smorad commented Jul 10, 2023

Superb I love this feature! It is meant to be part of RL, but there could be a usage for this in tensordict, wdyt?

Are you suggesting to move EnsembleModule into tensordict.nn? Yeah, I suppose that makes sense. Let's get it to the point where you are happy with it here, then I will abort this PR and move it to tensordict.nn.

smorad and others added 3 commits July 10, 2023 17:42
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
@smorad
Copy link
Contributor Author

smorad commented Jul 11, 2023

Closing in favor of pytorch/tensordict#485

@smorad smorad closed this Jul 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants