Skip to content

Commit

Permalink
[generate] fix eos/pad id check on mps devices
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jun 28, 2024
1 parent 0cf60f1 commit 18837f2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,7 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if eos_token_id is not None and pad_token_id in eos_token_id:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
Expand Down

0 comments on commit 18837f2

Please sign in to comment.