|
16 | 16 | # |
17 | 17 |
|
18 | 18 | import math |
19 | | -from typing import Any, Callable, Dict, Optional, Tuple, Union, List |
| 19 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | import torch.distributed as dist |
|
31 | 31 | dispose_tensor, get_ascend_soc_version, |
32 | 32 | npu_stream_switch, npu_wait_tensor) |
33 | 33 |
|
| 34 | + |
34 | 35 | def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor], |
35 | 36 | w1: torch.Tensor, |
36 | 37 | w1_scale: torch.Tensor, |
@@ -80,7 +81,7 @@ def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor], |
80 | 81 |
|
81 | 82 | # act_fn: swiglu |
82 | 83 | hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( |
83 | | - x=hidden_states, |
| 84 | + x=hidden_states, |
84 | 85 | weight_scale=w1_scale, |
85 | 86 | activation_scale=pertoken_scale, |
86 | 87 | bias=None, |
@@ -269,17 +270,18 @@ def fused_experts_with_mc2( |
269 | 270 | if shared_experts is not None: |
270 | 271 | with npu_stream_switch("moe_secondary", 0): |
271 | 272 | npu_wait_tensor(quantized_x_for_share, expand_x) |
272 | | - shared_act_out = shared_experts.act_fn((quantized_x_for_share, dynamic_scale_for_share)) |
| 273 | + shared_act_out = shared_experts.act_fn( |
| 274 | + (quantized_x_for_share, dynamic_scale_for_share)) |
273 | 275 | shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] |
274 | 276 |
|
275 | 277 | # `expand_x` will be disposed in the `apply_mlp` function |
276 | 278 | down_out_list = apply_mlp_decode([expand_x], |
277 | | - w1, |
278 | | - w1_scale, |
279 | | - w2, |
280 | | - w2_scale, |
281 | | - expert_token_nums, |
282 | | - dynamic_scale=dynamic_scale) |
| 279 | + w1, |
| 280 | + w1_scale, |
| 281 | + w2, |
| 282 | + w2_scale, |
| 283 | + expert_token_nums, |
| 284 | + dynamic_scale=dynamic_scale) |
283 | 285 |
|
284 | 286 | # moeCombine |
285 | 287 | kwargs_mc2 = { |
@@ -317,7 +319,8 @@ def fused_experts_with_mc2( |
317 | 319 | else: |
318 | 320 | with npu_stream_switch("moe_secondary", 0): |
319 | 321 | npu_wait_tensor(shared_act, down_out_list) |
320 | | - shared_output, _ = shared_experts.down_proj((shared_act, swiglu_out_scale)) |
| 322 | + shared_output, _ = shared_experts.down_proj( |
| 323 | + (shared_act, swiglu_out_scale)) |
321 | 324 | return hidden_states, shared_output |
322 | 325 |
|
323 | 326 |
|
@@ -774,8 +777,10 @@ def apply( |
774 | 777 | if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: |
775 | 778 | with npu_stream_switch("moe_secondary", 0): |
776 | 779 | npu_wait_tensor(quantized_x_for_share, router_logits) |
777 | | - share_up_out, _ = shared_experts.gate_up_proj((quantized_x_for_share, dynamic_scale_for_share)) |
778 | | - shared_gate_up, shared_dequant_scale = share_up_out[0], share_up_out[1] |
| 780 | + share_up_out, _ = shared_experts.gate_up_proj( |
| 781 | + (quantized_x_for_share, dynamic_scale_for_share)) |
| 782 | + shared_gate_up, shared_dequant_scale = share_up_out[ |
| 783 | + 0], share_up_out[1] |
779 | 784 |
|
780 | 785 | # this is a naive implementation for experts load balance so as |
781 | 786 | # to avoid accumulating too much tokens on a single rank. |
|
0 commit comments