-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
losses.py
237 lines (196 loc) · 7.66 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# ------------------------------------------------------------------------------
# Portions of this code are from
# CornerNet (https://github.com/princeton-vl/CornerNet)
# Copyright (c) 2018, University of Michigan
# Licensed under the BSD 3-Clause License
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
from .utils import _transpose_and_gather_feat
import torch.nn.functional as F
def _slow_neg_loss(pred, gt):
'''focal loss from CornerNet'''
pos_inds = gt.eq(1)
neg_inds = gt.lt(1)
neg_weights = torch.pow(1 - gt[neg_inds], 4)
loss = 0
pos_pred = pred[pos_inds]
neg_pred = pred[neg_inds]
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if pos_pred.nelement() == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
def _neg_loss(pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
def _not_faster_neg_loss(pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
num_pos = pos_inds.float().sum()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
trans_pred = pred * neg_inds + (1 - pred) * pos_inds
weight = neg_weights * neg_inds + pos_inds
all_loss = torch.log(1 - trans_pred) * torch.pow(trans_pred, 2) * weight
all_loss = all_loss.sum()
if num_pos > 0:
all_loss /= num_pos
loss -= all_loss
return loss
def _slow_reg_loss(regr, gt_regr, mask):
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr)
regr = regr[mask]
gt_regr = gt_regr[mask]
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
def _reg_loss(regr, gt_regr, mask):
''' L1 regression loss
Arguments:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
'''
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr).float()
regr = regr * mask
gt_regr = gt_regr * mask
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
class FocalLoss(nn.Module):
'''nn.Module warpper for focal loss'''
def __init__(self):
super(FocalLoss, self).__init__()
self.neg_loss = _neg_loss
def forward(self, out, target):
return self.neg_loss(out, target)
class RegLoss(nn.Module):
'''Regression loss for an output tensor
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def __init__(self):
super(RegLoss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _transpose_and_gather_feat(output, ind)
loss = _reg_loss(pred, target, mask)
return loss
class RegL1Loss(nn.Module):
def __init__(self):
super(RegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _transpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
class NormRegL1Loss(nn.Module):
def __init__(self):
super(NormRegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _transpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
pred = pred / (target + 1e-4)
target = target * 0 + 1
loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
class RegWeightedL1Loss(nn.Module):
def __init__(self):
super(RegWeightedL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _transpose_and_gather_feat(output, ind)
mask = mask.float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
class L1Loss(nn.Module):
def __init__(self):
super(L1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _transpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
return loss
class BinRotLoss(nn.Module):
def __init__(self):
super(BinRotLoss, self).__init__()
def forward(self, output, mask, ind, rotbin, rotres):
pred = _transpose_and_gather_feat(output, ind)
loss = compute_rot_loss(pred, rotbin, rotres, mask)
return loss
def compute_res_loss(output, target):
return F.smooth_l1_loss(output, target, reduction='elementwise_mean')
# TODO: weight
def compute_bin_loss(output, target, mask):
mask = mask.expand_as(output)
output = output * mask.float()
return F.cross_entropy(output, target, reduction='elementwise_mean')
def compute_rot_loss(output, target_bin, target_res, mask):
# output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos,
# bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
# target_bin: (B, 128, 2) [bin1_cls, bin2_cls]
# target_res: (B, 128, 2) [bin1_res, bin2_res]
# mask: (B, 128, 1)
# import pdb; pdb.set_trace()
output = output.view(-1, 8)
target_bin = target_bin.view(-1, 2)
target_res = target_res.view(-1, 2)
mask = mask.view(-1, 1)
loss_bin1 = compute_bin_loss(output[:, 0:2], target_bin[:, 0], mask)
loss_bin2 = compute_bin_loss(output[:, 4:6], target_bin[:, 1], mask)
loss_res = torch.zeros_like(loss_bin1)
if target_bin[:, 0].nonzero().shape[0] > 0:
idx1 = target_bin[:, 0].nonzero()[:, 0]
valid_output1 = torch.index_select(output, 0, idx1.long())
valid_target_res1 = torch.index_select(target_res, 0, idx1.long())
loss_sin1 = compute_res_loss(
valid_output1[:, 2], torch.sin(valid_target_res1[:, 0]))
loss_cos1 = compute_res_loss(
valid_output1[:, 3], torch.cos(valid_target_res1[:, 0]))
loss_res += loss_sin1 + loss_cos1
if target_bin[:, 1].nonzero().shape[0] > 0:
idx2 = target_bin[:, 1].nonzero()[:, 0]
valid_output2 = torch.index_select(output, 0, idx2.long())
valid_target_res2 = torch.index_select(target_res, 0, idx2.long())
loss_sin2 = compute_res_loss(
valid_output2[:, 6], torch.sin(valid_target_res2[:, 1]))
loss_cos2 = compute_res_loss(
valid_output2[:, 7], torch.cos(valid_target_res2[:, 1]))
loss_res += loss_sin2 + loss_cos2
return loss_bin1 + loss_bin2 + loss_res