Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unpad/pad repeated calls when use_cache=False #5

Open
wants to merge 1 commit into
base: add-flash-attn-2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 149 additions & 67 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -211,6 +211,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

def apply_rotary_pos_emb_unpad(q, key_states, cos, sin, position_ids_unpad):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(0, 1) # [seq_len, dim]
sin = sin.squeeze(0, 1) # [seq_len, dim]
cos = cos[position_ids_unpad] # [total_tokens, 1, dim]
sin = sin[position_ids_unpad] # [total_tokens, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
key_states.copy_((key_states * cos) + (rotate_half(key_states) * sin))
return q_embed


class LlamaMLP(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -323,9 +333,14 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
flash_kwargs: Optional[Dict] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()


if flash_kwargs is not None and flash_kwargs["is_unpadded"]:
raise ValueError("Non flash does not support the unpadded path")

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
Expand Down Expand Up @@ -478,89 +493,123 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
flash_kwargs: Optional[Dict] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention attention does not support output_attentions
output_attentions = False

bsz, q_len, _ = hidden_states.size()
if not flash_kwargs["is_unpadded"]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout

# contains at least one padding token
if flash_kwargs["masking"]:
indices_k = flash_kwargs["indices_k"]
cu_seqlens_k = flash_kwargs["cu_seqlens_k"]
max_seqlen_in_batch_k = flash_kwargs["max_seqlen_in_batch_k"]

key_states = index_first_axis(rearrange(key_states, "b s ... -> (b s) ..."), indices_k)
value_states = index_first_axis(rearrange(value_states, "b s ... -> (b s) ..."), indices_k)

# In an ideal world, at least for the path q_len == kv_seq_len and q_len == 1, we should collect the
if q_len == kv_seq_len:
query_states = index_first_axis(rearrange(query_states, "b s ... -> (b s) ..."), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = flash_kwargs["cu_seqlens_q"]
indices_q = flash_kwargs["indices_q"]
query_states = query_states.squeeze(1) # [batch_size, 1, num_heads, head_dim] -> [batch_size, num_heads, head_dim]
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -q_len:]
query_states, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_states, padding_mask)

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
else:
attn_output = flash_attn_func(query_states, key_states, value_states, dropout_rate, causal=True)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
else:
# is_unpadded path
total_tokens, _ = hidden_states.size()

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = self.q_proj(hidden_states).view(total_tokens, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(total_tokens, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(total_tokens, self.num_heads, self.head_dim)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
max_seqlen = flash_kwargs["max_seqlen_in_batch_k"]

past_key_value = (key_states, value_states) if use_cache else None
cos, sin = self.rotary_emb(value_states, seq_len=max_seqlen)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout

# contains at least one padding token
if padding_mask is not None:
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
key_states = index_first_axis(rearrange(key_states, "b s ... -> (b s) ..."), indices_k)
value_states = index_first_axis(rearrange(value_states, "b s ... -> (b s) ..."), indices_k)

# In an ideal world, at least for the path q_len == kv_seq_len and q_len == 1, we should collect the
if q_len == kv_seq_len:
query_states = index_first_axis(rearrange(query_states, "b s ... -> (b s) ..."), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
bsz + 1, dtype=torch.int32, device=query_states.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_states = query_states.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -q_len:]
query_states, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_states, padding_mask)
# key_states modified in place
query_states = apply_rotary_pos_emb_unpad(query_states, key_states, cos, sin, position_ids)

attn_output_unpad = flash_attn_varlen_func(
# It would be nice to use rather the flash_attn_kvpacked_func interface, with a single nn.Linear to compute keys/values
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
cu_seqlens_q=flash_kwargs["cu_seqlens_k"],
cu_seqlens_k=flash_kwargs["cu_seqlens_k"],
max_seqlen_q=flash_kwargs["max_seqlen_in_batch_k"],
max_seqlen_k=flash_kwargs["max_seqlen_in_batch_k"],
dropout_p=0.0,
softmax_scale=None,
causal=True,
)

attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
else:
attn_output = flash_attn_func(query_states, key_states, value_states, dropout_rate, causal=True)
attn_output = attn_output.reshape(-1, self.num_heads * self.head_dim)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand Down Expand Up @@ -591,6 +640,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
flash_kwargs: Optional[Dict] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -619,6 +669,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
flash_kwargs=flash_kwargs,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -770,6 +821,7 @@ def __init__(self, config: LlamaConfig):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
self._flash = getattr(config, "_flash_attn_2_enabled", False)
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -863,13 +915,39 @@ def forward(
padding_mask = attention_mask
else:
padding_mask = None

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)


hidden_states = inputs_embeds

if not self._flash:
flash_kwargs = None
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
flash_kwargs = {}
flash_kwargs["masking"] = padding_mask is not None

if padding_mask is not None:
if not use_cache:
hidden_states, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = unpad_input(hidden_states, padding_mask)
position_ids = position_ids.expand(batch_size, seq_length)
position_ids, _, _, _ = unpad_input(position_ids.unsqueeze(-1), padding_mask)
is_unpadded = True
flash_kwargs["is_unpadded"] = True
else:
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
flash_kwargs["is_unpadded"] = False
if seq_length == 1:
flash_kwargs["cu_seqlens_q"] = torch.arange(
batch_size + 1, dtype=torch.int32, device=input_ids.device
) # There is a memcpy here, that is very bad. At least happening only once.
flash_kwargs["indices_q"] = flash_kwargs["cu_seqlens_q"][:-1]
flash_kwargs["indices_k"] = indices_k
flash_kwargs["cu_seqlens_k"] = cu_seqlens_k
flash_kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k
else:
flash_kwargs["is_unpadded"] = False

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
Expand Down Expand Up @@ -909,6 +987,7 @@ def custom_forward(*inputs):
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
flash_kwargs=flash_kwargs,
)

hidden_states = layer_outputs[0]
Expand All @@ -921,6 +1000,9 @@ def custom_forward(*inputs):

hidden_states = self.norm(hidden_states)

if self._flash and padding_mask is not None and not use_cache:
hidden_states = pad_input(hidden_states, indices_k, batch_size, max_seqlen_in_batch_k)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand Down