-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Chaining Models in RLModules #31469
[RLlib] Chaining Models in RLModules #31469
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the direction this is headed to. I took a high-level look for 15 mins, so don't take this as my detailed review at all yet, but here is my feedback:
- I really like that the setup of RLModules is coming down to calling bunch of
config.build(framework=...)
calls while the forward calls are also kept minimal and simple. So if we end of with a general and yet flexible abstraction for configs and encoder / trunks the current proposed API for the RLModule is super optimal in my head. - I think
from_model_config()
has a lot of components that will be recurring across many RLModules, e.g. how to parse model_config into encoder config and trunk configs. I think we need an extendible abstraction layer for these kinds of stuff. That is what I have from catalog in mind.
def from_model_config(...):
catalog = Catalog.from_model_config(...)
encoder_config = catalog.build_encoder_config()
# a utility method that returns action_space.n or 2/1 * action_space.shape[0] or
action_dim = get_action_dim(action_space, free_log_std)
pi_config = catalog.build_trunk_config(out_dim=action_dim)
vf_config = catalog.build_trunk_config(out_dim=1)
config_ = PPOModuleConfig(
observation_space=observation_space,
action_space=action_space,
shared_encoder_config=encoder_config,
pi_config=pi_config,
vf_config=vf_config,
free_log_std=free_log_std,
)
module = PPOTorchRLModule(config_)
return module
A couple of things to notice here:
- Catalog does not build neural networks, it will just return the pre-defined neural network configs which can be constructed at run time via a simple
.build(framework)
API. - I think Catalog should still be a deep module with generic api here such as
build_encoder()
,build_trunk()
that have simple interfaces. There is a trade-off between making the interface of these apis simple vs. how deep they'll become. We need to find a sweet spot here.
- I really love that the spec checking is delegated to the sub-components now. It will make them much cleaner and easier to build.
- sub-modules should not inherit from
base_model.Model
. This class was an early version of modules with spec checking capability. I think we have a better version with the decorators in place now. For Encoders specifically, we can create a base interface class withget_initial_state()
abstract API.FCEncoder(Encoder)
will contain FCNet inside, and will overrideget_intial_state()
to return empty dict.LSTMEncoder(Encoder)
will contain LSTMCell, and will overrideget_initial_state()
to returnh, c
values. For Trunks I think we can also create a base-class with some simple interface apis to standardize trunks as well.
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leaving comments discussed offline
rllib/models/experimental/base.py
Outdated
# If no checking is needed, we can simply return an empty spec. | ||
return SpecDict() | ||
|
||
@check_input_specs("input_spec", filter=True, cache=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offline: forward -> _forward()
update to not expose users to spec checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fair to simply wrap Model.forward in the constructor to circumvent this.
Torch users will attempt to overwrite forward in any case.
I've made it now so that we detect when forward is not wrapped and "autowrap" it in that case.
raise NotImplementedError | ||
|
||
|
||
class Model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class abstracts two things:
- spec checking on forward method
- unifies forward call between torch and tf. The expectation is that the RLModule / model builder will only work with the assumptions on this api definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated it a little. tf.Model and torch.nn now unify that RLModule can simply call them.
Model only defines the minimal input_spec, output_spec and get_initial_state interface. It's pretty shallow now but I think we can leave it here for the moment because models might need other things soon. Possibly a name + a sequence number for richer repr.
|
||
@ExperimentalAPI | ||
@dataclass | ||
class ModelConfig(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one comment, this most likely will end up being a very shallow module that actually adds to the complexity rather than reducing it. Maybe we end up removing it later, once we see more examples of extending this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern is mostly about the build method. The dataclass itself is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make this class totally framework agnostic and only require the caller to pass in the same object to different framework specific model constructors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline yesterday, we'll keep the build method because it abstracts the class to be built. Any model_config infering code will simply return a config that can be built. We'd otherwise have to return a class (which we don't want because it's not framework agnostic) to resolve this issue.
) | ||
|
||
|
||
class TfMLPModel(Model, tf.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be tf.Model
or tf.keras.Model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for consistency we should stick to one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be tf.Module. tf.keras.Model is an extention of tf.Module and if we ever run into a situation where we need it's features, we can simply start inheriting from it. But today we don't need it.
https://www.tensorflow.org/api_docs/python/tf/keras/Model
vs
https://www.tensorflow.org/api_docs/python/tf/Module
"A module is a named container for tf.Variables, other tf.Modules and functions which apply to user input"
That's all we want. keras.Model has much richer features and all sorts of stuff that we don't necessarily want to guarantee:
) | ||
|
||
|
||
class TfMLPModel(Model, tf.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may have been TfModel base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, "typo"
raise NotImplementedError | ||
|
||
|
||
class TfMLP(tf.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subclass tfModel?
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
) -> PPOModuleConfig: | ||
"""Get a PPOModuleConfig that we would expect from the catalog otherwise. | ||
|
||
Args: | ||
env: Environment for which we build the model later | ||
lstm: If True, build recurrent pi encoder | ||
shared_encoder: If True, build a shared encoder for pi and vf, where pi | ||
encoder and vf encoder will be identity. If False, the shared encoder | ||
will be identity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll reintroduce this in a upcoming PR with the ActorCriticEncoder
) | ||
|
||
|
||
def get_expected_model_config_tf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've unified these into one, since model configs are planned to be framework agnostic.
@@ -343,6 +265,9 @@ def test_forward_train(self): | |||
for param in module.parameters(): | |||
self.assertIsNotNone(param.grad) | |||
else: | |||
batch = tree.map_structure( | |||
lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> because tf does not accept numpy arrays.
super().__init__() | ||
self.config = config | ||
self.setup() | ||
|
||
def setup(self) -> None: | ||
assert self.config.pi_config, "pi_config must be provided." | ||
assert self.config.vf_config, "vf_config must be provided." | ||
self.shared_encoder = self.config.shared_encoder_config.build() | ||
self.encoder = self.config.encoder_config.build(framework="tf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From here on, encoder will encapsulate the concept of shared/non-shared layers.
# Shared encoder | ||
encoder_out = self.encoder(batch) | ||
if STATE_OUT in encoder_out: | ||
output[STATE_OUT] = encoder_out[STATE_OUT] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll generally expect a state here in the future and hand that over, even if it's empty.
I'll update this when Encoder are updated in a follow up PR.
self.config = config | ||
|
||
@abc.abstractmethod | ||
def get_initial_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at some point we should change this to simply be a property "initial_state".
|
||
@check_input_specs("input_spec", cache=True) | ||
@check_output_specs("output_spec", cache=True) | ||
@abc.abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm spec checking here even though this method is abstract because Model() this also serves as an example of how forward should look like.
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
@kouroshHakha I've addressed all of your remarks and, afaics, the PR looks clean now. Could you please have a closer look? |
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I only have one major concern regarding the consistency between tf and torch and also a clean up comment on rl_module repo. There is also a nit :) feel free to ignore it.
@@ -40,7 +40,7 @@ def build(self): | |||
|
|||
|
|||
@dataclass | |||
class FCConfig(EncoderConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we just move the entirety of encoder.pys in rl_module to the experimental folder? To clean up the rl_module folder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I thought I'd moved them in the process of writing the new files that are under .../experimental but they where obviously not deleted.
The two encoder files in the .../rl_module folder are not even in use anymore. Many thanks for realizing that there's something off here 😃
@dataclass | ||
class PPOModuleConfig(RLModuleConfig): | ||
"""Configuration for the PPO module. | ||
class PPOModuleConfig(RLModuleConfig): # TODO (Artur): Move to Torch-unspecific file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: non-torch-specific or torch-agnostic :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed 🙂
encoder_out = self.shared_encoder(obs) | ||
action_logits = self.pi(encoder_out) | ||
vf = self.vf(encoder_out) | ||
encoder_out = self.encoder(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make sure we take care of STATE_OUT here as well? similar to torch? TF and torch should be maximally consistent going fwd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! 🙂
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gjoliver Let's merge?
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Why are these changes needed?
After sketching Solution2 (this PR) and Solution 1, we have decided to go with this PR to pursue this solution further.
With this PR, we introduce a hierarchy of models that fits should be generated by the ModelCatalog.
This PR also removes vf_encoder and pi_encoders to divide the PPORLModule into a shared encoder and vf/pi.
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.