-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Torch trainer #31628
[RLlib] Torch trainer #31628
Conversation
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…nabled Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
2. don't do numpy conversion for batch on the base class Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
# TODO (Kourosh): This method is built for multi-agent. While it is still | ||
# possible to write single-agent losses, it may become confusing to users. We | ||
# should find a way to allow them to specify single-agent losses as well, | ||
# without having to think about one extra layer of hierarchy for module ids. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if we have 2 functions:
1 called compute loss single agent, and 1 called compute loss multi agent.
if you implement compute loss single agent, update calls that for all agents. If you call compute loss multi agent then update calls that function.
if you implement both we throw an error.
the one downside is I think it involves using the overrides decorator from ray/rllib
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should roll a bit with multi-agent as first class citizen and see the impression of users. I have written BC trainer and honestly it's not that confusing. Especially if you think about the way people usually write these stuff for the first time, is that they most likely put breakpoints in the loss computation code to see what data they'll get and go from there. My hypothesis is that for an average "advanced" user (who is writing their own loss / algorithm) having the input as a MultiAgentSample batch is a good indicator of what they need to do plus they'll see examples of how other algorithm's losses are written. The advantage of this is that there is less api that a user would have to cope with hence a lower cognitive load.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a proposal for this in a follow up PR.
self._module[module_id].to(self._device) | ||
if self.distributed: | ||
self._module.add_module( | ||
module_id, TorchDDPRLModule(self._module[module_id]), override=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks so nice
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…e the unittest is locateed and import torch would import the relative torch module instead of the global torch moduel Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
* added quick cleanups to trainer_runner. * created test_trainer_runner * added bc_rl_trainer * moved the DDPRLModuleWrapper outside of RLTrainer + lint * merged tf and torch train_runner tests Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Andrea Pisoni <andreapiso@gmail.com>
Why are these changes needed?
This creates the torch trainer along with its unittest. There will be a few clean up PRs after this:
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.