Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add converter to support outbound calls to OpenAI servers that want tool not function #619

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,14 @@ def verify_first_message_correctness(self, response, require_send_message=True,
response_message = response.choices[0].message

# First message should be a call to send_message with a non-empty content
if require_send_message and not response_message.get("function_call"):
if require_send_message and not (response_message.get("function_call") or response_message.get("tool_calls")):
printd(f"First message didn't include function call: {response_message}")
return False

function_call = response_message.get("function_call")
if function_call is None and "tool_calls" in response_message and len(response_message["tool_calls"]) >= 1:
# Support for responses that have "tool_calls" instead of "function"
function_call = response_message["tool_calls"][0].get("function")
function_name = function_call.get("name") if function_call is not None else ""
if require_send_message and function_name != "send_message" and function_name != "archival_memory_search":
printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")
Expand Down Expand Up @@ -463,7 +466,14 @@ def handle_ai_response(self, response_message):
messages = [] # append these to the history when done

# Step 2: check if LLM wanted to call a function
if response_message.get("function_call"):
if response_message.get("function_call") or response_message.get("tool_calls"):
# TODO handle parallel function calling / move internal represetations to tool_calls from function_call
if "tool_calls" in response_message:
# Hack to go backwards from tool_calls to function_call representation
if len(response_message["tool_calls"]) > 1:
print(f"{CLI_WARNING_PREFIX}multiple tool calls not supported (got {len(response_message['tool_calls'])})")
response_message["function_call"] = response_message["tool_calls"][0]["function"]

# The content if then internal monologue, not chat
self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
Expand Down Expand Up @@ -803,7 +813,7 @@ def get_ai_reply(
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
Expand Down
118 changes: 116 additions & 2 deletions memgpt/openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import random
import os
import string
import time
import requests
import time
Expand All @@ -20,6 +22,111 @@
}


def convert_from_functions_to_tools(data: dict, generate_tool_call_ids: bool = True, allow_content_in_tool_calls: bool = True) -> dict:
"""Convert from the old style of 'functions' to 'tools'

Main differences that needs to be handled in the ChatCompletion request object:
(https://platform.openai.com/docs/api-reference/chat/create)

- data.function_call
-> data.tool_choice ("none" or "auto")
- data.functions (array of {description/name/parameters})
-> data.tools (array of (type: 'function', function: {description/name/arguments}))
- data.messages
- role == 'assistant'
- function_call ({arguments/name})
-> tool_calls (array of (id, type: 'function', function: {name/arguments}))
- role == 'function'
-> role == 'tool'
- name
-> tool_call_id
"""

def create_tool_call_id(prefix: str = "call_", length: int = 22) -> str:
# Generate a random string of letters and digits
random_str = "".join(random.choices(string.ascii_letters + string.digits, k=length))
return prefix + random_str

data = data.copy()

# function_call -> tool_choice
# function_call = None -> tool_choice = "none"
if "function_call" in data:
data["tool_choice"] = data.pop("function_call")
if data["tool_choice"] is None:
# None = default option
data["tool_choice"] = "auto" if "functions" in data else "none"
elif data["tool_choice"] in ["none", "auto"]:
# !None = was manually set
data["tool_choice"] = data["tool_choice"]
else:
# Assume function call was set to a name
if isinstance(data["tool_choice"], dict) and "name" in data["tool_choice"]:
data["tool_choice"] = {"type": "function", "function": {"name": data["tool_choice"]["name"]}}
elif isinstance(data["tool_choice", str]):
data["tool_choice"] = {"type": "function", "function": {"name": data["tool_choice"]}}
else:
ValueError(data["tool_choice"])

# functions -> tools
if "functions" in data:
data["tools"] = [{"type": "function", "function": json_schema} for json_schema in data.pop("functions")]

# need to correct for assistant role (that calls functions)
# and function role (renamed to "tool" role)
if "messages" in data:
renamed_messages = []
for i, msg in enumerate(data["messages"]):
# correct the function role
if msg["role"] == "function":
msg["role"] = "tool"
if "name" in msg:
# Use 'name' or None?
if data["messages"][i - 1]["role"] == "assistant":
# NOTE assumes len(tool_calls) == 1
prior_message = data["messages"][i - 1]
try:
msg["tool_call_id"] = prior_message["tool_calls"][0]["id"]
except (KeyError, IndexError, TypeError) as e:
print(f"Warning: couldn't find tool_call id to match with tool result")
# TODO figure out what we should do here if we can't find the relevant tool message (use 'name' or None?)
# msg["tool_call_id"] = msg["name"]
msg["tool_call_id"] = None

else:
# TODO figure out what we should do here if we can't find the relevant tool message (use 'name' or None?)
# msg["tool_call_id"] = msg["name"]
msg["tool_call_id"] = None

# NOTE: According to the official API docs, 'tool' role shouldn't have name
# However, it appears in their example docs + API throws an error when it's not included
# https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models
# msg.pop("name")

# correct the assistant role
elif msg["role"] == "assistant":
if "function_call" in msg:
msg["tool_calls"] = [
{
# TODO should we use 'name' instead of None here?
"id": create_tool_call_id() if generate_tool_call_ids else None,
"type": "function",
"function": msg.pop("function_call"),
}
]
if not allow_content_in_tool_calls:
msg["content"] = None
# TODO need backup of moving content into inner monologue parameter
# (vs just deleting it)
# raise NotImplementedError
print(f"Warning: deleting 'content' in function call assistant message without replacement")

renamed_messages.append(msg)
data["messages"] = renamed_messages

return data


def is_context_overflow_error(exception):
from memgpt.utils import printd

Expand Down Expand Up @@ -164,7 +271,7 @@ def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version
raise e


def openai_chat_completions_request(url, api_key, data):
def openai_chat_completions_request(url, api_key, data, use_tool_naming=True):
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
from memgpt.utils import printd

Expand All @@ -176,6 +283,9 @@ def openai_chat_completions_request(url, api_key, data):
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")

if use_tool_naming:
data = convert_from_functions_to_tools(data=data)

printd(f"Sending request to {url}")
try:
# Example code to trigger a rate limit response:
Expand Down Expand Up @@ -239,7 +349,7 @@ def openai_embeddings_request(url, api_key, data):
raise e


def azure_openai_chat_completions_request(resource_name, deployment_id, api_version, api_key, data):
def azure_openai_chat_completions_request(resource_name, deployment_id, api_version, api_key, data, use_tool_naming=True):
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from memgpt.utils import printd

Expand All @@ -252,6 +362,10 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")

if use_tool_naming:
# TODO azure doesn't seem to handle tool roles properly atm
data = convert_from_functions_to_tools(data=data)

printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
Expand Down