-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathsolver.py
172 lines (152 loc) · 5.71 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
In Ictu Oculi: Exposing AI Created Fake Videos by Detecting Eye Blinking
IEEE International Workshop on Information Forensics and Security (WIFS), 2018
Yuezun Li, Ming-ching Chang and Siwei Lyu
"""
import tensorflow as tf
import os, cv2
from deep_base.ops import get_restore_var_list
import yaml, os
from easydict import EasyDict as edict
pwd = os.path.dirname(__file__)
class Solver(object):
"""
Solver for training and testing
"""
def __init__(self,
sess,
net,
mode='cnn'):
cfg_file = os.path.join(pwd, 'blink_{}.yml'.format(mode))
with open(cfg_file, 'r') as f:
cfg = edict(yaml.load(f))
self.sess = sess
self.net = net
self.cfg = cfg
self.mode = mode
def init(self):
cfg = self.cfg
self.img_size = cfg.IMG_SIZE
pwd = os.path.dirname(os.path.abspath(__file__))
self.summary_dir = os.path.join(pwd, cfg.SUMMARY_DIR)
if not os.path.exists(self.summary_dir):
os.makedirs(self.summary_dir)
self.model_dir = os.path.join(pwd, cfg.MODEL_DIR)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
self.model_path = os.path.join(self.model_dir, 'model.ckpt')
self.global_step = tf.Variable(0, trainable=False, name='global_step')
self.saver = tf.train.Saver(max_to_keep=5)
# initialize the graph
if self.net.is_train:
self.num_epoch = cfg.TRAIN.NUM_EPOCH
self.learning_rate = cfg.TRAIN.LEARNING_RATE
self.decay_rate = cfg.TRAIN.DECAY_RATE
self.decay_step = cfg.TRAIN.DECAY_STEP
self.net.loss()
self.set_optimizer()
# Add summary
self.loss_summary = tf.summary.scalar('loss_summary', self.net.total_loss)
self.lr_summary = tf.summary.scalar('learning_rate_summary', self.LR)
self.summary = tf.summary.merge([self.loss_summary, self.lr_summary])
self.writer = tf.summary.FileWriter(self.summary_dir, self.sess.graph)
self.sess.run(tf.global_variables_initializer())
self.load()
def train(self, *args):
if self.mode == 'cnn':
return self.train_cnn(images=args[0], labels=args[1])
elif self.mode == 'lrcn':
return self.train_lrcn(seq_tensor=args[0],
len_list=args[1],
state_list=args[2])
else:
raise ValueError('We only support mode = [cnn, lrcn]...')
def test(self, *args):
if self.mode == 'cnn':
return self.test_cnn(images=args[0])
elif self.mode == 'lrcn':
return self.test_lrcn(seq_tensor=args[0],
len_list=args[1])
else:
raise ValueError('We only support mode = [cnn, lrcn]...')
def test_cnn(self, images):
# Check input size
for i, im in enumerate(images):
images[i] = cv2.resize(im, (self.img_size[0], self.img_size[1]))
feed_dict = {
self.net.input: images,
}
fetch_list = [
self.net.prob,
]
return self.sess.run(fetch_list, feed_dict=feed_dict)
def train_cnn(self, images, labels):
feed_dict = {
self.net.input: images,
self.net.gt: labels
}
fetch_list = [
self.train_op,
self.summary,
self.net.prob,
self.net.net_loss,
]
return self.sess.run(fetch_list, feed_dict=feed_dict)
def test_lrcn(self, seq_tensor, len_list):
feed_dict = {
self.net.input: seq_tensor,
self.net.seq_len: len_list,
}
fetch_list = [
self.net.prob,
]
return self.sess.run(fetch_list, feed_dict=feed_dict)
def train_lrcn(self, seq_tensor, len_list, state_list):
feed_dict = {
self.net.input: seq_tensor,
self.net.seq_len: len_list,
self.net.eye_state_gt: state_list
}
fetch_list = [
self.train_op,
self.summary,
self.net.prob,
self.net.net_loss,
]
return self.sess.run(fetch_list, feed_dict=feed_dict)
def save(self, step):
""" Save checkpoints """
save_path = self.saver.save(self.sess, self.model_path, global_step=step)
print('Model {} saved in file.'.format(save_path))
def load(self):
"""Load weights from checkpoint"""
if os.path.isfile(self.model_path + '.meta'):
variables_to_restore = get_restore_var_list(self.model_path)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(self.sess, self.model_path)
print('Loading checkpoint {}'.format(self.model_path))
else:
print('Loading failed.')
def set_optimizer(self):
# Set learning rate decay
self.LR = tf.train.exponential_decay(
learning_rate=self.learning_rate,
global_step=self.global_step,
decay_steps=self.decay_step,
decay_rate=self.decay_rate,
staircase=True
)
if self.cfg.TRAIN.METHOD == 'SGD':
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=self.LR,
)
elif self.cfg.TRAIN.METHOD == 'Adam':
optimizer = tf.train.AdamOptimizer(
learning_rate=self.LR,
)
else:
raise ValueError('We only support [SGD, Adam] right now...')
self.train_op = optimizer.minimize(
loss=self.net.total_loss,
global_step=self.global_step,
var_list=None)