Skip to content

Commit d97da63

Browse files
committed
first-upload
1 parent 9d6280a commit d97da63

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

my_loss.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
import math
6+
import torch.nn.functional as F
7+
import pdb
8+
9+
def Entropy(input_):
10+
bs = input_.size(0)
11+
entropy = -input_ * torch.log(input_ + 1e-7)
12+
entropy = torch.sum(entropy, dim=1)
13+
return entropy
14+
15+
def grl_hook(coeff):
16+
def fun1(grad):
17+
return -coeff*grad.clone()
18+
return fun1
19+
20+
def DANN(features, ad_net, entropy=None, coeff=None, cls_weight=None, len_share=0):
21+
ad_out = ad_net(features)
22+
train_bs = (ad_out.size(0) - len_share) // 2
23+
dc_target = torch.from_numpy(np.array([[1]] * train_bs + [[0]] * (train_bs + len_share))).float().cuda()
24+
if entropy is not None:
25+
entropy.register_hook(grl_hook(coeff))
26+
entropy = 1.0 + torch.exp(-entropy)
27+
else:
28+
entropy = torch.ones(ad_out.size(0)).cuda()
29+
30+
source_mask = torch.ones_like(entropy)
31+
source_mask[train_bs : 2 * train_bs] = 0
32+
source_weight = entropy * source_mask
33+
source_weight = source_weight * cls_weight
34+
35+
target_mask = torch.ones_like(entropy)
36+
target_mask[0 : train_bs] = 0
37+
target_mask[2 * train_bs::] = 0
38+
target_weight = entropy * target_mask
39+
target_weight = target_weight * cls_weight
40+
41+
weight = (1.0 + len_share / train_bs) * source_weight / (torch.sum(source_weight).detach().item()) + \
42+
target_weight / torch.sum(target_weight).detach().item()
43+
44+
weight = weight.view(-1, 1)
45+
return torch.sum(weight * nn.BCELoss(reduction='none')(ad_out, dc_target)) / (1e-8 + torch.sum(weight).detach().item())
46+
47+
def marginloss(yHat, y, classes=65, alpha=1, weight=None):
48+
batch_size = len(y)
49+
classes = classes
50+
yHat = F.softmax(yHat, dim=1)
51+
Yg = torch.gather(yHat, 1, torch.unsqueeze(y, 1))#.detach()
52+
Yg_ = (1 - Yg) + 1e-7 # avoiding numerical issues (first)
53+
Px = yHat / Yg_.view(len(yHat), 1)
54+
Px_log = torch.log(Px + 1e-10) # avoiding numerical issues (second)
55+
y_zerohot = torch.ones(batch_size, classes).scatter_(
56+
1, y.view(batch_size, 1).data.cpu(), 0)
57+
58+
output = Px * Px_log * y_zerohot.cuda()
59+
loss = torch.sum(output, dim=1)/ np.log(classes - 1)
60+
Yg_ = Yg_ ** alpha
61+
if weight is not None:
62+
weight *= (Yg_.view(len(yHat), )/ Yg_.sum())
63+
else:
64+
weight = (Yg_.view(len(yHat), )/ Yg_.sum())
65+
66+
weight = weight.detach()
67+
loss = torch.sum(weight * loss) / torch.sum(weight)
68+
69+
return loss

0 commit comments

Comments
 (0)