-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
74 lines (41 loc) · 2.13 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_zoo.loss import lovasz_hinge
from torch.nn.modules.loss import CrossEntropyLoss
__all__ = ['GT_BceDiceLoss_new2']
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_zoo.loss import lovasz_hinge
from torch.nn.modules.loss import CrossEntropyLoss
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _WeightedLoss
class BCEDiceLoss_newversion(nn.Module):
def __init__(self):
super().__init__()
self.bceloss = nn.BCELoss()
def forward(self, input, target):
input = torch.sigmoid(input)
smooth = 1e-5
num = target.size(0)
input = input.view(num, -1)
target = target.view(num, -1)
bce = self.bceloss(input,target)
intersection = (input * target)
dice = (2. * intersection.sum(1).pow(2) + smooth) / (input.sum(1).pow(2) + target.sum(1).pow(2) + smooth)
dice_loss = 1 - dice.sum() / num
dice1 = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
dice_loss1 = 1 - dice1.sum() / num
return bce +dice_loss+dice_loss1
class GT_BceDiceLoss_new2(nn.Module):
def __init__(self):
super(GT_BceDiceLoss_new2, self).__init__()
self.bcedice = BCEDiceLoss_newversion()
def forward(self, pre,out, target, epoch, num_epoch):
#print(epoch, num_epoch)
bcediceloss = self.bcedice(out, target)
gt_pre4, gt_pre3, gt_pre2, gt_pre1,gt_pre0 = pre
gt_loss = self.bcedice(gt_pre4, target) * 0.1 + self.bcedice(gt_pre3, target) * 0.2 + self.bcedice(gt_pre2, target) * 0.4 + self.bcedice(gt_pre1, target) * 0.6 +self.bcedice(gt_pre0, target) * 0.8
# print(bcediceloss)
return (2-torch.sin(torch.tensor(epoch/num_epoch*torch.pi/2)))*(bcediceloss + gt_loss),self.bcedice(gt_pre4, target),self.bcedice(gt_pre3, target),self.bcedice(gt_pre2, target),self.bcedice(gt_pre1, target),self.bcedice(gt_pre0, target)