diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 7fcfdd7b07..896150229b 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -48,18 +48,8 @@ def __init__(self, config): self.user_id = None self.threshold = 0.7 - def add(self, data, filters): - """ - Adds data to the graph. - - Args: - data (str): The data to add to the graph. - filters (dict): A dictionary containing filters to be applied during the addition. - """ - - # retrieve the search results - search_output = self._search(data, filters) - + # extracts nodes and relations from data + def _llm_extract_entities(self, data): if self.config.graph_store.custom_prompt: messages = [ { @@ -94,8 +84,10 @@ def add(self, data, filters): extracted_entities = [] logger.debug(f"Extracted entities: {extracted_entities}") + return extracted_entities - update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) + def _llm_update_existing_memory(self, existing_entities, extracted_entities): + update_memory_prompt = get_update_memory_messages(existing_entities, extracted_entities) _tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: @@ -111,20 +103,71 @@ def add(self, data, filters): ) to_be_added = [] + to_be_updated = [] for item in memory_updates["tool_calls"]: if item["name"] == "add_graph_memory": to_be_added.append(item["arguments"]) elif item["name"] == "update_graph_memory": - self._update_relationship( - item["arguments"]["source"], - item["arguments"]["destination"], - item["arguments"]["relationship"], - filters, - ) + to_be_updated.append(item["arguments"]) elif item["name"] == "noop": continue + return to_be_added, to_be_updated + + # extracts nodes from query, used for searching + def _llm_extract_nodes(self, query, filters): + _tools = [SEARCH_TOOL] + if self.llm_provider in ["azure_openai_structured", "openai_structured"]: + _tools = [SEARCH_STRUCT_TOOL] + search_results = self.llm.generate_response( + messages=[ + { + "role": "system", + "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities. ***DO NOT*** answer the question itself if the given text is a question.", + }, + {"role": "user", "content": query}, + ], + tools=_tools, + ) + + node_list = [] + + for item in search_results["tool_calls"]: + if item["name"] == "search": + try: + node_list.extend(item["arguments"]["nodes"]) + except Exception as e: + logger.error(f"Error in search tool: {e}") + + node_list = list(set(node_list)) + node_list = [node.lower().replace(" ", "_") for node in node_list] + + logger.debug(f"Node list for search query : {node_list}") + return node_list + + def add(self, data, filters): + """ + Adds data to the graph. + + Args: + data (str): The data to add to the graph. + filters (dict): A dictionary containing filters to be applied during the addition. + """ + + # retrieve the search results + existing_entities = self._search(data, filters) + extracted_entities = self._llm_extract_entities(data) + to_be_added, to_be_updated = self._llm_update_existing_memory(existing_entities, extracted_entities) + + for item in to_be_updated: + self._update_relationship( + item["arguments"]["source"], + item["arguments"]["destination"], + item["arguments"]["relationship"], + filters, + ) + returned_entities = [] for item in to_be_added: @@ -168,34 +211,8 @@ def add(self, data, filters): return returned_entities def _search(self, query, filters, limit=100): - _tools = [SEARCH_TOOL] - if self.llm_provider in ["azure_openai_structured", "openai_structured"]: - _tools = [SEARCH_STRUCT_TOOL] - search_results = self.llm.generate_response( - messages=[ - { - "role": "system", - "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities. ***DO NOT*** answer the question itself if the given text is a question.", - }, - {"role": "user", "content": query}, - ], - tools=_tools, - ) - - node_list = [] - - for item in search_results["tool_calls"]: - if item["name"] == "search": - try: - node_list.extend(item["arguments"]["nodes"]) - except Exception as e: - logger.error(f"Error in search tool: {e}") - - node_list = list(set(node_list)) - node_list = [node.lower().replace(" ", "_") for node in node_list] - - logger.debug(f"Node list for search query : {node_list}") + node_list = self._llm_extract_nodes(query, filters) result_relations = [] for node in node_list: diff --git a/mem0/memory/main.py b/mem0/memory/main.py index f0f9cbaeed..6b89f6dd20 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -134,7 +134,7 @@ def add( ) return vector_store_result - def _add_to_vector_store(self, messages, metadata, filters): + def _llm_extract_facts(self, messages): parsed_messages = parse_messages(messages) if self.custom_prompt: @@ -157,6 +157,21 @@ def _add_to_vector_store(self, messages, metadata, filters): logging.error(f"Error in new_retrieved_facts: {e}") new_retrieved_facts = [] + return new_retrieved_facts + + def _llm_new_memories_with_actions(self, retrieved_old_memory, new_retrieved_facts): + function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts) + + new_memories_with_actions = self.llm.generate_response( + messages=[{"role": "user", "content": function_calling_prompt}], + response_format={"type": "json_object"}, + ) + new_memories_with_actions = json.loads(new_memories_with_actions) + return new_memories_with_actions + + def _add_to_vector_store(self, messages, metadata, filters): + new_retrieved_facts = self._llm_extract_facts(messages) + retrieved_old_memory = [] new_message_embeddings = {} for new_mem in new_retrieved_facts: @@ -178,13 +193,7 @@ def _add_to_vector_store(self, messages, metadata, filters): temp_uuid_mapping[str(idx)] = item["id"] retrieved_old_memory[idx]["id"] = str(idx) - function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts) - - new_memories_with_actions = self.llm.generate_response( - messages=[{"role": "user", "content": function_calling_prompt}], - response_format={"type": "json_object"}, - ) - new_memories_with_actions = json.loads(new_memories_with_actions) + new_memories_with_actions = self._llm_new_memories_with_actions(retrieved_old_memory, new_retrieved_facts) returned_memories = [] try: