2323import os
2424from unittest .mock import patch
2525
26+ import pytest
27+ import vllm # noqa: F401
2628from modelscope import snapshot_download # type: ignore
2729from vllm import SamplingParams
2830from vllm .model_executor .models .registry import ModelRegistry
2931
3032from tests .conftest import VllmRunner
33+ from tests .model_utils import check_outputs_equal
3134
3235os .environ ["PYTORCH_NPU_ALLOC_CONF" ] = "max_split_size_mb:256"
3336
@@ -47,21 +50,6 @@ def test_models_distributed_QwQ():
4750 vllm_model .generate_greedy (example_prompts , max_tokens )
4851
4952
50- def test_models_distributed_DeepSeek ():
51- example_prompts = [
52- "Hello, my name is" ,
53- ]
54- dtype = "half"
55- max_tokens = 5
56- with VllmRunner (
57- "deepseek-ai/DeepSeek-V2-Lite" ,
58- dtype = dtype ,
59- tensor_parallel_size = 4 ,
60- distributed_executor_backend = "mp" ,
61- ) as vllm_model :
62- vllm_model .generate_greedy (example_prompts , max_tokens )
63-
64-
6553@patch .dict (os .environ , {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE" : "1" })
6654def test_models_distributed_topk () -> None :
6755 example_prompts = [
@@ -84,26 +72,6 @@ def test_models_distributed_topk() -> None:
8472 vllm_model .generate (example_prompts , sampling_params )
8573
8674
87- @patch .dict (os .environ , {"VLLM_ASCEND_ENABLE_DBO" : "1" })
88- def test_models_distributed_DeepSeek_dbo ():
89- example_prompts = ["The president of the United States is" ] * 41
90- dtype = "half"
91- sampling_params = SamplingParams (max_tokens = 100 , temperature = 0.0 )
92- with VllmRunner (
93- "deepseek-ai/DeepSeek-V2-Lite" ,
94- dtype = dtype ,
95- tensor_parallel_size = 4 ,
96- distributed_executor_backend = "mp" ,
97- ) as vllm_model :
98- model_arch = 'DeepseekV2ForCausalLM'
99- registed_models = ModelRegistry .models
100- assert registed_models [
101- model_arch ].module_name == "vllm_ascend.models.deepseek_dbo"
102- assert registed_models [
103- model_arch ].class_name == "CustomDeepseekDBOForCausalLM"
104- vllm_model .generate (example_prompts , sampling_params )
105-
106-
10775@patch .dict (os .environ , {"VLLM_ASCEND_ENABLE_DBO" : "1" })
10876def test_models_distributed_DeepSeekV3_dbo ():
10977 example_prompts = ["The president of the United States is" ] * 41
@@ -139,3 +107,36 @@ def test_models_distributed_DeepSeek_W8A8():
139107 quantization = "ascend" ,
140108 ) as vllm_model :
141109 vllm_model .generate_greedy (example_prompts , max_tokens )
110+
111+
112+ def test_models_distributed_DeepSeek_dbo (monkeypatch : pytest .MonkeyPatch ):
113+ with monkeypatch .context () as m :
114+ m .setenv ("VLLM_ASCEND_ENABLE_DBO" , "1" )
115+
116+ example_prompts = ["The president of the United States is" ] * 41
117+ dtype = "half"
118+ sampling_params = SamplingParams (max_tokens = 100 , temperature = 0.0 )
119+ with VllmRunner (
120+ "deepseek-ai/DeepSeek-V2-Lite" ,
121+ dtype = dtype ,
122+ tensor_parallel_size = 4 ,
123+ distributed_executor_backend = "mp" ,
124+ ) as vllm_model :
125+ dpo_output = vllm_model .generate (example_prompts , sampling_params )
126+
127+ with monkeypatch .context () as m :
128+ m .setenv ("VLLM_ASCEND_ENABLE_DBO" , "0" )
129+ with VllmRunner (
130+ "deepseek-ai/DeepSeek-V2-Lite" ,
131+ dtype = dtype ,
132+ tensor_parallel_size = 4 ,
133+ distributed_executor_backend = "mp" ,
134+ ) as vllm_model :
135+ output = vllm_model .generate (example_prompts , sampling_params )
136+
137+ check_outputs_equal (
138+ outputs_0_lst = output ,
139+ outputs_1_lst = dpo_output ,
140+ name_0 = "vllm_outputs" ,
141+ name_1 = "vllm_dbo_outputs" ,
142+ )
0 commit comments