Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Chaining Models in RLModules #31469

Merged
merged 52 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b1ddd73
initial
ArturNiederfahrenhorst Dec 16, 2022
5489c89
tests complete
ArturNiederfahrenhorst Dec 16, 2022
0dc91e2
wip
ArturNiederfahrenhorst Dec 16, 2022
a9a9498
wip
ArturNiederfahrenhorst Dec 16, 2022
9930d79
mutually exclusive encoders, tests passing
ArturNiederfahrenhorst Dec 18, 2022
109e56d
add lstm code
ArturNiederfahrenhorst Dec 19, 2022
532c8c0
better docs for get expected model config
ArturNiederfahrenhorst Dec 21, 2022
31ae2a0
kourosh's comments
ArturNiederfahrenhorst Dec 21, 2022
018e223
lstm fixed, tests working
ArturNiederfahrenhorst Dec 21, 2022
b38d501
add state out
ArturNiederfahrenhorst Dec 21, 2022
462fc4d
add __main__ to test
ArturNiederfahrenhorst Dec 21, 2022
f93795a
change lstm testing according to kourosh's comment
ArturNiederfahrenhorst Dec 21, 2022
6ae3443
fix get_initial_state
ArturNiederfahrenhorst Dec 21, 2022
548f42a
remove useless forward_exploration/forward_inference branch
ArturNiederfahrenhorst Dec 21, 2022
92ae510
revert changes to test_ppo_with_rl_module.py
ArturNiederfahrenhorst Dec 21, 2022
04228db
remove pass
ArturNiederfahrenhorst Dec 22, 2022
0ab7809
fix gym incompatability
ArturNiederfahrenhorst Dec 22, 2022
0b7536c
add missing ray.
ArturNiederfahrenhorst Dec 27, 2022
77f991a
test_ppo_rl_module working
ArturNiederfahrenhorst Jan 5, 2023
096a612
ppo_torch_rl_module tests working
ArturNiederfahrenhorst Jan 6, 2023
a373538
feedback from kourosh from last week
ArturNiederfahrenhorst Jan 19, 2023
c2ba97c
solution 3
ArturNiederfahrenhorst Jan 19, 2023
6ef042e
some larger refactors
ArturNiederfahrenhorst Jan 24, 2023
d8d8c72
rename PPOModuleConfig for upcoming release
ArturNiederfahrenhorst Jan 24, 2023
503b7dd
cleanup
ArturNiederfahrenhorst Jan 24, 2023
3c30b42
rename configs
ArturNiederfahrenhorst Jan 24, 2023
5687454
remove rebase artifacts
ArturNiederfahrenhorst Jan 24, 2023
bec7f47
minor fixes
ArturNiederfahrenhorst Jan 24, 2023
da887b6
fix import
ArturNiederfahrenhorst Jan 24, 2023
dbd7c0a
Merge remote-tracking branch 'upstream/master' into solution2
ArturNiederfahrenhorst Jan 24, 2023
42e0d49
fix misspelling
ArturNiederfahrenhorst Jan 24, 2023
aef9875
fix import
ArturNiederfahrenhorst Jan 24, 2023
d0d5277
also fix other imports
ArturNiederfahrenhorst Jan 24, 2023
3011262
typo
ArturNiederfahrenhorst Jan 24, 2023
3591fab
delete unneeded configs
ArturNiederfahrenhorst Jan 24, 2023
ab2c301
remove unneeded model configs
ArturNiederfahrenhorst Jan 24, 2023
e7aa528
sven's nits
ArturNiederfahrenhorst Jan 26, 2023
2d709d5
sven's nits
ArturNiederfahrenhorst Jan 26, 2023
871fe75
more refactors
ArturNiederfahrenhorst Jan 26, 2023
0bafdd7
another nit
ArturNiederfahrenhorst Jan 26, 2023
69d9655
some more renaming
ArturNiederfahrenhorst Jan 26, 2023
bd687ab
delete vectorencoder
ArturNiederfahrenhorst Jan 26, 2023
6085e05
add back torch folder
ArturNiederfahrenhorst Jan 26, 2023
f46b569
fix model names and some nits
ArturNiederfahrenhorst Jan 27, 2023
868d442
renaming and lint
ArturNiederfahrenhorst Jan 27, 2023
8326719
self-review
ArturNiederfahrenhorst Jan 27, 2023
cede51f
output activations
ArturNiederfahrenhorst Jan 27, 2023
8bc5c08
remove useless init and add comment to torch encoder
ArturNiederfahrenhorst Jan 27, 2023
7846d1b
remove useless constructor
ArturNiederfahrenhorst Jan 27, 2023
d2b4aee
lint
ArturNiederfahrenhorst Jan 30, 2023
166e4eb
kourohs's nits
ArturNiederfahrenhorst Jan 31, 2023
c70224f
unify torch + tf
ArturNiederfahrenhorst Jan 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1974,14 +1974,6 @@ py_test(
srcs = ["models/specs/tests/test_spec_dict.py"]
)

# test TorchVectorEncoder
py_test(
name = "test_torch_vector_encoder",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/torch/encoders/tests/test_torch_vector_encoder.py"]
)


# --------------------------------------------------------------------
# Offline
Expand Down
177 changes: 51 additions & 126 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,43 @@
import itertools
import ray
import unittest
import numpy as np

import gymnasium as gym
import torch
import numpy as np
import tensorflow as tf
import torch
import tree

import ray
from ray.rllib import SampleBatch
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import (
PPOTfRLModule,
)
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule,
PPOModuleConfig,
)
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import (
PPOTfRLModule,
PPOTfModuleConfig,
from ray.rllib.models.experimental.configs import (
MLPConfig,
MLPEncoderConfig,
LSTMEncoderConfig,
)
from ray.rllib.core.rl_module.encoder import (
FCConfig,
IdentityConfig,
LSTMConfig,
from ray.rllib.models.experimental.torch.encoder import (
STATE_IN,
STATE_OUT,
)
from ray.rllib.core.rl_module.encoder_tf import (
FCTfConfig,
IdentityTfConfig,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor


def get_expected_model_config_torch(
env: gym.Env, lstm: bool, shared_encoder: bool
def get_expected_model_config(
env: gym.Env,
lstm: bool,
) -> PPOModuleConfig:
"""Get a PPOModuleConfig that we would expect from the catalog otherwise.

Args:
env: Environment for which we build the model later
lstm: If True, build recurrent pi encoder
shared_encoder: If True, build a shared encoder for pi and vf, where pi
encoder and vf encoder will be identity. If False, the shared encoder
will be identity.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll reintroduce this in a upcoming PR with the ActorCriticEncoder


Returns:
A PPOModuleConfig containing the relevant configs to build PPORLModule
Expand All @@ -51,107 +47,44 @@ def get_expected_model_config_torch(
)
obs_dim = env.observation_space.shape[0]

if shared_encoder:
assert not lstm, "LSTM can only be used in PI"
shared_encoder_config = FCConfig(
if lstm:
encoder_config = LSTMEncoderConfig(
input_dim=obs_dim,
hidden_layers=[32],
activation="ReLU",
hidden_dim=32,
batch_first=True,
num_layers=1,
output_dim=32,
)
pi_encoder_config = IdentityConfig(output_dim=32)
vf_encoder_config = IdentityConfig(output_dim=32)
else:
shared_encoder_config = IdentityConfig(output_dim=obs_dim)
if lstm:
pi_encoder_config = LSTMConfig(
input_dim=obs_dim,
hidden_dim=32,
batch_first=True,
output_dim=32,
num_layers=1,
)
else:
pi_encoder_config = FCConfig(
input_dim=obs_dim,
output_dim=32,
hidden_layers=[32],
activation="ReLU",
)
vf_encoder_config = FCConfig(
encoder_config = MLPEncoderConfig(
input_dim=obs_dim,
hidden_layer_dims=[32],
hidden_layer_activation="ReLU",
output_dim=32,
hidden_layers=[32],
activation="ReLU",
)

pi_config = FCConfig()
vf_config = FCConfig()

if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
else:
pi_config.output_dim = env.action_space.shape[0] * 2

return PPOModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
shared_encoder_config=shared_encoder_config,
pi_encoder_config=pi_encoder_config,
vf_encoder_config=vf_encoder_config,
pi_config=pi_config,
vf_config=vf_config,
shared_encoder=shared_encoder,
pi_config = MLPConfig(
input_dim=32,
hidden_layer_dims=[32],
hidden_layer_activation="ReLU",
)


def get_expected_model_config_tf(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've unified these into one, since model configs are planned to be framework agnostic.

env: gym.Env, shared_encoder: bool
) -> PPOTfModuleConfig:
"""Get a PPOTfModuleConfig that we would expect from the catalog otherwise.

Args:
env: Environment for which we build the model later
shared_encoder: If True, build a shared encoder for pi and vf, where pi
encoder and vf encoder will be identity. If False, the shared encoder
will be identity.

Returns:
A PPOTfModuleConfig containing the relevant configs to build PPOTfRLModule.
"""
assert len(env.observation_space.shape) == 1, (
"No multidimensional obs space " "supported."
vf_config = MLPConfig(
input_dim=32,
hidden_layer_dims=[32, 1],
hidden_layer_activation="ReLU",
)
obs_dim = env.observation_space.shape[0]

if shared_encoder:
shared_encoder_config = FCTfConfig(
input_dim=obs_dim,
hidden_layers=[32],
activation="ReLU",
output_dim=32,
)
else:
shared_encoder_config = IdentityTfConfig(output_dim=obs_dim)
pi_config = FCConfig()
vf_config = FCConfig()
pi_config.input_dim = vf_config.input_dim = shared_encoder_config.output_dim

if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
else:
pi_config.output_dim = env.action_space.shape[0] * 2

pi_config.hidden_layers = vf_config.hidden_layers = [32]
pi_config.activation = vf_config.activation = "ReLU"

return PPOTfModuleConfig(
return PPOModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
shared_encoder_config=shared_encoder_config,
encoder_config=encoder_config,
pi_config=pi_config,
vf_config=vf_config,
shared_encoder=shared_encoder,
)


Expand Down Expand Up @@ -207,12 +140,11 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def get_ppo_module(self, framwework, env, lstm, shared_encoder):
if framwework == "torch":
config = get_expected_model_config_torch(env, lstm, shared_encoder)
def get_ppo_module(self, framework, env, lstm):
config = get_expected_model_config(env, lstm)
if framework == "torch":
module = PPOTorchRLModule(config)
else:
config = get_expected_model_config_tf(env, shared_encoder)
module = PPOTfRLModule(config)
return module

Expand All @@ -222,29 +154,24 @@ def get_input_batch_from_obs(self, framework, obs):
SampleBatch.OBS: convert_to_torch_tensor(obs)[None],
}
else:
batch = {SampleBatch.OBS: np.array([obs])}
batch = {SampleBatch.OBS: tf.convert_to_tensor([obs])}
return batch

def test_rollouts(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1"]
fwd_fns = ["forward_exploration", "forward_inference"]
shared_encoders = [False, True]
ltsms = [False, True]
config_combinations = [frameworks, env_names, fwd_fns, shared_encoders, ltsms]
lstm = [False, True]
config_combinations = [frameworks, env_names, fwd_fns, lstm]
for config in itertools.product(*config_combinations):
fw, env_name, fwd_fn, shared_encoder, lstm = config
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
fw, env_name, fwd_fn, lstm = config
if lstm and fw == "tf2":
# LSTM not implemented in TF2 yet
continue
print(f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" f"={lstm}")
print(f"[FW={fw} | [ENV={env_name}] | [FWD={fwd_fn}] | LSTM" f"={lstm}")
env = gym.make(env_name)
module = self.get_ppo_module(fw, env, lstm, shared_encoder)
module = self.get_ppo_module(framework=fw, env=env, lstm=lstm)

obs, _ = env.reset()

Expand All @@ -267,22 +194,17 @@ def test_forward_train(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1"]
shared_encoders = [False, True]
ltsms = [False, True]
config_combinations = [frameworks, env_names, shared_encoders, ltsms]
lstm = [False, True]
config_combinations = [frameworks, env_names, lstm]
for config in itertools.product(*config_combinations):
fw, env_name, shared_encoder, lstm = config
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
fw, env_name, lstm = config
if lstm and fw == "tf2":
# LSTM not implemented in TF2 yet
continue
print(f"[ENV={env_name}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}")
print(f"[FW={fw} | [ENV={env_name}] | LSTM={lstm}")
env = gym.make(env_name)

module = self.get_ppo_module(fw, env, lstm, shared_encoder)
module = self.get_ppo_module(fw, env, lstm)

# collect a batch of data
batches = []
Expand Down Expand Up @@ -343,6 +265,9 @@ def test_forward_train(self):
for param in module.parameters():
self.assertIsNotNone(param.grad)
else:
batch = tree.map_structure(
lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> because tf does not accept numpy arrays.

with tf.GradientTape() as tape:
fwd_out = module.forward_train(batch)
loss = dummy_tf_ppo_loss(batch, fwd_out)
Expand Down
Loading