From c1aa0edb48217f416f4bbe6e3a9db1500284513b Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:32:50 +0800 Subject: [PATCH] [generate] only require an attention mask for mps with torch<2.4 (#32367) * up * style * stopping --- src/transformers/generation/stopping_criteria.py | 5 ++++- src/transformers/generation/utils.py | 7 ++++--- src/transformers/pytorch_utils.py | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b1bf3dee9ae1..f8e94f6f86a0 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -9,6 +9,8 @@ import torch from torch.nn import functional as F +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 + from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import add_start_docstrings, logging @@ -485,7 +487,8 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: self.eos_token_id = self.eos_token_id.to(input_ids.device) - if input_ids.device.type == "mps": + if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + # TODO: remove this workaround when we stop supporting torch<=2.3 # https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 is_done = ( input_ids[:, -1] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 385d68cfbef3..ccaa1d80e3f8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -47,6 +47,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) +from ..pytorch_utils import is_torch_greater_or_equal_than_2_4 from ..tokenization_utils import ExtensionsTrie from ..utils import ( ModelOutput, @@ -488,10 +489,10 @@ def _prepare_attention_mask_for_generation( return default_attention_mask # Otherwise we have may have information -> try to infer the attention mask - if inputs.device.type == "mps": - # mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764) + if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + # mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764) raise ValueError( - "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device." + "Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4" ) is_pad_token_in_inputs = (pad_token_id is not None) and ( diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index ae6c0627bb26..2982864d883c 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4") is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")