Skip to content

Commit

Permalink
Fix BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Nov 6, 2019
1 parent 27e015b commit d531979
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down Expand Up @@ -655,11 +655,11 @@ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, posi
device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones(input_shape)
attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape)
encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
Expand Down

0 comments on commit d531979

Please sign in to comment.