Skip to content

Commit de7fe38

Browse files
authored
feat: add vllm e2e integration tests (#1935)
1 parent 860f3f7 commit de7fe38

File tree

3 files changed

+274
-0
lines changed

3 files changed

+274
-0
lines changed

tests/serve/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# List of models used in the serve tests
2222
SERVE_TEST_MODELS = [
23+
"Qwen/Qwen3-0.6B",
2324
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
2425
"llava-hf/llava-1.5-7b-hf",
2526
]

tests/serve/test_vllm.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import logging
5+
import os
6+
import time
7+
from dataclasses import dataclass
8+
from typing import Any, Callable, List
9+
10+
import pytest
11+
import requests
12+
13+
from tests.utils.deployment_graph import (
14+
Payload,
15+
chat_completions_response_handler,
16+
completions_response_handler,
17+
)
18+
from tests.utils.managed_process import ManagedProcess
19+
20+
logger = logging.getLogger(__name__)
21+
22+
text_prompt = "Tell me a short joke about AI."
23+
24+
25+
def create_payload_for_config(config: "VLLMConfig") -> Payload:
26+
"""Create a payload using the model from the vLLM config"""
27+
return Payload(
28+
payload_chat={
29+
"model": config.model,
30+
"messages": [
31+
{
32+
"role": "user",
33+
"content": text_prompt,
34+
}
35+
],
36+
"max_tokens": 150,
37+
"temperature": 0.1,
38+
},
39+
payload_completions={
40+
"model": config.model,
41+
"prompt": text_prompt,
42+
"max_tokens": 150,
43+
"temperature": 0.1,
44+
},
45+
repeat_count=1,
46+
expected_log=[],
47+
expected_response=["AI"],
48+
)
49+
50+
51+
@dataclass
52+
class VLLMConfig:
53+
"""Configuration for vLLM test scenarios"""
54+
55+
name: str
56+
directory: str
57+
script_name: str
58+
marks: List[Any]
59+
endpoints: List[str]
60+
response_handlers: List[Callable[[Any], str]]
61+
model: str
62+
timeout: int = 60
63+
delayed_start: int = 0
64+
65+
66+
class VLLMProcess(ManagedProcess):
67+
"""Simple process manager for vllm shell scripts"""
68+
69+
def __init__(self, config: VLLMConfig, request):
70+
self.port = 8080
71+
self.config = config
72+
self.dir = config.directory
73+
script_path = os.path.join(self.dir, "launch", config.script_name)
74+
75+
if not os.path.exists(script_path):
76+
raise FileNotFoundError(f"vLLM script not found: {script_path}")
77+
78+
command = ["bash", script_path]
79+
80+
super().__init__(
81+
command=command,
82+
timeout=config.timeout,
83+
display_output=True,
84+
working_dir=self.dir,
85+
health_check_ports=[], # Disable port health check
86+
health_check_urls=[
87+
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
88+
],
89+
delayed_start=config.delayed_start,
90+
terminate_existing=False, # If true, will call all bash processes including myself
91+
stragglers=[], # Don't kill any stragglers automatically
92+
log_dir=request.node.name,
93+
)
94+
95+
def _check_models_api(self, response):
96+
"""Check if models API is working and returns models"""
97+
try:
98+
if response.status_code != 200:
99+
return False
100+
data = response.json()
101+
return data.get("data") and len(data["data"]) > 0
102+
except Exception:
103+
return False
104+
105+
def _check_url(self, url, timeout=30, sleep=2.0):
106+
"""Override to use a more reasonable retry interval"""
107+
return super()._check_url(url, timeout, sleep)
108+
109+
def check_response(
110+
self, payload, response, response_handler, logger=logging.getLogger()
111+
):
112+
assert response.status_code == 200, "Response Error"
113+
content = response_handler(response)
114+
logger.info("Received Content: %s", content)
115+
# Check for expected responses
116+
assert content, "Empty response content"
117+
for expected in payload.expected_response:
118+
assert expected in content, "Expected '%s' not found in response" % expected
119+
120+
def wait_for_ready(self, payload, logger=logging.getLogger()):
121+
url = f"http://localhost:{self.port}/{self.config.endpoints[0]}"
122+
start_time = time.time()
123+
retry_delay = 5
124+
elapsed = 0.0
125+
logger.info("Waiting for Deployment Ready")
126+
json_payload = (
127+
payload.payload_chat
128+
if self.config.endpoints[0] == "v1/chat/completions"
129+
else payload.payload_completions
130+
)
131+
132+
while time.time() - start_time < self.config.timeout:
133+
elapsed = time.time() - start_time
134+
try:
135+
response = requests.post(
136+
url,
137+
json=json_payload,
138+
timeout=self.config.timeout - elapsed,
139+
)
140+
except (requests.RequestException, requests.Timeout) as e:
141+
logger.warning("Retrying due to Request failed: %s", e)
142+
time.sleep(retry_delay)
143+
continue
144+
logger.info("Response%r", response)
145+
if response.status_code == 500:
146+
error = response.json().get("error", "")
147+
if "no instances" in error:
148+
logger.warning("Retrying due to no instances available")
149+
time.sleep(retry_delay)
150+
continue
151+
if response.status_code == 404:
152+
error = response.json().get("error", "")
153+
if "Model not found" in error:
154+
logger.warning("Retrying due to model not found")
155+
time.sleep(retry_delay)
156+
continue
157+
# Process the response
158+
if response.status_code != 200:
159+
logger.error(
160+
"Service returned status code %s: %s",
161+
response.status_code,
162+
response.text,
163+
)
164+
pytest.fail(
165+
"Service returned status code %s: %s"
166+
% (response.status_code, response.text)
167+
)
168+
else:
169+
break
170+
else:
171+
logger.error(
172+
"Service did not return a successful response within %s s",
173+
self.config.timeout,
174+
)
175+
pytest.fail(
176+
"Service did not return a successful response within %s s"
177+
% self.config.timeout
178+
)
179+
180+
self.check_response(payload, response, self.config.response_handlers[0], logger)
181+
182+
logger.info("Deployment Ready")
183+
184+
185+
# vLLM test configurations
186+
vllm_configs = {
187+
"aggregated": VLLMConfig(
188+
name="aggregated",
189+
directory="/workspace/examples/llm",
190+
script_name="agg.sh",
191+
marks=[pytest.mark.gpu_1, pytest.mark.vllm],
192+
endpoints=["v1/chat/completions", "v1/completions"],
193+
response_handlers=[
194+
chat_completions_response_handler,
195+
completions_response_handler,
196+
],
197+
model="Qwen/Qwen3-0.6B",
198+
delayed_start=45,
199+
),
200+
"disaggregated": VLLMConfig(
201+
name="disaggregated",
202+
directory="/workspace/examples/llm",
203+
script_name="disagg.sh",
204+
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
205+
endpoints=["v1/chat/completions", "v1/completions"],
206+
response_handlers=[
207+
chat_completions_response_handler,
208+
completions_response_handler,
209+
],
210+
model="Qwen/Qwen3-0.6B",
211+
delayed_start=45,
212+
),
213+
}
214+
215+
216+
@pytest.fixture(
217+
params=[
218+
pytest.param(config_name, marks=config.marks)
219+
for config_name, config in vllm_configs.items()
220+
]
221+
)
222+
def vllm_config_test(request):
223+
"""Fixture that provides different vLLM test configurations"""
224+
return vllm_configs[request.param]
225+
226+
227+
@pytest.mark.e2e
228+
@pytest.mark.slow
229+
def test_serve_deployment(vllm_config_test, request, runtime_services):
230+
"""
231+
Test dynamo serve deployments with different graph configurations.
232+
"""
233+
234+
# runtime_services is used to start nats and etcd
235+
236+
logger = logging.getLogger(request.node.name)
237+
logger.info("Starting test_deployment")
238+
239+
config = vllm_config_test
240+
payload = create_payload_for_config(config)
241+
242+
logger.info("Using model: %s", config.model)
243+
logger.info("Script: %s", config.script_name)
244+
245+
with VLLMProcess(config, request) as server_process:
246+
server_process.wait_for_ready(payload, logger)
247+
248+
for endpoint, response_handler in zip(
249+
config.endpoints, config.response_handlers
250+
):
251+
url = f"http://localhost:{server_process.port}/{endpoint}"
252+
start_time = time.time()
253+
elapsed = 0.0
254+
255+
request_body = (
256+
payload.payload_chat
257+
if endpoint == "v1/chat/completions"
258+
else payload.payload_completions
259+
)
260+
261+
for _ in range(payload.repeat_count):
262+
elapsed = time.time() - start_time
263+
264+
response = requests.post(
265+
url,
266+
json=request_body,
267+
timeout=config.timeout - elapsed,
268+
)
269+
server_process.check_response(
270+
payload, response, response_handler, logger
271+
)

tests/utils/managed_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _start_process(self):
166166
stdin=stdin,
167167
stdout=stdout,
168168
stderr=stderr,
169+
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
169170
)
170171
self._sed_proc = subprocess.Popen(
171172
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
@@ -186,6 +187,7 @@ def _start_process(self):
186187
stdin=stdin,
187188
stdout=stdout,
188189
stderr=stderr,
190+
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
189191
)
190192

191193
self._sed_proc = subprocess.Popen(

0 commit comments

Comments
 (0)