diff --git a/easy_tpp/model/torch_model/torch_attnhp.py b/easy_tpp/model/torch_model/torch_attnhp.py index a847dd1..2be437c 100644 --- a/easy_tpp/model/torch_model/torch_attnhp.py +++ b/easy_tpp/model/torch_model/torch_attnhp.py @@ -151,7 +151,7 @@ def make_layer_mask(self, attention_mask): a diagonal matrix, [batch_size, seq_len, seq_len] """ # [batch_size, seq_len, seq_len] - layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask) + layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask).to(attention_mask.device) return layer_mask def make_combined_att_mask(self, attention_mask, layer_mask): @@ -303,7 +303,7 @@ def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_s if attention_mask is None: batch_size, seq_len = time_seqs.size() attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0) - attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) + attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool).to(time_seqs.device) if sample_times.size()[1] < time_seqs.size()[1]: # we pass sample_dtimes for last time step here