-
Notifications
You must be signed in to change notification settings - Fork 40
/
updater.py
86 lines (65 loc) · 2.68 KB
/
updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python
from __future__ import print_function
import chainer
import chainer.functions as F
from chainer import Variable
import numpy as np
from PIL import Image
from chainer import cuda
from chainer import function
from chainer.utils import type_check
import numpy
class FacadeUpdater(chainer.training.StandardUpdater):
def __init__(self, *args, **kwargs):
self.enc, self.dec, self.dis = kwargs.pop('models')
super(FacadeUpdater, self).__init__(*args, **kwargs)
def loss_enc(self, enc, x_out, t_out, y_out, lam1=100, lam2=1):
batchsize,_,w,h = y_out.data.shape
loss_rec = lam1*(F.mean_absolute_error(x_out, t_out))
loss_adv = lam2*F.sum(F.softplus(-y_out)) / batchsize / w / h
loss = loss_rec + loss_adv
chainer.report({'loss': loss}, enc)
return loss
def loss_dec(self, dec, x_out, t_out, y_out, lam1=100, lam2=1):
batchsize,_,w,h = y_out.data.shape
loss_rec = lam1*(F.mean_absolute_error(x_out, t_out))
loss_adv = lam2*F.sum(F.softplus(-y_out)) / batchsize / w / h
loss = loss_rec + loss_adv
chainer.report({'loss': loss}, dec)
return loss
def loss_dis(self, dis, y_in, y_out):
batchsize,_,w,h = y_in.data.shape
L1 = F.sum(F.softplus(-y_in)) / batchsize / w / h
L2 = F.sum(F.softplus(y_out)) / batchsize / w / h
loss = L1 + L2
chainer.report({'loss': loss}, dis)
return loss
def update_core(self):
enc_optimizer = self.get_optimizer('enc')
dec_optimizer = self.get_optimizer('dec')
dis_optimizer = self.get_optimizer('dis')
enc, dec, dis = self.enc, self.dec, self.dis
xp = enc.xp
batch = self.get_iterator('main').next()
batchsize = len(batch)
in_ch = batch[0][0].shape[0]
out_ch = batch[0][1].shape[0]
w_in = 256
w_out = 256
x_in = xp.zeros((batchsize, in_ch, w_in, w_in)).astype("f")
t_out = xp.zeros((batchsize, out_ch, w_out, w_out)).astype("f")
for i in range(batchsize):
x_in[i,:] = xp.asarray(batch[i][0])
t_out[i,:] = xp.asarray(batch[i][1])
x_in = Variable(x_in)
z = enc(x_in)
x_out = dec(z)
y_fake = dis(x_in, x_out)
y_real = dis(x_in, t_out)
enc_optimizer.update(self.loss_enc, enc, x_out, t_out, y_fake)
for z_ in z:
z_.unchain_backward()
dec_optimizer.update(self.loss_dec, dec, x_out, t_out, y_fake)
x_in.unchain_backward()
x_out.unchain_backward()
dis_optimizer.update(self.loss_dis, dis, y_real, y_fake)