diff --git a/src/vanna/pgvector/__init__.py b/src/vanna/pgvector/__init__.py index 1633653c..dd152a30 100644 --- a/src/vanna/pgvector/__init__.py +++ b/src/vanna/pgvector/__init__.py @@ -1,2 +1 @@ from .pgvector import PG_VectorStore -from .pgvecto_rs import PG_Vecto_rsStore diff --git a/src/vanna/pgvector/pgvecto_rs.py b/src/vanna/pgvector/pgvecto_rs.py deleted file mode 100644 index 0970e553..00000000 --- a/src/vanna/pgvector/pgvecto_rs.py +++ /dev/null @@ -1,269 +0,0 @@ -import ast -import json -import logging -import uuid - -import pandas as pd -from langchain_core.documents import Document -from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs -from sqlalchemy import create_engine, text - -from .. import ValidationError -from ..base import VannaBase -from ..types import TrainingPlan, TrainingPlanItem -from ..utils import deterministic_uuid - - -class PG_Vecto_rsStore(VannaBase): - def __init__(self, config=None): - if not config or "connection_string" not in config: - raise ValueError( - "A valid 'config' dictionary with a 'connection_string' is required.") - - VannaBase.__init__(self, config=config) - - if config and "connection_string" in config: - self.connection_string = config.get("connection_string") - self.n_results = config.get("n_results", 10) - - if config and "embedding_function" in config: - self.embedding_function = config.get("embedding_function") - self.vector_dimension = config.get("vector_dimension") - else: - from langchain_huggingface import HuggingFaceEmbeddings - self.embedding_function = HuggingFaceEmbeddings( - model_name="all-MiniLM-L6-v2") - self.vector_dimension = 384 - self.sql_collection = PGVecto_rs( - embedding=self.embedding_function, - collection_name="sql", - db_url=self.connection_string, - dimension=self.vector_dimension, - ) - self.ddl_collection = PGVecto_rs( - embedding=self.embedding_function, - collection_name="ddl", - db_url=self.connection_string, - dimension=self.vector_dimension, - ) - self.documentation_collection = PGVecto_rs( - embedding=self.embedding_function, - collection_name="documentation", - db_url=self.connection_string, - dimension=self.vector_dimension, - ) - - def add_question_sql(self, question: str, sql: str, **kwargs) -> str: - question_sql_json = json.dumps( - { - "question": question, - "sql": sql, - }, - ensure_ascii=False, - ) - id = deterministic_uuid(question_sql_json) + "-sql" - createdat = kwargs.get("createdat") - doc = Document( - page_content=question_sql_json, - metadata={"id": id, "createdat": createdat}, - ) - self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]]) - - return id - - def add_ddl(self, ddl: str, **kwargs) -> str: - _id = deterministic_uuid(ddl) + "-ddl" - doc = Document( - page_content=ddl, - metadata={"id": _id}, - ) - self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]]) - return _id - - def add_documentation(self, documentation: str, **kwargs) -> str: - _id = deterministic_uuid(documentation) + "-doc" - doc = Document( - page_content=documentation, - metadata={"id": _id}, - ) - self.documentation_collection.add_documents([doc], - ids=[doc.metadata["id"]]) - return _id - - def get_collection(self, collection_name): - match collection_name: - case "sql": - return self.sql_collection - case "ddl": - return self.ddl_collection - case "documentation": - return self.documentation_collection - case _: - raise ValueError("Specified collection does not exist.") - - def get_similar_question_sql(self, question: str, **kwargs) -> list: - documents = self.sql_collection.similarity_search(query=question, - k=self.n_results) - return [ast.literal_eval(document.page_content) for document in documents] - - def get_related_ddl(self, question: str, **kwargs) -> list: - documents = self.ddl_collection.similarity_search(query=question, - k=self.n_results) - return [document.page_content for document in documents] - - def get_related_documentation(self, question: str, **kwargs) -> list: - documents = self.documentation_collection.similarity_search(query=question, - k=self.n_results) - return [document.page_content for document in documents] - - def train( - self, - question: str | None = None, - sql: str | None = None, - ddl: str | None = None, - documentation: str | None = None, - plan: TrainingPlan | None = None, - createdat: str | None = None, - ): - if question and not sql: - raise ValidationError("Please provide a SQL query.") - - if documentation: - logging.info(f"Adding documentation: {documentation}") - return self.add_documentation(documentation) - - if sql and question: - return self.add_question_sql(question=question, sql=sql, - createdat=createdat) - - if ddl: - logging.info(f"Adding ddl: {ddl}") - return self.add_ddl(ddl) - - if plan: - for item in plan._plan: - if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: - self.add_ddl(item.item_value) - elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: - self.add_documentation(item.item_value) - elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name: - self.add_question_sql(question=item.item_name, sql=item.item_value) - - def get_training_data(self, **kwargs) -> pd.DataFrame: - # Establishing the connection - engine = create_engine(self.connection_string) - - # Querying the 'langchain_pg_embedding' table - query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding" - df_embedding = pd.read_sql(query_embedding, engine) - - # List to accumulate the processed rows - processed_rows = [] - - # Process each row in the DataFrame - for _, row in df_embedding.iterrows(): - custom_id = row["cmetadata"]["id"] - document = row["document"] - training_data_type = "documentation" if custom_id[ - -3:] == "doc" else custom_id[-3:] - - if training_data_type == "sql": - # Convert the document string to a dictionary - try: - doc_dict = ast.literal_eval(document) - question = doc_dict.get("question") - content = doc_dict.get("sql") - except (ValueError, SyntaxError): - logging.info( - f"Skipping row with custom_id {custom_id} due to parsing error.") - continue - elif training_data_type in ["documentation", "ddl"]: - question = None # Default value for question - content = document - else: - # If the suffix is not recognized, skip this row - logging.info( - f"Skipping row with custom_id {custom_id} due to unrecognized training data type.") - continue - - # Append the processed data to the list - processed_rows.append( - {"id": custom_id, "question": question, "content": content, - "training_data_type": training_data_type} - ) - - # Create a DataFrame from the list of processed rows - df_processed = pd.DataFrame(processed_rows) - - return df_processed - - def remove_training_data(self, id: str, **kwargs) -> bool: - # Create the database engine - engine = create_engine(self.connection_string) - - # SQL DELETE statement - delete_statement = text( - """ - DELETE FROM langchain_pg_embedding - WHERE cmetadata ->> 'id' = :id - """ - ) - - # Connect to the database and execute the delete statement - with engine.connect() as connection: - # Start a transaction - with connection.begin() as transaction: - try: - result = connection.execute(delete_statement, {"id": id}) - # Commit the transaction if the delete was successful - transaction.commit() - # Check if any row was deleted and return True or False accordingly - return result.rowcount() > 0 - except Exception as e: - # Rollback the transaction in case of error - logging.error(f"An error occurred: {e}") - transaction.rollback() - return False - - def remove_collection(self, collection_name: str) -> bool: - engine = create_engine(self.connection_string) - - # Determine the suffix to look for based on the collection name - suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"} - suffix = suffix_map.get(collection_name) - - if not suffix: - logging.info( - "Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.") - return False - - # SQL query to delete rows based on the condition - query = text( - f""" - DELETE FROM langchain_pg_embedding - WHERE cmetadata->>'id' LIKE '%{suffix}' - """ - ) - - # Execute the deletion within a transaction block - with engine.connect() as connection: - with connection.begin() as transaction: - try: - result = connection.execute(query) - transaction.commit() # Explicitly commit the transaction - if result.rowcount() > 0: - logging.info( - f"Deleted {result.rowcount()} rows from " - f"langchain_pg_embedding where collection is {collection_name}." - ) - return True - else: - logging.info(f"No rows deleted for collection {collection_name}.") - return False - except Exception as e: - logging.error(f"An error occurred: {e}") - transaction.rollback() # Rollback in case of error - return False - - def generate_embedding(self, *args, **kwargs): - pass diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index ede27497..3cddeb46 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -11,7 +11,6 @@ from .. import ValidationError from ..base import VannaBase from ..types import TrainingPlan, TrainingPlanItem -from ..utils import deterministic_uuid class PG_VectorStore(VannaBase): @@ -56,7 +55,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: }, ensure_ascii=False, ) - id = deterministic_uuid(question_sql_json) + "-sql" + id = str(uuid.uuid4()) + "-sql" createdat = kwargs.get("createdat") doc = Document( page_content=question_sql_json, @@ -67,7 +66,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: return id def add_ddl(self, ddl: str, **kwargs) -> str: - _id = deterministic_uuid(ddl) + "-ddl" + _id = str(uuid.uuid4()) + "-ddl" doc = Document( page_content=ddl, metadata={"id": _id}, @@ -76,7 +75,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str: return _id def add_documentation(self, documentation: str, **kwargs) -> str: - _id = deterministic_uuid(documentation) + "-doc" + _id = str(uuid.uuid4()) + "-doc" doc = Document( page_content=documentation, metadata={"id": _id}, @@ -95,7 +94,7 @@ def get_collection(self, collection_name): case _: raise ValueError("Specified collection does not exist.") - def get_similar_question_sql(self, question: str, **kwargs) -> list: + def get_similar_question_sql(self, question: str) -> list: documents = self.sql_collection.similarity_search(query=question, k=self.n_results) return [ast.literal_eval(document.page_content) for document in documents] @@ -204,7 +203,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool: # Commit the transaction if the delete was successful transaction.commit() # Check if any row was deleted and return True or False accordingly - return result.rowcount() > 0 + return result.rowcount > 0 except Exception as e: # Rollback the transaction in case of error logging.error(f"An error occurred: {e}") @@ -236,9 +235,9 @@ def remove_collection(self, collection_name: str) -> bool: try: result = connection.execute(query) transaction.commit() # Explicitly commit the transaction - if result.rowcount() > 0: + if result.rowcount > 0: logging.info( - f"Deleted {result.rowcount()} rows from " + f"Deleted {result.rowcount} rows from " f"langchain_pg_embedding where collection is {collection_name}." ) return True