Skip to content

Commit e72ca86

Browse files
committed
lazy init mcp session
Signed-off-by: wuhang <wuhang6@huawei.com>
1 parent f0862ea commit e72ca86

File tree

2 files changed

+61
-27
lines changed

2 files changed

+61
-27
lines changed

vllm/entrypoints/context.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ async def init_tool_sessions(
104104
) -> None:
105105
pass
106106

107+
@abstractmethod
108+
async def __aenter__(self):
109+
pass
110+
111+
@abstractmethod
112+
async def __aexit__(self, exc_type, exc, tb):
113+
pass
114+
107115
@abstractmethod
108116
async def cleanup_session(self) -> None:
109117
raise NotImplementedError("Should not be called.")
@@ -146,6 +154,12 @@ async def init_tool_sessions(
146154
) -> None:
147155
pass
148156

157+
async def __aenter__(self):
158+
return self
159+
160+
async def __aexit__(self, exc_type, exc, tb):
161+
pass
162+
149163
async def cleanup_session(self) -> None:
150164
raise NotImplementedError("Should not be called.")
151165

@@ -155,12 +169,17 @@ def __init__(
155169
self,
156170
messages: list,
157171
available_tools: list[str],
172+
tool_server: Optional[ToolServer],
158173
):
159174
self._messages = messages
160175
self.finish_reason: str | None = None
161176
self.available_tools = available_tools
162177
self._tool_sessions: dict[str, ClientSession | Tool] = {}
163178
self.called_tools: set[str] = set()
179+
self._tool_server = tool_server
180+
self._async_exit_stack: Optional[AsyncExitStack] = None
181+
self._reference_count = 0
182+
self._reference_count_lock = asyncio.Lock()
164183

165184
self.parser = get_streamable_parser_for_assistant()
166185
self.num_init_messages = len(messages)
@@ -309,6 +328,18 @@ def need_builtin_tool_call(self) -> bool:
309328
or recipient.startswith("container.")
310329
)
311330

331+
async def _get_tool_session(self, tool_name: str) -> Union["ClientSession", Tool]:
332+
if tool_name not in self._tool_sessions and self._tool_server is not None:
333+
assert self._async_exit_stack is not None, (
334+
"Async exit stack not set. Please report this issue."
335+
)
336+
self._tool_sessions[
337+
tool_name
338+
] = await self._async_exit_stack.enter_async_context(
339+
self._tool_server.new_session(tool_name)
340+
)
341+
return self._tool_sessions[tool_name]
342+
312343
async def call_tool(self) -> list[Message]:
313344
if not self.messages:
314345
return []
@@ -317,15 +348,15 @@ async def call_tool(self) -> list[Message]:
317348
if recipient is not None:
318349
if recipient.startswith("browser."):
319350
return await self.call_search_tool(
320-
self._tool_sessions["browser"], last_msg
351+
await self._get_tool_session("browser"), last_msg
321352
)
322353
elif recipient.startswith("python"):
323354
return await self.call_python_tool(
324-
self._tool_sessions["python"], last_msg
355+
await self._get_tool_session("python"), last_msg
325356
)
326357
elif recipient.startswith("container."):
327358
return await self.call_container_tool(
328-
self._tool_sessions["container"], last_msg
359+
await self._get_tool_session("container"), last_msg
329360
)
330361
raise ValueError("No tool call found")
331362

@@ -452,6 +483,25 @@ async def cleanup_tool_session(tool_session):
452483
)
453484
)
454485

486+
async def __aenter__(self):
487+
async with self._reference_count_lock:
488+
self._reference_count += 1
489+
if self._async_exit_stack is None:
490+
assert self._reference_count == 1, (
491+
"Reference count of exit stack should be "
492+
)
493+
"1 when initializing exit stack."
494+
self._async_exit_stack = AsyncExitStack()
495+
await self._async_exit_stack.__aenter__()
496+
return self
497+
498+
async def __aexit__(self, exc_type, exc, tb):
499+
async with self._reference_count_lock:
500+
self._reference_count -= 1
501+
if self._reference_count == 0 and self._async_exit_stack is not None:
502+
await self._async_exit_stack.__aexit__(exc_type, exc, tb)
503+
self._async_exit_stack = None
504+
455505

456506
class StreamingHarmonyContext(HarmonyContext):
457507
def __init__(self, *args, **kwargs):

vllm/entrypoints/openai/serving_responses.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,13 @@ async def create_responses(
352352
context: ConversationContext
353353
if self.use_harmony:
354354
if request.stream:
355-
context = StreamingHarmonyContext(messages, available_tools)
355+
context = StreamingHarmonyContext(
356+
messages, available_tools, self.tool_server
357+
)
356358
else:
357-
context = HarmonyContext(messages, available_tools)
359+
context = HarmonyContext(
360+
messages, available_tools, self.tool_server
361+
)
358362
else:
359363
context = SimpleContext()
360364
generator = self._generate_with_builtin_tools(
@@ -498,22 +502,6 @@ def _make_request_with_harmony(
498502

499503
return messages, [prompt_token_ids], [engine_prompt]
500504

501-
async def _initialize_tool_sessions(
502-
self,
503-
request: ResponsesRequest,
504-
context: ConversationContext,
505-
exit_stack: AsyncExitStack,
506-
):
507-
# we should only initialize the tool session if the request needs tools
508-
if len(request.tools) == 0:
509-
return
510-
mcp_tools = {
511-
tool.server_label: tool for tool in request.tools if tool.type == "mcp"
512-
}
513-
await context.init_tool_sessions(
514-
self.tool_server, exit_stack, request.request_id, mcp_tools
515-
)
516-
517505
async def responses_full_generator(
518506
self,
519507
request: ResponsesRequest,
@@ -528,9 +516,8 @@ async def responses_full_generator(
528516
if created_time is None:
529517
created_time = int(time.time())
530518

531-
async with AsyncExitStack() as exit_stack:
519+
async with context:
532520
try:
533-
await self._initialize_tool_sessions(request, context, exit_stack)
534521
async for _ in result_generator:
535522
pass
536523
except asyncio.CancelledError:
@@ -1894,12 +1881,9 @@ def _increment_sequence_number_and_return(
18941881
sequence_number += 1
18951882
return event
18961883

1897-
async with AsyncExitStack() as exit_stack:
1884+
async with context:
18981885
processer = None
18991886
if self.use_harmony:
1900-
# TODO: in streaming, we noticed this bug:
1901-
# https://github.com/vllm-project/vllm/issues/25697
1902-
await self._initialize_tool_sessions(request, context, exit_stack)
19031887
processer = self._process_harmony_streaming_events
19041888
else:
19051889
processer = self._process_simple_streaming_events

0 commit comments

Comments
 (0)