-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat(agent): support multiple tool groups #1556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -614,117 +614,132 @@ async def _run( | |
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") | ||
input_messages = input_messages + [message] | ||
else: | ||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}") | ||
# 1. Start the tool execution step and progress | ||
step_id = str(uuid.uuid4()) | ||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepStartPayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=step_id, | ||
input_messages = input_messages + [message] | ||
|
||
# Process tool calls in the message | ||
client_tool_calls = [] | ||
non_client_tool_calls = [] | ||
|
||
# Separate client and non-client tool calls | ||
for tool_call in message.tool_calls: | ||
if tool_call.tool_name in client_tools: | ||
client_tool_calls.append(tool_call) | ||
else: | ||
non_client_tool_calls.append(tool_call) | ||
|
||
# Process non-client tool calls first | ||
for tool_call in non_client_tool_calls: | ||
step_id = str(uuid.uuid4()) | ||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepStartPayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=step_id, | ||
) | ||
) | ||
) | ||
) | ||
tool_call = message.tool_calls[0] | ||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepProgressPayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=step_id, | ||
tool_call=tool_call, | ||
delta=ToolCallDelta( | ||
parse_status=ToolCallParseStatus.in_progress, | ||
tool_call=tool_call, | ||
), | ||
|
||
yield AgentTurnResponseStreamChunk( | ||
ashwinb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepProgressPayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=step_id, | ||
delta=ToolCallDelta( | ||
parse_status=ToolCallParseStatus.in_progress, | ||
tool_call=tool_call, | ||
), | ||
) | ||
) | ||
) | ||
) | ||
|
||
# If tool is a client tool, yield CompletionMessage and return | ||
if tool_call.tool_name in client_tools: | ||
# NOTE: mark end_of_message to indicate to client that it may | ||
# call the tool and continue the conversation with the tool's response. | ||
message.stop_reason = StopReason.end_of_message | ||
# Execute the tool call | ||
async with tracing.span( | ||
"tool_execution", | ||
{ | ||
"tool_name": tool_call.tool_name, | ||
"input": message.model_dump_json(), | ||
}, | ||
) as span: | ||
tool_execution_start_time = datetime.now(timezone.utc).isoformat() | ||
tool_result = await self.execute_tool_call_maybe( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this fail or throw? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, maybe we can have a catch all and return some generic error message. |
||
session_id, | ||
tool_call, | ||
) | ||
if tool_result.content is None: | ||
raise ValueError( | ||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content" | ||
) | ||
result_message = ToolResponseMessage( | ||
call_id=tool_call.call_id, | ||
content=tool_result.content, | ||
) | ||
span.set_attribute("output", result_message.model_dump_json()) | ||
|
||
# Store tool execution step | ||
tool_execution_step = ToolExecutionStep( | ||
step_id=step_id, | ||
turn_id=turn_id, | ||
tool_calls=[tool_call], | ||
tool_responses=[ | ||
ToolResponse( | ||
call_id=tool_call.call_id, | ||
tool_name=tool_call.tool_name, | ||
content=tool_result.content, | ||
metadata=tool_result.metadata, | ||
) | ||
], | ||
started_at=tool_execution_start_time, | ||
completed_at=datetime.now(timezone.utc).isoformat(), | ||
) | ||
|
||
# Yield the step completion event | ||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepCompletePayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=step_id, | ||
step_details=tool_execution_step, | ||
) | ||
) | ||
) | ||
|
||
# Add the result message to input_messages for the next iteration | ||
input_messages.append(result_message) | ||
|
||
# TODO: add tool-input touchpoint and a "start" event for this step also | ||
# but that needs a lot more refactoring of Tool code potentially | ||
if (type(result_message.content) is str) and ( | ||
out_attachment := _interpret_content_as_attachment(result_message.content) | ||
): | ||
# NOTE: when we push this message back to the model, the model may ignore the | ||
# attached file path etc. since the model is trained to only provide a user message | ||
# with the summary. We keep all generated attachments and then attach them to final message | ||
output_attachments.append(out_attachment) | ||
|
||
# If there are client tool calls, yield a message with only those tool calls | ||
if client_tool_calls: | ||
await self.storage.set_in_progress_tool_call_step( | ||
session_id, | ||
turn_id, | ||
ToolExecutionStep( | ||
step_id=step_id, | ||
turn_id=turn_id, | ||
tool_calls=[tool_call], | ||
tool_calls=client_tool_calls, | ||
tool_responses=[], | ||
started_at=datetime.now(timezone.utc).isoformat(), | ||
), | ||
) | ||
yield message | ||
return | ||
|
||
# If tool is a builtin server tool, execute it | ||
tool_name = tool_call.tool_name | ||
if isinstance(tool_name, BuiltinTool): | ||
tool_name = tool_name.value | ||
async with tracing.span( | ||
"tool_execution", | ||
{ | ||
"tool_name": tool_name, | ||
"input": message.model_dump_json(), | ||
}, | ||
) as span: | ||
tool_execution_start_time = datetime.now(timezone.utc).isoformat() | ||
tool_call = message.tool_calls[0] | ||
tool_result = await self.execute_tool_call_maybe( | ||
session_id, | ||
tool_call, | ||
) | ||
if tool_result.content is None: | ||
raise ValueError( | ||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content" | ||
) | ||
result_messages = [ | ||
ToolResponseMessage( | ||
call_id=tool_call.call_id, | ||
content=tool_result.content, | ||
) | ||
] | ||
assert len(result_messages) == 1, "Currently not supporting multiple messages" | ||
result_message = result_messages[0] | ||
span.set_attribute("output", result_message.model_dump_json()) | ||
|
||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepCompletePayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=step_id, | ||
step_details=ToolExecutionStep( | ||
step_id=step_id, | ||
turn_id=turn_id, | ||
tool_calls=[tool_call], | ||
tool_responses=[ | ||
ToolResponse( | ||
call_id=result_message.call_id, | ||
tool_name=tool_call.tool_name, | ||
content=result_message.content, | ||
metadata=tool_result.metadata, | ||
) | ||
], | ||
started_at=tool_execution_start_time, | ||
completed_at=datetime.now(timezone.utc).isoformat(), | ||
), | ||
) | ||
) | ||
) | ||
|
||
# TODO: add tool-input touchpoint and a "start" event for this step also | ||
# but that needs a lot more refactoring of Tool code potentially | ||
if (type(result_message.content) is str) and ( | ||
out_attachment := _interpret_content_as_attachment(result_message.content) | ||
): | ||
# NOTE: when we push this message back to the model, the model may ignore the | ||
# attached file path etc. since the model is trained to only provide a user message | ||
# with the summary. We keep all generated attachments and then attach them to final message | ||
output_attachments.append(out_attachment) | ||
# Create a copy of the message with only client tool calls | ||
client_message = message.model_copy(deep=True) | ||
client_message.tool_calls = client_tool_calls | ||
# NOTE: mark end_of_message to indicate to client that it may | ||
# call the tool and continue the conversation with the tool's response. | ||
client_message.stop_reason = StopReason.end_of_message | ||
|
||
input_messages = input_messages + [message, result_message] | ||
# Yield the message with client tool calls | ||
yield client_message | ||
return | ||
|
||
async def _initialize_tools( | ||
self, | ||
|
@@ -891,16 +906,14 @@ async def handle_documents( | |
if memory_tool and code_interpreter_tool: | ||
# if both memory and code_interpreter are available, we download the URLs | ||
# and attach the data to the last message. | ||
msg = await attachment_message(self.tempdir, url_items) | ||
input_messages.append(msg) | ||
await attachment_message(self.tempdir, url_items, input_messages[-1]) | ||
# Since memory is present, add all the data to the memory bank | ||
await self.add_to_session_vector_db(session_id, documents) | ||
elif code_interpreter_tool: | ||
# if only code_interpreter is available, we download the URLs to a tempdir | ||
# and attach the path to them as a message to inference with the | ||
# assumption that the model invokes the code_interpreter tool with the path | ||
msg = await attachment_message(self.tempdir, url_items) | ||
input_messages.append(msg) | ||
await attachment_message(self.tempdir, url_items, input_messages[-1]) | ||
elif memory_tool: | ||
# if only memory is available, we load the data from the URLs and content items to the memory bank | ||
await self.add_to_session_vector_db(session_id, documents) | ||
|
@@ -967,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]: | |
return data | ||
|
||
|
||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage: | ||
content = [] | ||
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None: | ||
contents = [] | ||
|
||
for url in urls: | ||
uri = url.uri | ||
|
@@ -988,16 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa | |
else: | ||
raise ValueError(f"Unsupported URL {url}") | ||
|
||
content.append( | ||
contents.append( | ||
TextContentItem( | ||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' | ||
) | ||
) | ||
|
||
return ToolResponseMessage( | ||
call_id="", | ||
content=content, | ||
) | ||
if isinstance(message.content, list): | ||
message.content.extend(contents) | ||
else: | ||
if isinstance(message.content, str): | ||
message.content = [TextContentItem(text=message.content)] + contents | ||
else: | ||
message.content = [message.content] + contents | ||
|
||
|
||
def _interpret_content_as_attachment( | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of this code inline in this method is becoming quite hard to scan through. I think we need another file where we write some simple functions (all of these inside the
agent_instance
object is also a FAIL mode that I had long before initiated ever since our first agent impl lol)