-
Notifications
You must be signed in to change notification settings - Fork 908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Loading a model from checkpoint prevents it from saving a .ckpt file correctly. #1561
Comments
I am not completely sure, but my initial guess is, that this PR will have a positive impact here. What do you think @madtoinou ? |
Hi, Thank you @solalatus for linking this issue. The root of the issue is the "intermediary" model save, that I am not sure to understand. Why would you like to perform such an operation since the "best" checkpoint is already exported by the Pytorch Lightning trainer and you're already loading it (nothing happens between the Regardless of the motivation of such a step, the PR #1501 will introduce two methods, |
Hi, thank you both for responding so promptly. @madtoinou model = NBEATSModel.load_from_checkpoint(model_name, best=True)
# Save the model.
model.save("nbeats_example_2.pt")
exit()
# Return later and do stuff with the model.
model_loaded = NBEATSModel.load("nbeats_example_2.pt") The motivation for using I looked into this a bit more and the model saves fine if # This saves a model only.
model = NBEATSModel.load_from_checkpoint(model_name, best=True)
model.predict(n=1, series=series_air)
model.save("nbeats_example_2.pt") # This saves a model and ckpt.
model = NBEATSModel.load_from_checkpoint(model_name, best=True)
model.save("nbeats_example_2.pt") While, yes I can use the methods from #1501, it really seems like there's some odd behaviour going on here that would be good to look into. I'm wondering if it's even intended for a model to be able to save without creating a corresponding .ckpt file, as the .pt file by itself doesn't seem to be able to be used for anything (although there may be use cases I am not aware of). The documentation for
or
Thanks for looking into this! |
For now a simple fix is just to save the trainer attribute and add it back after calling I'm not sure if this means there'll be a conflict between values in the trainer object and on other parts of the model, but if there is I don't think it's important for me because I'm not retraining, just performing inference. trainer = model.trainer
model = NBEATSModel.load_from_checkpoint(model_name, best=True)
model.trainer = trainer |
Sorry if my question not clear enough but the point I was trying to make it that you could just directly use The saving system (which also encompass the checkpoint system) indeed generates two files:
This separation was implemented to avoid saving the Saving a model without trainer is pointless because it would indicate that the model was not trained, and is either containing the initialization weights or weights contained in an existing checkpoint (that could be directly loaded if necessary). You could probably "trick" darts by creating a new trainer using the |
My use caseI think I understand what you mean now. I am calling save() as I am saving the model to a separate folder, rather than relying on the autosave created by Darts. The idea is that I can optimise the hyperparameters, and save the model manually if the performance of those current hyperparameters exceeds an earlier model with different hyperparameters. As each run of hyperparameter optimisation overwrites the 'best' model autosave I need to save it manually. I'm not sure if this is actually the best way to do it though, as it doesn't really seem to fit into the Darts workflow. So I might change this, especially as I'm considering only performing hyperparameter optimisation on a subset of the training data. So I'll probably export the best hyperparameters and train up my main model separately with a different model name. Then as you say I don't have to call save, and I can just reload the model from the autosaved checkpoint. Saving systemThank you for clarifying the usage of the files generated by the saving system. I was kind of confused there. The source of my confusion was that vanilla Pytorch saves weights in .pt file, whereas Pytorch Lightning saves weights in .ckpt file. I understand that you can load from the checkpoint that already exists, but it seems like But of course it is up to you. This issue is not a blocker for my project now as this discussion has identified several workarounds. Thank you for the great help and clarifications! |
Thank you for detailing your use case, everything makes perfect sense to me now :) I agree with you that for such scenario, one would expect One approach could be to try to retrieve the "automatic/PL" checkpoint (should be doable using the And now that I think about it, we should also probably prevent user from saving model using the "prediction trainer"... WDYT @dennisbader ? |
As of our discussion @madtoinou ,
|
Describe the bug
Hi, I've noticed some unexpected behaviour relating to loading and saving of model checkpoints when using the .load_from_checkpoint method.
This is a bit tricky to explain so the best way is to give examples.
Scenario 1:
This works fine.
Scenario 2:
This does not work because using .load_from_checkpoint removes the model.trainer object. When .save is called the model does not save a .ckpt file because it only saves a .ckpt file if the model.trainer object exists.
I am attempting to implement a workaround in my project by saving the checkpoint file manually, but this is not ideal.
To Reproduce
Expected behavior
The model should save a .cpkt regardless of it has been loaded from checkpoint or not.
System (please complete the following information):
Additional context
Unfortunately I don't know enough about the internal workings of Pytorch Lightning to be able to suggest a complete solution for this.
torch_forecasting_model.py in the darts library line 1275 is relevant
It relies on using the trainer object to save the checkpoint. So if the trainer does not exist then it cannot save.
I don't know if it makes sense to preserve the checkpoint on the model after calling load_from_checkpoint as that trainer object is then relating to a later epoch than of the actual parameters loaded into the model. I guess it would make sense to just try to save a .ckpt file from the model without calling the trainer.
Hope this is helpful!
The text was updated successfully, but these errors were encountered: