Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Chaining sub-models in RLModules with dynamic spec keys inside forward methods ("Solution 1") #31310

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a41c245
initial
ArturNiederfahrenhorst Dec 16, 2022
a15a82d
tests complete
ArturNiederfahrenhorst Dec 16, 2022
37d5708
wip
ArturNiederfahrenhorst Dec 16, 2022
2d40569
wip
ArturNiederfahrenhorst Dec 16, 2022
bedd45c
Merge branch 'master' into rlmoduletests
ArturNiederfahrenhorst Dec 18, 2022
74d213b
mutually exclusive encoders, tests passing
ArturNiederfahrenhorst Dec 18, 2022
ddf8596
add lstm code
ArturNiederfahrenhorst Dec 19, 2022
dae265d
Merge branch 'master' into rlmoduletests
ArturNiederfahrenhorst Dec 19, 2022
269a5dd
add underscore to forward method
ArturNiederfahrenhorst Dec 19, 2022
b00b9ee
better docs for get expected model config
ArturNiederfahrenhorst Dec 21, 2022
8157f63
kourosh's comments
ArturNiederfahrenhorst Dec 21, 2022
a174623
lstm fixed, tests working
ArturNiederfahrenhorst Dec 21, 2022
3a4ea01
add state out
ArturNiederfahrenhorst Dec 21, 2022
5ee9a4d
add __main__ to test
ArturNiederfahrenhorst Dec 21, 2022
eef53af
change lstm testing according to kourosh's comment
ArturNiederfahrenhorst Dec 21, 2022
3d1ebde
fix get_initial_state
ArturNiederfahrenhorst Dec 21, 2022
fdab59e
remove useless forward_exploration/forward_inference branch
ArturNiederfahrenhorst Dec 21, 2022
d7a9f17
revert changes to test_ppo_with_rl_module.py
ArturNiederfahrenhorst Dec 21, 2022
a3da1b6
merge master
ArturNiederfahrenhorst Dec 22, 2022
b1bb02f
remove pass
ArturNiederfahrenhorst Dec 22, 2022
fe9226e
wip
ArturNiederfahrenhorst Dec 22, 2022
8f36e45
fix gym incompatability
ArturNiederfahrenhorst Dec 22, 2022
8493100
merge rlmoduletest
ArturNiederfahrenhorst Dec 22, 2022
30be028
wip
ArturNiederfahrenhorst Dec 25, 2022
b1f8064
wip
ArturNiederfahrenhorst Dec 26, 2022
05cc03a
fix lstm test
ArturNiederfahrenhorst Dec 27, 2022
66ac43c
add missing ray.
ArturNiederfahrenhorst Dec 27, 2022
a74d9e8
merge rlmoduletest
ArturNiederfahrenhorst Dec 27, 2022
1d09f03
typo
ArturNiederfahrenhorst Jan 3, 2023
3ecc5e1
merge master
ArturNiederfahrenhorst Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
FCConfig,
IdentityConfig,
LSTMConfig,
STATE_IN,
STATE_OUT,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
Expand Down Expand Up @@ -76,8 +74,18 @@ def get_expected_model_config(
activation="ReLU",
)

pi_config = FCConfig()
vf_config = FCConfig()
pi_config = FCConfig(
input_dim=pi_encoder_config.output_dim,
hidden_layers=[16],
)
if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
else:
pi_config.output_dim = env.action_space.shape[0] * 2

vf_config = FCConfig(
input_dim=vf_encoder_config.output_dim, hidden_layers=[16], output_dim=1
)

if isinstance(env.action_space, gym.spaces.Discrete):
pi_config.output_dim = env.action_space.n
Expand Down Expand Up @@ -110,14 +118,14 @@ def test_rollouts(self):
for env_name in ["CartPole-v1", "Pendulum-v1"]:
for fwd_fn in ["forward_exploration", "forward_inference"]:
for shared_encoder in [False, True]:
for lstm in [True, False]:
for lstm in [False, True]:
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
continue
print(
f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM"
f"={lstm}"
f"={lstm} | [FWD={fwd_fn}"
)
env = gym.make(env_name)

Expand All @@ -135,7 +143,7 @@ def test_rollouts(self):
state_in = tree.map_structure(
lambda x: x[None], convert_to_torch_tensor(state_in)
)
batch[STATE_IN] = state_in
batch["state_in_0"] = state_in
batch[SampleBatch.SEQ_LENS] = torch.Tensor([1])

if fwd_fn == "forward_exploration":
Expand All @@ -146,8 +154,8 @@ def test_rollouts(self):
def test_forward_train(self):
# TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space
for env_name in ["CartPole-v1", "Pendulum-v1"]:
for shared_encoder in [False, True]:
for lstm in [True, False]:
for shared_encoder in [False]:
for lstm in [True]:
if lstm and shared_encoder:
# Not yet implemented
# TODO (Artur): Implement
Expand Down Expand Up @@ -175,7 +183,7 @@ def test_forward_train(self):
if lstm:
input_batch = {
SampleBatch.OBS: convert_to_torch_tensor(obs)[None],
STATE_IN: state_in,
"state_in_0": state_in,
SampleBatch.SEQ_LENS: np.array([1]),
}
else:
Expand All @@ -196,8 +204,7 @@ def test_forward_train(self):
SampleBatch.TRUNCATEDS: np.array(truncated),
}
if lstm:
assert STATE_OUT in fwd_out
state_in = fwd_out[STATE_OUT]
state_in = fwd_out["state_out_0"]
batches.append(output_batch)
obs = new_obs
tstep += 1
Expand All @@ -210,7 +217,7 @@ def test_forward_train(self):
for k, v in batch.items()
}
if lstm:
fwd_in[STATE_IN] = initial_state
fwd_in["state_in_0"] = initial_state
fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10])

# forward train
Expand Down
173 changes: 116 additions & 57 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
TorchDiagGaussian,
)
from ray.rllib.core.rl_module.encoder import (
FCNet,
FCConfig,
LSTMConfig,
IdentityConfig,
LSTMEncoder,
ENCODER_OUT,
)
from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space
from ray.rllib.models.base_model import BaseModelIOKeys


torch, nn = try_import_torch()
Expand Down Expand Up @@ -79,20 +78,8 @@ def setup(self) -> None:
self.shared_encoder = self.config.shared_encoder_config.build()
self.pi_encoder = self.config.pi_encoder_config.build()
self.vf_encoder = self.config.vf_encoder_config.build()

self.pi = FCNet(
input_dim=self.config.pi_encoder_config.output_dim,
output_dim=self.config.pi_config.output_dim,
hidden_layers=self.config.pi_config.hidden_layers,
activation=self.config.pi_config.activation,
)

self.vf = FCNet(
input_dim=self.config.vf_encoder_config.output_dim,
output_dim=1,
hidden_layers=self.config.vf_config.hidden_layers,
activation=self.config.vf_config.activation,
)
self.pi = self.config.pi_config.build()
self.vf = self.config.vf_config.build()

self._is_discrete = isinstance(
convert_old_gym_space_to_gymnasium_space(self.config.action_space),
Expand Down Expand Up @@ -212,35 +199,46 @@ def get_initial_state(self) -> NestedDict:
else:
return NestedDict({})

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return SpecDict({SampleBatch.ACTION_DIST: TorchDeterministic})

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
shared_enc_out = self.shared_encoder(batch)
pi_enc_out = self.pi_encoder(shared_enc_out)

action_logits = self.pi(pi_enc_out[ENCODER_OUT])
x = self.shared_encoder(
batch,
input_mapping={
self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS,
self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0",
},
)
x = self.pi_encoder(
x,
input_mapping={
self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[
BaseModelIOKeys.OUT
],
self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0",
},
)

x = self.pi(
x,
input_mapping={
self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT],
},
)
action_logits = x[self.pi.io[BaseModelIOKeys.OUT]]
if self._is_discrete:
action = torch.argmax(action_logits, dim=-1)
else:
action, _ = action_logits.chunk(2, dim=-1)

action_dist = TorchDeterministic(action)
output = {SampleBatch.ACTION_DIST: action_dist}
output["state_out"] = pi_enc_out.get("state_out", {})
output["state_out_0"] = x.get("state_out", {})
return output

@override(RLModule)
def input_specs_exploration(self):
return self.shared_encoder.input_spec()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
specs = {SampleBatch.ACTION_DIST: self.__get_action_dist_type()}
Expand All @@ -264,12 +262,40 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
policy distribution to be used for computing KL divergence between the old
policy and the new policy during training.
"""
encoder_out = self.shared_encoder(batch)
encoder_out_pi = self.pi_encoder(encoder_out)
encoder_out_vf = self.vf_encoder(encoder_out)
action_logits = self.pi(encoder_out_pi[ENCODER_OUT])
x = self.shared_encoder(
batch,
input_mapping={
self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS,
self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0",
},
)
x = self.pi_encoder(
x,
input_mapping={
self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[
BaseModelIOKeys.OUT
],
self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0",
},
)
x = self.vf_encoder(
x,
input_mapping={
self.vf_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[
BaseModelIOKeys.OUT
],
},
)

x = self.pi(
x,
input_mapping={
self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT],
},
)

output = {}
action_logits = x[self.pi.io[BaseModelIOKeys.OUT]]
if self._is_discrete:
action_dist = TorchCategorical(logits=action_logits)
output[SampleBatch.ACTION_DIST_INPUTS] = {"logits": action_logits}
Expand All @@ -281,24 +307,23 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
output[SampleBatch.ACTION_DIST] = action_dist

# compute the value function
output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf[ENCODER_OUT]).squeeze(-1)
output["state_out"] = encoder_out_pi.get("state_out", {})
return output
vf_out = self.vf(
x,
input_mapping={
self.vf.io[BaseModelIOKeys.IN]: self.vf_encoder.io[BaseModelIOKeys.OUT],
},
)
output[SampleBatch.VF_PREDS] = vf_out[self.vf.io[BaseModelIOKeys.OUT]].squeeze(
-1
)

@override(RLModule)
def input_specs_train(self) -> SpecDict:
if self._is_discrete:
action_spec = TorchTensorSpec("b")
else:
action_dim = self.config.action_space.shape[0]
action_spec = TorchTensorSpec("b, h", h=action_dim)

spec_dict = self.shared_encoder.input_spec()
spec_dict.update({SampleBatch.ACTIONS: action_spec})
if SampleBatch.OBS in spec_dict:
spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS]
spec = SpecDict(spec_dict)
return spec
shared_encoder_state = x.get(self.shared_encoder.io[BaseModelIOKeys.STATE_OUT])
pi_encoder_state = x.get(self.pi_encoder.io[BaseModelIOKeys.STATE_OUT])

state_out = shared_encoder_state or pi_encoder_state
if state_out:
output["state_out_0"] = state_out
return output

@override(RLModule)
def output_specs_train(self) -> SpecDict:
Expand All @@ -314,12 +339,46 @@ def output_specs_train(self) -> SpecDict:

@override(RLModule)
def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
encoder_out = self.shared_encoder(batch)
encoder_out_pi = self.pi_encoder(encoder_out)
encoder_out_vf = self.vf_encoder(encoder_out)
x = self.shared_encoder(
batch,
input_mapping={
self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS,
self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0",
},
)
x = self.pi_encoder(
x,
input_mapping={
self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[
BaseModelIOKeys.OUT
],
self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0",
},
)
x = self.vf_encoder(
x,
input_mapping={
self.vf_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[
BaseModelIOKeys.OUT
],
},
)

action_logits = self.pi(encoder_out_pi[ENCODER_OUT])
vf = self.vf(encoder_out_vf[ENCODER_OUT])
x = self.pi(
x,
input_mapping={
self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT],
},
)

action_logits = x[self.pi.io[BaseModelIOKeys.OUT]]

vf_out = self.vf(
x,
input_mapping={
self.vf.io[BaseModelIOKeys.IN]: self.vf_encoder.io[BaseModelIOKeys.OUT],
},
)

if self._is_discrete:
action_dist = TorchCategorical(logits=action_logits)
Expand All @@ -333,11 +392,11 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
output = {
SampleBatch.ACTION_DIST: action_dist,
SampleBatch.ACTION_LOGP: logp,
SampleBatch.VF_PREDS: vf.squeeze(-1),
SampleBatch.VF_PREDS: vf_out[self.vf.io[BaseModelIOKeys.OUT]].squeeze(-1),
"entropy": entropy,
}

output["state_out"] = encoder_out_pi.get("state_out", {})
output["state_out_0"] = x.get("state_out", {})
return output

def __get_action_dist_type(self):
Expand Down
Loading