1
+ import numpy as np
2
+ import torch
3
+ import torch .nn as nn
4
+ from torch .autograd import Variable
5
+ import math
6
+ import torch .nn .functional as F
7
+ import pdb
8
+
9
+ def Entropy (input_ ):
10
+ bs = input_ .size (0 )
11
+ entropy = - input_ * torch .log (input_ + 1e-7 )
12
+ entropy = torch .sum (entropy , dim = 1 )
13
+ return entropy
14
+
15
+ def grl_hook (coeff ):
16
+ def fun1 (grad ):
17
+ return - coeff * grad .clone ()
18
+ return fun1
19
+
20
+ def DANN (features , ad_net , entropy = None , coeff = None , cls_weight = None , len_share = 0 ):
21
+ ad_out = ad_net (features )
22
+ train_bs = (ad_out .size (0 ) - len_share ) // 2
23
+ dc_target = torch .from_numpy (np .array ([[1 ]] * train_bs + [[0 ]] * (train_bs + len_share ))).float ().cuda ()
24
+ if entropy is not None :
25
+ entropy .register_hook (grl_hook (coeff ))
26
+ entropy = 1.0 + torch .exp (- entropy )
27
+ else :
28
+ entropy = torch .ones (ad_out .size (0 )).cuda ()
29
+
30
+ source_mask = torch .ones_like (entropy )
31
+ source_mask [train_bs : 2 * train_bs ] = 0
32
+ source_weight = entropy * source_mask
33
+ source_weight = source_weight * cls_weight
34
+
35
+ target_mask = torch .ones_like (entropy )
36
+ target_mask [0 : train_bs ] = 0
37
+ target_mask [2 * train_bs ::] = 0
38
+ target_weight = entropy * target_mask
39
+ target_weight = target_weight * cls_weight
40
+
41
+ weight = (1.0 + len_share / train_bs ) * source_weight / (torch .sum (source_weight ).detach ().item ()) + \
42
+ target_weight / torch .sum (target_weight ).detach ().item ()
43
+
44
+ weight = weight .view (- 1 , 1 )
45
+ return torch .sum (weight * nn .BCELoss (reduction = 'none' )(ad_out , dc_target )) / (1e-8 + torch .sum (weight ).detach ().item ())
46
+
47
+ def marginloss (yHat , y , classes = 65 , alpha = 1 , weight = None ):
48
+ batch_size = len (y )
49
+ classes = classes
50
+ yHat = F .softmax (yHat , dim = 1 )
51
+ Yg = torch .gather (yHat , 1 , torch .unsqueeze (y , 1 ))#.detach()
52
+ Yg_ = (1 - Yg ) + 1e-7 # avoiding numerical issues (first)
53
+ Px = yHat / Yg_ .view (len (yHat ), 1 )
54
+ Px_log = torch .log (Px + 1e-10 ) # avoiding numerical issues (second)
55
+ y_zerohot = torch .ones (batch_size , classes ).scatter_ (
56
+ 1 , y .view (batch_size , 1 ).data .cpu (), 0 )
57
+
58
+ output = Px * Px_log * y_zerohot .cuda ()
59
+ loss = torch .sum (output , dim = 1 )/ np .log (classes - 1 )
60
+ Yg_ = Yg_ ** alpha
61
+ if weight is not None :
62
+ weight *= (Yg_ .view (len (yHat ), )/ Yg_ .sum ())
63
+ else :
64
+ weight = (Yg_ .view (len (yHat ), )/ Yg_ .sum ())
65
+
66
+ weight = weight .detach ()
67
+ loss = torch .sum (weight * loss ) / torch .sum (weight )
68
+
69
+ return loss
0 commit comments