88import pytest
99from prometheus_client import REGISTRY
1010
11+ import vllm .envs as envs
1112from vllm import SamplingParams
1213from vllm .core .scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT ,
1314 ENABLE_ARTIFICIAL_PREEMPT )
2425 "tests/basic_correctness/test_preemption.py`" )
2526
2627
28+ @pytest .fixture
29+ def worker_use_ray () -> bool :
30+ # When SPMD worker is used, use ray_use_worker=True
31+ # to test delta input optimization works with preemption.
32+ return envs .VLLM_USE_RAY_SPMD_WORKER
33+
34+
2735@pytest .mark .parametrize ("model" , MODELS )
2836@pytest .mark .parametrize ("dtype" , ["half" ])
2937@pytest .mark .parametrize ("max_tokens" , [96 ])
@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
3644 dtype : str ,
3745 max_tokens : int ,
3846 chunked_prefill_token_size : int ,
47+ worker_use_ray : bool ,
3948) -> None :
4049 """Ensure that chunked prefill works with preemption."""
4150 max_num_seqs = min (chunked_prefill_token_size , 256 )
@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
5463 max_num_batched_tokens = max_num_batched_tokens ,
5564 enable_chunked_prefill = enable_chunked_prefill ,
5665 max_num_seqs = max_num_seqs ,
66+ worker_use_ray = worker_use_ray ,
5767 ) as vllm_model :
5868 vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
5969 assert (vllm_model .model .llm_engine .scheduler [0 ].artificial_preempt_cnt
@@ -79,6 +89,7 @@ def test_preemption(
7989 model : str ,
8090 dtype : str ,
8191 max_tokens : int ,
92+ worker_use_ray : bool ,
8293) -> None :
8394 """By default, recompute preemption is enabled"""
8495
@@ -89,6 +100,7 @@ def test_preemption(
89100 model ,
90101 dtype = dtype ,
91102 disable_log_stats = False ,
103+ worker_use_ray = worker_use_ray ,
92104 ) as vllm_model :
93105 vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
94106 assert (vllm_model .model .llm_engine .scheduler [0 ].artificial_preempt_cnt
@@ -132,6 +144,7 @@ def test_swap(
132144 dtype : str ,
133145 max_tokens : int ,
134146 beam_width : int ,
147+ worker_use_ray : bool ,
135148) -> None :
136149 """Use beam search enables swapping."""
137150 example_prompts = example_prompts [:1 ]
@@ -144,6 +157,7 @@ def test_swap(
144157 dtype = dtype ,
145158 swap_space = 10 ,
146159 disable_log_stats = False ,
160+ worker_use_ray = worker_use_ray ,
147161 ) as vllm_model :
148162 vllm_outputs = vllm_model .generate_beam_search (example_prompts ,
149163 beam_width , max_tokens )
@@ -188,6 +202,7 @@ def test_swap_infeasible(
188202 dtype : str ,
189203 max_tokens : int ,
190204 beam_width : int ,
205+ worker_use_ray : bool ,
191206) -> None :
192207 """Verify infeasible swap request will be ignored."""
193208 BLOCK_SIZE = 16
@@ -204,6 +219,7 @@ def test_swap_infeasible(
204219 # decode blocks are not enough to finish.
205220 num_gpu_blocks_override = prefill_blocks + decode_blocks ,
206221 max_model_len = (prefill_blocks + decode_blocks ) * BLOCK_SIZE ,
222+ worker_use_ray = worker_use_ray ,
207223 ) as vllm_model :
208224 sampling_params = SamplingParams (n = beam_width ,
209225 use_beam_search = True ,
@@ -230,6 +246,7 @@ def test_preemption_infeasible(
230246 model : str ,
231247 dtype : str ,
232248 max_tokens : int ,
249+ worker_use_ray : bool ,
233250) -> None :
234251 """Verify infeasible preemption request will be ignored."""
235252 BLOCK_SIZE = 16
@@ -244,6 +261,7 @@ def test_preemption_infeasible(
244261 # ignored instead of hanging forever.
245262 num_gpu_blocks_override = prefill_blocks + decode_blocks // 2 ,
246263 max_model_len = ((prefill_blocks + decode_blocks // 2 ) * BLOCK_SIZE ),
264+ worker_use_ray = worker_use_ray ,
247265 ) as vllm_model :
248266 sampling_params = SamplingParams (max_tokens = max_tokens ,
249267 ignore_eos = True )
0 commit comments