Skip to content

Commit

Permalink
add needed attributes to dictionary for v1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
ejm714 committed Feb 10, 2023
1 parent 22975e8 commit 3090b65
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions zamba/pytorch_lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule
from sklearn.metrics import f1_score, top_k_accuracy_score, accuracy_score
import torch
Expand Down Expand Up @@ -273,9 +274,17 @@ def configure_optimizers(self):
}

def to_disk(self, path: os.PathLike):
"""Save out model weights to a checkpoint file on disk.
Note: this does not include callbacks, optimizer_states, or lr_schedulers.
To include those, use `Trainer.save_checkpoint()` instead.
"""

checkpoint = {
"state_dict": self.state_dict(),
"hyper_parameters": self.hparams,
"global_step": self.global_step,
"pytorch-lightning_version": pl.__version__,
}
torch.save(checkpoint, path)

Expand Down

0 comments on commit 3090b65

Please sign in to comment.