Skip to content

Commit

Permalink
new FA2 flow if position_ids is provided
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhui Dih Lee authored and Rhui Dih Lee committed Jul 12, 2024
1 parent 33ca44b commit c3451db
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,30 @@ def _upad_input(
)


def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length):

query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)

cu_seq_lens = torch.cat((
indices_q[position_ids==0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32)
))

max_length = position_ids.max()+1

return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length)
)

def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -137,6 +161,7 @@ def _flash_attention_forward(
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
Expand Down Expand Up @@ -203,6 +228,32 @@ def _flash_attention_forward(
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and not all examples (row) contain only 1 sequence
elif position_ids is not None and not (position_ids[:,-1]==position_ids.size(1)-1).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)

attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def forward(
value_layer,
attention_mask,
query_length,
position_ids=position_ids,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
softmax_scale=None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
Expand Down

0 comments on commit c3451db

Please sign in to comment.