-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder 23: Curiosity (inverse dynamics model based) RLModule example. #46841
[RLlib] Cleanup examples folder 23: Curiosity (inverse dynamics model based) RLModule example. #46841
Conversation
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…nup_examples_folder_22_count_based_curiosity
…cleanup_examples_folder_23_curiosity_rl_module_example
…nup_examples_folder_23_curiosity_rl_module_example
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…nup_examples_folder_23_curiosity_rl_module_example
…nup_examples_folder_23_curiosity_rl_module_example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Awesome example. Added some comments where I had questions. Furthermore, if we decide to provide exploration again as a feature we might need wrappers for learners and modules.
# Prepend a NEXT_OBS from episodes to train batch connector piece (right | ||
# after the observation default piece). | ||
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right | ||
# after the corresponding "add-OBS-..." default piece). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in this PR, but later we might want to have also a remove
method for the connector pipeline. We can of course always override build_learner_pipeline
in the config, but that means to define the complete pipeline instead of single parts that need to be removed/replaced.
@@ -59,6 +59,9 @@ class Columns: | |||
ADVANTAGES = "advantages" | |||
VALUE_TARGETS = "value_targets" | |||
|
|||
# Intrinsic rewards (learning with curiosity). | |||
INTRINSIC_REWARDS = "intrinsic_rewards" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Having this makes things less ugly :)
@@ -886,7 +883,7 @@ def compute_loss_for_module( | |||
self, | |||
*, | |||
module_id: ModuleID, | |||
config: Optional["AlgorithmConfig"] = None, | |||
config: "AlgorithmConfig", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we now always need to provide a config? I think for most algorithms this is not needed because self.config
should be available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt like this is the better solution for users. 2 reasons:
- Users normally override
compute_loss_for_module
, so now they do NOT have to implement a logic, where config is None. - Users do NOT normally override the more top-level
compute_loss
, so we can easily provide each module's individual config through our base implementations.
In other words, if we had left this arg to be optional, every user writing a custom loss function would have had to implement a (not too known) logic on how to get the module's individual config.
@@ -163,13 +163,6 @@ def restore_from_path(self, *args, **kwargs): | |||
def get_metadata(self, *args, **kwargs): | |||
self.unwrapped().get_metadata(*args, **kwargs) | |||
|
|||
# TODO (sven): Figure out a better way to avoid having to method-spam this wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah nice. This was still there.
}, | ||
) | ||
# Use our custom `curiosity` method to set up the ICM and our PPO/ICM-Learner. | ||
.curiosity( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
# `model_config_dict` property: | ||
cfg = self.config.model_config_dict | ||
|
||
feature_dim = cfg.get("feature_dim", 288) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:D How did you come up with this number? Is it in the paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes :D
layers.append(nn.Linear(in_size, out_size)) | ||
if cfg.get("feature_net_activation"): | ||
layers.append( | ||
get_activation_fn(cfg["feature_net_activation"], "torch")() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice use of our help function!
in_size = out_size | ||
# Last feature layer of n nodes (feature dimension). | ||
layers.append(nn.Linear(in_size, feature_dim)) | ||
self._feature_net = nn.Sequential(*layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question: did you want to show with this example also hwo to build a Torch module from scratch? We could simply use our predefined MLPs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, but yes, I wanted to show a more bare-bones approach w/o using too many RLlib-related utilities. Using only torch makes this RLModule much more readable for new users.
layers = [] | ||
dense_layers = cfg.get("feature_net_hiddens", (256, 256)) | ||
# `in_size` is the observation space (assume a simple Box(1D)). | ||
in_size = self.config.observation_space.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works only for 1d spaces, does it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, yes. If users need e.g. image-obs support, they can now easily write their own ICMs. I didn't want to make the example ICM too complicated.
# Forward loss term: Predicted phi - given phi and action - vs actually observed | ||
# phi (square-root of L2 norm). Note that this is the intrinsic reward that | ||
# will be used and the mean of this is the forward net loss. | ||
forward_l2_norm_sqrt = 0.5 * torch.sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks to me like half the MSE loss - is this intended? We could use here for simplification torch.nn.MSELoss
with reduction=sum
- and if it should be L2 loss a sqrt on top of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, correct. It's simple MSE loss. I'm always afraid of using these baked-in loss functions b/c they have no transparency and might do things differently from what my written-out code does :D
Great point. I think this example is a first step to making something like "curiosity" pluggable into any algorithm. This same example would work for DQN with only a handful of lines changed (ICM RLModule: unchanged; PPOTorchLearnerWCuriosity: flip over to DQNTorchLearnerWCuriosity w/o really having to change any code; config: stays the same except for flipping out PPOConfig vs DQNConfig). |
Cleanup examples folder 23: Curiosity (inverse dynamics model based) RLModule example.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.