|
45 | 45 | data_parallel_reduce_scatter |
46 | 46 | from vllm_ascend.distributed.parallel_state import get_mc2_group |
47 | 47 | from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer |
| 48 | +from vllm_ascend.ops.layers.experts_selector import select_experts |
48 | 49 | from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( |
49 | 50 | MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) |
50 | 51 | from vllm_ascend.ops.sequence_parallel import MetadataForPadding |
@@ -863,143 +864,6 @@ def fused_experts( |
863 | 864 | return final_hidden_states |
864 | 865 |
|
865 | 866 |
|
866 | | -def native_grouped_topk( |
867 | | - topk_weights: torch.Tensor, |
868 | | - num_expert_group: Optional[int], |
869 | | - topk_group: Optional[int], |
870 | | -): |
871 | | - topk_group = 0 if topk_group is None else topk_group |
872 | | - num_expert_group = 0 if num_expert_group is None else num_expert_group |
873 | | - |
874 | | - num_token = topk_weights.shape[0] |
875 | | - grouped_weights = topk_weights.view(num_token, num_expert_group, |
876 | | - -1).max(dim=-1).values |
877 | | - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), |
878 | | - k=topk_group, |
879 | | - dim=-1, |
880 | | - sorted=False)[1] |
881 | | - topk_group_mask = torch.zeros_like(grouped_weights) |
882 | | - topk_group_mask.scatter_(1, topk_group_indices, 1) |
883 | | - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( |
884 | | - num_token, num_expert_group, |
885 | | - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) |
886 | | - topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) |
887 | | - |
888 | | - return topk_weights |
889 | | - |
890 | | - |
891 | | -def select_experts( |
892 | | - hidden_states: torch.Tensor, |
893 | | - router_logits: torch.Tensor, |
894 | | - top_k: int, |
895 | | - use_grouped_topk: bool, |
896 | | - renormalize: bool, |
897 | | - topk_group: Optional[int] = None, |
898 | | - num_expert_group: Optional[int] = None, |
899 | | - custom_routing_function: Optional[Callable] = None, |
900 | | - scoring_func: str = "softmax", |
901 | | - e_score_correction_bias: Optional[torch.Tensor] = None, |
902 | | - global_num_experts: Optional[torch.Tensor] = None |
903 | | -) -> tuple[torch.Tensor, torch.Tensor]: |
904 | | - """ |
905 | | - Select top-k experts based on router logits. |
906 | | -
|
907 | | - Args: |
908 | | - hidden_states: Hidden states of shape (num_tokens, hidden_size). |
909 | | - router_logits: Router logits of shape (num_tokens, num_experts). |
910 | | - top_k: Number of experts to select. |
911 | | - use_grouped_topk: Whether to group experts before selecting top-k. |
912 | | - renormalize: Whether to renormalize the routing weights. |
913 | | - topk_group: Number of expert groups to select from. |
914 | | - num_expert_group: Number of experts in each group. |
915 | | - custom_routing_function: Custom routing function. |
916 | | - scoring_func: Scoring function to use. |
917 | | - e_score_correction_bias: Correction bias to apply to expert scores. |
918 | | -
|
919 | | - Returns: |
920 | | - topk_weights: Routing weights of shape (num_tokens, top_k). |
921 | | - topk_ids: Selected expert IDs of shape (num_tokens, top_k). |
922 | | -
|
923 | | - Raises: |
924 | | - ValueError: If an unsupported scoring function is provided. |
925 | | - """ |
926 | | - |
927 | | - def _renormalize_topk_weights( |
928 | | - topk_weights: torch.Tensor, |
929 | | - renormalize: bool, |
930 | | - ): |
931 | | - if renormalize: |
932 | | - topk_weights = topk_weights / topk_weights.sum(dim=-1, |
933 | | - keepdim=True) |
934 | | - return topk_weights |
935 | | - |
936 | | - if scoring_func == "softmax": |
937 | | - # NOTE: vLLM use dtype=torch.float here |
938 | | - if not use_grouped_topk and custom_routing_function is None: |
939 | | - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( |
940 | | - x=router_logits, finished=None, k=top_k) |
941 | | - topk_ids = topk_ids.to(torch.int32) |
942 | | - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) |
943 | | - return topk_weights, topk_ids |
944 | | - |
945 | | - topk_weights = router_logits.softmax(dim=-1) |
946 | | - elif scoring_func == "sigmoid": |
947 | | - topk_weights = router_logits.sigmoid() |
948 | | - else: |
949 | | - raise ValueError(f"Unsupported scoring function: {scoring_func}") |
950 | | - |
951 | | - if use_grouped_topk: |
952 | | - assert topk_group is not None |
953 | | - assert num_expert_group is not None |
954 | | - |
955 | | - if e_score_correction_bias is not None: |
956 | | - # Store original scores before applying correction bias. We use biased |
957 | | - # scores for expert selection but original scores for routing weights |
958 | | - original_weights = topk_weights |
959 | | - topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) |
960 | | - |
961 | | - # TODO: Change to npu_group_topk when the latest CANN and NNAL is available |
962 | | - # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) |
963 | | - topk_weights = native_grouped_topk(topk_weights, num_expert_group, |
964 | | - topk_group) |
965 | | - # TODO bfloat16 is not supported in torch.topk with ge graph. |
966 | | - if e_score_correction_bias is not None: |
967 | | - topk_ids = torch.topk(topk_weights.to(torch.float32), |
968 | | - k=top_k, |
969 | | - dim=-1, |
970 | | - sorted=False)[1] |
971 | | - # Use original unbiased scores for the routing weights |
972 | | - topk_weights = original_weights.gather(1, topk_ids) |
973 | | - else: |
974 | | - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), |
975 | | - k=top_k, |
976 | | - dim=-1, |
977 | | - sorted=False) |
978 | | - topk_ids = topk_ids.to(torch.int32) |
979 | | - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) |
980 | | - return topk_weights, topk_ids |
981 | | - |
982 | | - if custom_routing_function is not None: |
983 | | - topk_weights, topk_ids = custom_routing_function( |
984 | | - hidden_states=hidden_states, |
985 | | - gating_output=router_logits, |
986 | | - topk=top_k, |
987 | | - renormalize=renormalize, |
988 | | - global_num_experts=global_num_experts) |
989 | | - # Required by npu_moe_init_routing |
990 | | - topk_ids = topk_ids.to(torch.int32) |
991 | | - return topk_weights, topk_ids |
992 | | - |
993 | | - topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) |
994 | | - topk_weights = topk_weights.to(hidden_states.dtype) |
995 | | - |
996 | | - # Required by npu_moe_init_routing |
997 | | - topk_ids = topk_ids.to(torch.int32) |
998 | | - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) |
999 | | - |
1000 | | - return topk_weights, topk_ids |
1001 | | - |
1002 | | - |
1003 | 867 | class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): |
1004 | 868 |
|
1005 | 869 | def __init__(self, moe: FusedMoEConfig = None): |
@@ -1054,36 +918,19 @@ def apply( |
1054 | 918 | **kwargs, |
1055 | 919 | ) -> torch.Tensor: |
1056 | 920 |
|
1057 | | - is_deepseek_v3_r1 = global_num_experts == 256 |
1058 | | - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern |
1059 | | - if is_deepseek_v3_r1: |
1060 | | - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( |
1061 | | - router_logits, |
1062 | | - k=top_k, # topk currently is 8 |
1063 | | - bias=e_score_correction_bias, |
1064 | | - k_group=topk_group, # fix: 4 |
1065 | | - group_count=num_expert_group, # fix 8 |
1066 | | - group_select_mode= |
1067 | | - 1, # 0: the maximum in the group; 1: topk2.sum(fix) |
1068 | | - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax |
1069 | | - norm_type=1, # 0: softmax; 1: sigmoid(fix) |
1070 | | - # out_flag=False, # todo new api; should the third output be output |
1071 | | - # y2_flag=False, # old api; should the third output be output |
1072 | | - routed_scaling_factor=1, |
1073 | | - eps=float(1e-20)) |
1074 | | - else: |
1075 | | - topk_weights, topk_ids = select_experts( |
1076 | | - hidden_states=x, |
1077 | | - router_logits=router_logits, |
1078 | | - top_k=top_k, |
1079 | | - use_grouped_topk=use_grouped_topk, |
1080 | | - renormalize=renormalize, |
1081 | | - topk_group=topk_group, |
1082 | | - num_expert_group=num_expert_group, |
1083 | | - custom_routing_function=custom_routing_function, |
1084 | | - scoring_func=scoring_func, |
1085 | | - e_score_correction_bias=e_score_correction_bias, |
1086 | | - ) |
| 921 | + topk_weights, topk_ids = select_experts( |
| 922 | + hidden_states=x, |
| 923 | + router_logits=router_logits, |
| 924 | + top_k=top_k, |
| 925 | + use_grouped_topk=use_grouped_topk, |
| 926 | + renormalize=renormalize, |
| 927 | + topk_group=topk_group, |
| 928 | + num_expert_group=num_expert_group, |
| 929 | + custom_routing_function=custom_routing_function, |
| 930 | + scoring_func=scoring_func, |
| 931 | + e_score_correction_bias=e_score_correction_bias, |
| 932 | + global_num_experts=global_num_experts, |
| 933 | + is_unquantized=True) |
1087 | 934 |
|
1088 | 935 | topk_weights = topk_weights.to(x.dtype) |
1089 | 936 | # this is a naive implementation for experts load balance so as |
|
0 commit comments