diff --git a/memgpt/agent.py b/memgpt/agent.py index c4a65c9df2..3597a00a0b 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -340,8 +340,18 @@ def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]: """Load a list of messages from recall storage""" # Pull the message objects from the database - message_objs = [self.persistence_manager.recall_memory.storage.get(msg_id) for msg_id in message_ids] - assert all([isinstance(msg, Message) for msg in message_objs]) + message_objs = [] + for msg_id in message_ids: + msg_obj = self.persistence_manager.recall_memory.storage.get(msg_id) + if msg_obj: + if isinstance(msg_obj, Message): + message_objs.append(msg_obj) + else: + printd(f"Warning - message ID {msg_id} is not a Message object") + warnings.warn(f"Warning - message ID {msg_id} is not a Message object") + else: + printd(f"Warning - message ID {msg_id} not found in recall storage") + warnings.warn(f"Warning - message ID {msg_id} not found in recall storage") return message_objs diff --git a/memgpt/agent_store/chroma.py b/memgpt/agent_store/chroma.py index 9069022dea..3bfe620f20 100644 --- a/memgpt/agent_store/chroma.py +++ b/memgpt/agent_store/chroma.py @@ -131,7 +131,7 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None): results = self.collection.get(ids=ids, include=self.include, where=filters) return self.results_to_records(results) - def get(self, id): + def get(self, id: str): results = self.collection.get(ids=[str(id)]) if len(results["ids"]) == 0: return None diff --git a/memgpt/agent_store/milvus.py b/memgpt/agent_store/milvus.py index 0449586342..5fd067544d 100644 --- a/memgpt/agent_store/milvus.py +++ b/memgpt/agent_store/milvus.py @@ -91,7 +91,7 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]: ) return self._list_to_records(query_res) - def get(self, id: uuid.UUID) -> Optional[RecordType]: + def get(self, id: str) -> Optional[RecordType]: res = self.client.get(collection_name=self.table_name, ids=str(id)) return self._list_to_records(res)[0] if res else None diff --git a/memgpt/agent_store/qdrant.py b/memgpt/agent_store/qdrant.py index 640ad91ac8..84c16f4527 100644 --- a/memgpt/agent_store/qdrant.py +++ b/memgpt/agent_store/qdrant.py @@ -73,7 +73,7 @@ def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]: ) return self.to_records(results) - def get(self, id: uuid.UUID) -> Optional[RecordType]: + def get(self, id: str) -> Optional[RecordType]: results = self.qdrant_client.retrieve( collection_name=self.table_name, ids=[str(id)], diff --git a/memgpt/schemas/memgpt_message.py b/memgpt/schemas/memgpt_message.py index 1182ea528e..b99eb8c766 100644 --- a/memgpt/schemas/memgpt_message.py +++ b/memgpt/schemas/memgpt_message.py @@ -1,8 +1,8 @@ import json from datetime import datetime, timezone -from typing import Literal, Optional, Union +from typing import Annotated, Literal, Optional, Union -from pydantic import BaseModel, field_serializer, field_validator +from pydantic import BaseModel, Field, field_serializer, field_validator # MemGPT API style responses (intended to be easier to use vs getting true Message types) @@ -17,6 +17,9 @@ class MemGPTMessage(BaseModel): """ + # NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions + # see `message_type` attribute + id: str date: datetime @@ -39,6 +42,7 @@ class SystemMessage(MemGPTMessage): date (datetime): The date the message was created in ISO format """ + message_type: Literal["system_message"] = "system_message" message: str @@ -52,6 +56,7 @@ class UserMessage(MemGPTMessage): date (datetime): The date the message was created in ISO format """ + message_type: Literal["user_message"] = "user_message" message: str @@ -65,15 +70,18 @@ class InternalMonologue(MemGPTMessage): date (datetime): The date the message was created in ISO format """ + message_type: Literal["internal_monologue"] = "internal_monologue" internal_monologue: str class FunctionCall(BaseModel): + name: str arguments: str class FunctionCallDelta(BaseModel): + name: Optional[str] arguments: Optional[str] @@ -97,6 +105,7 @@ class FunctionCallMessage(MemGPTMessage): date (datetime): The date the message was created in ISO format """ + message_type: Literal["function_call"] = "function_call" function_call: Union[FunctionCall, FunctionCallDelta] # NOTE: this is required for the FunctionCallDelta exclude_none to work correctly @@ -140,17 +149,16 @@ class FunctionReturn(MemGPTMessage): date (datetime): The date the message was created in ISO format """ + message_type: Literal["function_return"] = "function_return" function_return: str status: Literal["success", "error"] -# MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn] - - # Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string class AssistantMessage(MemGPTMessage): + message_type: Literal["assistant_message"] = "assistant_message" assistant_message: str @@ -159,3 +167,9 @@ class LegacyFunctionCallMessage(MemGPTMessage): LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn] + + +MemGPTMessageUnion = Annotated[ + Union[SystemMessage, UserMessage, InternalMonologue, FunctionCallMessage, FunctionReturn, AssistantMessage], + Field(discriminator="message_type"), +] diff --git a/memgpt/server/rest_api/routers/v1/agents.py b/memgpt/server/rest_api/routers/v1/agents.py index fd8acdf890..809b7f62d7 100644 --- a/memgpt/server/rest_api/routers/v1/agents.py +++ b/memgpt/server/rest_api/routers/v1/agents.py @@ -8,7 +8,11 @@ from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState from memgpt.schemas.enums import MessageRole, MessageStreamStatus -from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage +from memgpt.schemas.memgpt_message import ( + LegacyMemGPTMessage, + MemGPTMessage, + MemGPTMessageUnion, +) from memgpt.schemas.memgpt_request import MemGPTRequest from memgpt.schemas.memgpt_response import MemGPTResponse from memgpt.schemas.memory import ( @@ -237,7 +241,7 @@ def delete_agent_archival_memory( return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) -@router.get("/{agent_id}/messages", response_model=List[Message], operation_id="list_agent_messages") +@router.get("/{agent_id}/messages", response_model=Union[List[Message], List[MemGPTMessageUnion]], operation_id="list_agent_messages") def get_agent_messages( agent_id: str, server: "SyncServer" = Depends(get_memgpt_server),