Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethink usage pattern for pretrained models #200

Open
nateraw opened this issue Sep 10, 2020 · 5 comments
Open

Rethink usage pattern for pretrained models #200

nateraw opened this issue Sep 10, 2020 · 5 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@nateraw
Copy link
Contributor

nateraw commented Sep 10, 2020

🚀 Feature

Switch to using SomeModel.from_pretrained('pretrained-model-name') for pretrained models.

Motivation

Seems we are following torchvision's pattern of having a 'pretrained' argument in the init of our models to initialize a pretrained model. In my opinion, this is extremely confusing as it makes the other init args + kwargs ambiguous/useless.

Pitch

add .from_pretrained classmethod to models and initialize an instance of the class based off of that. Pretrained models should incorporate any hparams needed to fill out init, I guess.

from pl_bolts.models import VAE

model = VAE.from_pretrained('imagenet2012')

Alternatives

Additional context

@nateraw nateraw added enhancement New feature or request help wanted Extra attention is needed labels Sep 10, 2020
@nateraw nateraw self-assigned this Sep 11, 2020
@williamFalcon
Copy link
Contributor

yeah, agree... although this is basically just the same as load_from_checkpoint no? sounds like we're looking for checkpoint nicknames instead?

doesn't it read better as:

VAE.pretrained_on('xyz')

@nateraw
Copy link
Contributor Author

nateraw commented Sep 11, 2020

Right, I think the distinction here is that load_from_checkpoint is for checkpoints you have saved locally, but this function would be for pretrained models that we are hosting (i.e. these guys).

So, yes! We are looking for something that can point to a nickname/identifier for a pretrained model.


I think 'pretrained_on' is a limiting name, as a model could be pretrained on the same dataset twice w/ different settings, and then would be ambiguous to load if using that function name. Thats why I suggest something a little more open, such as from_pretrained(identifier).

This is just my opinion... I could be convinced otherwise haha 😄 . Let's have others weigh in to come to consensus.

CC: @PyTorchLightning/core-contributors

@williamFalcon
Copy link
Contributor

oh i see. it's an id not a dataset.
yeah that works.

for instance we can have many backbones with different datasets as well

CPC.from_pretrained('resnet18-imagenet')
CPC.from_pretrained('resnet50-imagenet')
CPC.from_pretrained('resnet18-stl10')

@Borda
Copy link
Member

Borda commented Sep 11, 2020

Yes, they are trained on a defined dataset, in this case, the dataset name serves just as Look-up-table to a specific path on PL side...

@ananyahjha93
Copy link
Contributor

ananyahjha93 commented Sep 12, 2020

@williamFalcon @Borda @nateraw I included this pattern in the latest AE, VAE commits to bolts. Few points that I realized:

  1. We can shift the method from_pretrained() as a method to override in Lightning itself.
  2. from_pretrained() needs to be an instance method and not a static method. In most cases, you will initialize the lightning module with specific params according the the weights being loaded.
vae = VAE(input_height=32, first_conv=True)
vae = vae.from_pretrained('cifar10-resnet18')

In this example stl10 weights have a different configuration for the encoder of the VAE. But, at the same time the internal method has a strict=False flag while loading so that users can load stl10 weights to the encoder configuration of cifar10 dataset.

  1. Having this pattern allows us to test the correct loading of weights using the from_pretrained() function. @williamFalcon cases like the corrupt ImageNet weights for CPC will be caught automatically.

I have added all of this + tests for the AE and VAE classes I have updated for bolts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants