2626 tensor_model_parallel_all_reduce )
2727from vllm .distributed .parallel_state import get_dp_group
2828from vllm .model_executor .layers .fused_moe .layer import (
29- FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
30-
31- from vllm_ascend .utils import vllm_version_is
32-
33- if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
34- from vllm .model_executor .layers .fused_moe .layer import (
35- FusedMoEParallelConfig , MoEConfig )
36- else :
37- MoEConfig = None
38-
39- from vllm .model_executor .layers .quantization .base_config import (
40- QuantizationConfig , QuantizeMethodBase )
29+ FusedMoE , FusedMoEParallelConfig , MoEConfig , UnquantizedFusedMoEMethod ,
30+ determine_expert_map )
31+ from vllm .model_executor .layers .quantization .base_config import \
32+ QuantizationConfig
4133
4234import vllm_ascend .envs as envs_ascend
4335from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
@@ -587,10 +579,8 @@ def select_experts(
587579class AscendUnquantizedFusedMoEMethod (UnquantizedFusedMoEMethod ):
588580
589581 def __init__ (self , moe : MoEConfig = None ):
590- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
591- super ().__init__ ()
592- else :
593- super ().__init__ (moe = moe )
582+
583+ super ().__init__ (moe = moe )
594584 vllm_config = get_current_vllm_config ()
595585
596586 ep_group = get_ep_group ()
@@ -731,24 +721,17 @@ def __init__(
731721 params_dtype = torch .get_default_dtype ()
732722
733723 vllm_config = get_current_vllm_config ()
734- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
735- self .ep_size = get_ep_group ().world_size
736- self .tp_size = get_etp_group ().world_size
737- self .dp_size = (dp_size if dp_size is not None else
738- get_dp_group ().world_size )
739- self .dp_rank = (0 if self .dp_size == 1 else
740- get_dp_group ().rank_in_group )
741- else :
742- self .moe_parallel_config : FusedMoEParallelConfig = (
743- FusedMoEParallelConfig .make (
744- tp_size_ = (tp_size if tp_size is not None else
745- get_tensor_model_parallel_world_size ()),
746- dp_size_ = (dp_size if dp_size is not None else
747- get_dp_group ().world_size ),
748- vllm_parallel_config = vllm_config .parallel_config ))
749724
750- self .moe_parallel_config .ep_size = get_ep_group ().world_size
751- self .moe_parallel_config .tp_size = get_etp_group ().world_size
725+ self .moe_parallel_config : FusedMoEParallelConfig = (
726+ FusedMoEParallelConfig .make (
727+ tp_size_ = (tp_size if tp_size is not None else
728+ get_tensor_model_parallel_world_size ()),
729+ dp_size_ = (dp_size if dp_size is not None else
730+ get_dp_group ().world_size ),
731+ vllm_parallel_config = vllm_config .parallel_config ))
732+
733+ self .moe_parallel_config .ep_size = get_ep_group ().world_size
734+ self .moe_parallel_config .tp_size = get_etp_group ().world_size
752735
753736 self .top_k = top_k
754737 self .num_experts = num_experts
@@ -773,54 +756,39 @@ def __init__(
773756 self .local_num_experts , self .expert_map = determine_expert_map (
774757 self .ep_size ,
775758 get_ep_group ().rank_in_group , self .global_num_experts )
776- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
777- self .tp_rank = get_etp_group ().rank_in_group
778- self .ep_rank = get_ep_group ().rank_in_group
779- else :
780- self .moe_parallel_config .tp_rank = get_etp_group (
781- ).rank_in_group
782- self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
759+
760+ self .moe_parallel_config .tp_rank = get_etp_group ().rank_in_group
761+ self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
783762
784763 else :
785764 # Adjust TP size for DP attention
786765 # haven't test its functionality yet, may remove in the future
787- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
788- self .tp_rank = self .tp_size * self .dp_rank
789- self .ep_rank = 0
790- self .tp_size = self .tp_size * self .dp_size
791- self .ep_size = 1
792- else :
793- self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
794- self .moe_parallel_config .ep_rank = 0
795- self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
796- self .moe_parallel_config .ep_size = 1
766+
767+ self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
768+ self .moe_parallel_config .ep_rank = 0
769+ self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
770+ self .moe_parallel_config .ep_size = 1
797771
798772 self .local_num_experts , self .expert_map = (self .global_num_experts ,
799773 None )
800774 if self .scoring_func != "softmax" and not self .use_grouped_topk :
801775 raise ValueError ("Only softmax scoring function is supported for "
802776 "non-grouped topk." )
803- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
804- if quant_config is None :
805- self .quant_method : Optional [QuantizeMethodBase ] = (
806- AscendUnquantizedFusedMoEMethod ())
807- else :
808- self .quant_method = quant_config .get_quant_method (self , prefix )
809- else :
810- moe = MoEConfig (
811- num_experts = self .global_num_experts ,
812- experts_per_token = top_k ,
813- hidden_dim = hidden_size ,
814- num_local_experts = self .local_num_experts ,
815- moe_parallel_config = self .moe_parallel_config ,
816- # TODO (bnell): this needs to be fixed for quantized types.
817- in_dtype = params_dtype ,
818- )
819777
820- if quant_config is None :
821- self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
822- else :
823- self .quant_method = quant_config .get_quant_method (self , prefix )
778+ moe = MoEConfig (
779+ num_experts = self .global_num_experts ,
780+ experts_per_token = top_k ,
781+ hidden_dim = hidden_size ,
782+ num_local_experts = self .local_num_experts ,
783+ moe_parallel_config = self .moe_parallel_config ,
784+ # TODO (bnell): this needs to be fixed for quantized types.
785+ in_dtype = params_dtype ,
786+ )
787+
788+ if quant_config is None :
789+ self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
790+ else :
791+ self .quant_method = quant_config .get_quant_method (self , prefix )
824792
825793 assert self .quant_method is not None
826794
0 commit comments