-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
71 lines (54 loc) · 2.64 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Contains losses used for performing image-to-image domain adaptation."""
import tensorflow as tf
def cycle_consistency_loss(real_images, generated_images):
"""Compute the cycle consistency loss.
The cycle consistency loss is defined as the sum of the L1 distances
between the real images from each domain and their generated (fake)
counterparts.
This definition is derived from Equation 2 in:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks.
Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros.
Args:
real_images: A batch of images from domain X, a `Tensor` of shape
[batch_size, height, width, channels].
generated_images: A batch of generated images made to look like they
came from domain X, a `Tensor` of shape
[batch_size, height, width, channels].
Returns:
The cycle consistency loss.
"""
return tf.reduce_mean(tf.abs(real_images - generated_images))
def lsgan_loss_generator(prob_fake_is_real):
"""Computes the LS-GAN loss as minimized by the generator.
Rather than compute the negative loglikelihood, a least-squares loss is
used to optimize the discriminators as per Equation 2 in:
Least Squares Generative Adversarial Networks
Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
Stephen Paul Smolley.
https://arxiv.org/pdf/1611.04076.pdf
Args:
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total LS-GAN loss.
"""
return tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 1))
def lsgan_loss_discriminator(prob_real_is_real, prob_fake_is_real):
"""Computes the LS-GAN loss as minimized by the discriminator.
Rather than compute the negative loglikelihood, a least-squares loss is
used to optimize the discriminators as per Equation 2 in:
Least Squares Generative Adversarial Networks
Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
Stephen Paul Smolley.
https://arxiv.org/pdf/1611.04076.pdf
Args:
prob_real_is_real: The discriminator's estimate that images actually
drawn from the real domain are in fact real.
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total LS-GAN loss.
"""
return (tf.reduce_mean(tf.squared_difference(prob_real_is_real, 1)) +
tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 0))) * 0.5