1
1
import tensorflow as tf
2
2
import numpy as np
3
3
4
- def get_weights (shape , name , mask = None ):
4
+ def get_weights (shape , name , horizontal , mask_mode = 'noblind' , mask = None ):
5
5
weights_initializer = tf .contrib .layers .xavier_initializer ()
6
6
W = tf .get_variable (name , shape , tf .float32 , weights_initializer )
7
7
8
8
'''
9
9
Use of masking to hide subsequent pixel values
10
10
'''
11
11
if mask :
12
- filter_mid_x = shape [0 ]// 2
13
- filter_mid_y = shape [1 ]// 2
12
+ filter_mid_y = shape [0 ]// 2
13
+ filter_mid_x = shape [1 ]// 2
14
14
mask_filter = np .ones (shape , dtype = np .float32 )
15
- mask_filter [filter_mid_x , filter_mid_y + 1 :, :, :] = 0.
16
- mask_filter [filter_mid_x + 1 :, :, :, :] = 0.
15
+ if mask_mode == 'noblind' :
16
+ if horizontal :
17
+ # All rows after center must be zero
18
+ mask_filter [filter_mid_y + 1 :, :, :, :] = 0.0
19
+ # All columns after center in center row must be zero
20
+ mask_filter [filter_mid_y , filter_mid_x + 1 :, :, :] = 0.0
21
+ else :
22
+ if mask == 'a' :
23
+ # In the first layer, can ONLY access pixels above it
24
+ mask_filter [filter_mid_y :, :, :, :] = 0.0
25
+ else :
26
+ # In the second layer, can access pixels above or even with it.
27
+ # Reason being that the pixels to the right or left of the current pixel
28
+ # only have a receptive field of the layer above the current layer and up.
29
+ mask_filter [filter_mid_y + 1 :, :, :, :] = 0.0
30
+
31
+ if mask == 'a' :
32
+ # Center must be zero in first layer
33
+ mask_filter [filter_mid_y , filter_mid_x , :, :] = 0.0
34
+ else :
35
+ mask_filter [filter_mid_y , filter_mid_x + 1 :, :, :] = 0.
36
+ mask_filter [filter_mid_y + 1 :, :, :, :] = 0.
17
37
18
- if mask == 'a' :
19
- mask_filter [filter_mid_x , filter_mid_y , :, :] = 0.
20
-
38
+ if mask == 'a' :
39
+ mask_filter [filter_mid_y , filter_mid_x , :, :] = 0.
40
+
21
41
W *= mask_filter
22
42
return W
23
43
@@ -31,39 +51,48 @@ def max_pool_2x2(x):
31
51
return tf .nn .max_pool (x , ksize = [1 ,2 ,2 ,1 ], strides = [1 ,2 ,2 ,1 ], padding = 'SAME' )
32
52
33
53
class GatedCNN ():
34
- def __init__ (self , W_shape , fan_in , gated = True , payload = None , mask = None , activation = True , conditional = None ):
54
+ def __init__ (self , W_shape , fan_in , horizontal , gated = True , payload = None , mask = None , activation = True , conditional = None , conditional_image = None ):
35
55
self .fan_in = fan_in
36
56
in_dim = self .fan_in .get_shape ()[- 1 ]
37
57
self .W_shape = [W_shape [0 ], W_shape [1 ], in_dim , W_shape [2 ]]
38
58
self .b_shape = W_shape [2 ]
39
59
60
+ self .in_dim = in_dim
40
61
self .payload = payload
41
62
self .mask = mask
42
63
self .activation = activation
43
64
self .conditional = conditional
65
+ self .conditional_image = conditional_image
66
+ self .horizontal = horizontal
44
67
45
68
if gated :
46
69
self .gated_conv ()
47
70
else :
48
71
self .simple_conv ()
49
72
50
73
def gated_conv (self ):
51
- W_f = get_weights (self .W_shape , "v_W" , mask = self .mask )
52
- W_g = get_weights (self .W_shape , "h_W" , mask = self .mask )
74
+ W_f = get_weights (self .W_shape , "v_W" , self .horizontal , mask = self .mask )
75
+ W_g = get_weights (self .W_shape , "h_W" , self .horizontal , mask = self .mask )
76
+
77
+ b_f_total = get_bias (self .b_shape , "v_b" )
78
+ b_g_total = get_bias (self .b_shape , "h_b" )
53
79
if self .conditional is not None :
54
80
h_shape = int (self .conditional .get_shape ()[1 ])
55
- V_f = get_weights ([h_shape , self .W_shape [3 ]], "v_V" )
81
+ V_f = get_weights ([h_shape , self .W_shape [3 ]], "v_V" , self . horizontal )
56
82
b_f = tf .matmul (self .conditional , V_f )
57
- V_g = get_weights ([h_shape , self .W_shape [3 ]], "h_V" )
83
+ V_g = get_weights ([h_shape , self .W_shape [3 ]], "h_V" , self . horizontal )
58
84
b_g = tf .matmul (self .conditional , V_g )
59
85
60
86
b_f_shape = tf .shape (b_f )
61
87
b_f = tf .reshape (b_f , (b_f_shape [0 ], 1 , 1 , b_f_shape [1 ]))
62
88
b_g_shape = tf .shape (b_g )
63
89
b_g = tf .reshape (b_g , (b_g_shape [0 ], 1 , 1 , b_g_shape [1 ]))
64
- else :
65
- b_f = get_bias (self .b_shape , "v_b" )
66
- b_g = get_bias (self .b_shape , "h_b" )
90
+
91
+ b_f_total = b_f_total + b_f
92
+ b_g_total = b_g_total + b_g
93
+ if self .conditional_image is not None :
94
+ b_f_total = b_f_total + tf .layers .conv2d (self .conditional_image , self .in_dim , 1 , use_bias = False , name = "ci_f" )
95
+ b_g_total = b_g_total + tf .layers .conv2d (self .conditional_image , self .in_dim , 1 , use_bias = False , name = "ci_g" )
67
96
68
97
conv_f = conv_op (self .fan_in , W_f )
69
98
conv_g = conv_op (self .fan_in , W_g )
@@ -72,10 +101,10 @@ def gated_conv(self):
72
101
conv_f += self .payload
73
102
conv_g += self .payload
74
103
75
- self .fan_out = tf .multiply (tf .tanh (conv_f + b_f ), tf .sigmoid (conv_g + b_g ))
104
+ self .fan_out = tf .multiply (tf .tanh (conv_f + b_f_total ), tf .sigmoid (conv_g + b_g_total ))
76
105
77
106
def simple_conv (self ):
78
- W = get_weights (self .W_shape , "W" , mask = self .mask )
107
+ W = get_weights (self .W_shape , "W" , self . horizontal , mask_mode = "standard" , mask = self .mask )
79
108
b = get_bias (self .b_shape , "b" )
80
109
conv = conv_op (self .fan_in , W )
81
110
if self .activation :
0 commit comments