Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Chaining Models in RLModules #31469

Merged
merged 52 commits into from
Feb 7, 2023

Conversation

ArturNiederfahrenhorst
Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst commented Jan 5, 2023

Why are these changes needed?

After sketching Solution2 (this PR) and Solution 1, we have decided to go with this PR to pursue this solution further.

With this PR, we introduce a hierarchy of models that fits should be generated by the ModelCatalog.
This PR also removes vf_encoder and pi_encoders to divide the PPORLModule into a shared encoder and vf/pi.

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@ArturNiederfahrenhorst ArturNiederfahrenhorst added the do-not-merge Do not merge this PR! label Jan 5, 2023
@ArturNiederfahrenhorst ArturNiederfahrenhorst changed the title [RLlib] Chaining sub-models in RLModules at configuration time ("Solution 2") [RLlib] Chaining sub-models in RLModules Jan 19, 2023
@ArturNiederfahrenhorst ArturNiederfahrenhorst changed the title [RLlib] Chaining sub-models in RLModules [RLlib] Chaining Models in RLModules Jan 19, 2023
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the direction this is headed to. I took a high-level look for 15 mins, so don't take this as my detailed review at all yet, but here is my feedback:

  • I really like that the setup of RLModules is coming down to calling bunch of config.build(framework=...) calls while the forward calls are also kept minimal and simple. So if we end of with a general and yet flexible abstraction for configs and encoder / trunks the current proposed API for the RLModule is super optimal in my head.
  • I think from_model_config() has a lot of components that will be recurring across many RLModules, e.g. how to parse model_config into encoder config and trunk configs. I think we need an extendible abstraction layer for these kinds of stuff. That is what I have from catalog in mind.
def from_model_config(...):
     catalog = Catalog.from_model_config(...)
     encoder_config = catalog.build_encoder_config()
     # a utility method that returns action_space.n or 2/1 * action_space.shape[0] or 
     action_dim = get_action_dim(action_space, free_log_std)
     pi_config = catalog.build_trunk_config(out_dim=action_dim)
     vf_config = catalog.build_trunk_config(out_dim=1)
     config_ = PPOModuleConfig(
            observation_space=observation_space,
            action_space=action_space,
            shared_encoder_config=encoder_config,
            pi_config=pi_config,
            vf_config=vf_config,
            free_log_std=free_log_std,
        )
        module = PPOTorchRLModule(config_)
        return module

A couple of things to notice here:

  • Catalog does not build neural networks, it will just return the pre-defined neural network configs which can be constructed at run time via a simple .build(framework) API.
  • I think Catalog should still be a deep module with generic api here such as build_encoder(), build_trunk() that have simple interfaces. There is a trade-off between making the interface of these apis simple vs. how deep they'll become. We need to find a sweet spot here.
  • I really love that the spec checking is delegated to the sub-components now. It will make them much cleaner and easier to build.
  • sub-modules should not inherit from base_model.Model. This class was an early version of modules with spec checking capability. I think we have a better version with the decorators in place now. For Encoders specifically, we can create a base interface class with get_initial_state() abstract API. FCEncoder(Encoder) will contain FCNet inside, and will override get_intial_state() to return empty dict. LSTMEncoder(Encoder) will contain LSTMCell, and will override get_initial_state() to return h, c values. For Trunks I think we can also create a base-class with some simple interface apis to standardize trunks as well.

Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving comments discussed offline

# If no checking is needed, we can simply return an empty spec.
return SpecDict()

@check_input_specs("input_spec", filter=True, cache=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline: forward -> _forward() update to not expose users to spec checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fair to simply wrap Model.forward in the constructor to circumvent this.
Torch users will attempt to overwrite forward in any case.
I've made it now so that we detect when forward is not wrapped and "autowrap" it in that case.

raise NotImplementedError


class Model:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class abstracts two things:

  1. spec checking on forward method
  2. unifies forward call between torch and tf. The expectation is that the RLModule / model builder will only work with the assumptions on this api definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it a little. tf.Model and torch.nn now unify that RLModule can simply call them.
Model only defines the minimal input_spec, output_spec and get_initial_state interface. It's pretty shallow now but I think we can leave it here for the moment because models might need other things soon. Possibly a name + a sequence number for richer repr.


@ExperimentalAPI
@dataclass
class ModelConfig(abc.ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one comment, this most likely will end up being a very shallow module that actually adds to the complexity rather than reducing it. Maybe we end up removing it later, once we see more examples of extending this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is mostly about the build method. The dataclass itself is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make this class totally framework agnostic and only require the caller to pass in the same object to different framework specific model constructors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline yesterday, we'll keep the build method because it abstracts the class to be built. Any model_config infering code will simply return a config that can be built. We'd otherwise have to return a class (which we don't want because it's not framework agnostic) to resolve this issue.

)


class TfMLPModel(Model, tf.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be tf.Model or tf.keras.Model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency we should stick to one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be tf.Module. tf.keras.Model is an extention of tf.Module and if we ever run into a situation where we need it's features, we can simply start inheriting from it. But today we don't need it.

https://www.tensorflow.org/api_docs/python/tf/keras/Model
vs
https://www.tensorflow.org/api_docs/python/tf/Module

"A module is a named container for tf.Variables, other tf.Modules and functions which apply to user input"
That's all we want. keras.Model has much richer features and all sorts of stuff that we don't necessarily want to guarantee:
Screenshot 2023-01-27 at 10 38 47

)


class TfMLPModel(Model, tf.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may have been TfModel base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, "typo"

raise NotImplementedError


class TfMLP(tf.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subclass tfModel?

Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
) -> PPOModuleConfig:
"""Get a PPOModuleConfig that we would expect from the catalog otherwise.

Args:
env: Environment for which we build the model later
lstm: If True, build recurrent pi encoder
shared_encoder: If True, build a shared encoder for pi and vf, where pi
encoder and vf encoder will be identity. If False, the shared encoder
will be identity.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll reintroduce this in a upcoming PR with the ActorCriticEncoder

)


def get_expected_model_config_tf(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've unified these into one, since model configs are planned to be framework agnostic.

@@ -343,6 +265,9 @@ def test_forward_train(self):
for param in module.parameters():
self.assertIsNotNone(param.grad)
else:
batch = tree.map_structure(
lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> because tf does not accept numpy arrays.

super().__init__()
self.config = config
self.setup()

def setup(self) -> None:
assert self.config.pi_config, "pi_config must be provided."
assert self.config.vf_config, "vf_config must be provided."
self.shared_encoder = self.config.shared_encoder_config.build()
self.encoder = self.config.encoder_config.build(framework="tf")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From here on, encoder will encapsulate the concept of shared/non-shared layers.

# Shared encoder
encoder_out = self.encoder(batch)
if STATE_OUT in encoder_out:
output[STATE_OUT] = encoder_out[STATE_OUT]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll generally expect a state here in the future and hand that over, even if it's empty.
I'll update this when Encoder are updated in a follow up PR.

self.config = config

@abc.abstractmethod
def get_initial_state(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at some point we should change this to simply be a property "initial_state".


@check_input_specs("input_spec", cache=True)
@check_output_specs("output_spec", cache=True)
@abc.abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm spec checking here even though this method is abstract because Model() this also serves as an example of how forward should look like.

Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
@ArturNiederfahrenhorst
Copy link
Contributor Author

@kouroshHakha I've addressed all of your remarks and, afaics, the PR looks clean now. Could you please have a closer look?

Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I only have one major concern regarding the consistency between tf and torch and also a clean up comment on rl_module repo. There is also a nit :) feel free to ignore it.

@@ -40,7 +40,7 @@ def build(self):


@dataclass
class FCConfig(EncoderConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we just move the entirety of encoder.pys in rl_module to the experimental folder? To clean up the rl_module folder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I thought I'd moved them in the process of writing the new files that are under .../experimental but they where obviously not deleted.
The two encoder files in the .../rl_module folder are not even in use anymore. Many thanks for realizing that there's something off here 😃

@dataclass
class PPOModuleConfig(RLModuleConfig):
"""Configuration for the PPO module.
class PPOModuleConfig(RLModuleConfig): # TODO (Artur): Move to Torch-unspecific file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: non-torch-specific or torch-agnostic :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed 🙂

encoder_out = self.shared_encoder(obs)
action_logits = self.pi(encoder_out)
vf = self.vf(encoder_out)
encoder_out = self.encoder(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make sure we take care of STATE_OUT here as well? similar to torch? TF and torch should be maximally consistent going fwd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! 🙂

Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gjoliver Let's merge?

@gjoliver gjoliver merged commit 027965b into ray-project:master Feb 7, 2023
edoakes pushed a commit to edoakes/ray that referenced this pull request Mar 22, 2023
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants