Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Introduce Checkpointable API for RLlib components and subcomponents. #46376

Merged
merged 6 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import time
from typing import (
Callable,
Container,
Collection,
DefaultDict,
Dict,
List,
Expand Down Expand Up @@ -285,11 +285,11 @@ class Algorithm(Trainable, AlgorithmBase):
@staticmethod
def from_checkpoint(
checkpoint: Union[str, Checkpoint],
policy_ids: Optional[Container[PolicyID]] = None,
policy_ids: Optional[Collection[PolicyID]] = None,
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
policies_to_train: Optional[
Union[
Container[PolicyID],
Collection[PolicyID],
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
Expand Down Expand Up @@ -2038,7 +2038,7 @@ def add_policy(
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
policies_to_train: Optional[
Union[
Container[PolicyID],
Collection[PolicyID],
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
Expand Down Expand Up @@ -2231,7 +2231,7 @@ def remove_policy(
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[
Union[
Container[PolicyID],
Collection[PolicyID],
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
Expand Down Expand Up @@ -2945,11 +2945,11 @@ def _setup_eval_worker(w):
@staticmethod
def _checkpoint_info_to_algorithm_state(
checkpoint_info: dict,
policy_ids: Optional[Container[PolicyID]] = None,
policy_ids: Optional[Collection[PolicyID]] = None,
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
policies_to_train: Optional[
Union[
Container[PolicyID],
Collection[PolicyID],
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import (
Any,
Callable,
Container,
Collection,
Dict,
List,
Optional,
Expand Down Expand Up @@ -2501,7 +2501,7 @@ def multi_agent(
Callable[[AgentID, "OldEpisode"], PolicyID]
] = NotProvided,
policies_to_train: Optional[
Union[Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
] = NotProvided,
policy_states_are_swappable: Optional[bool] = NotProvided,
observation_fn: Optional[Callable] = NotProvided,
Expand Down Expand Up @@ -3499,7 +3499,7 @@ def get_multi_agent_setup(
policies[pid].config or {}
)

# If container given, construct a simple default callable returning True
# If collection given, construct a simple default callable returning True
# if the PolicyID is found in the list/set of IDs.
if self.policies_to_train is not None and not callable(self.policies_to_train):
pols = set(self.policies_to_train)
Expand Down
21 changes: 21 additions & 0 deletions rllib/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,30 @@
DEFAULT_POLICY_ID = "default_policy"
DEFAULT_MODULE_ID = DEFAULT_POLICY_ID # TODO (sven): Change this to "default_module"

COMPONENT_AGENT_TO_MODULE_MAPPING_FN = "agent_to_module_mapping_fn"
COMPONENT_ENV_RUNNER = "env_runner"
COMPONENT_ENV_TO_MODULE_CONNECTOR = "env_to_module_connector"
COMPONENT_EVAL_ENV_RUNNER = "eval_env_runner"
COMPONENT_LEARNER = "learner"
COMPONENT_METRICS_LOGGER = "metrics_logger"
COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector"
COMPONENT_SHOULD_MODULE_BE_UPDATED = "should_module_be_updated"
COMPONENT_OPTIMIZER = "optimizer"
COMPONENT_RL_MODULE = "rl_module"


__all__ = [
"Columns",
"COMPONENT_AGENT_TO_MODULE_MAPPING_FN",
"COMPONENT_ENV_RUNNER",
"COMPONENT_ENV_TO_MODULE_CONNECTOR",
"COMPONENT_EVAL_ENV_RUNNER",
"COMPONENT_LEARNER",
"COMPONENT_METRICS_LOGGER",
"COMPONENT_MODULE_TO_ENV_CONNECTOR",
"COMPONENT_SHOULD_MODULE_BE_UPDATED",
"COMPONENT_OPTIMIZER",
"COMPONENT_RL_MODULE",
"DEFAULT_AGENT_ID",
"DEFAULT_MODULE_ID",
"DEFAULT_POLICY_ID",
Expand Down
79 changes: 40 additions & 39 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import (
Any,
Callable,
Container,
Collection,
Dict,
List,
Hashable,
Expand Down Expand Up @@ -68,6 +68,7 @@
ParamRef,
ParamDict,
ResultDict,
StateDict,
TensorType,
)
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -805,7 +806,7 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None):
# If None, return True (by default, all modules should be updated).
if should_module_be_updated_fn is None:
return True
# If container given, return whether `module_id` is in that container.
# If collection given, return whether `module_id` is in that container.
elif not callable(should_module_be_updated_fn):
return module_id in set(should_module_be_updated_fn)

Expand Down Expand Up @@ -1038,47 +1039,13 @@ def _update(

"""

def set_state(self, state: Dict[str, Any]) -> None:
"""Set the state of the learner.

Args:
state: The state of the optimizer and module. Can be obtained
from `get_state`. State is a dictionary with two keys:
"module_state" and "optimizer_state". The value of each key
is a dictionary that can be passed to `set_module_state` and
`set_optimizer_state` respectively.

"""
self._check_is_built()

# TODO (sven): Deprecate old state keys and create constants for new ones.
module_state = state.get("rl_module", state.get("module_state"))
# TODO: once we figure out the optimizer format, we can set/get the state
if module_state is None:
raise ValueError(
"state must have a key 'module_state' for the module weights"
)
self.set_module_state(module_state)

# TODO (sven): Deprecate old state keys and create constants for new ones.
optimizer_state = state.get("optimizer", state.get("optimizer_state"))
if optimizer_state is None:
raise ValueError(
"state must have a key 'optimizer_state' for the optimizer weights"
)
self.set_optimizer_state(optimizer_state)

# Update our trainable Modules information/function via our config.
# If not provided in state (None), all Modules will be trained by default.
self.config.multi_agent(policies_to_train=state.get("modules_to_train"))

def get_state(
self,
components: Optional[Union[str, List[str]]] = None,
*,
inference_only: bool = False,
module_ids: Optional[Container[ModuleID]] = None,
) -> Dict[str, Any]:
module_ids: Optional[Collection[ModuleID]] = None,
) -> StateDict:
"""Get (select components of) the state of this Learner.

Args:
Expand All @@ -1087,7 +1054,7 @@ def get_state(
inference_only: Whether to return the inference-only weight set of the
underlying RLModule. Note that this setting only has an effect if
components is None or the string "rl_module" is in components.
module_ids: Optional container of ModuleIDs to be returned only within the
module_ids: Optional collection of ModuleIDs to be returned only within the
state dict. If None (default), all module IDs' weights are returned.

Returns:
Expand All @@ -1110,6 +1077,40 @@ def get_state(
state["modules_to_be_updated"] = self.config.policies_to_train
return state

def set_state(self, state: StateDict) -> None:
"""Set the state of the learner.

Args:
state: The state of the optimizer and module. Can be obtained
from `get_state`. State is a dictionary with two keys:
"module_state" and "optimizer_state". The value of each key
is a dictionary that can be passed to `set_module_state` and
`set_optimizer_state` respectively.

"""
self._check_is_built()

# TODO (sven): Deprecate old state keys and create constants for new ones.
module_state = state.get("rl_module", state.get("module_state"))
# TODO: once we figure out the optimizer format, we can set/get the state
if module_state is None:
raise ValueError(
"state must have a key 'module_state' for the module weights"
)
self.set_module_state(module_state)

# TODO (sven): Deprecate old state keys and create constants for new ones.
optimizer_state = state.get("optimizer", state.get("optimizer_state"))
if optimizer_state is None:
raise ValueError(
"state must have a key 'optimizer_state' for the optimizer weights"
)
self.set_optimizer_state(optimizer_state)

# Update our trainable Modules information/function via our config.
# If not provided in state (None), all Modules will be trained by default.
self.config.multi_agent(policies_to_train=state.get("modules_to_train"))

def set_optimizer_state(self, state: Dict[str, Any]) -> None:
"""Sets the state of all optimizers currently registered in this Learner.

Expand Down
13 changes: 7 additions & 6 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import (
Any,
Callable,
Container,
Collection,
Dict,
List,
Optional,
Expand Down Expand Up @@ -42,6 +42,7 @@
EpisodeType,
ModuleID,
RLModuleSpec,
StateDict,
T,
)
from ray.train._internal.backend_executor import BackendExecutor
Expand Down Expand Up @@ -718,11 +719,11 @@ def set_weights(self, weights: Dict[str, Any]) -> None:

def get_state(
self,
components: Optional[Container[str]] = None,
components: Optional[Collection[str]] = None,
*,
inference_only: bool = False,
module_ids: Container[ModuleID] = None,
) -> Dict[str, Any]:
module_ids: Collection[ModuleID] = None,
) -> StateDict:
"""Get the states of this LearnerGroup.

Contains the Learners' state (which should be the same across Learners) and
Expand All @@ -738,7 +739,7 @@ def get_state(
modules. This is needed for algorithms in the new stack that
use inference-only modules. In this case only a part of the
parameters are synced to the workers. Default is False.
module_ids: Optional container of ModuleIDs to be returned only within the
module_ids: Optional collection of ModuleIDs to be returned only within the
state dict. If None (default), all module IDs' weights are returned.

Returns:
Expand All @@ -765,7 +766,7 @@ def get_state(

return {"learner_state": learner_state}

def set_state(self, state: Dict[str, Any]) -> None:
def set_state(self, state: StateDict) -> None:
"""Sets the state of this LearnerGroup.

Note that all Learners share the same state.
Expand Down
35 changes: 21 additions & 14 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
Any,
Callable,
Collection,
Dict,
KeysView,
List,
Expand All @@ -31,7 +32,7 @@
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.serialization import serialize_type, deserialize_type
from ray.rllib.utils.typing import ModuleID, T
from ray.rllib.utils.typing import ModuleID, StateDict, T
from ray.util import log_once
from ray.util.annotations import PublicAPI

Expand All @@ -44,7 +45,7 @@ class MultiAgentRLModule(RLModule):

This class holds a mapping from module_ids to the underlying RLModules. It provides
a convenient way of accessing each individual module, as well as accessing all of
them with only one API call. Whether or not a given module is trainable is
them with only one API call. Whether a given module is trainable is
determined by the caller of this class (not the instance of this class itself).

The extension of this class can include any arbitrary neural networks as part of
Expand All @@ -67,7 +68,8 @@ def __init__(self, config: Optional["MultiAgentRLModuleConfig"] = None) -> None:
"""Initializes a MultiagentRLModule instance.

Args:
config: The MultiAgentRLModuleConfig to use.
config: An optional MultiAgentRLModuleConfig to use. If None, will use
`MultiAgentRLModuleConfig()` as default config.
"""
super().__init__(config or MultiAgentRLModuleConfig())

Expand Down Expand Up @@ -212,13 +214,16 @@ def __contains__(self, item) -> bool:
return item in self._rl_modules

def __getitem__(self, module_id: ModuleID) -> RLModule:
"""Returns the module with the given module ID.
"""Returns the RLModule with the given module ID.

Args:
module_id: The module ID to get.

Returns:
The module with the given module ID.
The RLModule with the given module ID.

Raises:
KeyError: If `module_id` cannot be found in self.
"""
self._check_module_exists(module_id)
return self._rl_modules[module_id]
Expand Down Expand Up @@ -296,8 +301,10 @@ def _forward_exploration(

@override(RLModule)
def get_state(
self, module_ids: Optional[Set[ModuleID]] = None, inference_only: bool = False
) -> Dict[ModuleID, Any]:
self,
module_ids: Optional[Collection[ModuleID]] = None,
inference_only: bool = False,
) -> StateDict:
"""Returns the state of the multi-agent module.

This method returns the state of each module specified by module_ids. If
Expand All @@ -324,15 +331,15 @@ def get_state(
}

@override(RLModule)
def set_state(self, state_dict: Dict[ModuleID, Any]) -> None:
def set_state(self, state_dict: StateDict) -> None:
"""Sets the state of the multi-agent module.

It is assumed that the state_dict is a mapping from module IDs to their
corressponding state. This method sets the state of each module by calling
their set_state method. If you want to set the state of some of the RLModules
within this MultiAgentRLModule your state_dict can only include the state of
those RLModules. Override this method to customize the state_dict for custom
more advanced multi-agent use cases.
It is assumed that the state_dict is a mapping from module IDs to the
corresponding module's state. This method sets the state of each module by
calling their set_state method. If you want to set the state of some of the
RLModules within this MultiAgentRLModule your state_dict can only include the
state of those RLModules. Override this method to customize the state_dict for
custom more advanced multi-agent use cases.

Args:
state_dict: The state dict to set.
Expand Down
Loading
Loading