-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils
60 lines (53 loc) · 2.03 KB
/
utils
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import cv2 as cv
import os
import torch
def load_image(path):
return cv.cvtColor(cv.imread(path),cv.COLOR_BGR2RGB).astype(np.float32)/255.0
def save_image(img, path):
img = torch.clamp(img, min = 0.0, max = 1.0)
out = img.cpu().squeeze(0).numpy().transpose([1,2,0])
out = 255.0*out.astype(np.float32)
#print('*')
cv.imwrite(path, cv.cvtColor(out, cv.COLOR_RGB2BGR))
def make_path(path):
if not os.path.exists(path):
os.makedirs(path)
def augmentation(inp,inf):
if np.random.randint(2,size=1)[0] == 1:
inp = np.flip(inp,axis=0)
inf = np.flip(inf,axis=0)
if np.random.randint(2,size=1)[0] == 1:
inp = np.flip(inp,axis=1)
inf = np.flip(inf,axis=1)
if np.random.randint(2,size=1)[0] == 1:
inp = np.transpose(inp,(1,0,2))
inf = np.transpose(inf,(1,0,2))
return inp, inf
def save_checkpoint(net, optimizer, epoch, checkpoint_path=None):
checkpoint = {
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, checkpoint_path)
def load_checkpoint(net, optimizer=None, epoch = None, checkpoint_path=None):
if checkpoint_path:
print('Load checkpoint: {}'.format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage.cuda())
net.load_state_dict(checkpoint['state_dict'])
if optimizer:
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
def token_cmp1(elem):
#if elem[0].split('/')[-2] == elem1[0].split('/')[-2]:
# return elem[0].split('/')[-1] < elem1[0].split('/')[-1]
#return elem[0].split('/')[-2] < elem1[0].split('/')[-2]
string = elem[0].split('/')[-1]
#num = string.split('.')[0:-1]
num = string[0:-4]
return float(num)
def token_cmp(elem):
#if elem[0].split('/')[-2] == elem1[0].split('/')[-2]:
# return elem[0].split('/')[-1] < elem1[0].split('/')[-1]
return elem[0].split('/')[-2]