-
Notifications
You must be signed in to change notification settings - Fork 12
/
bi_tempered_loss.py
109 lines (77 loc) · 3.5 KB
/
bi_tempered_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
def log_t(u, t):
"""Compute log_t for `u`."""
if t == 1.0:
return torch.log(u)
else:
return (u ** (1.0 - t) - 1.0) / (1.0 - t)
def exp_t(u, t):
"""Compute exp_t for `u`."""
if t == 1.0:
return torch.exp(u)
else:
return torch.relu(1.0 + (1.0 - t) * u) ** (1.0 / (1.0 - t))
def compute_normalization_fixed_point(activations, t, num_iters=5):
"""Returns the normalization value for each example (t > 1.0).
Args:
activations: A multi-dimensional tensor with last dimension `num_classes`.
t: Temperature 2 (> 1.0 for tail heaviness).
num_iters: Number of iterations to run the method.
Return: A tensor of same rank as activation with the last dimension being 1.
"""
mu = torch.max(activations, dim=-1).values.view(-1, 1)
normalized_activations_step_0 = activations - mu
normalized_activations = normalized_activations_step_0
i = 0
while i < num_iters:
i += 1
logt_partition = torch.sum(exp_t(normalized_activations, t), dim=-1).view(-1, 1)
normalized_activations = normalized_activations_step_0 * (logt_partition ** (1.0 - t))
logt_partition = torch.sum(exp_t(normalized_activations, t), dim=-1).view(-1, 1)
return -log_t(1.0 / logt_partition, t) + mu
def compute_normalization(activations, t, num_iters=5):
"""Returns the normalization value for each example.
Args:
activations: A multi-dimensional tensor with last dimension `num_classes`.
t: Temperature 2 (< 1.0 for finite support, > 1.0 for tail heaviness).
num_iters: Number of iterations to run the method.
Return: A tensor of same rank as activation with the last dimension being 1.
"""
if t < 1.0:
return None # not implemented as these values do not occur in the authors experiments...
else:
return compute_normalization_fixed_point(activations, t, num_iters)
def tempered_softmax(activations, t, num_iters=5):
"""Tempered softmax function.
Args:
activations: A multi-dimensional tensor with last dimension `num_classes`.
t: Temperature tensor > 0.0.
num_iters: Number of iterations to run the method.
Returns:
A probabilities tensor.
"""
if t == 1.0:
normalization_constants = torch.log(torch.sum(torch.exp(activations), dim=-1))
else:
normalization_constants = compute_normalization(activations, t, num_iters)
return exp_t(activations - normalization_constants, t)
def bi_tempered_logistic_loss(activations, labels, t1, t2, label_smoothing=0.0, num_iters=5):
"""Bi-Tempered Logistic Loss with custom gradient.
Args:
activations: A multi-dimensional tensor with last dimension `num_classes`.
labels: A tensor with shape and dtype as activations.
t1: Temperature 1 (< 1.0 for boundedness).
t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
label_smoothing: Label smoothing parameter between [0, 1).
num_iters: Number of iterations to run the method.
Returns:
A loss tensor.
"""
if label_smoothing > 0.0:
num_classes = labels.shape[-1]
labels = (1 - num_classes / (num_classes - 1) * label_smoothing) * labels + label_smoothing / (num_classes - 1)
probabilities = tempered_softmax(activations, t2, num_iters)
temp1 = (log_t(labels + 1e-10, t1) - log_t(probabilities, t1)) * labels
temp2 = (1 / (2 - t1)) * (torch.pow(labels, 2 - t1) - torch.pow(probabilities, 2 - t1))
loss_values = temp1 - temp2
return torch.sum(loss_values, dim=-1)