Skip to content

Commit 2fe095c

Browse files
committed
lazy init mcp session
Signed-off-by: wuhang <wuhang6@huawei.com>
1 parent 467a4f9 commit 2fe095c

File tree

2 files changed

+61
-28
lines changed

2 files changed

+61
-28
lines changed

vllm/entrypoints/context.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ async def init_tool_sessions(
8989
) -> None:
9090
pass
9191

92+
@abstractmethod
93+
async def __aenter__(self):
94+
pass
95+
96+
@abstractmethod
97+
async def __aexit__(self, exc_type, exc, tb):
98+
pass
99+
92100
@abstractmethod
93101
async def cleanup_session(self) -> None:
94102
raise NotImplementedError("Should not be called.")
@@ -129,6 +137,12 @@ async def init_tool_sessions(
129137
) -> None:
130138
pass
131139

140+
async def __aenter__(self):
141+
return self
142+
143+
async def __aexit__(self, exc_type, exc, tb):
144+
pass
145+
132146
async def cleanup_session(self) -> None:
133147
raise NotImplementedError("Should not be called.")
134148

@@ -138,12 +152,17 @@ def __init__(
138152
self,
139153
messages: list,
140154
available_tools: list[str],
155+
tool_server: Optional[ToolServer],
141156
):
142157
self._messages = messages
143158
self.finish_reason: Optional[str] = None
144159
self.available_tools = available_tools
145160
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
146161
self.called_tools: set[str] = set()
162+
self._tool_server = tool_server
163+
self._async_exit_stack: Optional[AsyncExitStack] = None
164+
self._reference_count = 0
165+
self._reference_count_lock = asyncio.Lock()
147166

148167
self.parser = get_streamable_parser_for_assistant()
149168
self.num_init_messages = len(messages)
@@ -288,6 +307,18 @@ def need_builtin_tool_call(self) -> bool:
288307
or recipient.startswith("container.")
289308
)
290309

310+
async def _get_tool_session(self, tool_name: str) -> Union["ClientSession", Tool]:
311+
if tool_name not in self._tool_sessions and self._tool_server is not None:
312+
assert self._async_exit_stack is not None, (
313+
"Async exit stack not set. Please report this issue."
314+
)
315+
self._tool_sessions[
316+
tool_name
317+
] = await self._async_exit_stack.enter_async_context(
318+
self._tool_server.new_session(tool_name)
319+
)
320+
return self._tool_sessions[tool_name]
321+
291322
async def call_tool(self) -> list[Message]:
292323
if not self.messages:
293324
return []
@@ -296,15 +327,15 @@ async def call_tool(self) -> list[Message]:
296327
if recipient is not None:
297328
if recipient.startswith("browser."):
298329
return await self.call_search_tool(
299-
self._tool_sessions["browser"], last_msg
330+
await self._get_tool_session("browser"), last_msg
300331
)
301332
elif recipient.startswith("python"):
302333
return await self.call_python_tool(
303-
self._tool_sessions["python"], last_msg
334+
await self._get_tool_session("python"), last_msg
304335
)
305336
elif recipient.startswith("container."):
306337
return await self.call_container_tool(
307-
self._tool_sessions["container"], last_msg
338+
await self._get_tool_session("container"), last_msg
308339
)
309340
raise ValueError("No tool call found")
310341

@@ -431,6 +462,25 @@ async def cleanup_tool_session(tool_session):
431462
)
432463
)
433464

465+
async def __aenter__(self):
466+
async with self._reference_count_lock:
467+
self._reference_count += 1
468+
if self._async_exit_stack is None:
469+
assert self._reference_count == 1, (
470+
"Reference count of exit stack should be "
471+
)
472+
"1 when initializing exit stack."
473+
self._async_exit_stack = AsyncExitStack()
474+
await self._async_exit_stack.__aenter__()
475+
return self
476+
477+
async def __aexit__(self, exc_type, exc, tb):
478+
async with self._reference_count_lock:
479+
self._reference_count -= 1
480+
if self._reference_count == 0 and self._async_exit_stack is not None:
481+
await self._async_exit_stack.__aexit__(exc_type, exc, tb)
482+
self._async_exit_stack = None
483+
434484

435485
class StreamingHarmonyContext(HarmonyContext):
436486
def __init__(self, *args, **kwargs):

vllm/entrypoints/openai/serving_responses.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import uuid
88
from collections import deque
99
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
10-
from contextlib import AsyncExitStack
1110
from copy import copy
1211
from http import HTTPStatus
1312
from typing import Callable, Final, Optional, Union
@@ -364,9 +363,13 @@ async def create_responses(
364363
context: ConversationContext
365364
if self.use_harmony:
366365
if request.stream:
367-
context = StreamingHarmonyContext(messages, available_tools)
366+
context = StreamingHarmonyContext(
367+
messages, available_tools, self.tool_server
368+
)
368369
else:
369-
context = HarmonyContext(messages, available_tools)
370+
context = HarmonyContext(
371+
messages, available_tools, self.tool_server
372+
)
370373
else:
371374
context = SimpleContext()
372375
generator = self._generate_with_builtin_tools(
@@ -510,22 +513,6 @@ def _make_request_with_harmony(
510513

511514
return messages, [prompt_token_ids], [engine_prompt]
512515

513-
async def _initialize_tool_sessions(
514-
self,
515-
request: ResponsesRequest,
516-
context: ConversationContext,
517-
exit_stack: AsyncExitStack,
518-
):
519-
# we should only initialize the tool session if the request needs tools
520-
if len(request.tools) == 0:
521-
return
522-
mcp_tools = {
523-
tool.server_label: tool for tool in request.tools if tool.type == "mcp"
524-
}
525-
await context.init_tool_sessions(
526-
self.tool_server, exit_stack, request.request_id, mcp_tools
527-
)
528-
529516
async def responses_full_generator(
530517
self,
531518
request: ResponsesRequest,
@@ -540,9 +527,8 @@ async def responses_full_generator(
540527
if created_time is None:
541528
created_time = int(time.time())
542529

543-
async with AsyncExitStack() as exit_stack:
530+
async with context:
544531
try:
545-
await self._initialize_tool_sessions(request, context, exit_stack)
546532
async for _ in result_generator:
547533
pass
548534
except asyncio.CancelledError:
@@ -1809,12 +1795,9 @@ def _increment_sequence_number_and_return(
18091795
sequence_number += 1
18101796
return event
18111797

1812-
async with AsyncExitStack() as exit_stack:
1798+
async with context:
18131799
processer = None
18141800
if self.use_harmony:
1815-
# TODO: in streaming, we noticed this bug:
1816-
# https://github.com/vllm-project/vllm/issues/25697
1817-
await self._initialize_tool_sessions(request, context, exit_stack)
18181801
processer = self._process_harmony_streaming_events
18191802
else:
18201803
processer = self._process_simple_streaming_events

0 commit comments

Comments
 (0)