Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder 23: Curiosity (inverse dynamics model based) RLModule example. #46841

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2499,7 +2499,7 @@ py_test(
name = "examples/curiosity/count_based_curiosity",
main = "examples/curiosity/count_based_curiosity.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
size = "large",
srcs = ["examples/curiosity/count_based_curiosity.py"],
args = ["--enable-new-api-stack", "--as-test"]
)
Expand All @@ -2508,11 +2508,21 @@ py_test(
name = "examples/curiosity/euclidian_distance_based_curiosity",
main = "examples/curiosity/euclidian_distance_based_curiosity.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
size = "large",
srcs = ["examples/curiosity/euclidian_distance_based_curiosity.py"],
args = ["--enable-new-api-stack", "--as-test"]
)

py_test(
name = "examples/curiosity/inverse_dynamics_model_based_curiosity",
main = "examples/curiosity/inverse_dynamics_model_based_curiosity.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/curiosity/inverse_dynamics_model_based_curiosity.py"],
args = ["--enable-new-api-stack", "--as-test"]
)


# subdirectory: curriculum/
# ....................................
py_test(
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def build(self) -> None:
)
)

# Prepend a NEXT_OBS from episodes to train batch connector piece (right
# after the observation default piece).
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this PR, but later we might want to have also a remove method for the connector pipeline. We can of course always override build_learner_pipeline in the config, but that means to define the complete pipeline instead of single parts that need to be removed/replaced.

if self.config.add_default_connectors_to_learner_pipeline:
self._learner_connector.insert_after(
AddObservationsFromEpisodesToBatch,
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/common/batch_individual_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __call__(
# connector piece is called.
if not self._multi_agent:
continue
# If MA Off-Policy and independent sampling we need to overcome
# this check.
# If MA Off-Policy and independent sampling we need to overcome this
# check.
module_data = column_data
for col, col_data in module_data.copy().items():
if isinstance(col_data, list) and col != Columns.INFOS:
Expand Down
23 changes: 12 additions & 11 deletions rllib/connectors/connector_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,17 +415,18 @@ def add_batch_item(
`column`.
"""
sub_key = None
if (
single_agent_episode is not None
and single_agent_episode.agent_id is not None
):
sub_key = (
single_agent_episode.multi_agent_episode_id,
single_agent_episode.agent_id,
single_agent_episode.module_id,
)
elif single_agent_episode is not None:
sub_key = (single_agent_episode.id_,)
# SAEpisode is provided ...
if single_agent_episode is not None:
# ... and has `agent_id` -> Use agent ID and module ID from it.
if single_agent_episode.agent_id is not None:
sub_key = (
single_agent_episode.multi_agent_episode_id,
single_agent_episode.agent_id,
single_agent_episode.module_id,
)
# Otherwise, just use episode's ID.
else:
sub_key = (single_agent_episode.id_,)

if column not in batch:
batch[column] = [] if sub_key is None else {sub_key: []}
Expand Down
3 changes: 3 additions & 0 deletions rllib/core/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class Columns:
ADVANTAGES = "advantages"
VALUE_TARGETS = "value_targets"

# Intrinsic rewards (learning with curiosity).
INTRINSIC_REWARDS = "intrinsic_rewards"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Having this makes things less ugly :)


# Loss mask. If provided in a train batch, a Learner's compute_loss_for_module
# method should respect the False-set value in here and mask out the respective
# items form the loss.
Expand Down
9 changes: 3 additions & 6 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,8 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None):

@OverrideToImplementCustomLogic
def compute_loss(
self,
*,
fwd_out: Dict[str, Any],
batch: Dict[str, Any],
) -> Union[TensorType, Dict[str, Any]]:
self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
) -> Dict[str, Any]:
"""Computes the loss for the module being optimized.

This method must be overridden by multiagent-specific algorithm learners to
Expand Down Expand Up @@ -886,7 +883,7 @@ def compute_loss_for_module(
self,
*,
module_id: ModuleID,
config: Optional["AlgorithmConfig"] = None,
config: "AlgorithmConfig",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we now always need to provide a config? I think for most algorithms this is not needed because self.config should be available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like this is the better solution for users. 2 reasons:

  • Users normally override compute_loss_for_module, so now they do NOT have to implement a logic, where config is None.
  • Users do NOT normally override the more top-level compute_loss, so we can easily provide each module's individual config through our base implementations.

In other words, if we had left this arg to be optional, every user writing a custom loss function would have had to implement a (not too known) logic on how to get the module's individual config.

batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> TensorType:
Expand Down
26 changes: 13 additions & 13 deletions rllib/core/models/specs/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_input_specs(
This is a stateful decorator
(https://realpython.com/primer-on-python-decorators/#stateful-decorators) to
enforce input specs for any instance method that has an argument named
`input_data` in its args.
`batch` in its args.

See more examples in ../tests/test_specs_dict.py)

Expand All @@ -183,7 +183,7 @@ def input_specs(self):
return {"obs": TensorSpec("b, d", d=64)}

@check_input_specs("input_specs", only_check_on_retry=False)
def forward(self, input_data, return_loss=False):
def forward(self, batch, return_loss=False):
...

model = MyModel()
Expand All @@ -194,11 +194,11 @@ def forward(self, input_data, return_loss=False):

Args:
func: The instance method to decorate. It should be a callable that takes
`self` as the first argument, `input_data` as the second argument and any
`self` as the first argument, `batch` as the second argument and any
other keyword argument thereafter.
input_specs: `self` should have an instance attribute whose name matches the
string in input_specs and returns the `SpecDict`, `Spec`, or simply the
`Type` that the `input_data` should comply with. It can also be None or
`Type` that the `batch` should comply with. It can also be None or
empty list / dict to enforce no input spec.
only_check_on_retry: If True, the spec will not be checked. Only if the
decorated method raises an Exception, we check the spec to provide a more
Expand All @@ -220,15 +220,15 @@ def forward(self, input_data, return_loss=False):

def decorator(func):
@functools.wraps(func)
def wrapper(self, input_data, **kwargs):
def wrapper(self, batch, **kwargs):
if cache and not hasattr(self, "__checked_input_specs_cache__"):
self.__checked_input_specs_cache__ = {}

initial_exception = None
if only_check_on_retry:
# Attempt to run the function without spec checking
try:
return func(self, input_data, **kwargs)
return func(self, batch, **kwargs)
except SpecCheckingError as e:
raise e
except Exception as e:
Expand All @@ -242,7 +242,7 @@ def wrapper(self, input_data, **kwargs):
)

# If the function was not executed successfully yet, we check specs
checked_data = input_data
checked_data = batch

if input_specs and (
initial_exception
Expand All @@ -262,7 +262,7 @@ def wrapper(self, input_data, **kwargs):
checked_data = _validate(
cls_instance=self,
method=func,
data=input_data,
data=batch,
spec=spec,
tag="input",
)
Expand Down Expand Up @@ -312,17 +312,17 @@ def output_specs(self):
return {"obs": TensorSpec("b, d", d=64)}

@check_output_specs("output_specs")
def forward(self, input_data, return_loss=False):
def forward(self, batch, return_loss=False):
return {"obs": torch.randn(32, 64)}

Args:
func: The instance method to decorate. It should be a callable that takes
`self` as the first argument, `input_data` as the second argument and any
`self` as the first argument, `batch` as the second argument and any
other keyword argument thereafter. It should return a single dict-like
object (i.e. not a tuple).
output_specs: `self` should have an instance attribute whose name matches the
string in output_specs and returns the `SpecDict`, `Spec`, or simply the
`Type` that the `input_data` should comply with. It can alos be None or
`Type` that the `batch` should comply with. It can alos be None or
empty list / dict to enforce no input spec.
cache: If True, only checks the data validation for the first time the
instance method is called.
Expand All @@ -338,11 +338,11 @@ def forward(self, input_data, return_loss=False):

def decorator(func):
@functools.wraps(func)
def wrapper(self, input_data, **kwargs):
def wrapper(self, batch, **kwargs):
if cache and not hasattr(self, "__checked_output_specs_cache__"):
self.__checked_output_specs_cache__ = {}

output_data = func(self, input_data, **kwargs)
output_data = func(self, batch, **kwargs)

if output_specs and (
not cache or func.__name__ not in self.__checked_output_specs_cache__
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/models/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _determine_components(self):
module = spec.build()

module.forward_inference(
input_data={"obs": torch.ones((32, *env.observation_space.shape))}
batch={"obs": torch.ones((32, *env.observation_space.shape))}
)


Expand Down
7 changes: 0 additions & 7 deletions rllib/core/rl_module/torch/torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ def restore_from_path(self, *args, **kwargs):
def get_metadata(self, *args, **kwargs):
self.unwrapped().get_metadata(*args, **kwargs)

# TODO (sven): Figure out a better way to avoid having to method-spam this wrapper
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice. This was still there.

# class, whenever we add a new API to any wrapped RLModule here. We could try
# auto generating the wrapper methods, but this will bring its own challenge
# (e.g. recursive calls due to __getattr__ checks, etc..).
def _compute_values(self, *args, **kwargs):
return self.unwrapped()._compute_values(*args, **kwargs)

@override(RLModule)
def unwrapped(self) -> "RLModule":
return self.module
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/testing/testing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_default_rl_module_spec(self) -> "RLModuleSpecType":

class BaseTestingLearner(Learner):
@override(Learner)
def compute_loss_for_module(self, *, module_id, config=None, batch, fwd_out):
def compute_loss_for_module(self, *, module_id, config, batch, fwd_out):
# This is to check if in the multi-gpu case, the weights across workers are
# the same. It is really only needed during testing.
if config.report_mean_weights:
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def sample(
# For complete episodes mode, sample a single episode and
# leave coordination of sampling to `synchronous_parallel_sample`.
# TODO (simon, sven): The coordination will eventually move
# to `EnvRunnerGroup` in the future. So from the algorithm one
# would do `EnvRunnerGroup.sample()`.
# to `EnvRunnerGroup` in the future. So from the algorithm one
# would do `EnvRunnerGroup.sample()`.
else:
samples = self._sample_episodes(
num_episodes=1,
Expand Down
Loading