Skip to content

Commit

Permalink
Fix Llama sdpa/ flash attention + adapters (adapter-hub#722)
Browse files Browse the repository at this point in the history
  • Loading branch information
dainis-boumber committed Aug 30, 2024
1 parent 7a2ceb6 commit 8f84ff9
Showing 1 changed file with 47 additions and 27 deletions.
74 changes: 47 additions & 27 deletions src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""

import math
import warnings
from typing import Optional, Tuple
Expand All @@ -29,8 +28,15 @@
from torch import nn

from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache, StaticCache
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging

from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin
Expand Down Expand Up @@ -81,10 +87,12 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
Expand All @@ -98,15 +106,16 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
bsz = key_states.shape[0]
# >>> END AH Changes <<<

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

bsz = key_states.shape[0]

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down Expand Up @@ -139,7 +148,7 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaAttention):
class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaFlashAttention2):
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -151,6 +160,12 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

output_attentions = False

bsz, q_len, _ = hidden_states.size()
Expand All @@ -166,27 +181,29 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)

# Make adjustments since (parallel) prefix tuning changes the attention mask
bsz = key_states.shape[0]

past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# >>> END AH Changes <<<

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
Expand All @@ -203,7 +220,7 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if input_dtype == torch.float32 or key_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
Expand All @@ -213,8 +230,8 @@ def forward(
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

Expand All @@ -235,7 +252,7 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention):
class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaSdpaAttention):

# Adapted from LlamaAttention.forward
def forward(
Expand All @@ -251,10 +268,8 @@ def forward(
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does"
" not support `output_attentions=True`. Falling back to the manual attention implementation, but"
" specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This"
' warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
Expand All @@ -276,17 +291,16 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand All @@ -295,15 +309,16 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
# >>> END AH Changes <<<

bsz = key_states.shape[0]

causal_mask = attention_mask
# if attention_mask is not None and cache_position is not None:
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

Expand All @@ -314,12 +329,17 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -393,4 +413,4 @@ def forward(
if use_cache:
outputs += (present_key_value,)

return outputs
return outputs

0 comments on commit 8f84ff9

Please sign in to comment.