Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Direct Import in the open source lib #2107

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def add(
metadata=None,
filters=None,
prompt=None,
infer=True,
):
"""
Create a new memory.
Expand All @@ -79,7 +80,8 @@ def add(
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.

infer (bool, optional): Whether to use inference to add the memory. Defaults to True.

Returns:
dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
Expand Down Expand Up @@ -111,7 +113,7 @@ def add(
messages = [{"role": "user", "content": messages}]

with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer)
future2 = executor.submit(self._add_to_graph, messages, filters)

concurrent.futures.wait([future1, future2])
Expand All @@ -134,9 +136,17 @@ def add(
)
return vector_store_result

def _add_to_vector_store(self, messages, metadata, filters):
def _add_to_vector_store(self, messages, metadata, filters, infer=True):
parsed_messages = parse_messages(messages)

if not infer:
messages_embeddings = self.embedding_model.embed(parsed_messages)
new_message_embeddings = {parsed_messages: messages_embeddings}
memory_id = self._create_memory(
data=parsed_messages, existing_embeddings=new_message_embeddings, metadata=metadata
)
return [{"id": memory_id, "memory": parsed_messages, "event": "ADD"}]

if self.custom_prompt:
system_prompt = self.custom_prompt
user_prompt = f"Input: {parsed_messages}"
Expand Down
75 changes: 74 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,87 @@ def test_add(memory_instance, version, enable_graph):
assert result["relations"] == []

memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, True
)

# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}
)

@pytest.mark.parametrize("version, enable_graph, infer", [
("v1.0", False, False),
("v1.1", True, True),
("v1.1", True, False)
])
def test_add_with_inference(memory_instance, version, enable_graph, infer):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph

# Setup mocks
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
memory_instance.vector_store.insert = Mock()
memory_instance.vector_store.search = Mock(return_value=[])
memory_instance.db.add_history = Mock()
memory_instance._add_to_graph = Mock(return_value=[])

# Mock LLM responses for inference case
if infer:
memory_instance.llm.generate_response = Mock(side_effect=[
'{"facts": ["Test fact 1", "Test fact 2"]}', # First call for fact retrieval
'{"memory": [{"event": "ADD", "text": "Test fact 1"},{"event": "ADD", "text": "Test fact 2"}]}' # Second call for memory actions
])
else:
memory_instance.llm.generate_response = Mock()

# Execute
result = memory_instance.add(
messages=[{"role": "user", "content": "Test fact 1 Text fact 2"}],
user_id="test_user",
infer=infer
)

# Verify basic structure of result
assert "results" in result
assert "relations" in result
assert isinstance(result["results"], list)
assert isinstance(result["relations"], list)

# Verify LLM behavior
if infer:
# Should be called twice: once for fact retrieval, once for memory actions
assert memory_instance.llm.generate_response.call_count == 2

# Verify first call (fact retrieval)
first_call = memory_instance.llm.generate_response.call_args_list[0]
assert len(first_call[1]['messages']) == 2
assert first_call[1]['messages'][0]['role'] == 'system'
assert first_call[1]['messages'][1]['role'] == 'user'

# Verify embedding was called for the facts
assert memory_instance.embedding_model.embed.call_count == 2

# Verify vector store operations
assert memory_instance.vector_store.insert.call_count == 2
else:
# For non-inference case, should directly create memory without LLM
memory_instance.llm.generate_response.assert_not_called()
# Should still embed the original message
memory_instance.embedding_model.embed.assert_called_once_with("user: Test fact 1 Text fact 2\n")
memory_instance.vector_store.insert.assert_called_once()

# Verify graph behavior
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test fact 1 Text fact 2"}], {"user_id": "test_user"}
)

if version == "v1.1":
assert isinstance(result, dict)
assert "results" in result
assert "relations" in result
else:
assert isinstance(result["results"], list)


def test_get(memory_instance):
mock_memory = Mock(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0.proxy.main import Chat, Completions, Mem0


Expand Down Expand Up @@ -94,4 +93,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm

call_args = mock_litellm.completion.call_args[1]
assert call_args["messages"][0]["role"] == "system"
assert call_args["messages"][0]["content"] == MEMORY_ANSWER_PROMPT
assert call_args["messages"][0]["content"] == "You are a helpful assistant."