Skip to content

Commit

Permalink
[RLlib] Single agent RLTrainer made easy (#31802)
Browse files Browse the repository at this point in the history
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
  • Loading branch information
kouroshHakha authored Jan 20, 2023
1 parent fdc2722 commit f31a0ad
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 40 deletions.
68 changes: 66 additions & 2 deletions rllib/core/rl_trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.core.rl_module.rl_module import RLModule, ModuleID
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType
Expand Down Expand Up @@ -145,7 +145,6 @@ def configure_optimizers(self) -> ParamOptimizerPairs:
"""

@abc.abstractmethod
def compute_loss(
self,
*,
Expand Down Expand Up @@ -178,6 +177,48 @@ def compute_loss(
# should find a way to allow them to specify single-agent losses as well,
# without having to think about one extra layer of hierarchy for module ids.

loss_total = None
results_all_modules = {}
for module_id in fwd_out:
module_batch = batch[module_id]
module_fwd_out = fwd_out[module_id]

module_results = self._compute_loss_per_module(
module_id, module_batch, module_fwd_out
)
results_all_modules[module_id] = module_results
loss = module_results[self.TOTAL_LOSS_KEY]

if loss_total is None:
loss_total = loss
else:
loss_total += loss

results_all_modules[self.TOTAL_LOSS_KEY] = loss_total

return results_all_modules

def _compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
) -> Mapping[str, Any]:
"""Computes the loss for a single module.
Think of this as computing loss for a
single agent. For multi-agent use-cases that require more complicated
computation for loss, consider overriding the `compute_loss` method instead.
Args:
module_id: The id of the module.
batch: The sample batch for this particular module.
fwd_out: The output of the forward pass for this particular module.
Returns:
A dictionary of losses. NOTE that the dictionary
must contain one protected key "total_loss" which will be used for
computing gradients through.
"""
raise NotImplementedError

def postprocess_gradients(
self, gradients_dict: Mapping[str, Any]
) -> Mapping[str, Any]:
Expand Down Expand Up @@ -288,6 +329,29 @@ def additional_update(self, *args, **kwargs) -> Mapping[str, Any]:
Returns:
A dictionary of results from the update
"""
results_all_modules = {}
for module_id in self._module.keys():
module_results = self._additional_update_per_module(
module_id, *args, **kwargs
)
results_all_modules[module_id] = module_results

return results_all_modules

def _additional_update_per_module(
self, module_id: str, *args, **kwargs
) -> Mapping[str, Any]:
"""Apply additional non-gradient based updates for a single module.
Args:
module_id: The id of the module to update.
*args: Arguments to use for the update.
**kwargs: Keyword arguments to use for the additional update.
Returns:
A dictionary of results from the update
"""

raise NotImplementedError

@abc.abstractmethod
Expand Down
25 changes: 7 additions & 18 deletions rllib/core/testing/tf/bc_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,17 @@

from ray.rllib.core.rl_trainer.tf.tf_rl_trainer import TfRLTrainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch

from ray.rllib.core.testing.testing_trainer import BaseTestingTrainer
from ray.rllib.utils.typing import TensorType


class BCTfRLTrainer(TfRLTrainer, BaseTestingTrainer):
def compute_loss(
self, fwd_out: MultiAgentBatch, batch: MultiAgentBatch
def _compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
) -> Mapping[str, Any]:

loss_dict = {}
loss_total = None
for module_id in fwd_out:
action_dist = fwd_out[module_id]["action_dist"]
loss = -tf.math.reduce_mean(
action_dist.log_prob(batch[module_id][SampleBatch.ACTIONS])
)
loss_dict[module_id] = loss
if loss_total is None:
loss_total = loss
else:
loss_total += loss

loss_dict[self.TOTAL_LOSS_KEY] = loss_total
action_dist = fwd_out["action_dist"]
loss = -tf.math.reduce_mean(action_dist.log_prob(batch[SampleBatch.ACTIONS]))

return loss_dict
return {self.TOTAL_LOSS_KEY: loss}
27 changes: 7 additions & 20 deletions rllib/core/testing/torch/bc_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,16 @@
from typing import Any, Mapping

from ray.rllib.core.rl_trainer.torch.torch_rl_trainer import TorchRLTrainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.core.testing.testing_trainer import BaseTestingTrainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import TensorType


class BCTorchRLTrainer(TorchRLTrainer, BaseTestingTrainer):
def compute_loss(
self, fwd_out: MultiAgentBatch, batch: MultiAgentBatch
def _compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
) -> Mapping[str, Any]:

loss_dict = {}
loss_total = None
for module_id in fwd_out:
action_dist = fwd_out[module_id]["action_dist"]
loss = -torch.mean(
action_dist.log_prob(batch[module_id][SampleBatch.ACTIONS])
)
loss_dict[module_id] = loss
if loss_total is None:
loss_total = loss
else:
loss_total += loss

loss_dict[self.TOTAL_LOSS_KEY] = loss_total

return loss_dict
action_dist = fwd_out["action_dist"]
loss = -torch.mean(action_dist.log_prob(batch[SampleBatch.ACTIONS]))
return {self.TOTAL_LOSS_KEY: loss}

0 comments on commit f31a0ad

Please sign in to comment.