-
Notifications
You must be signed in to change notification settings - Fork 105
/
loss.py
31 lines (23 loc) · 1.12 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
# Loss functions
def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not):
loss_1 = F.cross_entropy(y_1, t, reduce = False)
ind_1_sorted = np.argsort(loss_1.data).cuda()
loss_1_sorted = loss_1[ind_1_sorted]
loss_2 = F.cross_entropy(y_2, t, reduce = False)
ind_2_sorted = np.argsort(loss_2.data).cuda()
loss_2_sorted = loss_2[ind_2_sorted]
remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_1_sorted))
pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember)
pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember)
ind_1_update=ind_1_sorted[:num_remember]
ind_2_update=ind_2_sorted[:num_remember]
# exchange
loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])
return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2