-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspectral_norm_v2.py
317 lines (273 loc) · 12.6 KB
/
spectral_norm_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import tensorflow as tf
from tensorlayer.layers.core import Layer
from tensorlayer.layers.core import LayersConfig
from tensorlayer import tl_logging as logging
from tensorlayer.decorators import deprecated_alias
from tensorlayer.decorators import private_method
import warnings
NO_OPS = 'NO_OPS'
class SpectralConv2dLayer(Layer):
"""
The :class:`SpectralConv2dLayer` class is a 2D CNN layer with spectral norm, see `tf.nn.conv2d <https://www.tensorflow.org/versions/master/api_docs/python/nn.html#conv2d>`__.
Parameters
----------
prev_layer : :class:`Layer`
Previous layer.
act : activation function
The activation function of this layer.
shape : tuple of int
The shape of the filters: (filter_height, filter_width, in_channels, out_channels).
strides : tuple of int
The sliding window strides of corresponding input dimensions.
It must be in the same order as the ``shape`` parameter.
padding : str
The padding algorithm type: "SAME" or "VALID".
W_init : initializer
The initializer for the weight matrix.
b_init : initializer or None
The initializer for the bias vector. If None, skip biases.
W_init_args : dictionary
The arguments for the weight matrix initializer.
b_init_args : dictionary
The arguments for the bias vector initializer.
use_cudnn_on_gpu : bool
Default is False.
use_sn: bool
Default is False.
data_format : str
"NHWC" or "NCHW", default is "NHWC".
name : str
A unique layer name.
Notes
-----
- shape = [h, w, the number of output channel of previous layer, the number of output channels]
- the number of output channel of a layer is its last dimension.
Examples
--------
With TensorLayer
Without TensorLayer, you can implement 2D convolution as follow.
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(
self,
prev_layer,
act=None,
shape=(5, 5, 1, 100),
strides=(1, 1, 1, 1),
padding='SAME',
W_init=tf.truncated_normal_initializer(stddev=0.02),
b_init=tf.constant_initializer(value=0.0),
W_init_args=None,
b_init_args=None,
use_cudnn_on_gpu=None,
use_sn=False,
update_collection=None,
data_format=None,
name='cnn_layer',
):
super(SpectralConv2dLayer, self
).__init__(prev_layer=prev_layer, act=act, W_init_args=W_init_args, b_init_args=b_init_args, name=name)
logging.info(
"SpectralConv2dLayer %s: shape: %s strides: %s pad: %s act: %s spectral: %s" % (
self.name, str(shape), str(strides), padding, self.act.__name__
if self.act is not None else 'No Activation', use_sn
)
)
# check layer name (fixed)
Layer.__init__(self, prev_layer=prev_layer, name=name)
# the input of this layer is the output of previous layer (fixed)
self.inputs = prev_layer.outputs
with tf.variable_scope(name) as scope:
if self._scope_has_variables(scope):
scope.reuse_variables()
W = tf.get_variable(
name='W_conv2d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **self.W_init_args
)
if use_sn:
self.outputs = tf.nn.conv2d(
self.inputs, self._spectral_normed_weight(W, update_collection=update_collection), strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format)
else:
self.outputs = tf.nn.conv2d(
self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format)
if b_init:
b = tf.get_variable(
name='b_conv2d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype,
**self.b_init_args
)
self.outputs = tf.nn.bias_add(self.outputs, b, name='bias_add')
self.outputs = self._apply_activation(self.outputs)
self._add_layers(self.outputs)
if b_init:
self._add_params([W, b])
else:
self._add_params(W)
@private_method
def _scope_has_variables(self, scope):
return len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)) > 0
@private_method
def _l2normalize(self, v, eps=1e-12):
return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
@private_method
def _spectral_normed_weight(self, W, u=None, num_iters=1, update_collection=None, with_sigma=False):
# Usually num_iters = 1 will be enough
W_shape = W.shape.as_list()
W_reshaped = tf.reshape(W, [-1, W_shape[-1]])
if u is None:
u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
def power_iteration(i, u_i, v_i):
v_ip1 = self._l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped)))
u_ip1 = self._l2normalize(tf.matmul(v_ip1, W_reshaped))
return i + 1, u_ip1, v_ip1
_, u_final, v_final = tf.while_loop(
cond=lambda i, _1, _2: i < num_iters,
body=power_iteration,
loop_vars=(tf.constant(0, dtype=tf.int32),
u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]]))
)
if update_collection is None:
warnings.warn(
'Setting update_collection to None will make u being updated every W execution. This maybe undesirable'
'. Please consider using a update collection instead.')
sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0]
# sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final)
W_bar = W_reshaped / sigma
with tf.control_dependencies([u.assign(u_final)]):
W_bar = tf.reshape(W_bar, W_shape)
else:
sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0]
# sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final)
W_bar = W_reshaped / sigma
W_bar = tf.reshape(W_bar, W_shape)
# Put NO_OPS to not update any collection. This is useful for the second call of discriminator if the update_op
# has already been collected on the first call.
if update_collection != NO_OPS:
tf.add_to_collection(update_collection, u.assign(u_final))
if with_sigma:
return W_bar, sigma
else:
return W_bar
class SpectralDenseLayer(Layer):
"""The :class:`SpectralDenseLayer` class is a fully connected layer with spectral norm
Parameters
----------
prev_layer : :class:`Layer`
Previous layer.
n_units : int
The number of units of this layer.
act : activation function
The activation function of this layer.
W_init : initializer
The initializer for the weight matrix.
b_init : initializer or None
The initializer for the bias vector. If None, skip biases.
W_init_args : dictionary
The arguments for the weight matrix initializer.
b_init_args : dictionary
The arguments for the bias vector initializer.
use_sn: bool
Default is False.
name : a str
A unique layer name.
Notes
-----
If the layer input has more than two axes, it needs to be flatten by using :class:`FlattenLayer`.
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(
self,
prev_layer,
n_units=100,
act=None,
W_init=tf.truncated_normal_initializer(stddev=0.1),
b_init=tf.constant_initializer(value=0.0),
W_init_args=None,
b_init_args=None,
use_sn=False,
update_collection=None,
name='dense',
):
super(SpectralDenseLayer, self
).__init__(prev_layer=prev_layer, act=act, W_init_args=W_init_args, b_init_args=b_init_args, name=name)
logging.info(
"DenseLayer %s: %d %s spectral: %s" %
(self.name, n_units, self.act.__name__ if self.act is not None else 'No Activation', use_sn)
)
# check layer name (fixed)
Layer.__init__(self, prev_layer=prev_layer, name=name)
# the input of this layer is the output of previous layer (fixed)
self.inputs = prev_layer.outputs
self.n_units = n_units
if self.inputs.get_shape().ndims != 2:
raise AssertionError("The input dimension must be rank 2, please reshape or flatten it")
n_in = int(self.inputs.get_shape()[-1])
with tf.variable_scope(name) as scope:
if self._scope_has_variables(scope):
scope.reuse_variables()
W = tf.get_variable(
name='W', shape=(n_in, n_units), initializer=W_init, dtype=LayersConfig.tf_dtype, **self.W_init_args
)
if use_sn: # if set spectral norm True
self.outputs = tf.matmul(self.inputs, self._spectral_normed_weight(W, update_collection=update_collection ))
else:
self.outputs = tf.matmul(self.inputs, W)
if b_init is not None:
try:
b = tf.get_variable(
name='b', shape=(n_units), initializer=b_init, dtype=LayersConfig.tf_dtype, **self.b_init_args
)
except Exception: # If initializer is a constant, do not specify shape.
b = tf.get_variable(name='b', initializer=b_init, dtype=LayersConfig.tf_dtype, **self.b_init_args)
self.outputs = tf.nn.bias_add(self.outputs, b, name='bias_add')
self.outputs = self._apply_activation(self.outputs)
self._add_layers(self.outputs)
if b_init is not None:
self._add_params([W, b])
else:
self._add_params(W)
@private_method
def _scope_has_variables(self, scope):
return len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)) > 0
@private_method
def _l2normalize(self, v, eps=1e-12):
return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
@private_method
def _spectral_normed_weight(self, W, u=None, num_iters=1, update_collection=None, with_sigma=False):
# Usually num_iters = 1 will be enough
W_shape = W.shape.as_list()
W_reshaped = tf.reshape(W, [-1, W_shape[-1]])
if u is None:
u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
def power_iteration(i, u_i, v_i):
v_ip1 = self._l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped)))
u_ip1 = self._l2normalize(tf.matmul(v_ip1, W_reshaped))
return i + 1, u_ip1, v_ip1
_, u_final, v_final = tf.while_loop(
cond=lambda i, _1, _2: i < num_iters,
body=power_iteration,
loop_vars=(tf.constant(0, dtype=tf.int32),
u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]]))
)
if update_collection is None:
warnings.warn(
'Setting update_collection to None will make u being updated every W execution. This maybe undesirable'
'. Please consider using a update collection instead.')
sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0]
# sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final)
W_bar = W_reshaped / sigma
with tf.control_dependencies([u.assign(u_final)]):
W_bar = tf.reshape(W_bar, W_shape)
else:
sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0]
# sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final)
W_bar = W_reshaped / sigma
W_bar = tf.reshape(W_bar, W_shape)
# Put NO_OPS to not update any collection. This is useful for the second call of discriminator if the update_op
# has already been collected on the first call.
if update_collection != NO_OPS:
tf.add_to_collection(update_collection, u.assign(u_final))
if with_sigma:
return W_bar, sigma
else:
return W_bar