Skip to content

Commit

Permalink
[RLlib] - Fix APPO RLModule inference-only problems. (#45111)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored May 3, 2024
1 parent 6ab48be commit 45d5640
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 22 deletions.
26 changes: 26 additions & 0 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
This file holds framework-agnostic components for APPO's RLModules.
"""

import abc

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.utils.annotations import ExperimentalAPI

# TODO (simon): Write a light-weight version of this class for the `TFRLModule`


@ExperimentalAPI
class APPORLModule(PPORLModule, abc.ABC):
def setup(self):
super().setup()

# If the module is not for inference only, set up the target networks.
if not self.inference_only:
catalog = self.config.get_catalog()
# Old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(
framework=self.framework
)
20 changes: 10 additions & 10 deletions rllib/algorithms/appo/tf/appo_tf_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR
Expand All @@ -15,18 +16,17 @@
_, tf, _ = try_import_tf()


class APPOTfRLModule(PPOTfRLModule, RLModuleWithTargetNetworksInterface):
class APPOTfRLModule(PPOTfRLModule, RLModuleWithTargetNetworksInterface, APPORLModule):
@override(PPOTfRLModule)
def setup(self):
super().setup()
catalog = self.config.get_catalog()
# old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(framework=self.framework)
self.old_pi.set_weights(self.pi.get_weights())
self.old_encoder.set_weights(self.encoder.get_weights())
self.old_pi.trainable = False
self.old_encoder.trainable = False

# If the module is not for inference only, set up the target networks.
if not self.inference_only:
self.old_pi.set_weights(self.pi.get_weights())
self.old_encoder.set_weights(self.encoder.get_weights())
self.old_pi.trainable = False
self.old_encoder.trainable = False

@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self):
Expand Down
34 changes: 24 additions & 10 deletions rllib/algorithms/appo/torch/appo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ray.rllib.algorithms.appo.appo import (
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR
Expand All @@ -14,19 +15,20 @@
from ray.rllib.utils.nested_dict import NestedDict


class APPOTorchRLModule(PPOTorchRLModule, RLModuleWithTargetNetworksInterface):
class APPOTorchRLModule(
PPOTorchRLModule, RLModuleWithTargetNetworksInterface, APPORLModule
):
@override(PPOTorchRLModule)
def setup(self):
super().setup()
catalog = self.config.get_catalog()
# Old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(framework=self.framework)
self.old_pi.load_state_dict(self.pi.state_dict())
self.old_encoder.load_state_dict(self.encoder.state_dict())
self.old_pi.trainable = False
self.old_encoder.trainable = False

# If the module is not for inference only, update the target networks.
if not self.inference_only:
self.old_pi.load_state_dict(self.pi.state_dict())
self.old_encoder.load_state_dict(self.encoder.state_dict())
# We do not train the targets.
self.old_pi.requires_grad_(False)
self.old_encoder.requires_grad_(False)

@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self):
Expand All @@ -47,3 +49,15 @@ def _forward_train(self, batch: NestedDict):
old_action_dist_logits = self.old_pi(old_pi_inputs_encoded)
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs

@override(PPOTorchRLModule)
def _set_inference_only_state_dict_keys(self) -> None:
# Get the model_parameters from the `PPOTorchRLModule`.
super()._set_inference_only_state_dict_keys()
# Get the model_parameters.
state_dict = self.state_dict()
# Note, these keys are only known to the learner module. Furthermore,
# we want this to be run once during setup and not for each worker.
self._inference_only_state_dict_keys["unexpected_keys"].extend(
[name for name in state_dict if "old" in name]
)
6 changes: 4 additions & 2 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def _set_inference_only_state_dict_keys(self) -> None:
# Note, these keys are only known to the learner module. Furthermore,
# we want this to be run once during setup and not for each worker.
self._inference_only_state_dict_keys["unexpected_keys"] = [
name for name in state_dict if "vf" in name or "critic_encoder" in name
name
for name in state_dict
if "vf" in name or name.startswith("encoder.critic_encoder")
]
# Do we use a separate encoder for the actor and critic?
# if not self.config.model_config_dict.get("vf_share_layers", True):
Expand All @@ -153,7 +155,7 @@ def _set_inference_only_state_dict_keys(self) -> None:
self._inference_only_state_dict_keys["expected_keys"] = {
name: name.replace("actor_encoder", "encoder")
for name in state_dict
if "actor_encoder" in name
if name.startswith("encoder.actor_encoder")
}

@override(TorchRLModule)
Expand Down

0 comments on commit 45d5640

Please sign in to comment.