-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
110 lines (79 loc) · 3.34 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
"""Loss function.
* Author: Minseong Kim(tyui592@gmail.com)
"""
import torch
import torch.nn.functional as F
from typing import List
def calc_meanstd_loss(features: List[torch.Tensor],
targets: List[torch.Tensor],
weights: List[float] = None) -> torch.Tensor:
"""Calculate mean std loss with list of features."""
if weights is None:
weights = [1/len(features)] * len(features)
loss = 0
for f, t, w in zip(features, targets, weights):
f_std, f_mean = torch.std_mean(f.flatten(2), dim=2)
t_std, t_mean = torch.std_mean(t.flatten(2), dim=2)
loss += (F.mse_loss(f_std, t_std) + F.mse_loss(f_mean, t_mean)) * w
return loss / len(features)
def calc_l2_loss(features: List[torch.Tensor],
targets: List[torch.Tensor],
weights: List[float] = None) -> torch.Tensor:
"""Calculate content(L2) loss with list of features."""
if weights is None:
weights = [1/len(features)] * len(features)
loss = 0
for f, t, w in zip(features, targets, weights):
loss += F.mse_loss(f, t) * w
return loss / len(features)
def calc_uncorrelation_loss(features: List[torch.Tensor],
weights: List[float] = None,
eps: float = 1e-5) -> torch.Tensor:
"""Calculate uncorrealtion loss with list of features."""
if weights is None:
weights = [1/len(features)] * len(features)
loss = 0
for f, w in zip(features, weights):
# flatten a feature map to the vector
v = f.flatten(2)
# mean vector
m = torch.mean(v, dim=2, keepdim=True)
# move to zero mean
zm = v - m
# calculate the covariance
cov = torch.bmm(zm, zm.transpose(2, 1))
# correlation coefficient
zm_std = torch.sqrt(torch.sum(torch.pow(zm, 2), dim=2, keepdim=True))
denominator = torch.bmm(zm_std, zm_std.transpose(2, 1))
corr = cov / (denominator + eps)
# sum all off-diagonal terms
num_ch = corr.shape[1]
ones = torch.ones(num_ch).unsqueeze(0).type_as(corr)
diag = torch.eye(num_ch).unsqueeze(0).type_as(corr)
offdiag = ones - diag
# normalize with the number of off-diagonals
uncorr_loss = torch.sum(torch.abs(corr) * offdiag) / torch.sum(offdiag)
loss += uncorr_loss * w
return loss / len(features)
def calc_channel_loss(feature: torch.Tensor,
eps: float = 1e-5) -> torch.Tensor:
"""Calculate the number of nonzero response (l0-norm) channel."""
B, C = feature.shape[:2]
vector = feature.flatten(2)
l2_norm = torch.norm(vector, p=2, dim=2)
l0_norm = l2_norm / (l2_norm + eps)
channel_loss = torch.sum(l0_norm)
return channel_loss / (B * C)
def calc_xor_loss(feature: torch.Tensor,
eps: float = 1e-5) -> torch.Tensor:
"""Calculate positional variation of nonzero/zero response channel."""
B, C = feature.shape[:2]
vector = feature.flatten(2)
l2_norm = torch.norm(vector, p=2, dim=2)
l0_norm = l2_norm / (l2_norm + eps)
l0_norm_adddim = l0_norm.unsqueeze(0)
diff = torch.abs(l0_norm_adddim - l0_norm_adddim.transpose(1, 0))
xor = torch.sum(diff, dim=2)
xor_loss = torch.sum(torch.triu(xor))
return xor_loss / (C * B * (B - 1) / 2)