Skip to content

Commit 3e1c0b3

Browse files
committed
add hold to wait for next worker
1 parent 9b9f2ce commit 3e1c0b3

File tree

1 file changed

+18
-10
lines changed
  • components/backends/trtllm/src/dynamo/trtllm

1 file changed

+18
-10
lines changed

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,6 @@ async def init(runtime: DistributedRuntime, config: Config):
228228
async with get_llm_engine(engine_args) as engine:
229229
endpoint = component.endpoint(config.endpoint)
230230

231-
if is_first_worker(config):
232-
# Register the model with runtime config
233-
await register_llm(
234-
modelType,
235-
endpoint,
236-
config.model_path,
237-
config.served_model_name,
238-
kv_cache_block_size=config.kv_block_size,
239-
migration_limit=config.migration_limit,
240-
)
241231
# publisher will be set later if publishing is enabled.
242232
handler_config = RequestHandlerConfig(
243233
component=component,
@@ -250,6 +240,23 @@ async def init(runtime: DistributedRuntime, config: Config):
250240
multimodal_processor=multimodal_processor,
251241
)
252242

243+
if next_client:
244+
logging.info(
245+
f"Waiting for the next endpoint to be ready: {config.next_endpoint}"
246+
)
247+
next_client.wait_for_instances()
248+
249+
if is_first_worker(config):
250+
# Register the model with runtime config
251+
await register_llm(
252+
modelType,
253+
endpoint,
254+
config.model_path,
255+
config.served_model_name,
256+
kv_cache_block_size=config.kv_block_size,
257+
migration_limit=config.migration_limit,
258+
)
259+
253260
if config.publish_events_and_metrics and is_first_worker(config):
254261
# Initialize and pass in the publisher to the request handler to
255262
# publish events and metrics.
@@ -265,6 +272,7 @@ async def init(runtime: DistributedRuntime, config: Config):
265272
) as publisher:
266273
handler_config.publisher = publisher
267274
handler = RequestHandlerFactory().get_request_handler(handler_config)
275+
268276
await endpoint.serve_endpoint(handler.generate)
269277
else:
270278
handler = RequestHandlerFactory().get_request_handler(handler_config)

0 commit comments

Comments
 (0)