diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 14f980f271..6be222c4d2 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -67,6 +67,7 @@ def add( metadata=None, filters=None, prompt=None, + infer=True, ): """ Create a new memory. @@ -79,7 +80,8 @@ def add( metadata (dict, optional): Metadata to store with the memory. Defaults to None. filters (dict, optional): Filters to apply to the search. Defaults to None. prompt (str, optional): Prompt to use for memory deduction. Defaults to None. - + infer (bool, optional): Whether to use inference to add the memory. Defaults to True. + Returns: dict: A dictionary containing the result of the memory addition operation. result: dict of affected events with each dict has the following key: @@ -111,7 +113,7 @@ def add( messages = [{"role": "user", "content": messages}] with concurrent.futures.ThreadPoolExecutor() as executor: - future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters) + future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer) future2 = executor.submit(self._add_to_graph, messages, filters) concurrent.futures.wait([future1, future2]) @@ -134,9 +136,17 @@ def add( ) return vector_store_result - def _add_to_vector_store(self, messages, metadata, filters): + def _add_to_vector_store(self, messages, metadata, filters, infer=True): parsed_messages = parse_messages(messages) + if not infer: + messages_embeddings = self.embedding_model.embed(parsed_messages) + new_message_embeddings = {parsed_messages: messages_embeddings} + memory_id = self._create_memory( + data=parsed_messages, existing_embeddings=new_message_embeddings, metadata=metadata + ) + return [{"id": memory_id, "memory": parsed_messages, "event": "ADD"}] + if self.custom_prompt: system_prompt = self.custom_prompt user_prompt = f"Input: {parsed_messages}" diff --git a/tests/test_main.py b/tests/test_main.py index a311f854b2..f281868451 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -47,7 +47,7 @@ def test_add(memory_instance, version, enable_graph): assert result["relations"] == [] memory_instance._add_to_vector_store.assert_called_once_with( - [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"} + [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, True ) # Remove the conditional assertion for _add_to_graph @@ -55,6 +55,79 @@ def test_add(memory_instance, version, enable_graph): [{"role": "user", "content": "Test message"}], {"user_id": "test_user"} ) +@pytest.mark.parametrize("version, enable_graph, infer", [ + ("v1.0", False, False), + ("v1.1", True, True), + ("v1.1", True, False) +]) +def test_add_with_inference(memory_instance, version, enable_graph, infer): + memory_instance.config.version = version + memory_instance.enable_graph = enable_graph + + # Setup mocks + memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3]) + memory_instance.vector_store.insert = Mock() + memory_instance.vector_store.search = Mock(return_value=[]) + memory_instance.db.add_history = Mock() + memory_instance._add_to_graph = Mock(return_value=[]) + + # Mock LLM responses for inference case + if infer: + memory_instance.llm.generate_response = Mock(side_effect=[ + '{"facts": ["Test fact 1", "Test fact 2"]}', # First call for fact retrieval + '{"memory": [{"event": "ADD", "text": "Test fact 1"},{"event": "ADD", "text": "Test fact 2"}]}' # Second call for memory actions + ]) + else: + memory_instance.llm.generate_response = Mock() + + # Execute + result = memory_instance.add( + messages=[{"role": "user", "content": "Test fact 1 Text fact 2"}], + user_id="test_user", + infer=infer + ) + + # Verify basic structure of result + assert "results" in result + assert "relations" in result + assert isinstance(result["results"], list) + assert isinstance(result["relations"], list) + + # Verify LLM behavior + if infer: + # Should be called twice: once for fact retrieval, once for memory actions + assert memory_instance.llm.generate_response.call_count == 2 + + # Verify first call (fact retrieval) + first_call = memory_instance.llm.generate_response.call_args_list[0] + assert len(first_call[1]['messages']) == 2 + assert first_call[1]['messages'][0]['role'] == 'system' + assert first_call[1]['messages'][1]['role'] == 'user' + + # Verify embedding was called for the facts + assert memory_instance.embedding_model.embed.call_count == 2 + + # Verify vector store operations + assert memory_instance.vector_store.insert.call_count == 2 + else: + # For non-inference case, should directly create memory without LLM + memory_instance.llm.generate_response.assert_not_called() + # Should still embed the original message + memory_instance.embedding_model.embed.assert_called_once_with("user: Test fact 1 Text fact 2\n") + memory_instance.vector_store.insert.assert_called_once() + + # Verify graph behavior + memory_instance._add_to_graph.assert_called_once_with( + [{"role": "user", "content": "Test fact 1 Text fact 2"}], {"user_id": "test_user"} + ) + + if version == "v1.1": + assert isinstance(result, dict) + assert "results" in result + assert "relations" in result + else: + assert isinstance(result["results"], list) + def test_get(memory_instance): mock_memory = Mock( diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 8088f380ed..c3e42691e4 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -3,7 +3,6 @@ import pytest from mem0 import Memory, MemoryClient -from mem0.configs.prompts import MEMORY_ANSWER_PROMPT from mem0.proxy.main import Chat, Completions, Mem0 @@ -94,4 +93,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm call_args = mock_litellm.completion.call_args[1] assert call_args["messages"][0]["role"] == "system" - assert call_args["messages"][0]["content"] == MEMORY_ANSWER_PROMPT + assert call_args["messages"][0]["content"] == "You are a helpful assistant."