-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodel_utils.py
119 lines (90 loc) · 3.6 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import torch.nn.functional as F
class Online_Contrastive_Loss(nn.Module):
def __init__(self, margin=2.0, num_classes=200):
super(Online_Contrastive_Loss, self).__init__()
self.margin = margin
self.num_classes = num_classes
def forward(self, x, label):
# compute pair wise distance
n = x.size(0)
xxt = torch.matmul(x, x.t())
xn = torch.sum(torch.mul(x, x), keepdim=True, dim=-1)
dist = xn.t() + xn - 2.0*xxt
one_hot_label = torch.zeros(x.size(0), self.num_classes)
one_hot_label.scatter_(1, label.unsqueeze(-1), 1)
pmask = torch.matmul(one_hot_label, one_hot_label.t())
nmask = (1-pmask)
pmask[torch.eye(pmask.shape[0]) > 0] = 0.0
pmask = pmask > 0
nmask = nmask > 0
ploss = torch.sum(torch.masked_select(dist, pmask)) /torch.sum(pmask)
nloss = torch.sum(torch.clamp(self.margin - torch.masked_select(dist, nmask), min=0.0)) /torch.sum(nmask)
#
# mining the top k hardest negative pairs
# neg_dist = torch.masked_select(-dist, nmask)
# k = torch.sum(pmask)
# neg_dist, _ = neg_dist.topk(k=k)
# nloss = torch.sum(torch.clamp(self.margin + neg_dist, min=0.0))/k
loss = (ploss + nloss)
return loss
class Pooling_Classifier(nn.Module):
def __init__(self, feat_dim, num_classes, pooling_type='AVE'):
super(Pooling_Classifier, self).__init__()
self.feat_dim = feat_dim
self.num_classes = num_classes
self.pooling_type = pooling_type
self.fc = nn.Linear(feat_dim, num_classes)
# self.attention_layer = []
if pooling_type == 'NAN':
self.attention_layer = NAN_Attention(feat_dim)
def forward(self, x, lst_lens):
x = F.normalize(x)
if self.pooling_type == 'AVE':
lst_mx = []
idx = 0
for i in range(lst_lens.size(0)):
set_x = x[idx:idx+lst_lens[i], :]
idx += lst_lens[i]
set_mx = set_x.mean(dim=0, keepdim=True)
lst_mx.append(set_mx)
lst_mx = torch.cat(lst_mx, dim=0)
logits = self.fc(lst_mx)
return lst_mx, logits
else:
lst_mx = []
idx = 0
for i in range(lst_lens.size(0)):
set_x = x[idx:idx + lst_lens[i], :] # n*d
idx += lst_lens[i]
lst_mx.append(set_x.t().unsqueeze(0))
lst_mx = torch.cat(lst_mx, dim=0)
feats = self.attention_layer(lst_mx)
logits = self.fc(feats)
return feats, logits
class NAN_Attention(nn.Module):
def __init__(self, feat_dim=128, set_size=20):
super(NAN_Attention, self).__init__()
self.q = nn.Parameter(torch.ones((1, 1, feat_dim)) * 0.0)
# self.q = nn.Parameter(torch.Tensor((1, 1, feat_dim)))
# nn.init.xavier_uniform_(self.q)
self.fc = nn.Linear(feat_dim, feat_dim)
self.tanh = nn.Tanh()
self.fc.bias.data.zero_()
self.fc.weight.data.zero_()
def forward(self, Xs):
# Xs: N*C*K
N, C, K = Xs.shape
score = torch.matmul(self.q, Xs) # N*1*K
score = F.softmax(score, dim=-1)
r = torch.mul(Xs, score)
r = torch.sum(r, dim=-1) # N*C
new_q = self.fc(r) # N*C
new_q = self.tanh(new_q)
new_q = new_q.view(N, 1, C)
new_score = torch.matmul(new_q, Xs)
new_score = F.softmax(new_score, dim=-1)
o = torch.mul(Xs, new_score)
o = torch.sum(o, dim=-1) #N*C
return o