Skip to content

Commit

Permalink
feat: Get in-context Message.id values from server (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Jan 18, 2024
1 parent 34c2ee2 commit 97b70f5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
11 changes: 9 additions & 2 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,20 @@ def __init__(self, agent_state, restrict_search_to_summaries=False):
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}

def get_all(self, start=0, count=None):
results = self.storage.get_all(start, count)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)

def text_search(self, query_string, count=None, start=None):
results = self.storage.query_text(query_string, count, start)
return results, len(results)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)

def date_search(self, start_date, end_date, count=None, start=None):
results = self.storage.query_date(start_date, end_date, count, start)
return results, len(results)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)

def __repr__(self) -> str:
total = self.storage.size()
Expand Down
23 changes: 22 additions & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Union, Callable, Optional, Tuple
from typing import Union, Callable, Optional, Tuple, List
import uuid
import json
import logging
Expand Down Expand Up @@ -35,6 +35,21 @@
Message,
ToolCall,
)
from memgpt.data_types import (
Source,
Passage,
Document,
User,
AgentState,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
)

# TODO use custom interface
from memgpt.interface import CLIInterface # for printing to terminal
Expand Down Expand Up @@ -677,6 +692,12 @@ def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:

return memory_obj

def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]:
"""Get the message ids of the in-context messages in the agent's memory"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
return [m.id for m in memgpt_agent._messages]

def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent message queue"""
if self.ms.get_user(user_id=user_id) is None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def test_server():
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
assert len(messages_4) == 1

# test in-context message ids
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
assert len(in_context_ids) == len(messages_3)
assert isinstance(in_context_ids[0], uuid.UUID)
message_ids = [m["id"] for m in messages_3]
for message_id in message_ids:
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"

# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
Expand Down

0 comments on commit 97b70f5

Please sign in to comment.