Skip to content

Commit

Permalink
Added get_completion_parse method
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Aug 13, 2024
1 parent c24be1c commit 7f2e891
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 25 deletions.
78 changes: 67 additions & 11 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,37 @@
import os
import queue
import threading
import time
import uuid
from enum import Enum
from typing import List, TypedDict, Callable, Any, Dict, Literal, Union, Optional
from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, TypedDict, Union

from openai.lib._parsing._completions import type_to_response_format_param
from openai.types.beta.threads import Message
from openai.types.beta.threads.runs import RunStep
from pydantic import Field, field_validator, model_validator
from openai.types.beta.threads.runs.tool_call import (
CodeInterpreterToolCall,
FileSearchToolCall,
FunctionToolCall,
ToolCall,
)
from pydantic import BaseModel, Field, field_validator, model_validator
from rich.console import Console
from typing_extensions import override

from agency_swarm.agents import Agent
from agency_swarm.messages import MessageOutput
from agency_swarm.messages.message_output import MessageOutputLive
from agency_swarm.threads import Thread
from agency_swarm.tools import BaseTool, FileSearch, CodeInterpreter
from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch
from agency_swarm.user import User
from agency_swarm.util.errors import RefusalError
from agency_swarm.util.files import determine_file_type
from agency_swarm.util.shared_state import SharedState
from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, FileSearchToolCall


from agency_swarm.util.streaming import AgencyEventHandler

console = Console()

T = TypeVar('T', bound=BaseModel)

class SettingsCallbacks(TypedDict):
load: Callable[[], List[Dict]]
Expand Down Expand Up @@ -127,7 +132,8 @@ def get_completion(self, message: str,
additional_instructions: str = None,
attachments: List[dict] = None,
tool_choice: dict = None,
verbose: bool = False):
verbose: bool = False,
response_format: dict = None):
"""
Retrieves the completion for a given message from the main thread.
Expand All @@ -141,6 +147,7 @@ def get_completion(self, message: str,
tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None.
parallel_tool_calls (bool, optional): Whether to enable parallel function calling during tool use. Defaults to True.
verbose (bool, optional): Whether to print the intermediary messages in console. Defaults to False.
response_format (dict, optional): The response format to use for the completion.
Returns:
Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread.
Expand All @@ -154,7 +161,9 @@ def get_completion(self, message: str,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
yield_messages=yield_messages or verbose)
yield_messages=yield_messages or verbose,
response_format=response_format)

if not yield_messages or verbose:
while True:
try:
Expand All @@ -174,7 +183,8 @@ def get_completion_stream(self,
recipient_agent: Agent = None,
additional_instructions: str = None,
attachments: List[dict] = None,
tool_choice: dict = None):
tool_choice: dict = None,
response_format: dict = None):
"""
Generates a stream of completions for a given message from the main thread.
Expand All @@ -200,14 +210,60 @@ def get_completion_stream(self,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice)
tool_choice=tool_choice,
response_format=response_format)

while True:
try:
next(res)
except StopIteration as e:
event_handler.on_all_streams_end()

return e.value

def get_completion_parse(self, message: str,
response_format: Type[T],
message_files: List[str] = None,
recipient_agent: Agent = None,
additional_instructions: str = None,
attachments: List[dict] = None,
tool_choice: dict = None) -> T:
"""
Retrieves the completion for a given message from the main thread and parses the response using the provided response format.
Parameters:
message (str): The message for which completion is to be retrieved.
response_format (type(T)): The response format to use for the completion.
message_files (list, optional): A list of file ids to be sent as attachments with the message. When using this parameter, files will be assigned both to file_search and code_interpreter tools if available. It is recommended to assign files to the most sutiable tool manually, using the attachments parameter. Defaults to None.
recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart.
additional_instructions (str, optional): Additional instructions to be sent with the message. Defaults to None.
attachments (List[dict], optional): A list of attachments to be sent with the message, following openai format. Defaults to None.
tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None.
Returns:
Final response: The final response from the main thread, parsed using the provided response format.
"""
response_model = None
if isinstance(response_format, type):
response_model = response_format
response_format = type_to_response_format_param(response_format)

res = self.get_completion(message=message,
message_files=message_files,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
attachments=attachments,
tool_choice=tool_choice,
response_format=response_format)

try:
return response_model.model_validate_json(res)
except:
parsed_res = json.loads(res)
if 'refusal' in parsed_res:
raise RefusalError(parsed_res['refusal'])
else:
raise Exception("Failed to parse response: " + res)

def demo_gradio(self, height=450, dark_mode=True, **kwargs):
"""
Expand Down
3 changes: 1 addition & 2 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from agency_swarm.util.openapi import validate_openapi_spec
from agency_swarm.util.shared_state import SharedState
from pydantic import BaseModel
from openai import pydantic_function_tool
from openai.lib._parsing._completions import type_to_response_format_param

class ExampleMessage(TypedDict):
Expand Down Expand Up @@ -106,7 +105,7 @@ def __init__(
tool_resources (ToolResources, optional): A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. Defaults to None.
temperature (float, optional): The temperature parameter for the OpenAI API. Defaults to None.
top_p (float, optional): The top_p parameter for the OpenAI API. Defaults to None.
response_format (Dict, optional): The response format for the OpenAI API. Defaults to None.
response_format (Union[str, Dict, type], optional): The response format for the OpenAI API. If BaseModel is provided, it will be converted to a response format. Defaults to None.
tools_folder (str, optional): Path to a directory containing tools associated with the agent. Each tool must be defined in a separate file. File must be named as the class name of the tool. Defaults to None.
files_folder (Union[List[str], str], optional): Path or list of paths to directories containing files associated with the agent. Defaults to None.
schemas_folder (Union[List[str], str], optional): Path or list of paths to directories containing OpenAPI schemas associated with the agent. Defaults to None.
Expand Down
29 changes: 17 additions & 12 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
import time
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

from openai import BadRequestError
from openai.types.beta import AssistantToolChoice
Expand Down Expand Up @@ -59,9 +59,10 @@ def get_completion_stream(self,
event_handler: type(AgencyEventHandler),
message_files: List[str] = None,
attachments: Optional[List[Attachment]] = None,
recipient_agent=None,
recipient_agent:Agent=None,
additional_instructions: str = None,
tool_choice: AssistantToolChoice = None):
tool_choice: AssistantToolChoice = None,
response_format: Optional[dict] = None):

return self.get_completion(message,
message_files,
Expand All @@ -70,17 +71,19 @@ def get_completion_stream(self,
additional_instructions,
event_handler,
tool_choice,
yield_messages=False)
yield_messages=False,
response_format=response_format)

def get_completion(self,
message: str | List[dict],
message_files: List[str] = None,
attachments: Optional[List[dict]] = None,
recipient_agent=None,
recipient_agent: Agent = None,
additional_instructions: str = None,
event_handler: type(AgencyEventHandler) = None,
tool_choice: AssistantToolChoice = None,
yield_messages: bool = False
yield_messages: bool = False,
response_format: Optional[dict] = None
):
if not recipient_agent:
recipient_agent = self.recipient_agent
Expand Down Expand Up @@ -121,7 +124,7 @@ def get_completion(self,
if yield_messages:
yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj)

self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)

error_attempts = 0
validation_attempts = 0
Expand Down Expand Up @@ -235,14 +238,14 @@ def handle_output(tool_call, output):
# retry run 2 times
if error_attempts < 1 and "something went wrong" in self.run.last_error.message.lower():
time.sleep(1)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
error_attempts += 1
elif 1 <= error_attempts < 5 and "something went wrong" in self.run.last_error.message.lower():
self.create_message(
message="Continue.",
role="user"
)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
error_attempts += 1
else:
raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message)
Expand Down Expand Up @@ -292,13 +295,13 @@ def handle_output(tool_call, output):

validation_attempts += 1

self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)

continue

return last_message

def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None):
def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None, response_format: Optional[dict] = None):
if event_handler:
with self.client.beta.threads.runs.stream(
thread_id=self.thread.id,
Expand All @@ -311,6 +314,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature,
extra_body={"parallel_tool_calls": recipient_agent.parallel_tool_calls},
response_format=response_format
) as stream:
stream.until_done()
self.run = stream.get_final_run()
Expand All @@ -324,7 +328,8 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
max_completion_tokens=recipient_agent.max_completion_tokens,
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature,
parallel_tool_calls=recipient_agent.parallel_tool_calls
parallel_tool_calls=recipient_agent.parallel_tool_calls,
response_format=response_format
)
self.run = self.client.beta.threads.runs.poll(
thread_id=self.thread.id,
Expand Down
2 changes: 2 additions & 0 deletions agency_swarm/util/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class RefusalError(Exception):
pass
5 changes: 5 additions & 0 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,11 @@ class Step(BaseModel):
# check if result is a MathReasoning object
self.assertTrue(MathReasoning.model_validate_json(result))

result = agency.get_completion_parse("how can I solve 3x + 2 = 14", response_format=MathReasoning)

# check if result is a MathReasoning object
self.assertTrue(isinstance(result, MathReasoning))

# --- Helper methods ---

def get_class_folder_path(self):
Expand Down

0 comments on commit 7f2e891

Please sign in to comment.