Skip to content

Commit 170129e

Browse files
morgendavezhiweizaarnphmsimon-mohouseroad
authored
[gpt-oss] Harmony changes with container tool support (#23386)
Signed-off-by: zhiweiz <zhiweiz@fb.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: zhiweiz <zhiweiz@fb.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
1 parent 955c624 commit 170129e

File tree

5 files changed

+170
-27
lines changed

5 files changed

+170
-27
lines changed

vllm/entrypoints/context.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import asyncio
4+
import contextlib
35
import json
46
import logging
57
from abc import ABC, abstractmethod
@@ -57,9 +59,14 @@ def render_for_completion(self) -> list[int]:
5759

5860
@abstractmethod
5961
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
60-
exit_stack: AsyncExitStack) -> None:
62+
exit_stack: AsyncExitStack,
63+
request_id: str) -> None:
6164
pass
6265

66+
@abstractmethod
67+
async def cleanup_session(self) -> None:
68+
raise NotImplementedError("Should not be called.")
69+
6370

6471
class SimpleContext(ConversationContext):
6572

@@ -89,9 +96,13 @@ def render_for_completion(self) -> list[int]:
8996
raise NotImplementedError("Should not be called.")
9097

9198
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
92-
exit_stack: AsyncExitStack) -> None:
99+
exit_stack: AsyncExitStack,
100+
request_id: str) -> None:
93101
pass
94102

103+
async def cleanup_session(self) -> None:
104+
raise NotImplementedError("Should not be called.")
105+
95106

96107
class HarmonyContext(ConversationContext):
97108

@@ -103,6 +114,7 @@ def __init__(
103114
self._messages = messages
104115
self.available_tools = available_tools
105116
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
117+
self.called_tools: set[str] = set()
106118

107119
self.parser = get_streamable_parser_for_assistant()
108120
self.num_init_messages = len(messages)
@@ -234,7 +246,8 @@ def need_builtin_tool_call(self) -> bool:
234246
last_msg = self.messages[-1]
235247
recipient = last_msg.recipient
236248
return recipient is not None and (recipient.startswith("browser.")
237-
or recipient.startswith("python"))
249+
or recipient.startswith("python") or
250+
recipient.startswith("container."))
238251

239252
async def call_tool(self) -> list[Message]:
240253
if not self.messages:
@@ -248,6 +261,9 @@ async def call_tool(self) -> list[Message]:
248261
elif recipient.startswith("python"):
249262
return await self.call_python_tool(
250263
self._tool_sessions["python"], last_msg)
264+
elif recipient.startswith("container."):
265+
return await self.call_container_tool(
266+
self._tool_sessions["container"], last_msg)
251267
raise ValueError("No tool call found")
252268

253269
def render_for_completion(self) -> list[int]:
@@ -256,6 +272,7 @@ def render_for_completion(self) -> list[int]:
256272
async def call_search_tool(self, tool_session: Union["ClientSession",
257273
Tool],
258274
last_msg: Message) -> list[Message]:
275+
self.called_tools.add("browser")
259276
if isinstance(tool_session, Tool):
260277
return await tool_session.get_result(self)
261278
tool_name = last_msg.recipient.split(".")[1]
@@ -265,12 +282,16 @@ async def call_search_tool(self, tool_session: Union["ClientSession",
265282
content = TextContent(text=result_str)
266283
author = Author(role=Role.TOOL, name=last_msg.recipient)
267284
return [
268-
Message(author=author, content=[content], recipient=Role.ASSISTANT)
285+
Message(author=author,
286+
content=[content],
287+
recipient=Role.ASSISTANT,
288+
channel=last_msg.channel)
269289
]
270290

271291
async def call_python_tool(self, tool_session: Union["ClientSession",
272292
Tool],
273293
last_msg: Message) -> list[Message]:
294+
self.called_tools.add("python")
274295
if isinstance(tool_session, Tool):
275296
return await tool_session.get_result(self)
276297
param = {
@@ -290,13 +311,63 @@ async def call_python_tool(self, tool_session: Union["ClientSession",
290311
]
291312

292313
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
293-
exit_stack: AsyncExitStack) -> None:
314+
exit_stack: AsyncExitStack,
315+
request_id: str) -> None:
294316
if tool_server:
295317
for tool_name in self.available_tools:
296318
if tool_name not in self._tool_sessions:
297-
self._tool_sessions[
298-
tool_name] = await exit_stack.enter_async_context(
299-
tool_server.new_session(tool_name))
319+
tool_session = await exit_stack.enter_async_context(
320+
tool_server.new_session(tool_name, request_id))
321+
self._tool_sessions[tool_name] = tool_session
322+
exit_stack.push_async_exit(self.cleanup_session)
323+
324+
async def call_container_tool(self, tool_session: Union["ClientSession",
325+
Tool],
326+
last_msg: Message) -> list[Message]:
327+
"""
328+
Call container tool. Expect this to be run in a stateful docker
329+
with command line terminal.
330+
The official container tool would at least
331+
expect the following format:
332+
- for tool name: exec
333+
- args:
334+
{
335+
"cmd":List[str] "command to execute",
336+
"workdir":optional[str] "current working directory",
337+
"env":optional[object/dict] "environment variables",
338+
"session_name":optional[str] "session name",
339+
"timeout":optional[int] "timeout in seconds",
340+
"user":optional[str] "user name",
341+
}
342+
"""
343+
self.called_tools.add("container")
344+
if isinstance(tool_session, Tool):
345+
return await tool_session.get_result(self)
346+
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
347+
args = json.loads(last_msg.content[0].text)
348+
result = await tool_session.call_tool(tool_name, args)
349+
result_str = result.content[0].text
350+
content = TextContent(text=result_str)
351+
author = Author(role=Role.TOOL, name=last_msg.recipient)
352+
return [
353+
Message(author=author,
354+
content=[content],
355+
recipient=Role.ASSISTANT,
356+
channel=last_msg.channel)
357+
]
358+
359+
async def cleanup_session(self, *args, **kwargs) -> None:
360+
"""Can be used as coro to used in __aexit__"""
361+
362+
async def cleanup_tool_session(tool_session):
363+
if not isinstance(tool_session, Tool):
364+
logger.info("Cleaning up tool session for %s",
365+
tool_session._client_info)
366+
with contextlib.suppress(Exception):
367+
await tool_session.call_tool("cleanup_session", {})
368+
369+
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
370+
for tool in self.called_tools))
300371

301372

302373
class StreamingHarmonyContext(HarmonyContext):

vllm/entrypoints/harmony_utils.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from openai.types.responses.response_reasoning_item import (
1717
Content as ResponseReasoningTextContent)
1818
from openai.types.responses.tool import Tool
19-
from openai_harmony import (Author, Conversation, DeveloperContent,
20-
HarmonyEncodingName, Message, ReasoningEffort,
21-
Role, StreamableParser, SystemContent, TextContent,
22-
ToolDescription, load_harmony_encoding)
19+
from openai_harmony import (Author, ChannelConfig, Conversation,
20+
DeveloperContent, HarmonyEncodingName, Message,
21+
ReasoningEffort, Role, StreamableParser,
22+
SystemContent, TextContent, ToolDescription,
23+
load_harmony_encoding)
2324

25+
from vllm import envs
2426
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
2527
ResponseInputOutputItem)
2628
from vllm.utils import random_uuid
@@ -33,6 +35,20 @@
3335

3436
_harmony_encoding = None
3537

38+
# Builtin tools that should be included in the system message when
39+
# they are available and requested by the user.
40+
# Tool args are provided by MCP tool descriptions. Output
41+
# of the tools are stringified.
42+
BUILTIN_TOOLS = {
43+
"web_search_preview",
44+
"code_interpreter",
45+
"container",
46+
}
47+
48+
49+
def has_custom_tools(tool_types: list[str]) -> bool:
50+
return not set(tool_types).issubset(BUILTIN_TOOLS)
51+
3652

3753
def get_encoding():
3854
global _harmony_encoding
@@ -48,10 +64,19 @@ def get_system_message(
4864
start_date: Optional[str] = None,
4965
browser_description: Optional[str] = None,
5066
python_description: Optional[str] = None,
67+
container_description: Optional[str] = None,
68+
instructions: Optional[str] = None,
69+
with_custom_tools: bool = False,
5170
) -> Message:
5271
sys_msg_content = SystemContent.new()
5372
if model_identity is not None:
5473
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
74+
if (instructions is not None
75+
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
76+
current_identity = sys_msg_content.model_identity
77+
new_identity = (f'{current_identity}\n{instructions}'
78+
if current_identity else instructions)
79+
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
5580
if reasoning_effort is not None:
5681
sys_msg_content = sys_msg_content.with_reasoning_effort(
5782
REASONING_EFFORT[reasoning_effort])
@@ -63,6 +88,14 @@ def get_system_message(
6388
sys_msg_content = sys_msg_content.with_tools(browser_description)
6489
if python_description is not None:
6590
sys_msg_content = sys_msg_content.with_tools(python_description)
91+
if container_description is not None:
92+
sys_msg_content = sys_msg_content.with_tools(container_description)
93+
if not with_custom_tools:
94+
channel_config = sys_msg_content.channel_config
95+
invalid_channel = "commentary"
96+
new_config = ChannelConfig.require_channels(
97+
[c for c in channel_config.valid_channels if c != invalid_channel])
98+
sys_msg_content = sys_msg_content.with_channel_config(new_config)
6699
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
67100
return sys_msg
68101

@@ -86,14 +119,17 @@ def get_developer_message(
86119
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
87120
) -> Message:
88121
dev_msg_content = DeveloperContent.new()
89-
if instructions is not None:
122+
if (instructions is not None
123+
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
90124
dev_msg_content = dev_msg_content.with_instructions(instructions)
91125
if tools is not None:
92126
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
93127
for tool in tools:
94-
if tool.type in ("web_search_preview", "code_interpreter"):
128+
if tool.type in ("web_search_preview", "code_interpreter",
129+
"container"):
95130
# These are built-in tools that are added to the system message.
96131
pass
132+
97133
elif tool.type == "function":
98134
function_tools.append(tool)
99135
else:
@@ -136,6 +172,8 @@ def parse_response_input(
136172
TextContent(text=text_prefix + c["text"]) for c in content
137173
]
138174
msg = Message.from_role_and_contents(role, contents)
175+
if role == "assistant":
176+
msg = msg.with_channel("final")
139177
elif response_msg["type"] == "function_call_output":
140178
call_id = response_msg["call_id"]
141179
call_response: Optional[ResponseFunctionToolCall] = None

vllm/entrypoints/openai/serving_responses.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@
4444
SimpleContext, StreamingHarmonyContext)
4545
from vllm.entrypoints.harmony_utils import (
4646
get_developer_message, get_stop_tokens_for_assistant_actions,
47-
get_system_message, get_user_message, parse_output_message,
48-
parse_remaining_state, parse_response_input, render_for_completion)
47+
get_system_message, get_user_message, has_custom_tools,
48+
parse_output_message, parse_remaining_state, parse_response_input,
49+
render_for_completion)
4950
from vllm.entrypoints.logger import RequestLogger
5051
# yapf conflicts with isort for this block
5152
# yapf: disable
@@ -266,6 +267,8 @@ async def create_responses(
266267
builtin_tool_list.append("browser")
267268
if self.tool_server.has_tool("python"):
268269
builtin_tool_list.append("python")
270+
if self.tool_server.has_tool("container"):
271+
builtin_tool_list.append("container")
269272

270273
if self.tool_server is not None:
271274
available_tools = builtin_tool_list
@@ -448,7 +451,8 @@ async def responses_full_generator(
448451

449452
async with AsyncExitStack() as exit_stack:
450453
try:
451-
await context.init_tool_sessions(self.tool_server, exit_stack)
454+
await context.init_tool_sessions(self.tool_server, exit_stack,
455+
request.request_id)
452456
async for _ in result_generator:
453457
pass
454458
except asyncio.CancelledError:
@@ -710,13 +714,21 @@ def _construct_input_messages_with_harmony(
710714
# New conversation.
711715
reasoning_effort = (request.reasoning.effort
712716
if request.reasoning else None)
717+
# Temporary: OpenAI types doesn't have container tool
718+
# so we used MCP to cover that, up for change
713719
tool_types = [tool.type for tool in request.tools]
720+
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
721+
tool_types.append("container")
714722
enable_browser = ("web_search_preview" in tool_types
715723
and self.tool_server is not None
716724
and self.tool_server.has_tool("browser"))
717725
enable_code_interpreter = ("code_interpreter" in tool_types
718726
and self.tool_server is not None
719727
and self.tool_server.has_tool("python"))
728+
enable_container = ("container" in tool_types
729+
and self.tool_server is not None
730+
and self.tool_server.has_tool("container"))
731+
with_custom_tools = has_custom_tools(tool_types)
720732
sys_msg = get_system_message(
721733
reasoning_effort=reasoning_effort,
722734
browser_description=self.tool_server.get_tool_description(
@@ -725,11 +737,17 @@ def _construct_input_messages_with_harmony(
725737
python_description=self.tool_server.get_tool_description(
726738
"python") if enable_code_interpreter
727739
and self.tool_server is not None else None,
740+
container_description=self.tool_server.get_tool_description(
741+
"container")
742+
if enable_container and self.tool_server is not None else None,
743+
instructions=request.instructions,
744+
with_custom_tools=with_custom_tools,
728745
)
729746
messages.append(sys_msg)
730-
dev_msg = get_developer_message(request.instructions,
731-
request.tools)
732-
messages.append(dev_msg)
747+
if with_custom_tools:
748+
dev_msg = get_developer_message(
749+
instructions=request.instructions, tools=request.tools)
750+
messages.append(dev_msg)
733751
else:
734752
# Continue the previous conversation.
735753
# FIXME(woosuk): Currently, request params like reasoning and
@@ -1613,7 +1631,8 @@ def _send_event(event: BaseModel):
16131631
async with AsyncExitStack() as exit_stack:
16141632
processer = None
16151633
if self.use_harmony:
1616-
await context.init_tool_sessions(self.tool_server, exit_stack)
1634+
await context.init_tool_sessions(self.tool_server, exit_stack,
1635+
request.request_id)
16171636
processer = self._process_harmony_streaming_events
16181637
else:
16191638
processer = self._process_simple_streaming_events

vllm/entrypoints/tool_server.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def get_tool_description(self,
8686
pass
8787

8888
@abstractmethod
89-
def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
89+
def new_session(self, tool_name: str,
90+
session_id: str) -> AbstractAsyncContextManager[Any]:
9091
"""
9192
Create a session for the tool.
9293
"""
@@ -124,7 +125,8 @@ async def add_tool_server(self, server_url: str):
124125
description=tool.description,
125126
parameters=tool.inputSchema)
126127
for tool in list_tools_response.tools
127-
])
128+
],
129+
)
128130
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
129131
if tool_from_mcp.name not in self.urls:
130132
self.urls[tool_from_mcp.name] = url
@@ -142,14 +144,16 @@ def get_tool_description(self, tool_name: str):
142144
return self.harmony_tool_descriptions.get(tool_name)
143145

144146
@asynccontextmanager
145-
async def new_session(self, tool_name: str):
147+
async def new_session(self, tool_name: str, session_id: str):
146148
from mcp import ClientSession
147149
from mcp.client.sse import sse_client
148150
url = self.urls.get(tool_name)
151+
headers = {"x-session-id": session_id}
149152
if not url:
150153
raise KeyError(f"Tool '{tool_name}' is not supported")
151-
async with sse_client(url=url) as streams, ClientSession(
152-
*streams) as session:
154+
async with sse_client(url=url,
155+
headers=headers) as streams, ClientSession(
156+
*streams) as session:
153157
await session.initialize()
154158
yield session
155159

@@ -182,7 +186,7 @@ def get_tool_description(self,
182186
raise ValueError(f"Unknown tool {tool_name}")
183187

184188
@asynccontextmanager
185-
async def new_session(self, tool_name: str):
189+
async def new_session(self, tool_name: str, session_id: str):
186190
if tool_name not in self.tools:
187191
raise KeyError(f"Tool '{tool_name}' is not supported")
188192
yield self.tools[tool_name]

0 commit comments

Comments
 (0)