-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Implement ensemble reduce #1360
base: main
Are you sure you want to change the base?
Conversation
@@ -16,17 +16,27 @@ | |||
UnboundedContinuousTensorSpec, | |||
) | |||
from torchrl.envs.utils import set_exploration_type, step_mdp | |||
from torchrl.modules import LSTMModule, NormalParamWrapper, SafeModule, TanhNormal | |||
from torchrl.modules import ( | |||
EnsembleModule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from tensordict
return tensordict_reduced | ||
|
||
|
||
class EnsembleModule(TensorDictModuleBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate with tensordict
self, | ||
in_keys: list[str], | ||
out_keys: list[str], | ||
reduce_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.min(dim=0).values, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best way to approach this is
reduce_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.min(dim=0).values, | |
reduce_function: Callable[[torch.Tensor], torch.Tensor] = None, |
then
if reduce_function is None:
reduce_function = lambda x: x.min(dim=0).values
But do we really want to have "min"
as a default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about Enum to select which op?
REDUCE_OP.min
REDUCE_OP.max
REDUCE_OP.sum
REDUCE_OP.any
REDUCE_OP.all
And then we can ask the user along which dim the op should be done.
That dim would also be the dim along which we do the tensordict indexing.
assert ( | ||
len(in_keys) == len(out_keys) == 1 | ||
), "Reduce only supports one input and one output" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a ValueError
@smorad Friendly ping about this :) |
Description
This implements a reduction module for use with ensembles.
Motivation and Context
This relies on #1359 to make progress on #1344.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!