-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpytorch_utils.py
113 lines (91 loc) · 4.29 KB
/
pytorch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def lr_poly(base_lr, iter, max_iter, power):
return base_lr * ((1 - float(iter) / max_iter) ** (power))
def adjust_learning_rate(optimizer, i_iter, args):
lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power)
optimizer.param_groups[0]['lr'] = lr
if len(optimizer.param_groups) > 1:
optimizer.param_groups[1]['lr'] = lr * 10
def adjust_learning_rate_D(optimizer, i_iter, args):
lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power)
optimizer.param_groups[0]['lr'] = lr
if len(optimizer.param_groups) > 1:
optimizer.param_groups[1]['lr'] = lr * 10
def calc_mse_loss(item1, item2, batch_size):
criterion = nn.MSELoss(reduce=False)
return criterion(item1, item2).sum() / batch_size
def calc_l1_loss(item1, item2, batch_size, gpu):
item2 = Variable(item2.float()).cuda(gpu)
criterion = nn.L1Loss()
return criterion(item1, item2).sum() / batch_size
class LossMulti(nn.Module):
def __init__(self, jaccard_weight=0, class_weights=None, num_classes=1):
if class_weights is not None:
self.nll_weight = class_weights#Variable(class_weights.float()).cuda()
else:
self.nll_weight = None
self.jaccard_weight = jaccard_weight
self.num_classes = num_classes
def __call__(self, outputs, targets):
loss = (1 - self.jaccard_weight) * F.cross_entropy(outputs, targets, weight=self.nll_weight)
if self.jaccard_weight:
eps = 1e-15
outputs = F.softmax(outputs)
for cls in range(self.num_classes):
jaccard_target = (targets == cls).float()
jaccard_output = outputs[:, cls]#.exp()
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
loss -= torch.log((intersection + eps) / (union - intersection + eps)) * self.jaccard_weight
return loss
def Weighted_Jaccard_loss (label, pred, class_weights=None, gpu=0):
"""
This function returns cross entropy loss for semantic segmentation
"""
# out shape batch_size x channels x h x w -> batch_size x channels x h x w
# label shape h x w x 1 x batch_size -> batch_size x 1 x h x w
label = Variable(label.long()).cuda(gpu)
if class_weights is not None and class_weights != 0:
class_weights = torch.Tensor(class_weights)
class_weights = Variable(class_weights).cuda(gpu)
criterion = LossMulti(jaccard_weight=0.5, class_weights=class_weights,num_classes=3)#.cuda(gpu)
else:
criterion = LossMulti(jaccard_weight=0.5, num_classes=3) # .cuda(gpu)
return criterion(pred, label)
def dice_loss(true, logits, eps=1e-7):
"""Computes the Sørensen–Dice loss.
Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the dice loss so we
return the negated dice loss.
Args:
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps: added to the denominator for numerical stability.
Returns:
dice_loss: the Sørensen–Dice loss.
https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
"""
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
dice_loss = (2. * intersection / (cardinality + eps)).mean()
return (1 - dice_loss)