Skip to content

Commit

Permalink
Corrected R1 loss calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Aug 21, 2019
1 parent c63e7c2 commit bc4bd85
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def train(args, dataset, generator, discriminator):

elif args.loss == 'r1':
real_image.requires_grad = True
real_predict = discriminator(real_image, step=step, alpha=alpha)
real_predict = F.softplus(-real_predict).mean()
real_scores = discriminator(real_image, step=step, alpha=alpha)
real_predict = F.softplus(-real_scores).mean()
real_predict.backward(retain_graph=True)

grad_real = grad(
outputs=real_predict.sum(), inputs=real_image, create_graph=True
outputs=real_scores.sum(), inputs=real_image, create_graph=True
)[0]
grad_penalty = (
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
Expand Down

0 comments on commit bc4bd85

Please sign in to comment.