Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] How to get the model architecture when using recurrent policy? #1145

Closed
borninfreedom opened this issue Dec 7, 2021 · 8 comments
Labels
question Further information is requested

Comments

@borninfreedom
Copy link

When I run

from stable_baselines import PPO2
from stable_baselines.common.policies import LstmPolicy

class CustomLSTMPolicy(LstmPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64, reuse=False, **_kwargs):
        super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
                         net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])],
                         layer_norm=True, feature_extraction="mlp", **_kwargs)

model = PPO2(CustomLSTMPolicy, 'CartPole-v1',nminibatches=1, verbose=1)
print(model.policy)
print(model.policy_kwargs)

The output is

__main__.CustomLSTMPolicy
{}

Why can't I get the model architecture? If I use stable baselines3, It works writing like this.

@Miffyli Miffyli added the question Further information is requested label Dec 7, 2021
@Miffyli
Copy link
Collaborator

Miffyli commented Dec 7, 2021

I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement)

@araffin
Copy link
Collaborator

araffin commented Dec 7, 2021

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

@borninfreedom
Copy link
Author

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

Does the ppo-lstm work properly now?

@araffin
Copy link
Collaborator

araffin commented Dec 7, 2021

Does the ppo-lstm work properly now?

It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm).

@borninfreedom
Copy link
Author

Does the ppo-lstm work properly now?

It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm).

OK, that's cool. But why the feat/ppo-lstm branch doesn't merge to the master branch now?

@borninfreedom
Copy link
Author

borninfreedom commented Dec 8, 2021

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

How to resolve the error?

Traceback (most recent call last):
  File "test_lstm.py", line 132, in <module>
    test_cnn()
  File "test_lstm.py", line 50, in test_cnn
    model.learn(total_timesteps=32)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts
    actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward
    latent_pi = self.mlp_extractor.forward_actor(latent_pi)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'MlpExtractor' object has no attribute 'forward_actor'

@araffin
Copy link
Collaborator

araffin commented Dec 8, 2021

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

How to resolve the error?

Traceback (most recent call last):
  File "test_lstm.py", line 132, in <module>
    test_cnn()
  File "test_lstm.py", line 50, in test_cnn
    model.learn(total_timesteps=32)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts
    actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward
    latent_pi = self.mlp_extractor.forward_actor(latent_pi)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'MlpExtractor' object has no attribute 'forward_actor'

this issue is not the place for that.
you need master version of SB3.

It is not merged with master because we need to polish it (doc, tests, comments,...)

@araffin araffin closed this as completed Dec 8, 2021
@charlo1998
Copy link

charlo1998 commented May 11, 2023

I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement)
@Miffyli
Can you, by any chance, point to resources to "print out the model" with stable-baselines2? it's been pretty hard to find anything in the docs.
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants