@@ -54,7 +54,7 @@ def permute(
5454 topk_weights : torch .Tensor ,
5555 expert_map : torch .Tensor ,
5656 num_experts : int ,
57- use_a8 : bool ,
57+ apply_a8_quantization : bool ,
5858 ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
5959 """Pre-process before MLP.
6060
@@ -65,6 +65,7 @@ def permute(
6565 expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
6666 Mapping from global expert IDs to local expert IDs.
6767 num_experts (int): Number of local experts (experts on this device).
68+ apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
6869
6970 Returns:
7071 tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
@@ -73,6 +74,8 @@ def permute(
7374 hidden_states based on topk_ids.
7475 - expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
7576 Number of tokens assigned to each expert.
77+ - dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
78+ Dynamic scale for each expert, used for quantization.
7679 - group_list_type (int): Type of group list, 0 for `cumsum`
7780 and 1 for `count`. This is mainly for `npu_grouped_matmul`
7881 to determine how to handle the output.
@@ -160,7 +163,7 @@ def permute(
160163 topk_weights : torch .Tensor ,
161164 expert_map : torch .Tensor , # noqa: F841
162165 num_experts : int ,
163- use_a8 : bool ,
166+ apply_a8_quantization : bool ,
164167 ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
165168 num_tokens = hidden_states .shape [0 ]
166169
@@ -221,7 +224,7 @@ def permute(
221224 topk_weights : torch .Tensor ,
222225 expert_map : torch .Tensor ,
223226 num_experts : int ,
224- use_a8 : bool ,
227+ apply_a8_quantization : bool ,
225228 ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
226229 num_tokens = hidden_states .shape [0 ]
227230
@@ -378,7 +381,7 @@ def permute(
378381 topk_weights : torch .Tensor ,
379382 expert_map : torch .Tensor ,
380383 num_experts : int ,
381- use_a8 : bool ,
384+ apply_a8_quantization : bool ,
382385 ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
383386 # Store tensors needed for post_process
384387 self .topk_ids = topk_ids
@@ -392,7 +395,7 @@ def permute(
392395 "moe_expert_num" : self .moe_config .num_experts ,
393396 "global_bs" : 0 ,
394397 "scales" : None ,
395- "quant_mode" : 2 if use_a8 else 0 ,
398+ "quant_mode" : 2 if apply_a8_quantization else 0 ,
396399 "group_ep" : self .mc2_comm_name ,
397400 "ep_world_size" : self .moe_config .ep_size ,
398401 "ep_rank_id" : self .moe_config .ep_rank ,
@@ -536,13 +539,15 @@ def permute(
536539 topk_weights : torch .Tensor ,
537540 expert_map : torch .Tensor ,
538541 num_experts : int ,
539- use_a8 : bool ,
542+ apply_a8_quantization : bool ,
540543 ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
541- results = self .token_dispatcher .token_dispatch (hidden_states ,
542- topk_weights ,
543- topk_ids ,
544- None ,
545- log2phy = None )
544+ results = self .token_dispatcher .token_dispatch (
545+ hidden_states ,
546+ topk_weights ,
547+ topk_ids ,
548+ None ,
549+ log2phy = None ,
550+ with_quant = apply_a8_quantization )
546551 return results ["hidden_states" ], results ["group_list" ], results [
547552 "dynamic_scale" ], results ["group_list_type" ]
548553
0 commit comments