Skip to content

Commit 8c2072c

Browse files
authored
fix: [trtllm] add wait_for_instance before register_llm (#2683)
Signed-off-by: alec-flowers <aflowers@nvidia.com>
1 parent 63f5bbc commit 8c2072c

File tree

5 files changed

+91
-45
lines changed

5 files changed

+91
-45
lines changed

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -239,17 +239,6 @@ async def init(runtime: DistributedRuntime, config: Config):
239239
runtime_config.reasoning_parser = config.reasoning_parser
240240
runtime_config.tool_call_parser = config.tool_call_parser
241241

242-
if is_first_worker(config):
243-
# Register the model with runtime config
244-
await register_llm(
245-
modelType,
246-
endpoint,
247-
config.model_path,
248-
config.served_model_name,
249-
kv_cache_block_size=config.kv_block_size,
250-
migration_limit=config.migration_limit,
251-
runtime_config=runtime_config,
252-
)
253242
# publisher will be set later if publishing is enabled.
254243
handler_config = RequestHandlerConfig(
255244
component=component,
@@ -262,6 +251,23 @@ async def init(runtime: DistributedRuntime, config: Config):
262251
multimodal_processor=multimodal_processor,
263252
)
264253

254+
if next_client:
255+
logging.info(
256+
f"Waiting for the next endpoint to be ready: {config.next_endpoint}"
257+
)
258+
await next_client.wait_for_instances()
259+
260+
if is_first_worker(config):
261+
# Register the model with runtime config
262+
await register_llm(
263+
modelType,
264+
endpoint,
265+
config.model_path,
266+
config.served_model_name,
267+
kv_cache_block_size=config.kv_block_size,
268+
migration_limit=config.migration_limit,
269+
)
270+
265271
if config.publish_events_and_metrics and is_first_worker(config):
266272
# Initialize and pass in the publisher to the request handler to
267273
# publish events and metrics.

tests/serve/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class EngineConfig:
2323
endpoints: List[str]
2424
response_handlers: List[Callable[[Any], str]]
2525
model: str
26-
timeout: int = 120
26+
timeout: int = 600
2727
delayed_start: int = 0
2828

2929

tests/serve/test_trtllm.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
class TRTLLMConfig(EngineConfig):
2323
"""Configuration for trtllm test scenarios"""
2424

25-
timeout: int = 60
26-
2725

2826
class TRTLLMProcess(EngineProcess):
2927
"""Simple process manager for trtllm shell scripts"""
@@ -71,9 +69,7 @@ def __init__(self, config: TRTLLMConfig, request):
7169
chat_completions_response_handler,
7270
completions_response_handler,
7371
],
74-
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
75-
delayed_start=0,
76-
timeout=360,
72+
model="Qwen/Qwen3-0.6B",
7773
),
7874
"disaggregated": TRTLLMConfig(
7975
name="disaggregated",
@@ -85,9 +81,7 @@ def __init__(self, config: TRTLLMConfig, request):
8581
chat_completions_response_handler,
8682
completions_response_handler,
8783
],
88-
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
89-
delayed_start=0,
90-
timeout=360,
84+
model="Qwen/Qwen3-0.6B",
9185
),
9286
# TODO: These are sanity tests that the kv router examples launch
9387
# and inference without error, but do not do detailed checks on the
@@ -102,9 +96,7 @@ def __init__(self, config: TRTLLMConfig, request):
10296
chat_completions_response_handler,
10397
completions_response_handler,
10498
],
105-
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
106-
delayed_start=0,
107-
timeout=360,
99+
model="Qwen/Qwen3-0.6B",
108100
),
109101
"disaggregated_router": TRTLLMConfig(
110102
name="disaggregated_router",
@@ -116,9 +108,7 @@ def __init__(self, config: TRTLLMConfig, request):
116108
chat_completions_response_handler,
117109
completions_response_handler,
118110
],
119-
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
120-
delayed_start=0,
121-
timeout=360,
111+
model="Qwen/Qwen3-0.6B",
122112
),
123113
}
124114

tests/serve/test_vllm.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ def __init__(self, config: VLLMConfig, request):
133133
completions_response_handler,
134134
],
135135
model="Qwen/Qwen3-0.6B",
136-
delayed_start=0,
137-
timeout=360,
138136
),
139137
"agg-router": VLLMConfig(
140138
name="agg-router",
@@ -147,8 +145,6 @@ def __init__(self, config: VLLMConfig, request):
147145
completions_response_handler,
148146
],
149147
model="Qwen/Qwen3-0.6B",
150-
delayed_start=0,
151-
timeout=360,
152148
),
153149
"disaggregated": VLLMConfig(
154150
name="disaggregated",
@@ -161,8 +157,6 @@ def __init__(self, config: VLLMConfig, request):
161157
completions_response_handler,
162158
],
163159
model="Qwen/Qwen3-0.6B",
164-
delayed_start=0,
165-
timeout=360,
166160
),
167161
"deepep": VLLMConfig(
168162
name="deepep",
@@ -179,7 +173,6 @@ def __init__(self, config: VLLMConfig, request):
179173
completions_response_handler,
180174
],
181175
model="deepseek-ai/DeepSeek-V2-Lite",
182-
delayed_start=0,
183176
args=[
184177
"--model",
185178
"deepseek-ai/DeepSeek-V2-Lite",
@@ -190,7 +183,7 @@ def __init__(self, config: VLLMConfig, request):
190183
"--gpus-per-node",
191184
"2",
192185
],
193-
timeout=560,
186+
timeout=700,
194187
),
195188
"multimodal_agg_llava": VLLMConfig(
196189
name="multimodal_agg_llava",
@@ -202,9 +195,7 @@ def __init__(self, config: VLLMConfig, request):
202195
chat_completions_response_handler,
203196
],
204197
model="llava-hf/llava-1.5-7b-hf",
205-
delayed_start=0,
206198
args=["--model", "llava-hf/llava-1.5-7b-hf"],
207-
timeout=360,
208199
),
209200
"multimodal_agg_qwen": VLLMConfig(
210201
name="multimodal_agg_qwen",

tests/utils/managed_process.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import os
1919
import shutil
20+
import signal
2021
import socket
2122
import subprocess
2223
import time
@@ -82,6 +83,10 @@ class ManagedProcess:
8283
straggler_commands: List[str] = field(default_factory=list)
8384
log_dir: str = os.getcwd()
8485

86+
# Ensure attributes exist even if startup fails early
87+
proc: Optional[subprocess.Popen] = None
88+
_pgid: Optional[int] = None
89+
8590
_logger = logging.getLogger()
8691
_command_name = None
8792
_log_path = None
@@ -107,20 +112,30 @@ def __enter__(self):
107112

108113
return self
109114

110-
except Exception as e:
111-
self.__exit__(None, None, None)
112-
raise e
115+
except Exception:
116+
try:
117+
self.__exit__(None, None, None)
118+
except Exception as cleanup_err:
119+
self._logger.warning(
120+
"Error during cleanup in __enter__: %s", cleanup_err
121+
)
122+
raise
113123

114124
def __exit__(self, exc_type, exc_val, exc_tb):
125+
self._terminate_process_group()
126+
115127
process_list = [self.proc, self._tee_proc, self._sed_proc]
116128
for process in process_list:
117129
if process:
118-
if process.stdout:
119-
process.stdout.close()
120-
if process.stdin:
121-
process.stdin.close()
122-
terminate_process_tree(process.pid, self._logger)
123-
process.wait()
130+
try:
131+
if process.stdout:
132+
process.stdout.close()
133+
if process.stdin:
134+
process.stdin.close()
135+
terminate_process_tree(process.pid, self._logger)
136+
process.wait()
137+
except Exception as e:
138+
self._logger.warning("Error terminating process: %s", e)
124139
if self.data_dir:
125140
self._remove_directory(self.data_dir)
126141

@@ -169,6 +184,12 @@ def _start_process(self):
169184
stderr=stderr,
170185
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
171186
)
187+
# Capture the child's process group id for robust cleanup even if parent shell exits
188+
try:
189+
self._pgid = os.getpgid(self.proc.pid)
190+
except Exception as e:
191+
self._logger.warning("Could not get process group id: %s", e)
192+
self._pgid = None
172193
self._sed_proc = subprocess.Popen(
173194
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
174195
stdin=self.proc.stdout,
@@ -190,6 +211,12 @@ def _start_process(self):
190211
stderr=stderr,
191212
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
192213
)
214+
# Capture the child's process group id for robust cleanup even if parent shell exits
215+
try:
216+
self._pgid = os.getpgid(self.proc.pid)
217+
except Exception as e:
218+
self._logger.warning("Could not get process group id: %s", e)
219+
self._pgid = None
193220

194221
self._sed_proc = subprocess.Popen(
195222
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
@@ -198,6 +225,38 @@ def _start_process(self):
198225
)
199226
self._tee_proc = None
200227

228+
def _terminate_process_group(self, timeout: float = 5.0):
229+
"""Terminate the entire process group/session started for the child.
230+
231+
This catches cases where the launcher shell exits and its children are reparented,
232+
leaving no parent PID to traverse, but they remain in the same process group.
233+
"""
234+
if self._pgid is None:
235+
return
236+
try:
237+
self._logger.info("Terminating process group: %s", self._pgid)
238+
os.killpg(self._pgid, signal.SIGTERM)
239+
except ProcessLookupError:
240+
return
241+
except Exception as e:
242+
self._logger.warning(
243+
"Error sending SIGTERM to process group %s: %s", self._pgid, e
244+
)
245+
return
246+
247+
# Give processes a brief moment to exit gracefully
248+
time.sleep(timeout)
249+
250+
# Force kill if anything remains
251+
try:
252+
os.killpg(self._pgid, signal.SIGKILL)
253+
except ProcessLookupError:
254+
pass
255+
except Exception as e:
256+
self._logger.warning(
257+
"Error sending SIGKILL to process group %s: %s", self._pgid, e
258+
)
259+
201260
def _remove_directory(self, path: str) -> None:
202261
"""Remove a directory."""
203262
try:

0 commit comments

Comments
 (0)