diff --git a/biochatter/database_agent.py b/biochatter/database_agent.py index 23695524..3f79a8d8 100644 --- a/biochatter/database_agent.py +++ b/biochatter/database_agent.py @@ -1,5 +1,6 @@ from collections.abc import Callable import json +from typing import Dict, List, Optional from langchain.schema import Document import neo4j_utils as nu @@ -81,6 +82,44 @@ def _generate_query(self, query: str): results = self.driver.query(query=query) return query, results + def _build_response( + self, + results: List[Dict], + cypher_query: str, + results_num: Optional[int] = 3, + ) -> List[Document]: + if len(results) == 0: + return [ + Document( + page_content=( + "I didn't find any result in knowledge graph, " + f"but here is the query I used: {cypher_query}. " + "You can ask user to refine the question. " + "Note: please ensure to include the query in a code " + "block in your response so that the user can refine " + "their question effectively." + ), + metadata={"cypher_query": cypher_query}, + ) + ] + + clipped_results = results[:results_num] if results_num > 0 else results + results_dump = json.dumps(clipped_results) + + return [ + Document( + page_content=( + "The results retrieved from knowledge graph are: " + f"{results_dump}. " + f"The query used is: {cypher_query}. " + "Note: please ensure to include the query in a code block " + "in your response so that the user can refine " + "their question effectively." + ), + metadata={"cypher_query": cypher_query}, + ) + ] + def get_query_results(self, query: str, k: int = 3) -> list[Document]: """ Generate a query using the prompt engine and return the results. @@ -109,40 +148,14 @@ def get_query_results(self, query: str, k: int = 3) -> list[Document]: else: results = self.driver.query(query=cypher_query) - documents = [] # return first k results # returned nodes can have any formatting, and can also be empty or fewer # than k if results is None or len(results) == 0 or results[0] is None: return [] - if len(results[0]) == 0: - return [ - Document( - page_content = ( - "I didn't find any result in knowledge graph, " - f"but here is the query I used: {cypher_query}. " - "You can ask user to refine the question, " - "but don't make up anything." - ), - metadata={ - "cypher_query": cypher_query, - }, - ) - ] - - for result in results[0]: - documents.append( - Document( - page_content=json.dumps(result), - metadata={ - "cypher_query": cypher_query, - }, - ) - ) - if len(documents) == k: - break - - return documents + return self._build_response( + results=results[0], cypher_query=cypher_query, results_num=k + ) def get_description(self): result = self.driver.query("MATCH (n:Schema_info) RETURN n LIMIT 1") diff --git a/test/test_database_agent.py b/test/test_database_agent.py index 2936f0cd..6be6c60d 100644 --- a/test/test_database_agent.py +++ b/test/test_database_agent.py @@ -44,20 +44,14 @@ def test_get_query_results_with_reflexion(): result = db_agent.get_query_results("test_query", 3) # Check if the result is as expected - expected_result = [ - Document( - page_content='{"key": "value"}', - metadata={"cypher_query": "test_query"}, - ), - Document( - page_content='{"key": "value"}', - metadata={"cypher_query": "test_query"}, - ), - Document( - page_content='{"key": "value"}', - metadata={"cypher_query": "test_query"}, - ), - ] + expected_result = db_agent._build_response( + [ + {"key": "value"}, + {"key": "value"}, + {"key": "value"}, + ], + "test_query", + ) assert result == expected_result @@ -98,18 +92,12 @@ def test_get_query_results_without_reflexion(): result = db_agent.get_query_results("test_query", 3) # Check if the result is as expected - expected_result = [ - Document( - page_content='{"key": "value"}', - metadata={"cypher_query": "test_query"}, - ), - Document( - page_content='{"key": "value"}', - metadata={"cypher_query": "test_query"}, - ), - Document( - page_content='{"key": "value"}', - metadata={"cypher_query": "test_query"}, - ), - ] + expected_result = db_agent._build_response( + [ + {"key": "value"}, + {"key": "value"}, + {"key": "value"}, + ], + "test_query", + ) assert result == expected_result