diff --git a/zamba/pytorch_lightning/utils.py b/zamba/pytorch_lightning/utils.py index 40473dca..be397b76 100644 --- a/zamba/pytorch_lightning/utils.py +++ b/zamba/pytorch_lightning/utils.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import pytorch_lightning as pl from pytorch_lightning import LightningDataModule, LightningModule from sklearn.metrics import f1_score, top_k_accuracy_score, accuracy_score import torch @@ -273,9 +274,17 @@ def configure_optimizers(self): } def to_disk(self, path: os.PathLike): + """Save out model weights to a checkpoint file on disk. + + Note: this does not include callbacks, optimizer_states, or lr_schedulers. + To include those, use `Trainer.save_checkpoint()` instead. + """ + checkpoint = { "state_dict": self.state_dict(), "hyper_parameters": self.hparams, + "global_step": self.global_step, + "pytorch-lightning_version": pl.__version__, } torch.save(checkpoint, path)