Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] PPO torch RLTrainer #31801

Merged
merged 134 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
28679ca
added quick cleanups to trainer_runner.
kouroshHakha Jan 10, 2023
a2f9439
created test_trainer_runner
kouroshHakha Jan 10, 2023
d8b36c1
added TODO tag
kouroshHakha Jan 10, 2023
d24bff5
Merge branch 'master' into trainer-runner-quick-cleanups
kouroshHakha Jan 11, 2023
2b67577
fixed imports
kouroshHakha Jan 11, 2023
fe60e20
typo in BUILD
kouroshHakha Jan 11, 2023
916d674
started to create torch_rl_trainer
kouroshHakha Jan 12, 2023
71026e5
added bc_rl_trainer
kouroshHakha Jan 12, 2023
ae61014
torch trainer test works now
kouroshHakha Jan 12, 2023
ef1ffb8
lint
kouroshHakha Jan 12, 2023
f393cea
Merge branch 'master' into torch-trainer
kouroshHakha Jan 13, 2023
16f64f9
updated TODOs and BUILD
kouroshHakha Jan 13, 2023
9719518
wip: trainer_runner multi-gpu test
kouroshHakha Jan 13, 2023
77730d8
torch version runs but the parameters are not synced
kouroshHakha Jan 14, 2023
97573dc
wip
kouroshHakha Jan 14, 2023
091c406
got the multi-gpu gradient sync up working
kouroshHakha Jan 17, 2023
a42b0f1
fixed add/remove multi-gpu tests
kouroshHakha Jan 17, 2023
62fe11f
moved the DDPRLModuleWrapper outside of RLTrainer + lint
kouroshHakha Jan 17, 2023
2cc5185
merged tf and torch train_runner tests
kouroshHakha Jan 17, 2023
435c352
fixed trainer_runner auto-scaling on a cluster where autoscaling is e…
kouroshHakha Jan 18, 2023
ff845c3
fix rl_trainer unittest failures.
kouroshHakha Jan 18, 2023
7c3eed7
1. renamed the DDP wrapper
kouroshHakha Jan 18, 2023
d56ce2c
removed in_test from the production code
kouroshHakha Jan 18, 2023
200b5f7
clarified todo
kouroshHakha Jan 18, 2023
f747e50
comments
kouroshHakha Jan 18, 2023
7ce81f0
renamed make_distributed to make_distributed_module
kouroshHakha Jan 18, 2023
b2ddd2d
fixed test torch rl_trainer lint
kouroshHakha Jan 18, 2023
5bc625c
fixed marl_module stuff
kouroshHakha Jan 18, 2023
3302db7
fixed the import issue
kouroshHakha Jan 18, 2023
5455e29
lint
kouroshHakha Jan 18, 2023
a2d042f
Merge branch 'master' into torch-trainer
kouroshHakha Jan 19, 2023
b9159a8
fixed lint
kouroshHakha Jan 19, 2023
f3edd50
Merge branch 'master' into torch-trainer
kouroshHakha Jan 19, 2023
2aec198
test trainer runner updated
kouroshHakha Jan 19, 2023
873cdd5
fixed the scaling config and in_test issues introduced after the merge.
kouroshHakha Jan 19, 2023
13e19aa
fixed the scaling config and in_test issues introduced after the merge.
kouroshHakha Jan 19, 2023
68b72e4
Merge branch 'torch-trainer' of github.com:kouroshHakha/ray into torc…
kouroshHakha Jan 19, 2023
ca3e225
wip
kouroshHakha Jan 20, 2023
ff3b335
Merge branch 'master' into torch-trainer
kouroshHakha Jan 20, 2023
3234aaf
wip
kouroshHakha Jan 20, 2023
15b99ee
wip
kouroshHakha Jan 20, 2023
8b9ae92
wip
kouroshHakha Jan 20, 2023
eb82e67
wip
kouroshHakha Jan 20, 2023
dac0d6b
fixed trainer_runner config test
kouroshHakha Jan 20, 2023
cfdaa04
removed the stuff that got moved to SARLTrainer made easy PR
kouroshHakha Jan 20, 2023
7b5938b
fixed torch import
kouroshHakha Jan 20, 2023
f4cbe5a
removed the override decorator for nn.Module
kouroshHakha Jan 20, 2023
9fb749c
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Jan 20, 2023
9645137
added unittest (wip)
kouroshHakha Jan 20, 2023
eac5223
fixed import torch in bc_module.py
kouroshHakha Jan 20, 2023
0baa3cf
fixed the bazel bug where the working directory gets switched to wher…
kouroshHakha Jan 20, 2023
78860f8
Merge branch 'torch-trainer' into ppo-torch-trainer
kouroshHakha Jan 20, 2023
c125e20
wip
kouroshHakha Jan 20, 2023
5aaa603
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Jan 21, 2023
c77be0c
fixed the unittest
kouroshHakha Jan 21, 2023
0e6f511
wip
kouroshHakha Jan 21, 2023
1778c44
Merge branch 'master' into policy-with-marl
kouroshHakha Jan 23, 2023
4bdd949
added dataclass specs for RLModule and MARLModule for easier construc…
kouroshHakha Jan 24, 2023
58bbe82
test trainer runner local passed
kouroshHakha Jan 24, 2023
93b27ec
add_module() api is now update to accept a module_spec instead of mod…
kouroshHakha Jan 24, 2023
7aaee18
get_trainer_runner_config() now gets an optional ModuleSpec object
kouroshHakha Jan 24, 2023
7c831d3
Algorithm can now construct the trainer_runner based on the policy_maps
kouroshHakha Jan 24, 2023
d3d610e
lint and clean up
kouroshHakha Jan 24, 2023
97acdb7
fixed imports
kouroshHakha Jan 24, 2023
f3ccf54
Merge branch 'policy-with-marl' into ppo-torch-trainer
kouroshHakha Jan 24, 2023
93f3ce9
fixed the unittest for ppo_rl_trainer
kouroshHakha Jan 24, 2023
c58293a
got the PPO POC running
kouroshHakha Jan 25, 2023
3f36f0d
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Jan 25, 2023
77ff585
lint
kouroshHakha Jan 25, 2023
87bda01
wip
kouroshHakha Jan 25, 2023
6346a20
wip
kouroshHakha Jan 25, 2023
eb5106f
wip
kouroshHakha Jan 25, 2023
b2c01ad
multi-gpu test works now
kouroshHakha Jan 25, 2023
ea3d9c6
removed left out api get_weight()
kouroshHakha Jan 25, 2023
45830c6
get_weights() updated
kouroshHakha Jan 26, 2023
94ca772
trying out a new configuration pattern for trainer runner and rl trai…
kouroshHakha Jan 26, 2023
29ac2fb
wip
kouroshHakha Jan 26, 2023
4714e20
lint
kouroshHakha Jan 26, 2023
abd5e5e
rl_trainer tf test passes again
kouroshHakha Jan 27, 2023
a44c370
torch rl trainer test passed
kouroshHakha Jan 27, 2023
e496fcc
trainer_runner_config test works too
kouroshHakha Jan 27, 2023
477795d
tested the multigpu
kouroshHakha Jan 27, 2023
58fe5df
docstring updated
kouroshHakha Jan 27, 2023
d5bcd3b
updated the docstring
kouroshHakha Jan 27, 2023
d8841d1
wip
kouroshHakha Jan 28, 2023
026899e
renamed the classes and variables to backend
kouroshHakha Jan 28, 2023
f04e99d
renamed
kouroshHakha Jan 28, 2023
d280887
wip
kouroshHakha Jan 28, 2023
97d80b1
refactor
kouroshHakha Jan 28, 2023
85387e5
lin
kouroshHakha Jan 28, 2023
1a70b6e
fix the lint and tf_dependency test issue via adding tf stubs
kouroshHakha Jan 29, 2023
317a9fd
wip on unittest trianer_runner
kouroshHakha Jan 29, 2023
cbc9b02
wip
kouroshHakha Jan 29, 2023
869717e
Merge branch 'master' into trainer-runner-scaling-config
kouroshHakha Jan 29, 2023
defa5f1
test_trainer_runner updated to support all variations of scaling config
kouroshHakha Jan 29, 2023
e0a0bcf
removed test trainer runner local and moved it to test_trainer_runner.py
kouroshHakha Jan 29, 2023
cf4041e
fixed the test failures
kouroshHakha Jan 29, 2023
d4cd654
1. Removed tf due to flakiness from test_trainer_runner
kouroshHakha Jan 29, 2023
54eb315
removed backend class definitions
kouroshHakha Jan 29, 2023
1c826db
Removed Hyperparams class
kouroshHakha Jan 29, 2023
e8cf7e1
introed FrameworkHPs to differebntiate between tf/torch specific stuf…
kouroshHakha Jan 29, 2023
2ea3a4d
the unittests pass
kouroshHakha Jan 29, 2023
92dd832
Merge branch 'trainer-runner-scaling-config' into ppo-torch-trainer
kouroshHakha Jan 30, 2023
fd84f7a
addressed comments and fixed some introduced bug
kouroshHakha Jan 30, 2023
4c38455
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Jan 30, 2023
d40c48d
fix from_worker_or_trainer renaming issue
kouroshHakha Jan 30, 2023
b0fed29
fixed tests
kouroshHakha Jan 30, 2023
b92eee9
fixed test_ppo_rl_trainer.py
kouroshHakha Jan 30, 2023
1c3ccb5
rerunning ci
kouroshHakha Jan 31, 2023
37c9fca
lint
kouroshHakha Jan 31, 2023
1687241
added TODO
kouroshHakha Jan 31, 2023
d3ee81a
empty commit
kouroshHakha Jan 31, 2023
24abebd
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Jan 31, 2023
3ff0668
fixed weights to numpy
kouroshHakha Jan 31, 2023
d91eca5
[release] minor fix to pytorch_pbt_failure test when using gpu. (#32070)
xwjiang2010 Jan 31, 2023
5b209be
Merge branch 'master' into err-out-marl-env
kouroshHakha Jan 31, 2023
833e491
Merge branch 'err-out-marl-env' into ppo-torch-trainer
kouroshHakha Jan 31, 2023
e29021d
error out when no agent is passed in in the indepenent MARL case
kouroshHakha Jan 31, 2023
d113d3a
Merge branch 'err-out-marl-env' into ppo-torch-trainer
kouroshHakha Jan 31, 2023
f28a385
1. set resources for trainable 2. convert_to_numpy weights on RLTrain…
kouroshHakha Jan 31, 2023
993932f
added examples as a unittest to BUILD kite
kouroshHakha Jan 31, 2023
05c8297
fixed test name conflict
kouroshHakha Jan 31, 2023
839ff90
removed the wrong tag from docs
kouroshHakha Feb 1, 2023
320b116
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Feb 1, 2023
8466fc8
fixed as test flag
kouroshHakha Feb 1, 2023
4c4f7cc
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Feb 2, 2023
d78f3d5
made the sync_weights equivalent to the implementation before this PR
kouroshHakha Feb 6, 2023
7a68bb4
addressed jun's comments, created a minibatchCycleIterator
kouroshHakha Feb 7, 2023
0008833
Merge branch 'ppo-torch-trainer' of github.com:kouroshHakha/ray into …
kouroshHakha Feb 7, 2023
edbf081
added annotations
kouroshHakha Feb 7, 2023
2698973
Merge branch 'master' into ppo-torch-trainer
kouroshHakha Feb 7, 2023
4c8ce18
empty
kouroshHakha Feb 7, 2023
9f29038
empty
kouroshHakha Feb 7, 2023
b1d3f63
empty
kouroshHakha Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,14 @@ py_test(
srcs = ["algorithms/ppo/tests/test_ppo_rl_module.py"]
)


py_test(
name = "test_ppo_rl_trainer",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/ppo/tests/test_ppo_rl_trainer.py"]
)

# PPO Reproducibility
py_test(
name = "test_repro_ppo",
Expand Down Expand Up @@ -3840,6 +3848,44 @@ py_test(
]
)

# --------------------------------------------------------------------
# examples/rl_trainer directory
#
#
# Description: These are RLlib tests for the new multi-gpu enabled
# training stack via RLTrainers.
#
# NOTE: Add tests alphabetically to this list.
# --------------------------------------------------------------------

py_test(
name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch",
main = "examples/rl_trainer/multi_agent_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "examples", "no-gpu"],
size = "medium",
srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"],
args = ["--as-test", "--framework=torch", "--num-gpus=0"]
)

py_test(
name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch_gpu",
main = "examples/rl_trainer/multi_agent_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "examples", "gpu"],
size = "medium",
srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"],
args = ["--as-test", "--framework=torch", "--num-gpus=1"]
)


py_test(
name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch_multi_gpu",
main = "examples/rl_trainer/multi_agent_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "examples", "multi-gpu"],
size = "medium",
srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"],
args = ["--as-test", "--framework=torch", "--num-gpus=2"]
)

# --------------------------------------------------------------------
# examples/documentation directory
#
Expand Down
47 changes: 38 additions & 9 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# TODO (Kourosh): This is an interim solution where policies and modules
# co-exist. In this world we have both policy_map and MARLModule that need
# to be consistent with one another. To make a consistent parity between
# the two we need to loop throught the policy modules and create a simple
# the two we need to loop through the policy modules and create a simple
# MARLModule from the RLModule within each policy.
local_worker = self.workers.local_worker()
module_specs = {}
Expand All @@ -715,6 +715,10 @@ def setup(self, config: AlgorithmConfig) -> None:
trainer_runner_config = self.config.get_trainer_runner_config(module_spec)
self.trainer_runner = trainer_runner_config.build()

# sync the weights from local rollout worker to trainers
weights = local_worker.get_weights()
self.trainer_runner.set_weights(weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh the weights are formatted correctly already? Cool cool cool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, local worker returns mapping from pid to weights of the underlying rl module, trainer runner returns marl module weights which by default is module id to rl module weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo above: "throught" -> "through"


# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)

Expand Down Expand Up @@ -858,7 +862,7 @@ def evaluate(
# Sync weights to the evaluation WorkerSet.
if self.evaluation_workers is not None:
self.evaluation_workers.sync_weights(
from_worker=self.workers.local_worker()
from_worker_or_trainer=self.workers.local_worker()
)
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
Expand Down Expand Up @@ -1376,11 +1380,11 @@ def training_step(self) -> ResultDict:
# TODO (Avnish): Implement this on trainer_runner.get_weights().
# TODO (Kourosh): figure out how we are going to sync MARLModule
# weights to MARLModule weights under the policy_map objects?
from_worker = None
from_worker_or_trainer = None
if self.config._enable_rl_trainer_api:
from_worker = self.trainer_runner
from_worker_or_trainer = self.trainer_runner
self.workers.sync_weights(
from_worker=from_worker,
from_worker_or_trainer=from_worker_or_trainer,
policies=list(train_results.keys()),
global_vars=global_vars,
)
Expand Down Expand Up @@ -2132,10 +2136,13 @@ def default_resource_request(
eval_cf.freeze()

# resources for local worker
local_worker = {
"CPU": cf.num_cpus_for_local_worker,
"GPU": 0 if cf._fake_gpus else cf.num_gpus,
}
if cf._enable_rl_trainer_api:
local_worker = {"CPU": cf.num_cpus_for_local_worker, "GPU": 0}
else:
local_worker = {
"CPU": cf.num_cpus_for_local_worker,
"GPU": 0 if cf._fake_gpus else cf.num_gpus,
}

bundles = [local_worker]

Expand Down Expand Up @@ -2179,6 +2186,28 @@ def default_resource_request(

bundles += rollout_workers + evaluation_bundle

if cf._enable_rl_trainer_api:
# resources for the trainer
if cf.num_trainer_workers == 0:
# if num_trainer_workers is 0, then we need to allocate one gpu if
# num_gpus_per_trainer_worker is greater than 0.
trainer_bundle = [
{
"CPU": cf.num_cpus_per_trainer_worker,
"GPU": cf.num_gpus_per_trainer_worker,
}
]
else:
trainer_bundle = [
{
"CPU": cf.num_cpus_per_trainer_worker,
"GPU": cf.num_gpus_per_trainer_worker,
}
for _ in range(cf.num_trainer_workers)
]

bundles += trainer_bundle

# Return PlacementGroupFactory containing all needed resources
# (already properly defined as device bundles).
return PlacementGroupFactory(
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,13 @@ def validate(self) -> None:
rl_module_class_path = self.get_default_rl_module_class()
self.rl_module_class = _resolve_class_path(rl_module_class_path)

# make sure the resource requirements for trainer runner is valid
if self.num_trainer_workers == 0 and self.num_gpus_per_worker > 1:
raise ValueError(
"num_gpus_per_worker must be 0 (cpu) or 1 (gpu) when using local mode "
"(i.e. num_trainer_workers = 0)"
)

# resolve rl_trainer class
if self._enable_rl_trainer_api and self.rl_trainer_class is None:
rl_trainer_class_path = self.get_default_rl_trainer_class()
Expand Down
82 changes: 71 additions & 11 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.pg import PGConfig
from ray.rllib.algorithms.ppo.ppo_rl_trainer_config import PPORLTrainerHPs
from ray.rllib.execution.rollout_ops import (
standardize_fields,
)
Expand All @@ -42,6 +43,8 @@

if TYPE_CHECKING:
from ray.rllib.core.rl_module import RLModule
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +92,7 @@ def __init__(self, algo_class=None):
# fmt: off
# __sphinx_doc_begin__
# PPO specific settings:
self._rl_trainer_hps = PPORLTrainerHPs()
self.use_critic = True
self.use_gae = True
self.lambda_ = 1.0
Expand Down Expand Up @@ -131,6 +135,17 @@ def get_default_rl_module_class(self) -> Union[Type["RLModule"], str]:
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(AlgorithmConfig)
def get_default_rl_trainer_class(self) -> Union[Type["RLTrainer"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_trainer import (
PPOTorchRLTrainer,
)

return PPOTorchRLTrainer
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(AlgorithmConfig)
def training(
self,
Expand Down Expand Up @@ -201,12 +216,16 @@ def training(
self.lr_schedule = lr_schedule
if use_critic is not NotProvided:
self.use_critic = use_critic
# TODO (Kourosh) This is experimental. Set rl_trainer_hps parameters as
# well. Don't forget to remove .use_critic from algorithm config.
self._rl_trainer_hps.use_critic = use_critic
if use_gae is not NotProvided:
self.use_gae = use_gae
if lambda_ is not NotProvided:
self.lambda_ = lambda_
if kl_coeff is not NotProvided:
self.kl_coeff = kl_coeff
self._rl_trainer_hps.kl_coeff = kl_coeff
if sgd_minibatch_size is not NotProvided:
self.sgd_minibatch_size = sgd_minibatch_size
if num_sgd_iter is not NotProvided:
Expand All @@ -215,18 +234,24 @@ def training(
self.shuffle_sequences = shuffle_sequences
if vf_loss_coeff is not NotProvided:
self.vf_loss_coeff = vf_loss_coeff
self._rl_trainer_hps.vf_loss_coeff = vf_loss_coeff
if entropy_coeff is not NotProvided:
self.entropy_coeff = entropy_coeff
self._rl_trainer_hps.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is not NotProvided:
self.entropy_coeff_schedule = entropy_coeff_schedule
self._rl_trainer_hps.entropy_coeff_schedule = entropy_coeff_schedule
if clip_param is not NotProvided:
self.clip_param = clip_param
self._rl_trainer_hps.clip_param = clip_param
if vf_clip_param is not NotProvided:
self.vf_clip_param = vf_clip_param
self._rl_trainer_hps.vf_clip_param = vf_clip_param
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if kl_target is not NotProvided:
self.kl_target = kl_target
self._rl_trainer_hps.kl_target = kl_target

return self

Expand Down Expand Up @@ -366,14 +391,39 @@ def training_step(self) -> ResultDict:
train_batch = standardize_fields(train_batch, ["advantages"])
# Train
if self.config._enable_rl_trainer_api:
train_results = self.trainer_runner.update(train_batch)
# TODO (Kourosh) Clearly define what train_batch_size
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+100, let's change the naming of these params.

# TODO (Kourosh) Do this inside the RL Trainer so
# that we don't have to do this back and forth
# communication between driver and the remote
# trainer workers

train_results = self.trainer_runner.fit(
train_batch,
minibatch_size=self.config.sgd_minibatch_size,
num_iters=self.config.num_sgd_iter,
)

elif self.config.simple_optimizer:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)

policies_to_update = list(train_results.keys())
if self.config._enable_rl_trainer_api:
# the train results's loss keys are pids to their loss values. But we also
# return a total_loss key at the same level as the pid keys. So we need to
# subtract that to get the total set of pids to update.
# TODO (Kourosh): We need to make a better design for the hierarchy of the
# train results, so that all the policy ids end up in the same level.
# TODO (Kourosh): We should also not be using train_results as a message
# passing medium to infer whcih policies to update. We could use
# policies_to_train variable that is given by the user to infer this.
policies_to_update = set(train_results["loss"].keys()) - {"total_loss"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can return a tuple that contains total loss stats along with the per policy stats that way we don't have to do this weird line of logic here.

If we add a todo here I think that'd be enough for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I don't like this either. The format of what we return has to change anyway, so I'll revisit it for that regardless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a TODO here, then describing why we are doing this weird set subtraction and how we will fix it in the future?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think you should be explicit, and not rely on some keys on the result dict to tell you which policies need to be updated.
train_results should not be used as control messages basically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to hear more about what you mean by being more explicit? I was planning on revisiting the train_results structure to remove these requirements in the next round of updates, But would love to hear your thoughts on how it should ideally look like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not opinionated about how the result dict should look like.
but I do think we shouldn't use it as a control message, meaning, I can only get my policies updated if I add something in the result dict.
these two things probably shouldn't go together?

Copy link
Contributor Author

@kouroshHakha kouroshHakha Feb 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point and I actually have a better idea? right now I haven't even made trainer_runner such that it only updates the policies that is allowed (via policies_to_train). I think with that variable lingering around, I can infer the policies to update. Then the retuened results won't be used as the message passing medium. I'll add a todo with better design guideline in the next PR.

else:
policies_to_update = list(train_results.keys())

# TODO (Kourosh): num_grad_updates per each policy should be accessible via
# train_results
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
"num_grad_updates_per_policy": {
Expand All @@ -384,24 +434,34 @@ def training_step(self) -> ResultDict:

# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.num_remote_workers() > 0:
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
from_worker = None
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
if self.workers.num_remote_workers() > 0:
from_worker_or_trainer = None
if self.config._enable_rl_trainer_api:
from_worker = self.trainer_runner
# sync weights from trainer_runner to all rollout workers
from_worker_or_trainer = self.trainer_runner
self.workers.sync_weights(
from_worker=from_worker,
policies=list(train_results.keys()),
from_worker_or_trainer=from_worker_or_trainer,
policies=policies_to_update,
global_vars=global_vars,
)
elif self.config._enable_rl_trainer_api:
weights = self.trainer_runner.get_weights()
self.workers.local_worker().set_weights(weights)

if self.config._enable_rl_trainer_api:
kl_dict = {
pid: pinfo[LEARNER_STATS_KEY].get("kl")
for pid, pinfo in train_results.items()
# TODO (Kourosh): Train results don't match the old format. The thing
# that used to be under `kl` is now under `mean_kl_loss`. Fix this. Do
# we need get here?
pid: train_results["loss"][pid].get("mean_kl_loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the case where this KL key is NOT present in the ["loss"][pid] sub-dict? Do we need the get here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

captured it in a TODO

for pid in policies_to_update
}
# triggers a special update method on RLOptimizer to update the KL values.
self.trainer_runner.additional_update(kl_values=kl_dict)
self.trainer_runner.additional_update(
sampled_kl_values=kl_dict,
timestep=self._counters[NUM_AGENT_STEPS_SAMPLED],
)

return train_results

Expand Down
21 changes: 21 additions & 0 deletions rllib/algorithms/ppo/ppo_rl_trainer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Optional, Union

from ray.rllib.core.rl_trainer.rl_trainer import RLTrainerHPs


@dataclass
class PPORLTrainerHPs(RLTrainerHPs):
"""Hyperparameters for the PPO RL Trainer"""

kl_coeff: float = 0.2
kl_target: float = 0.01
use_critic: bool = True
clip_param: float = 0.3
vf_clip_param: float = 10.0
entropy_coeff: float = 0.0
vf_loss_coeff: float = 1.0

# experimental placeholder for things that could be part of the base RLTrainerHPs
lr_schedule: Optional[List[List[Union[int, float]]]] = None
entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None
Loading