Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【feat】add database engine and table name to support table ddl update #463

Merged
merged 10 commits into from
Nov 21, 2024
86 changes: 82 additions & 4 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import sqlparse

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
from ..types import TrainingPlan, TrainingPlanItem, TableMetadata
from ..utils import validate_config_path


Expand Down Expand Up @@ -210,6 +210,54 @@ def extract_sql(self, llm_response: str) -> str:

return llm_response

def extract_table_metadata(ddl: str) -> TableMetadata:
"""
Example:
```python
vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)")
```

Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata.
Override this function if your DDL statements need custom extraction logic.

Args:
ddl (str): The DDL statement.

Returns:
TableMetadata: The extracted table metadata.
"""
pattern_with_catalog_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_table = re.compile(
r'CREATE TABLE\s+(\w+)\s*\(',
re.IGNORECASE
)

match_with_catalog_schema = pattern_with_catalog_schema.search(ddl)
match_with_schema = pattern_with_schema.search(ddl)
match_with_table = pattern_with_table.search(ddl)

if match_with_catalog_schema:
catalog = match_with_catalog_schema.group(1)
schema = match_with_catalog_schema.group(2)
table_name = match_with_catalog_schema.group(3)
return TableMetadata(catalog, schema, table_name)
elif match_with_schema:
schema = match_with_schema.group(1)
table_name = match_with_schema.group(2)
return TableMetadata(None, schema, table_name)
elif match_with_table:
table_name = match_with_table.group(1)
return TableMetadata(None, None, table_name)
else:
return TableMetadata()

def is_sql_valid(self, sql: str) -> bool:
"""
Example:
Expand Down Expand Up @@ -368,6 +416,31 @@ def get_related_ddl(self, question: str, **kwargs) -> list:
"""
pass

@abstractmethod
def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
"""
This method is used to get similar tables metadata.

Args:
engine (str): The database engine.
catalog (str): The catalog.
schema (str): The schema.
table_name (str): The table name.
ddl (str): The DDL statement.
size (int): The number of tables to return.

Returns:
list: A list of tables metadata.
"""
pass

@abstractmethod
def get_related_documentation(self, question: str, **kwargs) -> list:
"""
Expand Down Expand Up @@ -396,12 +469,13 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
pass

@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.

Args:
ddl (str): The DDL statement to add.
engine (str): The database engine that the DDL statement applies to.

Returns:
str: The ID of the training data that was added.
Expand Down Expand Up @@ -1738,6 +1812,7 @@ def train(
question: str = None,
sql: str = None,
ddl: str = None,
engine: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
Expand All @@ -1758,8 +1833,11 @@ def train(
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
engine (str): The database engine.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
Returns:
str: The training pl
"""

if question and not sql:
Expand All @@ -1777,12 +1855,12 @@ def train(

if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
return self.add_ddl(ddl=ddl, engine=engine)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
self.add_ddl(ddl=item.item_value, engine=engine)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
Expand Down
93 changes: 87 additions & 6 deletions src/vanna/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import pandas as pd
from opensearchpy import OpenSearch
from ..types import TableMetadata

from ..base import VannaBase
from ..utils import deterministic_uuid


class OpenSearch_VectorStore(VannaBase):
Expand Down Expand Up @@ -56,6 +58,18 @@ def __init__(self, config=None):
},
"mappings": {
"properties": {
"engine": {
"type": "keyword",
},
"catalog": {
"type": "keyword",
},
"schema": {
"type": "keyword",
},
"table_name": {
"type": "keyword",
},
"ddl": {
"type": "text",
},
Expand Down Expand Up @@ -92,6 +106,8 @@ def __init__(self, config=None):
if config is not None and "es_question_sql_index_settings" in config:
question_sql_index_settings = config["es_question_sql_index_settings"]

self.n_results = config.get("n_results", 10)

self.document_index_settings = document_index_settings
self.ddl_index_settings = ddl_index_settings
self.question_sql_index_settings = question_sql_index_settings
Expand Down Expand Up @@ -231,10 +247,29 @@ def create_index_if_not_exists(self, index_name: str,
print(f"Error creating index: {index_name} ", e)
return False

def add_ddl(self, ddl: str, **kwargs) -> str:
def calculate_md5(self, string: str) -> str:
# 将字符串编码为 bytes
string_bytes = self.encode('utf-8')
# 计算 MD5 哈希值
md5_hash = hashlib.md5(string_bytes)
# 获取十六进制表示的哈希值
md5_hex = md5_hash.hexdigest()
return md5_hex

def add_ddl(self, ddl: str, engine: str = None,
**kwargs) -> str:
# Assuming that you have a DDL index in your OpenSearch
id = str(uuid.uuid4()) + "-ddl"
table_metadata = VannaBase.extract_table_metadata(ddl)
full_table_name = table_metadata.get_full_table_name()
if full_table_name is not None and engine is not None:
id = deterministic_uuid(engine + "-" + full_table_name) + "-ddl"
else:
id = str(uuid.uuid4()) + "-ddl"
ddl_dict = {
"engine": engine,
"catalog": table_metadata.catalog,
"schema": table_metadata.schema,
"table_name": table_metadata.table_name,
"ddl": ddl
}
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
Expand Down Expand Up @@ -270,7 +305,8 @@ def get_related_ddl(self, question: str, **kwargs) -> List[str]:
"match": {
"ddl": question
}
}
},
"size": self.n_results
}
print(query)
response = self.client.search(index=self.ddl_index, body=query,
Expand All @@ -283,7 +319,8 @@ def get_related_documentation(self, question: str, **kwargs) -> List[str]:
"match": {
"doc": question
}
}
},
"size": self.n_results
}
print(query)
response = self.client.search(index=self.document_index,
Expand All @@ -297,7 +334,8 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
"match": {
"question": question
}
}
},
"size": self.n_results
}
print(query)
response = self.client.search(index=self.question_sql_index,
Expand All @@ -306,6 +344,50 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
return [(hit['_source']['question'], hit['_source']['sql']) for hit in
response['hits']['hits']]

def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
# Assume you have some vector search mechanism associated with your data
query = {}
if engine is None and catalog is None and schema is None and table_name is None and ddl is None:
query = {
"query": {
"match_all": {}
}
}
else:
query["query"] = {
"bool": {
"should": [
]
}
}
if engine is not None:
query["query"]["bool"]["should"].append({"match": {"engine": engine}})

if catalog is not None:
query["query"]["bool"]["should"].append({"match": {"catalog": catalog}})

if schema is not None:
query["query"]["bool"]["should"].append({"match": {"schema": schema}})
if table_name is not None:
query["query"]["bool"]["should"].append({"match": {"table_name": table_name}})

if ddl is not None:
query["query"]["bool"]["should"].append({"match": {"ddl": ddl}})

if size > 0:
query["size"] = size

print(query)
response = self.client.search(index=self.ddl_index, body=query, **kwargs)
return [hit['_source'] for hit in response['hits']['hits']]

def get_training_data(self, **kwargs) -> pd.DataFrame:
# This will be a simple example pulling all data from an index
# WARNING: Do not use this approach in production for large indices!
Expand All @@ -315,7 +397,6 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
body={"query": {"match_all": {}}},
size=1000
)
print(query)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
Expand Down
26 changes: 26 additions & 0 deletions src/vanna/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,29 @@ def remove_item(self, item: str):
if str(plan_item) == item:
self._plan.remove(plan_item)
break


class TableMetadata:
def __init__(self, catalog=None, schema=None, table_name=None):
self.catalog = catalog
self.schema = schema
self.table_name = table_name

def __str__(self):
parts = []
if self.catalog:
parts.append(f"Catalog: {self.catalog}")
if self.schema:
parts.append(f"Schema: {self.schema}")
if self.table_name:
parts.append(f"Table: {self.table_name}")
return "\n".join(parts) if parts else "No match found"

def get_full_table_name(self):
if self.catalog and self.schema:
return f"{self.catalog}.{self.schema}.{self.table_name}"
elif self.schema:
return f"{self.schema}.{self.table_name}"
else:
return f"{self.table_name}"