diff --git a/autogen/agentchat/contrib/graph_rag/__init__.py b/autogen/agentchat/contrib/graph_rag/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/autogen/agentchat/contrib/graph_rag/document.py b/autogen/agentchat/contrib/graph_rag/document.py new file mode 100644 index 00000000000..9730269c7ab --- /dev/null +++ b/autogen/agentchat/contrib/graph_rag/document.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional + + +class DocumentType(Enum): + """ + Enum for supporting document type. + """ + + TEXT = auto() + HTML = auto() + PDF = auto() + + +@dataclass +class Document: + """ + A wrapper of graph store query results. + """ + + doctype: DocumentType + data: Optional[object] = None + path_or_url: Optional[str] = "" diff --git a/autogen/agentchat/contrib/graph_rag/graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py new file mode 100644 index 00000000000..28ef6ede84a --- /dev/null +++ b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Protocol + +from .document import Document + + +@dataclass +class GraphStoreQueryResult: + """ + A wrapper of graph store query results. + + answer: human readable answer to question/query. + results: intermediate results to question/query, e.g. node entities. + """ + + answer: Optional[str] = None + results: list = field(default_factory=list) + + +class GraphQueryEngine(Protocol): + """An abstract base class that represents a graph query engine on top of a underlying graph database. + + This interface defines the basic methods for graph rag. + """ + + def init_db(self, input_doc: List[Document] | None = None): + """ + This method initializes graph database with the input documents or records. + Usually, it takes the following steps, + 1. connecting to a graph database. + 2. extract graph nodes, edges based on input data, graph schema and etc. + 3. build indexes etc. + + Args: + input_doc: a list of input documents that are used to build the graph in database. + + Returns: GraphStore + """ + pass + + def add_records(self, new_records: List) -> bool: + """ + Add new records to the underlying database and add to the graph if required. + """ + pass + + def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult: + """ + This method transform a string format question into database query and return the result. + """ + pass diff --git a/autogen/agentchat/contrib/graph_rag/graph_rag_capability.py b/autogen/agentchat/contrib/graph_rag/graph_rag_capability.py new file mode 100644 index 00000000000..b6412305e06 --- /dev/null +++ b/autogen/agentchat/contrib/graph_rag/graph_rag_capability.py @@ -0,0 +1,56 @@ +from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability +from autogen.agentchat.conversable_agent import ConversableAgent + +from .graph_query_engine import GraphQueryEngine + + +class GraphRagCapability(AgentCapability): + """ + A graph rag capability uses a graph query engine to give a conversable agent the graph rag ability. + + An agent class with graph rag capability could + 1. create a graph in the underlying database with input documents. + 2. retrieved relevant information based on messages received by the agent. + 3. generate answers from retrieved information and send messages back. + + For example, + graph_query_engine = GraphQueryEngine(...) + graph_query_engine.init_db([Document(doc1), Document(doc2), ...]) + + graph_rag_agent = ConversableAgent( + name="graph_rag_agent", + max_consecutive_auto_reply=3, + ... + ) + graph_rag_capability = GraphRagCapbility(graph_query_engine) + graph_rag_capability.add_to_agent(graph_rag_agent) + + user_proxy = UserProxyAgent( + name="user_proxy", + code_execution_config=False, + is_termination_msg=lambda msg: "TERMINATE" in msg["content"], + human_input_mode="ALWAYS", + ) + user_proxy.initiate_chat(graph_rag_agent, message="Name a few actors who've played in 'The Matrix'") + + # ChatResult( + # chat_id=None, + # chat_history=[ + # {'content': 'Name a few actors who've played in \'The Matrix\'', 'role': 'graph_rag_agent'}, + # {'content': 'A few actors who have played in The Matrix are: + # - Keanu Reeves + # - Laurence Fishburne + # - Carrie-Anne Moss + # - Hugo Weaving', + # 'role': 'user_proxy'}, + # ...) + + """ + + def __init__(self, query_engine: GraphQueryEngine): + """ + initialize graph rag capability with a graph query engine + """ + ... + + def add_to_agent(self, agent: ConversableAgent): ... diff --git a/test/agentchat/contrib/graph_rag/test_graph_rag_basic.py b/test/agentchat/contrib/graph_rag/test_graph_rag_basic.py new file mode 100644 index 00000000000..7c4a5094947 --- /dev/null +++ b/test/agentchat/contrib/graph_rag/test_graph_rag_basic.py @@ -0,0 +1,17 @@ +from unittest.mock import Mock + +from autogen.agentchat.contrib.graph_rag.graph_query_engine import GraphQueryEngine +from autogen.agentchat.contrib.graph_rag.graph_rag_capability import GraphRagCapability +from autogen.agentchat.conversable_agent import ConversableAgent + + +def test_dry_run(): + """Dry run for basic graph rag objects.""" + mock_graph_query_engine = Mock(spec=GraphQueryEngine) + + graph_rag_agent = ConversableAgent( + name="graph_rag_agent", + max_consecutive_auto_reply=3, + ) + graph_rag_capability = GraphRagCapability(mock_graph_query_engine) + graph_rag_capability.add_to_agent(graph_rag_agent)