-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSAN.py
102 lines (81 loc) · 3.68 KB
/
SAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from language_model import WordEmbedding, QuestionEmbedding
from classifier import SimpleClassifier, PaperClassifier
from fc import FCNet, GTH
from attention import Att_0, Att_1, Att_2, Att_3, Att_P, Att_PD, Att_3S
import torch
import random
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
num_hid = opt.num_hid
activation = opt.activation
dropG = opt.dropG
dropW = opt.dropW
dropout = opt.dropout
dropL = opt.dropL
norm = opt.norm
dropC = opt.dropC
self.opt = opt
self.num_layer = 2
self.w_emb = WordEmbedding(opt.ntokens, emb_dim=300, dropout=dropW)
# self.w_emb.init_embedding(opt.dataroot + 'glove6b_init_300d.npy')
self.q_emb = QuestionEmbedding(in_dim=300, num_hid=num_hid, nlayers=1,
bidirect=False, dropout=dropG, rnn_type='LSTM')
self.q_net = FCNet([self.q_emb.num_hid, num_hid], dropout=dropL, norm=norm, act=activation)
self.gv_net = FCNet([opt.v_dim, num_hid], dropout=dropL, norm=norm, act=activation)
# self.gv_att_1 = Att_3(v_dim=opt.v_dim, q_dim=self.q_emb.num_hid, num_hid=num_hid, dropout=dropout, norm=norm,
# act=activation)
# self.gv_att_2 = Att_3(v_dim=opt.v_dim, q_dim=self.q_emb.num_hid, num_hid=num_hid, dropout=dropout, norm=norm,
# act=activation)
self.gv_modules = nn.ModuleList()
self.q_modules = nn.ModuleList()
self.att_down = nn.ModuleList()
for _ in range(self.num_layer):
self.gv_modules.append(nn.Linear(num_hid, num_hid))
self.q_modules.append(nn.Linear(num_hid, num_hid))
self.att_down.append(nn.Linear(num_hid, 1))
self.classifier = SimpleClassifier(in_dim=num_hid, hid_dim=2 * num_hid, out_dim=opt.ans_dim,
dropout=dropC, norm=norm, act=activation)
# self.normal = nn.BatchNorm1d(num_hid,affine=False)
def forward(self, q, gv_pos, self_sup=True):
"""Forward
q: [batch_size, seq_length]
gv_pos: [batch, K, v_dim]
self_sup: use negative images or not
return: logits, not probs
"""
w_emb = self.w_emb(q)
q_emb = self.q_emb(w_emb) # run GRU on word embeddings [batch, q_dim]
q_repr = self.q_net(q_emb)
gv_pos = self.gv_net(gv_pos)
# batch_size = q.size(0)
logits_pos = self.compute_predict(q_repr, gv_pos)
# if self_sup:
# # construct an irrelevant Q-I pair for each instance
# index = random.sample(range(0, batch_size), batch_size)
# gv_neg = gv_pos[index]
# logits_neg = \
# self.compute_predict(q_repr, gv_neg)
# return logits_pos, logits_neg
# else:
return logits_pos
def compute_predict(self, q_repr, v):
fea = self.san_att(v, q_repr)
logits = self.classifier(fea)
return logits
def san_att(self, gv_emb, q_emb): # [batch, 36, 1280], [batch, 1280]
u = {}
u[0] = q_emb
h_A = {}
p_I = {}
for k in range(1, self.num_layer + 1):
h_A[k] = torch.tanh(self.gv_modules[k-1](gv_emb) + self.q_modules[k-1](u[k-1]).unsqueeze(1)) # batch, 36, 1280
p_I[k] = torch.softmax(self.att_down[k-1](h_A[k]).squeeze(-1), dim=-1) # batch, 36
fusion_fea = (p_I[k].unsqueeze(-1) * gv_emb).sum(1) # batch, num_hid
u[k] = u[k-1] + fusion_fea
return u[self.num_layer]