Skip to content

Commit

Permalink
OPT/BioGPT: Improved attention mask shape exception (#23270)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored May 16, 2023
1 parent 21741e8 commit 466af1a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ def forward(

if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
)

# embed positions
positions = self.embed_positions(attention_mask, past_key_values_length)

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,11 @@ def forward(
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/opt/modeling_tf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,15 @@ def call(

if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
else:
tf.debugging.assert_equal(
attention_mask.shape[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
),
)

pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

Expand Down

0 comments on commit 466af1a

Please sign in to comment.