Skip to content

Commit 173f243

Browse files
committed
rebase main
Signed-off-by: wuhang <wuhang6@huawei.com>
1 parent 2fe095c commit 173f243

File tree

2 files changed

+39
-54
lines changed

2 files changed

+39
-54
lines changed

vllm/entrypoints/context.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,6 @@ def need_builtin_tool_call(self) -> bool:
7979
def render_for_completion(self) -> list[int]:
8080
pass
8181

82-
@abstractmethod
83-
async def init_tool_sessions(
84-
self,
85-
tool_server: Optional[ToolServer],
86-
exit_stack: AsyncExitStack,
87-
request_id: str,
88-
mcp_tools: dict[str, Mcp],
89-
) -> None:
90-
pass
91-
9282
@abstractmethod
9383
async def __aenter__(self):
9484
pass
@@ -128,15 +118,6 @@ async def call_tool(self) -> list[Message]:
128118
def render_for_completion(self) -> list[int]:
129119
raise NotImplementedError("Should not be called.")
130120

131-
async def init_tool_sessions(
132-
self,
133-
tool_server: Optional[ToolServer],
134-
exit_stack: AsyncExitStack,
135-
request_id: str,
136-
mcp_tools: dict[str, Mcp],
137-
) -> None:
138-
pass
139-
140121
async def __aenter__(self):
141122
return self
142123

@@ -153,13 +134,17 @@ def __init__(
153134
messages: list,
154135
available_tools: list[str],
155136
tool_server: Optional[ToolServer],
137+
request_id: str,
138+
mcp_tools: dict[str, Mcp],
156139
):
157140
self._messages = messages
158141
self.finish_reason: Optional[str] = None
159142
self.available_tools = available_tools
160143
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
161144
self.called_tools: set[str] = set()
162145
self._tool_server = tool_server
146+
self.request_id = request_id
147+
self.mcp_tools = mcp_tools
163148
self._async_exit_stack: Optional[AsyncExitStack] = None
164149
self._reference_count = 0
165150
self._reference_count_lock = asyncio.Lock()
@@ -307,18 +292,6 @@ def need_builtin_tool_call(self) -> bool:
307292
or recipient.startswith("container.")
308293
)
309294

310-
async def _get_tool_session(self, tool_name: str) -> Union["ClientSession", Tool]:
311-
if tool_name not in self._tool_sessions and self._tool_server is not None:
312-
assert self._async_exit_stack is not None, (
313-
"Async exit stack not set. Please report this issue."
314-
)
315-
self._tool_sessions[
316-
tool_name
317-
] = await self._async_exit_stack.enter_async_context(
318-
self._tool_server.new_session(tool_name)
319-
)
320-
return self._tool_sessions[tool_name]
321-
322295
async def call_tool(self) -> list[Message]:
323296
if not self.messages:
324297
return []
@@ -342,6 +315,24 @@ async def call_tool(self) -> list[Message]:
342315
def render_for_completion(self) -> list[int]:
343316
return render_for_completion(self.messages)
344317

318+
async def _get_tool_session(self, tool_name: str) -> Union["ClientSession", Tool]:
319+
if tool_name not in self._tool_sessions and self._tool_server is not None:
320+
assert self._async_exit_stack is not None, (
321+
"Async exit stack not set. Please report this issue."
322+
)
323+
tool_type = _map_tool_name_to_tool_type(tool_name)
324+
headers = (
325+
self.mcp_tools[tool_type].headers
326+
if tool_type in self.mcp_tools
327+
else None
328+
)
329+
self._tool_sessions[
330+
tool_name
331+
] = await self._async_exit_stack.enter_async_context(
332+
self._tool_server.new_session(tool_name, self.request_id, headers)
333+
)
334+
return self._tool_sessions[tool_name]
335+
345336
async def call_search_tool(
346337
self, tool_session: Union["ClientSession", Tool], last_msg: Message
347338
) -> list[Message]:
@@ -387,26 +378,6 @@ async def call_python_tool(
387378
)
388379
]
389380

390-
async def init_tool_sessions(
391-
self,
392-
tool_server: Optional[ToolServer],
393-
exit_stack: AsyncExitStack,
394-
request_id: str,
395-
mcp_tools: dict[str, Mcp],
396-
):
397-
if tool_server:
398-
for tool_name in self.available_tools:
399-
if tool_name not in self._tool_sessions:
400-
tool_type = _map_tool_name_to_tool_type(tool_name)
401-
headers = (
402-
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
403-
)
404-
tool_session = await exit_stack.enter_async_context(
405-
tool_server.new_session(tool_name, request_id, headers)
406-
)
407-
self._tool_sessions[tool_name] = tool_session
408-
exit_stack.push_async_exit(self.cleanup_session)
409-
410381
async def call_container_tool(
411382
self, tool_session: Union["ClientSession", Tool], last_msg: Message
412383
) -> list[Message]:
@@ -468,10 +439,11 @@ async def __aenter__(self):
468439
if self._async_exit_stack is None:
469440
assert self._reference_count == 1, (
470441
"Reference count of exit stack should be "
442+
"1 when initializing exit stack."
471443
)
472-
"1 when initializing exit stack."
473444
self._async_exit_stack = AsyncExitStack()
474445
await self._async_exit_stack.__aenter__()
446+
self._async_exit_stack.push_async_callback(self.cleanup_session)
475447
return self
476448

477449
async def __aexit__(self, exc_type, exc, tb):

vllm/entrypoints/openai/serving_responses.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,28 @@ async def create_responses(
360360
else await self._get_trace_headers(raw_request.headers)
361361
)
362362

363+
mcp_tools = {
364+
tool.server_label: tool
365+
for tool in request.tools
366+
if tool.type == "mcp"
367+
}
363368
context: ConversationContext
364369
if self.use_harmony:
365370
if request.stream:
366371
context = StreamingHarmonyContext(
367-
messages, available_tools, self.tool_server
372+
messages,
373+
available_tools,
374+
self.tool_server,
375+
request.request_id,
376+
mcp_tools,
368377
)
369378
else:
370379
context = HarmonyContext(
371-
messages, available_tools, self.tool_server
380+
messages,
381+
available_tools,
382+
self.tool_server,
383+
request.request_id,
384+
mcp_tools,
372385
)
373386
else:
374387
context = SimpleContext()

0 commit comments

Comments
 (0)