Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding dropout to a custom network #1666

Closed
4 tasks done
Karlheinzniebuhr opened this issue Sep 5, 2023 · 3 comments
Closed
4 tasks done

Adding dropout to a custom network #1666

Karlheinzniebuhr opened this issue Sep 5, 2023 · 3 comments
Labels
duplicate This issue or pull request already exists question Further information is requested

Comments

@Karlheinzniebuhr
Copy link

❓ Question

I'm trying to add dropout to my network, but everything I tried gets me some kind of error.

import gym
import torch
import torch.nn as nn
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3 import PPO

# Define a minimal Gym environment (CartPole-v1)
envs = gym.make("CartPole-v1")

class CustomMlpPolicy(ActorCriticPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomMlpPolicy, self).__init__(*args, **kwargs)
        
        # Modify the policy network to add dropout layers
        self.policy_net = nn.Sequential(
            nn.Linear(self.mlp_extractor.output_dim, 128),
            nn.Tanh(),
            nn.Dropout(p=0.5),  # Add dropout to the first hidden layer
            nn.Linear(128, 128),
            nn.Tanh(),
        )

        # Value network remains the same
        self.value_net = nn.Sequential(
            nn.Linear(self.mlp_extractor.output_dim, 256),
            nn.Tanh(),
            nn.Linear(256, 256),
            nn.Tanh(),
            nn.Linear(256, 256),
            nn.Tanh(),
            nn.Linear(256, 256),
            nn.Tanh(),
        )

# Create the PPO model using your custom policy
model = PPO(CustomMlpPolicy, envs)
model.learn(total_timesteps=10000)  # Train for a specified number of timesteps

AttributeError: 'MlpExtractor' object has no attribute 'output_dim'

I checked the documentation but it does not contain any reference to dropouts.

How can I implement dropouts?

Checklist

@Karlheinzniebuhr Karlheinzniebuhr added the question Further information is requested label Sep 5, 2023
@araffin araffin added the duplicate This issue or pull request already exists label Sep 6, 2023
@araffin
Copy link
Member

araffin commented Sep 6, 2023

Duplicate of #1069 (comment), a working example can be found in #1036

@Karlheinzniebuhr
Copy link
Author

I'm afraid #1036 only works for SAC() since PPO didn't get an update with that PR.

@araffin
Copy link
Member

araffin commented Sep 7, 2023

I'm afraid #1036 only works for SAC() since PPO didn't get an update with that PR.

well, nothing prevent you from forking and doing the update yourself ;)
i mostly wanted to link a working example with dropout.

@araffin araffin closed this as not planned Won't fix, can't repro, duplicate, stale Sep 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants