-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
21 lines (17 loc) · 838 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch.nn import functional as F
from utils import compute_mu_sigma
# def content_loss(feat_g_t, feat_content):
# w = torch.abs(feat_content) + 0.1
# return (w * (feat_g_t - feat_content) ** 2).mean()
def content_loss(feat_g_t, feat_content):
color_loss = F.mse_loss(feat_g_t, feat_content)
return color_loss
def style_loss(phi_list_g_t, phi_list_style):
loss_style = 0.
for phi_list_g_t_step, phi_list_style_step in zip(phi_list_g_t.values(), phi_list_style.values()):
mu_g_t_step, sigma_g_t_step = compute_mu_sigma(phi_list_g_t_step)
mu_style_step, sigma_style_step = compute_mu_sigma(phi_list_style_step)
loss_style += F.mse_loss(mu_g_t_step, mu_style_step)
loss_style += F.mse_loss(sigma_g_t_step, sigma_style_step)
return loss_style / len(phi_list_g_t)