-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add execution accuracy in spider example (#140)
- Loading branch information
1 parent
d26cb56
commit 2bc5465
Showing
11 changed files
with
226 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
--- | ||
"@empiricalrun/scorer": patch | ||
"@empiricalrun/cli": patch | ||
"@empiricalrun/ai": patch | ||
"web": patch | ||
--- | ||
|
||
fix: minor improvements for execution accuracy in spider example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,37 @@ | ||
## Example for Spider | ||
# Text-to-SQL | ||
|
||
``` | ||
npx @empiricalrun/cli run | ||
``` | ||
LLMs are good at converting natural language questions to SQL queries. This examples uses that | ||
scenario to demo Empirical. This example is based on the [Spider](https://github.com/taoyds/spider) dataset. | ||
|
||
In this example, we generate SQL queries, and score them on | ||
|
||
1. SQL syntax (with the `sql-syntax` scorer): Checks if the output syntax is valid SQL. For example, if the output is in | ||
markdown syntax (with backticks), it is not a valid SQL query. | ||
2. Execution accuracy (with the `py-script` scorer): We run the generated SQL query against a test database, and check | ||
if the query returns a result. This scorer cleans query outputs that have backticks | ||
([see code](./execution_accuracy.py)). | ||
|
||
This example requires Python. | ||
|
||
## Usage | ||
|
||
1. Run the prepare script to set up the example database with dummy data. | ||
```sh | ||
python prepare.py | ||
``` | ||
|
||
1. Review the `empiricalrc.json` configuration, and make changes if any. The current configuration runs models | ||
from OpenAI, Anthropic and Google, and thus, requires [relevant environment variables](https://docs.empirical.run/models/basic). | ||
```sh | ||
cat empiricalrc.json | ||
``` | ||
|
||
1. Run with Empirical | ||
```sh | ||
npx @empiricalrun/cli run | ||
``` | ||
|
||
1. See results on the Empirical web reporter | ||
```sh | ||
npx @empiricalrun/cli ui | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import re | ||
import sqlite3 | ||
|
||
|
||
def remove_backticks(text): | ||
"""If we find backticks, return the code snippet from within them""" | ||
pruned_text = text.replace("```sql", "```") | ||
found = re.findall(r"```(.*?)```", pruned_text, re.DOTALL) | ||
if len(found): | ||
return found[0] | ||
else: | ||
return pruned_text | ||
|
||
|
||
def evaluate(output, inputs): | ||
con = sqlite3.connect("concert_singer.db") | ||
cur = con.cursor() | ||
try: | ||
res = cur.execute(remove_backticks(output["value"])) | ||
first_row = res.fetchone() | ||
if first_row: | ||
passed = 1 | ||
message = "Result preview: " + ", ".join([str(x) for x in first_row]) | ||
else: | ||
passed = 0.5 | ||
message = "No results found" | ||
except Exception as e: | ||
passed, message = 0, repr(e) | ||
|
||
return [{"score": passed, "message": message, "name": "exec-accuracy"}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import sqlite3 | ||
|
||
|
||
def create_and_load_database(): | ||
sqlite_file_name = "concert_singer.db" | ||
if os.path.isfile(sqlite_file_name): | ||
os.remove(sqlite_file_name) | ||
con = sqlite3.connect(sqlite_file_name) | ||
cur = con.cursor() | ||
# First create tables | ||
cur.executescript(schema) | ||
# Then load data | ||
cur.executescript(data) | ||
con.close() | ||
|
||
|
||
schema = """CREATE TABLE stadium ( | ||
stadium_id NUMERIC PRIMARY KEY, | ||
location TEXT, | ||
name TEXT, | ||
capacity NUMERIC, | ||
highest NUMERIC, | ||
lowest NUMERIC, | ||
average NUMERIC | ||
); | ||
CREATE TABLE singer ( | ||
singer_id NUMERIC PRIMARY KEY, | ||
name TEXT, | ||
country TEXT, | ||
song_name TEXT, | ||
song_release_year TEXT, | ||
age NUMERIC, | ||
is_male TIMESTAMP | ||
); | ||
CREATE TABLE concert ( | ||
concert_id NUMERIC PRIMARY KEY, | ||
concert_name TEXT, | ||
theme TEXT, | ||
stadium_id TEXT, | ||
year TEXT, | ||
FOREIGN KEY (stadium_id) REFERENCES stadium(stadium_id) | ||
); | ||
CREATE TABLE singer_in_concert ( | ||
concert_id NUMERIC PRIMARY KEY, | ||
singer_id TEXT, | ||
FOREIGN KEY (singer_id) REFERENCES singer(singer_id), | ||
FOREIGN KEY (concert_id) REFERENCES concert(concert_id) | ||
);""" | ||
|
||
|
||
data = """-- INSERT INTO stadium | ||
INSERT INTO stadium (stadium_id, location, name, capacity, highest, lowest, average) VALUES | ||
(1, 'New York, USA', 'Madison Square Garden', 20789, 85, 70, 78), | ||
(2, 'London, UK', 'Wembley Stadium', 90000, 92, 65, 80), | ||
(3, 'Sydney, Australia', 'Sydney Opera House', 5738, 88, 68, 75), | ||
(4, 'Paris, France', 'Stade de France', 81338, 90, 72, 82), | ||
(5, 'Tokyo, Japan', 'Tokyo Dome', 55000, 87, 70, 78); | ||
-- INSERT INTO singer | ||
INSERT INTO singer (singer_id, name, country, song_name, song_release_year, age, is_male) VALUES | ||
(1, 'Taylor Swift', 'USA', 'Shake It Off', '2014', 33, 0), | ||
(2, 'Ed Sheeran', 'UK', 'Shape of You', '2017', 32, 1), | ||
(3, 'Adele', 'UK', 'Hello', '2015', 34, 0), | ||
(4, 'BTS', 'South Korea', 'Butter', '2021', 29, 0), | ||
(5, 'Drake', 'Canada', "God's Plan", '2018', 36, 1); | ||
-- INSERT INTO concert | ||
INSERT INTO concert (concert_id, concert_name, theme, stadium_id, year) VALUES | ||
(1, 'Reputation Tour', 'Pop', '1', '2018'), | ||
(2, 'Divide Tour', 'Pop', '2', '2019'), | ||
(3, '25 Tour', 'Pop', '3', '2016'), | ||
(4, 'Love Yourself Tour', 'K-Pop', '4', '2018'), | ||
(5, 'Scorpion Tour', 'Hip-Hop', '5', '2019'); | ||
-- INSERT INTO singer_in_concert | ||
INSERT INTO singer_in_concert (concert_id, singer_id) VALUES | ||
(1, '1'), | ||
(2, '2'), | ||
(3, '3'), | ||
(4, '4'), | ||
(5, '5');""" | ||
|
||
|
||
if __name__ == "__main__": | ||
create_and_load_database() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters