diff --git a/mindnlp/core/nn/modules/conv.py b/mindnlp/core/nn/modules/conv.py index 5906acb31..d4898bfbf 100644 --- a/mindnlp/core/nn/modules/conv.py +++ b/mindnlp/core/nn/modules/conv.py @@ -3,6 +3,8 @@ import math from typing import Optional, Tuple, Union, List from mindspore import Tensor, ops as mops +from mindspore.ops.auto_generate.gen_ops_prim import conv1d_ext_op, conv1d_padding_op +from mindspore.ops.function.nn_func import pad_ext from ..parameter import Parameter from .module import Module from ..common_types import _size_2_t, _size_1_t @@ -182,43 +184,18 @@ def __init__( super().__init__( in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, False, _single(0), groups, bias, padding_mode, **factory_kwargs) - - pad_mode = 'valid' - pad = padding - if isinstance(padding, tuple): - if padding[0] != 0: - pad_mode = 'pad' - pad = (0, 0, padding[0], padding[0]) - elif isinstance(padding, int): - if padding != 0: - pad_mode = 'pad' - pad = (0, 0) + (padding,) * 2 - if not isinstance(padding, (int, tuple)): - pad_mode = padding - pad = (0,) * 4 - - if self.padding_mode != 'zeros': - pad_mode = 'valid' - pad = (0,) * 4 - self.conv2d = mops.Conv2D(out_channel=self.out_channels, - kernel_size=(1,) + self.kernel_size, - mode=1, - pad_mode=pad_mode, - pad=pad, - stride=(1,) + self.stride, - dilation=(1,) + self.dilation, - group=self.groups) + + if isinstance(padding, str) and padding_mode == "zeros": + self.conv1d = conv1d_padding_op + else: + self.conv1d = conv1d_ext_op def forward(self, input): - if self.padding_mode != 'zeros': - input = F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) - input = input.expand_dims(2) - output = self.conv2d(input, self.weight.expand_dims(2)) - - if self.bias is not None: - output = mops.bias_add(output, self.bias) - - output = output.squeeze(2) + if self.padding_mode != "zeros": + output = self.conv1d(pad_ext(input, self._reversed_padding, mode=self.padding_mode), self.weight, + self.bias, self.stride, (0,), self.dilation, self.groups) + else: + output = self.conv1d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 1cd14c2c6..1eb318aee 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -130,7 +130,7 @@ def narrow(input, dim, start, length): has_nonzero = hasattr(mindspore.mint, 'nonzero') def nonzero(input, *, as_tuple=False): if use_pyboost() and has_nonzero: - return mindspore.mint.nonzero(input, as_tuple) + return mindspore.mint.nonzero(input, as_tuple=as_tuple) _nonzero = _get_cache_prim(ops.NonZero)() out = _nonzero(input) if as_tuple: diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 5bcfcd21c..3e2f955d6 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -556,6 +556,13 @@ def einsum(equation, *operands): return result +# expand_dims +has_expand_dims = hasattr(mindspore.mint, 'expand_dims') +def expand_dims(input, axis): + if use_pyboost() and has_expand_dims: + return mindspore.mint.expand_dims(input, axis) + return ops.expand_dims(input, axis) + # flatten has_flatten = hasattr(mindspore.mint, 'flatten') diff --git a/mindnlp/transformers/configuration_utils.py b/mindnlp/transformers/configuration_utils.py index 12a3d72f8..b8ede61e8 100644 --- a/mindnlp/transformers/configuration_utils.py +++ b/mindnlp/transformers/configuration_utils.py @@ -342,6 +342,7 @@ def __init__(self, **kwargs): # Attention implementation to use, if relevant. self._attn_implementation_internal = kwargs.pop("attn_implementation", None) + self._attn_implementation_autoset = False # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) diff --git a/mindnlp/transformers/integrations/npu_flash_attention.py b/mindnlp/transformers/integrations/npu_flash_attention.py new file mode 100644 index 000000000..82b27a04c --- /dev/null +++ b/mindnlp/transformers/integrations/npu_flash_attention.py @@ -0,0 +1,242 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. +Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask. +""" + + +import os + +import math +from typing import Optional, Tuple +import mindspore +from mindspore.ops import flash_attention_score +from mindspore import nn +from mindnlp.core import ops + + +# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. +# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask. +TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 +DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3 + +SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=str(DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))) +if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]: + raise ValueError( + "Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) " + "or 3 (down-right aligned causal mask)." + ) + + +def is_npu_fa2_top_left_aligned_causal_mask(): + return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE + + +class IndexFirstAxis(nn.Cell): + def __init__(self): + super(IndexFirstAxis, self).__init__() + + def construct(self, input: mindspore.Tensor, indices: mindspore.Tensor): + assert input.ndim >= 2 + first_axis_dim, other_shape = input.shape[0], input.shape[1:] + input_flat = input.reshape(first_axis_dim, -1) + indices_expanded = ops.expand_dims(indices, -1) + indices_expanded = ops.broadcast_to(indices_expanded, (-1, input_flat.shape[1])) + output_flat = ops.gather(input_flat, 0, indices_expanded) + output = output_flat.reshape(-1, *other_shape) + return output + + def bprop(self, input, indices, out, dout): + assert dout.ndim >= 2 + other_shape = dout.shape[1:] + grad_output = dout + + grad_flat = grad_output.reshape(grad_output.shape[0], -1) + grad_shape = (input.shape[0], grad_flat.shape[1]) + grad_input = ops.zeros(grad_shape, grad_flat.dtype) + + indices_expanded = ops.expand_dims(indices, -1) + indices_expanded = ops.broadcast_to(indices_expanded, (-1, grad_flat.shape[1])) + grad_input.scatter_(0, indices_expanded, grad_flat) + + return grad_input.reshape(input.shape[0], *other_shape), None + + +index_first_axis = IndexFirstAxis() + + +class IndexPutFirstAxis(nn.Cell): + def __init__(self): + super(IndexPutFirstAxis, self).__init__() + + def construct(self, values: mindspore.Tensor, indices: mindspore.Tensor, first_axis_dim: int): + assert indices.ndim == 1 + assert values.ndim >= 2 + output = ops.zeros( + (first_axis_dim, *values.shape[1:]), + values.dtype + ) + output[indices] = values + return output + + def bprop(self, values, indices, first_axis_dim, out, dout): + grad_values = dout[indices] + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis() + + +def pad_input( + hidden_states: mindspore.Tensor, + indices: mindspore.Tensor, + batch: int, + seqlen: int +): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return output.reshape(batch, seqlen, *hidden_states.shape[1:]) + + +def unpad_input( + hidden_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, + unused_mask: Optional[mindspore.Tensor] = None, +): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=mindspore.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32) + indices = ops.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0)) + + hidden_states_flat = hidden_states.reshape(-1, *hidden_states.shape[2:]) + hidden_states = index_first_axis(hidden_states_flat, indices) + return ( + hidden_states, + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def create_attn_mask(causal: bool, sparse_mode: int) -> Tuple[int, mindspore.Tensor]: + """ + Create a causal mask for the attention scores. + + Args: + causal (`bool`): + If `True`, the mask will be causal. + sparse_mode (`bool`): + If `True`, the mask will be top-left + aligned, otherwise it will be bottom-right aligned. + Returns: + `Tuple[bool, mindspore.Tensor]`: + A tuple containing sparse_mode and the mask tensor. + """ + if not causal: + sparse_mode = 0 + attn_mask = None + else: + if sparse_mode == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE: + attn_mask = ops.tril(ops.ones((2048, 2048)), diagonal=-1).bool() + else: + attn_mask = ops.triu(ops.ones((2048, 2048)), diagonal=1).bool() + return sparse_mode, attn_mask + + +def npu_flash_attn_func( + q: mindspore.Tensor, + k: mindspore.Tensor, + v: mindspore.Tensor, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs, +): + head_num = q.shape[2] + sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE) + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + output = flash_attention_score( + q, + k, + v, + head_num, + keep_prob=1.0 - dropout_p, + scalar_value=softmax_scale, + attn_mask=attn_mask, + input_layout="BSND", + sparse_mode=sparse_mode, + prefix=None, + ) + + return output + + +def npu_flash_attn_varlen_func( + q: mindspore.Tensor, + k: mindspore.Tensor, + v: mindspore.Tensor, + cu_seqlens_q: Optional[mindspore.Tensor] = None, + cu_seqlens_k: Optional[mindspore.Tensor] = None, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs, +): + head_num = q.shape[1] + sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE) + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + + output = flash_attention_score( + q, + k, + v, + head_num, + keep_prob=1.0 - dropout_p, + scalar_value=softmax_scale, + attn_mask=attn_mask, + input_layout="TND", + actual_seq_qlen=cu_seqlens_q[1:].asnumpy().tolist(), + actual_seq_kvlen=cu_seqlens_k[1:].asnumpy().tolist(), + sparse_mode=sparse_mode, + prefix=None, + ) + + return output diff --git a/mindnlp/transformers/modeling_flash_attention_utils.py b/mindnlp/transformers/modeling_flash_attention_utils.py new file mode 100644 index 000000000..6ed1c417f --- /dev/null +++ b/mindnlp/transformers/modeling_flash_attention_utils.py @@ -0,0 +1,372 @@ +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module provides utilities for flash attention in Transformers models.""" + + +import os +import inspect +from typing import Optional, Tuple +import mindspore +from mindnlp.core import ops +from ..utils import logging +from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input +from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func +from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + + +logger = logging.get_logger(__name__) + + +if flash_attn_func is not None: + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +def flash_attn_supports_top_left_mask(): + # down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask. + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + return is_npu_fa2_top_left_aligned_causal_mask() + + +def _get_unpad_data(attention_mask: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`mindspore.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`mindspore.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`mindspore.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32) + indices = ops.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: mindspore.Tensor, + key_layer: mindspore.Tensor, + value_layer: mindspore.Tensor, + attention_mask: mindspore.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`mindspore.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`mindspore.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`mindspore.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`mindspore.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`mindspore.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`mindspore.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`mindspore.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`mindspore.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = ops.arange(batch_size + 1, dtype=mindspore.int32) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids( + query: mindspore.Tensor, + key: mindspore.Tensor, + value: mindspore.Tensor, + position_ids: mindspore.Tensor, +): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cumulative lengths should be prepared at the data collator stage + + Arguments: + query (`mindspore.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`mindspore.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`mindspore.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`mindspore.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`mindspore.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`mindspore.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`mindspore.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`mindspore.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.shape[-2], query.shape[-1]) + key = key.contiguous().view(-1, key.shape[-2], key.shape[-1]) + value = value.contiguous().view(-1, value.shape[-2], value.shape[-1]) + position_ids = position_ids.flatten() + indices_q = ops.arange(position_ids.shape[0], dtype=mindspore.int32) + + cu_seq_lens = ops.cat( + (indices_q[position_ids == 0], + mindspore.tensor(position_ids.shape, dtype=mindspore.int32) + ) + ) + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def fa_peft_integration_check( + query: mindspore.Tensor, + key: mindspore.Tensor, + value: mindspore.Tensor, + target_dtype: Optional[mindspore.dtype.TensorType] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + + Args: + query (`mindspore.Tensor`): + Input query states to be passed to Flash Attention API + key (`mindspore.Tensor`): + Input key states to be passed to Flash Attention API + value (`mindspore.Tensor`): + Input value states to be passed to Flash Attention API + target_dtype (`mindspore.dtype`, *optional*): + The dtype to convert the attention tensors to. Conversion can be ignored by + not providing the target dtype. + """ + if target_dtype is None: + return query, key, value + + input_dtype = query.dtype + if input_dtype == mindspore.float32: + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + return query, key, value + + +flash_241 = False +deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + + +def _flash_attention_forward( + query_states: mindspore.Tensor, + key_states: mindspore.Tensor, + value_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[mindspore.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + cu_seq_lens_q: Optional[mindspore.Tensor] = None, + cu_seq_lens_k: Optional[mindspore.Tensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[mindspore.dtype.TensorType] = None, + **kwargs, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`mindspore.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`mindspore.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`mindspore.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`mindspore.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, + while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. + This attribute is used to handle this difference. + """ + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if flash_241: + if deterministic is None: + deterministic = deterministic_g + flash_kwargs["deterministic"] = deterministic + + if softcap is not None: + flash_kwargs["softcap"] = softcap + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype + ) + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and ( + max_length_q is not None or (query_length != 1 and not (position_ids.diff(dim=-1) >= 0).all()) + ): + batch_size = query_states.size(0) + if cu_seq_lens_q is None or cu_seq_lens_k is None: + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) + ) + + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens + max_length_q, max_length_k = max_seq_lens + + else: + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/mindnlp/transformers/modeling_utils.py b/mindnlp/transformers/modeling_utils.py index 414404aaf..7138cf0af 100644 --- a/mindnlp/transformers/modeling_utils.py +++ b/mindnlp/transformers/modeling_utils.py @@ -1349,7 +1349,20 @@ def _autoset_attn_implementation( # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal - config._attn_implementation = "eager" + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + + if config._attn_implementation == "flash_attention_2": + config._attn_implementation = "flash_attention_2" + elif isinstance(requested_attn_implementation, dict): + config._attn_implementation = None + else: + config._attn_implementation = "eager" + + config._attn_implementation_autoset = True return config diff --git a/mindnlp/transformers/models/whisper/modeling_whisper.py b/mindnlp/transformers/models/whisper/modeling_whisper.py index 1900818cd..c8f1698ec 100644 --- a/mindnlp/transformers/models/whisper/modeling_whisper.py +++ b/mindnlp/transformers/models/whisper/modeling_whisper.py @@ -41,6 +41,8 @@ from .configuration_whisper import WhisperConfig from .generation_whisper import WhisperGenerationMixin +from ...modeling_flash_attention_utils import _flash_attention_forward + logger = logging.get_logger(__name__) @@ -388,8 +390,118 @@ def forward( return attn_output, attn_weights, past_key_value +class WhisperFlashAttention2(WhisperAttention): + """ + Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._flash_attn_uses_top_left_mask = False + + def forward( + self, + hidden_states: mindspore.Tensor, + key_value_states: Optional[mindspore.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[mindspore.Tensor] = None, + layer_head_mask: Optional[mindspore.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[mindspore.Tensor] = None, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]: + logger.warning_once( + "The `flash_attention_2` implementation is in beta and is subject to change. Please use with caution." + ) + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + ) + + if output_attentions: + raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.shape + + # get query proj + query_states = ops.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + key_states = key_states.swapaxes(1, 2) + value_states = value_states.swapaxes(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == mindspore.float32: + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query_states = query_states.astype(target_dtype) + key_states = key_states.astype(target_dtype) + value_states = value_states.astype(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, + "flash_attention_2": WhisperFlashAttention2, } @@ -794,6 +906,8 @@ def __init__(self, config: WhisperConfig): self.layers = nn.ModuleList( [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1052,8 +1166,13 @@ def _update_causal_mask( input_tensor: mindspore.Tensor, cache_position: mindspore.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache)