diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eabc6f2926d3..b83d5a973398 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -678,9 +678,10 @@ def prepare_inputs_for_generation( if encoder_attention_mask is not None: model_inputs["attention_mask"] = encoder_attention_mask + # 7. Prepare kwargs for flash attention to avoid recomputations if "flash" in self.config._attn_implementation and self._supports_attention_backend: - cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids( - position_ids, is_packed_sequence=False + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids( + model_inputs["position_ids"], is_packed_sequence=False ) model_inputs.update( cu_seq_lens_q=cu_seq_lens_q.to(self.device), @@ -689,12 +690,12 @@ def prepare_inputs_for_generation( max_length_k=max_length_k, ) - # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + # 8. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: model_inputs[key] = value - # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + # 9. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) model_inputs.pop("labels", None) return model_inputs diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index ed1b30d9a6b0..716a3481a82a 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -10,20 +10,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import torch -import torch.nn.functional as F from ..utils.import_utils import is_torch_npu_available if is_torch_npu_available(): - import math - - import torch_npu - from einops import rearrange, repeat - from torch_npu import npu_rotary_mul + from torch_npu import npu_fusion_attention, npu_rotary_mul # FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. @@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask(): return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, - dtype=grad_output.dtype, - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis = IndexFirstAxis.apply - - -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - # dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) - - -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -def unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - def npu_flash_attn_func( q, k, @@ -179,11 +64,11 @@ def npu_flash_attn_func( if not causal: head_num = q.shape[2] - output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] + output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] else: attn_mask_npu = get_attn_mask_npu(q.device) head_num = q.shape[2] - output = torch_npu.npu_fusion_attention( + output = npu_fusion_attention( q, k, v, @@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func( if not causal: head_num = q.shape[1] - output = torch_npu.npu_fusion_attention( + output = npu_fusion_attention( q, k, v, @@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func( else: attn_mask_npu = get_attn_mask_npu(q.device) head_num = q.shape[1] - output = torch_npu.npu_fusion_attention( + output = npu_fusion_attention( q, k, v, @@ -267,8 +152,3 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs): sin = sin.unsqueeze(0).unsqueeze(2) return npu_rotary_mul(x, cos, sin) - - -def get_npu_flash_attn_funcs(): - # return flash attention related functions used for Ascend NPU in order - return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index e845e0cbc4a4..0d8906076829 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,17 +14,15 @@ import inspect import os import warnings +from functools import partial from typing import Optional, TypedDict import torch import torch.nn.functional as F -from transformers.utils.import_utils import is_kernels_available - from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, - is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torch_npu_available, logging, @@ -34,18 +32,135 @@ logger = logging.get_logger(__name__) -def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) - return reshaped[indices] +# TODO Deprecate when all models have the attention interface +def flash_attn_supports_top_left_mask(): + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): + return not is_flash_attn_greater_or_equal_2_10() + + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + return is_npu_fa2_top_left_aligned_causal_mask() + + +# TODO Deprecate when all models have the attention interface +def is_flash_attn_available(): + return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() + + +# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves +_flash_fn = None +_flash_varlen_fn = None +_pad_fn = None +_unpad_fn = None + +# function that processes kwargs, generalized to handle any supported kwarg within the function +_process_flash_kwargs_fn = None +# exceptions where hf API doesn't match the original flash attention API +_hf_api_to_flash_mapping = { + "dropout": "dropout_p", + "sliding_window": "window_size", +} + + +def _lazy_imports(implementation: Optional[str]): + """ + Lazy loads the respective flash attention implementations. + + Return: + flash_attn_func: The base flash attention function. + flash_attn_varlen_func: The flash attention function supporting variable sequence lengths, + e.g. for padding-free training. + pad_input: The function to pad inputs into one sequence and returning the respective kwargs. + unpad_input: The function to unpad outputs based on the kwargs (from pad_input). + """ + is_fa2 = is_flash_attn_2_available() + is_fa3 = is_flash_attn_3_available() + if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + else: + pad_input, unpad_input = _pad_input, _unpad_input + if implementation == "flash_attention_3" or (implementation is None and is_fa3): + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + elif is_torch_npu_available(): + from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func + from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + # Kernels fallback + else: + flash_attn_func = getattr(implementation, "flash_attn_func", None) + flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None) + if flash_attn_varlen_func is None or flash_attn_func is None: + raise ValueError( + f"Could not find the currently requested flash attention implementation at `{implementation}`." + f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`." + ) + + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input + + +def _lazy_define_process_function(flash_function): + """ + Depending on the version and kernel some features are not supported. Due to limitations in + `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported + within `_process_flash_attention_kwargs`. + + NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. + This might be confusing for kwargs that we use in any case, e.g. `is_causal`. + """ + global _process_flash_kwargs_fn, _hf_api_to_flash_mapping + + flash_parameters = inspect.signature(flash_function).parameters + process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters + + supports_mapping = {} + for param in process_parameters: + fa_param = _hf_api_to_flash_mapping.get(param, param) + supports_mapping[fa_param] = fa_param in flash_parameters + + return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping) + + +def lazy_import_flash_attention(implementation: Optional[str]): + """ + Lazy loading flash attention and returning the respective functions + flags back + + NOTE: For fullgraph, this needs to be called before compile while no fullgraph can + can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`. + """ + global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn + if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): + _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation) + + global _process_flash_kwargs_fn + if _process_flash_kwargs_fn is None: + _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) + return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn -def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. """ - FA3-compatible unpad_input function. + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. indices: (total_nnz), the indices of masked tokens from the flattened input sequence. @@ -69,14 +184,16 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): ) -def _fa3_pad_input(hidden_states, indices, batch, seqlen): +def _pad_input(hidden_states, indices, batch, seqlen): """ - FA3-compatible pad_input function. + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. batch: int, batch size for the padded sequence. seqlen: int, maximum sequence length for the padded sequence. + Return: hidden_states: (batch, seqlen, ...) """ @@ -89,9 +206,11 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen): def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. + Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. @@ -125,6 +244,7 @@ def _upad_input( Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. + Arguments: query_layer (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). @@ -138,6 +258,7 @@ def _upad_input( Target length. unpad_input_func: The function to use for unpadding the input tensors. + Return: query_layer (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). @@ -193,13 +314,15 @@ def _upad_input( def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True): """ This function returns all the necessary kwargs to call `flash_attn_varlen_func` - extracted from position_ids.The `position_ids` can be either packed sequence or - the usual padded position ids, for example in inference time.. + extracted from position_ids. The `position_ids` can be either packed sequence or + the usual padded position ids, for example in inference time. + Arguments: position_ids (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. is_packed_sequence (`bool`, *optional*, defaults to `True`): Whether the input position ids are a packed sequence or not. + Return: (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): The cumulative sequence lengths for the target (query) and source (key, value), used to index into @@ -212,19 +335,21 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = # In that case the position ids will not always start with `0` and we need a better way to infer # cumulative seq lengths. if not is_packed_sequence: - tensor_kws = {"dtype": torch.int32, "device": position_ids.device} - last_position_ids = position_ids[:, -1] + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + last_position_ids = position_ids[:, -1] + q_len = ( + torch.ones(position_ids.size(0), **tensor_kwargs) + if position_ids.shape[-1] == 1 + else last_position_ids.add(1) + ) + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0) cu_seq_lens_k = torch.cat( - [torch.zeros(1, **tensor_kws), last_position_ids.cumsum(0).add(1).to(torch.int32)], 0 + [torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0 ) - max_length_k = int(last_position_ids.max()) + 1 - q_len = ( - torch.ones(position_ids.size(0), **tensor_kws) if position_ids.shape[-1] == 1 else last_position_ids.add(1) - ) - cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0).to(torch.int32)], 0) max_length_q = int(q_len.max()) + max_length_k = int(last_position_ids.max()) + 1 else: position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) @@ -237,16 +362,18 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = ) cu_seq_lens_k = cu_seq_lens_q + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length_q = cu_seq_lens_q.diff().max() # NOTE: With torch compile, this will cause a graph break if you don't set # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. - # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 - # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing - # for some models (e.g. qwen2-vl). - max_length_q = cu_seq_lens_q.diff().max().item() + max_length_q = max_length_q.item() max_length_k = max_length_q + return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) @@ -256,6 +383,7 @@ def _prepare_from_posids(query, key, value, position_ids, query_length): All three query, key, value states will be flattened. Cumulative lengths of each examples in the batch will be extracted from position_ids. NOTE: ideally cumulative lengths should be prepared at the data collator stage + Arguments: query (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). @@ -267,6 +395,7 @@ def _prepare_from_posids(query, key, value, position_ids, query_length): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. query_length (`int`): Sequence length of the input queries. + Return: query (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). @@ -275,121 +404,156 @@ def _prepare_from_posids(query, key, value, position_ids, query_length): value (`torch.Tensor`): Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into - ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, - `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ kv_length = key.shape[1] + is_packed_sequence = query_length == kv_length + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - is_packed_sequence = query_length == kv_length - cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids( + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids( position_ids, is_packed_sequence=is_packed_sequence ) + return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): warnings.warn( - "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids", + "The function `_prepare_flash_attention_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_from_posids` instead.", FutureWarning, ) return _prepare_from_posids(query, key, value, position_ids) -def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = ( + torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + ) + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() + + +def fa_peft_integration_check( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + target_dtype: Optional[torch.dtype] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + """ if target_dtype and q.dtype == torch.float32: logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) return q, k, v -def _lazy_imports(impl: Optional[str]): - # returns funcs and pad/unpad based on impl - is_fa2 = is_flash_attn_2_available() - is_fa3 = is_flash_attn_3_available() - if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): - try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import pad_input, unpad_input - - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False - - except ImportError as e: - if not globals().get("use_remote_fa2", None): - use_remote_fa2 = ( - input( - "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " - ) - .strip() - .lower() - ) - globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} - if globals()["use_remote_fa2"]: - if not is_kernels_available(): - raise ImportError("You need to install kernels: `pip install kernels`") - from kernels import get_kernel - - impl = get_kernel("kernels-community/flash-attn") - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input - return ( - getattr(impl, "flash_attn_func", None), - getattr(impl, "flash_attn_varlen_func"), - pad_input, - unpad_input, - True, - ) - - else: - raise ImportError( - "Failed to import flash attention 2, please install it or use another implementation." - ) from e - elif is_torch_npu_available(): - # get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU - from .integrations.npu_flash_attention import get_npu_flash_attn_funcs - - return get_npu_flash_attn_funcs() - elif impl == "flash_attention_3" or (impl is None and is_fa3): - from flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True - else: - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input - return ( - getattr(impl, "flash_attn_func", None), - getattr(impl, "flash_attn_varlen_func"), - pad_input, - unpad_input, - True, - ) +class FlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Attention with Compile. + + Attributes: + cumulative_seqlens_q (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for query state. + cumulative_seqlens_k (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + cumulative_seqlens_q: Optional[torch.LongTensor] + cumulative_seqlens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] -_flash_supports_window = None +def _process_flash_attention_kwargs( + query_length: int, + key_length: int, + is_causal: bool, + dropout: float = 0.0, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + s_aux: Optional[torch.Tensor] = None, + supports_mapping: Optional[dict[str, bool]] = None, + **kwargs, +): + """ + Returns a set of kwargs that are passed down to the according flash attention function based on + requested features and whether it is supported - depends on the version and kernel implementation + which is dynamically configued at `lazy_import_flash_attention`. The (un)supported features can be + inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. -def is_flash_attn_available(): - return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() + Args: + query_length (`int`): + Length of the query states + key_length (`int`): + Length of the key states + is_causal (`bool`): + Whether we perform causal (decoder) attention or full attention. + dropout (`float`): + Attention dropout. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`. + sliding_window (`int`, *optional*): + The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back. + use_top_left_mask (`bool`): + Deprecated behavior of older versions of flash attention requiring different masking. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + s_aux (`torch.Tensor`, *optional*): + Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. + Return: + flash_kwargs (`dict`): + A dict of kwargs that are requested and supported. + """ + flash_kwargs = { + "causal": is_causal and not (use_top_left_mask and query_length == 1), + "softmax_scale": softmax_scale, + } + if supports_mapping["dropout_p"]: + flash_kwargs["dropout_p"] = dropout -def flash_attn_supports_top_left_mask(): - if is_flash_attn_3_available(): - return False - if is_flash_attn_2_available(): - return not is_flash_attn_greater_or_equal_2_10() + if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window: + flash_kwargs["window_size"] = (sliding_window, sliding_window) - from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + if supports_mapping["deterministic"]: + flash_kwargs["deterministic"] = ( + deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + ) - return is_npu_fa2_top_left_aligned_causal_mask() + if supports_mapping["softcap"] and softcap is not None: + flash_kwargs["softcap"] = softcap + # Only within kernel implementation atm + if supports_mapping["s_aux"] and s_aux is not None: + flash_kwargs["s_aux"] = s_aux -class FlashAttentionKwargs(TypedDict, total=False): - cumulative_seqlens_q: Optional[torch.LongTensor] - cumulative_seqlens_k: Optional[torch.LongTensor] + return flash_kwargs def _flash_attention_forward( @@ -414,100 +578,121 @@ def _flash_attention_forward( implementation: Optional[str] = None, **kwargs, ): - if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): - flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) - globals()["_flash_fn"] = flash_fn - globals()["_flash_varlen_fn"] = flash_varlen_fn - globals()["_pad_fn"] = pad_fn - globals()["_unpad_fn"] = unpad_fn - globals()["_is_fa3"] = is_fa3 - flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters - globals()["_flash_supports_window"] = flash_supports_window - else: - flash_fn = globals()["_flash_fn"] - flash_varlen_fn = globals()["_flash_varlen_fn"] - pad_fn = globals()["_pad_fn"] - unpad_fn = globals()["_unpad_fn"] - is_fa3 = globals()["_is_fa3"] - flash_supports_window = globals()["_flash_supports_window"] - - causal = is_causal and not (use_top_left_mask and query_length == 1) - use_sw = ( - (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`, *optional*): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + implementation (`str`, *optional*): + The attention implementation to use. If None, will default to the one based on the environment. + """ + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention( + implementation ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} - if not is_fa3: - flash_kwargs["dropout_p"] = dropout - if is_flash_attn_greater_or_equal("2.4.1"): - det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - flash_kwargs["deterministic"] = det - if softcap is not None: - flash_kwargs["softcap"] = softcap - if "s_aux" in kwargs: - flash_kwargs["s_aux"] = kwargs.get("s_aux") + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) - use_mask = position_ids is not None or all( - k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k] + + # Extract the flash attention kwargs that have been requested (and are supported by the implementation) + flash_kwargs = process_flash_kwargs_fn( + query_length=query_length, + key_length=key_states.size(1), + is_causal=is_causal, + dropout=dropout, + softmax_scale=softmax_scale, + sliding_window=sliding_window, + use_top_left_mask=use_top_left_mask, + softcap=softcap, + deterministic=deterministic, + **kwargs, + ) + + # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to + # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # + # NOTE: it is user's responsibility to take care of flattenning `position_ids` if that's needed by the model. + # See #39121 for more information. + is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) + is_fa_with_varlen_kwargs = all( + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) ) + + # Contains at least one padding token in the sequence if attention_mask is not None: - q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) - # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py if "mps" in str(q.device): - cu_k = cu_k.clone() + cu_seq_lens_k = cu_seq_lens_k.clone() + out_unpad = flash_varlen_fn( q, k, v, - cu_seqlens_q=cu_q.to(torch.int32), - cu_seqlens_k=cu_k.to(torch.int32), - max_seqlen_q=mq, - max_seqlen_k=mk, - softmax_scale=softmax_scale, - causal=causal, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, **flash_kwargs, ) if isinstance(out_unpad, tuple): out_unpad = out_unpad[0] - out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) - elif use_mask: + + out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) + + # Padding free, i.e. sequences flattened into one total sequence + elif is_fa_with_varlen_kwargs or is_fa_with_position_ids: if cu_seq_lens_q is None or cu_seq_lens_k is None: - if position_ids is None: - raise ValueError( - "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." - ) - q, k, v, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( + q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( query_states, key_states, value_states, position_ids, query_length=query_length ) else: q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - mq, mk = max_length_q, max_length_k - cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py if "mps" in str(q.device): - cu_k = cu_k.clone() + cu_seq_lens_k = cu_seq_lens_k.clone() + out = flash_varlen_fn( q, k, v, - cu_seqlens_q=cu_q.to(torch.int32), - cu_seqlens_k=cu_k.to(torch.int32), - max_seqlen_q=mq, - max_seqlen_k=mk, - softmax_scale=softmax_scale, - causal=causal, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, **flash_kwargs, ) if isinstance(out, tuple): out = out[0] - out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) + + out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1)) + + # No padding else: - out = flash_fn( - query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs - ) + out = flash_fn(query_states, key_states, value_states, **flash_kwargs) + if isinstance(out, tuple): + out = out[0] - return out[0] if isinstance(out, tuple) else out + return out diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b15183b4821e..b8a7d6a44024 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -74,6 +74,7 @@ ) from .loss.loss_utils import LOSS_MAPPING from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS +from .modeling_flash_attention_utils import lazy_import_flash_attention from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -2126,7 +2127,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH _pp_plan = None # This flag signal that the model can be used as an efficient backend in TGI and vLLM - # In practice, it means that they support attention interface functions, fully pass the kwargs + # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan _supports_attention_backend = False _can_record_outputs = None @@ -2748,6 +2749,7 @@ def _check_and_adjust_attn_implementation( if attention_wrapper is None: attention_wrapper = flash_attention_forward kernel_function = partial(attention_wrapper, implementation=kernel) + lazy_import_flash_attention(kernel) elif kernel_name is not None: kernel_function = getattr(kernel, kernel_name) ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) @@ -2763,7 +2765,13 @@ def _check_and_adjust_attn_implementation( attn_implementation = "sdpa" # Try to fallback to sdpa in this case return attn_implementation else: - return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check) + attn_implementation = self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check) + + # preload flash attention here to allow compile with fullgraph + if applicable_attn_implementation.startswith("flash_attention"): + lazy_import_flash_attention(applicable_attn_implementation) + + return attn_implementation def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str: requested_attention = "sdpa" if _requested_attention is None else _requested_attention diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 750c5c22324d..b7ca0e2d9b42 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3483,92 +3483,107 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid for model_class in self.all_model_classes: if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") + # Custom kernel which needs the mask interface to be properly usable on these models + if not model_class._supports_attention_backend and not attn_implementation.startswith("flash_attention"): + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.head_dim = 64 # fa2 does not always support arbitrary headim - model = model_class(config) - - model.to(torch_device) - model.to(torch.bfloat16) - dummy_input = inputs_dict[model.main_input_name][:1] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) - - dummy_attention_mask = inputs_dict.get("attention_mask", None) - if dummy_attention_mask is not None: - dummy_attention_mask = dummy_attention_mask[:1] - if padding_side == "left": - dummy_attention_mask[:, 1:] = 1 - dummy_attention_mask[:, :1] = 0 - else: - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - if model.config.is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + # flash attention variants does not always support arbitrary headim + config = self._prepare_config_headdim(config, 16) - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - else: - outputs = model(dummy_input, output_hidden_states=True) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, output_hidden_states=True) + # TODO it is unclear why saving and reloading with dtype works while + # casting with `.to(dtype=..., device=...)` does not. + # Discovered on tests with `Bart` models. + model = model_class(config) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) - model.set_attn_implementation("sdpa") - logits = ( - outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + dummy_attention_mask = inputs_dict.get("attention_mask", None) - if model.config.is_encoder_decoder: - other_inputs = { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask + dummy_attention_mask = dummy_attention_mask[:1] + if padding_side == "left": + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + else: + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - outputs = model(dummy_input, **other_inputs) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, **other_inputs) - else: - other_inputs = { - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, output_hidden_states=True) + + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - outputs = model(dummy_input, **other_inputs) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, **other_inputs) + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - model.set_attn_implementation("sdpa") - logits = ( - outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask - if padding_side == "left": - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - # check with inference + dropout - model.train() - model.set_attn_implementation(attn_implementation) - _ = model(dummy_input, **other_inputs) - else: - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + if padding_side == "left": + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + model.set_attn_implementation(attn_implementation) + _ = model(dummy_input, **other_inputs) + else: + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) @require_kernels @require_torch_gpu @@ -4698,6 +4713,70 @@ def recursively_check(eager_outputs, exported_outputs): is_tested = recursively_check(eager_outputs, exported_outputs) self.assertTrue(is_tested, msg=f"No outputs were compared for {model_class.__name__}") + @staticmethod + def _prepare_config_headdim(config, requested_dim): + """ + This method allows to update the head dim for all model types including + composite models and models that do not support head dim by themselves. + + Why? A lot of kernels including flex attention rely on triton for compilation. + However, triton cannot handle hidden dimensions of less than 16 for example. + (There are many more examples especially now that the `kernels` library is + supported) + """ + + def update_config_headdim(config, requested_dim): + # Flex Attention cannot use dropout + if hasattr(config, "attention_dropout"): + config.attention_dropout = 0 + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = 0 + + # Update the head dim and try to update hidden size as well if present in config + # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness` + head_dim = None + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + config.head_dim = max(requested_dim, config.head_dim) + + cross_head_dim = None + if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None: + cross_head_dim = config.cross_head_dim + config.cross_head_dim = max(requested_dim, config.cross_head_dim) + + if ( + getattr(config, "hidden_size", None) is not None + and getattr(config, "num_attention_heads", None) is not None + ): + head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads + config.hidden_size *= max(requested_dim // head_dim, 1) + + if ( + getattr(config, "decoder_hidden_size", None) is not None + and getattr(config, "decoder_num_attention_heads", None) is not None + ): + decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads + config.decoder_hidden_size *= max(requested_dim // decoder_head_dim, 1) + + if ( + getattr(config, "cross_hidden_size", None) is not None + and getattr(config, "cross_num_attention_heads", None) is not None + ): + cross_head_dim = ( + cross_head_dim + if cross_head_dim is not None + else config.cross_hidden_size // config.cross_num_attention_heads + ) + config.cross_hidden_size *= max(requested_dim // cross_head_dim, 1) + + # Update config values + update_config_headdim(config, requested_dim) + for key in config.sub_configs: + sub_config = getattr(config, key) + update_config_headdim(sub_config, requested_dim) + + return config + @require_torch_gpu def test_flex_attention_with_grads(self): for model_class in self.all_model_classes: @@ -4711,59 +4790,8 @@ def test_flex_attention_with_grads(self): ): self.skipTest(reason="At least some parts of this model do not support flex attention") - def update_config_for_flex(config): - # Flex Attention cannot use dropout - if hasattr(config, "attention_dropout"): - config.attention_dropout = 0 - if hasattr(config, "attention_probs_dropout_prob"): - config.attention_probs_dropout_prob = 0 - - # Flex attention relies on triton on compilation - # However, triton cannot handle hidden dimensions of less than 16 - # --> forcing at least a hidden dim of 16 - - # Update the head dim and try to update hidden size as well if present in config - # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness` - head_dim = None - if hasattr(config, "head_dim") and config.head_dim is not None: - head_dim = config.head_dim - config.head_dim = max(16, config.head_dim) - - cross_head_dim = None - if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None: - cross_head_dim = config.cross_head_dim - config.cross_head_dim = max(16, config.cross_head_dim) - - if ( - getattr(config, "hidden_size", None) is not None - and getattr(config, "num_attention_heads", None) is not None - ): - head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads - config.hidden_size *= max(16 // head_dim, 1) - - if ( - getattr(config, "decoder_hidden_size", None) is not None - and getattr(config, "decoder_num_attention_heads", None) is not None - ): - decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads - config.decoder_hidden_size *= max(16 // decoder_head_dim, 1) - - if ( - getattr(config, "cross_hidden_size", None) is not None - and getattr(config, "cross_num_attention_heads", None) is not None - ): - cross_head_dim = ( - cross_head_dim - if cross_head_dim is not None - else config.cross_hidden_size // config.cross_num_attention_heads - ) - config.cross_hidden_size *= max(16 // cross_head_dim, 1) - # Set default attention to flex and update config values - update_config_for_flex(config) - for key in config.sub_configs: - sub_config = getattr(config, key) - update_config_for_flex(sub_config) + config = self._prepare_config_headdim(config, 16) # specific to triton if model_class._can_set_attn_implementation(): model = model_class(config).to(device=torch_device)