-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.py
113 lines (89 loc) · 5.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import math
import keras.backend as K
from keras.layers import Conv2D, BatchNormalization, Activation, Add, \
AveragePooling2D, Input, Dense, Flatten, UpSampling2D, Layer, Reshape, Concatenate, Lambda
from keras.models import Model
from keras.regularizers import l2
def encoder_layers_introvae(image_size, base_channels, bn_allowed):
layers = []
layers.append(Conv2D(base_channels, (5, 5), strides=(1, 1), padding='same', kernel_initializer='he_normal', name='encoder_conv_0'))
if bn_allowed:
layers.append(BatchNormalization(axis=1, name='encoder_bn_0'))
layers.append(Activation('relu'))
layers.append(AveragePooling2D(pool_size=(2, 2), strides=None, padding='valid', name='encoder_avgpool_0'))
map_size = image_size[0] // 2
block = 1
channels = base_channels * 2
while map_size > 4:
layers.append(residual_block('encoder', [(3, 3), (3, 3)], channels, block=block, bn_allowed=bn_allowed))
layers.append(AveragePooling2D(pool_size=(2, 2), strides=None, padding='valid', name='encoder_avgpool_'+ str(block)))
block += 1
map_size = map_size // 2
channels = channels * 2 if channels <= 256 else 512
layers.append(residual_block('encoder', kernels=[(3, 3), (3, 3)], filters=channels, block=block, bn_allowed=bn_allowed, last_activation="linear"))
layers.append(Flatten(name='encoder_reshape'))
return layers
def generator_layers_introvae(image_size, base_channels, bn_allowed):
layers = []
layers.append(Dense(512 * 4 * 4, name='generator_dense'))
layers.append(Activation('relu'))
layers.append(Reshape((512, 4, 4), name='generator_reshape'))
layers.append(residual_block('generator', kernels=[(3, 3), (3, 3)], filters=512, block=1, bn_allowed=bn_allowed))
map_size = 4
upsamples = int(math.log2(image_size[0]) - 2)
block = 2
channels = 512
for i in range(upsamples - 6):
layers.append(UpSampling2D(size=(2, 2), name='generator_upsample_' + str(block)))
layers.append(residual_block('generator', [(3, 3), (3, 3)], 512, block=block, bn_allowed=bn_allowed))
map_size = map_size * 2
block += 1
while map_size < image_size[0]: # 4
channels = channels // 2 if channels >= 32 else 16
layers.append(UpSampling2D(size=(2, 2), name='generator_upsample_' + str(block)))
layers.append(residual_block('generator', [(3, 3), (3, 3)], channels, block=block, bn_allowed=bn_allowed))
map_size = map_size * 2
block += 1
layers.append(Conv2D(3, (5, 5), padding='same', kernel_initializer='he_normal', name='generator_conv_0'))
layers.append(Activation('sigmoid'))
return layers
def residual_block(model_type, kernels, filters, block, bn_allowed, stage='a', last_activation="relu"):
def identity_block(input_tensor, filters=filters):
if isinstance(filters, int):
filters = [filters] * len(kernels)
assert len(filters) == len(kernels), 'Number of filters and number of kernels differs.'
bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
bn_name_base = model_type + '_resblock_bn_' + stage + str(block) + '_branch_'
conv_name_base = model_type + '_resblock_conv_' + stage + str(block) + '_branch_'
if K.int_shape(input_tensor[-1]) != filters[0]:
input_tensor = Conv2D(filters[0], (1, 1), padding='same', kernel_initializer='glorot_normal', name=conv_name_base + str('00'), data_format='channels_first')(input_tensor)
if bn_allowed:
input_tensor = BatchNormalization(axis=bn_axis, name=bn_name_base + str('00'))(input_tensor)
input_tensor = Activation('relu')(input_tensor)
x = input_tensor
for idx in range(len(filters)):
x = Conv2D(filters[idx], kernels[idx], padding='same', kernel_initializer='he_normal', name=conv_name_base + str(idx), data_format='channels_first')(x)
if bn_allowed:
x = BatchNormalization(axis=bn_axis, name=bn_name_base + str(idx))(x)
if idx <= len(filters) - 1:
x = Activation('relu')(x)
x = Add(name=model_type + '_resblock_add_' + stage + str(block))([x, input_tensor])
x = Activation(last_activation)(x)
return x
return identity_block
def add_sampling(hidden, sampling, sampling_std, batch_size, latent_dim, wd):
z_mean = Dense(latent_dim, kernel_regularizer=l2(wd))(hidden)
if not sampling:
z_log_var = Lambda(lambda x: 0*x, output_shape=[latent_dim])(z_mean)
return z_mean, z_mean, z_log_var
else:
if sampling_std > 0:
z_log_var = Lambda(lambda x: 0*x + K.log(K.square(sampling_std)), output_shape=[latent_dim])(z_mean)
else:
z_log_var = Dense(latent_dim, kernel_regularizer=l2(wd))(hidden)
def sampling(inputs):
z_mean, z_log_var = inputs
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.)
return z_mean + K.exp(z_log_var / 2) * epsilon
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
return z, z_mean, z_log_var