Skip to content

Commit 055dab6

Browse files
authored
Merge pull request #7 from kkleidal/noblind
Removed blind spot in masked convolutions (see
2 parents 9a5c9a3 + dc5e67b commit 055dab6

File tree

5 files changed

+69
-40
lines changed

5 files changed

+69
-40
lines changed

Diff for: autoencoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def trainAE(conf, data):
2727

2828
if os.path.exists(conf.ckpt_file):
2929
saver.restore(sess, conf.ckpt_file)
30-
print "Model Restored"
30+
print("Model Restored")
3131

3232
# TODO The training part below and in main.py could be generalized
3333
if conf.epochs > 0:
34-
print "Started Model Training..."
34+
print("Started Model Training...")
3535
pointer = 0
3636
step = 0
3737
for i in range(conf.epochs):
@@ -45,7 +45,7 @@ def trainAE(conf, data):
4545
writer.add_summary(summary, step)
4646
step += 1
4747

48-
print "Epoch: %d, Cost: %f"%(i, l)
48+
print("Epoch: %d, Cost: %f"%(i, l))
4949
if (i+1)%10 == 0:
5050
saver.save(sess, conf.ckpt_file)
5151
generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i))

Diff for: layers.py

+47-18
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,43 @@
11
import tensorflow as tf
22
import numpy as np
33

4-
def get_weights(shape, name, mask=None):
4+
def get_weights(shape, name, horizontal, mask_mode='noblind', mask=None):
55
weights_initializer = tf.contrib.layers.xavier_initializer()
66
W = tf.get_variable(name, shape, tf.float32, weights_initializer)
77

88
'''
99
Use of masking to hide subsequent pixel values
1010
'''
1111
if mask:
12-
filter_mid_x = shape[0]//2
13-
filter_mid_y = shape[1]//2
12+
filter_mid_y = shape[0]//2
13+
filter_mid_x = shape[1]//2
1414
mask_filter = np.ones(shape, dtype=np.float32)
15-
mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.
16-
mask_filter[filter_mid_x+1:, :, :, :] = 0.
15+
if mask_mode == 'noblind':
16+
if horizontal:
17+
# All rows after center must be zero
18+
mask_filter[filter_mid_y+1:, :, :, :] = 0.0
19+
# All columns after center in center row must be zero
20+
mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.0
21+
else:
22+
if mask == 'a':
23+
# In the first layer, can ONLY access pixels above it
24+
mask_filter[filter_mid_y:, :, :, :] = 0.0
25+
else:
26+
# In the second layer, can access pixels above or even with it.
27+
# Reason being that the pixels to the right or left of the current pixel
28+
# only have a receptive field of the layer above the current layer and up.
29+
mask_filter[filter_mid_y+1:, :, :, :] = 0.0
30+
31+
if mask == 'a':
32+
# Center must be zero in first layer
33+
mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
34+
else:
35+
mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.
36+
mask_filter[filter_mid_y+1:, :, :, :] = 0.
1737

18-
if mask == 'a':
19-
mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.
20-
38+
if mask == 'a':
39+
mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.
40+
2141
W *= mask_filter
2242
return W
2343

@@ -31,39 +51,48 @@ def max_pool_2x2(x):
3151
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
3252

3353
class GatedCNN():
34-
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None):
54+
def __init__(self, W_shape, fan_in, horizontal, gated=True, payload=None, mask=None, activation=True, conditional=None, conditional_image=None):
3555
self.fan_in = fan_in
3656
in_dim = self.fan_in.get_shape()[-1]
3757
self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]
3858
self.b_shape = W_shape[2]
3959

60+
self.in_dim = in_dim
4061
self.payload = payload
4162
self.mask = mask
4263
self.activation = activation
4364
self.conditional = conditional
65+
self.conditional_image = conditional_image
66+
self.horizontal = horizontal
4467

4568
if gated:
4669
self.gated_conv()
4770
else:
4871
self.simple_conv()
4972

5073
def gated_conv(self):
51-
W_f = get_weights(self.W_shape, "v_W", mask=self.mask)
52-
W_g = get_weights(self.W_shape, "h_W", mask=self.mask)
74+
W_f = get_weights(self.W_shape, "v_W", self.horizontal, mask=self.mask)
75+
W_g = get_weights(self.W_shape, "h_W", self.horizontal, mask=self.mask)
76+
77+
b_f_total = get_bias(self.b_shape, "v_b")
78+
b_g_total = get_bias(self.b_shape, "h_b")
5379
if self.conditional is not None:
5480
h_shape = int(self.conditional.get_shape()[1])
55-
V_f = get_weights([h_shape, self.W_shape[3]], "v_V")
81+
V_f = get_weights([h_shape, self.W_shape[3]], "v_V", self.horizontal)
5682
b_f = tf.matmul(self.conditional, V_f)
57-
V_g = get_weights([h_shape, self.W_shape[3]], "h_V")
83+
V_g = get_weights([h_shape, self.W_shape[3]], "h_V", self.horizontal)
5884
b_g = tf.matmul(self.conditional, V_g)
5985

6086
b_f_shape = tf.shape(b_f)
6187
b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1]))
6288
b_g_shape = tf.shape(b_g)
6389
b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1]))
64-
else:
65-
b_f = get_bias(self.b_shape, "v_b")
66-
b_g = get_bias(self.b_shape, "h_b")
90+
91+
b_f_total = b_f_total + b_f
92+
b_g_total = b_g_total + b_g
93+
if self.conditional_image is not None:
94+
b_f_total = b_f_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_f")
95+
b_g_total = b_g_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_g")
6796

6897
conv_f = conv_op(self.fan_in, W_f)
6998
conv_g = conv_op(self.fan_in, W_g)
@@ -72,10 +101,10 @@ def gated_conv(self):
72101
conv_f += self.payload
73102
conv_g += self.payload
74103

75-
self.fan_out = tf.multiply(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))
104+
self.fan_out = tf.multiply(tf.tanh(conv_f + b_f_total), tf.sigmoid(conv_g + b_g_total))
76105

77106
def simple_conv(self):
78-
W = get_weights(self.W_shape, "W", mask=self.mask)
107+
W = get_weights(self.W_shape, "W", self.horizontal, mask_mode="standard", mask=self.mask)
79108
b = get_bias(self.b_shape, "b")
80109
conv = conv_op(self.fan_in, W)
81110
if self.activation:

Diff for: main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def train(conf, data):
2121
sess.run(tf.initialize_all_variables())
2222
if os.path.exists(conf.ckpt_file):
2323
saver.restore(sess, conf.ckpt_file)
24-
print "Model Restored"
24+
print("Model Restored")
2525

2626
if conf.epochs > 0:
27-
print "Started Model Training..."
27+
print("Started Model Training...")
2828
pointer = 0
2929
for i in range(conf.epochs):
3030
for j in range(conf.num_batches):
@@ -39,7 +39,7 @@ def train(conf, data):
3939
if conf.conditional is True:
4040
data_dict[model.h] = batch_y
4141
_, cost = sess.run([optimizer, model.loss], feed_dict=data_dict)
42-
print "Epoch: %d, Cost: %f"%(i, cost)
42+
print("Epoch: %d, Cost: %f"%(i, cost))
4343
if (i+1)%10 == 0:
4444
saver.save(sess, conf.ckpt_file)
4545
generate_samples(sess, X, model.h, model.pred, conf, "")

Diff for: models.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from layers import *
33

44
class PixelCNN(object):
5-
def __init__(self, X, conf, h=None):
5+
def __init__(self, X, conf, full_horizontal=True, h=None):
66
self.X = X
77
if conf.data == "mnist":
88
self.X_norm = X
@@ -27,33 +27,33 @@ def __init__(self, X, conf, h=None):
2727
residual = True if i > 0 else False
2828
i = str(i)
2929
with tf.variable_scope("v_stack"+i):
30-
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask, conditional=self.h).output()
30+
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, False, mask=mask, conditional=self.h).output()
3131
v_stack_in = v_stack
3232

3333
with tf.variable_scope("v_stack_1"+i):
34-
v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output()
34+
v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, False, gated=False, mask=mask).output()
3535

3636
with tf.variable_scope("h_stack"+i):
37-
h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask, conditional=self.h).output()
37+
h_stack = GatedCNN([filter_size if full_horizontal else 1, filter_size, conf.f_map], h_stack_in, True, payload=v_stack_1, mask=mask, conditional=self.h).output()
3838

3939
with tf.variable_scope("h_stack_1"+i):
40-
h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output()
40+
h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, True, gated=False, mask=mask).output()
4141
if residual:
4242
h_stack_1 += h_stack_in # Residual connection
4343
h_stack_in = h_stack_1
4444

4545
with tf.variable_scope("fc_1"):
46-
fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output()
46+
fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, True, gated=False, mask='b').output()
4747

4848
if conf.data == "mnist":
4949
with tf.variable_scope("fc_2"):
50-
self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
50+
self.fc2 = GatedCNN([1, 1, 1], fc1, True, gated=False, mask='b', activation=False).output()
5151
self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fc2, labels=self.X))
5252
self.pred = tf.nn.sigmoid(self.fc2)
5353
else:
5454
color_dim = 256
5555
with tf.variable_scope("fc_2"):
56-
self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output()
56+
self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, True, gated=False, mask='b', activation=False).output()
5757
self.fc2 = tf.reshape(self.fc2, (-1, color_dim))
5858

5959
self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32)))

Diff for: utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ def binarize(images):
88
return (np.random.uniform(size=images.shape) < images).astype(np.float32)
99

1010
def generate_samples(sess, X, h, pred, conf, suff):
11-
print "Generating Sample Images..."
11+
print("Generating Sample Images...")
1212
n_row, n_col = 10,10
1313
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
1414
# TODO make it generic
1515
labels = one_hot(np.array([0,1,2,3,4,5,6,7,8,9]*10), conf.num_classes)
1616

17-
for i in xrange(conf.img_height):
18-
for j in xrange(conf.img_width):
19-
for k in xrange(conf.channel):
17+
for i in range(conf.img_height):
18+
for j in range(conf.img_width):
19+
for k in range(conf.channel):
2020
data_dict = {X:samples}
2121
if conf.conditional is True:
2222
data_dict[h] = labels
@@ -29,17 +29,17 @@ def generate_samples(sess, X, h, pred, conf, suff):
2929

3030

3131
def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''):
32-
print "Generating Sample Images..."
32+
print("Generating Sample Images...")
3333
n_row, n_col = 10,10
3434
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
3535
if conf.data == 'mnist':
3636
labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))
3737
else:
3838
labels = get_batch(data, 0, n_row*n_col)
3939

40-
for i in xrange(conf.img_height):
41-
for j in xrange(conf.img_width):
42-
for k in xrange(conf.channel):
40+
for i in range(conf.img_height):
41+
for j in range(conf.img_width):
42+
for k in range(conf.channel):
4343
next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples})
4444
if conf.data == 'mnist':
4545
next_sample = binarize(next_sample)

0 commit comments

Comments
 (0)