-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vgg model checkpoint needs a change of classifier weight names #3
Comments
Just ran into this myself. The change is pretty simple, this should do it:
Would be great if you could upload the newer versions to s3 |
keep param order as model is an OrderedDict: from collections import OrderedDict
from torch.utils.model_zoo import load_url
import torch
sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.iteritems()])
torch.save(sd, "vgg19-d01eb7cb.pth") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See https://discuss.pytorch.org/t/upgrading-torchvision-module-makes-old-model-useless/1719
The reason this changed was because of pytorch/vision#107 where Sam realized that he put dropout in the wrong location.
So the state_dict needs the names changed appropriately.
The text was updated successfully, but these errors were encountered: