Skip to content

Commit

Permalink
feat: fix fp8 for MLA and support bmm fp8 for DeepSeek V2 (#1285)
Browse files Browse the repository at this point in the history
Co-authored-by: ispobock <ispobaoke@163.com>
  • Loading branch information
zhyncs and ispobock authored Sep 1, 2024
1 parent 1b5d56f commit 54772f7
Showing 1 changed file with 41 additions and 19 deletions.
60 changes: 41 additions & 19 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
from flashinfer import bmm_fp8
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
Expand Down Expand Up @@ -161,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0


def input_to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


class DeepseekV2Attention(nn.Module):

def __init__(
Expand Down Expand Up @@ -255,11 +265,6 @@ def __init__(
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)

# TODO, support head_size 192
self.attn = RadixAttention(
self.num_local_heads,
Expand All @@ -283,7 +288,7 @@ def forward(
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
Expand Down Expand Up @@ -419,6 +424,7 @@ def __init__(

self.w_kc = None
self.w_vc = None
self.w_scale = None

def forward(
self,
Expand All @@ -439,8 +445,17 @@ def forward(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))

if self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)

latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
v_input = latent_cache[..., : self.kv_lora_rank]
Expand All @@ -455,16 +470,21 @@ def forward(

attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = attn_output.new_empty(
q_len, self.num_local_heads, self.v_head_dim
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc,
out=attn_bmm_output.transpose(0, 1),
)

attn_output = attn_bmm_output.flatten(1, 2)
if self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
attn_bmm_output = bmm_fp8(
attn_output_val,
self.w_vc,
attn_output_scale,
self.w_scale,
torch.bfloat16,
)
else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
output, _ = self.o_proj(attn_output)

return output
Expand Down Expand Up @@ -717,8 +737,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.contiguous()
self_attn.w_vc = w_vc.transpose(1, 2).contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
del self_attn.kv_b_proj


Expand Down

0 comments on commit 54772f7

Please sign in to comment.