Skip to content

Format bug fix #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions docs/docs-ch/Table cars_data has columns such as id, weight.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
Table cars_data has columns such as id, weight.\nid is the primary key.\nTable model_list has columns such as maker, model.

Table model_list has columns such as model, maker.\nTable cars_data has columns such as id, weight.\nid is the primary key.

python evaluation.py --gold /Users/jizha/code/python/spider/dataset/spider/dev_gold.sql --pred /Users/jizha/code/python/test-suite-sql-eval/二轮测试_gpt4_choice.json --etype all --db /Users/jizha/code/python/spider/dataset/spider/database --table tables.json

```sh
python generate_question.py \
--data_type spider \
--split test \
--tokenizer gpt-3.5-turbo \
--max_seq_len 4096 \
--selector_type EUCDISMASKPRESKLSIMTHR \
--pre_test_result /Users/jizha/code/python/test-suite-sql-eval/随机列测试/union_test_20231201_random_table.sql \
--prompt_repr SQL \
--k_shot 9 \
--example_type QA

```

```
import argparse
import os
import json

import openai
from tqdm import tqdm

from llm.chatgpt import init_chatgpt, ask_llm
from utils.enums import LLM
from torch.utils.data import DataLoader

from utils.post_process import process_duplication, get_sqls
import concurrent.futures

QUESTION_FILE = "questions.json"


def gen_predict_sql(index, token_cnt, args, batch):
try:
res = ask_llm(args.model, batch, args.temperature, args.n)
except openai.error.InvalidRequestError:
print(f"The {i}-th question has too much tokens! Return \"SELECT\" instead")
res = ""
# parse result
token_cnt += res["total_tokens"]
results = []
if args.n == 1:
for sql in res["response"]:
# remove \n and extra spaces
sql = " ".join(sql.replace("\n", " ").split())
sql = process_duplication(sql)
# python version should >= 3.8
if sql.startswith("SELECT"):
results.append(sql)
elif sql.startswith(" "):
results.append("SELECT" + sql)
else:
results.append("SELECT " + sql)
else:
cur_db_ids = db_ids[i * args.batch_size: i * args.batch_size + len(batch)]
for sqls, db_id in zip(res["response"], cur_db_ids):
processed_sqls = []
for sql in sqls:
sql = " ".join(sql.replace("\n", " ").split())
sql = process_duplication(sql)
if sql.startswith("SELECT"):
pass
elif sql.startswith(" "):
sql = "SELECT" + sql
else:
sql = "SELECT " + sql
processed_sqls.append(sql)
result = {
'db_id': db_id,
'p_sqls': processed_sqls
}
final_sqls = get_sqls([result], args.n, args.db_dir)
results = final_sqls
return index, results


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--question", type=str)
parser.add_argument("--openai_api_key", type=str, default="eab38a33cc07467aae9b7d09783b75a8")
parser.add_argument("--openai_group_id", type=str, default="luli.wjc")
parser.add_argument("--openai_api_base", type=str,
default="https://codegencore.antgroup-inc.cn/api/chat/commonPower/v1")
parser.add_argument("--model", type=str, choices=[LLM.TEXT_DAVINCI_003,
LLM.GPT_35_TURBO,
LLM.GPT_35_TURBO_0613,
LLM.TONG_YI_QIAN_WEN,
LLM.GPT_35_TURBO_16K,
LLM.GPT_4],
default=LLM.GPT_35_TURBO)
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=1000000)
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--mini_index_path", type=str, default="")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--n", type=int, default=1, help="Size of self-consistent set")
parser.add_argument("--db_dir", type=str, default="dataset/spider/database")
args = parser.parse_args()

# check args
assert args.model in LLM.BATCH_FORWARD or \
args.model not in LLM.BATCH_FORWARD and args.batch_size == 1, \
f"{args.model} doesn't support batch_size > 1"

questions_json = json.load(open(os.path.join(args.question, QUESTION_FILE), "r"))
questions = [_["prompt"] for _ in questions_json["questions"]]
db_ids = [_["db_id"] for _ in questions_json["questions"]]

# init openai api
init_chatgpt(args.openai_api_key, args.openai_group_id, args.openai_api_base, args.model)

if args.start_index == 0:
mode = "w"
else:
mode = "a"

if args.mini_index_path:
mini_index = json.load(open(args.mini_index_path, 'r'))
questions = [questions[i] for i in mini_index]
out_file = f"{args.question}/RESULTS_MODEL-{args.model}_MINI.txt"
else:
out_file = f"{args.question}/RESULTS_MODEL-{args.model}.txt"

question_loader = DataLoader(questions, batch_size=args.batch_size, shuffle=False, drop_last=False)

token_cnt = 0
results = []
with open(out_file, mode) as f:
for i in tqdm(range(0, len(question_loader), 10)):
up = i + 10
if len(question_loader) < up:
up = len(question_loader)
result_temp = [""] * (up - i)
future_list = []
with concurrent.futures.ThreadPoolExecutor() as executor:
question_batch = question_loader[i:up]
for index, item in enumerate(question_batch):
future_list.append(executor.submit(gen_predict_sql, index, token_cnt, args, item))
for future in concurrent.futures.as_completed(future_list):
index, p_sqls = future.result()
result_temp[index] = p_sqls
for item in result_temp:
f.write("".join(item))
results.extend(item)


```

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "sqlgpt-parser"
version = "0.0.1a5"
version = "0.0.1a7"
authors = [
{ name="luliwjc", email="chenxiaoxi_wjc@163.com" },
{ name="Ifffff", email="tingkai.ztk@antgroup.com" },
@@ -35,7 +35,7 @@ dependencies = [
line-length=120

[tool.black]
skip-string-normalization = 1
skip-string-normalization = false
force-exclude = '''
sqlgpt_parser/parser/mysql_parser/parser_table.py
| sqlgpt_parser/parser/oceanbase_parser/parser_table.py
151 changes: 78 additions & 73 deletions sqlgpt_parser/parser/mysql_parser/lexer.py
Original file line number Diff line number Diff line change
@@ -21,43 +21,43 @@

tokens = (
[
'IDENTIFIER',
'DIGIT_IDENTIFIER',
'QUOTED_IDENTIFIER',
'BACKQUOTED_IDENTIFIER',
'PERIOD',
'COMMA',
'PLUS',
'MINUS',
'LPAREN',
'RPAREN',
'ANDAND',
'ASSIGNMENTEQ',
'GT',
'GE',
'LT',
'LE',
'EQ',
'NE',
'NULL_SAFE_EQ',
'BIT_OR',
'BIT_AND',
'BIT_XOR',
'BIT_OPPOSITE',
'EXCLA_MARK',
'BIT_MOVE_LEFT',
'BIT_MOVE_RIGHT',
'PIPES',
'SLASH',
'ASTERISK',
'PERCENT',
'NUMBER',
'FRACTION',
'QM',
'SCONST',
'SINGLE_AT_IDENTIFIER',
'DOUBLE_AT_IDENTIFIER',
'HEX_NUMBER',
"IDENTIFIER",
"DIGIT_IDENTIFIER",
"QUOTED_IDENTIFIER",
"BACKQUOTED_IDENTIFIER",
"PERIOD",
"COMMA",
"PLUS",
"MINUS",
"LPAREN",
"RPAREN",
"ANDAND",
"ASSIGNMENTEQ",
"GT",
"GE",
"LT",
"LE",
"EQ",
"NE",
"NULL_SAFE_EQ",
"BIT_OR",
"BIT_AND",
"BIT_XOR",
"BIT_OPPOSITE",
"EXCLA_MARK",
"BIT_MOVE_LEFT",
"BIT_MOVE_RIGHT",
"PIPES",
"SLASH",
"ASTERISK",
"PERCENT",
"NUMBER",
"FRACTION",
"QM",
"SCONST",
"SINGLE_AT_IDENTIFIER",
"DOUBLE_AT_IDENTIFIER",
"HEX_NUMBER",
]
+ list(reversed)
+ list(nonreserved)
@@ -66,48 +66,48 @@

sql_tokens = list(reversed) + list(nonreserved) + list(not_keyword_token)

t_LPAREN = r'\('
t_RPAREN = r'\)'

t_ASSIGNMENTEQ = r':='
t_EQ = r'='
t_NE = r'<>|!='
t_LT = r'<'
t_LE = r'<='
t_GT = r'>'
t_GE = r'>='
t_NULL_SAFE_EQ = r'<=>'
t_PERIOD = r'\.'
t_COMMA = r','
t_PLUS = r'\+'
t_MINUS = r'-'
t_ASTERISK = r'\*'
t_SLASH = r'/'
t_PERCENT = r'%'
t_QM = r'\?'
t_LPAREN = r"\("
t_RPAREN = r"\)"

t_ASSIGNMENTEQ = r":="
t_EQ = r"="
t_NE = r"<>|!="
t_LT = r"<"
t_LE = r"<="
t_GT = r">"
t_GE = r">="
t_NULL_SAFE_EQ = r"<=>"
t_PERIOD = r"\."
t_COMMA = r","
t_PLUS = r"\+"
t_MINUS = r"-"
t_ASTERISK = r"\*"
t_SLASH = r"/"
t_PERCENT = r"%"
t_QM = r"\?"

# TODO
# By default, || is a logical OR operator.
# With PIPES_AS_CONCAT enabled, || is string concatenation.
# Need support or semantics in future development
t_PIPES = r'\|\|'
t_PIPES = r"\|\|"

t_ignore = ' \t'
t_ignore = " \t"

t_ANDAND = r'\&\&'
t_BIT_OR = r'\|'
t_BIT_AND = r'\&'
t_BIT_XOR = r'\^'
t_BIT_OPPOSITE = r'\~'
t_BIT_MOVE_LEFT = r'<<'
t_BIT_MOVE_RIGHT = r'>>'
t_EXCLA_MARK = r'!'
t_ANDAND = r"\&\&"
t_BIT_OR = r"\|"
t_BIT_AND = r"\&"
t_BIT_XOR = r"\^"
t_BIT_OPPOSITE = r"\~"
t_BIT_MOVE_LEFT = r"<<"
t_BIT_MOVE_RIGHT = r">>"
t_EXCLA_MARK = r"!"


def t_DOUBLE(t):
r"[0-9]*\.[0-9]+([eE][-+]?[0-9]+)?|[-+]?[0-9]+([eE][-+]?[0-9]+)"
if 'e' in t.value or 'E' in t.value or '.' in t.value:
t.type = 'FRACTION'
if "e" in t.value or "E" in t.value or "." in t.value:
t.type = "FRACTION"
else:
t.type = "NUMBER"
return t
@@ -129,7 +129,7 @@ def t_NUMBER_START_WITH_XB(t):
def t_IDENTIFIER(t):
r"""[a-zA-Z\u4e00-\u9fa50-9_$][a-zA-Z\u4e00-\u9fa50-9_@:$]*"""
if re.match(
r'(^0[xX][0-9a-fA-F]+$)|(^0[bB][01]+$)|(^\d+$)',
r"(^0[xX][0-9a-fA-F]+$)|(^0[bB][01]+$)|(^\d+$)",
t.value,
):
t.type = "NUMBER"
@@ -155,21 +155,21 @@ def t_DOUBLE_AT_IDENTIFIER(t):


def t_QUOTED_IDENTIFIER(t):
r'"(\\["\\]|[^"]|["]{2})*"'
r""" "(\\["\\]|[^"]|["]{2})*\" """
t.type = "QUOTED_IDENTIFIER"
return t


def t_BACKQUOTED_IDENTIFIER(t):
r'`([^`]|``)*`'
r"""`([^`]|``)*`"""
val = t.value.lower()
if val in tokens:
t.type = tokens[val]
return t


def t_newline(t):
r'[\r\n]+'
r"""[\r\n]+"""
t.lexer.lineno += t.value.count("\n")


@@ -179,7 +179,12 @@ def t_error(t):


def t_COMMENT(t):
r'(\/\*\*\/)|(/\*((?!\/\*).)+\*/)'
r"""(\/\*\*\/)|(/\*((?!\/\*).)+\*/)"""
pass


def t_SEMICOLON(t):
r""";"""
pass


Loading