-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward
to further simplify the user experience).
#47889
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Some nits in the docstrings.
action_dist_class_exploration = ( | ||
self.module[module_id].unwrapped().get_exploration_action_dist_cls() | ||
) | ||
action_dist_class_train = module.get_train_action_dist_cls() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to use unwrapped
in case DDP is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! DDP already wraps this method to use the unwrapped
underlying RLModule, so this is ok here.
@@ -91,12 +89,14 @@ def possibly_masked_mean(data_): | |||
|
|||
# Compute a value function loss. | |||
if config.use_critic: | |||
value_fn_out = fwd_out[Columns.VF_PREDS] | |||
value_fn_out = module.compute_values( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in forward_train
, but I guess the problem was that forward_train
was run multiple times. So, my guess: works here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point, I think you are right. Let's see what the tests say ...
encoder_outs = self.encoder(batch) | ||
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo features
is a misleading term here as features
are usually the inputs to a neural network or model in general. embeddings
might fit better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right!
Changed everywhere to Columns.EMBEDDINGS
and argument name: compute_values(self, batch, embedding=None)
.
batch: Dict[str, Any], | ||
features: Optional[Any] = None, | ||
) -> TensorType: | ||
if features is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using features
in batch
and instead passing it in as an extra argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. This would mean that we would have to change the batch (add a new key to it) during the update procedure, which might clash when we have to (torch) compile this operation. We had the same problem with tf-static graph.
Also, design-wise, I think it's cleaner not to change the batch after it comes out of a connector pipeline. Separation of concerns: Only connector pipelines are ever allowed to write to a batch:
connector -> train_batch # <- read-only from here on
fwd_out = rl_module.forward_train(train_batch)
losses = rl_module.compute_losses(train_batch, fwd_out)
batch: The batch of multi-agent data (i.e. mapping from module ids to | ||
individual modules' batches). | ||
def items(self) -> ItemsView[ModuleID, RLModule]: | ||
"""Returns a keys view over the module IDs in this MultiRLModule.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"keys" -> "items"
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: | ||
"""Runs the forward_exploration pass. | ||
def values(self) -> ValuesView[ModuleID]: | ||
"""Returns a keys view over the module IDs in this MultiRLModule.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"keys" -> "values"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch! Fixed for values()
as well.
By default, RLlib assumes that the module is non-recurrent if the initial | ||
state is an empty dict and recurrent otherwise. | ||
This behavior can be overridden by implementing this method. | ||
Note that RLlib's distribution classes all implement the `Distribution` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! This makes it clear why!
values = self._values(features).squeeze(-1) | ||
# Same logic as _forward, but also return features to be used by value function | ||
# branch during training. | ||
features, state_outs = self._compute_features_and_state_outs(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As before, in my very own opinion I think "features" is a misleading name as it is usually used for the inputs of a neural network.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed everywhere. Great catch and suggestion! Makes things much clearer.
…odule_do_over_bc_default_module_03_common_forward
…odule_do_over_bc_default_module_03_common_forward
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…eric `_forward` to further simplify the user experience). (ray-project#47889) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
self.module[module_id].unwrapped().get_exploration_action_dist_cls() | ||
) | ||
action_dist_class_train = module.get_train_action_dist_cls() | ||
action_dist_class_exploration = module.get_exploration_action_dist_cls() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I have a question here: shouldn't exploration
or inference
dist be used? In a similar fashion to GetActions
connector's logic?
This affects KL loss calculation which might end up using a different distribution class (exploration_dist
) than the one used for the surrogate loss (inference_dist
). It is somewhat an edge case since the two are actually the same as per TorchRLModule, but users sub-classing it would be unaware.
New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic
_forward
to further simplify the user experience)._forward
method to be used by RLModule subclasses (by default, all_forward_[inference|exploration|train]
call this)_forward_[inference|exploration|train]
to individualize behavior for the different algo phases.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.