Skip to content

Commit 6676990

Browse files
tgasser-nvPouyanpi
andauthored
feat(benchmark): Create mock LLM server for use in benchmarks (#1403)
* Initial scaffold of mock OpenAI-compatible server * Refactor mock LLM, fix tests * Added tests to load YAML config. Still debugging dependency-injection of this into endpoints * Move FastAPI app import **after** the dependencies are loaded and cached * Remove debugging print statements * Temporary checkin * Add refusal probability and tests to check it * Use YAML configs for Nemoguard and app LLMs * Add Mock configs for content-safety and App LLM * Add async sleep statements and logging to record request time * Change content-safety mock to have latency of 0.5s * Add unit-tests to mock llm * Check for config file * Rename test files to avoid conflicts with other tests * Remove example_usage.py script and type-clean config.py * Regenerate headers with 2023 - 2025 * Removed commented-out code * review: PR#1403 (#1453) * test: run_server test coverage pragma no cover * Update licence --------- Co-authored-by: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> * Apply greptile fixes * Last couple of cleanups --------- Co-authored-by: Pouyan <13303554+Pouyanpi@users.noreply.github.com>
1 parent 2d773cc commit 6676990

File tree

17 files changed

+2189
-0
lines changed

17 files changed

+2189
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import asyncio
18+
import logging
19+
import time
20+
from typing import Annotated, Union
21+
22+
from fastapi import Depends, FastAPI, HTTPException, Request
23+
24+
from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings, get_settings
25+
from nemoguardrails.benchmark.mock_llm_server.models import (
26+
ChatCompletionChoice,
27+
ChatCompletionRequest,
28+
ChatCompletionResponse,
29+
CompletionChoice,
30+
CompletionRequest,
31+
CompletionResponse,
32+
Message,
33+
Model,
34+
ModelsResponse,
35+
Usage,
36+
)
37+
from nemoguardrails.benchmark.mock_llm_server.response_data import (
38+
calculate_tokens,
39+
generate_id,
40+
get_latency_seconds,
41+
get_response,
42+
)
43+
44+
# Create a console logging handler
45+
log = logging.getLogger(__name__)
46+
log.setLevel(logging.INFO) # TODO Control this from the CLi args
47+
48+
# Create a formatter to define the log message format
49+
formatter = logging.Formatter(
50+
"%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
51+
)
52+
53+
# Create a console handler to print logs to the console
54+
console_handler = logging.StreamHandler()
55+
console_handler.setLevel(logging.INFO) # DEBUG and higher will go to the console
56+
console_handler.setFormatter(formatter)
57+
58+
# Add console handler to logs
59+
log.addHandler(console_handler)
60+
61+
62+
ModelSettingsDep = Annotated[ModelSettings, Depends(get_settings)]
63+
64+
65+
def _validate_request_model(
66+
config: ModelSettingsDep,
67+
request: Union[CompletionRequest, ChatCompletionRequest],
68+
) -> None:
69+
"""Check the Completion or Chat Completion `model` field is in our supported model list"""
70+
if request.model != config.model:
71+
raise HTTPException(
72+
status_code=400,
73+
detail=f"Model '{request.model}' not found. Available models: {config.model}",
74+
)
75+
76+
77+
app = FastAPI(
78+
title="Mock LLM Server",
79+
description="OpenAI-compatible mock LLM server for testing and benchmarking",
80+
version="0.0.1",
81+
)
82+
83+
84+
@app.middleware("http")
85+
async def log_http_duration(request: Request, call_next):
86+
"""
87+
Middleware to log incoming requests and their responses.
88+
"""
89+
request_time = time.time()
90+
response = await call_next(request)
91+
response_time = time.time()
92+
93+
duration_seconds = response_time - request_time
94+
log.info(
95+
"Request finished: %s, took %.3f seconds",
96+
response.status_code,
97+
duration_seconds,
98+
)
99+
return response
100+
101+
102+
@app.get("/")
103+
async def root(config: ModelSettingsDep):
104+
"""Root endpoint with basic server information."""
105+
return {
106+
"message": "Mock LLM Server",
107+
"version": "0.0.1",
108+
"description": f"OpenAI-compatible mock LLM server for model: {config.model}",
109+
"endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"],
110+
"model_configuration": config,
111+
}
112+
113+
114+
@app.get("/v1/models", response_model=ModelsResponse)
115+
async def list_models(config: ModelSettingsDep):
116+
"""List available models."""
117+
log.debug("/v1/models request")
118+
119+
model = Model(
120+
id=config.model, object="model", created=int(time.time()), owned_by="system"
121+
)
122+
response = ModelsResponse(object="list", data=[model])
123+
log.debug("/v1/models response: %s", response)
124+
return response
125+
126+
127+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
128+
async def chat_completions(
129+
request: ChatCompletionRequest, config: ModelSettingsDep
130+
) -> ChatCompletionResponse:
131+
"""Create a chat completion."""
132+
133+
log.debug("/v1/chat/completions request: %s", request)
134+
135+
# Validate model exists
136+
_validate_request_model(config, request)
137+
138+
# Generate dummy response
139+
response_content = get_response(config)
140+
response_latency_seconds = get_latency_seconds(config)
141+
142+
# Calculate token usage
143+
prompt_text = " ".join([msg.content for msg in request.messages])
144+
prompt_tokens = calculate_tokens(prompt_text)
145+
completion_tokens = calculate_tokens(response_content)
146+
147+
# Create response
148+
completion_id = generate_id("chatcmpl")
149+
created_timestamp = int(time.time())
150+
151+
choices = []
152+
for i in range(request.n or 1):
153+
choice = ChatCompletionChoice(
154+
index=i,
155+
message=Message(role="assistant", content=response_content),
156+
finish_reason="stop",
157+
)
158+
choices.append(choice)
159+
160+
response = ChatCompletionResponse(
161+
id=completion_id,
162+
object="chat.completion",
163+
created=created_timestamp,
164+
model=request.model,
165+
choices=choices,
166+
usage=Usage(
167+
prompt_tokens=prompt_tokens,
168+
completion_tokens=completion_tokens,
169+
total_tokens=prompt_tokens + completion_tokens,
170+
),
171+
)
172+
await asyncio.sleep(response_latency_seconds)
173+
log.debug("/v1/chat/completions response: %s", response)
174+
return response
175+
176+
177+
@app.post("/v1/completions", response_model=CompletionResponse)
178+
async def completions(
179+
request: CompletionRequest, config: ModelSettingsDep
180+
) -> CompletionResponse:
181+
"""Create a text completion."""
182+
183+
log.debug("/v1/completions request: %s", request)
184+
185+
# Validate model exists
186+
_validate_request_model(config, request)
187+
188+
# Handle prompt (can be string or list)
189+
if isinstance(request.prompt, list):
190+
prompt_text = " ".join(request.prompt)
191+
else:
192+
prompt_text = request.prompt
193+
194+
# Generate dummy response
195+
response_text = get_response(config)
196+
response_latency_seconds = get_latency_seconds(config)
197+
198+
# Calculate token usage
199+
prompt_tokens = calculate_tokens(prompt_text)
200+
completion_tokens = calculate_tokens(response_text)
201+
202+
# Create response
203+
completion_id = generate_id("cmpl")
204+
created_timestamp = int(time.time())
205+
206+
choices = []
207+
for i in range(request.n or 1):
208+
choice = CompletionChoice(
209+
text=response_text, index=i, logprobs=None, finish_reason="stop"
210+
)
211+
choices.append(choice)
212+
213+
response = CompletionResponse(
214+
id=completion_id,
215+
object="text_completion",
216+
created=created_timestamp,
217+
model=request.model,
218+
choices=choices,
219+
usage=Usage(
220+
prompt_tokens=prompt_tokens,
221+
completion_tokens=completion_tokens,
222+
total_tokens=prompt_tokens + completion_tokens,
223+
),
224+
)
225+
226+
await asyncio.sleep(response_latency_seconds)
227+
log.debug("/v1/completions response: %s", response)
228+
return response
229+
230+
231+
@app.get("/health")
232+
async def health_check():
233+
"""Health check endpoint."""
234+
log.debug("/health request")
235+
response = {"status": "healthy", "timestamp": int(time.time())}
236+
log.debug("/health response: %s", response)
237+
return response
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
from functools import lru_cache
18+
from pathlib import Path
19+
20+
from pydantic import Field
21+
from pydantic_settings import BaseSettings, SettingsConfigDict
22+
23+
CONFIG_FILE_ENV_VAR = "MOCK_LLM_CONFIG_FILE"
24+
config_file_path = os.getenv(CONFIG_FILE_ENV_VAR, "model_settings.yml")
25+
CONFIG_FILE = Path(config_file_path)
26+
27+
28+
class ModelSettings(BaseSettings):
29+
"""Pydantic model to configure the Mock LLM Server."""
30+
31+
# Mandatory fields
32+
model: str = Field(..., description="Model name served by mock server")
33+
unsafe_probability: float = Field(
34+
default=0.1, description="Probability of unsafe response (between 0 and 1)"
35+
)
36+
unsafe_text: str = Field(..., description="Refusal response to unsafe prompt")
37+
safe_text: str = Field(..., description="Safe response")
38+
39+
# Config with default values
40+
# Latency sampled from a truncated-normal distribution.
41+
# Plain Normal distributions have infinite support, and can be negative
42+
latency_min_seconds: float = Field(
43+
default=0.1, description="Minimum latency in seconds"
44+
)
45+
latency_max_seconds: float = Field(
46+
default=5, description="Maximum latency in seconds"
47+
)
48+
latency_mean_seconds: float = Field(
49+
default=0.5, description="The average response time in seconds"
50+
)
51+
latency_std_seconds: float = Field(
52+
default=0.1, description="Standard deviation of response time"
53+
)
54+
55+
model_config = SettingsConfigDict(env_file=CONFIG_FILE)
56+
57+
58+
@lru_cache()
59+
def get_settings() -> ModelSettings:
60+
"""Singleton-pattern to get settings once via lru_cache"""
61+
settings = ModelSettings() # type: ignore (These are filled in by loading from CONFIG_FILE)
62+
return settings
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
models:
2+
- type: main
3+
engine: nim
4+
model: meta/llama-3.3-70b-instruct
5+
parameters:
6+
base_url: http://localhost:8000
7+
8+
- type: content_safety
9+
engine: nim
10+
model: nvidia/llama-3.1-nemoguard-8b-content-safety
11+
parameters:
12+
base_url: http://localhost:8001
13+
14+
15+
rails:
16+
input:
17+
flows:
18+
- content safety check input $model=content_safety
19+
output:
20+
flows:
21+
- content safety check output $model=content_safety

0 commit comments

Comments
 (0)