diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 148bc2bd18a356..7c2c6f46025363 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -170,7 +170,7 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand(input_shape) if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -655,11 +655,11 @@ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, posi device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: - attention_mask = torch.ones(input_shape) + attention_mask = torch.ones(input_shape, device=device) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(input_shape) + encoder_attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads.