-
Notifications
You must be signed in to change notification settings - Fork 11
/
loss.py
120 lines (92 loc) · 4.51 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import torch
import torch.nn as nn
def Entropy(input_):
epsilon = 1e-5
entropy = -input_ * torch.log(input_ + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
def grl_hook(coeff):
def fun1(grad):
return -coeff * grad.clone()
return fun1
def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
softmax_output = input_list[1].detach()
feature = input_list[0]
if random_layer is None:
op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
else:
random_out = random_layer.forward([feature, softmax_output])
ad_out = ad_net(random_out.view(-1, random_out.size(1)))
batch_size = softmax_output.size(0) // 2
dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
if entropy is not None:
entropy.register_hook(grl_hook(coeff))
entropy = 1.0 + torch.exp(-entropy)
source_mask = torch.ones_like(entropy)
source_mask[feature.size(0) // 2:] = 0
source_weight = entropy * source_mask
target_mask = torch.ones_like(entropy)
target_mask[0:feature.size(0) // 2] = 0
target_weight = entropy * target_mask
weight = source_weight / torch.sum(source_weight).detach().item() + \
target_weight / torch.sum(target_weight).detach().item()
l = nn.BCELoss(reduction='none')(ad_out, dc_target)
return torch.sum(weight.view(-1, 1) * nn.BCELoss()(ad_out, dc_target)) / torch.sum(weight).detach().item()
else:
return nn.BCELoss()(ad_out, dc_target)
def mdd_loss(features, labels, left_weight=1, right_weight=1):
softmax_out = nn.Softmax(dim=1)(features)
batch_size = features.size(0)
if float(batch_size) % 2 != 0:
raise Exception('Incorrect batch size provided')
batch_left = softmax_out[:int(0.5 * batch_size)]
batch_right = softmax_out[int(0.5 * batch_size):]
loss = torch.norm((batch_left - batch_right).abs(), 2, 1).sum() / float(batch_size)
labels_left = labels[:int(0.5 * batch_size)]
batch_left_loss = get_pari_loss1(labels_left, batch_left)
labels_right = labels[int(0.5 * batch_size):]
batch_right_loss = get_pari_loss1(labels_right, batch_right)
return loss + left_weight * batch_left_loss + right_weight * batch_right_loss
def mdd_digit(features, labels, left_weight=1, right_weight=1, weight=1):
softmax_out = nn.Softmax(dim=1)(features)
batch_size = features.size(0)
if float(batch_size) % 2 != 0:
raise Exception('Incorrect batch size provided')
batch_left = softmax_out[:int(0.5 * batch_size)]
batch_right = softmax_out[int(0.5 * batch_size):]
loss = torch.norm((batch_left - batch_right).abs(), 2, 1).sum() / float(batch_size)
labels_left = labels[:int(0.5 * batch_size)]
labels_left_left = labels_left[:int(0.25 * batch_size)]
labels_left_right = labels_left[int(0.25 * batch_size):]
batch_left_left = batch_left[:int(0.25 * batch_size)]
batch_left_right = batch_left[int(0.25 * batch_size):]
batch_left_loss = get_pair_loss(labels_left_left, labels_left_right, batch_left_left, batch_left_right)
labels_right = labels[int(0.5 * batch_size):]
labels_right_left = labels_right[:int(0.25 * batch_size)]
labels_right_right = labels_right[int(0.25 * batch_size):]
batch_right_left = batch_right[:int(0.25 * batch_size)]
batch_right_right = batch_right[int(0.25 * batch_size):]
batch_right_loss = get_pair_loss(labels_right_left, labels_right_right, batch_right_left, batch_right_right)
return weight*loss + left_weight * batch_left_loss + right_weight * batch_right_loss
def get_pair_loss(labels_left, labels_right, features_left, features_right):
loss = 0
for i in range(len(labels_left)):
if (labels_left[i] == labels_right[i]):
loss += torch.norm((features_left[i] - features_right[i]).abs(), 2, 0).sum()
return loss
def get_pari_loss1(labels, features):
loss = 0
count = 0
for i in range(len(labels)):
for j in range(i + 1, len(labels)):
if (labels[i] == labels[j]):
count += 1
loss += torch.norm((features[i] - features[j]).abs(), 2, 0).sum()
return loss / count
def EntropicConfusion(features):
softmax_out = nn.Softmax(dim=1)(features)
batch_size = features.size(0)
loss = torch.mul(softmax_out, torch.log(softmax_out)).sum() * (1.0 / batch_size)
return loss