forked from zhoubolei/TRN-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
class_balanced_loss.py
115 lines (91 loc) · 3.99 KB
/
class_balanced_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Pytorch implementation of Class-Balanced-Loss
Reference: "Class-Balanced Loss Based on Effective Number of Samples"
Authors: Yin Cui and
Menglin Jia and
Tsung Yi Lin and
Yang Song and
Serge J. Belongie
https://arxiv.org/abs/1901.05555, CVPR'19.
"""
import numpy as np
import torch
import torch.nn.functional as F
def focal_loss(labels, logits, alpha, gamma):
"""Compute the focal loss between `logits` and the ground truth `labels`.
Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
Args:
labels: A float tensor of size [batch, num_classes].
logits: A float tensor of size [batch, num_classes].
alpha: A float tensor of size [batch_size]
specifying per-example weight for balanced cross entropy.
gamma: A float scalar modulating loss from hard and easy examples.
Returns:
focal_loss: A float32 scalar representing normalized total loss.
"""
BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none")
if gamma == 0.0:
modulator = 1.0
else:
modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 +
torch.exp(-1.0 * logits)))
loss = modulator * BCLoss
weighted_loss = alpha * loss
focal_loss = torch.sum(weighted_loss)
focal_loss /= torch.sum(labels)
return focal_loss
def CB_loss(logits, labels, samples_per_cls=None, no_of_classes=4, loss_type='focal', beta=0.9999, gamma=1.0):
"""Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
where Loss is one of the standard losses used for Neural Networks.
Args:
labels: A int tensor of size [batch].
logits: A float tensor of size [batch, no_of_classes].
samples_per_cls: A python list of size [no_of_classes].
no_of_classes: total number of classes. int
loss_type: string. One of "sigmoid", "focal", "softmax".
beta: float. Hyperparameter for Class balanced loss.
gamma: float. Hyperparameter for Focal loss.
Returns:
cb_loss: A float tensor representing class balanced loss
"""
samples_per_cls = [1022, 841, 472, 70] # list in rachel train set
effective_num = 1.0 - np.power(beta, samples_per_cls)
weights = (1.0 - beta) / np.array(effective_num)
weights = weights / np.sum(weights) * no_of_classes
labels_one_hot = F.one_hot(labels.long(), no_of_classes).float()
# print('labels_one_hot ', labels_one_hot)
#
#
# print('weights ', weights)
weights = torch.tensor(weights).float().cuda()
weights = weights.unsqueeze(0)
#print('repeat ', weights.repeat(labels_one_hot.shape[0], 1))
weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
#print('weights1 ', weights)
weights = weights.sum(1)
#print('weights2 ', weights)
weights = weights.unsqueeze(1)
#print('weights3 ', weights)
weights = weights.repeat(1,no_of_classes)
#print('weights4 ', weights)
if loss_type == "focal":
cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
elif loss_type == "sigmoid":
cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
elif loss_type == "softmax":
pred = logits.softmax(dim = 1)
cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
return cb_loss
if __name__ == '__main__':
no_of_classes = 5
logits = torch.rand(10,no_of_classes).float()
labels = torch.randint(0,no_of_classes, size = (10,))
print('labels ', labels)
beta = 0.9999
gamma = 2.0
samples_per_cls = [2,3,1,2,2]
loss_type = "focal"
cb_loss = CB_loss(logits, labels, samples_per_cls, no_of_classes,loss_type, beta, gamma)
print(cb_loss)