Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a Metric to LightningModule prevents loading of a checkpoint / weights #4361

Closed
NumesSanguis opened this issue Oct 26, 2020 · 5 comments · Fixed by #4482
Closed

Adding a Metric to LightningModule prevents loading of a checkpoint / weights #4361

NumesSanguis opened this issue Oct 26, 2020 · 5 comments · Fixed by #4482
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on
Milestone

Comments

@NumesSanguis
Copy link
Contributor

NumesSanguis commented Oct 26, 2020

🐛 Bug

Adding a Metric like Accuracy prevents the loading of a .ckpt due to missing keys:

RuntimeError: Error(s) in loading state_dict for BoringModel2:
	Missing key(s) in state_dict: "pl_accuracy.correct", "pl_accuracy.total". 

Please reproduce using the BoringModel and post here

https://colab.research.google.com/drive/1km0SE2TVRuif6R4uF8vihqUl-FshXu_u?usp=sharing

To Reproduce

class BoringModel2(BoringModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # NEW adding metric
        self.pl_accuracy = Accuracy()

model2 = BoringModel2.load_from_checkpoint(checkpoint_path="example.ckpt")

Expected behavior

Being able to add new metrics to a model without changing the layers (e.g. in transfer learning settings), but still be able to load the weights of a model without those metrics.

Actual behavior

RuntimeError                              Traceback (most recent call last)

<ipython-input-15-50d6f204e428> in <module>()
----> 1 model2 = BoringModel2.load_from_checkpoint(checkpoint_path="example.ckpt")  # , strict=False

2 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1043         if len(error_msgs) > 0:
   1044             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1045                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1046         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1047 

RuntimeError: Error(s) in loading state_dict for BoringModel2:
	Missing key(s) in state_dict: "pl_accuracy.correct", "pl_accuracy.total". 

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py

Colab:

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0+cu101
    • pytorch-lightning: 1.0.3
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

@NumesSanguis NumesSanguis added bug Something isn't working help wanted Open to be worked on labels Oct 26, 2020
@NumesSanguis
Copy link
Contributor Author

NumesSanguis commented Oct 26, 2020

Adding strict=False to:

model2 = BoringModel2.load_from_checkpoint(checkpoint_path="example.ckpt", strict=False)

does allow to load the weights (thank you @ananthsub ), but this will increase the risk of improper variables in your .ckpt's that would have been noticed if strict=True.
This is not something you want in bigger projects where you want to make sure that keys match.


Since the metric Accuracy uses default values, I have no idea what "pl_accuracy.correct" or "pl_accuracy.total" should be initialized with, and even if I knew, needing to provide these values to load_from_checkpoint() is probably not a good approach.

@NumesSanguis
Copy link
Contributor Author

NumesSanguis commented Oct 26, 2020

Using BoringModel as a submodule is also not desirable, as you cannot do simply:

class BoringModelTransfer(LightningModule):

    def __init__(self):
        super().__init__()
        self.boring = BoringModel()

    def forward(self, x):
        return self.boring(x)

You would need to copy all the other functions, such as training_step() as well. Lightning should take away engineering; not increasing it (if you would take this approach).

Also, hparams will be in self.boring.hparams instead of self.hparams, giving new issues.

@Shreeyak
Copy link

Best approach might be to modify the old checkpoints to make them compatible to a common format.
Assume the keys epoch and total_iter_num don't exist, here's a sample code to add those params to the checkpoint (with a default value of 0):

checkpoint = torch.load(path_checkpoint, map_location='cpu')
checkpoint['epoch'] = 0
checkpoint['total_iter_num'] = 0

torch.save(checkpoint, path_new_checkpoint)

@SkafteNicki
Copy link
Member

Metrics are simple nn.Module and therefore also have there own state_dict. In the case of Accuracy metric, the correct and total tensors refer to the number of correctly labeled data points seen and the total number of data points seen. When initialized both of these numbers are 0. Having them as state_dict, for example allows the metric to be computed on some fraction of the dataset, saved, loaded and compute on the remaining part, and still get the right result.
Therefore, Metrics should be treated as any other nn.Module, reducing the options to either loading with strict=True or adding the missing state_dict keys to the checkpoint (the default values for any metrics can be access through metric._defaults).

@teddykoker
Copy link
Contributor

As @SkafteNicki said, this comes down to more of a limitation of torch.load if you add new nn.Modules as attributes to a model, you will not be able to load an older version without specifying strict=False.

If you do not want this behavior when implementing your own metric, you can simply pass persistent=False when you call add_state. (only works in PyTorch 1.7+).

Do people think every metric should have a persistent flag that would apply this behavior to any state variable? This could be possible and solve your issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants