Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Torch trainer #31628

Merged
merged 43 commits into from
Jan 20, 2023
Merged

[RLlib] Torch trainer #31628

merged 43 commits into from
Jan 20, 2023

Conversation

kouroshHakha
Copy link
Contributor

@kouroshHakha kouroshHakha commented Jan 12, 2023

Why are these changes needed?

This creates the torch trainer along with its unittest. There will be a few clean up PRs after this:

  • Consolidate torch and tf RLTrainer ensuring the super-class/sub-class relation is intuitive and well-documented.
  • increase the coverage of the unittest (some public methods like compile_results are not tested). Test scenarios that are multi-gpu multi-node, as well as simple non-distributed versions that people may use for getting a sense about the implementations.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@kouroshHakha kouroshHakha added tests-ok The tagger certifies test failures are unrelated and assumes personal liability. and removed tests-ok The tagger certifies test failures are unrelated and assumes personal liability. labels Jan 14, 2023
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…nabled

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
2. don't do numpy conversion for batch on the base class

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
# TODO (Kourosh): This method is built for multi-agent. While it is still
# possible to write single-agent losses, it may become confusing to users. We
# should find a way to allow them to specify single-agent losses as well,
# without having to think about one extra layer of hierarchy for module ids.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we have 2 functions:

1 called compute loss single agent, and 1 called compute loss multi agent.

if you implement compute loss single agent, update calls that for all agents. If you call compute loss multi agent then update calls that function.

if you implement both we throw an error.

the one downside is I think it involves using the overrides decorator from ray/rllib

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should roll a bit with multi-agent as first class citizen and see the impression of users. I have written BC trainer and honestly it's not that confusing. Especially if you think about the way people usually write these stuff for the first time, is that they most likely put breakpoints in the loss computation code to see what data they'll get and go from there. My hypothesis is that for an average "advanced" user (who is writing their own loss / algorithm) having the input as a MultiAgentSample batch is a good indicator of what they need to do plus they'll see examples of how other algorithm's losses are written. The advantage of this is that there is less api that a user would have to cope with hence a lower cognitive load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a proposal for this in a follow up PR.

self._module[module_id].to(self._device)
if self.distributed:
self._module.add_module(
module_id, TorchDDPRLModule(self._module[module_id]), override=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks so nice

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…e the unittest is locateed and import torch would import the relative torch module instead of the global torch moduel

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@gjoliver gjoliver merged commit 66d6ce6 into ray-project:master Jan 20, 2023
andreapiso pushed a commit to andreapiso/ray that referenced this pull request Jan 22, 2023
* added quick cleanups to trainer_runner.
* created test_trainer_runner
* added bc_rl_trainer
* moved the DDPRLModuleWrapper outside of RLTrainer + lint
* merged tf and torch train_runner tests

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Andrea Pisoni <andreapiso@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants