-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] PPO torch RLTrainer #31801
[RLlib] PPO torch RLTrainer #31801
Conversation
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…nabled Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
2. don't do numpy conversion for batch on the base class Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…project#32070) Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…er get_weights() Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
I have to rebase once the pre-reqs are merged. |
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some comments.
rllib/algorithms/algorithm.py
Outdated
trainer_bundle = [ | ||
{ | ||
"CPU": cf.num_cpus_per_trainer_worker, | ||
"GPU": int(cf.num_gpus_per_trainer_worker > 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overriding user configuration is confusing.
maybe consider validate and raise error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it just implements what's in the docs, if num_gpus_trainer_worker > 1 and num_trainer_workers = 0 we will just use one gpu. This is just enforcing that from tune's perspective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. I just want to provide my personal perspective here.
I feel like we always try to "auto-correct" for our users. for example, compute and override some parameters based on values of other parameters. we also put up a lot of explanations / docs around the "corrections" we may potentially do.
if it's 100% up to me, I would instead just raise error and telling users that I am seeing contradictory configs, because num_trainer_workers=0
doesn't work with num_gpus_per_trainer_worker > 0
. now, we have no idea what the actual user intention is. did they mis-specify num_trainer_workers
or did they mis-specify num_gpus_per_trainer_worker
? the only one who can fix this is the user him/herself.
now, this is just some of my thoughts. I can't enforce this, so hopefully you understand what I meant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/algorithms/ppo/ppo.py
Outdated
@@ -201,12 +216,16 @@ def training( | |||
self.lr_schedule = lr_schedule | |||
if use_critic is not NotProvided: | |||
self.use_critic = use_critic | |||
# TODO (Kourosh) This is experimental. Set rl_trainer_hps parameters as | |||
# well. Don't forget to remote .use_critic from algorithm config. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does remote .use_critic
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: remove :) not remote
# subtract that to get the total set of pids to update. | ||
# TODO (Kourosh): We need to make a better design for the hierarchy of the | ||
# train results, so that all the policy ids end up in the same level. | ||
policies_to_update = set(train_results["loss"].keys()) - {"total_loss"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think you should be explicit, and not rely on some keys on the result dict to tell you which policies need to be updated.
train_results should not be used as control messages basically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to hear more about what you mean by being more explicit? I was planning on revisiting the train_results structure to remove these requirements in the next round of updates, But would love to hear your thoughts on how it should ideally look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not opinionated about how the result dict should look like.
but I do think we shouldn't use it as a control message, meaning, I can only get my policies updated if I add something in the result dict.
these two things probably shouldn't go together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point and I actually have a better idea? right now I haven't even made trainer_runner such that it only updates the policies that is allowed (via policies_to_train
). I think with that variable lingering around, I can infer the policies to update. Then the retuened results won't be used as the message passing medium. I'll add a todo with better design guideline in the next PR.
|
||
samples_to_concat = [] | ||
# cycle through the batch until we have enough samples | ||
while e >= len(module_batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you duplicate module_batches multiple times to make a mini_batch large enough?
this block of code is actually very confusing ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, generally speaking, if number of samples within each policy batch is un-even and skewed, you need to make sure that they are sharded almost equally when you pass them to the RLTrainer. Say one policy1 have 100 samples, and policy2 has 20 samples. but when I want to pass a minibatch size of 40, from policy1 I will select 0-39, but from policy2 I select 0-19+0-19 to make up 40 samples. This is the problem of sharding across policies, I couldn't figure out an easy-to-understand sharding strategy. But I want to think about this more when I think about the double batch communication overhead, they are very relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. maybe make a util function out of this. so then it's pluggable. and we can write and compare a few different schemes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, not sure if you want to add these as comments in the code. that would help a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will update taking into account the suggestions 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I am creating a utility MiniBatchIterator that can be re-used once we move this iteration to sharded batches inside RLTrainer as well. Thanks for the suggestion. I really like the design break down.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
…ppo-torch-trainer
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Why are these changes needed?
This PR creates a PPO torch RLTrainer.
PRs that have to merge first:
PRs that make clean ups and are less hacky for this to work:
Here is the learning curves for CartPole-v1. Blue is the old training stack with RLModules, the red is the new training stack on one CPU. Throughput is also relatively good compared to before.
Let's check it out in multi-gpu case:
Blue is zero gpu, red is 1 gpu, and cyan is 2 gpus.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.