Skip to content

Commit 8ceca57

Browse files
committed
Update w4a8_dynamic.py
1 parent 2235c42 commit 8ceca57

File tree

1 file changed

+7
-29
lines changed

1 file changed

+7
-29
lines changed

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def apply(
275275
e_score_correction_bias=e_score_correction_bias,
276276
global_num_experts=global_num_experts)
277277

278-
fused_moe_state = get_forward_context().moe_comm_method_name
278+
fused_moe_state = get_forward_context().fused_moe_state
279279
shared_gate_up, shared_dequant_scale = None, None
280-
if shared_experts is not None and fused_moe_state == "mc2commimpl":
280+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
281281
share_up_out, _ = shared_experts.gate_up_proj(
282282
(quantized_x_for_share, dynamic_scale_for_share))
283283
shared_gate_up, shared_dequant_scale = share_up_out[
@@ -290,10 +290,8 @@ def apply(
290290
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
291291

292292
topk_weights = topk_weights.to(x.dtype)
293-
294-
moe_comm_method = get_forward_context().moe_comm_method
295-
print("using w4a8")
296-
return moe_comm_method.fused_experts(
293+
294+
return unified_fused_experts_eager(
297295
hidden_states=x,
298296
w1=layer.w13_weight,
299297
w2=layer.w2_weight,
@@ -304,34 +302,14 @@ def apply(
304302
topk_weights=topk_weights,
305303
topk_ids=topk_ids,
306304
row_idx=row_idx,
307-
use_int4_w4a8=True,
308305
expert_map=expert_map,
309306
log2phy=log2phy,
310307
global_redundant_expert_num=global_redundant_expert_num,
311308
shared_experts=shared_experts,
312309
shared_gate_up=shared_gate_up,
313-
shared_dequant_scale=shared_dequant_scale
314-
)
315-
316-
# return unified_fused_experts_eager(
317-
# hidden_states=x,
318-
# w1=layer.w13_weight,
319-
# w2=layer.w2_weight,
320-
# w1_scale=layer.w13_weight_scale_second,
321-
# w2_scale=layer.w2_weight_scale_second,
322-
# w1_scale_bias=layer.w13_scale_bias,
323-
# w2_scale_bias=layer.w2_scale_bias,
324-
# topk_weights=topk_weights,
325-
# topk_ids=topk_ids,
326-
# row_idx=row_idx,
327-
# expert_map=expert_map,
328-
# log2phy=log2phy,
329-
# global_redundant_expert_num=global_redundant_expert_num,
330-
# shared_experts=shared_experts,
331-
# shared_gate_up=shared_gate_up,
332-
# shared_dequant_scale=shared_dequant_scale,
333-
# mc2_mask=kwargs.get("mc2_mask", None),
334-
# with_quant=True)
310+
shared_dequant_scale=shared_dequant_scale,
311+
mc2_mask=kwargs.get("mc2_mask", None),
312+
with_quant=True)
335313

336314
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
337315
group_num, k, n = weight.shape

0 commit comments

Comments
 (0)