Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 04 (deprecate RLM…
Browse files Browse the repository at this point in the history
…oduleConfig; cleanups, DefaultModelConfig dataclass). (ray-project#47908)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent a994eec commit 08fc41e
Show file tree
Hide file tree
Showing 54 changed files with 426 additions and 525 deletions.
19 changes: 13 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,9 +3362,9 @@ def rl_module(
"""Sets the config's RLModule settings.
Args:
model_config_dict: The default model config dictionary for `RLModule`s. This
is used for any `RLModule` if not otherwise specified in the
`rl_module_spec`.
model_config: The DefaultModelConfig object (or a config dictionary) passed
as `model_config` arg into each RLModule's constructor. This is used
for all RLModules, if not otherwise specified through `rl_module_spec`.
rl_module_spec: The RLModule spec to use for this config. It can be either
a RLModuleSpec or a MultiRLModuleSpec. If the
observation_space, action_space, catalog_class, or the model config is
Expand Down Expand Up @@ -3392,9 +3392,16 @@ def rl_module(
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)",
error=True,
)
if model_config_dict != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(model_config_dict=..)",
new="AlgorithmConfig.rl_module(model_config=..)",
error=False,
)
model_config = model_config_dict

if model_config_dict is not NotProvided:
self._model_config_dict = model_config_dict
if model_config is not NotProvided:
self._model_config = model_config
if rl_module_spec is not NotProvided:
self._rl_module_spec = rl_module_spec
if algorithm_config_overrides_per_module is not NotProvided:
Expand Down Expand Up @@ -4179,7 +4186,7 @@ def model_config(self):
This method combines the auto configuration `self _model_config_auto_includes`
defined by an algorithm with the user-defined configuration in
`self._model_config_dict`.This configuration dictionary is used to
`self._model_config`.This configuration dictionary is used to
configure the `RLModule` in the new stack and the `ModelV2` in the old
stack.
Expand Down
5 changes: 2 additions & 3 deletions rllib/algorithms/bc/torch/bc_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
"""Generic BC forward pass (for all phases of training/evaluation)."""
output = {}

# State encodings.
# Encoder forward pass.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Actions.
action_logits = self.pi(encoder_outs[ENCODER_OUT])
output[Columns.ACTION_DIST_INPUTS] = action_logits
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT])

return output
13 changes: 5 additions & 8 deletions rllib/algorithms/cql/torch/cql_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,14 @@ def _forward_train(self, batch: Dict) -> Dict[str, Any]:
# First for the random actions (from the mu-distribution as named by Kumar et
# al. (2020)).
low = torch.tensor(
self.config.action_space.low,
self.action_space.low,
device=fwd_out[QF_PREDS].device,
)
high = torch.tensor(
self.config.action_space.high,
self.action_space.high,
device=fwd_out[QF_PREDS].device,
)
num_samples = (
batch[Columns.ACTIONS].shape[0]
* self.config.model_config_dict["num_actions"]
)
num_samples = batch[Columns.ACTIONS].shape[0] * self.model_config["num_actions"]
actions_rand_repeat = low + (high - low) * torch.rand(
(num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device
)
Expand Down Expand Up @@ -128,7 +125,7 @@ def _repeat_actions(
) -> Dict[str, TensorType]:
"""Generated actions and Q-values for repeated observations.
The `self.config.model_condfig_dict["num_actions"]` define a multiplier
The `self.model_config["num_actions"]` define a multiplier
used for generating `num_actions` as many actions as the batch size.
Observations are repeated and then a model forward pass is made.
Expand All @@ -145,7 +142,7 @@ def _repeat_actions(
# Receive the batch size.
batch_size = obs.shape[0]
# Receive the number of action to sample.
num_actions = self.config.model_config_dict["num_actions"]
num_actions = self.model_config["num_actions"]
# Repeat the observations `num_actions` times.
obs_repeat = tree.map_structure(
lambda t: self._repeat_tensor(t, num_actions), obs
Expand Down
9 changes: 4 additions & 5 deletions rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,13 @@ def possibly_masked_mean(data_):

# Calculate our own expected values (to then compare against the
# agent's loss output).
fwd_out = (
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.forward_train({k: v for k, v in batch[DEFAULT_MODULE_ID].items()})
module = algo.learner_group._learner.module[DEFAULT_MODULE_ID].unwrapped()
fwd_out = module.forward_train(
{k: v for k, v in batch[DEFAULT_MODULE_ID].items()}
)
advantages = (
batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy()
- fwd_out["vf_preds"].detach().cpu().numpy()
- module.compute_values(batch[DEFAULT_MODULE_ID]).detach().cpu().numpy()
)
advantages_squared = possibly_masked_mean(np.square(advantages))
c_2 = 100.0 + 1e-8 * (advantages_squared - 100.0)
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setup(self):
self.inference_only = False
# If this is an `inference_only` Module, we'll have to pass this information
# to the encoder config as well.
if self.config.inference_only and self.framework == "torch":
if self.inference_only and self.framework == "torch":
self.catalog.actor_critic_encoder_config.inference_only = True

# Build models from catalog.
Expand All @@ -54,6 +54,6 @@ def get_non_inference_attributes(self) -> List[str]:
"""Return attributes, which are NOT inference-only (only used for training)."""
return ["vf"] + (
[]
if self.config.model_config_dict.get("vf_share_layers")
if self.model_config.get("vf_share_layers")
else ["encoder.critic_encoder"]
)
20 changes: 9 additions & 11 deletions rllib/algorithms/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER, LR_KEY

from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.utils.metrics import LEARNER_RESULTS
from ray.rllib.utils.test_utils import check, check_train_results_new_api_stack


def get_model_config(framework, lstm=False):
def get_model_config(lstm=False):
return (
dict(
use_lstm=True,
Expand Down Expand Up @@ -102,9 +102,7 @@ def test_ppo_compilation_and_schedule_mixins(self):
print("Env={}".format(env))
for lstm in [False]:
print("LSTM={}".format(lstm))
config.rl_module(
model_config_dict=get_model_config("torch", lstm=lstm)
).framework(eager_tracing=False)
config.rl_module(model_config=get_model_config(lstm=lstm))

algo = config.build(env=env)
# TODO: Maybe add an API to get the Learner(s) instances within
Expand Down Expand Up @@ -143,12 +141,12 @@ def test_ppo_free_log_std(self):
num_env_runners=1,
)
.rl_module(
model_config_dict={
"fcnet_hiddens": [10],
"fcnet_activation": "linear",
"free_log_std": True,
"vf_share_layers": True,
}
model_config=DefaultModelConfig(
fcnet_hiddens=[10],
fcnet_activation="linear",
free_log_std=True,
vf_share_layers=True,
),
)
.training(
gamma=0.99,
Expand Down
21 changes: 5 additions & 16 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@
torch, nn = try_import_torch()


def get_expected_module_config(env, model_config_dict, observation_space):
config = RLModuleConfig(
observation_space=observation_space,
action_space=env.action_space,
model_config_dict=model_config_dict,
catalog_class=PPOCatalog,
)

return config


def dummy_torch_ppo_loss(module, batch, fwd_out):
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
Expand All @@ -46,12 +35,12 @@ def dummy_torch_ppo_loss(module, batch, fwd_out):


def _get_ppo_module(env, lstm, observation_space):
model_config_dict = {"use_lstm": lstm}
config = get_expected_module_config(
env, model_config_dict=model_config_dict, observation_space=observation_space
return PPOTorchRLModule(
observation_space=observation_space,
action_space=env.action_space,
model_config=DefaultModelConfig(use_lstm=lstm),
catalog_class=PPOCatalog,
)
module = PPOTorchRLModule(config)
return module


def _get_input_batch_from_obs(obs, lstm):
Expand Down
6 changes: 2 additions & 4 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.annotations import override
Expand All @@ -14,8 +14,6 @@


class PPOTorchRLModule(TorchRLModule, PPORLModule):
framework: str = "torch"

@override(RLModule)
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Default forward pass (used for inference and exploration)."""
Expand All @@ -31,7 +29,7 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:

@override(RLModule)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Train forward pass (keep features for possible shared value func. call)."""
"""Train forward pass (keep embeddings for possible shared value func. call)."""
output = {}
encoder_outs = self.encoder(batch)
output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC]
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/sac/sac_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def setup(self):
# Build heads.
self.pi = self.catalog.build_pi_head(framework=self.framework)

if not self.config.inference_only or self.framework != "torch":
if not self.inference_only or self.framework != "torch":
self.qf = self.catalog.build_qf_head(framework=self.framework)
# If necessary build also a twin Q heads.
if self.twin_q:
Expand Down
48 changes: 0 additions & 48 deletions rllib/algorithms/tests/test_algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,30 +404,6 @@ def get_default_rl_module_spec(self):
spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
self._assertEqualMARLSpecs(spec, expected)

# expected module should become the passed module if we pass it in.
spec, expected = self._get_expected_marl_spec(
config, CustomRLModule2, passed_module_class=CustomRLModule2
)
self._assertEqualMARLSpecs(spec, expected)
########################################
# This is an alternative way to ask the algorithm to assign a specific type of
# RLModule class to ALL module_ids.
config = (
SingleAgentAlgoConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.rl_module(
rl_module_spec=MultiRLModuleSpec(
module_specs=RLModuleSpec(module_class=CustomRLModule1)
),
)
)

spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
self._assertEqualMARLSpecs(spec, expected)

# expected module should become the passed module if we pass it in.
spec, expected = self._get_expected_marl_spec(
config, CustomRLModule2, passed_module_class=CustomRLModule2
Expand Down Expand Up @@ -490,30 +466,6 @@ def get_default_rl_module_spec(self):
lambda: config.rl_module_spec,
)

########################################
# This is the case where we ask the algorithm to use its default
# MultiRLModuleSpec, and the MultiRLModuleSpec has defined its
# RLModuleSpecs.
config = MultiAgentAlgoConfig().api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)

spec, expected = self._get_expected_marl_spec(
config,
DiscreteBCTorchModule,
expected_multi_rl_module_class=CustomMultiRLModule1,
)
self._assertEqualMARLSpecs(spec, expected)

spec, expected = self._get_expected_marl_spec(
config,
CustomRLModule1,
passed_module_class=CustomRLModule1,
expected_multi_rl_module_class=CustomMultiRLModule1,
)
self._assertEqualMARLSpecs(spec, expected)


if __name__ == "__main__":
import pytest
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/tests/test_algorithm_rl_module_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_e2e_load_complex_multi_rl_module(self):
module_class=PPOTorchRLModule,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
model_config=DefaultModelConfig(fcnet_hiddens=[64]),
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
module_class=PPOTorchRLModule,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
model_config=DefaultModelConfig(fcnet_hiddens=[64]),
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
)
Expand Down
18 changes: 8 additions & 10 deletions rllib/connectors/common/add_states_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,10 @@ class AddStatesFromEpisodesToBatch(ConnectorV2):
from ray.rllib.utils.test_utils import check
# Create a simple dummy class, pretending to be an RLModule with
# `get_initial_state`, `is_stateful` and its `config` property defined:
# `get_initial_state`, `is_stateful` and `model_config` property defined:
class MyStateModule:
# dummy config class
class Cfg(dict):
model_config_dict = {"max_seq_len": 2}
config = Cfg()
# dummy config
model_config = {"max_seq_len": 2}
def is_stateful(self):
return True
Expand Down Expand Up @@ -300,7 +298,7 @@ def __call__(
if not sa_module.is_stateful():
continue

max_seq_len = sa_module.config.model_config_dict["max_seq_len"]
max_seq_len = sa_module.model_config["max_seq_len"]

# look_back_state.shape=([state-dim],)
look_back_state = (
Expand Down Expand Up @@ -390,14 +388,14 @@ def _get_max_seq_len(self, rl_module, module_id=None):
mod = rl_module[module_id]
else:
mod = next(iter(rl_module.values()))
if "max_seq_len" not in mod.config.model_config_dict:
if "max_seq_len" not in mod.model_config:
raise ValueError(
"You are using a stateful RLModule and are not providing a "
"'max_seq_len' key inside your model config dict. You can set this "
"'max_seq_len' key inside your `model_config`. You can set this "
"dict and/or override keys in it via `config.rl_module("
"model_config_dict={'max_seq_len': [some int]})`."
"model_config={'max_seq_len': [some int]})`."
)
return mod.config.model_config_dict["max_seq_len"]
return mod.model_config["max_seq_len"]


@Deprecated(
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/learner/general_advantage_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ def __call__(
split_and_zero_pad_n_episodes(
module_advantages,
episode_lens=episode_lens,
max_seq_len=module.config.model_config_dict["max_seq_len"],
max_seq_len=module.model_config["max_seq_len"],
),
axis=0,
)
module_value_targets = np.stack(
split_and_zero_pad_n_episodes(
module_value_targets,
episode_lens=episode_lens,
max_seq_len=module.config.model_config_dict["max_seq_len"],
max_seq_len=module.model_config["max_seq_len"],
),
axis=0,
)
Expand Down
1 change: 1 addition & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class Learner(Checkpointable):
PPOTorchRLModule
)
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
env = gym.make("CartPole-v1")
Expand Down
Loading

0 comments on commit 08fc41e

Please sign in to comment.