Skip to content

Commit

Permalink
[generate] fix eos/pad id check on mps devices (huggingface#31695)
Browse files Browse the repository at this point in the history
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
  • Loading branch information
sanchit-gandhi and gante authored Jul 22, 2024
1 parent f2a1e3c commit 5a649ff
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,10 +1542,7 @@ def _tensor_or_none(token, device=None):
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
Expand Down

0 comments on commit 5a649ff

Please sign in to comment.