Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/serve/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# List of models used in the serve tests
SERVE_TEST_MODELS = [
"Qwen/Qwen3-0.6B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llava-hf/llava-1.5-7b-hf",
]
Expand Down
271 changes: 271 additions & 0 deletions tests/serve/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import time
from dataclasses import dataclass
from typing import Any, Callable, List

import pytest
import requests

from tests.utils.deployment_graph import (
Payload,
chat_completions_response_handler,
completions_response_handler,
)
from tests.utils.managed_process import ManagedProcess

logger = logging.getLogger(__name__)

text_prompt = "Tell me a short joke about AI."


def create_payload_for_config(config: "VLLMConfig") -> Payload:
"""Create a payload using the model from the vLLM config"""
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": text_prompt,
}
],
"max_tokens": 150,
"temperature": 0.1,
},
payload_completions={
"model": config.model,
"prompt": text_prompt,
"max_tokens": 150,
"temperature": 0.1,
},
repeat_count=1,
expected_log=[],
expected_response=["AI"],
)


@dataclass
class VLLMConfig:
"""Configuration for vLLM test scenarios"""

name: str
directory: str
script_name: str
marks: List[Any]
endpoints: List[str]
response_handlers: List[Callable[[Any], str]]
model: str
timeout: int = 60
delayed_start: int = 0


class VLLMProcess(ManagedProcess):
"""Simple process manager for vllm shell scripts"""

def __init__(self, config: VLLMConfig, request):
self.port = 8080
self.config = config
self.dir = config.directory
script_path = os.path.join(self.dir, "launch", config.script_name)

if not os.path.exists(script_path):
raise FileNotFoundError(f"vLLM script not found: {script_path}")

command = ["bash", script_path]

super().__init__(
command=command,
timeout=config.timeout,
display_output=True,
working_dir=self.dir,
health_check_ports=[], # Disable port health check
health_check_urls=[
(f"http://localhost:{self.port}/v1/models", self._check_models_api)
],
delayed_start=config.delayed_start,
terminate_existing=False, # If true, will call all bash processes including myself
stragglers=[], # Don't kill any stragglers automatically
log_dir=request.node.name,
)

def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False

def _check_url(self, url, timeout=30, sleep=2.0):
"""Override to use a more reasonable retry interval"""
return super()._check_url(url, timeout, sleep)

def check_response(
self, payload, response, response_handler, logger=logging.getLogger()
):
assert response.status_code == 200, "Response Error"
content = response_handler(response)
logger.info("Received Content: %s", content)
# Check for expected responses
assert content, "Empty response content"
for expected in payload.expected_response:
assert expected in content, "Expected '%s' not found in response" % expected

def wait_for_ready(self, payload, logger=logging.getLogger()):
url = f"http://localhost:{self.port}/{self.config.endpoints[0]}"
start_time = time.time()
retry_delay = 5
elapsed = 0.0
logger.info("Waiting for Deployment Ready")
json_payload = (
payload.payload_chat
if self.config.endpoints[0] == "v1/chat/completions"
else payload.payload_completions
)

while time.time() - start_time < self.config.timeout:
elapsed = time.time() - start_time
try:
response = requests.post(
url,
json=json_payload,
timeout=self.config.timeout - elapsed,
)
except (requests.RequestException, requests.Timeout) as e:
logger.warning("Retrying due to Request failed: %s", e)
time.sleep(retry_delay)
continue
logger.info("Response%r", response)
if response.status_code == 500:
error = response.json().get("error", "")
if "no instances" in error:
logger.warning("Retrying due to no instances available")
time.sleep(retry_delay)
continue
if response.status_code == 404:
error = response.json().get("error", "")
if "Model not found" in error:
logger.warning("Retrying due to model not found")
time.sleep(retry_delay)
continue
# Process the response
if response.status_code != 200:
logger.error(
"Service returned status code %s: %s",
response.status_code,
response.text,
)
pytest.fail(
"Service returned status code %s: %s"
% (response.status_code, response.text)
)
else:
break
else:
logger.error(
"Service did not return a successful response within %s s",
self.config.timeout,
)
pytest.fail(
"Service did not return a successful response within %s s"
% self.config.timeout
)

self.check_response(payload, response, self.config.response_handlers[0], logger)

logger.info("Deployment Ready")


# vLLM test configurations
vllm_configs = {
"aggregated": VLLMConfig(
name="aggregated",
directory="/workspace/examples/llm",
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
),
"disaggregated": VLLMConfig(
name="disaggregated",
directory="/workspace/examples/llm",
script_name="disagg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
),
}


@pytest.fixture(
params=[
pytest.param(config_name, marks=config.marks)
for config_name, config in vllm_configs.items()
]
)
def vllm_config_test(request):
"""Fixture that provides different vLLM test configurations"""
return vllm_configs[request.param]


@pytest.mark.e2e
@pytest.mark.slow
def test_serve_deployment(vllm_config_test, request, runtime_services):
"""
Test dynamo serve deployments with different graph configurations.
"""

# runtime_services is used to start nats and etcd

logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")

config = vllm_config_test
payload = create_payload_for_config(config)

logger.info("Using model: %s", config.model)
logger.info("Script: %s", config.script_name)

with VLLMProcess(config, request) as server_process:
server_process.wait_for_ready(payload, logger)

for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0

request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
)

for _ in range(payload.repeat_count):
elapsed = time.time() - start_time

response = requests.post(
url,
json=request_body,
timeout=config.timeout - elapsed,
)
server_process.check_response(
payload, response, response_handler, logger
)
2 changes: 2 additions & 0 deletions tests/utils/managed_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _start_process(self):
stdin=stdin,
stdout=stdout,
stderr=stderr,
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
)
self._sed_proc = subprocess.Popen(
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
Expand All @@ -186,6 +187,7 @@ def _start_process(self):
stdin=stdin,
stdout=stdout,
stderr=stderr,
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
)

self._sed_proc = subprocess.Popen(
Expand Down
Loading