Skip to content

Commit 9302009

Browse files
alec-flowersnnshah1
authored andcommitted
fix: pytest robustness and parsing error (#2676)
Signed-off-by: nnshah1 <neelays@nvidia.com>
1 parent 6a39968 commit 9302009

File tree

9 files changed

+449
-322
lines changed

9 files changed

+449
-322
lines changed

examples/multimodal/components/encode_worker.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.utils import FlexibleArgumentParser
2929

3030
import dynamo.nixl_connect as connect
31-
from dynamo.runtime import DistributedRuntime, dynamo_worker
31+
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
3232
from dynamo.runtime.logging import configure_dynamo_logging
3333

3434
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
@@ -56,8 +56,13 @@
5656

5757

5858
class VllmEncodeWorker:
59-
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs) -> None:
60-
self.downstream_endpoint = args.downstream_endpoint
59+
def __init__(
60+
self,
61+
args: argparse.Namespace,
62+
engine_args: AsyncEngineArgs,
63+
pd_worker_client: Client,
64+
) -> None:
65+
self.pd_worker_client = pd_worker_client
6166
self.engine_args = engine_args
6267
self.model = self.engine_args.model
6368

@@ -178,16 +183,6 @@ async def generate(
178183

179184
async def async_init(self, runtime: DistributedRuntime):
180185
logger.info("Startup started.")
181-
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
182-
self.downstream_endpoint
183-
)
184-
self.pd_worker_client = (
185-
await runtime.namespace(parsed_namespace)
186-
.component(parsed_component_name)
187-
.endpoint(parsed_endpoint_name)
188-
.client()
189-
)
190-
191186
# Create and initialize a dynamo connector for this worker.
192187
# We'll needs this to move data between this worker and remote workers efficiently.
193188
self._connector = connect.Connector()
@@ -262,9 +257,22 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
262257

263258
generate_endpoint = component.endpoint(config.endpoint)
264259

265-
handler = VllmEncodeWorker(args, config.engine_args)
260+
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
261+
args.downstream_endpoint
262+
)
263+
pd_worker_client = (
264+
await runtime.namespace(parsed_namespace)
265+
.component(parsed_component_name)
266+
.endpoint(parsed_endpoint_name)
267+
.client()
268+
)
269+
270+
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
266271
await handler.async_init(runtime)
267272

273+
logger.info("Waiting for PD Worker Instances ...")
274+
await pd_worker_client.wait_for_instances()
275+
268276
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
269277

270278
try:

examples/multimodal/components/processor.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from vllm.utils import FlexibleArgumentParser
3434

3535
from dynamo.llm import ModelType, register_llm
36-
from dynamo.runtime import DistributedRuntime, dynamo_worker
36+
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
3737
from dynamo.runtime.logging import configure_dynamo_logging
3838

3939
# To import example local module
@@ -96,9 +96,14 @@ def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
9696

9797
return args, config
9898

99-
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs):
99+
def __init__(
100+
self,
101+
args: argparse.Namespace,
102+
engine_args: AsyncEngineArgs,
103+
encode_worker_client: Client,
104+
):
105+
self.encode_worker_client = encode_worker_client
100106
self.prompt_template = args.prompt_template
101-
self.downstream_endpoint = args.downstream_endpoint
102107
self.engine_args = engine_args
103108
self.model_config = self.engine_args.create_model_config()
104109
self.default_sampling_params = self.model_config.get_diff_sampling_param()
@@ -125,17 +130,6 @@ def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
125130
)
126131
return base_tokenizer
127132

128-
async def async_init(self, runtime: DistributedRuntime):
129-
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
130-
self.downstream_endpoint
131-
)
132-
self.encode_worker_client = (
133-
await runtime.namespace(parsed_namespace)
134-
.component(parsed_component_name)
135-
.endpoint(parsed_endpoint_name)
136-
.client()
137-
)
138-
139133
# Main method to parse the request and send the request to the vllm worker.
140134
async def _generate(
141135
self,
@@ -300,8 +294,20 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
300294

301295
generate_endpoint = component.endpoint(config.endpoint)
302296

303-
handler = Processor(args, config.engine_args)
304-
await handler.async_init(runtime)
297+
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
298+
args.downstream_endpoint
299+
)
300+
encode_worker_client = (
301+
await runtime.namespace(parsed_namespace)
302+
.component(parsed_component_name)
303+
.endpoint(parsed_endpoint_name)
304+
.client()
305+
)
306+
307+
handler = Processor(args, config.engine_args, encode_worker_client)
308+
309+
logger.info("Waiting for Encoder Worker Instances ...")
310+
await encode_worker_client.wait_for_instances()
305311

306312
# Register the endpoint as entrypoint to a model
307313
await register_llm(

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,24 @@ class Client:
246246

247247
...
248248

249+
def instance_ids(self) -> List[int]:
250+
"""
251+
Get list of current instance IDs.
252+
253+
Returns:
254+
A list of currently available instance IDs
255+
"""
256+
...
257+
258+
async def wait_for_instances(self) -> List[int]:
259+
"""
260+
Wait for instances to be available for work and return their IDs.
261+
262+
Returns:
263+
A list of instance IDs that are available for work
264+
"""
265+
...
266+
249267
async def random(self, request: JsonLike) -> AsyncIterator[JsonLike]:
250268
"""
251269
Pick a random instance of the endpoint and issue the request

tests/serve/common.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""
5+
6+
from dataclasses import dataclass
7+
from typing import Any, Callable, List
8+
9+
from tests.utils.deployment_graph import Payload
10+
11+
# Common text prompt used across tests
12+
TEXT_PROMPT = "Tell me a short joke about AI."
13+
14+
15+
@dataclass
16+
class EngineConfig:
17+
"""Base configuration for engine test scenarios"""
18+
19+
name: str
20+
directory: str
21+
script_name: str
22+
marks: List[Any]
23+
endpoints: List[str]
24+
response_handlers: List[Callable[[Any], str]]
25+
model: str
26+
timeout: int = 120
27+
delayed_start: int = 0
28+
29+
30+
def create_payload_for_config(config: EngineConfig) -> Payload:
31+
"""Create a standard payload using the model from the engine config.
32+
33+
This provides the default implementation for text-only models.
34+
"""
35+
return Payload(
36+
payload_chat={
37+
"model": config.model,
38+
"messages": [
39+
{
40+
"role": "user",
41+
"content": TEXT_PROMPT,
42+
}
43+
],
44+
"max_tokens": 150,
45+
"temperature": 0.1,
46+
"stream": False,
47+
},
48+
payload_completions={
49+
"model": config.model,
50+
"prompt": TEXT_PROMPT,
51+
"max_tokens": 150,
52+
"temperature": 0.1,
53+
"stream": False,
54+
},
55+
repeat_count=3,
56+
expected_log=[],
57+
expected_response=["AI"],
58+
)

0 commit comments

Comments
 (0)