Skip to content

Commit

Permalink
Feat/improved training from ckpt (unit8co#1501)
Browse files Browse the repository at this point in the history
* feat: new function fit_from_checkpoint that load one chkpt from the mode, allows user to change the optimizer, scheduler or trainer and export the ckpt of this fine-tuned model into another folder. fine-tuning cannot be chained using this method (original model ckpt must be reloaded)

* fix: improved the model saving to allow chaining of fine-tuning, better control over the logger, made the function static

* feat: allow to save the checkpoint in the same folder (loaded checkpoint is likely to be overwritten if the model is trained with default parameters)

* fix: ordered arguments in a more intuitive way

* fix: saving model after updating all the parameters to facilitate the chain-fine tuning

* feat: support for load_from_checkpoint kwargs, support for force_reset argument

* feat: adding test for setup_finetuning

* fix: fused the setup_finetuning and load_from_checkpoint methods, added dcostring, updated tests

* fix: changed the API/approach, instead of trying to overwrite attributes of an existing model, rather load the weights into a new model (but not the other attributes such as the optimizer, trainer, ...

* fix: convertion of hyper-parameters to list when checking compatibility between checkpoint and instantiated model

* fix: skip the None attribute during the hp check

* fix: removed unecessary attribute initialization

* feat: pl_forecasting_module also save the train_sample in the checkpoints

* fix: saving only shape instead of the sample itself

* fix: restore the self.train_sample in TorchForecastingModel

* fix: update fit_called attribute to enable inference without retraining

* fix: the mock train_sample must be converted to tuple

* fix: tweaked model parameters to improve convergence

* fix: increased number of epochs to improve convergence/test stability

* fix: addressing review comments; added load_weights method and corresponding tests, updated documentation

* fix: changed default checkpoint path name for compatibility with Windows OS

* feat: raise error if the checkpoint being loaded does not contain the train_sample_shape entry, to make the break more transparent to users

* fix: saving model manually directly after laoding it from checkpoint will retrieve and copy the original .ckpt file to avoid unexpected behaviors

* fix: use random_state to fix randomness in tests

* fix: restore newlines

* fix: casting dtype of PLModule before loading the weights

* doc: model_name docstring and code were not consistent

* doc: improve phrasing

* Apply suggestions from code review

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: removed warning in saving about trainer/ckpt not being found, warning will be raised in the load() call if no weights can be loaded

* fix: uniformised filename convention using '_' to separate hours, minutes and seconds, updated doc accordingly

* fix: removed typo

* Update darts/models/forecasting/torch_forecasting_model.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: more consistent use of the path argument during save and load

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
2 people authored and alexcolpitts96 committed May 31, 2023
1 parent 3e43174 commit e373217
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 32 deletions.
4 changes: 2 additions & 2 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
6 changes: 3 additions & 3 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def model_params(self) -> dict:

@classmethod
def _default_save_path(cls) -> str:
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"

def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> None:
"""
Expand All @@ -1555,8 +1555,8 @@ def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> Non
----------
path
Path or file handle under which to save the model at its current state. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pkl"``.
E.g., ``"RegressionModel_2020-01-01_12:00:00.pkl"``.
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pkl"``.
E.g., ``"RegressionModel_2020-01-01_12_00_00.pkl"``.
pkl_kwargs
Keyword arguments passed to `pickle.dump()`
"""
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
13 changes: 13 additions & 0 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self,
input_chunk_length: int,
output_chunk_length: int,
train_sample_shape: Optional[Tuple] = None,
loss_fn: nn.modules.loss._Loss = nn.MSELoss(),
torch_metrics: Optional[
Union[torchmetrics.Metric, torchmetrics.MetricCollection]
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
Number of input past time steps per chunk.
output_chunk_length
Number of output time steps per chunk.
train_sample_shape
Shape of the model's input, used to instantiate model without calling ``fit_from_dataset`` and
perform sanity check on new training/inference datasets used for re-training or prediction.
loss_fn
PyTorch loss function used for training.
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
Expand Down Expand Up @@ -101,6 +105,9 @@ def __init__(
# by default models are deterministic (i.e. not probabilistic)
self.likelihood = likelihood

# saved in checkpoint to be able to instantiate a model without calling fit_from_dataset
self.train_sample_shape = train_sample_shape

# persist optimiser and LR scheduler parameters
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs
Expand Down Expand Up @@ -383,11 +390,17 @@ def _produce_predict_output(self, x: Tuple):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# we must save the dtype for correct parameter precision at loading time
checkpoint["model_dtype"] = self.dtype
# we must save the shape of the input to be able to instanciate the model without calling fit_from_dataset
checkpoint["train_sample_shape"] = self.train_sample_shape

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision
# before parameters are loaded by PyTorch-Lightning
dtype = checkpoint["model_dtype"]
self.to_dtype(dtype)

def to_dtype(self, dtype):
"""Cast module precision (float32 by default) to another precision."""
if dtype == torch.float16:
self.half()
if dtype == torch.float32:
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
Loading

0 comments on commit e373217

Please sign in to comment.