Skip to content

feat: Get in-context Message.id values from server #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,20 @@ def __init__(self, agent_state, restrict_search_to_summaries=False):
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}

def get_all(self, start=0, count=None):
results = self.storage.get_all(start, count)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)

def text_search(self, query_string, count=None, start=None):
results = self.storage.query_text(query_string, count, start)
return results, len(results)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)

def date_search(self, start_date, end_date, count=None, start=None):
results = self.storage.query_date(start_date, end_date, count, start)
return results, len(results)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)

def __repr__(self) -> str:
total = self.storage.size()
Expand Down
23 changes: 22 additions & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Union, Callable, Optional, Tuple
from typing import Union, Callable, Optional, Tuple, List
import uuid
import json
import logging
Expand Down Expand Up @@ -35,6 +35,21 @@
Message,
ToolCall,
)
from memgpt.data_types import (
Source,
Passage,
Document,
User,
AgentState,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
)

# TODO use custom interface
from memgpt.interface import CLIInterface # for printing to terminal
Expand Down Expand Up @@ -677,6 +692,12 @@ def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:

return memory_obj

def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]:
"""Get the message ids of the in-context messages in the agent's memory"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
return [m.id for m in memgpt_agent._messages]

def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent message queue"""
if self.ms.get_user(user_id=user_id) is None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def test_server():
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
assert len(messages_4) == 1

# test in-context message ids
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
assert len(in_context_ids) == len(messages_3)
assert isinstance(in_context_ids[0], uuid.UUID)
message_ids = [m["id"] for m in messages_3]
for message_id in message_ids:
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"

# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
Expand Down