@@ -33,7 +33,6 @@ def conv2d(inpt, nb_filter, filter_size=5, strides=2, bias=True, stddev=0.02, pa
33
33
# Convolution 2D Transpose
34
34
def deconv2d (inpt , output_shape , filter_size = 5 , strides = 2 , bias = True , stddev = 0.02 ,
35
35
padding = "SAME" , name = "deconv2d" ):
36
-
37
36
in_channels = inpt .get_shape ().as_list ()[- 1 ]
38
37
with tf .variable_scope (name ):
39
38
# Note: filter with shape [height, width, output_channels, in_channels]
@@ -54,17 +53,25 @@ def lrelu(x, leak=0.2, name="lrelu"):
54
53
def linear (x , output_dim , stddev = 0.02 , name = "linear" ):
55
54
input_dim = x .get_shape ().as_list ()[- 1 ]
56
55
with tf .variable_scope (name ):
57
- w = tf .get_variable ("w" , shape = [input_dim , output_dim ], initializer = tf .random_normal_initializer (stddev = stddev ))
56
+ w = tf .get_variable ("w" , shape = [input_dim , output_dim ], initializer = \
57
+ tf .random_normal_initializer (stddev = stddev ))
58
58
b = tf .get_variable ("b" , shape = [output_dim ,], initializer = tf .constant_initializer (0.0 ))
59
59
return tf .nn .xw_plus_b (x , w , b )
60
60
61
61
class DCGAN (object ):
62
62
"""A class of DCGAN model"""
63
- def __init__ (self , z_dim = 100 , output_dim = 28 , batch_size = 100 , c_dim = 1 , df_dim = 64 , gf_dim = 64 , gfc_dim = 1024 ,
64
- dfc_dim = 1024 , n_conv = 2 , n_deconv = 2 ):
63
+ def __init__ (self , z_dim = 100 , output_dim = 28 , batch_size = 100 , c_dim = 1 , df_dim = 64 , gf_dim = 64 , dfc_dim = 1024 ,
64
+ n_conv = 3 , n_deconv = 2 ):
65
65
"""
66
66
:param z_dim: int, the dimension of z (the noise input of generator)
67
- :param output_dim: int,
67
+ :param output_dim: int, the resolution in pixels of the images (height, width)
68
+ :param batch_size: int, the size of the mini-batch
69
+ :param c_dim: int, the dimension of image color, for minist, it is 1 (grayscale)
70
+ :param df_dim: int, the number of filters in the first convolution layer of discriminator
71
+ :param gf_dim: int, the number of filters in the penultimate deconvolution layer of generator (last is 1)
72
+ :param dfc_dim: int, the number of units in the penultimate fully-connected layer of discriminator (last is 1)
73
+ :param n_conv: int, number of convolution layer in discriminator (the number of filters is double increased)
74
+ :param n_deconv: int, number of deconvolution layer in generator (the number of filters is double reduced)
68
75
"""
69
76
self .z_dim = z_dim
70
77
self .output_dim = output_dim
@@ -102,7 +109,7 @@ def _build_model(self):
102
109
def _discriminator (self , input , reuse = False ):
103
110
with tf .variable_scope ("D" , reuse = reuse ):
104
111
h = lrelu (conv2d (input , nb_filter = self .df_dim , name = "d_conv0" ))
105
- for i in range (1 , self .n_conv + 1 ):
112
+ for i in range (1 , self .n_conv ):
106
113
conv = conv2d (h , nb_filter = self .df_dim * (2 ** i ), name = "d_conv{0}" .format (i ))
107
114
h = lrelu (batch_norm (conv , name = "d_bn{0}" .format (i )))
108
115
h = linear (tf .reshape (h , shape = [self .batch_size , - 1 ]), self .dfc_dim , name = "d_lin0" )
@@ -112,22 +119,25 @@ def _discriminator(self, input, reuse=False):
112
119
def _generator (self , input ):
113
120
with tf .variable_scope ("G" ):
114
121
nb_fliters = [self .gf_dim ]
115
- s = [self .output_dim ]
122
+ f_size = [self .output_dim // 2 ]
116
123
for i in range (1 , self .n_deconv ):
117
124
nb_fliters .append (nb_fliters [- 1 ]* 2 )
118
- s .append (s [- 1 ]// 2 )
119
- s .append (s [- 1 ]// 2 )
120
- h = linear (input , nb_fliters [- 1 ]* s [- 1 ]* s [- 1 ], name = "g_lin0" )
121
- h = tf .nn .relu (batch_norm (tf .reshape (h , shape = [- 1 , s [- 1 ], s [- 1 ], nb_fliters [- 1 ]]), name = "g_bn0" ))
125
+ f_size .append (f_size [- 1 ]// 2 )
126
+
127
+ h = linear (input , nb_fliters [- 1 ]* f_size [- 1 ]* f_size [- 1 ], name = "g_lin0" )
128
+ h = tf .nn .relu (batch_norm (tf .reshape (h , shape = [- 1 , f_size [- 1 ], f_size [- 1 ], nb_fliters [- 1 ]]),
129
+ name = "g_bn0" ))
122
130
for i in range (1 , self .n_deconv ):
123
- h = deconv2d (h , [self .batch_size , s [- i - 1 ], s [- i - 1 ], nb_fliters [- i - 1 ]],
131
+ h = deconv2d (h , [self .batch_size , f_size [- i - 1 ], f_size [- i - 1 ], nb_fliters [- i - 1 ]],
124
132
name = "g_deconv{0}" .format (i - 1 ))
125
133
h = tf .nn .relu (batch_norm (h , name = "g_bn{0}" .format (i )))
126
134
127
- h = deconv2d (h , [self .batch_size , s [0 ], s [0 ], self .c_dim ], name = "g_deconv{0}" .format (self .n_deconv - 1 ))
135
+ h = deconv2d (h , [self .batch_size , self .output_dim , self .output_dim , self .c_dim ],
136
+ name = "g_deconv{0}" .format (self .n_deconv - 1 ))
128
137
return tf .nn .tanh (h )
129
138
130
139
def combine_images (images ):
140
+ """Combine the bacth images"""
131
141
num = images .shape [0 ]
132
142
width = int (np .sqrt (num ))
133
143
height = int (np .ceil (num / width ))
@@ -139,8 +149,8 @@ def combine_images(images):
139
149
img [i * h :(i + 1 )* h , j * w :(j + 1 )* w ] = m [:, :, 0 ]
140
150
return img
141
151
142
-
143
152
if __name__ == "__main__" :
153
+ # Load minist data
144
154
(X_train , y_train ), (X_test , y_test ) = mnist .load_data ()
145
155
X_train = (np .asarray (X_train , dtype = np .float32 ) - 127.5 )/ 127.5
146
156
X_train = np .reshape (X_train , [- 1 , 28 , 28 , 1 ])
@@ -152,15 +162,14 @@ def combine_images(images):
152
162
153
163
sess = tf .Session ()
154
164
dcgan = DCGAN (z_dim = z_dim , output_dim = 28 , batch_size = 128 , c_dim = 1 )
155
-
165
+ # The optimizers
156
166
d_train_op = tf .train .AdamOptimizer (lr , beta1 = 0.5 ).minimize (dcgan .d_loss ,
157
167
var_list = dcgan .d_vars )
158
168
g_train_op = tf .train .AdamOptimizer (lr , beta1 = 0.5 ).minimize (dcgan .g_loss ,
159
169
var_list = dcgan .g_vars )
160
170
sess .run (tf .global_variables_initializer ())
161
171
162
172
num_batches = int (len (X_train )/ batch_size )
163
-
164
173
for epoch in range (n_epochs ):
165
174
print ("Epoch" , epoch )
166
175
d_losses = 0
@@ -184,3 +193,7 @@ def combine_images(images):
184
193
img = combine_images (images )
185
194
img = img * 127.5 + 127.5
186
195
Image .fromarray (img .astype (np .uint8 )).save ("epoch{0}_g_images.png" .format (epoch ))
196
+
197
+
198
+
199
+
0 commit comments