-
Notifications
You must be signed in to change notification settings - Fork 104
/
Copy pathmodule.py
74 lines (57 loc) · 2.61 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras as keras
# ==============================================================================
# = networks =
# ==============================================================================
def _get_norm_layer(norm):
if norm == 'none':
return lambda: lambda x: x
elif norm == 'batch_norm':
return keras.layers.BatchNormalization
elif norm == 'instance_norm':
return tfa.layers.InstanceNormalization
elif norm == 'layer_norm':
return keras.layers.LayerNormalization
def ConvGenerator(input_shape=(1, 1, 128),
output_channels=3,
dim=64,
n_upsamplings=4,
norm='batch_norm',
name='ConvGenerator'):
Norm = _get_norm_layer(norm)
# 0
h = inputs = keras.Input(shape=input_shape)
# 1: 1x1 -> 4x4
d = min(dim * 2 ** (n_upsamplings - 1), dim * 8)
h = keras.layers.Conv2DTranspose(d, 4, strides=1, padding='valid', use_bias=False)(h)
h = Norm()(h)
h = tf.nn.relu(h) # or h = keras.layers.ReLU()(h)
# 2: upsamplings, 4x4 -> 8x8 -> 16x16 -> ...
for i in range(n_upsamplings - 1):
d = min(dim * 2 ** (n_upsamplings - 2 - i), dim * 8)
h = keras.layers.Conv2DTranspose(d, 4, strides=2, padding='same', use_bias=False)(h)
h = Norm()(h)
h = tf.nn.relu(h) # or h = keras.layers.ReLU()(h)
h = keras.layers.Conv2DTranspose(output_channels, 4, strides=2, padding='same')(h)
h = tf.tanh(h) # or h = keras.layers.Activation('tanh')(h)
return keras.Model(inputs=inputs, outputs=h, name=name)
def ConvDiscriminator(input_shape=(64, 64, 3),
dim=64,
n_downsamplings=4,
norm='batch_norm',
name='ConvDiscriminator'):
Norm = _get_norm_layer(norm)
# 0
h = inputs = keras.Input(shape=input_shape)
# 1: downsamplings, ... -> 16x16 -> 8x8 -> 4x4
h = keras.layers.Conv2D(dim, 4, strides=2, padding='same')(h)
h = tf.nn.leaky_relu(h, alpha=0.2) # or keras.layers.LeakyReLU(alpha=0.2)(h)
for i in range(n_downsamplings - 1):
d = min(dim * 2 ** (i + 1), dim * 8)
h = keras.layers.Conv2D(d, 4, strides=2, padding='same', use_bias=False)(h)
h = Norm()(h)
h = tf.nn.leaky_relu(h, alpha=0.2) # or h = keras.layers.LeakyReLU(alpha=0.2)(h)
# 2: logit
h = keras.layers.Conv2D(1, 4, strides=1, padding='valid')(h)
return keras.Model(inputs=inputs, outputs=h, name=name)