Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ jobs:
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
fi

Expand Down
27 changes: 27 additions & 0 deletions tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from vllm.model_executor.models.registry import ModelRegistry

from tests.conftest import VllmRunner

Expand Down Expand Up @@ -94,6 +95,32 @@ def test_models_distributed_DeepSeek_dbo():
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
model_arch = 'DeepseekV2ForCausalLM'
registed_models = ModelRegistry.models
assert registed_models[
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
assert registed_models[
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeekV3_dbo():
example_prompts = ["The president of the United States is"] * 41
dtype = "half"
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
model_arch = 'DeepseekV3ForCausalLM'
registed_models = ModelRegistry.models
assert registed_models[
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
assert registed_models[
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
vllm_model.generate(example_prompts, sampling_params)


Expand Down
11 changes: 8 additions & 3 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ def register_model():
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")

ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")

else:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")

ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")

ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/models/deepseek_dbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def _forward_ms_layer(

if self.mlp.tp_size > 1:
num_token, _ = hidden_states[i].shape
padded_num_tokens = (self.mlp.tp_size - num_token %
padded_num_tokens = (self.mlp.tp_size - num_tokens[i] %
self.mlp.tp_size) % self.mlp.tp_size
if padded_num_tokens > 0:
hidden_states[i] = nn.functional.pad(
Expand Down Expand Up @@ -851,16 +851,16 @@ def forward(
if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms()
else self.end_layer - self.start_layer)

for i in range(self.start_layer, self.start_layer + num_normal_layers):
moe_start_layer = self.start_layer + num_normal_layers
for i in range(self.start_layer, min(moe_start_layer, self.end_layer)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)

moe_start_layer = self.start_layer + num_normal_layers
if moe_start_layer != self.end_layer:
if moe_start_layer < self.end_layer:
# if we enable multistream/dbo, process sparse layers here
hidden_states, residual = self._forward_ms_layers(
positions=positions,
Expand Down