-
Notifications
You must be signed in to change notification settings - Fork 23
/
saml_func.py
336 lines (276 loc) · 21.4 KB
/
saml_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
from __future__ import print_function
import numpy as np
import sys
import tensorflow as tf
from tensorflow.image import resize_images
# try:
# import special_grads
# except KeyError as e:
# print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr)
from tensorflow.python.platform import flags
from layer import conv_block, deconv_block, fc, max_pool, concat2d
from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost
class SAML:
def __init__(self, args):
""" Call construct_model_*() after initializing MASF"""
self.args = args
self.batch_size = args.meta_batch_size
self.test_batch_size = args.test_batch_size
self.volume_size = args.volume_size
self.n_class = args.n_class
self.compactness_loss_weight = args.compactness_loss_weight
self.smoothness_loss_weight = args.smoothness_loss_weight
self.margin = args.margin
self.forward = self.forward_unet
self.construct_weights = self.construct_unet_weights
self.seg_loss = _get_segmentation_cost
self.get_compactness_cost = _get_compactness_cost
def construct_model_train(self, prefix='metatrain_'):
# a: meta-train for inner update, b: meta-test for meta loss
self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1])
self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1])
self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving")
self.clip_value = self.args.gradients_clip_value
self.KEEP_PROB = tf.placeholder(tf.float32)
with tf.variable_scope('model', reuse=None) as training_scope:
if 'weights' in dir(self):
print('weights already defined')
training_scope.reuse_variables()
weights = self.weights
else:
# Define the weights
self.weights = weights = self.construct_weights()
def task_metalearn(inp, reuse=True):
# Function to perform meta learning update """
inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp
# Obtaining the conventional task loss on meta-train
task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode)
task_lossa = self.seg_loss(task_outputa, labela)
task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode)
task_lossa1 = self.seg_loss(task_outputa1, labela1)
## perform inner update with plain gradient descent on meta-train
grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values()))
grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation
gradients = dict(zip(weights.keys(), grads))
# fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()]))
fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()]))
## compute compactness loss
task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode)
task_lossb = self.seg_loss(task_outputb, labelb)
compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb)
compactness_loss_b = self.compactness_loss_weight * compactness_loss_b
# compute smoothness loss
_, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode)
coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings)
metric_embeddings = self.forward_metric_net(coutour_embeddings)
print (metric_label_group.shape)
print (metric_embeddings.shape)
smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin)
smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b
task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1]
return task_output
self.global_step = tf.Variable(0, trainable=False)
# self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
# self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False)
self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False)
self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False)
input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group)
result = task_metalearn(inp=input_tensors)
self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result
## Performance & Optimization
if 'train' in prefix:
self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0
self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b
var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')]
var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')]
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step)
optimizer = tf.train.AdamOptimizer(self.outer_lr)
gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor)
# observe stability of gradients for meta loss
# l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
# for grad, var in gvs:
# tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad))
# tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var))
# tf.summary.histogram('gradients/' + var.name, var)
# tf.summary.histogram("feature_extractor_var/" + var.name, var)
# gvs = [(grad, var) for grad, var in gvs]
gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs]
self.meta_train_op = optimizer.apply_gradients(gvs)
# for grad, var in gvs:
# tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad))
# tf.summary.histogram('gradients_clipped/' + var.name, var)
self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric)
## Summaries
# scalar_summaries = []
# train_images = []
# val_images = []
tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a)
tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1)
tf.summary.scalar(prefix+'target_loss', self.seg_loss_b)
tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b)
tf.summary.scalar(prefix+'target_length', self.length)
tf.summary.scalar(prefix+'target_area', self.area)
tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3))
tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3))
tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3))
tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3))
tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3))
tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3))
tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3))
def extract_coutour_embedding(self, coutour, embeddings):
coutour_embeddings = coutour * embeddings
average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2])
# print (coutour.shape)
# print (embeddings.shape)
# print (coutour_embeddings.shape)
# print (average_embeddings.shape)
return average_embeddings
def construct_model_test(self, prefix='test'):
self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
with tf.variable_scope('model', reuse=None) as testing_scope:
if 'weights' in dir(self):
testing_scope.reuse_variables()
weights = self.weights
else:
raise ValueError('Weights not initilized. Create training model before testing model')
outputs, mask, _ = self.forward(self.test_input, weights)
losses = self.seg_loss(outputs, self.test_label)
# self.pred_prob = tf.nn.softmax(outputs)
self.outputs = mask
self.test_loss = losses
# self.test_acc = accuracies
def forward_metric_net(self, x):
with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope:
w1 = tf.get_variable('w1', shape=[48,24])
b1 = tf.get_variable('b1', shape=[24])
out = fc(x, w1, b1, activation='leaky_relu')
w2 = tf.get_variable('w2', shape=[24,16])
b2 = tf.get_variable('b2', shape=[16])
out = fc(out, w2, b2, activation='leaky_relu')
return out
def construct_unet_weights(self):
weights = {}
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
with tf.variable_scope('conv1') as scope:
weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer)
weights['conv11_biases'] = tf.get_variable('biases', [16])
weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer)
weights['conv12_biases'] = tf.get_variable('biases2', [16])
with tf.variable_scope('conv2') as scope:
weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer)
weights['conv21_biases'] = tf.get_variable('biases', [32])
weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer)
weights['conv22_biases'] = tf.get_variable('biases2', [32])
## Network has downsample here
with tf.variable_scope('conv3') as scope:
weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer)
weights['conv31_biases'] = tf.get_variable('biases', [64])
weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
weights['conv32_biases'] = tf.get_variable('biases2', [64])
with tf.variable_scope('conv4') as scope:
weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer)
weights['conv41_biases'] = tf.get_variable('biases', [128])
weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
weights['conv42_biases'] = tf.get_variable('biases2', [128])
## Network has downsample here
with tf.variable_scope('conv5') as scope:
weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer)
weights['conv51_biases'] = tf.get_variable('biases', [256])
weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer)
weights['conv52_biases'] = tf.get_variable('biases2', [256])
with tf.variable_scope('deconv6') as scope:
weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer)
weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer)
weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer)
weights['conv61_biases'] = tf.get_variable('biases', [128])
weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
weights['conv62_biases'] = tf.get_variable('biases2', [128])
with tf.variable_scope('deconv7') as scope:
weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer)
weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer)
weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer)
weights['conv71_biases'] = tf.get_variable('biases', [64])
weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
weights['conv72_biases'] = tf.get_variable('biases2', [64])
with tf.variable_scope('deconv8') as scope:
weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer)
weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer)
weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer)
weights['conv81_biases'] = tf.get_variable('biases', [32])
weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer)
weights['conv82_biases'] = tf.get_variable('biases2', [32])
with tf.variable_scope('deconv9') as scope:
weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer)
weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer)
weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer)
weights['conv91_biases'] = tf.get_variable('biases', [16])
weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer)
weights['conv92_biases'] = tf.get_variable('biases2', [16])
with tf.variable_scope('output') as scope:
weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer)
weights['output_biases'] = tf.get_variable('biases', [2])
return weights
def forward_unet(self, inp, weights, is_training=True):
self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training)
self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training)
self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID')
# 192x192x16
self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training)
self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training)
self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID')
# 96x96x32
self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training)
self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training)
self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID')
# 48x48x64
self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training)
self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training)
self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID')
# 24x24x128
self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training)
self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training)
# 24x24x256
## add upsampling, meanwhile, channel number is reduced to half
self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training)
# 48x48x128
self.sum6 = concat2d(self.deconv6, self.deconv6)
self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training)
self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training)
# 48x48x128
self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training)
# 96x96x64
self.sum7 = concat2d(self.deconv7, self.deconv7)
self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training)
self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training)
# 96x96x64
self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training)
# 192x192x32
self.sum8 = concat2d(self.deconv8, self.deconv8)
self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training)
self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training)
self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False)
# 192x192x32
self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training)
# 384x384x16
self.sum9 = concat2d(self.deconv9, self.deconv9)
self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training)
self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training)
# 384x384x16
self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training)
#384x384x2
self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes]
self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h]
self.embeddings = concat2d(self.conv82_resize, self.conv92)
return self.pred_prob, self.pred_compact, self.embeddings