Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why are hparams mandatory in the LightningModule definition? #599

Closed
akshaykulkarni07 opened this issue Dec 6, 2019 · 10 comments
Closed
Labels
question Further information is requested

Comments

@akshaykulkarni07
Copy link
Contributor

When I don't pass hparams in the LightningModule, it doesn't allow me to load a previously saved model using a checkpoint. Particularly, hparams can't be passed in Jupyter Notebook/Lab, so how to use it in such a usecase (for testing, etc.)?

I am able to train a model, checkpoint it, but I can't load it when I restart the kernel.

@akshaykulkarni07 akshaykulkarni07 added the question Further information is requested label Dec 6, 2019
@williamFalcon
Copy link
Contributor

hparams are not required. we could also expand support to passing in a dictionary or something

@akshaykulkarni07
Copy link
Contributor Author

Yes, they are not mandatory as such. But the saved weights in the checkpoint cannot be loaded. Is there a workaround for that?

@williamFalcon
Copy link
Contributor

@neggert

@neggert
Copy link
Contributor

neggert commented Dec 6, 2019

It's just the load_from_checkpoint method that doesn't work if you don't use hparams. I think the idea was that load_from_checkpoint as a classmethod doesn't make much sense unless you have some way of restoring the hyperparameters and a standard set of model constructor args.

If you don't want to follow the standard convention, you should be able to create your model manually, then restore your weights just as you would a standard Torch model.

model = MyModel(whatever, args, you, want)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

This is a little bit error-prone in cases where your model has variable layer sizes. If you don't create the model with the same layers as the model that the checkpoint was saved from, this will fail.

You can also let the trainer restore the weights by doing e.g.

trainer = Trainer(..., restore_from_checkpoint=checkpoint_path)
trainer.test()

Do either of these solutions meet your needs? Perhaps we could take a look at providing a convenience method for this? Or maybe even change the signature of load_from_checkpoint to

@classmethod
def load_from_checkpoint(cls, checkpoint_path, model_kwargs=None):

Then if we get model_kwargs, use them. If not, try to load hparams from the the checkpoint.

@akshaykulkarni07
Copy link
Contributor Author

akshaykulkarni07 commented Dec 7, 2019

@neggert
The first solution of using the normal PyTorch way works. But the 2nd one using the Trainer doesn't. Trainer doesn't have a restore_from_checkpoint argument.

@Borda
Copy link
Member

Borda commented Dec 7, 2019

BTW, we shall include these usecases in tests...

@jeffling
Copy link
Contributor

Perhaps we could take a look at providing a convenience method for this?

I'm very for this. For my usecase, we use a completely separate configuration system outside of lightning's for our model construction.

The first solution of using the normal PyTorch way works. But the 2nd one using the Trainer doesn't. Trainer doesn't have a restore_from_checkpoint argument.

@akshaykvnit I believe it's resume_from_checkpoint

@shijianjian
Copy link
Contributor

@jeffling No resume_from_checkpoint either. I am not pretty sure about the previous version, but in fact, there is no such thing to restore the previous state from the trainer. Also, I believe the trainer should not have that function as it should do the training only.

@akshaykulkarni07
Copy link
Contributor Author

@shijianjian @jeffling @neggert @williamFalcon @Borda There is a resume_from_checkpoint parameter in the Trainer source code. However, the parameter is not used in the remaining code to load anything.

@neggert
Copy link
Contributor

neggert commented Dec 12, 2019

Seems to have been implemented in #516. At a glance the implementation looks reasonable, but I haven't tested it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

6 participants