Skip to content

Commit c94d4c0

Browse files
kliuaetjtanaavllmellm
authored andcommitted
[ROCm] Add aiter tkw1 kernel for Llama4 fp8 (vllm-project#16727)
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent 705f5f1 commit c94d4c0

File tree

6 files changed

+136
-50
lines changed

6 files changed

+136
-50
lines changed

docker/Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="8970b25b"
15+
ARG AITER_BRANCH="5a77249"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base

vllm/envs.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
VLLM_ROCM_USE_AITER: bool = False
7878
VLLM_ROCM_USE_AITER_LINEAR: bool = True
7979
VLLM_ROCM_USE_AITER_MOE: bool = True
80-
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
8180
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
8281
VLLM_ROCM_FP8_PADDING: bool = True
8382
VLLM_ROCM_MOE_PADDING: bool = True
@@ -546,13 +545,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
546545
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
547546
("true", "1")),
548547

549-
# Whether to use aiter block scaled moe kernel.
550-
# By default this is disabled.
551-
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
552-
lambda:
553-
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
554-
("true", "1")),
555-
556548
# use aiter rms norm op if aiter ops are enabled.
557549
"VLLM_ROCM_USE_AITER_RMSNORM":
558550
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
from vllm.platforms import current_platform
2424
from vllm.utils import direct_register_custom_op
2525

26-
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
27-
rocm_aiter_fused_experts,
28-
rocm_aiter_topk_softmax)
26+
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
2927

3028
logger = init_logger(__name__)
3129

@@ -846,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
846844

847845
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
848846
if is_rocm_aiter_moe_enabled():
847+
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
849848
return rocm_aiter_topk_softmax
850849
return vllm_topk_softmax
851850

@@ -1102,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
11021101

11031102
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
11041103
if is_rocm_aiter_moe_enabled():
1104+
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
11051105
return rocm_aiter_fused_experts
11061106
if inplace:
11071107
return torch_vllm_inplace_fused_experts

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,68 @@
1010
def is_rocm_aiter_moe_enabled() -> bool:
1111
return current_platform.is_rocm() \
1212
and envs.VLLM_ROCM_USE_AITER_MOE \
13-
and envs.VLLM_ROCM_USE_AITER \
14-
15-
16-
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
17-
return is_rocm_aiter_moe_enabled() and \
18-
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
13+
and envs.VLLM_ROCM_USE_AITER
14+
15+
16+
def rocm_aiter_asm_moe_tkw1(hidden_states,
17+
w1,
18+
w2,
19+
topk_weight,
20+
topk_ids,
21+
fc1_scale=None,
22+
fc2_scale=None,
23+
fc1_smooth_scale=None,
24+
fc2_smooth_scale=None,
25+
a16=False,
26+
per_tensor_quant_scale=None,
27+
expert_mask=None,
28+
activation_str: str = "silu") -> None:
29+
30+
from aiter import ActivationType
31+
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
32+
33+
activation = \
34+
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu
35+
36+
return asm_moe_tkw1(hidden_states,
37+
w1,
38+
w2,
39+
topk_weight,
40+
topk_ids,
41+
fc1_scale=fc1_scale,
42+
fc2_scale=fc2_scale,
43+
fc1_smooth_scale=fc1_smooth_scale,
44+
fc2_smooth_scale=fc2_smooth_scale,
45+
a16=a16,
46+
per_tensor_quant_scale=per_tensor_quant_scale,
47+
expert_mask=expert_mask,
48+
activation=activation)
1949

2050

2151
def rocm_aiter_fused_experts(
22-
*,
23-
hidden_states: torch.Tensor,
24-
w1: torch.Tensor,
25-
w2: torch.Tensor,
26-
topk_weights: torch.Tensor,
27-
topk_ids: torch.Tensor,
28-
use_fp8_w8a8: bool = False,
29-
apply_router_weight_on_input: bool = False,
30-
w1_scale: Optional[torch.Tensor] = None,
31-
w2_scale: Optional[torch.Tensor] = None,
32-
block_shape: Optional[List[int]] = None,
33-
expert_mask: Optional[torch.Tensor] = None,
34-
**kwagrs # Ignore additional keyword arguments
52+
hidden_states: torch.Tensor,
53+
w1: torch.Tensor,
54+
w2: torch.Tensor,
55+
topk_weights: torch.Tensor,
56+
topk_ids: torch.Tensor,
57+
inplace: bool = False,
58+
activation: str = "silu",
59+
apply_router_weight_on_input: bool = False,
60+
use_fp8_w8a8: bool = False,
61+
use_int8_w8a8: bool = False,
62+
use_int8_w8a16: bool = False,
63+
use_int4_w4a16: bool = False,
64+
per_channel_quant: bool = False,
65+
global_num_experts: int = -1,
66+
expert_map: Optional[torch.Tensor] = None,
67+
w1_scale: Optional[torch.Tensor] = None,
68+
w2_scale: Optional[torch.Tensor] = None,
69+
w1_zp: Optional[torch.Tensor] = None,
70+
w2_zp: Optional[torch.Tensor] = None,
71+
a1_scale: Optional[torch.Tensor] = None,
72+
a2_scale: Optional[torch.Tensor] = None,
73+
block_shape: Optional[List[int]] = None,
74+
allow_deep_gemm: bool = False,
3575
) -> torch.Tensor:
3676

3777
import aiter as rocm_aiter
@@ -40,25 +80,21 @@ def rocm_aiter_fused_experts(
4080
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
4181
per_token_group_quant_fp8)
4282

43-
if apply_router_weight_on_input:
44-
assert (topk_weights.dim() == 2
45-
), "`topk_weights` should be in shape (num_tokens, topk)"
46-
_, topk = topk_weights.shape
47-
assert (
48-
topk == 1
49-
), "Only support topk=1 when `apply_router_weight_on_input` is True"
83+
# All AITER Fused MoE kernels are expecting the following datatypes
84+
topk_weights = topk_weights.to(torch.float32)
85+
topk_ids = topk_ids.to(torch.int32)
5086

51-
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
52-
topk_ids = topk_ids.to(torch.int32)
53-
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
87+
if (block_shape is not None) and use_fp8_w8a8:
88+
assert not apply_router_weight_on_input, (
89+
"apply_router_weight_on_input is not supported for block scaled moe"
90+
)
5491

55-
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
5692
assert w1_scale is not None
5793
assert w2_scale is not None
5894

5995
local_E = E = w1.shape[0]
60-
if expert_mask is not None:
61-
E = expert_mask.numel()
96+
if expert_map is not None:
97+
E = expert_map.numel()
6298

6399
topk = topk_ids.shape[1]
64100
model_dim = w1.shape[-1]
@@ -80,7 +116,7 @@ def rocm_aiter_fused_experts(
80116
E,
81117
model_dim,
82118
dtype,
83-
expert_mask=expert_mask)
119+
expert_mask=expert_map)
84120

85121
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
86122
rocm_aiter.fmoe_fp8_blockscale_g1u1(
@@ -102,7 +138,33 @@ def rocm_aiter_fused_experts(
102138
)
103139
return out_asm
104140

141+
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
142+
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
143+
# This applies topk_weights on the GEMM output of the first FC layer
144+
# rather than the second FC.
145+
assert (topk_weights.dim() == 2
146+
), "`topk_weights` should be in shape (num_tokens, topk)"
147+
assert topk_weights.shape[-1] == 1, (
148+
"Only support topk=1 when"
149+
" `apply_router_weight_on_input` is True")
150+
151+
return rocm_aiter_asm_moe_tkw1(hidden_states,
152+
w1,
153+
w2,
154+
topk_weights,
155+
topk_ids,
156+
fc1_scale=w1_scale,
157+
fc2_scale=w2_scale,
158+
fc1_smooth_scale=None,
159+
fc2_smooth_scale=None,
160+
a16=False,
161+
per_tensor_quant_scale=None,
162+
expert_mask=expert_map,
163+
activation_str=activation)
164+
105165
elif use_fp8_w8a8:
166+
assert not apply_router_weight_on_input, (
167+
"apply_router_weight_on_input is not supported for fp8_w8a8")
106168
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
107169
w1=w1,
108170
w2=w2,
@@ -114,6 +176,18 @@ def rocm_aiter_fused_experts(
114176
fc2_smooth_scale=None,
115177
a16=False)
116178

179+
if apply_router_weight_on_input:
180+
assert (topk_weights.dim() == 2
181+
), "`topk_weights` should be in shape (num_tokens, topk)"
182+
_, topk = topk_weights.shape
183+
assert (
184+
topk == 1
185+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
186+
187+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
188+
topk_ids = topk_ids.to(torch.int32)
189+
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
190+
117191
return rocm_aiter.ck_moe(hidden_states=hidden_states,
118192
w1=w1,
119193
w2=w2,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,28 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
250250
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
251251
requires_grad=False)
252252

253+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
254+
is_rocm_aiter_moe_enabled)
255+
256+
# Property to determine if AITER is used
257+
if is_rocm_aiter_moe_enabled():
258+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
259+
rocm_aiter_fused_experts, shuffle_weights)
260+
261+
# reshaping weights is required for aiter moe kernel.
262+
shuffled_w13, shuffled_w2 = shuffle_weights(
263+
layer.w13_weight.data, layer.w2_weight.data)
264+
265+
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
266+
requires_grad=False)
267+
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
268+
requires_grad=False)
269+
270+
self.fused_experts_func = rocm_aiter_fused_experts
271+
else:
272+
from vllm.model_executor.layers.fused_moe import fused_experts
273+
self.fused_experts_func = fused_experts
274+
253275
def apply(
254276
self,
255277
layer: torch.nn.Module,
@@ -268,7 +290,6 @@ def apply(
268290
apply_router_weight_on_input: bool = False,
269291
activation: str = "silu",
270292
) -> torch.Tensor:
271-
from vllm.model_executor.layers.fused_moe import fused_experts
272293

273294
topk_weights, topk_ids = FusedMoE.select_experts(
274295
hidden_states=x,
@@ -282,7 +303,7 @@ def apply(
282303
scoring_func=scoring_func,
283304
e_score_correction_bias=e_score_correction_bias)
284305

285-
return fused_experts(
306+
return self.fused_experts_func(
286307
x,
287308
layer.w13_weight,
288309
layer.w2_weight,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
575575
def process_weights_after_loading(self, layer: Module) -> None:
576576
# Lazy import to avoid importing triton too early.
577577
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
578-
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
579-
is_rocm_aiter_moe_enabled, shuffle_weights)
578+
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
580579

581580
# TODO (rob): refactor block quant into separate class.
582581
if self.block_quant:
@@ -603,7 +602,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
603602
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
604603
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
605604
requires_grad=False)
606-
if is_rocm_aiter_block_scaled_moe_enabled():
605+
if is_rocm_aiter_moe_enabled():
607606
# reshaping weights is required for aiter moe kernel.
608607
shuffled_w13, shuffled_w2 = shuffle_weights(
609608
layer.w13_weight.data, layer.w2_weight.data)

0 commit comments

Comments
 (0)