Skip to content

Commit 4812f00

Browse files
committed
feat(graph): Refactor and optimize MoE with unified W8A8 support
Refactors the Fused MoE implementation by unifying the quantized and non-quantized execution paths into a single `fused_experts` function. This simplifies the codebase and centralizes MoE logic. Adds support for W8A8 dynamic quantization within the unified MoE kernel. Communication methods are updated to handle dynamic scales for quantized activations. Additionally, this change introduces a weight pre-processing step that transposes and converts weights to the `NZ` format, optimizing `matmul` performance on NPU hardware. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent dfc7eb3 commit 4812f00

File tree

3 files changed

+94
-21
lines changed

3 files changed

+94
-21
lines changed

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def permute(
5454
topk_weights: torch.Tensor,
5555
expert_map: torch.Tensor,
5656
num_experts: int,
57+
use_a8: bool,
5758
) -> tuple[torch.Tensor, torch.Tensor, int]:
5859
"""Pre-process before MLP.
5960
@@ -159,6 +160,7 @@ def permute(
159160
topk_weights: torch.Tensor,
160161
expert_map: torch.Tensor, # noqa: F841
161162
num_experts: int,
163+
use_a8: bool,
162164
) -> tuple[torch.Tensor, torch.Tensor, int]:
163165
num_tokens = hidden_states.shape[0]
164166

@@ -194,7 +196,7 @@ def permute(
194196

195197
group_list_type = 1 # `count` mode
196198

197-
return permuted_hidden_states, expert_tokens, group_list_type
199+
return permuted_hidden_states, expert_tokens, None, group_list_type
198200

199201
def unpermute(self, mlp_output: torch.Tensor,
200202
hidden_states: torch.Tensor) -> None:
@@ -219,6 +221,7 @@ def permute(
219221
topk_weights: torch.Tensor,
220222
expert_map: torch.Tensor,
221223
num_experts: int,
224+
use_a8: bool,
222225
) -> tuple[torch.Tensor, torch.Tensor, int]:
223226
num_tokens = hidden_states.shape[0]
224227

@@ -269,7 +272,7 @@ def permute(
269272

270273
group_list_type = 1 # `count` mode
271274

272-
return permuted_hidden_states, expert_tokens, group_list_type
275+
return permuted_hidden_states, expert_tokens, None, group_list_type
273276

274277
def unpermute(self, mlp_output: torch.Tensor,
275278
hidden_states: torch.Tensor) -> None:
@@ -375,6 +378,7 @@ def permute(
375378
topk_weights: torch.Tensor,
376379
expert_map: torch.Tensor,
377380
num_experts: int,
381+
use_a8: bool,
378382
) -> tuple[torch.Tensor, torch.Tensor, int]:
379383
# Store tensors needed for post_process
380384
self.topk_ids = topk_ids
@@ -388,7 +392,7 @@ def permute(
388392
"moe_expert_num": self.moe_config.num_experts,
389393
"global_bs": 0,
390394
"scales": None,
391-
"quant_mode": 0,
395+
"quant_mode": 2 if use_a8 else 0,
392396
"group_ep": self.mc2_comm_name,
393397
"ep_world_size": self.moe_config.ep_size,
394398
"ep_rank_id": self.moe_config.ep_rank,
@@ -409,7 +413,7 @@ def permute(
409413

410414
(
411415
permuted_hidden_states,
412-
_, # dynamic_scale is not used
416+
dynamic_scale,
413417
self.assist_info_for_combine,
414418
expert_tokens,
415419
self.ep_recv_counts,
@@ -418,7 +422,7 @@ def permute(
418422

419423
group_list_type = 1
420424

421-
return permuted_hidden_states, expert_tokens, group_list_type
425+
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
422426

423427
def unpermute(self, mlp_output: torch.Tensor,
424428
hidden_states: torch.Tensor) -> None:

vllm_ascend/ops/common_fused_moe.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Callable, Optional
1919

2020
import torch
21+
import torch_npu
2122
from vllm.config import CompilationLevel, get_current_vllm_config
2223
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
2324
from vllm.forward_context import get_forward_context
@@ -31,7 +32,7 @@
3132
from vllm_ascend.distributed.parallel_state import get_mc2_group
3233
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
3334
from vllm_ascend.ops.layers.experts_selector import select_experts
34-
from vllm_ascend.utils import is_310p
35+
from vllm_ascend.utils import is_310p, ACL_FORMAT_FRACTAL_NZ
3536

3637
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
3738

@@ -52,7 +53,6 @@ def fused_experts(
5253
w2_scale: Optional[torch.Tensor] = None,
5354
w1_scale_bias: torch.Tensor = None,
5455
w2_scale_bias: torch.Tensor = None,
55-
moe_comm_method: Optional[MoECommMethod] = None,
5656
# For TorchAir graph
5757
is_torchair: bool = False,
5858
# For Cube/Vector parallel
@@ -64,8 +64,8 @@ def fused_experts(
6464
global_redundant_expert_num: int = 0,
6565
) -> torch.Tensor:
6666
# Check constraints
67-
assert hidden_states.shape[1] == w1.shape[2], (
68-
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
67+
assert hidden_states.shape[1] == w1.shape[1], (
68+
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
6969

7070
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
7171
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@@ -74,20 +74,58 @@ def fused_experts(
7474
assert hidden_states.dtype in [
7575
torch.float32, torch.float16, torch.bfloat16
7676
]
77+
78+
moe_comm_method = get_forward_context().moe_comm_method
7779
assert moe_comm_method is not None, "Missing communication context"
7880

7981
num_experts = w1.shape[0]
8082

81-
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute(
82-
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
83-
mlp_output = apply_mlp(
84-
permuted_hidden_states,
85-
w1,
86-
w2,
87-
expert_tokens,
83+
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
84+
hidden_states, topk_ids, topk_weights, expert_map, num_experts, use_int8_w8a8 or use_int4_w4a8)
85+
86+
if (use_int8_w8a8 or use_int4_w4a8) and dynamic_scale is None:
87+
permuted_hidden_states, dynamic_scale = torch_npu.npu_dynamic_quant(
88+
permuted_hidden_states)
89+
90+
gate_up_output = torch_npu.npu_grouped_matmul(
91+
x=[permuted_hidden_states],
92+
weight=[w1],
93+
split_item=2,
8894
group_list_type=group_list_type,
89-
)
90-
moe_comm_method.unpermute(mlp_output, hidden_states)
95+
group_type=0,
96+
group_list=expert_tokens,
97+
output_dtype=torch.int32 if use_int8_w8a8 else None,
98+
)[0]
99+
100+
if use_int8_w8a8:
101+
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
102+
x=gate_up_output,
103+
weight_scale=w1_scale.to(torch.float32),
104+
activation_scale=dynamic_scale,
105+
bias=None,
106+
quant_scale=None,
107+
quant_offset=None,
108+
group_index=expert_tokens,
109+
activate_left=True,
110+
quant_mode=1,
111+
)
112+
else:
113+
activated_output = torch_npu.npu_swiglu(gate_up_output)
114+
activated_output_scale = None
115+
116+
down_output = torch_npu.npu_grouped_matmul(
117+
x=[activated_output],
118+
weight=[w2],
119+
scale=[w2_scale] if use_int8_w8a8 else None,
120+
per_token_scale=[activated_output_scale] if use_int8_w8a8 else None,
121+
split_item=2,
122+
group_list_type=group_list_type,
123+
group_type=0,
124+
group_list=expert_tokens,
125+
output_dtype=w2_scale.dtype if use_int8_w8a8 else None,
126+
)[0]
127+
128+
moe_comm_method.unpermute(down_output, hidden_states)
91129

92130
return hidden_states
93131

@@ -156,8 +194,6 @@ def forward_oot(
156194
expert_map=expert_map,
157195
apply_router_weight_on_input=apply_router_weight_on_input)
158196

159-
moe_comm_method = get_forward_context().moe_comm_method
160-
161197
return fused_experts(
162198
hidden_states=x,
163199
w1=layer.w13_weight,
@@ -166,10 +202,26 @@ def forward_oot(
166202
topk_ids=topk_ids,
167203
global_num_experts=global_num_experts,
168204
expert_map=expert_map,
169-
moe_comm_method=moe_comm_method,
170205
)
171206

172207

208+
def process_weights_after_loading(self, layer):
209+
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
210+
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
211+
1, 2).contiguous()
212+
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
213+
214+
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
215+
1, 2).contiguous()
216+
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
217+
218+
if not is_310p():
219+
layer.w13_weight.data = torch_npu.npu_format_cast(
220+
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
221+
layer.w2_weight.data = torch_npu.npu_format_cast(
222+
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
223+
224+
173225
class AscendFusedMoE(FusedMoE):
174226

175227
def __init__(
@@ -281,4 +333,5 @@ def forward_impl(self, hidden_states: torch.Tensor,
281333

282334

283335
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
336+
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
284337
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from vllm_ascend.ascend_forward_context import FusedMoEState
2727
from vllm_ascend.distributed.parallel_state import get_mc2_group
2828
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
29+
from vllm_ascend.ops.common_fused_moe import \
30+
fused_experts as unified_fused_experts
2931
from vllm_ascend.ops.layers.experts_selector import select_experts
3032
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
3133

@@ -375,6 +377,20 @@ def apply(
375377
e_score_correction_bias=e_score_correction_bias,
376378
global_num_experts=global_num_experts)
377379

380+
moe_comm_method = get_forward_context().moe_comm_method
381+
382+
return unified_fused_experts(
383+
hidden_states=x,
384+
w1=layer.w13_weight,
385+
w2=layer.w2_weight,
386+
topk_weights=topk_weights,
387+
topk_ids=topk_ids,
388+
use_int8_w8a8=True,
389+
w1_scale=layer.w13_weight_scale,
390+
w2_scale=layer.w2_weight_scale,
391+
expert_map=expert_map,
392+
)
393+
378394
fused_moe_state = get_forward_context().fused_moe_state
379395
shared_gate_up, shared_dequant_scale = None, None
380396
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:

0 commit comments

Comments
 (0)