Skip to content

Commit

Permalink
[RLlib] PPO torch RLTrainer (#31801)
Browse files Browse the repository at this point in the history
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
  • Loading branch information
kouroshHakha authored Feb 8, 2023
1 parent 3f43969 commit 1f77e04
Show file tree
Hide file tree
Showing 17 changed files with 839 additions and 75 deletions.
46 changes: 46 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,14 @@ py_test(
srcs = ["algorithms/ppo/tests/test_ppo_rl_module.py"]
)


py_test(
name = "test_ppo_rl_trainer",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/ppo/tests/test_ppo_rl_trainer.py"]
)

# PPO Reproducibility
py_test(
name = "test_repro_ppo",
Expand Down Expand Up @@ -3832,6 +3840,44 @@ py_test(
]
)

# --------------------------------------------------------------------
# examples/rl_trainer directory
#
#
# Description: These are RLlib tests for the new multi-gpu enabled
# training stack via RLTrainers.
#
# NOTE: Add tests alphabetically to this list.
# --------------------------------------------------------------------

py_test(
name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch",
main = "examples/rl_trainer/multi_agent_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "examples", "no-gpu"],
size = "medium",
srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"],
args = ["--as-test", "--framework=torch", "--num-gpus=0"]
)

py_test(
name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch_gpu",
main = "examples/rl_trainer/multi_agent_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "examples", "gpu"],
size = "medium",
srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"],
args = ["--as-test", "--framework=torch", "--num-gpus=1"]
)


py_test(
name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch_multi_gpu",
main = "examples/rl_trainer/multi_agent_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "examples", "multi-gpu"],
size = "medium",
srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"],
args = ["--as-test", "--framework=torch", "--num-gpus=2"]
)

# --------------------------------------------------------------------
# examples/documentation directory
#
Expand Down
47 changes: 38 additions & 9 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# TODO (Kourosh): This is an interim solution where policies and modules
# co-exist. In this world we have both policy_map and MARLModule that need
# to be consistent with one another. To make a consistent parity between
# the two we need to loop throught the policy modules and create a simple
# the two we need to loop through the policy modules and create a simple
# MARLModule from the RLModule within each policy.
local_worker = self.workers.local_worker()
module_specs = {}
Expand All @@ -715,6 +715,10 @@ def setup(self, config: AlgorithmConfig) -> None:
trainer_runner_config = self.config.get_trainer_runner_config(module_spec)
self.trainer_runner = trainer_runner_config.build()

# sync the weights from local rollout worker to trainers
weights = local_worker.get_weights()
self.trainer_runner.set_weights(weights)

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)

Expand Down Expand Up @@ -858,7 +862,7 @@ def evaluate(
# Sync weights to the evaluation WorkerSet.
if self.evaluation_workers is not None:
self.evaluation_workers.sync_weights(
from_worker=self.workers.local_worker()
from_worker_or_trainer=self.workers.local_worker()
)
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
Expand Down Expand Up @@ -1376,11 +1380,11 @@ def training_step(self) -> ResultDict:
# TODO (Avnish): Implement this on trainer_runner.get_weights().
# TODO (Kourosh): figure out how we are going to sync MARLModule
# weights to MARLModule weights under the policy_map objects?
from_worker = None
from_worker_or_trainer = None
if self.config._enable_rl_trainer_api:
from_worker = self.trainer_runner
from_worker_or_trainer = self.trainer_runner
self.workers.sync_weights(
from_worker=from_worker,
from_worker_or_trainer=from_worker_or_trainer,
policies=list(train_results.keys()),
global_vars=global_vars,
)
Expand Down Expand Up @@ -2132,10 +2136,13 @@ def default_resource_request(
eval_cf.freeze()

# resources for local worker
local_worker = {
"CPU": cf.num_cpus_for_local_worker,
"GPU": 0 if cf._fake_gpus else cf.num_gpus,
}
if cf._enable_rl_trainer_api:
local_worker = {"CPU": cf.num_cpus_for_local_worker, "GPU": 0}
else:
local_worker = {
"CPU": cf.num_cpus_for_local_worker,
"GPU": 0 if cf._fake_gpus else cf.num_gpus,
}

bundles = [local_worker]

Expand Down Expand Up @@ -2179,6 +2186,28 @@ def default_resource_request(

bundles += rollout_workers + evaluation_bundle

if cf._enable_rl_trainer_api:
# resources for the trainer
if cf.num_trainer_workers == 0:
# if num_trainer_workers is 0, then we need to allocate one gpu if
# num_gpus_per_trainer_worker is greater than 0.
trainer_bundle = [
{
"CPU": cf.num_cpus_per_trainer_worker,
"GPU": cf.num_gpus_per_trainer_worker,
}
]
else:
trainer_bundle = [
{
"CPU": cf.num_cpus_per_trainer_worker,
"GPU": cf.num_gpus_per_trainer_worker,
}
for _ in range(cf.num_trainer_workers)
]

bundles += trainer_bundle

# Return PlacementGroupFactory containing all needed resources
# (already properly defined as device bundles).
return PlacementGroupFactory(
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,13 @@ def validate(self) -> None:
rl_module_class_path = self.get_default_rl_module_class()
self.rl_module_class = _resolve_class_path(rl_module_class_path)

# make sure the resource requirements for trainer runner is valid
if self.num_trainer_workers == 0 and self.num_gpus_per_worker > 1:
raise ValueError(
"num_gpus_per_worker must be 0 (cpu) or 1 (gpu) when using local mode "
"(i.e. num_trainer_workers = 0)"
)

# resolve rl_trainer class
if self._enable_rl_trainer_api and self.rl_trainer_class is None:
rl_trainer_class_path = self.get_default_rl_trainer_class()
Expand Down
82 changes: 71 additions & 11 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.pg import PGConfig
from ray.rllib.algorithms.ppo.ppo_rl_trainer_config import PPORLTrainerHPs
from ray.rllib.execution.rollout_ops import (
standardize_fields,
)
Expand All @@ -42,6 +43,8 @@

if TYPE_CHECKING:
from ray.rllib.core.rl_module import RLModule
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +92,7 @@ def __init__(self, algo_class=None):
# fmt: off
# __sphinx_doc_begin__
# PPO specific settings:
self._rl_trainer_hps = PPORLTrainerHPs()
self.use_critic = True
self.use_gae = True
self.lambda_ = 1.0
Expand Down Expand Up @@ -131,6 +135,17 @@ def get_default_rl_module_class(self) -> Union[Type["RLModule"], str]:
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(AlgorithmConfig)
def get_default_rl_trainer_class(self) -> Union[Type["RLTrainer"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_trainer import (
PPOTorchRLTrainer,
)

return PPOTorchRLTrainer
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(AlgorithmConfig)
def training(
self,
Expand Down Expand Up @@ -201,12 +216,16 @@ def training(
self.lr_schedule = lr_schedule
if use_critic is not NotProvided:
self.use_critic = use_critic
# TODO (Kourosh) This is experimental. Set rl_trainer_hps parameters as
# well. Don't forget to remove .use_critic from algorithm config.
self._rl_trainer_hps.use_critic = use_critic
if use_gae is not NotProvided:
self.use_gae = use_gae
if lambda_ is not NotProvided:
self.lambda_ = lambda_
if kl_coeff is not NotProvided:
self.kl_coeff = kl_coeff
self._rl_trainer_hps.kl_coeff = kl_coeff
if sgd_minibatch_size is not NotProvided:
self.sgd_minibatch_size = sgd_minibatch_size
if num_sgd_iter is not NotProvided:
Expand All @@ -215,18 +234,24 @@ def training(
self.shuffle_sequences = shuffle_sequences
if vf_loss_coeff is not NotProvided:
self.vf_loss_coeff = vf_loss_coeff
self._rl_trainer_hps.vf_loss_coeff = vf_loss_coeff
if entropy_coeff is not NotProvided:
self.entropy_coeff = entropy_coeff
self._rl_trainer_hps.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is not NotProvided:
self.entropy_coeff_schedule = entropy_coeff_schedule
self._rl_trainer_hps.entropy_coeff_schedule = entropy_coeff_schedule
if clip_param is not NotProvided:
self.clip_param = clip_param
self._rl_trainer_hps.clip_param = clip_param
if vf_clip_param is not NotProvided:
self.vf_clip_param = vf_clip_param
self._rl_trainer_hps.vf_clip_param = vf_clip_param
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if kl_target is not NotProvided:
self.kl_target = kl_target
self._rl_trainer_hps.kl_target = kl_target

return self

Expand Down Expand Up @@ -366,14 +391,39 @@ def training_step(self) -> ResultDict:
train_batch = standardize_fields(train_batch, ["advantages"])
# Train
if self.config._enable_rl_trainer_api:
train_results = self.trainer_runner.update(train_batch)
# TODO (Kourosh) Clearly define what train_batch_size
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
# TODO (Kourosh) Do this inside the RL Trainer so
# that we don't have to do this back and forth
# communication between driver and the remote
# trainer workers

train_results = self.trainer_runner.fit(
train_batch,
minibatch_size=self.config.sgd_minibatch_size,
num_iters=self.config.num_sgd_iter,
)

elif self.config.simple_optimizer:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)

policies_to_update = list(train_results.keys())
if self.config._enable_rl_trainer_api:
# the train results's loss keys are pids to their loss values. But we also
# return a total_loss key at the same level as the pid keys. So we need to
# subtract that to get the total set of pids to update.
# TODO (Kourosh): We need to make a better design for the hierarchy of the
# train results, so that all the policy ids end up in the same level.
# TODO (Kourosh): We should also not be using train_results as a message
# passing medium to infer whcih policies to update. We could use
# policies_to_train variable that is given by the user to infer this.
policies_to_update = set(train_results["loss"].keys()) - {"total_loss"}
else:
policies_to_update = list(train_results.keys())

# TODO (Kourosh): num_grad_updates per each policy should be accessible via
# train_results
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
"num_grad_updates_per_policy": {
Expand All @@ -384,24 +434,34 @@ def training_step(self) -> ResultDict:

# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.num_remote_workers() > 0:
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
from_worker = None
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
if self.workers.num_remote_workers() > 0:
from_worker_or_trainer = None
if self.config._enable_rl_trainer_api:
from_worker = self.trainer_runner
# sync weights from trainer_runner to all rollout workers
from_worker_or_trainer = self.trainer_runner
self.workers.sync_weights(
from_worker=from_worker,
policies=list(train_results.keys()),
from_worker_or_trainer=from_worker_or_trainer,
policies=policies_to_update,
global_vars=global_vars,
)
elif self.config._enable_rl_trainer_api:
weights = self.trainer_runner.get_weights()
self.workers.local_worker().set_weights(weights)

if self.config._enable_rl_trainer_api:
kl_dict = {
pid: pinfo[LEARNER_STATS_KEY].get("kl")
for pid, pinfo in train_results.items()
# TODO (Kourosh): Train results don't match the old format. The thing
# that used to be under `kl` is now under `mean_kl_loss`. Fix this. Do
# we need get here?
pid: train_results["loss"][pid].get("mean_kl_loss")
for pid in policies_to_update
}
# triggers a special update method on RLOptimizer to update the KL values.
self.trainer_runner.additional_update(kl_values=kl_dict)
self.trainer_runner.additional_update(
sampled_kl_values=kl_dict,
timestep=self._counters[NUM_AGENT_STEPS_SAMPLED],
)

return train_results

Expand Down
21 changes: 21 additions & 0 deletions rllib/algorithms/ppo/ppo_rl_trainer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Optional, Union

from ray.rllib.core.rl_trainer.rl_trainer import RLTrainerHPs


@dataclass
class PPORLTrainerHPs(RLTrainerHPs):
"""Hyperparameters for the PPO RL Trainer"""

kl_coeff: float = 0.2
kl_target: float = 0.01
use_critic: bool = True
clip_param: float = 0.3
vf_clip_param: float = 10.0
entropy_coeff: float = 0.0
vf_loss_coeff: float = 1.0

# experimental placeholder for things that could be part of the base RLTrainerHPs
lr_schedule: Optional[List[List[Union[int, float]]]] = None
entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None
Loading

0 comments on commit 1f77e04

Please sign in to comment.