Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch validation error on /messages endpoint #1750

Merged
merged 6 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,18 @@ def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
"""Load a list of messages from recall storage"""

# Pull the message objects from the database
message_objs = [self.persistence_manager.recall_memory.storage.get(msg_id) for msg_id in message_ids]
assert all([isinstance(msg, Message) for msg in message_objs])
message_objs = []
for msg_id in message_ids:
msg_obj = self.persistence_manager.recall_memory.storage.get(msg_id)
if msg_obj:
if isinstance(msg_obj, Message):
message_objs.append(msg_obj)
else:
printd(f"Warning - message ID {msg_id} is not a Message object")
warnings.warn(f"Warning - message ID {msg_id} is not a Message object")
else:
printd(f"Warning - message ID {msg_id} not found in recall storage")
warnings.warn(f"Warning - message ID {msg_id} not found in recall storage")

return message_objs

Expand Down
2 changes: 1 addition & 1 deletion memgpt/agent_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None):
results = self.collection.get(ids=ids, include=self.include, where=filters)
return self.results_to_records(results)

def get(self, id):
def get(self, id: str):
results = self.collection.get(ids=[str(id)])
if len(results["ids"]) == 0:
return None
Expand Down
2 changes: 1 addition & 1 deletion memgpt/agent_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
)
return self._list_to_records(query_res)

def get(self, id: uuid.UUID) -> Optional[RecordType]:
def get(self, id: str) -> Optional[RecordType]:
res = self.client.get(collection_name=self.table_name, ids=str(id))
return self._list_to_records(res)[0] if res else None

Expand Down
2 changes: 1 addition & 1 deletion memgpt/agent_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
)
return self.to_records(results)

def get(self, id: uuid.UUID) -> Optional[RecordType]:
def get(self, id: str) -> Optional[RecordType]:
results = self.qdrant_client.retrieve(
collection_name=self.table_name,
ids=[str(id)],
Expand Down
24 changes: 19 additions & 5 deletions memgpt/schemas/memgpt_message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from datetime import datetime, timezone
from typing import Literal, Optional, Union
from typing import Annotated, Literal, Optional, Union

from pydantic import BaseModel, field_serializer, field_validator
from pydantic import BaseModel, Field, field_serializer, field_validator

# MemGPT API style responses (intended to be easier to use vs getting true Message types)

Expand All @@ -17,6 +17,9 @@ class MemGPTMessage(BaseModel):

"""

# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
# see `message_type` attribute

id: str
date: datetime

Expand All @@ -39,6 +42,7 @@ class SystemMessage(MemGPTMessage):
date (datetime): The date the message was created in ISO format
"""

message_type: Literal["system_message"] = "system_message"
message: str


Expand All @@ -52,6 +56,7 @@ class UserMessage(MemGPTMessage):
date (datetime): The date the message was created in ISO format
"""

message_type: Literal["user_message"] = "user_message"
message: str


Expand All @@ -65,15 +70,18 @@ class InternalMonologue(MemGPTMessage):
date (datetime): The date the message was created in ISO format
"""

message_type: Literal["internal_monologue"] = "internal_monologue"
internal_monologue: str


class FunctionCall(BaseModel):

name: str
arguments: str


class FunctionCallDelta(BaseModel):

name: Optional[str]
arguments: Optional[str]

Expand All @@ -97,6 +105,7 @@ class FunctionCallMessage(MemGPTMessage):
date (datetime): The date the message was created in ISO format
"""

message_type: Literal["function_call"] = "function_call"
function_call: Union[FunctionCall, FunctionCallDelta]

# NOTE: this is required for the FunctionCallDelta exclude_none to work correctly
Expand Down Expand Up @@ -140,17 +149,16 @@ class FunctionReturn(MemGPTMessage):
date (datetime): The date the message was created in ISO format
"""

message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]


# MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn]


# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string


class AssistantMessage(MemGPTMessage):
message_type: Literal["assistant_message"] = "assistant_message"
assistant_message: str


Expand All @@ -159,3 +167,9 @@ class LegacyFunctionCallMessage(MemGPTMessage):


LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]


MemGPTMessageUnion = Annotated[
Union[SystemMessage, UserMessage, InternalMonologue, FunctionCallMessage, FunctionReturn, AssistantMessage],
Field(discriminator="message_type"),
]
8 changes: 6 additions & 2 deletions memgpt/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_message import (
LegacyMemGPTMessage,
MemGPTMessage,
MemGPTMessageUnion,
)
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.memory import (
Expand Down Expand Up @@ -237,7 +241,7 @@ def delete_agent_archival_memory(
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})


@router.get("/{agent_id}/messages", response_model=List[Message], operation_id="list_agent_messages")
@router.get("/{agent_id}/messages", response_model=Union[List[Message], List[MemGPTMessageUnion]], operation_id="list_agent_messages")
def get_agent_messages(
agent_id: str,
server: "SyncServer" = Depends(get_memgpt_server),
Expand Down
Loading