Skip to content

Commit

Permalink
Make EosTokenCriteria compatible with mps (#30376)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcuenca authored and ArthurZucker committed Apr 23, 2024
1 parent 745bbfe commit f8fec6b
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,18 @@ def __init__(self, eos_token_id: Union[int, List[int]]):

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]
.tile(self.eos_token_id.shape[0], 1)
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
.sum(dim=0)
.bool()
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done


Expand Down

0 comments on commit f8fec6b

Please sign in to comment.