-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
74 lines (55 loc) · 2.59 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import tensorflow.keras as kr
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
from FrEIA.keras.framework import ReversibleSequential
from FrEIA.keras.modules import AllInOneBlock, HaarDownsampling
BatchNormalization._USE_V2_BEHAVIOR = False
class SubnetFactory:
def __init__(self, ch_hidden):
self.ch = ch_hidden
self.init = kr.initializers.he_normal()
self.regu = kr.regularizers.l2(1e-4)
def __call__(self, ch_in, ch_out):
net = kr.models.Sequential([
kr.layers.Conv2D(self.ch, 3, padding='same', use_bias=False,
kernel_initializer=self.init, kernel_regularizer=self.regu),
BatchNormalization(),
kr.layers.ReLU(),
kr.layers.Conv2D(self.ch, 3, padding='same', use_bias=False,
kernel_initializer=self.init, kernel_regularizer=self.regu),
BatchNormalization(),
kr.layers.ReLU(),
kr.layers.Conv2D(ch_out, 3, padding='same', use_bias=True,
kernel_initializer=self.init, kernel_regularizer=self.regu),
], name='subnet')
return net
def build_model(args):
input_data_dims = eval(args['data']['data_dimensions'])
n_blocks_per_res = eval(args['model']['inn_coupling_blocks'])
channels_per_res = eval(args['model']['inn_subnet_channels'])
clamps_per_res = eval(args['model']['affine_clamp'])
actnorm_per_res = eval(args['model']['global_affine_init'])
n_res_levels = len(n_blocks_per_res)
model = ReversibleSequential(*input_data_dims)
for l in range(n_res_levels):
kwargs = {'affine_clamping': clamps_per_res[l],
'global_affine_init': actnorm_per_res[l],
'global_affine_type': 'SOFTPLUS',
'subnet_constructor': SubnetFactory(channels_per_res[l])}
for k in range(n_blocks_per_res[l]):
model.append(AllInOneBlock, **kwargs)
if l < n_res_levels - 1:
model.append(HaarDownsampling)
return model
if __name__ == '__main__':
args = {'data': {'data_dimensions': '(32, 32, 3)'},
'model': {'inn_coupling_blocks': '[2, 4, 4, 4]',
'inn_subnet_channels': '[16, 32, 64, 128]',
'affine_clamp': '[1.5] * 4',
'global_affine_init': '[0.8] * 4'}
}
test_model = build_model(args)
x = np.random.randn(16, 32, 32, 3).astype(np.float32)
x = tf.constant(x)
z = test_model(x)