Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TFTModel.load_from_checkpoint and .fit() is returning an error. #1090

Closed
criscapdechoy opened this issue Jul 21, 2022 · 6 comments
Closed
Labels
question Further information is requested

Comments

@criscapdechoy
Copy link

Describe the bug
First of all we train a model with TFTModel with 30 epochs. Then, we aim to do transfer learning by re-training the previous model loading it from last checkpoint. Then, we execute the .fit(..,epochs=additional_n_epochs) but an error occurs:

File "<string>", line 1, in <module>
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
  return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 771, in fit
  return self.fit_from_dataset(
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
  return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 930, in fit_from_dataset
  self._train(train_loader, val_loader)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 952, in _train
  self.trainer.fit(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
  self._call_and_handle_interrupt(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
  return trainer_fn(*args, **kwargs)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
  results = self._run(model, ckpt_path=self.ckpt_path)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1232, in _run
  self._checkpoint_connector.restore_training_state()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 199, in restore_training_state
  self.restore_loops()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 293, in restore_loops
  raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: You restored a checkpoint with current_epoch=29, but you have set Trainer(max_epochs=5).

To Reproduce

additional_n_epochs=5
my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
my_model.fit(...,epochs=additional_n_epochs)

Expected behavior
We aim to get a training process departing from the epoch of last checkpoint and continue until the total number of epochs is: my_model.n_epochs + additional_n_epochs .

System (please complete the following information):

  • Python version: 3.9
  • darts version 0.18.0
@criscapdechoy criscapdechoy added bug Something isn't working triage Issue waiting for triaging labels Jul 21, 2022
@dennisbader
Copy link
Collaborator

Hi @criscapdechoy, can you try the following?

previous_n_epochs = 30  # for how many epochs you trained the first time
my_model.fit(..., epochs=previous_n_epochs + additional_n_epochs)

When loading the model from checkpoint, it is at epoch 30 (index 29). You need to tell the model to continue training until epoch 30 + additional_n_epochs = 35 (index 34).

@criscapdechoy
Copy link
Author

Hi @dennisbader! Thank you for your fast replay.

If I use what you propose the error is not showing! However something I was not expecting happens...When I try to retrain, i.e. after loading_from_checkpoint(), the model starts from epoch 0 till epoch 34 (resulting from previous_n_epochs + additional_n_epochs indeed). But I was expecting the model to start from epoch 30 till epoch 34 this second time. Am I wrong?

@dennisbader
Copy link
Collaborator

Hmm, it might be that we lost automatic support for that with the new PyTorch Lightning versions.
Could you try this manual approach to see if it works?

import os
import pytorch_lightning as pl
from darts.models.forecasting.torch_forecasting_model import _get_checkpoint_folder, _get_checkpoint_fname

epochs = 30
additional_epochs = 5
model_name = mymodelname
work_dir = mymodeldir

ckpt_dir = _get_checkpoint_folder(work_dir, model_name)
file_name = _get_checkpoint_fname(work_dir, model_name, best=False)
ckpt_path = os.path.join(ckpt_dir, file_name)

my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
trainer_params = my_model.trainer_params

# instantiate a PyTorch Lightning trainer and tell it to resume from your last checkpoint
trainer = pl.Trainer(resume_from_checkpoint=ckpt_path, **trainer_params)

my_model.fit(..., epochs=epochs + additional_epochs, trainer=trainer)

@criscapdechoy
Copy link
Author

Hello again!
I've tried what you propose with a small change and now the retrain works!!! :) So, after instantiating the trainer from the checkpoint with the trainer_params,I've added a line which redefine the trainer["max_epochs"] to avoid the error. I've set the trainer["max_epochs"] = epochs+additional_epochs and then use my_model.fit(...).

PD:To get the epochs that the model checkpoint loaded I have to call:

import torch 
...
epochs = torch.load(ckpt_path)["epochs"]

However, I am not sure if it is the best way to know how many epochs the model have been trained. I don't know if there's any other attribute from the trainer containing this info?

@hrzn
Copy link
Contributor

hrzn commented Aug 23, 2022

@dennisbader , what's your opinion, do you think we should rework the way we handle epochs?

@hrzn hrzn added question Further information is requested and removed bug Something isn't working triage Issue waiting for triaging labels Aug 23, 2022
@madtoinou
Copy link
Collaborator

Closing this as @dennisbader detailed the reason for keeping the current behavior in #1689. load_weights_from_checkpoint() should be privileged for fine-tuning/retraining whereas calling fit() after loading a model will resume the interrupted training (and it remains possible to increase the maximum number of epochs by taking into account the initial number of epochs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants