-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model load_from_checkpoint #525
Comments
IIRC, that was a hack to workaround an edge case where the hparams weren't pickleable. Seems like the original ticket #433 is still open. @williamFalcon do we still need this hack? |
Maybe the problem is about saving an object with a lamba function. I see this line in the issue log:
Pickle doesn't allow lambda functions to be saved but if it's this the reason, it's an easy fix I believe. https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions |
I'm also having the same issue but I'm not using any lambda |
in general lambda function is not serializable, so all these items should be removed before saving |
I've encountered the same problem. Looks like the culprit is this line: https://github.com/williamFalcon/pytorch-lightning/blob/6666ca5af39aa2d3e5a483da3d7f6bb76514cc9f/pytorch_lightning/trainer/trainer_io.py#L321
This exception in its turn gets handled in https://github.com/williamFalcon/pytorch-lightning/blob/6666ca5af39aa2d3e5a483da3d7f6bb76514cc9f/pytorch_lightning/trainer/trainer_io.py#L264 Do we need this |
Ahh, good catch. So this works with an |
I just encountered the same issue - If you still have access to hparams, here is a quick fix for @classmethod
def load_from_checkpoint(cls, checkpoint_path, hparams, map_location=None):
"""
Primary way of loading model from a checkpoint
:param checkpoint_path:
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
:return:
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# try:
# ckpt_hparams = checkpoint['hparams']
# except KeyError:
# raise IOError(
# "Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
# "in self.hparams?"
# )
# hparams = Namespace(**ckpt_hparams)
# load the state_dict on the model automatically
model = cls(hparams)
model.load_state_dict(checkpoint['state_dict'])
# give model a chance to load something
model.on_load_checkpoint(checkpoint)
return model This is how to use it: from model_utils.model_definitions.my_classifier import MyCoolModule
from argparse import Namespace
checkpoint_path='/home/verena/.../checkpoints/_ckpt_epoch_18.ckpt'
hparams = {
"batch_size":32,
...
}
namespace = Namespace(**hparams)
model = MyCoolModule.load_from_checkpoint(checkpoint_path=checkpoint_path, hparams=namespace) |
I faced the same issue, thanks @expectopatronum for the workaround, it helps me a lot. |
Here's a solution that doesn't require modifying your model (from #599). model = MyModel(whatever, args, you, want)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict']) |
Hi guys.. it seems that in my case the |
solved in 0.7.1 |
@williamFalcon I initialized a trainer as follows (disabled tensorboard because it was erring due to TF dep): trainer = pl.Trainer(gpus=1, val_check_interval=0.25, use_amp=True, logger=False) Then observed that the models were saved in
What is best practice for loading/ saving model with no logger? is this possible? Thank you in advance. |
@pertschuk As i remember that auto loading is disabled in latest master, can you check your lightning version? |
@Ir1d 0.7.1 |
@pertschuk I believe that auto restoring is removed, and you should load the weights on your own. The doc you linked is not updated yet. |
@Ir1d Is there a callback or function to override to integrate weight loading / saving with PL checkpointing? For example I'm training a huggingface/transformers model and want to save checkpoints in that format. |
@pertschuk sorry, I dont understand transformer model. You see, PL checkpoints is a wrap of a lot of things.YOu can get this by calling |
I also have been using lightning with pytorch transformers. I save checkpoints normally without changing anything in lightning. If for some reason I need to resume training from a given checkpoint I just use the If I just want to load weights from a pretrained model I use the parser = HyperOptArgumentParser(strategy="random_search", add_help=False)
parser.add_argument(
"--resume_from_checkpoint",
default=None,
type=str,
help=(
"To resume training from a specific checkpoint pass in the path here."
"(e.g. 'some/path/to/my_checkpoint.ckpt')"
),
)
parser.add_argument(
"--load_weights",
default=None,
type=str,
help=(
"Loads the model weights from a given checkpoint while resume_from_checkpoint "
"resumes the entire training session (model/optimizer/scheduler etc..). "
"If architectures are different this will load only the common module parts."
),
)
.....
trainer = Trainer(
logger=setup_testube_logger(),
checkpoint_callback=True,
early_stop_callback=early_stop_callback,
default_save_path="experiments/",
gradient_clip_val=hparams.gradient_clip_val,
gpus=hparams.gpus,
show_progress_bar=False,
overfit_pct=hparams.overfit_pct,
check_val_every_n_epoch=hparams.check_val_every_n_epoch,
fast_dev_run=False,
accumulate_grad_batches=hparams.accumulate_grad_batches,
max_epochs=hparams.max_epochs,
min_epochs=hparams.min_epochs,
train_percent_check=hparams.train_percent_check,
val_percent_check=hparams.val_percent_check,
val_check_interval=hparams.val_check_interval,
log_save_interval=hparams.log_save_interval,
row_log_interval=hparams.row_log_interval,
distributed_backend=hparams.distributed_backend,
precision=hparams.precision,
weights_summary=hparams.weights_summary,
resume_from_checkpoint=hparams.resume_from_checkpoint,
profiler=hparams.profiler,
log_gpu_memory="all",
)
model = build_model(hparams)
if hparams.load_weights:
model.load_weights_from_checkpoint(hparams.load_weights)
log.info(f"{model.__class__.__name__} train starting:")
trainer.fit(model) My def load_weights_from_checkpoint(self, checkpoint: str) -> None:
""" Function that loads the weights from a given checkpoint file.
Note:
If the checkpoint model architecture is different then `self`, only
the common parts will be loaded.
:param checkpoint: Path to the checkpoint containing the weights to be loaded.
"""
log.info(f"loading model weights from {checkpoint}.")
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,)
pretrained_dict = checkpoint["state_dict"]
model_dict = self.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
self.load_state_dict(pretrained_dict) Does this solve your problem of loading pre-trained weights and resuming training sessions? |
PS: The initial purpose of this issue was solved some versions ago and it's now working. |
@ricardorei yes it does, thank you |
EDIT: this seems to be a apex/amp fp16 precision bug Okay sorry to keep posting here but have run into VERY confusing issue and would appreciate any ideas for guidance @ricardorei. I am trying to export the models to save in huggingface/transformers format for reuse and the saved model appears to have identical state_dict to the model wrapped in a Pytorch Lightning module, but the results of passing the same inputs through are not the same? import os
os.makedirs('./test-1', exist_ok=True)
# model is PytorchLightning Module and model.model = Transformers model
model.model.save_pretrained('./test-1')
loaded_model = AlbertForSequenceClassification.from_pretrained('./test-1')
loaded_model.cuda()
for k, v in loaded_model.state_dict().items():
assert torch.all(model.model.state_dict()[k].eq(v)) # this assert works
correct = 0
total = 0
def call_model(inputs, model):
return model(inputs['input_ids'].cuda(),
token_type_ids=inputs['token_type_ids'].cuda(),
attention_mask=inputs['attention_mask'].cuda())[0]
for ex in get_data():
label = 0 if ex['is_impossible'] else 1
inputs = tokenizer.encode_plus(ex['question'],
ex['context'],
add_special_tokens=True,
max_length=256,
return_tensors='pt')
lightning_logits = call_model(inputs, model.model)
transformers_logits = call_model(inputs, loaded_model)
assert torch.all(lightning_logits.eq(transformers_logits)) # this assert fails ??? Note: I also tried saving / loading the state_dict for the PytorchLightning module itself and same problem, state dicts match up but different outputs during inference? I'm totally lost. |
@pertschuk you should check how big is the difference. I noticed some small differences when using big transformer models. I actually have an issue in lightning regarding the subject and in Fairseq. facebookresearch/fairseq#1605 If the difference is really small this should not affect your results and is basically a precision issue. |
@ricardorei unfortunately it was a very large error but I fixed by disabling mixed precision training if anyone else finds this thread. Frustrating as training is much slower now.... but at least it works! |
For those that have the issue of not being able to load the model using the To be more specific, the weights were being loaded into the model but there was no error message. |
I checked: The problem arises when we are using a |
I encountered the same issue when passing
|
For some reason even after the fix I am forced to use quoted solution. The normal load_from_checkpoint function still gives me |
My version of PL is |
mind try v0.8rc1 or latest |
Describe the bug
When loading a model directly from a checkpoint I get an error "OSError: Checkpoint does not contain hyperparameters. Are your model hyperparameters storedin self.hparams?"
But my model clearly has the hparams.
To Reproduce
Just create a model save a checkpoint and try to load it like explained in the documentation:
Possible reason
I found that code in the trainer_io.py class line 301:
Obviously if the code to save the checkpoint deletes de hparams the load checkpoint function will not find that...
Expected behavior
A more concise way to easily load a checkpoint without the need for the load_from_metrics function.
The text was updated successfully, but these errors were encountered: