-
Notifications
You must be signed in to change notification settings - Fork 2
/
decoupledsegnet_relax_boundary_loss.py
129 lines (115 loc) · 5.05 KB
/
decoupledsegnet_relax_boundary_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import nn
import paddle.nn.functional as F
from scipy.ndimage import shift
from paddleseg.cvlibs import manager
@manager.LOSSES.add_component
class RelaxBoundaryLoss(nn.Layer):
"""
Implements the ohem cross entropy loss function.
Args:
border (int, optional): The value of border to relax. Default: 1.
calculate_weights (bool, optional): Whether to calculate weights for every classes. Default: False.
upper_bound (float, optional): The upper bound of weights if calculating weights for every classes. Default: 1.0.
ignore_index (int64): Specifies a target value that is ignored
and does not contribute to the input gradient. Default: 255.
"""
def __init__(self,
border=1,
calculate_weights=False,
upper_bound=1.0,
ignore_index=255):
super(RelaxBoundaryLoss, self).__init__()
self.border = border
self.calculate_weights = calculate_weights
self.upper_bound = upper_bound
self.ignore_index = ignore_index
self.EPS = 1e-5
def relax_onehot(self, label, num_classes):
# pad label, and let ignore_index as num_classes
if len(label.shape) == 3:
label = label.unsqueeze(1)
h, w = label.shape[-2], label.shape[-1]
label = F.pad(label, [self.border] * 4, value=num_classes)
label = label.squeeze(1)
ignore_mask = (label == self.ignore_index).astype('int64')
label = label * (1 - ignore_mask) + num_classes * ignore_mask
onehot = 0
for i in range(-self.border, self.border + 1):
for j in range(-self.border, self.border + 1):
h_start, h_end = 1 + i, h + 1 + i
w_start, w_end = 1 + j, w + 1 + j
label_ = label[:, h_start:h_end, w_start:w_end]
onehot_ = F.one_hot(label_, num_classes + 1)
onehot += onehot_
onehot = (onehot > 0).astype('int64')
onehot = paddle.transpose(onehot, (0, 3, 1, 2))
return onehot
def calculate_weights(self, label):
hist = paddle.sum(label, axis=(1, 2)) * 1.0 / label.sum()
hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
def custom_nll(self,
logit,
label,
class_weights=None,
border_weights=None,
ignore_mask=None):
soft = F.softmax(logit, axis=1)
# calculate the valid soft where label is 1.
soft_label = ((soft * label[:, :-1, :, :]).sum(
1, keepdim=True)) * (label[:, :-1, :, :].astype('float32'))
soft = soft * (1 - label[:, :-1, :, :]) + soft_label
logsoft = paddle.log(soft)
if class_weights is not None:
logsoft = class_weights.unsqueeze((0, 2, 3))
logsoft = label[:, :-1, :, :] * logsoft
logsoft = logsoft.sum(1)
# border loss is divided equally
logsoft = -1 / border_weights * logsoft * (1. - ignore_mask)
n, _, h, w = label.shape
logsoft = logsoft.sum() / (n * h * w - ignore_mask.sum() + 1)
return logsoft
def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
"""
n, c, h, w = logit.shape
label.stop_gradient = True
label = self.relax_onehot(label, c)
weights = label[:, :-1, :, :].sum(1).astype('float32')
ignore_mask = (weights == 0).astype('float32')
# border is greater than 1, other is 1
border_weights = weights + ignore_mask
loss = 0
class_weights = None
for i in range(n):
if self.calculate_weights:
class_weights = self.calculate_weights(label[i])
loss = loss + self.custom_nll(
logit[i].unsqueeze(0),
label[i].unsqueeze(0),
class_weights=class_weights,
border_weights=border_weights,
ignore_mask=ignore_mask[i])
return loss