-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathtraining.py
130 lines (106 loc) · 4.46 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import keras.backend as K
from keras.applications import vgg16
'''
Module that defines loss functions and other auxiliary functions used when
training a pastiche model.
'''
def gram_matrix(x, norm_by_channels=False):
'''
Returns the Gram matrix of the tensor x.
'''
if K.ndim(x) == 3:
features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
shape = K.shape(x)
C, H, W = shape[0], shape[1], shape[2]
gram = K.dot(features, K.transpose(features))
elif K.ndim(x) == 4:
# Swap from (H, W, C) to (B, C, H, W)
x = K.permute_dimensions(x, (0, 3, 1, 2))
shape = K.shape(x)
B, C, H, W = shape[0], shape[1], shape[2], shape[3]
# Reshape as a batch of 2D matrices with vectorized channels
features = K.reshape(x, K.stack([B, C, H*W]))
# This is a batch of Gram matrices (B, C, C).
gram = K.batch_dot(features, features, axes=2)
else:
raise ValueError('The input tensor should be either a 3d (H, W, C) or 4d (B, H, W, C) tensor.')
# Normalize the Gram matrix
if norm_by_channels:
denominator = C * H * W # Normalization from Johnson
else:
denominator = H * W # Normalization from Google
gram = gram / K.cast(denominator, x.dtype)
return gram
def content_loss(x, target):
'''
Content loss is simply the MSE between activations of a layer
'''
return K.mean(K.square(target - x), axis=(1, 2, 3))
def style_loss(x, target, norm_by_channels=False):
'''
Style loss is the MSE between Gram matrices computed using activation maps.
'''
x_gram = gram_matrix(x, norm_by_channels=norm_by_channels)
return K.mean(K.square(target - x_gram), axis=(1, 2))
def tv_loss(x):
'''
Total variation loss is used to keep the image locally coherent
'''
assert K.ndim(x) == 4
a = K.square(x[:, :-1, :-1, :] - x[:, 1:, :-1, :])
b = K.square(x[:, :-1, :-1, :] - x[:, :-1, 1:, :])
return K.sum(a + b, axis=(1, 2, 3))
def get_content_features(out_dict, layer_names):
return [out_dict[l] for l in layer_names]
def get_style_features(out_dict, layer_names, norm_by_channels=False):
features = []
for l in layer_names:
layer_features = out_dict[l]
S = gram_matrix(layer_features, norm_by_channels=norm_by_channels)
features.append(S)
return features
def get_loss_net(pastiche_net_output, input_tensor=None):
'''
Instantiates a VGG net and applies its layers on top of the pastiche net's
output.
'''
loss_net = vgg16.VGG16(weights='imagenet', include_top=False,
input_tensor=input_tensor)
targets_dict = dict([(layer.name, layer.output) for layer in loss_net.layers])
i = pastiche_net_output
# We need to apply all layers to the output of the style net
outputs_dict = {}
for l in loss_net.layers[1:]: # Ignore the input layer
i = l(i)
outputs_dict[l.name] = i
return loss_net, outputs_dict, targets_dict
def get_style_losses(outputs_dict, targets_dict, style_layers,
norm_by_channels=False):
'''
Returns the style loss for the desired layers
'''
return [style_loss(outputs_dict[l], targets_dict[l],
norm_by_channels=norm_by_channels)
for l in style_layers]
def get_content_losses(outputs_dict, targets_dict, content_layers):
return [content_loss(outputs_dict[l], targets_dict[l])
for l in content_layers]
def get_total_loss(content_losses, style_losses, total_var_loss,
content_weights, style_weights, tv_weights, class_targets):
total_loss = K.variable(0.)
# Compute content losses
for loss in content_losses:
weighted_loss = K.mean(K.gather(content_weights, class_targets) * loss)
weighted_content_losses.append(weighted_loss)
total_loss += weighted_loss
# Compute style losses
for loss in style_losses:
weighted_loss = K.mean(K.gather(style_weights, class_targets) * loss)
weighted_style_losses.append(weighted_loss)
total_loss += weighted_loss
# Compute tv loss
weighted_tv_loss = K.mean(K.gather(tv_weights, class_targets) *
total_var_loss)
total_loss += weighted_tv_loss
return (total_loss, weighted_content_losses, weighted_style_losses,
weighted_tv_loss)