Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight dropout #48

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Weight dropout #48

wants to merge 3 commits into from

Conversation

Judyxujj
Copy link
Contributor

This PR adds the weight dropout

@albertz
Copy link
Member

albertz commented Apr 23, 2024

I think following torch.nn.utils.weight_norm or torch.nn.utils.parametrizations.weight_norm for this makes more sense, or not? This doesn't need to be (or rather should not be) specific for one module (here it is specific to torch.nn.Linear) but it should be generic, and less hacky, and simpler.

@Judyxujj Judyxujj closed this Apr 23, 2024
@Judyxujj
Copy link
Contributor Author

Judyxujj commented Apr 23, 2024

I think following torch.nn.utils.weight_norm or torch.nn.utils.parametrizations.weight_norm for this makes more sense, or not? This doesn't need to be (or rather should not be) specific for one module (here it is specific to torch.nn.Linear) but it should be generic, and less hacky, and simpler.

@albertz but here we can also use the WeightDrop class to wrap other torch.nn.Module like this

gru = torch.nn.GRUCell(2, 2)
weights = ['weight_hh']
weight_drop_gru = WeightDrop(gru, weights, dropout=0.9)

I added WeightDropLinear because I think it would be most frequently used.

@Judyxujj Judyxujj reopened this Apr 23, 2024
@albertz
Copy link
Member

albertz commented Apr 23, 2024

Ah sorry I misread the code regarding nn.Linear. You are right, your WeightDrop is generic. But I still think it's too complicated and too hacky and non-PyTorch like compared to how torch.nn.utils.weight_norm or torch.nn.utils.parametrizations.weight_norm works. I would follow standard PyTorch code (like torch.nn.utils.weight_norm or torch.nn.utils.parametrizations.weight_norm), i.e. add weight dropout just in the same way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants