-
Notifications
You must be signed in to change notification settings - Fork 7
/
multi_database.py
132 lines (98 loc) · 4.67 KB
/
multi_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import Callable, Dict, Iterable, List, Optional
from llama_hub.tools.database.base import DatabaseToolSpec
from llama_index import Document
from llama_index.readers.base import BaseReader
from llama_index.tools.tool_spec.base import BaseToolSpec
from sqlalchemy import text
from sqlalchemy.exc import InvalidRequestError
class NoSuchDatabaseError(InvalidRequestError):
"""Database does not exist or is not visible to a connection."""
class TrackingDatabaseToolSpec(DatabaseToolSpec):
handler: Callable[[str, str, Iterable], None]
database_name: str
def set_handler(self, func: Callable) -> None:
self.handler = func
def set_database_name(self, database_name: str) -> None:
self.database_name = database_name
def load_data(self, query: str) -> List[Document]:
"""Query and load data from the Database, returning a list of Documents.
Args:
query (str): an SQL query to filter tables and rows.
Returns:
List[Document]: A list of Document objects.
"""
documents = []
with self.sql_database.engine.connect() as connection:
if query is None:
raise ValueError("A query parameter is necessary to filter the data")
else:
result = connection.execute(text(query))
items = result.fetchall()
if self.handler:
self.handler(self.database_name, query, items)
for item in items:
# fetch each item
doc_str = ", ".join([str(entry) for entry in item])
documents.append(Document(text=doc_str))
return documents
class MultiDatabaseToolSpec(BaseToolSpec, BaseReader):
database_specs: Dict[str, TrackingDatabaseToolSpec]
handler: Callable[[str, str, Iterable], None]
spec_functions = ["load_data", "describe_tables", "list_tables", "list_databases"]
def __init__(
self,
database_toolspec_mapping: Optional[Dict[str, TrackingDatabaseToolSpec]] = None,
handler: Optional[Callable[[str, str, Iterable], None]] = None,
) -> None:
self.database_specs = database_toolspec_mapping or dict()
self.handler = handler
for spec in self.database_specs.values():
spec.set_handler(self.handler)
def add_connection(self, database_name: str, uri: str) -> None:
spec = TrackingDatabaseToolSpec(uri=uri)
spec.set_handler(self.handler)
spec.set_database_name(database_name)
self.database_specs[database_name] = spec
def add_database_tool_spec(self, database_name: str, tool_spec: TrackingDatabaseToolSpec) -> None:
tool_spec.set_handler(self.handler)
tool_spec.set_database_name(database_name)
self.database_specs[database_name] = tool_spec
def load_data(self, database: str, query: str) -> List[Document]:
"""Query and load data from the given Database, returning a list of Documents.
Args:
database (str): A database name to query and load data from
query (str): an SQL query to filter tables and rows.
Returns:
List[Document]: A list of Document objects.
"""
if database not in self.database_specs:
raise NoSuchDatabaseError(f"Database '{database}' does not exist.")
return self.database_specs[database].load_data(query)
def describe_tables(self, database: str, tables: Optional[List[str]] = None) -> str:
"""
Describes the specifed tables in the given database
Args:
database (str): A database name to retrieve the table details from
tables (List[str]): A list of table names to retrieve details about
"""
if database not in self.database_specs:
raise NoSuchDatabaseError(f"Database '{database}' does not exist.")
return self.database_specs[database].describe_tables(tables)
def list_tables(self, database: str) -> List[str]:
"""
Returns a list of available tables in the database.
To retrieve details about the columns of specfic tables, use
the describe_tables endpoint
Args:
database (str): A database name to retrieve the list of tables from
"""
if database not in self.database_specs:
raise NoSuchDatabaseError(f"Database '{database}' does not exist.")
return self.database_specs[database].list_tables()
def list_databases(self) -> List[str]:
"""
Returns a list of available databases.
To retrieve details about the tables of a specfic database, use
the list_tables endpoint
"""
return list(self.database_specs.keys())