Skip to content

Commit

Permalink
[RLlib] Chaining Models in RLModules (#31469)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
  • Loading branch information
ArturNiederfahrenhorst authored Feb 7, 2023
1 parent c83111a commit 027965b
Show file tree
Hide file tree
Showing 24 changed files with 1,075 additions and 813 deletions.
8 changes: 0 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1975,14 +1975,6 @@ py_test(
srcs = ["models/specs/tests/test_spec_dict.py"]
)

# test TorchVectorEncoder
py_test(
name = "test_torch_vector_encoder",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/torch/encoders/tests/test_torch_vector_encoder.py"]
)


# --------------------------------------------------------------------
# Offline
Expand Down
177 changes: 51 additions & 126 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,43 @@
import itertools
import ray
import unittest
import numpy as np

import gymnasium as gym
import torch
import numpy as np
import tensorflow as tf
import torch
import tree

import ray
from ray.rllib import SampleBatch
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import (
PPOTfRLModule,
)
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule,
PPOModuleConfig,
)
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import (
PPOTfRLModule,
PPOTfModuleConfig,
from ray.rllib.models.experimental.configs import (
MLPConfig,
MLPEncoderConfig,
LSTMEncoderConfig,
)
from ray.rllib.core.rl_module.encoder import (
FCConfig,
IdentityConfig,
LSTMConfig,
from ray.rllib.models.experimental.torch.encoder import (
STATE_IN,
STATE_OUT,
)
from ray.rllib.core.rl_module.encoder_tf import (
FCTfConfig,
IdentityTfConfig,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor


def get_expected_model_config_torch(
env: gym.Env, lstm: bool, shared_encoder: bool
def get_expected_model_config(
env: gym.Env,
lstm: bool,
) -> PPOModuleConfig:
"""Get a PPOModuleConfig that we would expect from the catalog otherwise.
Args:
env: Environment for which we build the model later
lstm: If True, build recurrent pi encoder
shared_encoder: If True, build a shared encoder for pi and vf, where pi
encoder and vf encoder will be identity. If False, the shared encoder
will be identity.
Returns:
A PPOModuleConfig containing the relevant configs to build PPORLModule
Expand All @@ -51,107 +47,44 @@ def get_expected_model_config_torch(
)
obs_dim = env.observation_space.shape[0]

if shared_encoder:
assert not lstm, "LSTM can only be used in PI"
shared_encoder_config = FCConfig(
if lstm:
encoder_config = LSTMEncoderConfig(
input_dim=obs_dim,
hidden_layers=[32],
activation="ReLU",
hidden_dim=32,
batch_first=True,
num_layers=1,
output_dim=32,
)
pi_encoder_config = IdentityConfig(output_dim=32)
vf_encoder_config = IdentityConfig(output_dim=32)
else:
shared_encoder_config = IdentityConfig(output_dim=obs_dim)
if lstm:
pi_encoder_config = LSTMConfig(
input_dim=obs_dim,
hidden_dim=32,
batch_first=True,
output_dim=32,
num_layers=1,
)
else:
pi_encoder_config = FCConfig(
input_dim=obs_dim,
output_dim=32,
hidden_layers=[32],
activation="ReLU",
)
vf_encoder_config = FCConfig(
encoder_config = MLPEncoderConfig(
input_dim=obs_dim,
hidden_layer_dims=[32],
hidden_layer_activation="ReLU",
output_dim=32,
hidden_layers=[32],
activation="ReLU",
)

pi_config = FCConfig()
vf_config = FCConfig()

if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
else:
pi_config.output_dim = env.action_space.shape[0] * 2

return PPOModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
shared_encoder_config=shared_encoder_config,
pi_encoder_config=pi_encoder_config,
vf_encoder_config=vf_encoder_config,
pi_config=pi_config,
vf_config=vf_config,
shared_encoder=shared_encoder,
pi_config = MLPConfig(
input_dim=32,
hidden_layer_dims=[32],
hidden_layer_activation="ReLU",
)


def get_expected_model_config_tf(
env: gym.Env, shared_encoder: bool
) -> PPOTfModuleConfig:
"""Get a PPOTfModuleConfig that we would expect from the catalog otherwise.
Args:
env: Environment for which we build the model later
shared_encoder: If True, build a shared encoder for pi and vf, where pi
encoder and vf encoder will be identity. If False, the shared encoder
will be identity.
Returns:
A PPOTfModuleConfig containing the relevant configs to build PPOTfRLModule.
"""
assert len(env.observation_space.shape) == 1, (
"No multidimensional obs space " "supported."
vf_config = MLPConfig(
input_dim=32,
hidden_layer_dims=[32, 1],
hidden_layer_activation="ReLU",
)
obs_dim = env.observation_space.shape[0]

if shared_encoder:
shared_encoder_config = FCTfConfig(
input_dim=obs_dim,
hidden_layers=[32],
activation="ReLU",
output_dim=32,
)
else:
shared_encoder_config = IdentityTfConfig(output_dim=obs_dim)
pi_config = FCConfig()
vf_config = FCConfig()
pi_config.input_dim = vf_config.input_dim = shared_encoder_config.output_dim

if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
else:
pi_config.output_dim = env.action_space.shape[0] * 2

pi_config.hidden_layers = vf_config.hidden_layers = [32]
pi_config.activation = vf_config.activation = "ReLU"

return PPOTfModuleConfig(
return PPOModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
shared_encoder_config=shared_encoder_config,
encoder_config=encoder_config,
pi_config=pi_config,
vf_config=vf_config,
shared_encoder=shared_encoder,
)


Expand Down Expand Up @@ -207,12 +140,11 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def get_ppo_module(self, framwework, env, lstm, shared_encoder):
if framwework == "torch":
config = get_expected_model_config_torch(env, lstm, shared_encoder)
def get_ppo_module(self, framework, env, lstm):
config = get_expected_model_config(env, lstm)
if framework == "torch":
module = PPOTorchRLModule(config)
else:
config = get_expected_model_config_tf(env, shared_encoder)
module = PPOTfRLModule(config)
return module

Expand All @@ -222,29 +154,24 @@ def get_input_batch_from_obs(self, framework, obs):
SampleBatch.OBS: convert_to_torch_tensor(obs)[None],
}
else:
batch = {SampleBatch.OBS: np.array([obs])}
batch = {SampleBatch.OBS: tf.convert_to_tensor([obs])}
return batch

def test_rollouts(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1"]
fwd_fns = ["forward_exploration", "forward_inference"]
shared_encoders = [False, True]
ltsms = [False, True]
config_combinations = [frameworks, env_names, fwd_fns, shared_encoders, ltsms]
lstm = [False, True]
config_combinations = [frameworks, env_names, fwd_fns, lstm]
for config in itertools.product(*config_combinations):
fw, env_name, fwd_fn, shared_encoder, lstm = config
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
fw, env_name, fwd_fn, lstm = config
if lstm and fw == "tf2":
# LSTM not implemented in TF2 yet
continue
print(f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" f"={lstm}")
print(f"[FW={fw} | [ENV={env_name}] | [FWD={fwd_fn}] | LSTM" f"={lstm}")
env = gym.make(env_name)
module = self.get_ppo_module(fw, env, lstm, shared_encoder)
module = self.get_ppo_module(framework=fw, env=env, lstm=lstm)

obs, _ = env.reset()

Expand All @@ -267,22 +194,17 @@ def test_forward_train(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1"]
shared_encoders = [False, True]
ltsms = [False, True]
config_combinations = [frameworks, env_names, shared_encoders, ltsms]
lstm = [False, True]
config_combinations = [frameworks, env_names, lstm]
for config in itertools.product(*config_combinations):
fw, env_name, shared_encoder, lstm = config
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
fw, env_name, lstm = config
if lstm and fw == "tf2":
# LSTM not implemented in TF2 yet
continue
print(f"[ENV={env_name}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}")
print(f"[FW={fw} | [ENV={env_name}] | LSTM={lstm}")
env = gym.make(env_name)

module = self.get_ppo_module(fw, env, lstm, shared_encoder)
module = self.get_ppo_module(fw, env, lstm)

# collect a batch of data
batches = []
Expand Down Expand Up @@ -343,6 +265,9 @@ def test_forward_train(self):
for param in module.parameters():
self.assertIsNotNone(param.grad)
else:
batch = tree.map_structure(
lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch
)
with tf.GradientTape() as tape:
fwd_out = module.forward_train(batch)
loss = dummy_tf_ppo_loss(batch, fwd_out)
Expand Down
Loading

0 comments on commit 027965b

Please sign in to comment.