Skip to content

Commit

Permalink
Merge pull request #113 from fengsh27/main
Browse files Browse the repository at this point in the history
Inject conversation factory to RagAgent
  • Loading branch information
slobentanzer authored Feb 3, 2024
2 parents 3fa797e + 90d28cb commit c1e98e5
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 49 deletions.
10 changes: 9 additions & 1 deletion biochatter/database_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
model_name: str,
connection_args: dict,
schema_config_or_info_dict: dict,
conversation_factory: callable,
) -> None:
"""
Create a DatabaseAgent analogous to the VectorDatabaseAgentMilvus class,
Expand All @@ -19,10 +20,14 @@ def __init__(
Args:
connection_args (dict): A dictionary of arguments to connect to the
database. Contains database name, URI, user, and password.
conversation_factory (callable): A function to create a conversation
for creating the KG query.
"""
self.prompt_engine = BioCypherPromptEngine(
model_name=model_name,
schema_config_or_info_dict=schema_config_or_info_dict,
conversation_factory=conversation_factory,
)
self.connection_args = connection_args
self.driver = None
Expand All @@ -41,6 +46,7 @@ def connect(self) -> None:
user=user,
password=password,
)

def is_connected(self) -> bool:
return not self.driver is None

Expand Down Expand Up @@ -69,7 +75,9 @@ def get_query_results(self, query: str, k: int = 3) -> list[Document]:
# return first k results
# returned nodes can have any formatting, and can also be empty or fewer
# than k
for result in results:
if results is None or len(results) == 0 or results[0] is None:
return []
for result in results[0]:
documents.append(
Document(
page_content=json.dumps(result),
Expand Down
41 changes: 28 additions & 13 deletions biochatter/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
schema_config_or_info_path: Optional[str] = None,
schema_config_or_info_dict: Optional[dict] = None,
model_name: str = "gpt-3.5-turbo",
conversation_factory: Optional[callable] = None,
) -> None:
"""
Expand All @@ -32,8 +33,13 @@ def __init__(
generated by BioCypher's `write_schema_info` function
(preferred).
Todo:
inject conversation directly instead of specifying model name?
model_name: The name of the model to use for the conversation.
DEPRECATED: This should now be set in the conversation factory.
conversation_factory: A function used to create a conversation for
creating the KG query. If not provided, a default function is
used (creating an OpenAI conversation with the specified model,
see `_get_conversation`).
"""

if not schema_config_or_info_path and not schema_config_or_info_dict:
Expand All @@ -48,6 +54,13 @@ def __init__(
"path to a file or as a dictionary, not both."
)

# set conversation factory or use default
self.conversation_factory = (
conversation_factory
if conversation_factory is not None
else self._get_conversation
)

if schema_config_or_info_path:
# read the schema configuration
with open(schema_config_or_info_path, "r") as f:
Expand Down Expand Up @@ -79,13 +92,13 @@ def __init__(
value["represented_as"] == "node"
and name_indicates_relationship
):
self.relationships[
sentencecase_to_pascalcase(key)
] = value
self.relationships[sentencecase_to_pascalcase(key)] = (
value
)
elif value["represented_as"] == "edge":
self.relationships[
sentencecase_to_pascalcase(key)
] = value
self.relationships[sentencecase_to_pascalcase(key)] = (
value
)
else:
for key, value in schema_config.items():
if not isinstance(value, dict):
Expand Down Expand Up @@ -134,7 +147,9 @@ def _capitalise_source_and_target(self, relationship: dict) -> dict:
]
return relationship

def generate_query(self, question: str, query_language: str) -> str:
def generate_query(
self, question: str, query_language: Optional[str] = "Cypher"
) -> str:
"""
Wrap entity and property selection and query generation; return the
generated query.
Expand All @@ -149,23 +164,23 @@ def generate_query(self, question: str, query_language: str) -> str:
"""

success1 = self._select_entities(
question=question, conversation=self._get_conversation()
question=question, conversation=self.conversation_factory()
)
if not success1:
raise ValueError(
"Entity selection failed. Please try again with a different "
"question."
)
success2 = self._select_relationships(
conversation=self._get_conversation()
conversation=self.conversation_factory()
)
if not success2:
raise ValueError(
"Relationship selection failed. Please try again with a "
"different question."
)
success3 = self._select_properties(
conversation=self._get_conversation()
conversation=self.conversation_factory()
)
if not success3:
raise ValueError(
Expand All @@ -179,7 +194,7 @@ def generate_query(self, question: str, query_language: str) -> str:
relationships=self.selected_relationship_labels,
properties=self.selected_properties,
query_language=query_language,
conversation=self._get_conversation(),
conversation=self.conversation_factory(),
)

def _get_conversation(
Expand Down
31 changes: 19 additions & 12 deletions biochatter/rag_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, List


class RagAgentModeEnum:
VectorStore = "vectorstore"
KG = "kg"


class RagAgent:
def __init__(
self,
Expand All @@ -13,8 +15,9 @@ def __init__(
n_results: Optional[int] = 3,
use_prompt: Optional[bool] = False,
schema_config_or_info_dict: Optional[dict] = None,
conversation_factory: Optional[callable] = None,
embedding_func: Optional[object] = None,
documentids_workspace: Optional[List[str]]=None
documentids_workspace: Optional[List[str]] = None,
) -> None:
"""
Create a RAG agent that can return results from a database or vector
Expand All @@ -31,23 +34,22 @@ def __init__(
n_results: the number of results to return for method
generate_response
schema_config_or_info_dict (dict): A dictionary of schema
information for the database. Required if mode is "kg".
conversation_factory (callable): A function used to create a
conversation for creating the KG query. Required if mode is
"kg".
embedding_func (object): An embedding function. Required if mode is
"vectorstore".
embedding_collection_name (str): The name of the embedding
collection. Required if mode is "vectorstore".
documentids_workspace (Optional[List[str]], optional): a list of
document IDs that defines the scope within which similarity
search occurs. Defaults to None, which means the operations will
be performed across all documents in the database.
metadata_collection_name (str): The name of the metadata
collection. Required if mode is "vectorstore".
documentids_workspace (Optional[List[str]], optional): a list of document IDs
that defines the scope within which similarity search occurs. Defaults
to None, which means the operations will be performed across all
documents in the database.
"""
self.mode = mode
self.model_name = model_name
Expand All @@ -65,6 +67,7 @@ def __init__(
model_name=model_name,
connection_args=connection_args,
schema_config_or_info_dict=self.schema_config_or_info_dict,
conversation_factory=conversation_factory,
)

self.agent.connect()
Expand Down Expand Up @@ -115,7 +118,11 @@ def generate_responses(self, user_question: str) -> list[tuple]:
for result in results
]
elif self.mode == RagAgentModeEnum.VectorStore:
results = self.query_func(user_question, self.n_results, doc_ids=self.documentids_workspace)
results = self.query_func(
user_question,
self.n_results,
doc_ids=self.documentids_workspace,
)
return [
(
result.page_content,
Expand Down
36 changes: 15 additions & 21 deletions biochatter/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,57 +47,51 @@ def __init__(
Args:
used (bool, optional): whether RAG has been used (ChatGSE setting).
Defaults to False.
Defaults to False.
online (bool, optional): whether we are running ChatGSE online.
Defaults to False.
Defaults to False.
chunk_size (int, optional): size of chunks to split text into.
Defaults to 1000.
Defaults to 1000.
chunk_overlap (int, optional): overlap between chunks. Defaults to 0.
split_by_characters (bool, optional): whether to split by characters
or tokens. Defaults to True.
or tokens. Defaults to True.
separators (Optional[list], optional): list of separators to use when
splitting by characters. Defaults to [" ", ",", "\n"].
splitting by characters. Defaults to [" ", ",", "\n"].
n_results (int, optional): number of results to return from
similarity search. Defaults to 3.
similarity search. Defaults to 3.
model (Optional[str], optional): name of model to use for embeddings.
Defaults to 'text-embedding-ada-002'.
Defaults to 'text-embedding-ada-002'.
vector_db_vendor (Optional[str], optional): name of vector database
to use. Defaults to Milvus.
to use. Defaults to Milvus.
connection_args (Optional[dict], optional): arguments to pass to
vector database connection. Defaults to None.
embedding_collection_name (Optional[str], optional): name of
collection to store embeddings in. Defaults to 'DocumentEmbeddings'.
metadata_collection_name (Optional[str], optional): name of
collection to store metadata in. Defaults to 'DocumentMetadata'.
vector database connection. Defaults to None.
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
base_url (Optional[str], optional): base url of OpenAI API.
embeddings (Optional[OpenAIEmbeddings | XinferenceEmbeddings],
optional): Embeddings object to use. Defaults to OpenAI.
optional): Embeddings object to use. Defaults to OpenAI.
documentids_workspace (Optional[List[str]], optional): a list of document IDs
that defines the scope within which rag operations (remove, similarity search,
and get all) occur. Defaults to None, which means the operations will be
performed across all documents in the database.
that defines the scope within which rag operations (remove, similarity search,
and get all) occur. Defaults to None, which means the operations will be
performed across all documents in the database.
is_azure (Optional[bool], optional): if we are using Azure
azure_deployment (Optional[str], optional): Azure embeddings model deployment,
should work with azure_endpoint when is_azure is True
should work with azure_endpoint when is_azure is True
azure_endpoint (Optional[str], optional): Azure endpoint, should work with
azure_deployment when is_azure is True
azure_deployment when is_azure is True
"""
self.used = used
Expand Down
5 changes: 3 additions & 2 deletions test/test_database_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_get_query_results():
"password": "password",
},
{"schema_config": "test_schema"},
None,
)
db_agent.connect() # Call the connect method to initialize the driver

Expand All @@ -24,12 +25,12 @@ def test_get_query_results():
with mock.patch.object(
db_agent.driver,
"query",
return_value=[
return_value=[[
{"key": "value"},
{"key": "value"},
{"key": "value"},
{"key": "value"},
],
], {}],
):
result = db_agent.get_query_results("test_query", 3)

Expand Down

0 comments on commit c1e98e5

Please sign in to comment.