Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Incorporate VectorEncoder into PPORLModule and tests #31238

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,13 @@ py_test(
srcs = ["algorithms/ppo/tests/test_ppo_with_rl_module.py"]
)

py_test(
name = "test_ppo_rl_module",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
srcs = ["algorithms/ppo/tests/test_ppo_rl_module.py"]
)

# PPO Reproducibility
py_test(
name = "test_repro_ppo",
Expand Down Expand Up @@ -1927,9 +1934,16 @@ py_test(
name = "test_torch_vector_encoder",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/torch/encoders/tests/test_torch_vector_encoder.py"]
srcs = ["models/torch/encoders/tests/test_vector_encoder.py"]
)

# test TorchIdentity
py_test(
name = "test_torch_identity",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/torch/tests/test_identity.py"]
)

# --------------------------------------------------------------------
# Offline
Expand Down
234 changes: 234 additions & 0 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import ray
import unittest
import numpy as np
import gym
import torch
import tree

from ray.rllib import SampleBatch
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule,
get_ppo_loss,
PPOModuleConfig,
)
from ray.rllib.core.rl_module.encoder import (
LSTMConfig,
STATE_IN,
STATE_OUT,
)
from ray.rllib.models.configs.identity import IdentityConfig
from ray.rllib.models.configs.encoder import VectorEncoderConfig
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor


def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig:
"""Get a PPOModuleConfig that we would expect from the catalog otherwise.

Args:
env: Environment for which we build the model later
lstm: If True, build recurrent pi encoder
shared_encoder: If True, build a shared encoder for pi and vf, where pi
encoder and vf encoder will be identity. If False, the shared encoder
will be identity.

Returns:
A PPOModuleConfig containing the relevant configs to build PPORLModule
"""
assert len(env.observation_space.shape) == 1, (
"No multidimensional obs space " "supported."
)
obs_dim = env.observation_space.shape[0]

if shared_encoder:
assert not lstm, "LSTM can only be used in PI"
shared_encoder_config = VectorEncoderConfig(
hidden_layer_sizes=[32],
)
pi_encoder_config = IdentityConfig()
vf_encoder_config = IdentityConfig()
else:
shared_encoder_config = IdentityConfig()
if lstm:
pi_encoder_config = LSTMConfig(
input_dim=obs_dim,
hidden_dim=32,
batch_first=True,
output_dim=32,
num_layers=1,
)
else:
pi_encoder_config = VectorEncoderConfig(
hidden_layer_sizes=[32],
)
vf_encoder_config = VectorEncoderConfig(
hidden_layer_sizes=[32],
)

pi_config = VectorEncoderConfig(hidden_layer_sizes=[32])

vf_config = VectorEncoderConfig(hidden_layer_sizes=[32])

if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
else:
pi_config.output_dim = env.action_space.shape[0] * 2

return PPOModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
shared_encoder_config=shared_encoder_config,
pi_encoder_config=pi_encoder_config,
vf_encoder_config=vf_encoder_config,
pi_config=pi_config,
vf_config=vf_config,
shared_encoder=shared_encoder,
)


class TestPPO(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()

@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_rollouts(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
for env_name in ["CartPole-v1", "Pendulum-v1"]:
for fwd_fn in ["forward_exploration", "forward_inference"]:
for shared_encoder in [False, True]:
for lstm in [True, False]:
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
print(
f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED="
f"{shared_encoder}] | LSTM={lstm}"
)
env = gym.make(env_name)

config = get_expected_model_config(env, lstm, shared_encoder)
module = PPOTorchRLModule(config)

obs = env.reset()

batch = {
SampleBatch.OBS: convert_to_torch_tensor(obs)[None],
}

if lstm:
state_in = module.pi_encoder.get_inital_state()
state_in = tree.map_structure(
lambda x: x[None], convert_to_torch_tensor(state_in)
)
batch[STATE_IN] = state_in
batch["seq_lens"] = torch.Tensor([1])

if fwd_fn == "forward_exploration":
module.forward_exploration(batch)
elif fwd_fn == "forward_inference":
module.forward_inference(batch)

def test_forward_train(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
for env_name in ["CartPole-v1", "Pendulum-v1"]:
for fwd_fn in ["forward_exploration", "forward_inference"]:
for shared_encoder in [False, True]:
for lstm in [True, False]:
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
print(
f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED="
f"{shared_encoder}] | LSTM={lstm}"
)
env = gym.make(env_name)

config = get_expected_model_config(env, lstm, shared_encoder)
module = PPOTorchRLModule(config)

# collect a batch of data
batches = []
obs = env.reset()
tstep = 0
if lstm:
# TODO (Artur): Multiple states
state_in = module.pi_encoder.get_inital_state()
state_in = tree.map_structure(
lambda x: x[None], convert_to_torch_tensor(state_in)
)
output_states = state_in
while tstep < 10:
if lstm:
input_batch = {
SampleBatch.OBS: convert_to_torch_tensor(obs)[None],
STATE_IN: state_in,
SampleBatch.SEQ_LENS: np.array([1]),
}
else:
input_batch = {
SampleBatch.OBS: convert_to_torch_tensor(obs)[None]
}
fwd_out = module.forward_exploration(input_batch)
action = convert_to_numpy(
fwd_out["action_dist"].sample().squeeze(0)
)
new_obs, reward, done, _ = env.step(action)
output_batch = {
SampleBatch.OBS: obs,
SampleBatch.NEXT_OBS: new_obs,
SampleBatch.ACTIONS: action,
SampleBatch.REWARDS: np.array(reward),
SampleBatch.DONES: np.array(done),
}
if lstm:
assert STATE_OUT in fwd_out
if tstep > 0: # First states are already added

# Extend nested batches of states
output_states = tree.map_structure(
lambda *s: torch.cat((s[0], s[1])),
output_states,
state_in,
)
state_in = fwd_out[STATE_OUT]
batches.append(output_batch)
obs = new_obs
tstep += 1

# convert the list of dicts to dict of lists
batch = tree.map_structure(lambda *x: list(x), *batches)
# convert dict of lists to dict of tensors
fwd_in = {
k: convert_to_torch_tensor(np.array(v))
for k, v in batch.items()
}
if lstm:
fwd_in[STATE_IN] = output_states
fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10)

# forward train
# before training make sure module is on the right device and in
# training mode
module.to("cpu")
module.train()
fwd_out = module.forward_train(fwd_in)
loss = get_ppo_loss(fwd_in, fwd_out)
loss.backward()

# check that all neural net parameters have gradients
for param in module.parameters():
pass
self.assertIsNotNone(param.grad)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
4 changes: 2 additions & 2 deletions rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
import unittest

import numpy as np

import ray
import ray.rllib.algorithms.ppo as ppo

from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
Expand Down
Loading