forked from d1ggs/cycleGAN-keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresidual.py
49 lines (39 loc) · 2.21 KB
/
residual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from keras.layers import Conv2D, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.merge import add
from keras import initializers
from keras_contrib.layers import InstanceNormalization
'''
Keras Customizable Residual Unit
This is a simplified implementation of the basic (no bottlenecks) full pre-activation residual unit from He, K., Zhang, X., Ren, S., Sun, J., "Identity Mappings in Deep Residual Networks" (http://arxiv.org/abs/1603.05027v2).
'''
conv_init = initializers.RandomNormal(0, 0.02) # for convolution kernel
gamma_init = initializers.RandomNormal(1., 0.02) # for batch normalization
def conv_block(feat_maps_out, prev):
prev = InstanceNormalization(gamma_initializer=gamma_init)(prev, training=1) # Specifying the axis and mode allows for later merging
prev = Activation('relu')(prev) # possibile migliore risultato con ReLU?
prev = Conv2D(feat_maps_out, (3, 3), padding='same',
kernel_initializer=conv_init)(prev)
prev = InstanceNormalization(gamma_initializer=gamma_init)(prev, training=1) # Specifying the axis and mode allows for later merging
prev = Activation('relu')(prev)
prev = Conv2D(feat_maps_out, (3, 3), padding='same',
kernel_initializer=conv_init)(prev)
return prev
def skip_block(feat_maps_in, feat_maps_out, prev):
if feat_maps_in != feat_maps_out:
# This adds in a 1x1 convolution on shortcuts that map between an uneven amount of channels
prev = Conv2D(feat_maps_out, (1, 1), padding='same',
kernel_initializer=conv_init)(prev)
return prev
def Residual(feat_maps_in, feat_maps_out, prev_layer):
'''
A customizable residual unit with convolutional and shortcut blocks
Args:
feat_maps_in: number of channels/filters coming in, from input or previous layer
feat_maps_out: how many output channels/filters this block will produce
prev_layer: the previous layer
'''
skip = skip_block(feat_maps_in, feat_maps_out, prev_layer)
conv = conv_block(feat_maps_out, prev_layer)
#print('Residual block mapping '+str(feat_maps_in)+' channels to '+str(feat_maps_out)+' channels built')
return add([skip, conv]) # the residual connection