-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Milestone
Description
🚀 Feature
To support stop-resume and saving the model based on the best validation performance when using Stochastic weight averaging.
Motivation
If one does not reach the end of the training, the SWA model would not be saved to the checkpoint.
For example:
- One may want to save the best validation performance model before overfitting.
- The training becomes unstable to reach the final epoch.
I observe that the saved checkpoint is still the original model even if I use an SWA callback.
Pitch
- Use SWA model to run validation step instead of the original model.
- We may consider to Load SWA weight before validation and Restore original model weight after validation
- Save both original model weight and SWA model weight into checkpoint for resume training.
- With the concern of say one want to load the model for testing, he/she might expect the model be the SWA weight. Therefore, we can save the SWA weight to the
checkpotint['state_dict']and let the original model weight saved as a callback state.
- With the concern of say one want to load the model for testing, he/she might expect the model be the SWA weight. Therefore, we can save the SWA weight to the
Alternatives
Additional context
@Borda @MilesCranmer, I do not have much experience on coding neither am I familiar with PyTorch Lightning. would you help? Thx a lot. 😉
janoshpriancho
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on