-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_function.py
25 lines (20 loc) · 914 Bytes
/
loss_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Pattern Recognition final project Group 8 Kaixiang Huang
# support function for cycleGAN
import keras.backend as K
def loss_fn(input, target):
return K.mean(K.square(input + 1e-12 - target), axis=-1)
def cycle_variables(netG1, netG2):
real_input = netG1.inputs[0]
fake_output = netG1.outputs[0]
rec_input = netG2([fake_output])
fn_generate = K.function([real_input], [fake_output, rec_input])
return real_input, fake_output, rec_input, fn_generate
def loss_generator(netD, real, fake, rec):
output_real = netD([real])
output_fake = netD([fake])
loss_D_real = loss_fn(output_real, K.ones_like(output_real))
loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake))
loss_G = loss_fn(output_fake, K.ones_like(output_fake))
loss_D = loss_D_real + loss_D_fake
loss_cyc = K.mean(K.abs(rec - real))
return loss_D, loss_G, loss_cyc