Skip to content

Commit

Permalink
[modelling] remove un-necessary transpose for fa2 attention (#31749)
Browse files Browse the repository at this point in the history
* [whisper] remove un-necessary transpose for fa2 attention

* propagate
  • Loading branch information
sanchit-gandhi authored and itazap committed Jul 25, 2024
1 parent 60130df commit 128f36e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def forward(
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

Expand All @@ -311,7 +311,6 @@ def forward(

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down Expand Up @@ -817,7 +816,7 @@ def forward(
key_states = self.k_proj(torch.cat([context, latents], dim=-2))
value_states = self.v_proj(torch.cat([context, latents], dim=-2))

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

Expand Down Expand Up @@ -882,7 +881,6 @@ def forward(
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def forward(
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

Expand Down Expand Up @@ -469,7 +469,6 @@ def forward(
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def forward(
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
Expand Down Expand Up @@ -416,7 +416,6 @@ def forward(

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

Expand Down

0 comments on commit 128f36e

Please sign in to comment.