-
Notifications
You must be signed in to change notification settings - Fork 559
/
SfMLearner.py
337 lines (311 loc) · 15.5 KB
/
SfMLearner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
from __future__ import division
import os
import time
import math
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from data_loader import DataLoader
from nets import *
from utils import *
class SfMLearner(object):
def __init__(self):
pass
def build_train_graph(self):
opt = self.opt
loader = DataLoader(opt.dataset_dir,
opt.batch_size,
opt.img_height,
opt.img_width,
opt.num_source,
opt.num_scales)
with tf.name_scope("data_loading"):
tgt_image, src_image_stack, intrinsics = loader.load_train_batch()
tgt_image = self.preprocess_image(tgt_image)
src_image_stack = self.preprocess_image(src_image_stack)
with tf.name_scope("depth_prediction"):
pred_disp, depth_net_endpoints = disp_net(tgt_image,
is_training=True)
pred_depth = [1./d for d in pred_disp]
with tf.name_scope("pose_and_explainability_prediction"):
pred_poses, pred_exp_logits, pose_exp_net_endpoints = \
pose_exp_net(tgt_image,
src_image_stack,
do_exp=(opt.explain_reg_weight > 0),
is_training=True)
with tf.name_scope("compute_loss"):
pixel_loss = 0
exp_loss = 0
smooth_loss = 0
tgt_image_all = []
src_image_stack_all = []
proj_image_stack_all = []
proj_error_stack_all = []
exp_mask_stack_all = []
for s in range(opt.num_scales):
if opt.explain_reg_weight > 0:
# Construct a reference explainability mask (i.e. all
# pixels are explainable)
ref_exp_mask = self.get_reference_explain_mask(s)
# Scale the source and target images for computing loss at the
# according scale.
curr_tgt_image = tf.image.resize_area(tgt_image,
[int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
curr_src_image_stack = tf.image.resize_area(src_image_stack,
[int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
if opt.smooth_weight > 0:
smooth_loss += opt.smooth_weight/(2**s) * \
self.compute_smooth_loss(pred_disp[s])
for i in range(opt.num_source):
# Inverse warp the source image to the target image frame
curr_proj_image = projective_inverse_warp(
curr_src_image_stack[:,:,:,3*i:3*(i+1)],
tf.squeeze(pred_depth[s], axis=3),
pred_poses[:,i,:],
intrinsics[:,s,:,:])
curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image)
# Cross-entropy loss as regularization for the
# explainability prediction
if opt.explain_reg_weight > 0:
curr_exp_logits = tf.slice(pred_exp_logits[s],
[0, 0, 0, i*2],
[-1, -1, -1, 2])
exp_loss += opt.explain_reg_weight * \
self.compute_exp_reg_loss(curr_exp_logits,
ref_exp_mask)
curr_exp = tf.nn.softmax(curr_exp_logits)
# Photo-consistency loss weighted by explainability
if opt.explain_reg_weight > 0:
pixel_loss += tf.reduce_mean(curr_proj_error * \
tf.expand_dims(curr_exp[:,:,:,1], -1))
else:
pixel_loss += tf.reduce_mean(curr_proj_error)
# Prepare images for tensorboard summaries
if i == 0:
proj_image_stack = curr_proj_image
proj_error_stack = curr_proj_error
if opt.explain_reg_weight > 0:
exp_mask_stack = tf.expand_dims(curr_exp[:,:,:,1], -1)
else:
proj_image_stack = tf.concat([proj_image_stack,
curr_proj_image], axis=3)
proj_error_stack = tf.concat([proj_error_stack,
curr_proj_error], axis=3)
if opt.explain_reg_weight > 0:
exp_mask_stack = tf.concat([exp_mask_stack,
tf.expand_dims(curr_exp[:,:,:,1], -1)], axis=3)
tgt_image_all.append(curr_tgt_image)
src_image_stack_all.append(curr_src_image_stack)
proj_image_stack_all.append(proj_image_stack)
proj_error_stack_all.append(proj_error_stack)
if opt.explain_reg_weight > 0:
exp_mask_stack_all.append(exp_mask_stack)
total_loss = pixel_loss + smooth_loss + exp_loss
with tf.name_scope("train_op"):
train_vars = [var for var in tf.trainable_variables()]
optim = tf.train.AdamOptimizer(opt.learning_rate, opt.beta1)
# self.grads_and_vars = optim.compute_gradients(total_loss,
# var_list=train_vars)
# self.train_op = optim.apply_gradients(self.grads_and_vars)
self.train_op = slim.learning.create_train_op(total_loss, optim)
self.global_step = tf.Variable(0,
name='global_step',
trainable=False)
self.incr_global_step = tf.assign(self.global_step,
self.global_step+1)
# Collect tensors that are useful later (e.g. tf summary)
self.pred_depth = pred_depth
self.pred_poses = pred_poses
self.steps_per_epoch = loader.steps_per_epoch
self.total_loss = total_loss
self.pixel_loss = pixel_loss
self.exp_loss = exp_loss
self.smooth_loss = smooth_loss
self.tgt_image_all = tgt_image_all
self.src_image_stack_all = src_image_stack_all
self.proj_image_stack_all = proj_image_stack_all
self.proj_error_stack_all = proj_error_stack_all
self.exp_mask_stack_all = exp_mask_stack_all
def get_reference_explain_mask(self, downscaling):
opt = self.opt
tmp = np.array([0,1])
ref_exp_mask = np.tile(tmp,
(opt.batch_size,
int(opt.img_height/(2**downscaling)),
int(opt.img_width/(2**downscaling)),
1))
ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32)
return ref_exp_mask
def compute_exp_reg_loss(self, pred, ref):
l = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.reshape(ref, [-1, 2]),
logits=tf.reshape(pred, [-1, 2]))
return tf.reduce_mean(l)
def compute_smooth_loss(self, pred_disp):
def gradient(pred):
D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :]
D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :]
return D_dx, D_dy
dx, dy = gradient(pred_disp)
dx2, dxdy = gradient(dx)
dydx, dy2 = gradient(dy)
return tf.reduce_mean(tf.abs(dx2)) + \
tf.reduce_mean(tf.abs(dxdy)) + \
tf.reduce_mean(tf.abs(dydx)) + \
tf.reduce_mean(tf.abs(dy2))
def collect_summaries(self):
opt = self.opt
tf.summary.scalar("total_loss", self.total_loss)
tf.summary.scalar("pixel_loss", self.pixel_loss)
tf.summary.scalar("smooth_loss", self.smooth_loss)
tf.summary.scalar("exp_loss", self.exp_loss)
for s in range(opt.num_scales):
tf.summary.histogram("scale%d_depth" % s, self.pred_depth[s])
tf.summary.image('scale%d_disparity_image' % s, 1./self.pred_depth[s])
tf.summary.image('scale%d_target_image' % s, \
self.deprocess_image(self.tgt_image_all[s]))
for i in range(opt.num_source):
if opt.explain_reg_weight > 0:
tf.summary.image(
'scale%d_exp_mask_%d' % (s, i),
tf.expand_dims(self.exp_mask_stack_all[s][:,:,:,i], -1))
tf.summary.image(
'scale%d_source_image_%d' % (s, i),
self.deprocess_image(self.src_image_stack_all[s][:, :, :, i*3:(i+1)*3]))
tf.summary.image('scale%d_projected_image_%d' % (s, i),
self.deprocess_image(self.proj_image_stack_all[s][:, :, :, i*3:(i+1)*3]))
tf.summary.image('scale%d_proj_error_%d' % (s, i),
self.deprocess_image(tf.clip_by_value(self.proj_error_stack_all[s][:,:,:,i*3:(i+1)*3] - 1, -1, 1)))
tf.summary.histogram("tx", self.pred_poses[:,:,0])
tf.summary.histogram("ty", self.pred_poses[:,:,1])
tf.summary.histogram("tz", self.pred_poses[:,:,2])
tf.summary.histogram("rx", self.pred_poses[:,:,3])
tf.summary.histogram("ry", self.pred_poses[:,:,4])
tf.summary.histogram("rz", self.pred_poses[:,:,5])
# for var in tf.trainable_variables():
# tf.summary.histogram(var.op.name + "/values", var)
# for grad, var in self.grads_and_vars:
# tf.summary.histogram(var.op.name + "/gradients", grad)
def train(self, opt):
opt.num_source = opt.seq_length - 1
# TODO: currently fixed to 4
opt.num_scales = 4
self.opt = opt
self.build_train_graph()
self.collect_summaries()
with tf.name_scope("parameter_count"):
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
for v in tf.trainable_variables()])
self.saver = tf.train.Saver([var for var in tf.model_variables()] + \
[self.global_step],
max_to_keep=10)
sv = tf.train.Supervisor(logdir=opt.checkpoint_dir,
save_summaries_secs=0,
saver=None)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with sv.managed_session(config=config) as sess:
print('Trainable variables: ')
for var in tf.trainable_variables():
print(var.name)
print("parameter_count =", sess.run(parameter_count))
if opt.continue_train:
if opt.init_checkpoint_file is None:
checkpoint = tf.train.latest_checkpoint(opt.checkpoint_dir)
else:
checkpoint = opt.init_checkpoint_file
print("Resume training from previous checkpoint: %s" % checkpoint)
self.saver.restore(sess, checkpoint)
start_time = time.time()
for step in range(1, opt.max_steps):
fetches = {
"train": self.train_op,
"global_step": self.global_step,
"incr_global_step": self.incr_global_step
}
if step % opt.summary_freq == 0:
fetches["loss"] = self.total_loss
fetches["summary"] = sv.summary_op
results = sess.run(fetches)
gs = results["global_step"]
if step % opt.summary_freq == 0:
sv.summary_writer.add_summary(results["summary"], gs)
train_epoch = math.ceil(gs / self.steps_per_epoch)
train_step = gs - (train_epoch - 1) * self.steps_per_epoch
print("Epoch: [%2d] [%5d/%5d] time: %4.4f/it loss: %.3f" \
% (train_epoch, train_step, self.steps_per_epoch, \
(time.time() - start_time)/opt.summary_freq,
results["loss"]))
start_time = time.time()
if step % opt.save_latest_freq == 0:
self.save(sess, opt.checkpoint_dir, 'latest')
if step % self.steps_per_epoch == 0:
self.save(sess, opt.checkpoint_dir, gs)
def build_depth_test_graph(self):
input_uint8 = tf.placeholder(tf.uint8, [self.batch_size,
self.img_height, self.img_width, 3], name='raw_input')
input_mc = self.preprocess_image(input_uint8)
with tf.name_scope("depth_prediction"):
pred_disp, depth_net_endpoints = disp_net(
input_mc, is_training=False)
pred_depth = [1./disp for disp in pred_disp]
pred_depth = pred_depth[0]
self.inputs = input_uint8
self.pred_depth = pred_depth
self.depth_epts = depth_net_endpoints
def build_pose_test_graph(self):
input_uint8 = tf.placeholder(tf.uint8, [self.batch_size,
self.img_height, self.img_width * self.seq_length, 3],
name='raw_input')
input_mc = self.preprocess_image(input_uint8)
loader = DataLoader()
tgt_image, src_image_stack = \
loader.batch_unpack_image_sequence(
input_mc, self.img_height, self.img_width, self.num_source)
with tf.name_scope("pose_prediction"):
pred_poses, _, _ = pose_exp_net(
tgt_image, src_image_stack, do_exp=False, is_training=False)
self.inputs = input_uint8
self.pred_poses = pred_poses
def preprocess_image(self, image):
# Assuming input image is uint8
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image * 2. -1.
def deprocess_image(self, image):
# Assuming input image is float32
image = (image + 1.)/2.
return tf.image.convert_image_dtype(image, dtype=tf.uint8)
def setup_inference(self,
img_height,
img_width,
mode,
seq_length=3,
batch_size=1):
self.img_height = img_height
self.img_width = img_width
self.mode = mode
self.batch_size = batch_size
if self.mode == 'depth':
self.build_depth_test_graph()
if self.mode == 'pose':
self.seq_length = seq_length
self.num_source = seq_length - 1
self.build_pose_test_graph()
def inference(self, inputs, sess, mode='depth'):
fetches = {}
if mode == 'depth':
fetches['depth'] = self.pred_depth
if mode == 'pose':
fetches['pose'] = self.pred_poses
results = sess.run(fetches, feed_dict={self.inputs:inputs})
return results
def save(self, sess, checkpoint_dir, step):
model_name = 'model'
print(" [*] Saving checkpoint to %s..." % checkpoint_dir)
if step == 'latest':
self.saver.save(sess,
os.path.join(checkpoint_dir, model_name + '.latest'))
else:
self.saver.save(sess,
os.path.join(checkpoint_dir, model_name),
global_step=step)