forked from xavysp/TEED
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss2.py
92 lines (71 loc) · 3.35 KB
/
loss2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn.functional as F
from utils.AF.Fsmish import smish as Fsmish
def bdcn_loss2(inputs, targets, l_weight=1.1):
# bdcn loss modified in DexiNed
targets = targets.long()
mask = targets.float()
num_positive = torch.sum((mask > 0.0).float()).float() # >0.1
num_negative = torch.sum((mask <= 0.0).float()).float() # <= 0.1
mask[mask > 0.] = 1.0 * num_negative / (num_positive + num_negative) #0.1
mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) # before mask[mask <= 0.1]
inputs= torch.sigmoid(inputs)
cost = torch.nn.BCELoss(mask, reduction='none')(inputs, targets.float())
cost = torch.sum(cost.float().mean((1, 2, 3))) # before sum
return l_weight*cost
# ------------ cats losses ----------
def bdrloss(prediction, label, radius,device='cpu'):
'''
The boundary tracing loss that handles the confusing pixels.
'''
filt = torch.ones(1, 1, 2*radius+1, 2*radius+1)
filt.requires_grad = False
filt = filt.to(device)
bdr_pred = prediction * label
pred_bdr_sum = label * F.conv2d(bdr_pred, filt, bias=None, stride=1, padding=radius)
texture_mask = F.conv2d(label.float(), filt, bias=None, stride=1, padding=radius)
mask = (texture_mask != 0).float()
mask[label == 1] = 0
pred_texture_sum = F.conv2d(prediction * (1-label) * mask, filt, bias=None, stride=1, padding=radius)
softmax_map = torch.clamp(pred_bdr_sum / (pred_texture_sum + pred_bdr_sum + 1e-10), 1e-10, 1 - 1e-10)
cost = -label * torch.log(softmax_map)
cost[label == 0] = 0
return torch.sum(cost.float().mean((1, 2, 3)))
def textureloss(prediction, label, mask_radius, device='cpu'):
'''
The texture suppression loss that smooths the texture regions.
'''
filt1 = torch.ones(1, 1, 3, 3)
filt1.requires_grad = False
filt1 = filt1.to(device)
filt2 = torch.ones(1, 1, 2*mask_radius+1, 2*mask_radius+1)
filt2.requires_grad = False
filt2 = filt2.to(device)
pred_sums = F.conv2d(prediction.float(), filt1, bias=None, stride=1, padding=1)
label_sums = F.conv2d(label.float(), filt2, bias=None, stride=1, padding=mask_radius)
mask = 1 - torch.gt(label_sums, 0).float()
loss = -torch.log(torch.clamp(1-pred_sums/9, 1e-10, 1-1e-10))
loss[mask == 0] = 0
return torch.sum(loss.float().mean((1, 2, 3)))
def cats_loss(prediction, label, l_weight=[0.,0.], device='cpu'):
# tracingLoss
tex_factor,bdr_factor = l_weight
balanced_w = 1.1
label = label.float()
prediction = prediction.float()
with torch.no_grad():
mask = label.clone()
num_positive = torch.sum((mask == 1).float()).float()
num_negative = torch.sum((mask == 0).float()).float()
beta = num_negative / (num_positive + num_negative)
mask[mask == 1] = beta
mask[mask == 0] = balanced_w * (1 - beta)
mask[mask == 2] = 0
prediction = torch.sigmoid(prediction)
cost = torch.nn.functional.binary_cross_entropy(
prediction.float(), label.float(), weight=mask, reduction='none')
cost = torch.sum(cost.float().mean((1, 2, 3))) # by me
label_w = (label != 0).float()
textcost = textureloss(prediction.float(), label_w.float(), mask_radius=4, device=device)
bdrcost = bdrloss(prediction.float(), label_w.float(), radius=4, device=device)
return cost + bdr_factor * bdrcost + tex_factor * textcost