Skip to content

Commit 4e56e27

Browse files
Merge pull request vllm-project#6 from wenxcs/wenxh/fp8-on-a100
FP8 on A100 for PHIMOE
2 parents de23377 + e90dfdb commit 4e56e27

File tree

4 files changed

+308
-1
lines changed

4 files changed

+308
-1
lines changed

requirements-cuda.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ nvidia-ml-py # for pynvml package
77
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
88
torch == 2.2.1
99
xformers == 0.0.25 # Requires PyTorch 2.2.1
10+
11+
cupy-cuda12x
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
"""Fused MoE kernel with FP8 weight using Ampere."""
2+
3+
import vllm
4+
import torch
5+
from vllm import _custom_ops as ops
6+
import cupy
7+
from typing import Dict, Any, Optional, Callable
8+
9+
import torch
10+
import triton
11+
import triton.language as tl
12+
13+
from vllm import _custom_ops as ops
14+
from vllm.logger import init_logger
15+
from vllm.utils import is_hip
16+
17+
logger = init_logger(__name__)
18+
19+
from vllm.model_executor.layers.fused_moe.fused_moe import (
20+
get_moe_configs,
21+
moe_align_block_size,
22+
invoke_fused_moe_kernel,
23+
)
24+
25+
# <todo:wenxh> Kernels performance needs to be optimized
26+
# such as one thread deals with multiple elements to reduce memory transaction.
27+
28+
convert_fp8e4m3_to_half = cupy.RawKernel(
29+
r"""
30+
#include "cuda_fp8.h"
31+
#include "cuda_fp16.h"
32+
extern "C" __global__
33+
void convert_fp8e4m3_to_half(const __nv_fp8_storage_t* x, float *scale_p, half* y, int size) {
34+
int tid = blockDim.x * blockIdx.x + threadIdx.x;
35+
float scale = *scale_p;
36+
if (tid < size)
37+
y[tid] = __nv_cvt_fp8_to_halfraw(x[tid], __NV_E4M3) * scale;
38+
}
39+
""",
40+
"convert_fp8e4m3_to_half",
41+
)
42+
43+
convert_fp8e4m3_to_bfloat16 = cupy.RawKernel(
44+
r"""
45+
#include "cuda_fp8.h"
46+
#include "cuda_fp16.h"
47+
#include "cuda_bf16.h"
48+
extern "C" __global__
49+
void convert_fp8e4m3_to_bfloat16(const __nv_fp8_storage_t* x, float* scale_p, __nv_bfloat16* y, int size) {
50+
int tid = blockDim.x * blockIdx.x + threadIdx.x;
51+
float scale = *scale_p;
52+
if (tid < size)
53+
y[tid] = __float2bfloat16(__nv_cvt_fp8_to_halfraw(x[tid], __NV_E4M3) * scale);
54+
}
55+
""",
56+
"convert_fp8e4m3_to_bfloat16",
57+
)
58+
59+
60+
def dequantize_fp8(t_fp8, scales, dtype=torch.float16):
61+
s = torch.empty_like(t_fp8, dtype=dtype)
62+
convert = (
63+
convert_fp8e4m3_to_half
64+
if dtype == torch.float16
65+
else convert_fp8e4m3_to_bfloat16
66+
)
67+
68+
expert_num = t_fp8.shape[0]
69+
70+
expert_in = torch.chunk(t_fp8, expert_num, dim=0)
71+
expert_out = torch.chunk(s, expert_num, dim=0)
72+
73+
for i in range(expert_num):
74+
scale = scales[i]
75+
convert(
76+
((expert_in[i].numel() + 1024 - 1) // 1024,),
77+
(1024,),
78+
(expert_in[i].data_ptr(), scale.data_ptr(), expert_out[i].data_ptr(), t_fp8.numel()),
79+
)
80+
return s
81+
82+
83+
def fused_moe(
84+
hidden_states: torch.Tensor,
85+
w1: torch.Tensor,
86+
w2: torch.Tensor,
87+
gating_output: torch.Tensor,
88+
topk: int,
89+
renormalize: bool,
90+
training: bool = False,
91+
sparse_mixer: bool = False,
92+
inplace: bool = False,
93+
override_config: Optional[Dict[str, Any]] = None,
94+
use_fp8: bool = False,
95+
w1_scale: Optional[torch.Tensor] = None,
96+
w2_scale: Optional[torch.Tensor] = None,
97+
a1_scale: Optional[torch.Tensor] = None,
98+
a2_scale: Optional[torch.Tensor] = None,
99+
routing_func: Callable = torch.topk,
100+
) -> torch.Tensor:
101+
"""
102+
This function computes a Mixture of Experts (MoE) layer using two sets of
103+
weights, w1 and w2, and top-k gating mechanism.
104+
105+
This layer works the same as fused_moe, but it is used for the Ampere arch, which does not support fp8.
106+
By default, to be more comparable to Hopper, we reuse E4M3 configuration.
107+
<todo:wenxh> Use FP8E4b16 to reduce overhead:
108+
https://github.com/triton-lang/triton/blob/d7c8b3d7890125f5fc1b9f046e3189baa2665be4/python/triton/language/extra/cuda/utils.py#L34
109+
110+
Parameters:
111+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
112+
- w1 (torch.Tensor): The first set of expert weights.
113+
- w2 (torch.Tensor): The second set of expert weights.
114+
- gating_output (torch.Tensor): The output of the gating operation
115+
(before softmax).
116+
- topk (int): The number of top-k experts to select.
117+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
118+
- inplace (bool): If True, perform the operation in-place.
119+
Defaults to False.
120+
- override_config (Optional[Dict[str, Any]]): Optional override
121+
for the kernel configuration.
122+
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
123+
products for w1 and w2. Defaults to False.
124+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
125+
w1.
126+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
127+
w2.
128+
129+
Returns:
130+
- torch.Tensor: The output tensor after applying the MoE layer.
131+
"""
132+
# Check constraints.
133+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
134+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
135+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
136+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
137+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
138+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
139+
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
140+
M, _ = hidden_states.shape
141+
E, N, _ = w1.shape
142+
143+
if routing_func != torch.topk:
144+
topk_weights, topk_ids = routing_func(gating_output, topk)
145+
elif is_hip():
146+
# The MoE kernels are not yet supported on ROCm.
147+
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
148+
topk_weights, topk_ids = routing_func(routing_weights, topk)
149+
else:
150+
import vllm._moe_C as moe_kernels
151+
152+
topk_weights = torch.empty(
153+
M, topk, dtype=torch.float32, device=hidden_states.device
154+
)
155+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
156+
token_expert_indicies = torch.empty(
157+
M, topk, dtype=torch.int32, device=hidden_states.device
158+
)
159+
moe_kernels.topk_softmax(
160+
topk_weights,
161+
topk_ids,
162+
token_expert_indicies,
163+
gating_output.float(), # TODO(woosuk): Optimize this.
164+
)
165+
del token_expert_indicies # Not used. Will be used in the future.
166+
if renormalize:
167+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
168+
169+
if override_config:
170+
config = override_config
171+
else:
172+
# First try to load optimal config from the file
173+
configs = get_moe_configs(E, w2.shape[2], None)
174+
175+
if configs:
176+
# If an optimal configuration map has been found, look up the
177+
# optimal config
178+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
179+
else:
180+
# Else use the default config
181+
config = {
182+
"BLOCK_SIZE_M": 64,
183+
"BLOCK_SIZE_N": 64,
184+
"BLOCK_SIZE_K": 32,
185+
"GROUP_SIZE_M": 8,
186+
}
187+
188+
if M <= E:
189+
config = {
190+
"BLOCK_SIZE_M": 16,
191+
"BLOCK_SIZE_N": 32,
192+
"BLOCK_SIZE_K": 64,
193+
"GROUP_SIZE_M": 1,
194+
}
195+
196+
if M == 1:
197+
# expert, hs1, hs2
198+
topk_w1 = w1.view(torch.uint8)[topk_ids.flatten()]
199+
topk_w2 = w2.view(torch.uint8)[topk_ids.flatten()]
200+
topk_ids = torch.arange(
201+
topk, device=topk_ids.device, dtype=topk_ids.dtype
202+
).unsqueeze(0)
203+
204+
E = topk
205+
206+
w1_scale = w1_scale[topk_ids.flatten()]
207+
w1 = dequantize_fp8(topk_w1, w1_scale, dtype=hidden_states.dtype)
208+
209+
else:
210+
w1 = dequantize_fp8(w1, w1_scale, dtype=hidden_states.dtype)
211+
212+
use_fp8 = False
213+
w1_scale = None
214+
a1_scale = None
215+
a2_scale = None
216+
217+
intermediate_cache1 = torch.empty(
218+
(M, topk_ids.shape[1], N),
219+
device=hidden_states.device,
220+
dtype=hidden_states.dtype,
221+
)
222+
intermediate_cache2 = torch.empty(
223+
(M * topk_ids.shape[1], N // 2),
224+
device=hidden_states.device,
225+
dtype=hidden_states.dtype,
226+
)
227+
intermediate_cache3 = torch.empty(
228+
(M, topk_ids.shape[1], w2.shape[1]),
229+
device=hidden_states.device,
230+
dtype=hidden_states.dtype,
231+
)
232+
233+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
234+
topk_ids, config["BLOCK_SIZE_M"], E
235+
)
236+
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
237+
238+
invoke_fused_moe_kernel(
239+
hidden_states,
240+
w1,
241+
intermediate_cache1,
242+
a1_scale,
243+
w1_scale,
244+
topk_weights,
245+
topk_ids,
246+
sorted_token_ids,
247+
expert_ids,
248+
num_tokens_post_padded,
249+
False,
250+
topk_ids.shape[1],
251+
config,
252+
compute_type=compute_type,
253+
use_fp8=use_fp8,
254+
)
255+
256+
del w1
257+
258+
if M == 1:
259+
w2_scale = w2_scale[topk_ids.flatten()]
260+
w2 = dequantize_fp8(topk_w2, w2_scale, dtype=hidden_states.dtype)
261+
else:
262+
w2 = dequantize_fp8(w2, w2_scale, dtype=hidden_states.dtype)
263+
264+
w2_scale = None
265+
266+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
267+
268+
invoke_fused_moe_kernel(
269+
intermediate_cache2,
270+
w2,
271+
intermediate_cache3,
272+
a2_scale,
273+
w2_scale,
274+
topk_weights,
275+
topk_ids,
276+
sorted_token_ids,
277+
expert_ids,
278+
num_tokens_post_padded,
279+
True,
280+
1,
281+
config,
282+
compute_type=compute_type,
283+
use_fp8=use_fp8,
284+
)
285+
286+
del w2
287+
288+
if inplace:
289+
return torch.sum(
290+
intermediate_cache3.view(*intermediate_cache3.shape),
291+
dim=1,
292+
out=hidden_states,
293+
)
294+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_min_capability(cls) -> int:
3333
# TODO: PyTorch 2.3.0+ is required to run FP8 on
3434
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
3535
# be included: https://github.com/pytorch/pytorch/pull/118881
36-
return 90
36+
return 80
3737

3838
@classmethod
3939
def get_config_filenames(cls) -> List[str]:

vllm/model_executor/models/mixtral.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@
5454
from vllm.utils import print_warning_once
5555

5656

57+
def is_sm80(device_id=0):
58+
if not torch.cuda.is_available():
59+
return False
60+
device_properties = torch.cuda.get_device_properties(device_id)
61+
return (device_properties.major == 8 and device_properties.minor == 0)
62+
63+
if is_sm80():
64+
from vllm.model_executor.layers.fused_moe import ampere_fp8_fused_moe
65+
fused_moe = ampere_fp8_fused_moe.fused_moe
66+
5767
logger = logging.get_logger(__name__)
5868

5969

@@ -248,6 +258,7 @@ def __init__(
248258
# FIXME(pcmoritz): Make this more general to support different
249259
# quantization schemes
250260
self.use_fp8 = isinstance(quant_config, Fp8Config)
261+
assert self.use_fp8, "USE FP8"
251262

252263
if params_dtype is None:
253264
params_dtype = torch.get_default_dtype()

0 commit comments

Comments
 (0)