You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to add dropout to my network, but everything I tried gets me some kind of error.
import gym
import torch
import torch.nn as nn
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3 import PPO
# Define a minimal Gym environment (CartPole-v1)
envs = gym.make("CartPole-v1")
class CustomMlpPolicy(ActorCriticPolicy):
def __init__(self, *args, **kwargs):
super(CustomMlpPolicy, self).__init__(*args, **kwargs)
# Modify the policy network to add dropout layers
self.policy_net = nn.Sequential(
nn.Linear(self.mlp_extractor.output_dim, 128),
nn.Tanh(),
nn.Dropout(p=0.5), # Add dropout to the first hidden layer
nn.Linear(128, 128),
nn.Tanh(),
)
# Value network remains the same
self.value_net = nn.Sequential(
nn.Linear(self.mlp_extractor.output_dim, 256),
nn.Tanh(),
nn.Linear(256, 256),
nn.Tanh(),
nn.Linear(256, 256),
nn.Tanh(),
nn.Linear(256, 256),
nn.Tanh(),
)
# Create the PPO model using your custom policy
model = PPO(CustomMlpPolicy, envs)
model.learn(total_timesteps=10000) # Train for a specified number of timesteps
AttributeError: 'MlpExtractor' object has no attribute 'output_dim'
I checked the documentation but it does not contain any reference to dropouts.
How can I implement dropouts?
Checklist
I have checked that there is no similar issue in the repo
❓ Question
I'm trying to add dropout to my network, but everything I tried gets me some kind of error.
AttributeError: 'MlpExtractor' object has no attribute 'output_dim'
I checked the documentation but it does not contain any reference to dropouts.
How can I implement dropouts?
Checklist
The text was updated successfully, but these errors were encountered: