Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improved training from ckpt #1501

Merged
merged 47 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e7e92fa
feat: new function fit_from_checkpoint that load one chkpt from the m…
madtoinou Jan 19, 2023
22828d1
fix: improved the model saving to allow chaining of fine-tuning, bett…
madtoinou Jan 20, 2023
c4f4370
feat: allow to save the checkpoint in the same folder (loaded checkpo…
madtoinou Jan 20, 2023
75acd53
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
c6eddc1
fix: ordered arguments in a more intuitive way
madtoinou Jan 20, 2023
4b38347
fix: saving model after updating all the parameters to facilitate the…
madtoinou Jan 20, 2023
30603ca
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
1abcb96
feat: support for load_from_checkpoint kwargs, support for force_rese…
madtoinou Jan 20, 2023
bd4f035
feat: adding test for setup_finetuning
madtoinou Jan 20, 2023
0e71805
Merge branch 'feat/improved-training-from-ckpt' of https://github.com…
madtoinou Jan 20, 2023
5ec58bc
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
a7be96f
fix: fused the setup_finetuning and load_from_checkpoint methods, add…
madtoinou Jan 23, 2023
07ac34a
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 23, 2023
206aa40
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 23, 2023
247b570
fix: changed the API/approach, instead of trying to overwrite attribu…
madtoinou Jan 30, 2023
83211be
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 30, 2023
5a39edd
fix: convertion of hyper-parameters to list when checking compatibili…
madtoinou Jan 30, 2023
4d2b77c
Merge branch 'feat/improved-training-from-ckpt' of https://github.com…
madtoinou Jan 30, 2023
44a3fa4
fix: skip the None attribute during the hp check
madtoinou Jan 30, 2023
ee00b89
fix: removed unecessary attribute initialization
madtoinou Jan 30, 2023
9cc0ac8
feat: pl_forecasting_module also save the train_sample in the checkpo…
madtoinou Feb 5, 2023
8c93454
fix: saving only shape instead of the sample itself
madtoinou Feb 5, 2023
77447b2
fix: restore the self.train_sample in TorchForecastingModel
madtoinou Feb 6, 2023
17f9c3d
fix: update fit_called attribute to enable inference without retraining
madtoinou Feb 6, 2023
8e2462f
fix: the mock train_sample must be converted to tuple
madtoinou Feb 6, 2023
ce35e8a
fix: tweaked model parameters to improve convergence
madtoinou Feb 6, 2023
167498a
fix: increased number of epochs to improve convergence/test stability
madtoinou Feb 6, 2023
4a18301
fix: addressing review comments; added load_weights method and corres…
madtoinou Feb 13, 2023
192a423
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 13, 2023
0c6a461
fix: changed default checkpoint path name for compatibility with Wind…
madtoinou Feb 14, 2023
e309390
feat: raise error if the checkpoint being loaded does not contain the…
madtoinou Feb 14, 2023
d13f4a7
fix: saving model manually directly after laoding it from checkpoint …
madtoinou Feb 16, 2023
96812d8
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 16, 2023
4304cf1
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 17, 2023
867ad35
fix: use random_state to fix randomness in tests
madtoinou Feb 19, 2023
b42d6e1
fix: restore newlines
madtoinou Feb 19, 2023
6b0de3e
fix: casting dtype of PLModule before loading the weights
madtoinou Feb 19, 2023
845f96e
doc: model_name docstring and code were not consistent
madtoinou Feb 19, 2023
497420f
doc: improve phrasing
madtoinou Feb 19, 2023
72486f8
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 19, 2023
39ba739
Apply suggestions from code review
madtoinou Feb 19, 2023
edab120
fix: removed warning in saving about trainer/ckpt not being found, wa…
madtoinou Feb 19, 2023
c002f3e
fix: uniformised filename convention using '_' to separate hours, min…
madtoinou Feb 19, 2023
aa735de
fix: removed typo
madtoinou Feb 19, 2023
3328835
Update darts/models/forecasting/torch_forecasting_model.py
madtoinou Feb 19, 2023
9d13eaf
fix: more consistent use of the path argument during save and load
madtoinou Feb 19, 2023
b60c9f2
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
from abc import ABC, abstractmethod
from glob import glob
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -916,6 +916,167 @@ def fit_from_dataset(
self._train(train_loader, val_loader)
return self

@staticmethod
def setup_finetuning(
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
old_model_name: str,
new_model_name: str = None,
additional_epochs: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this parameter, given that fit() already accepts epochs ?
I would find it cleaner to rely exclusively on fit()'s parameter. If there's a problem with it, could we maybe fix it there (i.e. handle the trainer correctly in fit() to handle epoch)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an issue with the fit() parameter (#1495), I think that @rijkvandermeulen is already working on a fix. I will remove the epochs argument from this method and wait for the patch to be merged.

trainer_params: Optional[Dict] = None,
work_dir: str = None,
file_name: str = None,
best: bool = False,
save_inplace: bool = False,
force_reset: bool = False,
optimizer_cls: torch.optim.Optimizer = torch.optim.Adam,
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_cls: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
**kwargs,
):
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
# call _get_checkpoint_fname if file_name is None
# TODO: support for load_from_checkpoint kwargs
model = TorchForecastingModel.load_from_checkpoint(
model_name=old_model_name,
work_dir=work_dir,
file_name=file_name,
best=best,
**kwargs,
)

if new_model_name is None:
model.model_name = old_model_name
else:
model.model_name = new_model_name
model.model_params["model_name"] = model.model_name

if work_dir is not None and work_dir != model.work_dir:
model.work_dir = work_dir

# checkpoint path
checkpoints_folder = _get_checkpoint_folder(model.work_dir, model.model_name)
checkpoint_exists = (
os.path.exists(checkpoints_folder)
and len(glob(os.path.join(checkpoints_folder, "*"))) > 0
)
raise_if(
save_inplace and force_reset,
"For safety reasons, `save_inplace` and `force_reset` cannot be both True to prevent "
" deletion of the loaded checkpoint.",
logger,
)
if checkpoint_exists and model.save_checkpoints:
if force_reset:
model.reset_model()
else:
raise_if_not(
save_inplace,
f"Some model data already exists for `model_name` '{model.model_name}'. Either provide a"
f" `new_model_name` to save the checkpoints in a new folder or set `save_inplace`"
f" to True to save them in the existing folder (calling `fit` on the loaded model"
f" will likely overwrite the loaded checkpoint).",
logger,
)
elif model.save_checkpoints:
model._create_save_dirs()
else:
pass

# TODO: avoid user warning about dirpath change
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=checkpoints_folder,
save_last=True,
monitor="val_loss",
filename="best-{epoch}-{val_loss:.2f}",
)
checkpoint_callback.CHECKPOINT_NAME_LAST = "last-{epoch}"

# update trainer
old_trainer_params = model.trainer_params.copy()
new_max_epochs = old_trainer_params["max_epochs"] + additional_epochs
if trainer_params is not None:
for trainer_param in trainer_params.keys():
model.trainer_params[trainer_param] = trainer_params[trainer_param]

# special parameter handling
if "callbacks" in trainer_params.keys() and len(
trainer_params["callbacks"] > 0
):
model.trainer_params["callbacks"] = [
checkpoint_callback
] + trainer_params["callbacks"]
else:
model.trainer_params["callbacks"] = [checkpoint_callback]

if "max_epochs" in trainer_params.keys():
if additional_epochs != 0:
raise_if(
trainer_params["max_epochs"] != new_max_epochs,
"The number of epochs to retrain the model for was defined in"
" both the `trainer_params` and `additional_epochs` arguments"
f" with differents values ({trainer_params['max_epochs']} and"
f" {old_trainer_params['max_epochs']} + {additional_epochs})",
logger,
)
else:
new_max_epochs = trainer_params["max_epochs"]
raise_if(
new_max_epochs <= old_trainer_params["max_epochs"],
"The number of epochs to retrain passed in `trainer_params['max_epochs']`"
" or `additional_epochs`is smaller or equal to the number of epochs used"
" to train the model",
logger,
)
model.n_epochs = new_max_epochs
model.model_params["n_epochs"] = new_max_epochs
model.trainer_params["max_epochs"] = new_max_epochs

# update optimizer
if optimizer_cls == model.model.optimizer_cls:
if optimizer_kwargs is not None:
model.model.optimizer_kwargs.update(optimizer_kwargs)
else:
# using different optimizer
model.model.optimizer_cls = optimizer_cls
model.model.optimizer_kwargs = (
dict() if optimizer_kwargs is None else optimizer_kwargs
)
model.model_params["optimizer_scheduler_cls"] = model.model.optimizer_cls
model.model_params["optimizer_kwargs"] = model.model.optimizer_kwargs
model.pl_module_params["optimizer_cls"] = model.model.optimizer_cls
model.pl_module_params["optimizer_kwargs"] = model.model.optimizer_kwargs

# update scheduler
if lr_scheduler_cls == model.model.lr_scheduler_cls:
if lr_scheduler_kwargs is not None:
model.model.lr_scheduler_kwargs.update(lr_scheduler_kwargs)
else:
model.model.lr_scheduler_cls = lr_scheduler_cls
model.model.lr_scheduler_kwargs = (
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
)
model.model_params["lr_scheduler_cls"] = model.model.lr_scheduler_cls
model.model_params["lr_scheduler_kwargs"] = model.model.lr_scheduler_kwargs
model.pl_module_params["lr_scheduler_cls"] = model.model.lr_scheduler_cls
model.pl_module_params["lr_scheduler_kwargs"] = model.model.lr_scheduler_kwargs

# save the initialized TorchForecastingModel as PyTorch-Lightning only saves module checkpoints
# to allow finetuning of a fine-tuned model...
model.save(
os.path.join(
_get_runs_folder(model.work_dir, model.model_name), INIT_MODEL_NAME
)
)

new_trainer = model._init_trainer(
model.trainer_params, model.trainer_params["max_epochs"]
)
model.trainer = new_trainer
model.model.trainer = new_trainer
model.trainer.strategy.setup_optimizers(new_trainer)

model.model.setup("fit")
return model

def _train(
self, train_loader: DataLoader, val_loader: Optional[DataLoader]
) -> None:
Expand Down
159 changes: 159 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,165 @@ def test_train_from_10_n_epochs_20_fit_15_epochs(self):
model1.fit(self.series, epochs=15)
self.assertEqual(15, model1.epochs_trained)

def test_setup_finetuning(self):
original_model_name = "original"
fintuned_model_name = "fintuned"
# original model, checkpoints are saved
model1 = RNNModel(
12,
"RNN",
10,
10,
n_epochs=2,
work_dir=self.temp_work_dir,
save_checkpoints=True,
model_name=original_model_name,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
)
model1.fit(self.series)
self.assertEqual(2, model1.epochs_trained)

# load last checkpoint of original model, train it for 2 additional epochs
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=fintuned_model_name,
work_dir=self.temp_work_dir,
additional_epochs=2,
)
model_ft.fit(self.series)
self.assertEqual(4, model_ft.epochs_trained)

# load last checkpoint of original model, train it for 4 additional epochs
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=fintuned_model_name,
work_dir=self.temp_work_dir,
trainer_params={"max_epochs": 6},
force_reset=True,
)
model_ft.fit(self.series)
self.assertEqual(6, model_ft.epochs_trained)

# load last checkpoint of fine-tuned model, train it for 2 additional epochs
model_ft = RNNModel.setup_finetuning(
old_model_name=fintuned_model_name,
new_model_name=fintuned_model_name + "_twice",
work_dir=self.temp_work_dir,
additional_epochs=2,
)
model_ft.fit(self.series)
self.assertEqual(8, model_ft.epochs_trained)

# check saving last ckpt in same folder as original model
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=original_model_name,
work_dir=self.temp_work_dir,
trainer_params={"max_epochs": 8},
save_inplace=True,
)
model_ft.fit(self.series)
self.assertEqual(8, model_ft.epochs_trained)

# raise Exception when the number of additional epochs is contradictory
with self.assertRaises(ValueError):
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=fintuned_model_name,
work_dir=self.temp_work_dir,
additional_epochs=2,
trainer_params={"max_epochs": 10},
)

# raise Exception when the max_epochs trainer parameter is too small
with self.assertRaises(ValueError):
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=fintuned_model_name,
work_dir=self.temp_work_dir,
trainer_params={"max_epochs": 1},
)

# raise Exception when the target checkpoint folder already exist
with self.assertRaises(ValueError):
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=original_model_name,
work_dir=self.temp_work_dir,
additional_epochs=2,
)

# raise Exception when trying to save ckpt in place and force_reset simultaneously
with self.assertRaises(ValueError):
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=original_model_name,
work_dir=self.temp_work_dir,
additional_epochs=2,
save_inplace=True,
force_reset=True,
)

def test_setup_finetuning_optimizer(self):
original_model_name = "original"
fintuned_model_name = "fintuned"
# original model, Adam optimizer
model1 = RNNModel(
12,
"RNN",
10,
10,
n_epochs=2,
work_dir=self.temp_work_dir,
save_checkpoints=True,
model_name=original_model_name,
)
model1.fit(self.series)
self.assertEqual(2, model1.epochs_trained)

# load last checkpoint of original model, change optimizer from Adam to RAdam
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=fintuned_model_name,
work_dir=self.temp_work_dir,
additional_epochs=2,
optimizer_cls=torch.optim.RAdam,
optimizer_kwargs={"lr": 0.0001},
)
model_ft.fit(self.series, trainer=model_ft.trainer)
self.assertEqual(4, model_ft.epochs_trained)
self.assertEqual(type(model_ft.trainer.optimizers[0]), torch.optim.RAdam)

def test_setup_finetuning_scheduler(self):
original_model_name = "original"
fintuned_model_name = "fintuned"
# original model, without scheduler
model1 = RNNModel(
12,
"RNN",
10,
10,
n_epochs=2,
work_dir=self.temp_work_dir,
save_checkpoints=True,
model_name=original_model_name,
)
model1.fit(self.series)
self.assertEqual(2, model1.epochs_trained)

# load last checkpoint of original model, add a scheduler
model_ft = RNNModel.setup_finetuning(
old_model_name=original_model_name,
new_model_name=fintuned_model_name,
work_dir=self.temp_work_dir,
additional_epochs=2,
lr_scheduler_cls=torch.optim.lr_scheduler.StepLR,
lr_scheduler_kwargs={"step_size": 10},
)
model_ft.fit(self.series)
self.assertEqual(4, model_ft.epochs_trained)
# cannot check class, use length of config as a proxy
self.assertEqual(len(model_ft.trainer.lr_scheduler_configs), 1)

def test_optimizers(self):

optimizers = [
Expand Down