Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/ut/worker/test_model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
return_value=soc_version), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
return_value=True):

# Bind the real method to the mock object
Expand Down Expand Up @@ -102,6 +104,8 @@ def test_select_moe_comm_method_unsupported_soc():
return_value=unsupported_soc), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
return_value=True), \
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):

NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)
7 changes: 3 additions & 4 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
set_forward_context)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp, is_moe_model, version_check
from vllm_ascend.utils import (enable_sp, has_layer_idx, is_moe_model,
version_check)

if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
Expand Down Expand Up @@ -134,9 +135,7 @@ def set_ascend_forward_context(
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if model_instance is not None and \
hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer"):
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer

# TODO(rjg-lyh): refactor mlp weight prefetch method
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type)
return _MoECommMethods.get(moe_comm_type, None)


def setup_moe_comm_method(moe_config):
Expand Down
12 changes: 12 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
_MIN_DP_BUFFER_SIZE = 50
_IS_MOE_MODEL = None
_ENABLE_SP = None
_HAS_LAYER_IDX = None


def is_310p():
Expand Down Expand Up @@ -796,3 +797,14 @@ def version_check():
if full_date >= "20250919":
return True
return False


def has_layer_idx(model_instance: torch.nn.Module) -> bool:
if model_instance is None:
return False

global _HAS_LAYER_IDX
if _HAS_LAYER_IDX is None:
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer")
return _HAS_LAYER_IDX
34 changes: 14 additions & 20 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
enable_sp, get_ascend_soc_version, is_310p,
is_enable_nz, lmhead_tp_enable,
is_enable_nz, is_moe_model, lmhead_tp_enable,
prefill_context_parallel_enable,
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
Expand Down Expand Up @@ -506,11 +506,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.in_profile_run = False

self._init_mc2_tokens_capacity()
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
if is_moe_model(vllm_config):
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
else:
self.reserved_mc2_mask = None
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
if self.dynamic_eplb:
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
Expand Down Expand Up @@ -1484,9 +1487,7 @@ def _prepare_inputs(
self.query_lens = torch.from_numpy(num_scheduled_tokens)

# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)

self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
Expand All @@ -1508,16 +1509,6 @@ def _prepare_inputs(
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
attn_metadata: dict[str, Any] = {}

# Prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Copy the tensors to the NPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)

# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
Expand Down Expand Up @@ -2031,7 +2022,7 @@ def _pool(
)

def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> MoECommType:
with_prefill: bool) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
Expand All @@ -2054,6 +2045,9 @@ def _select_moe_comm_method(self, num_tokens: int,
Returns:
MoECommType: The selected MoE communication method.
"""
if not is_moe_model(self.vllm_config):
return None

soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
Expand Down
Loading