From 65da2538341d32c0a9721a2eb2c8bc277a745e20 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 17:12:29 +0100 Subject: [PATCH 1/8] gpt neox flex attention + refactor --- src/transformers/modeling_utils.py | 45 +- .../models/gpt_neox/modeling_gpt_neox.py | 503 +++++++++--------- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 11 + .../models/gpt_neox/test_modeling_gpt_neox.py | 24 + 5 files changed, 338 insertions(+), 246 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4703c415e42fbb..57044636dad4c8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -89,6 +89,7 @@ is_peft_available, is_remote_url, is_safetensors_available, + is_torch_flex_attn_available, is_torch_greater_or_equal, is_torch_sdpa_available, is_torch_xla_available, @@ -1338,6 +1339,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # SDPA support _supports_sdpa = False + # Flex Attention support + _supports_flex_attn = False + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? _supports_cache_class = False _supports_static_cache = False @@ -1544,6 +1548,10 @@ def _autoset_attn_implementation( message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' if cls._supports_sdpa: message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if cls._supports_flex_attn: + message += ( + ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + ) raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. @@ -1578,6 +1586,8 @@ def _autoset_attn_implementation( hard_check_only=False, check_device_map=check_device_map, ) + elif requested_attn_implementation == "flex_attention": + config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa( @@ -1774,7 +1784,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra """ Checks the availability of SDPA for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module. """ if hard_check_only: if not cls._supports_sdpa: @@ -1799,6 +1809,39 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra config._attn_implementation = "sdpa" return config + @classmethod + def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of Flex Attention for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_flex_attn: + # TODO: add contribution notice? + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch's flex_attention." + " If you believe" + ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_flex_attn_available(): + raise ImportError( + "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0." + ) + + if not is_torch_flex_attn_available() or not cls._supports_flex_attn: + return config + + # TODO check for more edge cases as done in the other implementations + # _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + # if _is_bettertransformer: + # return config + + if not hard_check_only: + config._attn_implementation = "flex_attention" + + return config + def enable_input_require_grads(self): """ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 359996983eed74..c1537316a01620 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -45,6 +45,7 @@ get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, ) from .configuration_gpt_neox import GPTNeoXConfig @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" @@ -76,6 +80,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_static_cache = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" @@ -92,6 +97,189 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) +def eager_attention_forward( + query, key, value, attention_mask, head_mask, norm_factor, attention_dropout, training, **_kwargs +): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=norm_factor, + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_scores = attn_scores + causal_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training) + attn_output = torch.matmul(attn_weights, value) + + # Reshape outputs + attn_output = GPTNeoXAttention._merge_heads(attn_output, num_attention_heads, attn_head_size) + + return attn_output, attn_weights + + +def flash_attention_forward( + query, + key, + value, + attention_mask, + norm_factor, + attention_dropout, + training, + target_dtype=torch.float16, + _flash_attn_uses_top_left_mask=False, + **_kwargs, +): + query_length = query.shape[-2] + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + value_dtype = value.dtype + if query.dtype != value_dtype: + query = query.to(value_dtype) + if key.dtype != value_dtype: + key = key.to(value_dtype) + + # Permute to get the expected shape for Flash Attention + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 / bfloat16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + input_dtype = query.dtype + if input_dtype == torch.float32: + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attention_dropout = attention_dropout if training else 0.0 + + # Compute attention + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attention_dropout, + softmax_scale=norm_factor, + is_causal=True, + use_top_left_mask=_flash_attn_uses_top_left_mask, + ) + + return attn_output, None + + +def sdpa_attention_forward( + query, key, value, attention_mask, attention_dropout, training, require_contiguous_qkv=False, **_kwargs +): + q_len = query.shape[-2] + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + value_dtype = value.dtype + if query.dtype != value_dtype: + query = query.to(value_dtype) + if key.dtype != value_dtype: + key = key.to(value_dtype) + + # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA + if require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=causal_mask, + dropout_p=attention_dropout.p if training else 0.0, + is_causal=is_causal, + ) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None + + +def flex_attention_forward(query, key, value, attention_mask, norm_factor, output_attentions=False, **_kwargs): + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if causal_mask is not None: + return score + causal_mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=causal_mod, + enable_gqa=True, + scale=norm_factor, + return_lse=output_attentions, + ) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +GPTNEOX_ATTENTION_FUNCTION = { + "eager": eager_attention_forward, + "flash_attention_2": flash_attention_forward, + "sdpa": sdpa_attention_forward, + "flex_attention": flex_attention_forward, +} + + class GPTNeoXAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() @@ -123,6 +311,13 @@ def __init__(self, config, layer_idx=None): self.is_causal = True self.layer_idx = layer_idx + # Attention specific information for the implementations + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + def _init_bias(self, max_positions, device=None): self.register_buffer( "bias", @@ -147,20 +342,71 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): + bsz, seq_len, _ = hidden_states.shape + # Apply attention-specific projections and rope query, key, value, present = self._attn_projections_and_rope( hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache, + cache_position=cache_position, position_embeddings=position_embeddings, ) + # Flash Attention 2 specific handling for PEFT integration + target_dtype = None + if self.config._attn_implementation == "flash_attention_2": + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + + # Checking for fallbacks in case an unsupported feature is requested + attention_type = self.config._attn_implementation + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + logger.warning_once( + f"Setting `attention_type` to `eager` because `output_attentions=True` is not supported in {attention_type}" + ) + attention_type = "eager" + + if ( + self.training + and self.config.attention_dropout > 0 + and self.config._attn_implementation == "flex_attention" + ): + logger.warning_once( + f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}" + ) + attention_type = "eager" + # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = GPTNEOX_ATTENTION_FUNCTION[attention_type]( + query, + key, + value, + attention_mask=attention_mask, + head_mask=head_mask, + norm_factor=self.norm_factor, + attention_dropout=self.config.attention_dropout, + training=self.training, + # Flash Attention 2 specific + target_dtype=target_dtype, + _flash_attn_uses_top_left_mask=self._flash_attn_uses_top_left_mask, + # SDPA specific + require_contiguous_qkv=self.require_contiguous_qkv, + # Flex Attention specific + output_attentions=output_attentions, + ) - # Reshape outputs - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + # Reshape outputs and final projection + attn_output = attn_output.contiguous() + attn_output = attn_output.view(bsz, seq_len, -1) attn_output = self.dense(attn_output) outputs = (attn_output, present) @@ -250,262 +496,28 @@ def _attn_projections_and_rope( return query, key, value, layer_past - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] - # compute causal mask from causal mask buffer - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - - # dynamically increase the causal mask with the key length, if needed. - if key_length > self.bias.shape[-1]: - self._init_bias(key_length, device=key.device) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - - query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) - key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - batch_size * num_attention_heads, - query_length, - key_length, - dtype=query.dtype, - device=key.device, - ) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.norm_factor, - ) - attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) - - mask_value = torch.finfo(attn_scores.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) - attn_scores = torch.where(causal_mask, attn_scores, mask_value) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_scores = attn_scores + causal_mask - - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_weights = self.attention_dropout(attn_weights) - - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights - +# TODO Remove in deprecation cycle class GPTNeoXFlashAttention2(GPTNeoXAttention): - """ - GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, - head_mask: Optional[torch.FloatTensor] = None, - layer_past: Optional[Cache] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): - # Apply attention-specific projections and rope - query, key, value, present = self._attn_projections_and_rope( - hidden_states=hidden_states, - position_ids=position_ids, - layer_past=layer_past, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - query_length = query.shape[-2] - - # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision - target_dtype = value.dtype - if query.dtype != target_dtype: - query = query.to(target_dtype) - if key.dtype != target_dtype: - key = key.to(target_dtype) - - # Permute to get the expected shape for Flash Attention - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 / bfloat16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - input_dtype = query.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.query_key_value.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - - attention_dropout = self.config.attention_dropout if self.training else 0.0 - - # Compute attention - attn_weights = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attention_dropout, - softmax_scale=self.norm_factor, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - # Reshape outputs - attn_output = attn_weights.reshape( - attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size + logger.warning_once( + "The `GPTNeoXFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GPTNeoXAttention` class! It will be removed in v4.48" ) - attn_output = self.dense(attn_output) - - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs +# TODO Remove in deprecation cycle class GPTNeoXSdpaAttention(GPTNeoXAttention): - """ - GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GPTNeoXAttention` as the weights of the module stays untouched. The only changes are on the forward pass - to adapt to the SDPA API. - """ - def __init__(self, config, layer_idx=None): super().__init__(config, layer_idx=layer_idx) - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, - head_mask: Optional[torch.FloatTensor] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): - if output_attentions or head_mask is not None: - logger.warning_once( - "`GPTNeoXSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - # Apply attention-specific projections and rope - query, key, value, present = self._attn_projections_and_rope( - hidden_states=hidden_states, - position_ids=position_ids, - layer_past=layer_past, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision - target_dtype = value.dtype - if query.dtype != target_dtype: - query = query.to(target_dtype) - if key.dtype != target_dtype: - key = key.to(target_dtype) - - # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA - if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=causal_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=is_causal, + logger.warning_once( + "The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GPTNeoXAttention` class! It will be removed in v4.48" ) - # Reshape outputs - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.dense(attn_output) - - return attn_output, present, None - - -def attention_mask_func(attention_scores, ltor_mask): - attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min) - return attention_scores - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX class GPTNeoXRotaryEmbedding(nn.Module): @@ -675,6 +687,7 @@ def forward(self, hidden_states): "eager": GPTNeoXAttention, "flash_attention_2": GPTNeoXFlashAttention2, "sdpa": GPTNeoXSdpaAttention, + "flex_attention": GPTNeoXAttention, } diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 492642d61babb5..f7e962bec346fb 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -206,6 +206,7 @@ is_torch_compile_available, is_torch_cuda_available, is_torch_deterministic, + is_torch_flex_attn_available, is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 70bd236e3bb4ac..1c0bf4861e3b3f 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -358,6 +358,17 @@ def is_torch_sdpa_available(): return version.parse(_torch_version) >= version.parse("2.1.1") +def is_torch_flex_attn_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": + return False + + # TODO check if some bugs cause push backs on the exact version + # NOTE: We require torch>=2.5.0 as it is the first release + return version.parse(_torch_version) >= version.parse("2.5.0") + + def is_torchvision_available(): return _torchvision_available diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 2c3319f02475cc..5c23af4a01891d 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -459,6 +459,30 @@ def test_lm_generate_gptneox(self): self.assertEqual(output_str, expected_output) + @slow + def test_lm_generate_flex_attn_gptneox(self): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped") + for checkpointing in [True, False]: + model = GPTNeoXForCausalLM.from_pretrained( + "EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention" + ) + + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() + model.to(torch_device) + + inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device) + # The hub repo. is updated on 2023-04-04, resulting in poor outputs. + # See: https://github.com/huggingface/transformers/pull/24193 + expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure" + + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) + output_str = tokenizer.batch_decode(output_ids)[0] + + self.assertEqual(output_str, expected_output) + def pythia_integration_test(self): model_name_or_path = "EleutherAI/pythia-70m" model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device) From ebfdc1f95071c27db9daa5fa6c15cfdac7034e5e Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 17:33:35 +0100 Subject: [PATCH 2/8] some formatting --- src/transformers/modeling_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 57044636dad4c8..b47cd73904a2c6 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1821,8 +1821,9 @@ def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> P # TODO: add contribution notice? raise ValueError( f"{cls.__name__} does not support an attention implementation through torch's flex_attention." - " If you believe" - ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + " If you believe this error is a bug, please open an issue in Transformers GitHub repository" + ' and load your model with the argument `attn_implementation="eager"` meanwhile.' + ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' ) if not is_torch_flex_attn_available(): raise ImportError( From b3c2b11b6f89cde5c242e7c1f9d4b2f702d1a367 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 17:42:44 +0100 Subject: [PATCH 3/8] small fix on dropout --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index c1537316a01620..9d9865d4938640 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -233,7 +233,7 @@ def sdpa_attention_forward( key=key, value=value, attn_mask=causal_mask, - dropout_p=attention_dropout.p if training else 0.0, + dropout_p=attention_dropout if training else 0.0, is_causal=is_causal, ) From 06685f82757db8cb343334fac11e6ac2fea095dd Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 17:56:30 +0100 Subject: [PATCH 4/8] add assertion on flex attn test --- tests/models/gpt_neox/test_modeling_gpt_neox.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 5c23af4a01891d..435133e93860ac 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -466,6 +466,7 @@ def test_lm_generate_flex_attn_gptneox(self): model = GPTNeoXForCausalLM.from_pretrained( "EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention" ) + self.assertTrue(model.config._attn_implementation == "flex_attention") if checkpointing: model.gradient_checkpointing_enable() From 53f731906d446e4ce1f0a723322e4db4fb95c62f Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 18:01:43 +0100 Subject: [PATCH 5/8] flaky ci :( From 21941c5cd04c008d79f1fe7e7ca61284ab30ce59 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 19:15:41 +0100 Subject: [PATCH 6/8] add head mask support --- .../models/gpt_neox/modeling_gpt_neox.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9d9865d4938640..bdc1a252b6f90b 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -243,14 +243,19 @@ def sdpa_attention_forward( return attn_output, None -def flex_attention_forward(query, key, value, attention_mask, norm_factor, output_attentions=False, **_kwargs): +def flex_attention_forward(query, key, value, attention_mask, head_mask, norm_factor, output_attentions=False, **_kwargs): + causal_mask_exists = attention_mask is not None + head_mask_exists = head_mask is not None + causal_mask = attention_mask - if attention_mask is not None: + if causal_mask_exists: causal_mask = causal_mask[:, :, :, : key.shape[-2]] def causal_mod(score, b, h, q_idx, kv_idx): - if causal_mask is not None: - return score + causal_mask[b][0][q_idx][kv_idx] + if causal_mask_exists: + score += causal_mask[b][0][q_idx][kv_idx] + if head_mask_exists: + score += head_mask[b][h][0][0] return score attn_output = flex_attention( @@ -369,9 +374,19 @@ def forward( # Checking for fallbacks in case an unsupported feature is requested attention_type = self.config._attn_implementation - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + if (output_attentions or head_mask is not None) and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + warning_msg = "Setting `attention_type` to `eager` because" + + if output_attentions: + warning_msg += f" `output_attentions=True`" + if output_attentions and head_mask is not None: + warning_msg += f" and `head_mask` not None" + elif head_mask is not None: + warning_msg += f" `head_mask` not None" + warning_msg += f" not supported in {attention_type}" + logger.warning_once( - f"Setting `attention_type` to `eager` because `output_attentions=True` is not supported in {attention_type}" + warning_msg ) attention_type = "eager" @@ -932,7 +947,12 @@ def forward( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + # Flex Attention converts it to a separate mask + if head_mask is not None: + converted_head_mask = torch.where(converted_head_mask < 1.0, torch.finfo(inputs_embeds.dtype).min, 0).to(device=self.device) + head_mask = converted_head_mask + hidden_states = self.emb_dropout(inputs_embeds) # create position embeddings to be shared across the decoder layers From 02089056b6dbc7bfe4d7a8d5c61334627da33e76 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 19:20:52 +0100 Subject: [PATCH 7/8] style --- .../models/gpt_neox/modeling_gpt_neox.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index bdc1a252b6f90b..554926da5778f7 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -243,7 +243,9 @@ def sdpa_attention_forward( return attn_output, None -def flex_attention_forward(query, key, value, attention_mask, head_mask, norm_factor, output_attentions=False, **_kwargs): +def flex_attention_forward( + query, key, value, attention_mask, head_mask, norm_factor, output_attentions=False, **_kwargs +): causal_mask_exists = attention_mask is not None head_mask_exists = head_mask is not None @@ -374,20 +376,21 @@ def forward( # Checking for fallbacks in case an unsupported feature is requested attention_type = self.config._attn_implementation - if (output_attentions or head_mask is not None) and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + ]: warning_msg = "Setting `attention_type` to `eager` because" if output_attentions: - warning_msg += f" `output_attentions=True`" + warning_msg += " `output_attentions=True`" if output_attentions and head_mask is not None: - warning_msg += f" and `head_mask` not None" + warning_msg += " and `head_mask` not None" elif head_mask is not None: - warning_msg += f" `head_mask` not None" + warning_msg += " `head_mask` not None" warning_msg += f" not supported in {attention_type}" - logger.warning_once( - warning_msg - ) + logger.warning_once(warning_msg) attention_type = "eager" if ( @@ -950,7 +953,8 @@ def forward( converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # Flex Attention converts it to a separate mask if head_mask is not None: - converted_head_mask = torch.where(converted_head_mask < 1.0, torch.finfo(inputs_embeds.dtype).min, 0).to(device=self.device) + converted_head_mask = torch.where(converted_head_mask < 1.0, torch.finfo(inputs_embeds.dtype).min, 0) + converted_head_mask = converted_head_mask.to(device=self.device) head_mask = converted_head_mask hidden_states = self.emb_dropout(inputs_embeds) From e689d28cf742c9f89efa06bc65f41d33a23108bf Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 23 Nov 2024 20:24:20 +0100 Subject: [PATCH 8/8] handle dtype, replace torch where --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 554926da5778f7..7c3d01bef4bdd6 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -953,8 +953,8 @@ def forward( converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # Flex Attention converts it to a separate mask if head_mask is not None: - converted_head_mask = torch.where(converted_head_mask < 1.0, torch.finfo(inputs_embeds.dtype).min, 0) - converted_head_mask = converted_head_mask.to(device=self.device) + converted_head_mask = ~converted_head_mask.bool() * torch.finfo(inputs_embeds.dtype).min + converted_head_mask = converted_head_mask.to(dtype=self.dtype, device=self.device) head_mask = converted_head_mask hidden_states = self.emb_dropout(inputs_embeds)