-
Notifications
You must be signed in to change notification settings - Fork 724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[question] How to get the model architecture when using recurrent policy? #1145
Comments
I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement) |
and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment) |
Does the ppo-lstm work properly now? |
It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm). |
OK, that's cool. But why the feat/ppo-lstm branch doesn't merge to the master branch now? |
How to resolve the error? Traceback (most recent call last):
File "test_lstm.py", line 132, in <module>
test_cnn()
File "test_lstm.py", line 50, in test_cnn
model.learn(total_timesteps=32)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
type(self).__name__, name))
AttributeError: 'MlpExtractor' object has no attribute 'forward_actor' |
this issue is not the place for that. It is not merged with master because we need to polish it (doc, tests, comments,...) |
|
When I run
The output is
Why can't I get the model architecture? If I use stable baselines3, It works writing like this.
The text was updated successfully, but these errors were encountered: