Skip to content

Support training resume and saving best model for SWA #6074

@b02202050

Description

@b02202050

🚀 Feature

To support stop-resume and saving the model based on the best validation performance when using Stochastic weight averaging.

Motivation

If one does not reach the end of the training, the SWA model would not be saved to the checkpoint.
For example:

  1. One may want to save the best validation performance model before overfitting.
  2. The training becomes unstable to reach the final epoch.

I observe that the saved checkpoint is still the original model even if I use an SWA callback.

Pitch

  1. Use SWA model to run validation step instead of the original model.
    1. We may consider to Load SWA weight before validation and Restore original model weight after validation
  2. Save both original model weight and SWA model weight into checkpoint for resume training.
    1. With the concern of say one want to load the model for testing, he/she might expect the model be the SWA weight. Therefore, we can save the SWA weight to the checkpotint['state_dict'] and let the original model weight saved as a callback state.

Alternatives

Additional context

@Borda @MilesCranmer, I do not have much experience on coding neither am I familiar with PyTorch Lightning. would you help? Thx a lot. 😉

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions