-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Chaining Models in RLModules #31469
Merged
gjoliver
merged 52 commits into
ray-project:master
from
ArturNiederfahrenhorst:solution2
Feb 7, 2023
+1,075
−813
Merged
Changes from all commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
b1ddd73
initial
ArturNiederfahrenhorst 5489c89
tests complete
ArturNiederfahrenhorst 0dc91e2
wip
ArturNiederfahrenhorst a9a9498
wip
ArturNiederfahrenhorst 9930d79
mutually exclusive encoders, tests passing
ArturNiederfahrenhorst 109e56d
add lstm code
ArturNiederfahrenhorst 532c8c0
better docs for get expected model config
ArturNiederfahrenhorst 31ae2a0
kourosh's comments
ArturNiederfahrenhorst 018e223
lstm fixed, tests working
ArturNiederfahrenhorst b38d501
add state out
ArturNiederfahrenhorst 462fc4d
add __main__ to test
ArturNiederfahrenhorst f93795a
change lstm testing according to kourosh's comment
ArturNiederfahrenhorst 6ae3443
fix get_initial_state
ArturNiederfahrenhorst 548f42a
remove useless forward_exploration/forward_inference branch
ArturNiederfahrenhorst 92ae510
revert changes to test_ppo_with_rl_module.py
ArturNiederfahrenhorst 04228db
remove pass
ArturNiederfahrenhorst 0ab7809
fix gym incompatability
ArturNiederfahrenhorst 0b7536c
add missing ray.
ArturNiederfahrenhorst 77f991a
test_ppo_rl_module working
ArturNiederfahrenhorst 096a612
ppo_torch_rl_module tests working
ArturNiederfahrenhorst a373538
feedback from kourosh from last week
ArturNiederfahrenhorst c2ba97c
solution 3
ArturNiederfahrenhorst 6ef042e
some larger refactors
ArturNiederfahrenhorst d8d8c72
rename PPOModuleConfig for upcoming release
ArturNiederfahrenhorst 503b7dd
cleanup
ArturNiederfahrenhorst 3c30b42
rename configs
ArturNiederfahrenhorst 5687454
remove rebase artifacts
ArturNiederfahrenhorst bec7f47
minor fixes
ArturNiederfahrenhorst da887b6
fix import
ArturNiederfahrenhorst dbd7c0a
Merge remote-tracking branch 'upstream/master' into solution2
ArturNiederfahrenhorst 42e0d49
fix misspelling
ArturNiederfahrenhorst aef9875
fix import
ArturNiederfahrenhorst d0d5277
also fix other imports
ArturNiederfahrenhorst 3011262
typo
ArturNiederfahrenhorst 3591fab
delete unneeded configs
ArturNiederfahrenhorst ab2c301
remove unneeded model configs
ArturNiederfahrenhorst e7aa528
sven's nits
ArturNiederfahrenhorst 2d709d5
sven's nits
ArturNiederfahrenhorst 871fe75
more refactors
ArturNiederfahrenhorst 0bafdd7
another nit
ArturNiederfahrenhorst 69d9655
some more renaming
ArturNiederfahrenhorst bd687ab
delete vectorencoder
ArturNiederfahrenhorst 6085e05
add back torch folder
ArturNiederfahrenhorst f46b569
fix model names and some nits
ArturNiederfahrenhorst 868d442
renaming and lint
ArturNiederfahrenhorst 8326719
self-review
ArturNiederfahrenhorst cede51f
output activations
ArturNiederfahrenhorst 8bc5c08
remove useless init and add comment to torch encoder
ArturNiederfahrenhorst 7846d1b
remove useless constructor
ArturNiederfahrenhorst d2b4aee
lint
ArturNiederfahrenhorst 166e4eb
kourohs's nits
ArturNiederfahrenhorst c70224f
unify torch + tf
ArturNiederfahrenhorst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,43 @@ | ||
import itertools | ||
import ray | ||
import unittest | ||
import numpy as np | ||
|
||
import gymnasium as gym | ||
import torch | ||
import numpy as np | ||
import tensorflow as tf | ||
import torch | ||
import tree | ||
|
||
import ray | ||
from ray.rllib import SampleBatch | ||
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( | ||
PPOTfRLModule, | ||
) | ||
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig | ||
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( | ||
PPOTorchRLModule, | ||
PPOModuleConfig, | ||
) | ||
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( | ||
PPOTfRLModule, | ||
PPOTfModuleConfig, | ||
from ray.rllib.models.experimental.configs import ( | ||
MLPConfig, | ||
MLPEncoderConfig, | ||
LSTMEncoderConfig, | ||
) | ||
from ray.rllib.core.rl_module.encoder import ( | ||
FCConfig, | ||
IdentityConfig, | ||
LSTMConfig, | ||
from ray.rllib.models.experimental.torch.encoder import ( | ||
STATE_IN, | ||
STATE_OUT, | ||
) | ||
from ray.rllib.core.rl_module.encoder_tf import ( | ||
FCTfConfig, | ||
IdentityTfConfig, | ||
) | ||
from ray.rllib.utils.numpy import convert_to_numpy | ||
from ray.rllib.utils.torch_utils import convert_to_torch_tensor | ||
|
||
|
||
def get_expected_model_config_torch( | ||
env: gym.Env, lstm: bool, shared_encoder: bool | ||
def get_expected_model_config( | ||
env: gym.Env, | ||
lstm: bool, | ||
) -> PPOModuleConfig: | ||
"""Get a PPOModuleConfig that we would expect from the catalog otherwise. | ||
|
||
Args: | ||
env: Environment for which we build the model later | ||
lstm: If True, build recurrent pi encoder | ||
shared_encoder: If True, build a shared encoder for pi and vf, where pi | ||
encoder and vf encoder will be identity. If False, the shared encoder | ||
will be identity. | ||
|
||
Returns: | ||
A PPOModuleConfig containing the relevant configs to build PPORLModule | ||
|
@@ -51,107 +47,44 @@ def get_expected_model_config_torch( | |
) | ||
obs_dim = env.observation_space.shape[0] | ||
|
||
if shared_encoder: | ||
assert not lstm, "LSTM can only be used in PI" | ||
shared_encoder_config = FCConfig( | ||
if lstm: | ||
encoder_config = LSTMEncoderConfig( | ||
input_dim=obs_dim, | ||
hidden_layers=[32], | ||
activation="ReLU", | ||
hidden_dim=32, | ||
batch_first=True, | ||
num_layers=1, | ||
output_dim=32, | ||
) | ||
pi_encoder_config = IdentityConfig(output_dim=32) | ||
vf_encoder_config = IdentityConfig(output_dim=32) | ||
else: | ||
shared_encoder_config = IdentityConfig(output_dim=obs_dim) | ||
if lstm: | ||
pi_encoder_config = LSTMConfig( | ||
input_dim=obs_dim, | ||
hidden_dim=32, | ||
batch_first=True, | ||
output_dim=32, | ||
num_layers=1, | ||
) | ||
else: | ||
pi_encoder_config = FCConfig( | ||
input_dim=obs_dim, | ||
output_dim=32, | ||
hidden_layers=[32], | ||
activation="ReLU", | ||
) | ||
vf_encoder_config = FCConfig( | ||
encoder_config = MLPEncoderConfig( | ||
input_dim=obs_dim, | ||
hidden_layer_dims=[32], | ||
hidden_layer_activation="ReLU", | ||
output_dim=32, | ||
hidden_layers=[32], | ||
activation="ReLU", | ||
) | ||
|
||
pi_config = FCConfig() | ||
vf_config = FCConfig() | ||
|
||
if isinstance(env.action_space, gym.spaces.Discrete): | ||
pi_config.output_dim = env.action_space.n | ||
else: | ||
pi_config.output_dim = env.action_space.shape[0] * 2 | ||
|
||
return PPOModuleConfig( | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
shared_encoder_config=shared_encoder_config, | ||
pi_encoder_config=pi_encoder_config, | ||
vf_encoder_config=vf_encoder_config, | ||
pi_config=pi_config, | ||
vf_config=vf_config, | ||
shared_encoder=shared_encoder, | ||
pi_config = MLPConfig( | ||
input_dim=32, | ||
hidden_layer_dims=[32], | ||
hidden_layer_activation="ReLU", | ||
) | ||
|
||
|
||
def get_expected_model_config_tf( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've unified these into one, since model configs are planned to be framework agnostic. |
||
env: gym.Env, shared_encoder: bool | ||
) -> PPOTfModuleConfig: | ||
"""Get a PPOTfModuleConfig that we would expect from the catalog otherwise. | ||
|
||
Args: | ||
env: Environment for which we build the model later | ||
shared_encoder: If True, build a shared encoder for pi and vf, where pi | ||
encoder and vf encoder will be identity. If False, the shared encoder | ||
will be identity. | ||
|
||
Returns: | ||
A PPOTfModuleConfig containing the relevant configs to build PPOTfRLModule. | ||
""" | ||
assert len(env.observation_space.shape) == 1, ( | ||
"No multidimensional obs space " "supported." | ||
vf_config = MLPConfig( | ||
input_dim=32, | ||
hidden_layer_dims=[32, 1], | ||
hidden_layer_activation="ReLU", | ||
) | ||
obs_dim = env.observation_space.shape[0] | ||
|
||
if shared_encoder: | ||
shared_encoder_config = FCTfConfig( | ||
input_dim=obs_dim, | ||
hidden_layers=[32], | ||
activation="ReLU", | ||
output_dim=32, | ||
) | ||
else: | ||
shared_encoder_config = IdentityTfConfig(output_dim=obs_dim) | ||
pi_config = FCConfig() | ||
vf_config = FCConfig() | ||
pi_config.input_dim = vf_config.input_dim = shared_encoder_config.output_dim | ||
|
||
if isinstance(env.action_space, gym.spaces.Discrete): | ||
pi_config.output_dim = env.action_space.n | ||
else: | ||
pi_config.output_dim = env.action_space.shape[0] * 2 | ||
|
||
pi_config.hidden_layers = vf_config.hidden_layers = [32] | ||
pi_config.activation = vf_config.activation = "ReLU" | ||
|
||
return PPOTfModuleConfig( | ||
return PPOModuleConfig( | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
shared_encoder_config=shared_encoder_config, | ||
encoder_config=encoder_config, | ||
pi_config=pi_config, | ||
vf_config=vf_config, | ||
shared_encoder=shared_encoder, | ||
) | ||
|
||
|
||
|
@@ -207,12 +140,11 @@ def setUpClass(cls): | |
def tearDownClass(cls): | ||
ray.shutdown() | ||
|
||
def get_ppo_module(self, framwework, env, lstm, shared_encoder): | ||
if framwework == "torch": | ||
config = get_expected_model_config_torch(env, lstm, shared_encoder) | ||
def get_ppo_module(self, framework, env, lstm): | ||
config = get_expected_model_config(env, lstm) | ||
if framework == "torch": | ||
module = PPOTorchRLModule(config) | ||
else: | ||
config = get_expected_model_config_tf(env, shared_encoder) | ||
module = PPOTfRLModule(config) | ||
return module | ||
|
||
|
@@ -222,29 +154,24 @@ def get_input_batch_from_obs(self, framework, obs): | |
SampleBatch.OBS: convert_to_torch_tensor(obs)[None], | ||
} | ||
else: | ||
batch = {SampleBatch.OBS: np.array([obs])} | ||
batch = {SampleBatch.OBS: tf.convert_to_tensor([obs])} | ||
return batch | ||
|
||
def test_rollouts(self): | ||
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space | ||
frameworks = ["torch", "tf2"] | ||
env_names = ["CartPole-v1", "Pendulum-v1"] | ||
fwd_fns = ["forward_exploration", "forward_inference"] | ||
shared_encoders = [False, True] | ||
ltsms = [False, True] | ||
config_combinations = [frameworks, env_names, fwd_fns, shared_encoders, ltsms] | ||
lstm = [False, True] | ||
config_combinations = [frameworks, env_names, fwd_fns, lstm] | ||
for config in itertools.product(*config_combinations): | ||
fw, env_name, fwd_fn, shared_encoder, lstm = config | ||
if lstm and shared_encoder: | ||
# Not yet implemented | ||
# TODO (Artur): Implement | ||
continue | ||
fw, env_name, fwd_fn, lstm = config | ||
if lstm and fw == "tf2": | ||
# LSTM not implemented in TF2 yet | ||
continue | ||
print(f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" f"={lstm}") | ||
print(f"[FW={fw} | [ENV={env_name}] | [FWD={fwd_fn}] | LSTM" f"={lstm}") | ||
env = gym.make(env_name) | ||
module = self.get_ppo_module(fw, env, lstm, shared_encoder) | ||
module = self.get_ppo_module(framework=fw, env=env, lstm=lstm) | ||
|
||
obs, _ = env.reset() | ||
|
||
|
@@ -267,22 +194,17 @@ def test_forward_train(self): | |
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space | ||
frameworks = ["torch", "tf2"] | ||
env_names = ["CartPole-v1", "Pendulum-v1"] | ||
shared_encoders = [False, True] | ||
ltsms = [False, True] | ||
config_combinations = [frameworks, env_names, shared_encoders, ltsms] | ||
lstm = [False, True] | ||
config_combinations = [frameworks, env_names, lstm] | ||
for config in itertools.product(*config_combinations): | ||
fw, env_name, shared_encoder, lstm = config | ||
if lstm and shared_encoder: | ||
# Not yet implemented | ||
# TODO (Artur): Implement | ||
continue | ||
fw, env_name, lstm = config | ||
if lstm and fw == "tf2": | ||
# LSTM not implemented in TF2 yet | ||
continue | ||
print(f"[ENV={env_name}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}") | ||
print(f"[FW={fw} | [ENV={env_name}] | LSTM={lstm}") | ||
env = gym.make(env_name) | ||
|
||
module = self.get_ppo_module(fw, env, lstm, shared_encoder) | ||
module = self.get_ppo_module(fw, env, lstm) | ||
|
||
# collect a batch of data | ||
batches = [] | ||
|
@@ -343,6 +265,9 @@ def test_forward_train(self): | |
for param in module.parameters(): | ||
self.assertIsNotNone(param.grad) | ||
else: | ||
batch = tree.map_structure( | ||
lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> because tf does not accept numpy arrays. |
||
with tf.GradientTape() as tape: | ||
fwd_out = module.forward_train(batch) | ||
loss = dummy_tf_ppo_loss(batch, fwd_out) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll reintroduce this in a upcoming PR with the ActorCriticEncoder