-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlogit_margin_l1.py
90 lines (77 loc) · 3.5 KB
/
logit_margin_l1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
class LogitMarginL1(nn.Module):
"""Add marginal penalty to logits:
CE + alpha * max(0, max(l^n) - l^n - margin)
Args:
margin (float, optional): The margin value. Defaults to 10.
alpha (float, optional): The balancing weight. Defaults to 0.1.
ignore_index (int, optional):
Specifies a target value that is ignored
during training. Defaults to -100.
The following args are related to balancing weight (alpha) scheduling.
Note all the results presented in our paper are obtained without the scheduling strategy.
So it's fine to ignore if you don't want to try it.
schedule (str, optional):
Different stragety to schedule the balancing weight alpha or not:
"" | add | multiply | step. Defaults to "" (no scheduling).
To activate schedule, you should call function
`schedula_alpha` every epoch in your training code.
mu (float, optional): scheduling weight. Defaults to 0.
max_alpha (float, optional): Defaults to 100.0.
step_size (int, optional): The step size for updating alpha. Defaults to 100.
"""
def __init__(self,
margin: float = 10,
alpha: float = 0.1,
ignore_index: int = -100,
schedule: str = "",
mu: float = 0,
max_alpha: float = 100.0,
step_size: int = 100):
super().__init__()
assert schedule in ("", "add", "multiply", "step")
self.margin = margin
self.alpha = alpha
self.ignore_index = ignore_index
self.mu = mu
self.schedule = schedule
self.max_alpha = max_alpha
self.step_size = step_size
self.cross_entropy = nn.CrossEntropyLoss()
@property
def names(self):
return "loss", "loss_ce", "loss_margin_l1"
def schedule_alpha(self, epoch):
"""Should be called in the training pipeline if you want to se schedule alpha
"""
if self.schedule == "add":
self.alpha = min(self.alpha + self.mu, self.max_alpha)
elif self.schedule == "multiply":
self.alpha = min(self.alpha * self.mu, self.max_alpha)
elif self.schedule == "step":
if (epoch + 1) % self.step_size == 0:
self.alpha = min(self.alpha * self.mu, self.max_alpha)
def get_diff(self, inputs):
max_values = inputs.max(dim=1)
max_values = max_values.values.unsqueeze(dim=1).repeat(1, inputs.shape[1])
diff = max_values - inputs
return diff
def forward(self, inputs, targets):
if inputs.dim() > 2:
inputs = inputs.view(inputs.size(0), inputs.size(1), -1) # N,C,H,W => N,C,H*W
inputs = inputs.transpose(1, 2) # N,C,H*W => N,H*W,C
inputs = inputs.contiguous().view(-1, inputs.size(2)) # N,H*W,C => N*H*W,C
targets = targets.view(-1)
if self.ignore_index >= 0:
index = torch.nonzero(targets != self.ignore_index).squeeze()
inputs = inputs[index, :]
targets = targets[index]
loss_ce = self.cross_entropy(inputs, targets)
# get logit distance
diff = self.get_diff(inputs)
# linear penalty where logit distances are larger than the margin
loss_margin = F.relu(diff-self.margin).mean()
loss = loss_ce + self.alpha * loss_margin
return loss, loss_ce, loss_margin