Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vgg model checkpoint needs a change of classifier weight names #3

Open
soumith opened this issue Apr 8, 2017 · 2 comments
Open

vgg model checkpoint needs a change of classifier weight names #3

soumith opened this issue Apr 8, 2017 · 2 comments

Comments

@soumith
Copy link

soumith commented Apr 8, 2017

See https://discuss.pytorch.org/t/upgrading-torchvision-module-makes-old-model-useless/1719

The reason this changed was because of pytorch/vision#107 where Sam realized that he put dropout in the wrong location.

So the state_dict needs the names changed appropriately.

@alykhantejani
Copy link

alykhantejani commented May 16, 2017

Just ran into this myself. The change is pretty simple, this should do it:

import torch
from torch.utils.model_zoo import load_url
from torchvision import models

sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth")
sd['classifier.0.weight'] = sd['classifier.1.weight']
sd['classifier.0.bias'] = sd['classifier.1.bias']
del sd['classifier.1.weight']
del sd['classifier.1.bias']

sd['classifier.3.weight'] = sd['classifier.4.weight']
sd['classifier.3.bias'] = sd['classifier.4.bias']
del sd['classifier.4.weight']
del sd['classifier.4.bias']

torch.save(sd, "vgg19-d01eb7cb.pth")

Would be great if you could upload the newer versions to s3

@ZhengRui
Copy link

ZhengRui commented Aug 24, 2017

keep param order as model is an OrderedDict:

from collections import OrderedDict
from torch.utils.model_zoo import load_url
import torch

sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.iteritems()])
torch.save(sd, "vgg19-d01eb7cb.pth")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants