2020import torch
2121import torch_npu
2222from vllm .config import get_current_vllm_config
23- from vllm .distributed import tensor_model_parallel_all_reduce
23+ from vllm .distributed import (get_tensor_model_parallel_world_size ,
24+ tensor_model_parallel_all_reduce )
2425from vllm .distributed .parallel_state import get_dp_group
2526from vllm .model_executor .layers .fused_moe .layer import (
2627 FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
27- from vllm .model_executor .layers .quantization .base_config import \
28- QuantizeMethodBase
28+
29+ from vllm_ascend .utils import vllm_version_is
30+
31+ if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
32+ from vllm .model_executor .layers .fused_moe .layer import (
33+ FusedMoEParallelConfig , MoEConfig )
34+ else :
35+ MoEConfig = None
36+
37+ from vllm .model_executor .layers .quantization .base_config import (
38+ QuantizationConfig , QuantizeMethodBase )
2939
3040import vllm_ascend .envs as envs_ascend
3141from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
@@ -437,8 +447,11 @@ def select_experts(
437447
438448class AscendUnquantizedFusedMoEMethod (UnquantizedFusedMoEMethod ):
439449
440- def __init__ (self ):
441- super ().__init__ ()
450+ def __init__ (self , moe : MoEConfig = None ):
451+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
452+ super ().__init__ ()
453+ else :
454+ super ().__init__ (moe = moe )
442455 vllm_config = get_current_vllm_config ()
443456
444457 ep_group = get_ep_group ()
@@ -535,37 +548,54 @@ def apply(
535548
536549class AscendFusedMoE (FusedMoE ):
537550
538- def __init__ (self ,
539- num_experts ,
540- top_k ,
541- hidden_size ,
542- intermediate_size ,
543- params_dtype = None ,
544- reduce_results = False ,
545- renormalize = True ,
546- use_grouped_topk = False ,
547- num_expert_group = None ,
548- topk_group = None ,
549- quant_config = None ,
550- tp_size = None ,
551- ep_size = None ,
552- dp_size = None ,
553- prefix = "" ,
554- custom_routing_function = None ,
555- scoring_func = "softmax" ,
556- e_score_correction_bias = None ,
557- activation = "silu" ):
551+ def __init__ (
552+ self ,
553+ num_experts : int , # Global number of experts
554+ top_k : int ,
555+ hidden_size : int ,
556+ intermediate_size : int ,
557+ params_dtype : Optional [torch .dtype ] = None ,
558+ reduce_results : bool = False ,
559+ renormalize : bool = True ,
560+ use_grouped_topk : bool = False ,
561+ num_expert_group : Optional [int ] = None ,
562+ topk_group : Optional [int ] = None ,
563+ quant_config : Optional [QuantizationConfig ] = None ,
564+ tp_size : Optional [int ] = None ,
565+ ep_size : Optional [int ] = None ,
566+ dp_size : Optional [int ] = None ,
567+ prefix : str = "" ,
568+ custom_routing_function : Optional [Callable ] = None ,
569+ scoring_func : str = "softmax" ,
570+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
571+ activation : str = "silu" ,
572+ apply_router_weight_on_input : bool = False ,
573+ ):
574+ # TODO: This could not initialize FusedMoE baseclass,
575+ # fixme and make __init__() of AscendFusedMoE more clear
558576 super (FusedMoE , self ).__init__ ()
559577
560578 if params_dtype is None :
561579 params_dtype = torch .get_default_dtype ()
562580
563- self .ep_size = get_ep_group ().world_size
564- self .tp_size = get_etp_group ().world_size
565- self .dp_size = (dp_size
566- if dp_size is not None else get_dp_group ().world_size )
567- self .dp_rank = (0
568- if self .dp_size == 1 else get_dp_group ().rank_in_group )
581+ vllm_config = get_current_vllm_config ()
582+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
583+ self .ep_size = get_ep_group ().world_size
584+ self .tp_size = get_etp_group ().world_size
585+ self .dp_size = (dp_size if dp_size is not None else
586+ get_dp_group ().world_size )
587+ self .dp_rank = (0 if self .dp_size == 1 else
588+ get_dp_group ().rank_in_group )
589+ else :
590+ self .moe_parallel_config : FusedMoEParallelConfig = (
591+ FusedMoEParallelConfig .make (
592+ tp_size_ = (tp_size if tp_size is not None else
593+ get_tensor_model_parallel_world_size ()),
594+ dp_size_ = (dp_size if dp_size is not None else
595+ get_dp_group ().world_size ),
596+ vllm_parallel_config = vllm_config .parallel_config ))
597+
598+ self .moe_parallel_config .ep_size = get_ep_group ().world_size
569599
570600 self .top_k = top_k
571601 self .num_experts = num_experts
@@ -590,27 +620,55 @@ def __init__(self,
590620 self .local_num_experts , self .expert_map = determine_expert_map (
591621 self .ep_size ,
592622 get_ep_group ().rank_in_group , self .global_num_experts )
593- self .tp_rank = get_etp_group ().rank_in_group
594- self .ep_rank = get_ep_group ().rank_in_group
623+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
624+ self .tp_rank = get_etp_group ().rank_in_group
625+ self .ep_rank = get_ep_group ().rank_in_group
626+ else :
627+ self .moe_parallel_config .tp_rank = get_etp_group (
628+ ).rank_in_group
629+ self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
630+
595631 else :
596632 # Adjust TP size for DP attention
597633 # haven't test its functionality yet, may remove in the future
598- self .tp_rank = self .tp_size * self .dp_rank
599- self .ep_rank = 0
600- self .tp_size = self .tp_size * self .dp_size
601- self .ep_size = 1
602- self .local_num_experts = self .global_num_experts
603- self .expert_map = None
604-
634+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
635+ self .tp_rank = self .tp_size * self .dp_rank
636+ self .ep_rank = 0
637+ self .tp_size = self .tp_size * self .dp_size
638+ self .ep_size = 1
639+ else :
640+ self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
641+ self .moe_parallel_config .ep_rank = 0
642+ self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
643+ self .moe_parallel_config .ep_size = 1
644+
645+ self .local_num_experts , self .expert_map = (self .global_num_experts ,
646+ None )
605647 if self .scoring_func != "softmax" and not self .use_grouped_topk :
606648 raise ValueError ("Only softmax scoring function is supported for "
607649 "non-grouped topk." )
608-
609- if quant_config is None :
610- self .quant_method : Optional [QuantizeMethodBase ] = (
611- AscendUnquantizedFusedMoEMethod ())
650+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
651+ if quant_config is None :
652+ self .quant_method : Optional [QuantizeMethodBase ] = (
653+ AscendUnquantizedFusedMoEMethod ())
654+ else :
655+ self .quant_method = quant_config .get_quant_method (self , prefix )
612656 else :
613- self .quant_method = quant_config .get_quant_method (self , prefix )
657+ moe = MoEConfig (
658+ num_experts = self .global_num_experts ,
659+ experts_per_token = top_k ,
660+ hidden_dim = hidden_size ,
661+ num_local_experts = self .local_num_experts ,
662+ moe_parallel_config = self .moe_parallel_config ,
663+ # TODO (bnell): this needs to be fixed for quantized types.
664+ in_dtype = params_dtype ,
665+ )
666+
667+ if quant_config is None :
668+ self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
669+ else :
670+ self .quant_method = quant_config .get_quant_method (self , prefix )
671+
614672 assert self .quant_method is not None
615673
616674 local_num_experts = torch .sum (self .expert_map != - 1 ) \
0 commit comments