Skip to content

Commit

Permalink
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
@@ -105,15 +105,22 @@ def __init__(self, config: PretrainedConfig):
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)

def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Parameters:
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
input_ids (torch.Tensor):
torch.tensor(bs, max_seq_length) The token ids to embed.
input_embeds (*optional*, torch.Tensor):
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
embeddings)
"""
seq_length = input_ids.size(1)
if input_ids is not None:
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)

seq_length = input_embeds.size(1)

# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
@@ -124,10 +131,9 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)

word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)

embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
return embeddings
@@ -573,10 +579,10 @@ def forward(
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)

return self.transformer(
x=inputs_embeds,
x=embeddings,
attn_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,

0 comments on commit 14f3320

Please sign in to comment.