diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 46551f4bd2..ba7d56ec22 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -49,14 +49,15 @@ def gpt2_wrapped_scaled_dot_product( if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") + dropout_p = self.dropout_prob_attn if self.training else 0.0 if batch_size == 1 or self.training: if query.shape[2] > 1: sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: query_length, key_length = query.size(-2), key.size(-2) @@ -73,7 +74,7 @@ def gpt2_wrapped_scaled_dot_product( attention_mask = causal_mask + attention_mask sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) # in gpt-neo-x and gpt-j the query and keys are always in fp32 @@ -103,14 +104,15 @@ def gpt_neo_wrapped_scaled_dot_product( if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") + dropout_p = self.dropout_prob_attn if self.training else 0.0 if (batch_size == 1 or self.training) and self.attention_type == "global": if query.shape[2] > 1: sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: query_length, key_length = query.size(-2), key.size(-2) @@ -125,7 +127,7 @@ def gpt_neo_wrapped_scaled_dot_product( attention_mask = causal_mask + attention_mask sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) return sdpa_result, None @@ -153,18 +155,19 @@ def codegen_wrapped_scaled_dot_product( query = query.to(value.dtype) key = key.to(value.dtype) + dropout_p = self.dropout_prob_attn if self.training else 0.0 if batch_size == 1 or self.training: if query.shape[2] > 1: # first step of the decoding sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: # in this case, which is the later decoding steps, the `causal_mask`` in # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195 # is [True, ..., True] so actually not causal sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: query_length, key_length = query.size(-2), key.size(-2) @@ -183,7 +186,7 @@ def codegen_wrapped_scaled_dot_product( attention_mask = torch.min(causal_mask, attention_mask) sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) return sdpa_result, None @@ -247,18 +250,20 @@ def opt_forward( query_states = self._shape(query_states, tgt_len, batch_size) query_states = query_states * self.scale + + dropout_p = self.dropout if self.training else 0.0 if batch_size == 1 or self.training: if query_states.shape[2] > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True + query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False + query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) if attn_output.size() != (batch_size, self.num_heads, tgt_len, self.head_dim): @@ -361,15 +366,16 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): past_key_value[1] if past_key_value is not None else None, ) + dropout_p = self.dropout if self.training else 0.0 query_states = self.scale * query_states if position_bias is None and not self.has_relative_attention_bias: if mask is None: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False + query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False ) elif mask is not None: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False + query_states, key_states, value_states, attn_mask=mask, dropout_p=dropout_p, is_causal=False ) if position_bias is None: @@ -394,11 +400,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if self.has_relative_attention_bias: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False ) attn_output = unshape(attn_output) # (batch_size, seq_length, dim) @@ -471,7 +477,12 @@ def bart_forward( value_states = value_states attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index e246f244a6..7052cf5ce2 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -65,6 +65,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.supports_training = True self.downcast_qk = False + self.dropout_prob_attn = config.attn_pdrop def forward(self, *args, **kwargs): super().forward_checker() @@ -91,19 +92,19 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): "scale_attn", "masked_bias", ] + # Attribute only for transformers>=4.28 + if hasattr(layer, "embed_positions"): + submodules.append("embed_positions") + for attr in submodules: setattr(self, attr, getattr(layer, attr)) self.module_mapping = None self.original_layers_mapping = {submodule: submodule for submodule in submodules} - # this attributes does not exist in transformers<=4.27.4 - if hasattr(self, "embed_positions"): - self.original_layers_mapping["embed_positions"] = "embed_positions" - setattr(self, "embed_positions", getattr(layer, "embed_positions")) - self.downcast_qk = True self.supports_training = True + self.dropout_prob_attn = config.attn_pdrop def forward(self, *args, **kwargs): super().forward_checker() @@ -127,6 +128,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.downcast_qk = True self.supports_training = True + self.dropout_prob_attn = 0.0 # no dropout for gpt-neox def forward(self, *args, **kwargs): super().forward_checker() @@ -156,6 +158,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.scale = torch.sqrt(torch.tensor(layer.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) self.supports_training = True + self.dropout_prob_attn = float(config.attention_dropout) def forward(self, *args, **kwargs): super().forward_checker() @@ -173,12 +176,18 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.module_mapping = None submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "causal_mask", "scale_attn"] + + # Attribute only for transformers>=4.28 + if hasattr(layer, "embed_positions"): + submodules.append("embed_positions") + for attr in submodules: setattr(self, attr, getattr(layer, attr)) self.original_layers_mapping = {submodule: submodule for submodule in submodules} self.supports_training = True + self.dropout_prob_attn = config.attn_pdrop def forward(self, *args, **kwargs): super().forward_checker() diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index c7d694ab4c..a8e7a9add2 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) if is_accelerate_available(): - from accelerate import dispatch_model + from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import remove_hook_from_module ERROR_MESSAGE = r"The Better Transformers implementation for the model {model_name} has not been implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer` implementation." @@ -254,7 +254,18 @@ def transform( setattr(model_fast, "use_bettertransformer", True) if load_accelerate: - model_fast = dispatch_model(model_fast, hf_device_map) + all_model_tensors = [name for name, _ in model_fast.state_dict().items()] + for module_name in hf_device_map.keys(): + all_model_tensors = [name for name in all_model_tensors if not name.startswith(module_name)] + + if len(all_model_tensors) > 0: + # This is the case where a transformed submodule is broken into several devices: + # as the submodules map may differ, we need to reinfer the device map + bt_device_map = infer_auto_device_map(model_fast, max_memory=max_memory) + else: + bt_device_map = hf_device_map + + model_fast = dispatch_model(model_fast, bt_device_map) # It is not recommended to have `keep_original_model=True` with a model # that is loaded with accelerate but just in case diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index 1877e33bf4..2ffc0afea8 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import unittest import pytest @@ -22,7 +23,7 @@ from optimum.bettertransformer import BetterTransformer from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager -from optimum.utils.testing_utils import grid_parameters, require_torch_20, require_torch_gpu +from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_20, require_torch_gpu class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase): @@ -193,3 +194,40 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina self._test_invert_model_logits( model_id=model_id, model_type=model_type, keep_original_model=keep_original_model ) + + @parameterized.expand( + grid_parameters( + {"keep_original_model": [True], "max_memory": [{0: "300MB", "cpu": "3GB"}, {0: "2GB"}]}, + add_test_name=False, + ) + ) + @require_torch_gpu + @require_accelerate + def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None): + hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory).eval() + bt_model = BetterTransformer.transform( + hf_model, keep_original_model=keep_original_model, max_memory=max_memory + ) + + inputs_ids = torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]) + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + + # Check that the model has been dispatched on CPU and GPU + self.assertSetEqual(set(hf_model.hf_device_map.values()), set(max_memory)) + self.assertSetEqual(set(bt_model.hf_device_map.values()), set(max_memory)) + + # Check that the model has weights on GPU and CPU + self.assertEqual(bt_model.transformer.h[0].mlp.c_fc.weight.device, torch.device("cuda:0")) + + # Weights that are offloaded on the CPU are offloaded on the `meta` device + if "cpu" in set(max_memory): + self.assertEqual(bt_model.transformer.h[-1].mlp.c_fc.weight.device, torch.device("meta")) + + with torch.inference_mode(): + output_bt = bt_model(inputs_ids, attention_mask=attention_mask) + output_hf = hf_model(inputs_ids, attention_mask=attention_mask) + + self.assertEqual(output_bt[0].device, torch.device("cpu")) + self.assertTrue(torch.allclose(output_bt[0], output_hf[0], atol=1e-3)) + + gc.collect() diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 5bc08a51db..93b9a2cf75 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -209,8 +209,8 @@ def _onnx_export( monolith: bool, device="cpu", ): - model_class = TasksManager.get_model_class_for_task(task) config = AutoConfig.from_pretrained(model_name) + model_class = TasksManager.get_model_class_for_task(task, model_type=config.model_type.replace("_", "-")) model = model_class.from_config(config) # Dynamic axes aren't supported for YOLO-like models. This means they cannot be exported to ONNX on CUDA devices.