Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wip/dependencies #40

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ Flask==2.3.2
Flask-Cors==4.0.0
Flask-RESTful==0.3.10
requests==2.31.0
tiktoken==0.4.0
psycopg2-binary==2.9.7 # you can also install from source if it works
tiktoken>=0.4.0
pglast==5.3
litellm==1.34.34
platformdirs>=4.0.0
litellm>=1.34.34
platformdirs>=4.0.0
39 changes: 17 additions & 22 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@

# Define your dependencies
install_requires = [
'Jinja2==3.1.2',
'Flask==2.3.2',
'Flask-Cors==4.0.0',
'Flask-RESTful==0.3.10',
'requests==2.31.0',
'tiktoken==0.4.0',
'psycopg2-binary==2.9.7',
'pglast==5.3',
'litellm==1.34.34',
'platformdirs>=4.0.0',
'sqlparse~=0.5.0'
"Jinja2==3.1.2",
"Flask==2.3.2",
"Flask-Cors==4.0.0",
"Flask-RESTful==0.3.10",
"requests==2.31.0",
"tiktoken>=0.4.0",
"pglast>=6.10",
"litellm>=1.34.34",
"platformdirs>=4.0.0",
"sqlparse~=0.5.0",
]

install_dev_requires = [
'spacy==3.6.0',
'FlagEmbedding~=1.2.5',
"spacy==3.6.0",
"FlagEmbedding~=1.2.5",
]

# Additional package information
Expand All @@ -46,19 +45,15 @@
name=name,
version=version,
description=description,
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
author=author,
author_email=author_email,
packages=packages,
package_dir={"": "src"},
install_requires=install_requires,
extra_requires={
"dev": install_dev_requires
},
extra_requires={"dev": install_dev_requires},
url=url,
classifiers=classifiers,
package_data={
"": ["*.prompt"]
}
)
package_data={"": ["*.prompt"]},
)
33 changes: 23 additions & 10 deletions src/suql/free_text_fcns_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@
def _answer(
source,
query,
type_prompt = None,
type_prompt=None,
k=5,
max_input_token=10000,
engine="gpt-3.5-turbo-0125"
engine="gpt-3.5-turbo-0125",
api_base=None,
api_version=None,
):
from suql.prompt_continuation import llm_generate

if not source:
return {"result": "no information"}

text_res = []
if isinstance(source, list):
documents = compute_top_similarity_documents(
source, query, top=k
)
documents = compute_top_similarity_documents(source, query, top=k)
for i in documents:
if num_tokens_from_string("\n".join(text_res + [i])) < max_input_token:
text_res.append(i)
Expand Down Expand Up @@ -63,11 +64,20 @@ def _answer(
temperature=0.0,
stop_tokens=[],
postprocess=False,
api_base=api_base,
api_version=api_version,
)
return {"result": continuation}


def start_free_text_fncs_server(
host="127.0.0.1", port=8500, k=5, max_input_token=3800, engine="gpt-4o-mini"
host="127.0.0.1",
port=8500,
k=5,
max_input_token=3800,
engine="gpt-4o-mini",
api_base=None,
api_version=None,
):
"""
Set up a free text functions server for the free text
Expand Down Expand Up @@ -115,11 +125,12 @@ def answer():
data["text"],
data["question"],
type_prompt=data["type_prompt"] if "type_prompt" in data else None,
k = k,
max_input_token = max_input_token,
engine = engine
k=k,
max_input_token=max_input_token,
engine=engine,
api_base=api_base,
api_version=api_version,
)


@app.route("/summary", methods=["POST"])
def summary():
Expand Down Expand Up @@ -166,6 +177,8 @@ def summary():
temperature=0.0,
stop_tokens=["\n"],
postprocess=False,
api_base=api_base,
api_version=api_version,
)

res = {"result": continuation}
Expand Down
58 changes: 34 additions & 24 deletions src/suql/postgresql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def execute_sql(
data=None,
commit_in_lieu_fetch=False,
no_print=False,
unprotected=False
unprotected=False,
host="127.0.0.1",
port="5432",
):
start_time = time.time()

Expand All @@ -21,16 +23,16 @@ def execute_sql(
dbname=database,
user=user,
host="/var/run/postgresql",
port="5432",
port=port,
options="-c statement_timeout=30000 -c client_encoding=UTF8",
)
else:
conn = psycopg2.connect(
database=database,
user=user,
password=password,
host="127.0.0.1",
port="5432",
host=host,
port=port,
options="-c statement_timeout=30000 -c client_encoding=UTF8",
)

Expand All @@ -57,7 +59,7 @@ def sql_unprotected():
else:
results = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description]

return results, column_names

try:
Expand Down Expand Up @@ -85,14 +87,16 @@ def execute_sql_with_column_info(
user="select_user",
password="select_user",
unprotected=False,
host="127.0.0.1",
port="5432",
):
# Establish a connection to the PostgreSQL database
conn = psycopg2.connect(
database=database,
user=user,
password=password,
host="127.0.0.1",
port="5432",
host=host,
port=port,
options="-c statement_timeout=30000 -c client_encoding=UTF8",
)

Expand Down Expand Up @@ -125,7 +129,7 @@ def sql_unprotected():

column_types = [type_map[oid] for oid in column_type_oids]
column_info = list(zip(column_names, column_types))

return results, column_info

try:
Expand All @@ -141,12 +145,15 @@ def sql_unprotected():
conn.close()
return list(results), column_info


def split_sql_statements(query):
def strip_trailing_comments(stmt):
idx = len(stmt.tokens) - 1
while idx >= 0:
tok = stmt.tokens[idx]
if tok.is_whitespace or sqlparse.utils.imt(tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment):
if tok.is_whitespace or sqlparse.utils.imt(
tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment
):
stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ")
else:
break
Expand All @@ -159,8 +166,13 @@ def strip_trailing_semicolon(stmt):
tok = stmt.tokens[idx]
# we expect that trailing comments already are removed
if not tok.is_whitespace:
if sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation) and tok.value == ";":
stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ")
if (
sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation)
and tok.value == ";"
):
stmt.tokens[idx] = sqlparse.sql.Token(
sqlparse.tokens.Whitespace, " "
)
break
idx -= 1
return stmt
Expand All @@ -187,15 +199,16 @@ def is_empty_statement(stmt):

return [""] # if all statements were empty - return a single empty statement


def query_is_select_no_limit(query):
limit_keywords = ["LIMIT", "OFFSET"]

def find_last_keyword_idx(parsed_query):
for i in reversed(range(len(parsed_query.tokens))):
if parsed_query.tokens[i].ttype in sqlparse.tokens.Keyword:
return i
return -1

parsed_query = sqlparse.parse(query)[0]
last_keyword_idx = find_last_keyword_idx(parsed_query)
# Either invalid query or query that is not select
Expand All @@ -206,10 +219,8 @@ def find_last_keyword_idx(parsed_query):

return no_limit

def add_limit_to_query(
query,
limit_query = " LIMIT 1000"
):

def add_limit_to_query(query, limit_query=" LIMIT 1000"):
parsed_query = sqlparse.parse(query)[0]
limit_tokens = sqlparse.parse(limit_query)[0].tokens
length = len(parsed_query.tokens)
Expand All @@ -220,22 +231,21 @@ def add_limit_to_query(

return str(parsed_query)

def apply_auto_limit(
query_text,
limit_query = " LIMIT 1000"
):

def apply_auto_limit(query_text, limit_query=" LIMIT 1000"):
def combine_sql_statements(queries):
return ";\n".join(queries)

queries = split_sql_statements(query_text)
res = []
for query in queries:
if query_is_select_no_limit(query):
query = add_limit_to_query(query, limit_query=limit_query)
res.append(query)

return combine_sql_statements(res)


if __name__ == "__main__":
print(apply_auto_limit("SELECT * FROM restaurants LIMIT 1;"))
print(apply_auto_limit("SELECT * FROM restaurants;"))
print(apply_auto_limit("SELECT * FROM restaurants;"))
22 changes: 16 additions & 6 deletions src/suql/prompt_continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@
"""

import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import List

import os
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from threading import Thread
from typing import List

from jinja2 import Environment, FileSystemLoader, select_autoescape

from suql.utils import num_tokens_from_string
from litellm import completion, completion_cost

from suql.utils import num_tokens_from_string

logger = logging.getLogger(__name__)
# create file handler which logs even debug messages
Expand All @@ -36,11 +33,14 @@
ENABLE_CACHING = False
if ENABLE_CACHING:
import pymongo

mongo_client = pymongo.MongoClient("localhost", 27017)
prompt_cache_db = mongo_client["open_ai_prompts"]["caches"]


total_cost = 0 # in USD


def get_total_cost():
global total_cost
return total_cost
Expand Down Expand Up @@ -75,6 +75,8 @@ def _generate(
postprocess,
max_tries,
ban_line_break_start,
api_base=None,
api_version=None,
):
# don't try multiple times if the temperature is 0, because the results will be the same
if max_tries > 1 and temperature == 0:
Expand All @@ -96,6 +98,8 @@ def _generate(
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stop": stop_tokens,
"api_base": api_base,
"api_version": api_version,
}

generation_output = chat_completion_with_backoff(**kwargs)
Expand Down Expand Up @@ -198,6 +202,8 @@ def llm_generate(
filled_prompt=None,
attempts=2,
max_wait_time=None,
api_base=None,
api_version=None,
):
"""
filled_prompt gives direct access to the underlying model, without having to load a prompt template from a .prompt file. Used for testing.
Expand Down Expand Up @@ -247,6 +253,8 @@ def llm_generate(
postprocess,
max_tries,
ban_line_break_start,
api_base,
api_version,
)
if success:
final_result = result
Expand All @@ -265,6 +273,8 @@ def llm_generate(
postprocess,
max_tries,
ban_line_break_start,
api_version,
api_base,
)

end_time = time.time()
Expand Down
Loading