Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement ensemble reduce #1360

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

smorad
Copy link
Contributor

@smorad smorad commented Jul 5, 2023

Description

This implements a reduction module for use with ensembles.

Motivation and Context

This relies on #1359 to make progress on #1344.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 5, 2023
@@ -16,17 +16,27 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import LSTMModule, NormalParamWrapper, SafeModule, TanhNormal
from torchrl.modules import (
EnsembleModule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from tensordict

return tensordict_reduced


class EnsembleModule(TensorDictModuleBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate with tensordict

self,
in_keys: list[str],
out_keys: list[str],
reduce_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.min(dim=0).values,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best way to approach this is

Suggested change
reduce_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.min(dim=0).values,
reduce_function: Callable[[torch.Tensor], torch.Tensor] = None,

then

if reduce_function is None:
    reduce_function = lambda x: x.min(dim=0).values

But do we really want to have "min" as a default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Enum to select which op?

REDUCE_OP.min
REDUCE_OP.max
REDUCE_OP.sum
REDUCE_OP.any
REDUCE_OP.all

And then we can ask the user along which dim the op should be done.
That dim would also be the dim along which we do the tensordict indexing.

Comment on lines +52 to +54
assert (
len(in_keys) == len(out_keys) == 1
), "Reduce only supports one input and one output"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a ValueError

@vmoens vmoens changed the title Implement ensemble reduce [Feature] Implement ensemble reduce Aug 30, 2023
@vmoens
Copy link
Contributor

vmoens commented Sep 7, 2023

@smorad Friendly ping about this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants