Skip to content

Commit

Permalink
Merge pull request #41 from zefang-liu/fix/attnhp
Browse files Browse the repository at this point in the history
Fix the AttNHP GPU issue
  • Loading branch information
iLampard authored Sep 19, 2024
2 parents 01551fb + f70df24 commit d4731dd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions easy_tpp/model/torch_model/torch_attnhp.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def make_layer_mask(self, attention_mask):
a diagonal matrix, [batch_size, seq_len, seq_len]
"""
# [batch_size, seq_len, seq_len]
layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask)
layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask).to(attention_mask.device)
return layer_mask

def make_combined_att_mask(self, attention_mask, layer_mask):
Expand Down Expand Up @@ -303,7 +303,7 @@ def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_s
if attention_mask is None:
batch_size, seq_len = time_seqs.size()
attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0)
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool)
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool).to(time_seqs.device)

if sample_times.size()[1] < time_seqs.size()[1]:
# we pass sample_dtimes for last time step here
Expand Down

0 comments on commit d4731dd

Please sign in to comment.