-
Notifications
You must be signed in to change notification settings - Fork 6
/
sql_assistant.py
73 lines (63 loc) · 2.28 KB
/
sql_assistant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from function import Function, Property
from dotenv import load_dotenv
from assistant import AIAssistant
from sqlalchemy import create_engine, inspect
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
load_dotenv()
class GetDBSchema(Function):
def __init__(self):
super().__init__(
name="get_db_schema",
description="Get the schema of the database",
)
def function(self):
engine = create_engine("sqlite:///assistants_files/Chinook.sqlite")
inspector = inspect(engine)
table_names = inspector.get_table_names()
schema_statements = []
for table_name in table_names:
table = inspector.get_columns(table_name)
create_statement = f"CREATE TABLE {table_name} (\n"
create_statement += ",\n".join(
f"{col['name']} {col['type']}" for col in table
)
create_statement += "\n);"
schema_statements.append(create_statement)
return "\n\n".join(schema_statements)
class RunSQLQuery(Function):
def __init__(self):
super().__init__(
name="run_sql_query",
description="Run a SQL query on the database",
parameters=[
Property(
name="query",
description="The SQL query to run",
type="string",
required=True,
),
],
)
def function(self, query):
engine = create_engine("sqlite:///assistants_files/Chinook.sqlite")
Session = sessionmaker(bind=engine)
session = Session()
try:
results = session.execute(text(query)).fetchall()
return '\n'.join([str(result) for result in results])
except Exception as e:
return str(e)
finally:
session.close()
if __name__ == "__main__":
assistant = AIAssistant(
instruction="""
You are a SQL expert. User asks you questions about the Chinook database.
First obtain the schema of the database to check the tables and columns, then generate SQL queries to answer the questions.
""",
model="gpt-3.5-turbo-1106",
functions=[GetDBSchema(), RunSQLQuery()],
use_code_interpreter=True,
)
assistant.chat()