Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model load_from_checkpoint #525

Closed
ricardorei opened this issue Nov 19, 2019 · 29 comments
Closed

Model load_from_checkpoint #525

ricardorei opened this issue Nov 19, 2019 · 29 comments
Labels
bug Something isn't working

Comments

@ricardorei
Copy link

Describe the bug
When loading a model directly from a checkpoint I get an error "OSError: Checkpoint does not contain hyperparameters. Are your model hyperparameters storedin self.hparams?"

But my model clearly has the hparams.

To Reproduce
Just create a model save a checkpoint and try to load it like explained in the documentation:

pretrained_model = MyLightningModule.load_from_checkpoint(
    checkpoint_path='/path/to/pytorch_checkpoint.ckpt'
)

Possible reason
I found that code in the trainer_io.py class line 301:

 try:
     torch.save(checkpoint, filepath)
except AttributeError:
     if 'hparams' in checkpoint:
     del checkpoint['hparams']
     torch.save(checkpoint, filepath)

Obviously if the code to save the checkpoint deletes de hparams the load checkpoint function will not find that...

Expected behavior
A more concise way to easily load a checkpoint without the need for the load_from_metrics function.

@ricardorei ricardorei added the bug Something isn't working label Nov 19, 2019
@neggert
Copy link
Contributor

neggert commented Nov 19, 2019

IIRC, that was a hack to workaround an edge case where the hparams weren't pickleable. Seems like the original ticket #433 is still open. @williamFalcon do we still need this hack?

@ricardorei
Copy link
Author

Maybe the problem is about saving an object with a lamba function. I see this line in the issue log:

File "/private/home/falc/.local/lib/python3.7/site-packages/torch/serialization.py", line 224, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))

Pickle doesn't allow lambda functions to be saved but if it's this the reason, it's an easy fix I believe.

https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions

@wakandan
Copy link
Contributor

wakandan commented Dec 6, 2019

I'm also having the same issue but I'm not using any lambda

@Borda
Copy link
Member

Borda commented Dec 6, 2019

in general lambda function is not serializable, so all these items should be removed before saving
@wakandan @ricardorei @neggert interested in sending PR?

@antvconst
Copy link
Contributor

antvconst commented Dec 6, 2019

I've encountered the same problem. Looks like the culprit is this line: https://github.com/williamFalcon/pytorch-lightning/blob/6666ca5af39aa2d3e5a483da3d7f6bb76514cc9f/pytorch_lightning/trainer/trainer_io.py#L321
After a bit of debugging I've figured out that the return of vars actually contains all of the bound methods of hparams! This seems to produce a pickling error. In my case it goes like this:

AttributeError: Can't pickle local object 'ArgumentParser.__init__.<locals>.identity'

This exception in its turn gets handled in https://github.com/williamFalcon/pytorch-lightning/blob/6666ca5af39aa2d3e5a483da3d7f6bb76514cc9f/pytorch_lightning/trainer/trainer_io.py#L264
by removing the hparams altogether.

Do we need this vars call at all? The TTNamespace that is normally used here is perfectly picklable on its own.

@neggert
Copy link
Contributor

neggert commented Dec 6, 2019

Ahh, good catch. So this works with an argparse.Namespace, but will fail with a TTNamespace? Want to send a PR that changes it to just pickle the Namespace directly and removes the hacky exception handling?

@expectopatronum
Copy link
Contributor

expectopatronum commented Jan 9, 2020

I just encountered the same issue - does this mean I can't load the models I trained in the past couple of days or is there some workaround until this is fixed?

If you still have access to hparams, here is a quick fix for load_from_checkpoint (I am not suggesting to change the method, just in case someone needs this functionality before it is fixed).

@classmethod
def load_from_checkpoint(cls, checkpoint_path, hparams, map_location=None):
    """
    Primary way of loading model from a checkpoint
    :param checkpoint_path:
    :param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
    :return:
    """

    if map_location is not None:
        checkpoint = torch.load(checkpoint_path, map_location=map_location)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

    # try:
    #     ckpt_hparams = checkpoint['hparams']
    # except KeyError:
    #     raise IOError(
    #         "Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
    #         "in self.hparams?"
    #     )
    # hparams = Namespace(**ckpt_hparams)

    # load the state_dict on the model automatically
    model = cls(hparams)
    model.load_state_dict(checkpoint['state_dict'])

    # give model a chance to load something
    model.on_load_checkpoint(checkpoint)

    return model

This is how to use it:

from model_utils.model_definitions.my_classifier import MyCoolModule
from argparse import Namespace

checkpoint_path='/home/verena/.../checkpoints/_ckpt_epoch_18.ckpt'

hparams = {
"batch_size":32,
...
}
namespace = Namespace(**hparams)

model = MyCoolModule.load_from_checkpoint(checkpoint_path=checkpoint_path, hparams=namespace)

@esadr
Copy link

esadr commented Jan 9, 2020

I faced the same issue, thanks @expectopatronum for the workaround, it helps me a lot.

@neggert
Copy link
Contributor

neggert commented Jan 9, 2020

Here's a solution that doesn't require modifying your model (from #599).

model = MyModel(whatever, args, you, want)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

@Ir1d
Copy link
Contributor

Ir1d commented Feb 20, 2020

Hi guys.. it seems that in my case the load_from_checkpoint function didn't load the params for me..
I had to use the code posted above by @neggert instead. ( manually load_state_dict)
Hope it helps.

@williamFalcon
Copy link
Contributor

solved in 0.7.1
if not, we can reopen

@pertschuk
Copy link
Contributor

@williamFalcon
Based on the documentation I found (this seemed most relevant) I'm not totally sure best practices for how saving / loading is supposed to work:

I initialized a trainer as follows (disabled tensorboard because it was erring due to TF dep):

trainer = pl.Trainer(gpus=1, val_check_interval=0.25, use_amp=True, logger=False)

Then observed that the models were saved in ./checkpoints/ by default and thus assumed when I restarted training for same dir it would load the weights. But that was not the case (it seemed), instead I got this message:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:82: UserWarning: Checkpoint directory ~/answerbot/accuracy/checkpoints exists and is not empty with save_top_k != 0.All files in this directory will be deleted when a checkpoint is saved!

What is best practice for loading/ saving model with no logger? is this possible? Thank you in advance.

@Ir1d
Copy link
Contributor

Ir1d commented Mar 8, 2020

@pertschuk As i remember that auto loading is disabled in latest master, can you check your lightning version?

@pertschuk
Copy link
Contributor

@Ir1d 0.7.1

@Ir1d
Copy link
Contributor

Ir1d commented Mar 8, 2020

@pertschuk I believe that auto restoring is removed, and you should load the weights on your own. The doc you linked is not updated yet.

@pertschuk
Copy link
Contributor

pertschuk commented Mar 8, 2020

@Ir1d Is there a callback or function to override to integrate weight loading / saving with PL checkpointing?

For example I'm training a huggingface/transformers model and want to save checkpoints in that format.

@Ir1d
Copy link
Contributor

Ir1d commented Mar 8, 2020

@pertschuk sorry, I dont understand transformer model. You see, PL checkpoints is a wrap of a lot of things.YOu can get this by calling dict.keys(). And you'll find that model.load_state_dict(dict['state_dict']) is exactly the weight loading for pure pytorch. Hope this helps. Also, the above script by @neggert works perfectly for me .
`

@ricardorei
Copy link
Author

ricardorei commented Mar 9, 2020

@Ir1d Is there a callback or function to override to integrate weight loading / saving with PL checkpointing?

For example I'm training a huggingface/transformers model and want to save checkpoints in that format.

I also have been using lightning with pytorch transformers. I save checkpoints normally without changing anything in lightning.

If for some reason I need to resume training from a given checkpoint I just use the resume_from_checkpoint Trainer attribute.

If I just want to load weights from a pretrained model I use the load_weightsflag and call the function load_weights_from_checkpoint that is implemented in my "base" model.

parser = HyperOptArgumentParser(strategy="random_search",  add_help=False)
parser.add_argument(
        "--resume_from_checkpoint", 
        default=None,
        type=str,
        help=(
            "To resume training from a specific checkpoint pass in the path here."
            "(e.g. 'some/path/to/my_checkpoint.ckpt')"
        ),
)
parser.add_argument(
        "--load_weights",
        default=None,
        type=str,
        help=(
            "Loads the model weights from a given checkpoint while resume_from_checkpoint "
            "resumes the entire training session (model/optimizer/scheduler etc..). "
            "If architectures are different this will load only the common module parts."
        ),
)

.....


trainer = Trainer(
        logger=setup_testube_logger(),
        checkpoint_callback=True,
        early_stop_callback=early_stop_callback,
        default_save_path="experiments/",
        gradient_clip_val=hparams.gradient_clip_val,
        gpus=hparams.gpus,
        show_progress_bar=False,
        overfit_pct=hparams.overfit_pct,
        check_val_every_n_epoch=hparams.check_val_every_n_epoch,
        fast_dev_run=False,
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        max_epochs=hparams.max_epochs,
        min_epochs=hparams.min_epochs,
        train_percent_check=hparams.train_percent_check,
        val_percent_check=hparams.val_percent_check,
        val_check_interval=hparams.val_check_interval,
        log_save_interval=hparams.log_save_interval,
        row_log_interval=hparams.row_log_interval,
        distributed_backend=hparams.distributed_backend,
        precision=hparams.precision,
        weights_summary=hparams.weights_summary,
        resume_from_checkpoint=hparams.resume_from_checkpoint,
        profiler=hparams.profiler,
        log_gpu_memory="all",
    )

model = build_model(hparams)
if hparams.load_weights:
    model.load_weights_from_checkpoint(hparams.load_weights)

log.info(f"{model.__class__.__name__} train starting:")
trainer.fit(model)

My load_weights_from_checkpoint function:

def load_weights_from_checkpoint(self, checkpoint: str) -> None:
        """ Function that loads the weights from a given checkpoint file. 
        Note:
            If the checkpoint model architecture is different then `self`, only
            the common parts will be loaded.

        :param checkpoint: Path to the checkpoint containing the weights to be loaded.
        """
        log.info(f"loading model weights from {checkpoint}.")
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,)
        pretrained_dict = checkpoint["state_dict"]
        model_dict = self.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        self.load_state_dict(pretrained_dict)

Does this solve your problem of loading pre-trained weights and resuming training sessions?

@ricardorei
Copy link
Author

PS: The initial purpose of this issue was solved some versions ago and it's now working.

@pertschuk
Copy link
Contributor

@ricardorei yes it does, thank you

@pertschuk
Copy link
Contributor

pertschuk commented Mar 10, 2020

EDIT: this seems to be a apex/amp fp16 precision bug

Okay sorry to keep posting here but have run into VERY confusing issue and would appreciate any ideas for guidance @ricardorei. I am trying to export the models to save in huggingface/transformers format for reuse and the saved model appears to have identical state_dict to the model wrapped in a Pytorch Lightning module, but the results of passing the same inputs through are not the same?

import os
os.makedirs('./test-1', exist_ok=True)
# model is PytorchLightning Module and model.model = Transformers model
model.model.save_pretrained('./test-1')
loaded_model = AlbertForSequenceClassification.from_pretrained('./test-1')
loaded_model.cuda()

for k, v in loaded_model.state_dict().items():
    assert torch.all(model.model.state_dict()[k].eq(v)) # this assert works
    
correct = 0
total = 0

def call_model(inputs, model):
    return model(inputs['input_ids'].cuda(), 
                 token_type_ids=inputs['token_type_ids'].cuda(),
                 attention_mask=inputs['attention_mask'].cuda())[0]

for ex in get_data():
    label = 0 if ex['is_impossible'] else 1
    inputs = tokenizer.encode_plus(ex['question'],  
                                   ex['context'],
                                   add_special_tokens=True,
                                   max_length=256,
                                   return_tensors='pt')
    
    lightning_logits = call_model(inputs, model.model)
    transformers_logits = call_model(inputs, loaded_model)
    
    assert torch.all(lightning_logits.eq(transformers_logits)) # this assert fails ???

Note: I also tried saving / loading the state_dict for the PytorchLightning module itself and same problem, state dicts match up but different outputs during inference? I'm totally lost.

@ricardorei
Copy link
Author

@pertschuk you should check how big is the difference. I noticed some small differences when using big transformer models. I actually have an issue in lightning regarding the subject and in Fairseq.

facebookresearch/fairseq#1605
#669

If the difference is really small this should not affect your results and is basically a precision issue.

@pertschuk
Copy link
Contributor

pertschuk commented Mar 10, 2020

@ricardorei unfortunately it was a very large error but I fixed by disabling mixed precision training if anyone else finds this thread. Frustrating as training is much slower now.... but at least it works!

@ssakhavi
Copy link
Contributor

For those that have the issue of not being able to load the model using the load_from_checkpoint method, I tried using the workaround here. After playing around, I noticed that there is a problem with the way the state_dict is being loaded. After using the code snippet from a post on PyTorch's forum, I managed to solve the problem.

To be more specific, the weights were being loaded into the model but there was no error message.
@williamFalcon Please take note of this.

@ssakhavi
Copy link
Contributor

For those that have the issue of not being able to load the model using the load_from_checkpoint method, I tried using the workaround here. After playing around, I noticed that there is a problem with the way the state_dict is being loaded. After using the code snippet from a post on PyTorch's forum, I managed to solve the problem.

To be more specific, the weights were being loaded into the model but there was no error message.
@williamFalcon Please take note of this.

I checked: The problem arises when we are using a self.model value to define our forward pass and also our parameters.

@deng-cy
Copy link
Contributor

deng-cy commented May 31, 2020

I encountered the same issue when passing self.model

For those that have the issue of not being able to load the model using the load_from_checkpoint method, I tried using the workaround here. After playing around, I noticed that there is a problem with the way the state_dict is being loaded. After using the code snippet from a post on PyTorch's forum, I managed to solve the problem.
To be more specific, the weights were being loaded into the model but there was no error message.
@williamFalcon Please take note of this.

I checked: The problem arises when we are using a self.model value to define our forward pass and also our parameters.

@BenisonSam
Copy link

Here's a solution that doesn't require modifying your model (from #599).

model = MyModel(whatever, args, you, want)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

For some reason even after the fix I am forced to use quoted solution. The normal load_from_checkpoint function still gives me pytorch_lightning.utilities.exceptions.MisconfigurationException: Checkpoint contains hyperparameters but MyModule's __init__ is missing the argument 'hparams'. Are you loading the correct checkpoint?

@BenisonSam
Copy link

solved in 0.7.1
if not, we can reopen

My version of PL is 0.7.6

@Borda
Copy link
Member

Borda commented Jun 10, 2020

solved in 0.7.1
if not, we can reopen

My version of PL is 0.7.6

mind try v0.8rc1 or latest master?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests