Skip to content

Commit

Permalink
offer way to load from last checkpoint, with trainer.load(-1)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 22, 2022
1 parent a195615 commit 497b349
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.4',
version = '0.1.5',
license='MIT',
description = 'Video Diffusion - Pytorch',
author = 'Phil Wang',
Expand Down
5 changes: 5 additions & 0 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,11 @@ def save(self, milestone):
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

def load(self, milestone):
if milestone == -1:
all_milestones = [int(p.stem.split('-')[-1]) for p in Path(self.results_folder).glob('**/*.pt')]
assert len(all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)'
milestone = max(all_milestones)

data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

self.step = data['step']
Expand Down

0 comments on commit 497b349

Please sign in to comment.