-
Notifications
You must be signed in to change notification settings - Fork 4
/
sn.py
62 lines (47 loc) · 2.34 KB
/
sn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf # TF 2.0
class SpectralNormalization(tf.keras.layers.Wrapper):
def __init__(self, layer, iteration=1, eps=1e-12, training=True, **kwargs):
self.iteration = iteration
self.eps = eps
self.do_power_iteration = training
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError(
'Please initialize `TimeDistributed` layer with a '
'`Layer` instance. You passed: {input}'.format(input=layer))
super(SpectralNormalization, self).__init__(layer, **kwargs)
def build(self, input_shape):
self.layer.build(input_shape)
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.v = self.add_weight(shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
initializer=tf.initializers.TruncatedNormal(stddev=0.02),
trainable=False,
name='sn_v',
dtype=tf.float32)
self.u = self.add_weight(shape=(1, self.w_shape[-1]),
initializer=tf.initializers.TruncatedNormal(stddev=0.02),
trainable=False,
name='sn_u',
dtype=tf.float32)
super(SpectralNormalization, self).build()
def call(self, inputs):
self.update_weights()
output = self.layer(inputs)
self.restore_weights() # Restore weights because of this formula "W = W - alpha * W_SN`"
return output
def update_weights(self):
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
u_hat = self.u
v_hat = self.v # init v vector
if self.do_power_iteration:
for _ in range(self.iteration):
v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
v_hat = v_ / (tf.reduce_sum(v_**2)**0.5 + self.eps)
u_ = tf.matmul(v_hat, w_reshaped)
u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
self.u.assign(u_hat)
self.v.assign(v_hat)
self.layer.kernel.assign(self.w / sigma)
def restore_weights(self):
self.layer.kernel.assign(self.w)