-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] PPO torch RLTrainer #31801
[RLlib] PPO torch RLTrainer #31801
Changes from all commits
28679ca
a2f9439
d8b36c1
d24bff5
2b67577
fe60e20
916d674
71026e5
ae61014
ef1ffb8
f393cea
16f64f9
9719518
77730d8
97573dc
091c406
a42b0f1
62fe11f
2cc5185
435c352
ff845c3
7c3eed7
d56ce2c
200b5f7
f747e50
7ce81f0
b2ddd2d
5bc625c
3302db7
5455e29
a2d042f
b9159a8
f3edd50
2aec198
873cdd5
13e19aa
68b72e4
ca3e225
ff3b335
3234aaf
15b99ee
8b9ae92
eb82e67
dac0d6b
cfdaa04
7b5938b
f4cbe5a
9fb749c
9645137
eac5223
0baa3cf
78860f8
c125e20
5aaa603
c77be0c
0e6f511
1778c44
4bdd949
58bbe82
93b27ec
7aaee18
7c831d3
d3d610e
97acdb7
f3ccf54
93f3ce9
c58293a
3f36f0d
77ff585
87bda01
6346a20
eb5106f
b2c01ad
ea3d9c6
45830c6
94ca772
29ac2fb
4714e20
abd5e5e
a44c370
e496fcc
477795d
58fe5df
d5bcd3b
d8841d1
026899e
f04e99d
d280887
97d80b1
85387e5
1a70b6e
317a9fd
cbc9b02
869717e
defa5f1
e0a0bcf
cf4041e
d4cd654
54eb315
1c826db
e8cf7e1
2ea3a4d
92dd832
fd84f7a
4c38455
d40c48d
b0fed29
b92eee9
1c3ccb5
37c9fca
1687241
d3ee81a
24abebd
3ff0668
d91eca5
5b209be
833e491
e29021d
d113d3a
f28a385
993932f
05c8297
839ff90
320b116
8466fc8
4c4f7cc
d78f3d5
7a68bb4
0008833
edbf081
2698973
4c8ce18
9f29038
b1d3f63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from ray.rllib.algorithms.algorithm import Algorithm | ||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided | ||
from ray.rllib.algorithms.pg import PGConfig | ||
from ray.rllib.algorithms.ppo.ppo_rl_trainer_config import PPORLTrainerHPs | ||
from ray.rllib.execution.rollout_ops import ( | ||
standardize_fields, | ||
) | ||
|
@@ -42,6 +43,8 @@ | |
|
||
if TYPE_CHECKING: | ||
from ray.rllib.core.rl_module import RLModule | ||
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -89,6 +92,7 @@ def __init__(self, algo_class=None): | |
# fmt: off | ||
# __sphinx_doc_begin__ | ||
# PPO specific settings: | ||
self._rl_trainer_hps = PPORLTrainerHPs() | ||
self.use_critic = True | ||
self.use_gae = True | ||
self.lambda_ = 1.0 | ||
|
@@ -131,6 +135,17 @@ def get_default_rl_module_class(self) -> Union[Type["RLModule"], str]: | |
else: | ||
raise ValueError(f"The framework {self.framework_str} is not supported.") | ||
|
||
@override(AlgorithmConfig) | ||
def get_default_rl_trainer_class(self) -> Union[Type["RLTrainer"], str]: | ||
if self.framework_str == "torch": | ||
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_trainer import ( | ||
PPOTorchRLTrainer, | ||
) | ||
|
||
return PPOTorchRLTrainer | ||
else: | ||
raise ValueError(f"The framework {self.framework_str} is not supported.") | ||
|
||
@override(AlgorithmConfig) | ||
def training( | ||
self, | ||
|
@@ -201,12 +216,16 @@ def training( | |
self.lr_schedule = lr_schedule | ||
if use_critic is not NotProvided: | ||
self.use_critic = use_critic | ||
# TODO (Kourosh) This is experimental. Set rl_trainer_hps parameters as | ||
# well. Don't forget to remove .use_critic from algorithm config. | ||
self._rl_trainer_hps.use_critic = use_critic | ||
if use_gae is not NotProvided: | ||
self.use_gae = use_gae | ||
if lambda_ is not NotProvided: | ||
self.lambda_ = lambda_ | ||
if kl_coeff is not NotProvided: | ||
self.kl_coeff = kl_coeff | ||
self._rl_trainer_hps.kl_coeff = kl_coeff | ||
if sgd_minibatch_size is not NotProvided: | ||
self.sgd_minibatch_size = sgd_minibatch_size | ||
if num_sgd_iter is not NotProvided: | ||
|
@@ -215,18 +234,24 @@ def training( | |
self.shuffle_sequences = shuffle_sequences | ||
if vf_loss_coeff is not NotProvided: | ||
self.vf_loss_coeff = vf_loss_coeff | ||
self._rl_trainer_hps.vf_loss_coeff = vf_loss_coeff | ||
if entropy_coeff is not NotProvided: | ||
self.entropy_coeff = entropy_coeff | ||
self._rl_trainer_hps.entropy_coeff = entropy_coeff | ||
if entropy_coeff_schedule is not NotProvided: | ||
self.entropy_coeff_schedule = entropy_coeff_schedule | ||
self._rl_trainer_hps.entropy_coeff_schedule = entropy_coeff_schedule | ||
if clip_param is not NotProvided: | ||
self.clip_param = clip_param | ||
self._rl_trainer_hps.clip_param = clip_param | ||
if vf_clip_param is not NotProvided: | ||
self.vf_clip_param = vf_clip_param | ||
self._rl_trainer_hps.vf_clip_param = vf_clip_param | ||
if grad_clip is not NotProvided: | ||
self.grad_clip = grad_clip | ||
if kl_target is not NotProvided: | ||
self.kl_target = kl_target | ||
self._rl_trainer_hps.kl_target = kl_target | ||
|
||
return self | ||
|
||
|
@@ -366,14 +391,39 @@ def training_step(self) -> ResultDict: | |
train_batch = standardize_fields(train_batch, ["advantages"]) | ||
# Train | ||
if self.config._enable_rl_trainer_api: | ||
train_results = self.trainer_runner.update(train_batch) | ||
# TODO (Kourosh) Clearly define what train_batch_size | ||
# vs. sgd_minibatch_size and num_sgd_iter is in the config. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +100, let's change the naming of these params. |
||
# TODO (Kourosh) Do this inside the RL Trainer so | ||
# that we don't have to do this back and forth | ||
# communication between driver and the remote | ||
# trainer workers | ||
|
||
train_results = self.trainer_runner.fit( | ||
train_batch, | ||
minibatch_size=self.config.sgd_minibatch_size, | ||
num_iters=self.config.num_sgd_iter, | ||
) | ||
|
||
elif self.config.simple_optimizer: | ||
train_results = train_one_step(self, train_batch) | ||
else: | ||
train_results = multi_gpu_train_one_step(self, train_batch) | ||
|
||
policies_to_update = list(train_results.keys()) | ||
if self.config._enable_rl_trainer_api: | ||
# the train results's loss keys are pids to their loss values. But we also | ||
# return a total_loss key at the same level as the pid keys. So we need to | ||
# subtract that to get the total set of pids to update. | ||
# TODO (Kourosh): We need to make a better design for the hierarchy of the | ||
# train results, so that all the policy ids end up in the same level. | ||
# TODO (Kourosh): We should also not be using train_results as a message | ||
# passing medium to infer whcih policies to update. We could use | ||
# policies_to_train variable that is given by the user to infer this. | ||
policies_to_update = set(train_results["loss"].keys()) - {"total_loss"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can return a tuple that contains total loss stats along with the per policy stats that way we don't have to do this weird line of logic here. If we add a todo here I think that'd be enough for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I don't like this either. The format of what we return has to change anyway, so I'll revisit it for that regardless. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a TODO here, then describing why we are doing this weird set subtraction and how we will fix it in the future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think you should be explicit, and not rely on some keys on the result dict to tell you which policies need to be updated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to hear more about what you mean by being more explicit? I was planning on revisiting the train_results structure to remove these requirements in the next round of updates, But would love to hear your thoughts on how it should ideally look like? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not opinionated about how the result dict should look like. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see your point and I actually have a better idea? right now I haven't even made trainer_runner such that it only updates the policies that is allowed (via |
||
else: | ||
policies_to_update = list(train_results.keys()) | ||
|
||
# TODO (Kourosh): num_grad_updates per each policy should be accessible via | ||
# train_results | ||
global_vars = { | ||
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], | ||
"num_grad_updates_per_policy": { | ||
|
@@ -384,24 +434,34 @@ def training_step(self) -> ResultDict: | |
|
||
# Update weights - after learning on the local worker - on all remote | ||
# workers. | ||
if self.workers.num_remote_workers() > 0: | ||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: | ||
from_worker = None | ||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: | ||
if self.workers.num_remote_workers() > 0: | ||
from_worker_or_trainer = None | ||
if self.config._enable_rl_trainer_api: | ||
from_worker = self.trainer_runner | ||
# sync weights from trainer_runner to all rollout workers | ||
from_worker_or_trainer = self.trainer_runner | ||
self.workers.sync_weights( | ||
from_worker=from_worker, | ||
policies=list(train_results.keys()), | ||
from_worker_or_trainer=from_worker_or_trainer, | ||
policies=policies_to_update, | ||
global_vars=global_vars, | ||
) | ||
elif self.config._enable_rl_trainer_api: | ||
weights = self.trainer_runner.get_weights() | ||
self.workers.local_worker().set_weights(weights) | ||
|
||
if self.config._enable_rl_trainer_api: | ||
kl_dict = { | ||
pid: pinfo[LEARNER_STATS_KEY].get("kl") | ||
for pid, pinfo in train_results.items() | ||
# TODO (Kourosh): Train results don't match the old format. The thing | ||
# that used to be under `kl` is now under `mean_kl_loss`. Fix this. Do | ||
# we need get here? | ||
pid: train_results["loss"][pid].get("mean_kl_loss") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the case where this KL key is NOT present in the ["loss"][pid] sub-dict? Do we need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. captured it in a TODO |
||
for pid in policies_to_update | ||
} | ||
# triggers a special update method on RLOptimizer to update the KL values. | ||
self.trainer_runner.additional_update(kl_values=kl_dict) | ||
self.trainer_runner.additional_update( | ||
sampled_kl_values=kl_dict, | ||
timestep=self._counters[NUM_AGENT_STEPS_SAMPLED], | ||
) | ||
|
||
return train_results | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from dataclasses import dataclass | ||
from typing import List, Optional, Union | ||
|
||
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainerHPs | ||
|
||
|
||
@dataclass | ||
class PPORLTrainerHPs(RLTrainerHPs): | ||
"""Hyperparameters for the PPO RL Trainer""" | ||
|
||
kl_coeff: float = 0.2 | ||
kl_target: float = 0.01 | ||
use_critic: bool = True | ||
clip_param: float = 0.3 | ||
vf_clip_param: float = 10.0 | ||
entropy_coeff: float = 0.0 | ||
vf_loss_coeff: float = 1.0 | ||
|
||
# experimental placeholder for things that could be part of the base RLTrainerHPs | ||
lr_schedule: Optional[List[List[Union[int, float]]]] = None | ||
entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh the weights are formatted correctly already? Cool cool cool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, local worker returns mapping from pid to weights of the underlying rl module, trainer runner returns marl module weights which by default is module id to rl module weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo above: "throught" -> "through"