-
Notifications
You must be signed in to change notification settings - Fork 68
/
loss_functions.py
77 lines (71 loc) · 4.15 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
'''
-- Our loss implementations for wlsep, lsep, warp, bp_mll and bce with optional weighted learning
--
-- If you use these implementations in your paper please cite our paper https://arxiv.org/abs/1911.00232
--
-- scores is the output of the model
-- labels is a binary vector with a 1 indicating a positive class for each batch member
-- Both scores and labels have size BxC with B = batch size and C = number of classes
-- weights is an optional tensor of size C used to weight the learning for unbalanced training sets
-- We used w_i = min(count)/count_i for the weights to train the Multi-Moments model where
count_i is the number of examples in the training set with a positive label for class i
and min(count) is the number of examples with a positive label for the least common class.
--
-- By Mathew Monfort, mmonfort@mit.edu
'''
import torch
from torch.nn import functional as F
# https://arxiv.org/abs/1911.00232
def wlsep(scores, labels, weights=None):
mask = ((labels.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)) -
labels.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1))) > 0).float()
diffs = (scores.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1)) -
scores.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)))
if weights is not None:
return F.pad(diffs.add(-(1-mask)*1e10),
pad=(0,0,0,1)).logsumexp(dim=1).mul(weights).masked_select(labels.bool()).mean()
else:
return F.pad(diffs.add(-(1-mask)*1e10),
pad=(0,0,0,1)).logsumexp(dim=1).masked_select(labels.bool()).mean()
# http://openaccess.thecvf.com/content_cvpr_2017/html/Li_Improving_Pairwise_Ranking_CVPR_2017_paper.html
def lsep(scores, labels, weights=None):
mask = ((labels.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)) -
labels.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1))) > 0).float()
diffs = (scores.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1)) -
scores.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)))
return diffs.exp().mul(mask).sum().add(1).log().mean()
""" https://www.aaai.org/ocs/index.php/IJCAI/IJCAI11/paper/viewPaper/2926
We pre-compute the rank weights (rank_w) into a tensor as below:
rank_w = torch.zeros(num_classes)
sum = 0.
for i in range(num_classes):
sum += 1./(i+1)
rank_w[i] = sum
"""
def warp(scores, labels, rank_w, weights=None):
mask = ((labels.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)) -
labels.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1))) > 0).float()
diffs = (scores.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1)) -
scores.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1))).add(1)
if weights is not None:
return (diffs.clamp(0,1e10).mul(mask).sum(1).div(mask.sum(1)).mul(weights).masked_select(labels.bool())
.mul(rank_w.index_select(0,scores.sort(descending=True)[1].masked_select(labels.bool()))).mean())
else:
return (diffs.clamp(0,1e10).mul(mask).sum(1).div(mask.sum(1)).masked_select(labels.bool())
.mul(rank_w.index_select(0,scores.sort(descending=True)[1].masked_select(labels.bool()))).mean())
#https://ieeexplore.ieee.org/abstract/document/1683770
def bp_mll(scores, labels, weights=None):
mask = ((labels.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)) -
labels.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1))) > 0).float()
diffs = (scores.unsqueeze(2).expand(labels.size(0), labels.size(1), labels.size(1)) -
scores.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)))
if weights is not None:
return diffs.exp().mul(mask).sum(1).mul(weights).sum(1).mean()
else:
return diffs.exp().mul(mask).sum(1).sum(1).mean()
def bce(output, labels, weights=None):
if weights is not None:
return (((1.-weights)*labels + weights*(1.-labels))*
bceCriterion(output, torch.autograd.Variable(labels))).sum(1).mean()
else:
return bceCriterion(output, torch.autograd.Variable(labels)).sum(1).mean()