Skip to content

Commit

Permalink
'comment'
Browse files Browse the repository at this point in the history
  • Loading branch information
DRL36 committed Mar 3, 2019
1 parent 6811dc6 commit 58697e8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
12 changes: 3 additions & 9 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,18 @@ def forward(self, cw_idxs, qw_idxs, cwf):

c_emb = self.emb(cw_idxs) # (batch_size, c_len, hidden_size)
q_emb = self.emb(qw_idxs) # (batch_size, q_len, hidden_size)


# s = c_emb.shape
# cf_emb = torch.zeros(s[0],s[1],1,device='cuda')
#add word feature for content
cwf = torch.unsqueeze(cwf, dim = 2)

cwf = cwf.type(torch.cuda.FloatTensor)
cwf.to('cuda')

ct_emb = torch.cat((c_emb, cwf), dim = 2)

#for dimension consistent, add a feature of all-zero for questions
s = q_emb.shape
qf_emb = torch.zeros(s[0],s[1],1,device='cuda')
qt_emb = torch.cat((q_emb, qf_emb), dim = 2)

# for index in range(len(cw_idxs)):
# for i, word_id in enumerate(cw_idxs[index]):
# if word_id in qw_idxs[index]:
# ct_emb[index][i][-1] = 1

c_enc = self.enc(ct_emb, c_len) # (batch_size, c_len, 2 * hidden_size)
q_enc = self.enc(qt_emb, q_len) # (batch_size, q_len, 2 * hidden_size)
Expand Down
5 changes: 3 additions & 2 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self, data_path, use_v2=True):
# self.context_word_features[i][j][0] = 1


#compute word features(exact match)
def compute_context_word_features(self, idx):
s = self.context_idxs[idx].shape
context_word_features = torch.zeros(s[0])
Expand All @@ -95,8 +96,6 @@ def compute_context_word_features(self, idx):
return context_word_features




def __getitem__(self, idx):
idx = self.valid_idxs[idx]
example = (self.context_idxs[idx],
Expand All @@ -106,6 +105,7 @@ def __getitem__(self, idx):
self.y1s[idx],
self.y2s[idx],
self.ids[idx],
#get word features
self.compute_context_word_features(idx)
)

Expand Down Expand Up @@ -153,6 +153,7 @@ def merge_2d(matrices, dtype=torch.int64, pad_value=0):
return padded

# Group by tensor type
#add cwf for word features for content
context_idxs, context_char_idxs, \
question_idxs, question_char_idxs, \
y1s, y2s, ids, cwf = zip(*examples)
Expand Down

0 comments on commit 58697e8

Please sign in to comment.