diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 31c0c28b46..d585f773b4 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -38,8 +38,7 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"): self.gpt_layer = gpt_layer self.gpt_layer._attn = self.wrapped_scaled_dot_product - mask_value = torch.finfo(torch.float32).min - self._mask_value = torch.full([], mask_value, dtype=torch.float32) + self.downcast_qk = config.model_type in ["gptj", "gpt_neox"] def wrapped_scaled_dot_product( self, @@ -52,6 +51,9 @@ def wrapped_scaled_dot_product( raise_on_head_mask(head_mask) batch_size = query.shape[0] + mask_value = torch.finfo(value.dtype).min + mask_value = torch.full([], mask_value, dtype=value.dtype) + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") @@ -74,12 +76,18 @@ def wrapped_scaled_dot_product( torch.bool ) - causal_mask = torch.where(causal_mask, 0, self._mask_value) + causal_mask = torch.where(causal_mask, 0, mask_value) # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) attention_mask = causal_mask + attention_mask + # in gpt-neo-x and gpt-j the query and keys are always in fp32 + # thus we need to cast them to the value dtype + if self.downcast_qk: + query = query.to(value.dtype) + key = key.to(value.dtype) + sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -98,9 +106,6 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"): self.gpt_layer = gpt_layer self.gpt_layer._attn = self.wrapped_scaled_dot_product - mask_value = torch.finfo(torch.float32).min - self._mask_value = torch.full([], mask_value, dtype=torch.float32) - if self.gpt_layer.bias[0][0][-1][0] == 1: self.attention_type = "global" else: @@ -121,6 +126,10 @@ def wrapped_scaled_dot_product( raise_on_head_mask(head_mask) query = query * self.scale batch_size = query.shape[0] + + mask_value = torch.finfo(value.dtype).min + mask_value = torch.full([], mask_value, dtype=value.dtype) + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") @@ -136,12 +145,13 @@ def wrapped_scaled_dot_product( else: query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.gpt_layer.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + causal_mask = self.gpt_layer.bias[:, :, key_length - query_length : key_length, :key_length] - causal_mask = torch.where(causal_mask, 0, self._mask_value) + causal_mask = torch.where(causal_mask, 0, mask_value) if batch_size > 1: # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) + attention_mask = causal_mask + attention_mask sdpa_result = torch.nn.functional.scaled_dot_product_attention( @@ -162,9 +172,6 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"): self.gpt_layer = gpt_layer self.gpt_layer._attn = self.wrapped_scaled_dot_product - mask_value = torch.finfo(torch.float32).min - self._mask_value = torch.full([], mask_value, dtype=torch.float32) - def wrapped_scaled_dot_product( self, query: torch.Tensor, @@ -175,6 +182,9 @@ def wrapped_scaled_dot_product( ): raise_on_head_mask(head_mask) batch_size = query.shape[0] + mask_value = torch.finfo(value.dtype).min + mask_value = torch.full([], mask_value, dtype=value.dtype) + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") @@ -201,7 +211,7 @@ def wrapped_scaled_dot_product( torch.bool ) - causal_mask = torch.where(causal_mask, 0, self._mask_value) + causal_mask = torch.where(causal_mask, 0, mask_value) # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) @@ -209,6 +219,11 @@ def wrapped_scaled_dot_product( # we use torch.min to avoid having tensor(-inf) attention_mask = torch.min(causal_mask, attention_mask) + # in codegen the query and key are always in fp32 regardless of the dtype of the model + # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 + query = query.to(value.dtype) + key = key.to(value.dtype) + sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -226,9 +241,6 @@ def __init__(self, opt_layer: "nn.Module", config: "PretrainedConfig"): self.opt_layer = opt_layer - mask_value = torch.finfo(torch.float32).min - self._mask_value = torch.full([], mask_value, dtype=torch.float32) - self.scale = torch.sqrt(torch.tensor(self.opt_layer.head_dim, dtype=torch.float32)).to( torch.get_default_dtype() ) @@ -245,6 +257,9 @@ def forward( super().forward_checker() raise_on_head_mask(layer_head_mask) + mask_value = torch.finfo(torch.float32).min + self._mask_value = torch.full([], mask_value, dtype=torch.float32) + if output_attentions is True: raise ValueError("output_attentions=True can not be supported with BetterTransformer.") @@ -329,9 +344,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.layer = layer - mask_value = torch.finfo(torch.float32).min - self._mask_value = torch.full([], mask_value, dtype=torch.float32) - head_dim = self.layer.d_model // self.layer.n_heads # hidden size / num attention heads self.scale = torch.sqrt(torch.tensor(head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) @@ -349,6 +361,9 @@ def forward( ): super().forward_checker() raise_on_head_mask(layer_head_mask) + mask_value = torch.finfo(torch.float32).min + self._mask_value = torch.full([], mask_value, dtype=torch.float32) + if len(self.layer.pruned_heads) > 0: raise ValueError( f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.layer.pruned_heads}." diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index 7c05f5942b..fbe923ec48 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -14,6 +14,7 @@ # limitations under the License. import unittest +import pytest import torch from packaging.version import parse from parameterized import parameterized @@ -22,7 +23,7 @@ from optimum.bettertransformer import BetterTransformer, BetterTransformerManager from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager -from optimum.utils.testing_utils import grid_parameters +from optimum.utils.testing_utils import grid_parameters, require_torch_gpu class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase): @@ -62,6 +63,22 @@ def test_logits_without_cache(self, test_name: str, model_type: str, padding, ba model_id = MODELS_DICT[model_type] super()._test_logits(model_id, padding=padding, batch_size=batch_size) + @parameterized.expand( + grid_parameters( + { + "model_type": SUPPORTED_ARCH, + "use_to_operator": [True, False], + } + ) + ) + @pytest.mark.fp16 + @require_torch_gpu + def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool): + self._skip_on_torch_version(model_type) + + model_id = MODELS_DICT[model_type] + super()._test_fp16_inference(model_id, use_to_operator=use_to_operator) + @parameterized.expand( grid_parameters( { diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 8cf0d927e7..0ec86e1387 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -14,10 +14,10 @@ # limitations under the License. import torch -from transformers import AutoModel +from transformers import AutoModel, AutoModelForCausalLM from optimum.bettertransformer import BetterTransformer -from optimum.utils.testing_utils import flatten_dict +from optimum.utils.testing_utils import flatten_dict, require_torch_gpu MODELS_DICT = { @@ -79,6 +79,44 @@ class BetterTransformersTestMixin: def prepare_inputs_for_class(self, model_id=None): raise NotImplementedError + @require_torch_gpu + def _test_fp16_inference(self, model_id: str, use_to_operator=False, **preprocessor_kwargs): + r""" + This tests if the converted model runs fine under fp16. + """ + # The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283 + inputs = self.prepare_inputs_for_class(model_id=model_id, **preprocessor_kwargs).to(0) + + torch.manual_seed(0) + if not use_to_operator: + hf_random_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(0) + else: + hf_random_model = AutoModelForCausalLM.from_pretrained(model_id).to(0, torch.float16) + + torch.manual_seed(0) + converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) + + self.assertFalse( + hasattr(hf_random_model, "use_bettertransformer"), + f"The model {hf_random_model.__class__.__name__} has been converted to a `fast` model by mistake.", + ) + + with torch.no_grad(): + r""" + Make sure the models are in eval mode! Make also sure that the original model + has not been converted to a fast model. The check is done above. + """ + torch.manual_seed(0) + output_hf = hf_random_model.generate(**inputs) + + torch.manual_seed(0) + output_bt = converted_model.generate(**inputs) + + self.assertTrue( + torch.allclose(output_hf, output_bt, atol=1e-4), + f"The logits of the converted model {converted_model.__class__.__name__} are not equal to the logits of the original model {hf_random_model.__class__.__name__}.", + ) + def _test_logits(self, model_id: str, **preprocessor_kwargs): r""" This tests if the converted model produces the same logits