Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Llama sdpa/ flash attention + adapters #722

Merged
merged 3 commits into from
Jul 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@
from torch import nn

from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache, StaticCache
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging

from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin
Expand Down Expand Up @@ -80,10 +87,12 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
Expand All @@ -97,15 +106,16 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
bsz = key_states.shape[0]
# >>> END AH Changes <<<

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

bsz = key_states.shape[0]

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down Expand Up @@ -138,7 +148,7 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaAttention):
class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaFlashAttention2):
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -150,6 +160,12 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

output_attentions = False

bsz, q_len, _ = hidden_states.size()
Expand All @@ -165,27 +181,29 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)

# Make adjustments since (parallel) prefix tuning changes the attention mask
bsz = key_states.shape[0]

past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# >>> END AH Changes <<<

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
Expand All @@ -202,7 +220,7 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if input_dtype == torch.float32 or key_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
Expand All @@ -212,8 +230,8 @@ def forward(
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

Expand All @@ -234,7 +252,7 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention):
class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaSdpaAttention):

# Adapted from LlamaAttention.forward
def forward(
Expand All @@ -250,10 +268,8 @@ def forward(
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does"
" not support `output_attentions=True`. Falling back to the manual attention implementation, but"
" specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This"
' warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
Expand All @@ -275,17 +291,16 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand All @@ -294,15 +309,16 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
# >>> END AH Changes <<<

bsz = key_states.shape[0]

causal_mask = attention_mask
# if attention_mask is not None and cache_position is not None:
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

Expand All @@ -313,12 +329,17 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
Loading