Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience). #47889

13 changes: 8 additions & 5 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
Expand All @@ -35,6 +35,10 @@ def compute_loss_for_module(
batch: Dict,
fwd_out: Dict[str, TensorType],
) -> TensorType:
module = self.module[module_id].unwrapped()
assert isinstance(module, TargetNetworkAPI)
assert isinstance(module, ValueFunctionAPI)

# TODO (sven): Now that we do the +1ts trick to be less vulnerable about
# bootstrap values at the end of rollouts in the new stack, we might make
# this a more flexible, configurable parameter for users, e.g.
Expand All @@ -51,10 +55,9 @@ def compute_loss_for_module(
)
size_loss_mask = torch.sum(loss_mask)

module = self.module[module_id].unwrapped()
assert isinstance(module, TargetNetworkAPI)

values = fwd_out[Columns.VF_PREDS]
values = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)

action_dist_cls_train = module.get_train_action_dist_cls()
target_policy_dist = action_dist_cls_train.from_logits(
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/bc_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def build_pi_head(self, framework: str) -> Model:

The default behavior is to build the head from the pi_head_config.
This can be overridden to build a custom policy head as a means of configuring
the behavior of a BCRLModule implementation.
the behavior of a BC specific RLModule implementation.

Args:
framework: The framework to use. Either "torch" or "tf2".
Expand Down
82 changes: 0 additions & 82 deletions rllib/algorithms/bc/bc_rl_module.py

This file was deleted.

32 changes: 29 additions & 3 deletions rllib/algorithms/bc/torch/bc_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
from ray.rllib.algorithms.bc.bc_rl_module import BCRLModule
from typing import Any, Dict

from ray.rllib.core import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override


class BCTorchRLModule(TorchRLModule):
@override(RLModule)
def setup(self):
# __sphinx_doc_begin__
# Build models from catalog
self.encoder = self.catalog.build_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)

@override(RLModule)
def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
"""Generic BC forward pass (for all phases of training/evaluation)."""
output = {}

# State encodings.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Actions.
action_logits = self.pi(encoder_outs[ENCODER_OUT])
output[Columns.ACTION_DIST_INPUTS] = action_logits

class BCTorchRLModule(TorchRLModule, BCRLModule):
pass
return output
14 changes: 8 additions & 6 deletions rllib/algorithms/impala/torch/impala_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def compute_loss_for_module(
batch: Dict,
fwd_out: Dict[str, TensorType],
) -> TensorType:
module = self.module[module_id].unwrapped()

# TODO (sven): Now that we do the +1ts trick to be less vulnerable about
# bootstrap values at the end of rollouts in the new stack, we might make
# this a more flexible, configurable parameter for users, e.g.
Expand All @@ -46,17 +48,17 @@ def compute_loss_for_module(

# Behavior actions logp and target actions logp.
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
target_policy_dist = (
self.module[module_id]
.unwrapped()
.get_train_action_dist_cls()
.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
target_policy_dist = module.get_train_action_dist_cls().from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])

# Values and bootstrap values.
values = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)
values_time_major = make_time_major(
fwd_out[Columns.VF_PREDS],
values,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
Expand Down
6 changes: 2 additions & 4 deletions rllib/algorithms/marwil/marwil_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

@DeveloperAPI(stability="alpha")
class MARWILRLModule(RLModule, ValueFunctionAPI, abc.ABC):
@override(RLModule)
def setup(self):
# Build models from catalog
self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
Expand Down Expand Up @@ -47,7 +48,4 @@ def input_specs_train(self) -> SpecDict:

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.VF_PREDS,
Columns.ACTION_DIST_INPUTS,
]
return [Columns.ACTION_DIST_INPUTS]
9 changes: 5 additions & 4 deletions rllib/algorithms/marwil/torch/marwil_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def compute_loss_for_module(
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType]
) -> TensorType:
module = self.module[module_id].unwrapped()

# Possibly apply masking to some sub loss terms and to the total loss term
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
Expand All @@ -46,9 +47,7 @@ def possibly_masked_mean(data_):
else:
possibly_masked_mean = torch.mean

action_dist_class_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
curr_action_dist = action_dist_class_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
Expand All @@ -64,7 +63,9 @@ def possibly_masked_mean(data_):
# Otherwise, compute advantages.
else:
# cumulative_rewards = batch[Columns.ADVANTAGES]
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)
advantages = batch[Columns.VALUE_TARGETS] - value_fn_out
advantages_squared_mean = possibly_masked_mean(torch.pow(advantages, 2.0))

Expand Down
32 changes: 18 additions & 14 deletions rllib/algorithms/marwil/torch/marwil_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.algorithms.marwil.marwil_rl_module import MARWILRLModule
from ray.rllib.core.columns import Columns
Expand Down Expand Up @@ -42,7 +42,6 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"flag `inference_only=False` when building the module."
)
output = {}

# Shared encoder.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
Expand All @@ -63,18 +62,23 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# (similar to IMPALA's v-trace architecture). This would also get rid of the
# second Connector pass currently necessary.
@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
encoder_outs = self.encoder.critic_encoder(batch)[ENCODER_OUT]
# Shared encoder.
else:
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC]
def compute_values(
self,
batch: Dict[str, Any],
embeddings: Optional[Any] = None,
) -> TensorType:
if embeddings is None:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
embeddings = self.encoder.critic_encoder(batch)[ENCODER_OUT]
# Shared encoder.
else:
embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC]
# Value head.
vf_out = self.vf(encoder_outs)
vf_out = self.vf(embeddings)
# Squeeze out last dimension (single node value head).
return vf_out.squeeze(-1)
5 changes: 1 addition & 4 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ def input_specs_train(self) -> SpecDict:

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.VF_PREDS,
Columns.ACTION_DIST_INPUTS,
]
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
Expand Down
49 changes: 4 additions & 45 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,7 @@
torch, nn = try_import_torch()


def get_expected_module_config(
env: gym.Env,
model_config_dict: dict,
observation_space: gym.spaces.Space,
) -> RLModuleConfig:
"""Get a PPOModuleConfig that we would expect from the catalog otherwise.

Args:
env: Environment for which we build the model later
model_config_dict: Model config to use for the catalog
observation_space: Observation space to use for the catalog.

Returns:
A PPOModuleConfig containing the relevant configs to build PPORLModule
"""
def get_expected_module_config(env, model_config_dict, observation_space):
config = RLModuleConfig(
observation_space=observation_space,
action_space=env.action_space,
Expand All @@ -52,22 +38,7 @@ def get_expected_module_config(


def dummy_torch_ppo_loss(module, batch, fwd_out):
"""Dummy PPO loss function for testing purposes.

Will eventually use the actual PPO loss function implemented in PPO.

Args:
batch: Batch used for training.
fwd_out: Forward output of the model.

Returns:
Loss tensor
"""
# TODO: we should replace these components later with real ppo components when
# RLOptimizer and RLModule are integrated together.
# this is not exactly a ppo loss, just something to show that the
# forward train works
adv = batch[Columns.REWARDS] - fwd_out[Columns.VF_PREDS]
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
action_probs = action_dist_class.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
Expand All @@ -80,19 +51,7 @@ def dummy_torch_ppo_loss(module, batch, fwd_out):


def dummy_tf_ppo_loss(module, batch, fwd_out):
"""Dummy PPO loss function for testing purposes.

Will eventually use the actual PPO loss function implemented in PPO.

Args:
module: PPOTfRLModule
batch: Batch used for training.
fwd_out: Forward output of the model.

Returns:
Loss tensor
"""
adv = batch[Columns.REWARDS] - fwd_out[Columns.VF_PREDS]
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
action_probs = action_dist_class.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
Expand Down Expand Up @@ -180,7 +139,7 @@ def test_rollouts(self):

def test_forward_train(self):
# TODO: Add FrozenLake-v1 to cover LSTM case.
frameworks = ["tf2", "torch"]
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]
lstm = [False, True]
config_combinations = [frameworks, env_names, lstm]
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tf/ppo_tf_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _forward_train(self, batch: Dict):
return output

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
def compute_values(self, batch: Dict[str, Any], embeddings=None) -> TensorType:
infos = batch.pop(Columns.INFOS, None)
batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch)
if infos is not None:
Expand Down
Loading
Loading