Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience). #47889

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Oct 3, 2024

New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience).

  • Adds a generic _forward method to be used by RLModule subclasses (by default, all _forward_[inference|exploration|train] call this)
  • Users can still override _forward_[inference|exploration|train] to individualize behavior for the different algo phases.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some nits in the docstrings.

action_dist_class_exploration = (
self.module[module_id].unwrapped().get_exploration_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to use unwrapped in case DDP is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! DDP already wraps this method to use the unwrapped underlying RLModule, so this is ok here.

@@ -91,12 +89,14 @@ def possibly_masked_mean(data_):

# Compute a value function loss.
if config.use_critic:
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in forward_train, but I guess the problem was that forward_train was run multiple times. So, my guess: works here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point, I think you are right. Let's see what the tests say ...

encoder_outs = self.encoder(batch)
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo features is a misleading term here as features are usually the inputs to a neural network or model in general. embeddings might fit better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!
Changed everywhere to Columns.EMBEDDINGS and argument name: compute_values(self, batch, embedding=None).

batch: Dict[str, Any],
features: Optional[Any] = None,
) -> TensorType:
if features is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using features in batch and instead passing it in as an extra argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This would mean that we would have to change the batch (add a new key to it) during the update procedure, which might clash when we have to (torch) compile this operation. We had the same problem with tf-static graph.
Also, design-wise, I think it's cleaner not to change the batch after it comes out of a connector pipeline. Separation of concerns: Only connector pipelines are ever allowed to write to a batch:

connector -> train_batch  # <- read-only from here on
fwd_out = rl_module.forward_train(train_batch)
losses = rl_module.compute_losses(train_batch, fwd_out)

batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).
def items(self) -> ItemsView[ModuleID, RLModule]:
"""Returns a keys view over the module IDs in this MultiRLModule."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"keys" -> "items"

) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_exploration pass.
def values(self) -> ValuesView[ModuleID]:
"""Returns a keys view over the module IDs in this MultiRLModule."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"keys" -> "values"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Fixed for values() as well.

By default, RLlib assumes that the module is non-recurrent if the initial
state is an empty dict and recurrent otherwise.
This behavior can be overridden by implementing this method.
Note that RLlib's distribution classes all implement the `Distribution`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! This makes it clear why!

values = self._values(features).squeeze(-1)
# Same logic as _forward, but also return features to be used by value function
# branch during training.
features, state_outs = self._compute_features_and_state_outs(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, in my very own opinion I think "features" is a misleading name as it is usually used for the inputs of a neural network.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed everywhere. Great catch and suggestion! Makes things much clearer.

…odule_do_over_bc_default_module_03_common_forward
…odule_do_over_bc_default_module_03_common_forward
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) October 5, 2024 12:21
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Oct 5, 2024
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge October 5, 2024 12:29
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) October 5, 2024 13:51
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge October 5, 2024 14:01
@sven1977 sven1977 enabled auto-merge (squash) October 5, 2024 14:30
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge October 5, 2024 16:31
@sven1977 sven1977 enabled auto-merge (squash) October 5, 2024 17:48
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge October 5, 2024 21:29
@sven1977 sven1977 enabled auto-merge (squash) October 5, 2024 21:52
@sven1977 sven1977 merged commit e182e19 into ray-project:master Oct 5, 2024
6 checks passed
@sven1977 sven1977 deleted the rl_module_do_over_bc_default_module_03_common_forward branch October 6, 2024 06:16
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
self.module[module_id].unwrapped().get_exploration_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
action_dist_class_exploration = module.get_exploration_action_dist_cls()
Copy link

@smanolloff smanolloff Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I have a question here: shouldn't exploration or inference dist be used? In a similar fashion to GetActions connector's logic?

This affects KL loss calculation which might end up using a different distribution class (exploration_dist) than the one used for the surrogate loss (inference_dist). It is somewhat an edge case since the two are actually the same as per TorchRLModule, but users sub-classing it would be unaware.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants