Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 04 (deprecate RLM…
Browse files Browse the repository at this point in the history
…oduleConfig; cleanups, DefaultModelConfig dataclass). (ray-project#47908)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent fb34398 commit e993de6
Show file tree
Hide file tree
Showing 49 changed files with 396 additions and 496 deletions.
8 changes: 4 additions & 4 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,9 +3362,9 @@ def rl_module(
"""Sets the config's RLModule settings.
Args:
model_config_dict: The default model config dictionary for `RLModule`s. This
is used for any `RLModule` if not otherwise specified in the
`rl_module_spec`.
model_config: The DefaultModelConfig object (or a config dictionary) passed
as `model_config` arg into each RLModule's constructor. This is used
for all RLModules, if not otherwise specified through `rl_module_spec`.
rl_module_spec: The RLModule spec to use for this config. It can be either
a RLModuleSpec or a MultiRLModuleSpec. If the
observation_space, action_space, catalog_class, or the model config is
Expand Down Expand Up @@ -4186,7 +4186,7 @@ def model_config(self):
This method combines the auto configuration `self _model_config_auto_includes`
defined by an algorithm with the user-defined configuration in
`self._model_config_dict`.This configuration dictionary is used to
`self._model_config`.This configuration dictionary is used to
configure the `RLModule` in the new stack and the `ModelV2` in the old
stack.
Expand Down
5 changes: 2 additions & 3 deletions rllib/algorithms/bc/torch/bc_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
"""Generic BC forward pass (for all phases of training/evaluation)."""
output = {}

# State encodings.
# Encoder forward pass.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Actions.
action_logits = self.pi(encoder_outs[ENCODER_OUT])
output[Columns.ACTION_DIST_INPUTS] = action_logits
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT])

return output
9 changes: 4 additions & 5 deletions rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,13 @@ def possibly_masked_mean(data_):

# Calculate our own expected values (to then compare against the
# agent's loss output).
fwd_out = (
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.forward_train({k: v for k, v in batch[DEFAULT_MODULE_ID].items()})
module = algo.learner_group._learner.module[DEFAULT_MODULE_ID].unwrapped()
fwd_out = module.forward_train(
{k: v for k, v in batch[DEFAULT_MODULE_ID].items()}
)
advantages = (
batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy()
- fwd_out["vf_preds"].detach().cpu().numpy()
- module.compute_values(batch[DEFAULT_MODULE_ID]).detach().cpu().numpy()
)
advantages_squared = possibly_masked_mean(np.square(advantages))
c_2 = 100.0 + 1e-8 * (advantages_squared - 100.0)
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setup(self):
self.inference_only = False
# If this is an `inference_only` Module, we'll have to pass this information
# to the encoder config as well.
if self.config.inference_only and self.framework == "torch":
if self.inference_only and self.framework == "torch":
self.catalog.actor_critic_encoder_config.inference_only = True

# Build models from catalog.
Expand All @@ -54,6 +54,6 @@ def get_non_inference_attributes(self) -> List[str]:
"""Return attributes, which are NOT inference-only (only used for training)."""
return ["vf"] + (
[]
if self.config.model_config_dict.get("vf_share_layers")
if self.model_config.get("vf_share_layers")
else ["encoder.critic_encoder"]
)
20 changes: 9 additions & 11 deletions rllib/algorithms/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER, LR_KEY

from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.utils.metrics import LEARNER_RESULTS
from ray.rllib.utils.test_utils import check, check_train_results_new_api_stack


def get_model_config(framework, lstm=False):
def get_model_config(lstm=False):
return (
dict(
use_lstm=True,
Expand Down Expand Up @@ -102,9 +102,7 @@ def test_ppo_compilation_and_schedule_mixins(self):
print("Env={}".format(env))
for lstm in [False]:
print("LSTM={}".format(lstm))
config.rl_module(
model_config_dict=get_model_config("torch", lstm=lstm)
).framework(eager_tracing=False)
config.rl_module(model_config=get_model_config(lstm=lstm))

algo = config.build(env=env)
# TODO: Maybe add an API to get the Learner(s) instances within
Expand Down Expand Up @@ -143,12 +141,12 @@ def test_ppo_free_log_std(self):
num_env_runners=1,
)
.rl_module(
model_config_dict={
"fcnet_hiddens": [10],
"fcnet_activation": "linear",
"free_log_std": True,
"vf_share_layers": True,
}
model_config=DefaultModelConfig(
fcnet_hiddens=[10],
fcnet_activation="linear",
free_log_std=True,
vf_share_layers=True,
),
)
.training(
gamma=0.99,
Expand Down
21 changes: 5 additions & 16 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@
torch, nn = try_import_torch()


def get_expected_module_config(env, model_config_dict, observation_space):
config = RLModuleConfig(
observation_space=observation_space,
action_space=env.action_space,
model_config_dict=model_config_dict,
catalog_class=PPOCatalog,
)

return config


def dummy_torch_ppo_loss(module, batch, fwd_out):
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
Expand All @@ -46,12 +35,12 @@ def dummy_torch_ppo_loss(module, batch, fwd_out):


def _get_ppo_module(env, lstm, observation_space):
model_config_dict = {"use_lstm": lstm}
config = get_expected_module_config(
env, model_config_dict=model_config_dict, observation_space=observation_space
return PPOTorchRLModule(
observation_space=observation_space,
action_space=env.action_space,
model_config=DefaultModelConfig(use_lstm=lstm),
catalog_class=PPOCatalog,
)
module = PPOTorchRLModule(config)
return module


def _get_input_batch_from_obs(obs, lstm):
Expand Down
6 changes: 2 additions & 4 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.annotations import override
Expand All @@ -14,8 +14,6 @@


class PPOTorchRLModule(TorchRLModule, PPORLModule):
framework: str = "torch"

@override(RLModule)
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Default forward pass (used for inference and exploration)."""
Expand All @@ -31,7 +29,7 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:

@override(RLModule)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Train forward pass (keep features for possible shared value func. call)."""
"""Train forward pass (keep embeddings for possible shared value func. call)."""
output = {}
encoder_outs = self.encoder(batch)
output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC]
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/sac/sac_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def setup(self):
# Build heads.
self.pi = self.catalog.build_pi_head(framework=self.framework)

if not self.config.inference_only or self.framework != "torch":
if not self.inference_only or self.framework != "torch":
self.qf = self.catalog.build_qf_head(framework=self.framework)
# If necessary build also a twin Q heads.
if self.twin_q:
Expand Down
48 changes: 0 additions & 48 deletions rllib/algorithms/tests/test_algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,30 +404,6 @@ def get_default_rl_module_spec(self):
spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
self._assertEqualMARLSpecs(spec, expected)

# expected module should become the passed module if we pass it in.
spec, expected = self._get_expected_marl_spec(
config, CustomRLModule2, passed_module_class=CustomRLModule2
)
self._assertEqualMARLSpecs(spec, expected)
########################################
# This is an alternative way to ask the algorithm to assign a specific type of
# RLModule class to ALL module_ids.
config = (
SingleAgentAlgoConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.rl_module(
rl_module_spec=MultiRLModuleSpec(
module_specs=RLModuleSpec(module_class=CustomRLModule1)
),
)
)

spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
self._assertEqualMARLSpecs(spec, expected)

# expected module should become the passed module if we pass it in.
spec, expected = self._get_expected_marl_spec(
config, CustomRLModule2, passed_module_class=CustomRLModule2
Expand Down Expand Up @@ -490,30 +466,6 @@ def get_default_rl_module_spec(self):
lambda: config.rl_module_spec,
)

########################################
# This is the case where we ask the algorithm to use its default
# MultiRLModuleSpec, and the MultiRLModuleSpec has defined its
# RLModuleSpecs.
config = MultiAgentAlgoConfig().api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)

spec, expected = self._get_expected_marl_spec(
config,
DiscreteBCTorchModule,
expected_multi_rl_module_class=CustomMultiRLModule1,
)
self._assertEqualMARLSpecs(spec, expected)

spec, expected = self._get_expected_marl_spec(
config,
CustomRLModule1,
passed_module_class=CustomRLModule1,
expected_multi_rl_module_class=CustomMultiRLModule1,
)
self._assertEqualMARLSpecs(spec, expected)


if __name__ == "__main__":
import pytest
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/tests/test_algorithm_rl_module_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_e2e_load_complex_multi_rl_module(self):
module_class=PPOTorchRLModule,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
model_config=DefaultModelConfig(fcnet_hiddens=[64]),
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
module_class=PPOTorchRLModule,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
model_config=DefaultModelConfig(fcnet_hiddens=[64]),
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
)
Expand Down
14 changes: 12 additions & 2 deletions rllib/core/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,21 @@ def __init__(
if view_requirements != DEPRECATED_VALUE:
deprecation_warning(old="Catalog(view_requirements=..)", error=True)

# TODO (sven): The following logic won't be needed anymore, once we get rid of
# Catalogs entirely. We will assert directly inside the algo's DefaultRLModule
# class that the `model_config` is a DefaultModelConfig. Thus users won't be
# able to pass in partial config dicts into a default model (alternatively, we
# could automatically augment the user provided dict by the default config
# dataclass object only(!) for default modules).
if dataclasses.is_dataclass(model_config_dict):
model_config_dict = dataclasses.asdict(model_config_dict)
default_config = dataclasses.asdict(DefaultModelConfig())
# end: TODO

self.observation_space = observation_space
self.action_space = action_space

# TODO (Artur): Make model defaults a dataclass
self._model_config_dict = {**MODEL_DEFAULTS, **model_config_dict}
self._model_config_dict = default_config | model_config_dict
self._latent_dims = None

self._determine_components_hook()
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/rl_module/apis/inference_only_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ class InferenceOnlyAPI(abc.ABC):
Only the `get_non_inference_attributes` method needs to get implemented for
an RLModule to have the following functionality:
- On EnvRunners (or when self.config.inference_only=True), RLlib will remove
- On EnvRunners (or when self.inference_only=True), RLlib will remove
those parts of the model not required for action computation.
- An RLModule on a Learner (where `self.config.inference_only=False`) will
- An RLModule on a Learner (where `self.inference_only=False`) will
return only those weights from `get_state()` that are part of its inference-only
version, thus possibly saving network traffic/time.
"""
Expand Down
Loading

0 comments on commit e993de6

Please sign in to comment.