Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring attention #1182

Merged
merged 16 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 140 additions & 88 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import math
import warnings
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
import transformers
from einops import rearrange
from packaging import version
Expand Down Expand Up @@ -233,7 +232,6 @@ def flash_attn_fn(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -506,6 +504,54 @@ def forward(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
query, key, value = self.get_qkv(x)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
rotary_emb_w_meta_info,
query,
key,
value,
)

extra_attn_kwargs = self.get_implementation_specific_args(
attention_mask,
alibi_slopes,
flash_attn_padding_info,
)

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
n_heads=self.n_heads,
kv_n_heads=self.kv_n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value

def get_qkv(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.

Args:
x (torch.Tensor): The input tensor.

Returns:
query (torch.Tensor): The query tensor.
key (torch.Tensor): The key tensor.
value (torch.Tensor): The value tensor.
"""
qkv = self.Wqkv(x)

if self.clip_qkv:
Expand All @@ -520,8 +566,6 @@ def forward(
dim=2,
)

key_padding_mask = attention_mask

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
q_shape, k_shape = query.shape, key.shape
Expand All @@ -533,97 +577,105 @@ def forward(
query = self.q_ln(query).to(dtype).view(q_shape)
key = self.k_ln(key).to(dtype).view(k_shape)

if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(
query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len,
return query, key, value

def _apply_rotary_embeddings(
self,
rotary_emb_w_meta_info: Dict[str, Any],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(
query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len,
)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(bsz, seqlen, -1)
elif rotary_emb_w_meta_info['impl'] == 'hf':
if is_transformers_version_gte('4.38'):
(cos, sin) = rotary_emb(
x=value,
position_ids=offset_info,
)
else:
(cos, sin) = rotary_emb(x=value, seq_len=seq_len)
if is_transformers_version_gte('4.38'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=None,
unsqueeze_dim=2,
)
elif is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
unsqueeze_dim=2,
)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
if is_transformers_version_gte('4.38'):
(cos, sin) = rotary_emb(
x=value,
position_ids=offset_info,
)
else:
(cos, sin) = rotary_emb(x=value, seq_len=seq_len)
if is_transformers_version_gte('4.38'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=None,
unsqueeze_dim=2,
)
elif is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
unsqueeze_dim=2,
)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
)
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

extra_attn_kwargs = {}
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, -1)
key = key.view(bsz, seqlen, -1)
return query, key, value

def get_implementation_specific_args(
self,
attention_mask: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> dict[str, Any]:
"""Returns attention implementation specific args.

Args:
attention_mask (Optional[torch.Tensor]): The attention mask.
alibi_slopes (Optional[torch.Tensor]): The alibi slopes.
flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention.

Returns:
extra_attn_kwargs (dict[str, Any]): Implementation specific args.
"""
if self.attn_impl == 'flash':
key_padding_mask = None
extra_attn_kwargs = {
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'flash_attn_padding_info': flash_attn_padding_info,
'key_padding_mask': None,
}

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
self.kv_n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value
else:
extra_attn_kwargs = {'key_padding_mask': attention_mask}
return extra_attn_kwargs


@attention_classes.register_class('multihead_attention')
Expand Down
66 changes: 39 additions & 27 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""GPT Blocks used for the GPT Model."""

from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Set, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -88,6 +88,8 @@ def __init__(
self.norm_attn_norm = FusedNormAttentionNorm(
d_model=d_model,
n_heads=n_heads,
args_to_exclude_in_attn_class=self.
args_to_exclude_in_attn_class,
attn_config=attn_config,
ffn_has_norm=ffn_has_norm,
fc_type=fc_type,
Expand All @@ -99,21 +101,10 @@ def __init__(
else:
assert isinstance(attn_config['attn_type'], str)
# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type',
'alibi',
'attn_uses_sequence_id',
'alibi_bias_max',
'rope',
'rope_theta',
'rope_impl',
'rope_dail_config',
'rope_hf_config',
}
attn_config_subset_for_attn_class = {
k: v
for k, v in attn_config.items()
if k not in args_to_exclude_in_attn_class
if k not in self.args_to_exclude_in_attn_class
}

self.norm_1 = build_norm(
Expand Down Expand Up @@ -153,6 +144,20 @@ def __init__(
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn

@property
def args_to_exclude_in_attn_class(self):
return {
'attn_type',
'alibi',
'attn_uses_sequence_id',
'alibi_bias_max',
'rope',
'rope_theta',
'rope_impl',
'rope_dail_config',
'rope_hf_config',
}

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -196,6 +201,24 @@ def forward(
if self.norm_2 is not None:
m = self.norm_2(x)

n = self.apply_ffn(attention_mask, m)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value

def apply_ffn(
self,
attention_mask: Optional[torch.ByteTensor],
m: torch.Tensor,
) -> torch.Tensor:
"""Apply feed forward layers to the input.

Args:
attention_mask (Optional[torch.ByteTensor]): The attention mask.
m (torch.Tensor): The input.

Returns:
n (torch.Tensor): The output.
"""
batch_size, seq_len = m.size()[:2]
indices = None
if not self.use_pad_tok_in_ffn:
Expand All @@ -205,8 +228,7 @@ def forward(
if not self.use_pad_tok_in_ffn:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
return n


class FusedNormAttentionNorm(nn.Module):
Expand All @@ -215,6 +237,7 @@ def __init__(
self,
d_model: int,
n_heads: int,
args_to_exclude_in_attn_class: Set[str],
attn_config: Optional[Dict] = None,
ffn_has_norm: bool = False,
fc_type: str = 'torch',
Expand All @@ -228,18 +251,7 @@ def __init__(
assert attn_config is not None
assert isinstance(attn_config['attn_type'], str)

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type',
'alibi',
'attn_uses_sequence_id',
'alibi_bias_max',
'rope',
'rope_theta',
'rope_impl',
'rope_dail_config',
'rope_hf_config',
}
# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
attn_config_subset_for_attn_class = {
k: v
for k, v in attn_config.items()
Expand Down
Loading
Loading