diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 895b7ffca0..016dd6c1eb 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -22,18 +22,17 @@ def get_etp_group() -> GroupCoordinator: def init_ascend_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, + expert_parallel_size: int = 1, expert_tensor_parallel_size: int = 1, + world_size: Optional[int] = None, backend: Optional[str] = None, ): assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() + world_size = world_size or torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - num_expert_parallel_groups: int = expert_tensor_parallel_size - num_expert_tensor_parallel_groups: int = (world_size // - expert_tensor_parallel_size) + num_expert_parallel_groups = expert_tensor_parallel_size + num_expert_tensor_parallel_groups = expert_parallel_size global _EP group_ranks = [] diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 8a5b115723..8da897b668 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -119,6 +119,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel # noqa: E402 compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config + additional_config = vllm_config.additional_config + parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + + if parallel_config: + # Default value for expert tensor parallel size + parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size + + # NOTE: When enable_expert_parallel is True, we follow vLLM convention: + # ep_size = world_size, which means expert_tensor_parallel_size must be 1 + if (additional_config + and "expert_tensor_parallel_size" in additional_config + and not parallel_config.enable_expert_parallel): + parallel_config.expert_tensor_parallel_size = int( + additional_config["expert_tensor_parallel_size"]) + + # Calculate expert parallel size based on world size + parallel_config.expert_parallel_size = ( + parallel_config.world_size // + parallel_config.expert_tensor_parallel_size) if model_config is None: logger.warning("Model config is missing. This may indicate " @@ -127,9 +147,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: enforce_eager = getattr(model_config, "enforce_eager", False) - if vllm_config.additional_config is not None: - enable_graph_mode = vllm_config.additional_config.get( - "enable_graph_mode", False) + if additional_config is not None: + enable_graph_mode = additional_config.get("enable_graph_mode", + False) if enable_graph_mode: if enforce_eager: raise RuntimeError( @@ -139,7 +159,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.warning( "NPU graph mode is still experimental and not supported for V1 without mla currently, " "it has been disabled automatically.") - vllm_config.additional_config["enable_graph_mode"] = False + additional_config["enable_graph_mode"] = False if model_config: model_type = model_config.hf_config.model_type if "deepseek" not in model_type: @@ -178,7 +198,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: ["vllm.unified_ascend_attention_with_output"]) update_aclgraph_sizes(vllm_config) - parallel_config = vllm_config.parallel_config if parallel_config and parallel_config.worker_cls == "auto": if envs.VLLM_USE_V1: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" @@ -190,7 +209,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" - cache_config = vllm_config.cache_config if cache_config: if cache_config.block_size is None: cache_config.block_size = 128 @@ -202,11 +220,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_USE_V1: # Activate custom ops for v1. - vllm_config.compilation_config.custom_ops = ["all"] + compilation_config.custom_ops = ["all"] # If ascend_scheduler_config exists in additional_config, # extents original scheduler_config to use AscendScheduler. - additional_config = vllm_config.additional_config if additional_config and additional_config.get( "ascend_scheduler_config", None) is not None: additional_scheduler_config = additional_config.get( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 67cc0b8f4d..34b8da1d31 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -126,14 +126,16 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: original_sizes, compilation_config.cudagraph_capture_sizes = \ compilation_config.cudagraph_capture_sizes, None - # Calculate parallel configuration factor (increases with DP or TP) - # TODO(Yizhou): This is a temporary solution, need to be improved - # in the future, taking into account the other parallel configurations. + # Calculate parallel configuration factor num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers parallel_config = vllm_config.parallel_config + + # TODO: Find out whether we need to take into account the pp_size parallel_factor = 1 + sum(size > 1 for size in [ - parallel_config.data_parallel_size, - parallel_config.tensor_parallel_size + parallel_config.data_parallel_size_local, + parallel_config.tensor_parallel_size, + parallel_config.expert_parallel_size, + parallel_config.expert_tensor_parallel_size, ]) # Calculate maximum supported batch sizes considering model architecture diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 3e1515d123..d98b0fe886 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -534,7 +534,6 @@ def _init_worker_distributed_environment( backend: str = "hccl") -> None: """Initialize the distributed environment.""" parallel_config = self.parallel_config - additional_config = self.vllm_config.additional_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, @@ -542,13 +541,11 @@ def _init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - expert_tensor_parallel_size = 1 - if additional_config: - expert_tensor_parallel_size = additional_config.get( - "expert_tensor_parallel_size", 1) - init_ascend_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - expert_tensor_parallel_size) + init_ascend_model_parallel( + parallel_config.expert_parallel_size, + parallel_config.expert_tensor_parallel_size, + parallel_config.world_size, + ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index ae6a59eb00..69476e256f 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -234,7 +234,6 @@ def execute_dummy_batch(self) -> None: def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" - additional_config = self.vllm_config.additional_config parallel_config = self.vllm_config.parallel_config set_custom_all_reduce( not self.parallel_config.disable_custom_all_reduce) @@ -244,13 +243,11 @@ def _init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - expert_tensor_parallel_size = 1 - if additional_config is not None and "expert_tensor_parallel_size" in additional_config: - expert_tensor_parallel_size = int( - additional_config["expert_tensor_parallel_size"]) - init_ascend_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - expert_tensor_parallel_size) + init_ascend_model_parallel( + parallel_config.expert_parallel_size, + parallel_config.expert_tensor_parallel_size, + parallel_config.world_size, + ) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):