Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Loading a model from checkpoint prevents it from saving a .ckpt file correctly. #1561

Closed
JLC827 opened this issue Feb 12, 2023 · 8 comments · Fixed by #1501
Closed
Labels
bug Something isn't working triage Issue waiting for triaging

Comments

@JLC827
Copy link

JLC827 commented Feb 12, 2023

Describe the bug
Hi, I've noticed some unexpected behaviour relating to loading and saving of model checkpoints when using the .load_from_checkpoint method.

This is a bit tricky to explain so the best way is to give examples.

Scenario 1:

  • Train a model
  • Save the model.
  • Load the model and use it to predict.
    This works fine.

Scenario 2:

  • Train a model.
  • Load an earlier checkpoint (for example the epoch with the lowest validation loss).
  • Save the model.
  • Load the model and use it to predict.
    This does not work because using .load_from_checkpoint removes the model.trainer object. When .save is called the model does not save a .ckpt file because it only saves a .ckpt file if the model.trainer object exists.

I am attempting to implement a workaround in my project by saving the checkpoint file manually, but this is not ideal.

To Reproduce

import numpy as np
from darts.datasets import AirPassengersDataset
from darts.models import NBEATSModel

# Load example dataset
series_air = AirPassengersDataset().load().astype(np.float32)

# Train a model and save it without loading a previous checkpoint. This works fine.
# Create a model.
model_name = "nbeats_example_1"
model = NBEATSModel(input_chunk_length=24, output_chunk_length=12, model_name=model_name, save_checkpoints=True,
                    force_reset=True)
# Train the model. Yes I'm validating on the same data as the training, but it's not important for this example.
model.fit(series=series_air, val_series=series_air, epochs=5, verbose=True)
# Save the model.
model.save("nbeats_example_1.pt")
model_loaded = NBEATSModel.load("nbeats_example_1.pt")
predictions = model_loaded.predict(series=series_air, n=12)
print(predictions)

# Train a model and save it without loading a previous checkpoint. This does not work.
# Create a model.
model_name = "nbeats_example_2"
model = NBEATSModel(input_chunk_length=24, output_chunk_length=12, model_name=model_name, save_checkpoints=True,
                    force_reset=True)
# Train the model. Yes I'm validating on the same data as the training, but it's not important for this example.
model.fit(series=series_air, val_series=series_air, epochs=5, verbose=True)
# Load the best checkpoint. This removes the trainer from the model and prevents it from saving a .ckpt file.
model = NBEATSModel.load_from_checkpoint(model_name, best=True)
# Save the model.
model.save("nbeats_example_2.pt")
model_loaded = NBEATSModel.load("nbeats_example_2.pt")
# This line fails with the exception AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'.
predictions = model_loaded.predict(series=series_air, n=12)
print(predictions)

Expected behavior
The model should save a .cpkt regardless of it has been loaded from checkpoint or not.

System (please complete the following information):

  • Python version: 3.10
  • darts version 0.23.1
  • Windows 10.0.19045

Additional context
Unfortunately I don't know enough about the internal workings of Pytorch Lightning to be able to suggest a complete solution for this.

torch_forecasting_model.py in the darts library line 1275 is relevant

        if self.trainer is not None:
            path_ptl_ckpt = path + ".ckpt"
            self.trainer.save_checkpoint(path_ptl_ckpt)

It relies on using the trainer object to save the checkpoint. So if the trainer does not exist then it cannot save.

I don't know if it makes sense to preserve the checkpoint on the model after calling load_from_checkpoint as that trainer object is then relating to a later epoch than of the actual parameters loaded into the model. I guess it would make sense to just try to save a .ckpt file from the model without calling the trainer.

Hope this is helpful!

@JLC827 JLC827 added bug Something isn't working triage Issue waiting for triaging labels Feb 12, 2023
@JLC827 JLC827 changed the title [BUG] Loading a model from checkpoint prevents it from saving a .cpkt file correctly. [BUG] Loading a model from checkpoint prevents it from saving a .ckpt file correctly. Feb 12, 2023
@solalatus
Copy link
Contributor

I am not completely sure, but my initial guess is, that this PR will have a positive impact here.

What do you think @madtoinou ?

@madtoinou
Copy link
Collaborator

madtoinou commented Feb 13, 2023

Hi,

Thank you @solalatus for linking this issue.

The root of the issue is the "intermediary" model save, that I am not sure to understand. Why would you like to perform such an operation since the "best" checkpoint is already exported by the Pytorch Lightning trainer and you're already loading it (nothing happens between the load_from_checkpoint and the load in scenario 2)?

Regardless of the motivation of such a step, the PR #1501 will introduce two methods, load_weights and load_weights_from_checkpoint to more easily load weights, potentially making the intermediary save unnecessary.

@JLC827
Copy link
Author

JLC827 commented Feb 14, 2023

Hi, thank you both for responding so promptly.

@madtoinou
Nothing happens between load_from_checkpoint and load in this scenario because I created a simplified example. Sorry, I should have added a note explaining this. In practice, my python process is exiting after training, and I am importing the models into a separate python process for later steps. This is why I am saving the models rather than keeping them in memory. Hopefully the snippet below makes it clearer.

model = NBEATSModel.load_from_checkpoint(model_name, best=True)
# Save the model.
model.save("nbeats_example_2.pt")
exit()

# Return later and do stuff with the model.
model_loaded = NBEATSModel.load("nbeats_example_2.pt")

The motivation for using load_from_checkpoint is to load the model with the weights that achieved the highest validation score before saving it. This is similar to what is being done in this Optuna example in the docs. As far as I'm aware this is the best way to manually save and use the 'best' model, please correct me if I'm wrong! I guess an alternative would be to manually copy the autosaved 'best' .ckpt file to a new location, but that seems more complicated.

I looked into this a bit more and the model saves fine if predict is called before save is called, because predict creates a new trainer object.

# This saves a model only.
model = NBEATSModel.load_from_checkpoint(model_name, best=True)
model.predict(n=1, series=series_air)
model.save("nbeats_example_2.pt")
# This saves a model and ckpt.
model = NBEATSModel.load_from_checkpoint(model_name, best=True)
model.save("nbeats_example_2.pt")

While, yes I can use the methods from #1501, it really seems like there's some odd behaviour going on here that would be good to look into. I'm wondering if it's even intended for a model to be able to save without creating a corresponding .ckpt file, as the .pt file by itself doesn't seem to be able to be used for anything (although there may be use cases I am not aware of). The documentation for save even states that two files are created. It seems like the solutions would be to:

  • Prevent load_from_checkpoint from creating a model without a trainer attribute
  • (and optionally) Prevent save from only saving a .pt file

or

  • Allow predictions to be made using only a .pt file.

Thanks for looking into this!

@JLC827
Copy link
Author

JLC827 commented Feb 14, 2023

For now a simple fix is just to save the trainer attribute and add it back after calling load_from_checkpoint.

I'm not sure if this means there'll be a conflict between values in the trainer object and on other parts of the model, but if there is I don't think it's important for me because I'm not retraining, just performing inference.

            trainer = model.trainer
            model = NBEATSModel.load_from_checkpoint(model_name, best=True)
            model.trainer = trainer

@madtoinou
Copy link
Collaborator

Sorry if my question not clear enough but the point I was trying to make it that you could just directly use load_from_checkpoint in order to reload your model in your second process, especially if you want to get the one carrying the "best" prefix, without having to call save().

The saving system (which also encompass the checkpoint system) indeed generates two files:

  • a .pt file containing the TorchForecastingModel, which is kind of an empty shell, containing attributes not affected by the training.
  • a .ckpt file containing a PLForecastingModule, which directly inherit from pl.LightningModule. This module contains all the weights as well as the state of the optimizer and learning scheduler (typically used to resume training).

This separation was implemented to avoid saving the TorchForecastingModel at each checkpoint, since it does not change during the training.

Saving a model without trainer is pointless because it would indicate that the model was not trained, and is either containing the initialization weights or weights contained in an existing checkpoint (that could be directly loaded if necessary). You could probably "trick" darts by creating a new trainer using the predict method (as the library does not distinguish "training" trainer from "prediction trainer") but the weights of the model would remain the same and it would be simpler to just load the "best" checkpoint that you're trying to duplicate (by calling save()).

@JLC827
Copy link
Author

JLC827 commented Feb 15, 2023

My use case

I think I understand what you mean now. I am calling save() as I am saving the model to a separate folder, rather than relying on the autosave created by Darts. The idea is that I can optimise the hyperparameters, and save the model manually if the performance of those current hyperparameters exceeds an earlier model with different hyperparameters. As each run of hyperparameter optimisation overwrites the 'best' model autosave I need to save it manually.

I'm not sure if this is actually the best way to do it though, as it doesn't really seem to fit into the Darts workflow. So I might change this, especially as I'm considering only performing hyperparameter optimisation on a subset of the training data. So I'll probably export the best hyperparameters and train up my main model separately with a different model name. Then as you say I don't have to call save, and I can just reload the model from the autosaved checkpoint.

Saving system

Thank you for clarifying the usage of the files generated by the saving system. I was kind of confused there. The source of my confusion was that vanilla Pytorch saves weights in .pt file, whereas Pytorch Lightning saves weights in .ckpt file.

I understand that you can load from the checkpoint that already exists, but it seems like save should create all the files that are necessary to reload the model in the save directory, and that reloading the model in full should not have to rely on files that are stored somewhere else. If this would conflict with existing internal Darts functionality it seems like there should be two different methods created. Or I would like to see a mention in the documentation that this kind of thing can occur, maybe a recommendation to use the autosaves if that is the recommended workflow.

But of course it is up to you. This issue is not a blocker for my project now as this discussion has identified several workarounds. Thank you for the great help and clarifications!

@madtoinou
Copy link
Collaborator

madtoinou commented Feb 15, 2023

Thank you for detailing your use case, everything makes perfect sense to me now :) I agree with you that for such scenario, one would expect save() to create all the files necessary to load the model later, even if the model did not undergo any kind of training.

One approach could be to try to retrieve the "automatic/PL" checkpoint (should be doable using the work_dir and model_name attributes of the model) if the model's trainer is None and copy it to the target path.

And now that I think about it, we should also probably prevent user from saving model using the "prediction trainer"...

WDYT @dennisbader ?

@dennisbader
Copy link
Collaborator

dennisbader commented Feb 19, 2023

As of our discussion @madtoinou ,

  • retrieving the automatic checkpoint and saving it along with the manual save files is for now probably the best solution. If PyTorch-Lightning would allow to save the model without having called fit/predict/validate/test, we can adapt this. I asked them here.
  • saving models from the prediction trainer should not be an issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue waiting for triaging
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants