Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update message schema / data type to match OAI tools style #783

Merged
merged 7 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,14 @@ def handle_ai_response(self, response_message):
response_message["tool_call_id"] = tool_call_id
# role: assistant (requesting tool call, set tool call ID)
messages.append(response_message) # extend conversation with assistant's reply
printd(f"Function call message: {messages[-1]}")

# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors

# Failure case 1: function name is wrong
function_name = response_message["function_call"]["name"]
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
try:
function_to_call = self.functions_python[function_name]
except KeyError as e:
Expand Down
28 changes: 23 additions & 5 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.utils import printd
from memgpt.data_types import Record, Message, Passage, Source
from memgpt.data_types import Record, Message, Passage, Source, ToolCall

from datetime import datetime

Expand Down Expand Up @@ -71,6 +71,26 @@ def process_result_value(self, value, dialect):
return np.array(list_value)


class ToolCalls(TypeDecorator):

"""Custom type for storing List[ToolCall] as JSON"""

impl = JSON

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
return [vars(v) for v in value]
return value

def process_result_value(self, value, dialect):
if value:
return [ToolCall(**v) for v in value]
return value


Base = declarative_base()


Expand Down Expand Up @@ -155,8 +175,7 @@ class MessageModel(Base):
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
tool_name = Column(String)
tool_args = Column(String)
tool_calls = Column(ToolCalls)

# tool call response info
# if role == "tool", then this must be specified
Expand Down Expand Up @@ -185,8 +204,7 @@ def to_record(self):
user=self.user,
text=self.text,
model=self.model,
tool_name=self.tool_name,
tool_args=self.tool_args,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
embedding=self.embedding,
created_at=self.created_at,
Expand Down
38 changes: 31 additions & 7 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import uuid
from abc import abstractmethod
from typing import Optional
from typing import Optional, List, Dict
import numpy as np


Expand All @@ -24,8 +24,28 @@ def __init__(self, id: Optional[str] = None):
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"


class ToolCall(object):
def __init__(
self,
id: str,
# TODO should we include this? it's fixed to 'function' only (for now) in OAI schema
tool_call_type: str, # only 'function' is supported
# function: { 'name': ..., 'arguments': ...}
cpacker marked this conversation as resolved.
Show resolved Hide resolved
function: Dict[str, str],
):
self.id = id
self.tool_call_type = tool_call_type
self.function = function


class Message(Record):
"""Representation of a message sent from the agent -> user. Also includes function calls."""
"""Representation of a message sent.

Messages can be:
- agent->user (role=='agent')
- user->agent and system->agent (role=='user')
- or function/tool call returns (role=='function'/'tool').
"""

def __init__(
self,
Expand All @@ -36,9 +56,8 @@ def __init__(
model: str, # model used to make function call
user: Optional[str] = None, # optional participant name
created_at: Optional[str] = None,
tool_name: Optional[str] = None, # name of tool used
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
tool_args: Optional[str] = None, # args of tool used
tool_call_id: Optional[str] = None, # id of tool call
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
tool_call_id: Optional[str] = None,
embedding: Optional[np.ndarray] = None,
id: Optional[str] = None,
):
Expand All @@ -54,8 +73,13 @@ def __init__(
self.user = user

# tool (i.e. function) call info (optional)
self.tool_name = tool_name
self.tool_args = tool_args

# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
self.tool_calls = tool_calls

# if role == "tool", then this must be specified
# if role != "tool", this must be null
self.tool_call_id = tool_call_id

# embedding (optional)
Expand Down
21 changes: 18 additions & 3 deletions memgpt/persistence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
EmbeddingArchivalMemory,
)
from memgpt.utils import get_local_time, printd
from memgpt.data_types import Message
from memgpt.data_types import Message, ToolCall
from memgpt.config import MemGPTConfig

from datetime import datetime
Expand Down Expand Up @@ -116,15 +116,30 @@ def json_to_message(self, message_json) -> Message:
timestamp = message_json["timestamp"]
message = message_json["message"]

# TODO: change this when we fully migrate to tool calls API
if "function_call" in message:
tool_calls = [
ToolCall(
id=message["tool_call_id"],
tool_call_type="function",
function={
"name": message["function_call"]["name"],
"arguments": message["function_call"]["arguments"],
},
)
]
printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}")
else:
tool_calls = None

return Message(
user_id=self.config.anon_clientid,
agent_id=self.agent_config.name,
role=message["role"],
text=message["content"],
model=self.agent_config.model,
created_at=parse_formatted_time(timestamp),
tool_name=message["function_name"] if "function_name" in message else None,
tool_args=message["function_args"] if "function_args" in message else None,
tool_calls=tool_calls,
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
id=message["id"] if "id" in message else None,
)
Expand Down
Loading