-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
29 lines (22 loc) · 841 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import numpy as np
import torch
import torch.nn as nn
## 네트워크 저장하기
def save(ckpt_dir, net, optim, epoch):
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},
"./%s/model_epoch%d.pth" % (ckpt_dir, epoch))
## 네트워크 불러오기
def load(ckpt_dir, net, optim):
if not os.path.exists(ckpt_dir):
epoch =0
return net, optim, epoch
ckpt_lst = os.listdir(ckpt_dir)
ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
dict_model = torch.load('./%s/%s' % (ckpt_dir, ckpt_lst[-1]))
net.load_state_dict(dict_model['net'])
optim.load_state_dict(dict_model['optim'])
epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])
return net, optim, epoch