-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathfunctions.py
84 lines (79 loc) · 3.71 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def compute_gradient_penalty(D, real_samples, fake_samples, image_c, z_speech, noise):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = torch.FloatTensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(params['DEVICE'])
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(image_c, interpolates, z_speech, noise)
fake = Variable(torch.FloatTensor(d_interpolates.shape).fill_(1.0), requires_grad=False).to(params['DEVICE'])
# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)
gradient_penalty = 0#gradients[0].view(gradients[0].size(0), -1)
for grad in gradients:#[1:]:
gradient_penalty += ((grad.norm(2, dim=1) - 1) ** 2).mean()#torch.cat((grads, grad.view(grad.size(0), -1)), 1)
# gradient_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def compute_gradient_penalty_F(D, real_samples, fake_samples, image_c):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = torch.FloatTensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(params['DEVICE'])
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates, image_c)
fake = Variable(torch.FloatTensor(d_interpolates.shape).fill_(1.0), requires_grad=False).to(params['DEVICE'])
# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)
gradient_penalty = 0#gradients[0].view(gradients[0].size(0), -1)
for grad in gradients:#[1:]:
gradient_penalty += ((grad.norm(2, dim=1) - 1) ** 2).mean()#torch.cat((grads, grad.view(grad.size(0), -1)), 1)
# gradient_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def compute_grad_P(D, video, speech):
batch_size = video.size(0)
d_out = D(video.requires_grad_(True), speech)
# print(d_out.shape, torch.ones(d_out.size()).to(params['DEVICE']).shape)
grad_dout = torch.autograd.grad(
outputs= d_out,
inputs= video,
grad_outputs= torch.ones(d_out.size()).to(params['DEVICE']),
create_graph=True,
retain_graph=True,
only_inputs=True,
)
grad_norm = 0
for grad in grad_dout:
grad = grad.view(grad.size(0), -1)
grad_norm += grad.norm(2, dim=1).mean()
return grad_norm
return grad_norm
def compute_grad_F(D, video, image_c):
batch_size = video.size(0)
d_out = D(video.requires_grad_(True), image_c)
# print(d_out.shape, torch.ones(d_out.size()).to(params['DEVICE']).shape)
grad_dout = torch.autograd.grad(
outputs= d_out,
inputs= video,
grad_outputs= torch.ones(d_out.size()).to(params['DEVICE']),
create_graph=True,
retain_graph=True,
only_inputs=True,
)
grad_norm = 0
for grad in grad_dout:
grad = grad.view(grad.size(0), -1)
grad_norm += grad.norm(2, dim=1).mean()
return grad_norm