Skip to content

Commit

Permalink
updt to include low precision groupnorm;
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Jan 20, 2024
1 parent 6c0472b commit bca1c33
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
67 changes: 57 additions & 10 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import transformers
from einops import rearrange
from packaging import version
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.norm import (NORM_CLASS_REGISTRY, LPLayerNorm,
low_precision_groupnorm)


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -498,6 +498,47 @@ def triton_flash_attn_fn(
return output, None, past_key_value


def _expand_params(heads: int, param: Optional[torch.Tensor] = None):
if param is None:
return None
return param.repeat(heads)


def _apply_qk_gn(
query: torch.Tensor,
key: torch.Tensor,
n_heads: int,
kv_n_heads: int,
q_ln: nn.Module,
k_ln: nn.Module,
):
dtype = query.dtype

w = _expand_params(n_heads, q_ln.weight)
b = _expand_params(n_heads, q_ln.bias)
if isinstance(q_ln, LPLayerNorm):
query = low_precision_groupnorm(query, n_heads, w, b,
eps=q_ln.eps).to(dtype)
elif isinstance(q_ln, nn.LayerNorm):
query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps)
else:
raise ValueError(
f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).')

w = _expand_params(kv_n_heads, k_ln.weight)
b = _expand_params(kv_n_heads, k_ln.bias)
if isinstance(k_ln, LPLayerNorm):
key = low_precision_groupnorm(key, kv_n_heads, w, b,
eps=k_ln.eps).to(dtype)
elif isinstance(k_ln, nn.LayerNorm):
key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps)
else:
raise ValueError(
f'qk_gn not applicable for given k_ln type ({type(k_ln)=}).')

return query, key


class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
Expand Down Expand Up @@ -629,16 +670,22 @@ def forward(

key_padding_mask = attention_mask

if self.qk_ln or self.qk_gn:
if self.qk_gn:
# Applying groupnorm to qk
query, key = _apply_qk_gn(
query,
key,
self.n_heads,
self.kv_n_heads,
self.q_ln,
self.k_ln,
)

if self.qk_ln:
# Applying layernorm to qk
q_shape, k_shape = query.shape, key.shape
if self.qk_gn:
b, s = query.shape[:2]
query = query.view(b, s, self.n_heads, -1)
key = key.view(b, s, self.kv_n_heads, -1)
dtype = query.dtype
query = self.q_ln(query).to(dtype).view(q_shape)
key = self.k_ln(key).to(dtype).view(k_shape)
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)

if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
Expand Down
23 changes: 23 additions & 0 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)


def low_precision_groupnorm(
x: torch.Tensor,
groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-05,
):
device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = _cast_if_autocast_enabled(
weight) if weight is not None else weight
downcast_bias = _cast_if_autocast_enabled(
bias) if bias is not None else bias
with torch.autocast(enabled=False, device_type=device.type):
return torch.nn.functional.group_norm(
downcast_x,
groups,
downcast_weight,
downcast_bias,
eps,
)


def rms_norm(x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5) -> torch.Tensor:
Expand Down

0 comments on commit bca1c33

Please sign in to comment.