-
Notifications
You must be signed in to change notification settings - Fork 38
/
losses.py
67 lines (46 loc) · 2.01 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import math
import torch
from torch import autograd
from torch.nn import functional as F
def viewpoints_loss(viewpoint_pred, viewpoint_target):
loss = F.smooth_l1_loss(viewpoint_pred, viewpoint_target)
return loss
def eikonal_loss(eikonal_term, sdf=None, beta=100, deltasdf=False):
if eikonal_term == None:
eikonal_loss = 0
else:
if deltasdf:
eikonal_loss = ((eikonal_term.norm(dim=-1)) ** 2).mean()
else:
eikonal_loss = ((eikonal_term.norm(dim=-1) - 1) ** 2).mean()
if sdf == None:
minimal_surface_loss = torch.tensor(0.0, device=eikonal_term.device)
else:
if deltasdf:
minimal_surface_loss = torch.nn.functional.smooth_l1_loss(sdf, torch.zeros_like(sdf)).mean()
else:
minimal_surface_loss = torch.exp(-beta * torch.abs(sdf)).mean()
return eikonal_loss, minimal_surface_loss
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
grad_real, = autograd.grad(outputs=real_pred.sum(),
inputs=real_img,
create_graph=True)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred):
loss = F.softplus(-fake_pred).mean()
return loss
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3]
)
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents,
create_graph=True, only_inputs=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_mean.detach(), path_lengths