-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add legacy load utility #9166
add legacy load utility #9166
Conversation
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
for more information, see https://pre-commit.ci
…ix/callback-state
for more information, see https://pre-commit.ci
@Borda sorry I missed your question. Of course, I made sure that the legacy patch works with the old checkpoints. Below is the code I used to verify it. Since we know what the legacy format was, the unit tests that I added should suffice. import torch
from argparse import ArgumentParser
from torch.utils.data import Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters(kwargs)
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def step(self, x):
x = self.layer(x)
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return out
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_step_end(self, training_step_outputs):
return training_step_outputs
def training_epoch_end(self, outputs) -> None:
torch.stack([x["loss"] for x in outputs]).mean()
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def run_1_2_7():
# RUN WITH LIGHTNING 1.2.7
# fake data
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
# model
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
trainer = Trainer.from_argparse_args(args, default_root_dir="legacy_logs1.2.7", max_steps=1)
model = BoringModel(**vars(args))
trainer.fit(model, train_data)
print(trainer.checkpoint_callback.best_model_path)
def run_master():
# RUN WITH LIGHTNING 1.2.7
from pytorch_lightning.utilities.migration import pl_legacy_patch
with pl_legacy_patch(): # without this, pickle error!
# can unpickle!
x = torch.load("legacy_logs1.2.7/lightning_logs/version_2/checkpoints/epoch=0-step=0.ckpt")
assert callable(x["hyper_parameters"]["gpus"])
if __name__ == "__main__":
run_master() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great solution and pattern for future similar changes!
@awaelchli mind resolve conflicts? 🐰
cfefee2
to
20ab34a
Compare
for more information, see https://pre-commit.ci
…nto feature/callback-state/legacy
What does this PR do?
Splits changes off #8558
Adds a context manager to handle legacy checkpoints that have pickled attributes in the content that was removed from lightning in new versions. The currently known ones are:
pytorch_lightning.utilities.argparse_utils
(renamed, slated for removal)pytorch_lightning.utilities.argparse._gpus_arg_default
(dead code)We can remove these dead code pieces and instead dynamically patch modules to re-route the imports.
The following works:
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
I made sure I had fun coding 🙃