-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
30 lines (25 loc) · 811 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import torch
def save_checkpoint(save_path, model, criterion, optimizer, scheduler, epoch):
torch.save(
model.state_dict(),
os.path.join(save_path, f'model_{epoch+1:05d}.pt')
)
if criterion:
torch.save(
criterion.state_dict(),
os.path.join(save_path, f'criterion_{epoch+1:05d}.pt')
)
if optimizer:
torch.save(
optimizer.state_dict(),
os.path.join(save_path, f'optimizer_{epoch+1:05d}.pt')
)
if scheduler:
torch.save(
scheduler.state_dict(),
os.path.join(save_path, f'scheduler_{epoch+1:05d}.pt')
)
def load_checkpoint(model, model_path, device):
if model_path:
model.load_state_dict(torch.load(model_path, map_location=device))