forked from unit8co/darts
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat/improved training from ckpt (unit8co#1501)
* feat: new function fit_from_checkpoint that load one chkpt from the mode, allows user to change the optimizer, scheduler or trainer and export the ckpt of this fine-tuned model into another folder. fine-tuning cannot be chained using this method (original model ckpt must be reloaded) * fix: improved the model saving to allow chaining of fine-tuning, better control over the logger, made the function static * feat: allow to save the checkpoint in the same folder (loaded checkpoint is likely to be overwritten if the model is trained with default parameters) * fix: ordered arguments in a more intuitive way * fix: saving model after updating all the parameters to facilitate the chain-fine tuning * feat: support for load_from_checkpoint kwargs, support for force_reset argument * feat: adding test for setup_finetuning * fix: fused the setup_finetuning and load_from_checkpoint methods, added dcostring, updated tests * fix: changed the API/approach, instead of trying to overwrite attributes of an existing model, rather load the weights into a new model (but not the other attributes such as the optimizer, trainer, ... * fix: convertion of hyper-parameters to list when checking compatibility between checkpoint and instantiated model * fix: skip the None attribute during the hp check * fix: removed unecessary attribute initialization * feat: pl_forecasting_module also save the train_sample in the checkpoints * fix: saving only shape instead of the sample itself * fix: restore the self.train_sample in TorchForecastingModel * fix: update fit_called attribute to enable inference without retraining * fix: the mock train_sample must be converted to tuple * fix: tweaked model parameters to improve convergence * fix: increased number of epochs to improve convergence/test stability * fix: addressing review comments; added load_weights method and corresponding tests, updated documentation * fix: changed default checkpoint path name for compatibility with Windows OS * feat: raise error if the checkpoint being loaded does not contain the train_sample_shape entry, to make the break more transparent to users * fix: saving model manually directly after laoding it from checkpoint will retrieve and copy the original .ckpt file to avoid unexpected behaviors * fix: use random_state to fix randomness in tests * fix: restore newlines * fix: casting dtype of PLModule before loading the weights * doc: model_name docstring and code were not consistent * doc: improve phrasing * Apply suggestions from code review Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * fix: removed warning in saving about trainer/ckpt not being found, warning will be raised in the load() call if no weights can be loaded * fix: uniformised filename convention using '_' to separate hours, minutes and seconds, updated doc accordingly * fix: removed typo * Update darts/models/forecasting/torch_forecasting_model.py Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * fix: more consistent use of the path argument during save and load --------- Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
- Loading branch information
1 parent
3e43174
commit e373217
Showing
13 changed files
with
395 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.