Skip to content

Commit

Permalink
feat: add execution accuracy in spider example (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunattam authored Apr 16, 2024
1 parent d26cb56 commit 2bc5465
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 24 deletions.
8 changes: 8 additions & 0 deletions .changeset/wicked-laws-rhyme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"@empiricalrun/scorer": patch
"@empiricalrun/cli": patch
"@empiricalrun/ai": patch
"web": patch
---

fix: minor improvements for execution accuracy in spider example
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ packages/**/empiricalrc.json
.empiricalrun
arxiv-papers
examples/humaneval/HumanEval.jsonl
examples/spider/concert_singer.db
7 changes: 0 additions & 7 deletions apps/web/components/scores.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { Score } from "@empiricalrun/types";
import ScoreBadge from "./ui/score-badge";
import { Separator } from "./ui/separator";
import { DoubleArrowDownIcon, DoubleArrowUpIcon } from "@radix-ui/react-icons";
import { useScoresView } from "../hooks/useScoresView";
import { Button } from "./ui/button";
Expand Down Expand Up @@ -28,12 +27,6 @@ export function Scores({ scores }: { scores: Score[] }) {
</p>
)}
</div>
{expandState && (
<Separator
className="w-full self-center"
orientation="horizontal"
/>
)}
</>
);
})}
Expand Down
40 changes: 36 additions & 4 deletions examples/spider/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
## Example for Spider
# Text-to-SQL

```
npx @empiricalrun/cli run
```
LLMs are good at converting natural language questions to SQL queries. This examples uses that
scenario to demo Empirical. This example is based on the [Spider](https://github.com/taoyds/spider) dataset.

In this example, we generate SQL queries, and score them on

1. SQL syntax (with the `sql-syntax` scorer): Checks if the output syntax is valid SQL. For example, if the output is in
markdown syntax (with backticks), it is not a valid SQL query.
2. Execution accuracy (with the `py-script` scorer): We run the generated SQL query against a test database, and check
if the query returns a result. This scorer cleans query outputs that have backticks
([see code](./execution_accuracy.py)).

This example requires Python.

## Usage

1. Run the prepare script to set up the example database with dummy data.
```sh
python prepare.py
```

1. Review the `empiricalrc.json` configuration, and make changes if any. The current configuration runs models
from OpenAI, Anthropic and Google, and thus, requires [relevant environment variables](https://docs.empirical.run/models/basic).
```sh
cat empiricalrc.json
```

1. Run with Empirical
```sh
npx @empiricalrun/cli run
```

1. See results on the Empirical web reporter
```sh
npx @empiricalrun/cli ui
```
23 changes: 20 additions & 3 deletions examples/spider/empiricalrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
"type": "sql-syntax"
},
{
"type": "sql-semantic"
"type": "py-script",
"path": "execution_accuracy.py"
}
]
},
Expand All @@ -25,12 +26,28 @@
"type": "sql-syntax"
},
{
"type": "sql-semantic"
"type": "py-script",
"path": "execution_accuracy.py"
}
]
},
{
"type": "model",
"provider": "google",
"model": "gemini-1.5-pro-latest",
"prompt": "You are an SQLite expert who can convert natural language questions to SQL queries for the database schema given below.\n\nDatabase schema:\n{{schema}}\n\nAnswer the following question with only the SQL query.\n\nQuestion: {{question}}",
"scorers": [
{
"type": "sql-syntax"
},
{
"type": "py-script",
"path": "execution_accuracy.py"
}
]
}
],
"dataset": {
"path": "https://assets.empirical.run/datasets/v2/json/spider-tiny.json"
"path": "https://docs.google.com/spreadsheets/d/1x_p0lX2pJEyGkFoe1A9nY3q87qOJUd547f2lz99ugiM/edit#gid=0"
}
}
30 changes: 30 additions & 0 deletions examples/spider/execution_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import re
import sqlite3


def remove_backticks(text):
"""If we find backticks, return the code snippet from within them"""
pruned_text = text.replace("```sql", "```")
found = re.findall(r"```(.*?)```", pruned_text, re.DOTALL)
if len(found):
return found[0]
else:
return pruned_text


def evaluate(output, inputs):
con = sqlite3.connect("concert_singer.db")
cur = con.cursor()
try:
res = cur.execute(remove_backticks(output["value"]))
first_row = res.fetchone()
if first_row:
passed = 1
message = "Result preview: " + ", ".join([str(x) for x in first_row])
else:
passed = 0.5
message = "No results found"
except Exception as e:
passed, message = 0, repr(e)

return [{"score": passed, "message": message, "name": "exec-accuracy"}]
89 changes: 89 additions & 0 deletions examples/spider/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import sqlite3


def create_and_load_database():
sqlite_file_name = "concert_singer.db"
if os.path.isfile(sqlite_file_name):
os.remove(sqlite_file_name)
con = sqlite3.connect(sqlite_file_name)
cur = con.cursor()
# First create tables
cur.executescript(schema)
# Then load data
cur.executescript(data)
con.close()


schema = """CREATE TABLE stadium (
stadium_id NUMERIC PRIMARY KEY,
location TEXT,
name TEXT,
capacity NUMERIC,
highest NUMERIC,
lowest NUMERIC,
average NUMERIC
);
CREATE TABLE singer (
singer_id NUMERIC PRIMARY KEY,
name TEXT,
country TEXT,
song_name TEXT,
song_release_year TEXT,
age NUMERIC,
is_male TIMESTAMP
);
CREATE TABLE concert (
concert_id NUMERIC PRIMARY KEY,
concert_name TEXT,
theme TEXT,
stadium_id TEXT,
year TEXT,
FOREIGN KEY (stadium_id) REFERENCES stadium(stadium_id)
);
CREATE TABLE singer_in_concert (
concert_id NUMERIC PRIMARY KEY,
singer_id TEXT,
FOREIGN KEY (singer_id) REFERENCES singer(singer_id),
FOREIGN KEY (concert_id) REFERENCES concert(concert_id)
);"""


data = """-- INSERT INTO stadium
INSERT INTO stadium (stadium_id, location, name, capacity, highest, lowest, average) VALUES
(1, 'New York, USA', 'Madison Square Garden', 20789, 85, 70, 78),
(2, 'London, UK', 'Wembley Stadium', 90000, 92, 65, 80),
(3, 'Sydney, Australia', 'Sydney Opera House', 5738, 88, 68, 75),
(4, 'Paris, France', 'Stade de France', 81338, 90, 72, 82),
(5, 'Tokyo, Japan', 'Tokyo Dome', 55000, 87, 70, 78);
-- INSERT INTO singer
INSERT INTO singer (singer_id, name, country, song_name, song_release_year, age, is_male) VALUES
(1, 'Taylor Swift', 'USA', 'Shake It Off', '2014', 33, 0),
(2, 'Ed Sheeran', 'UK', 'Shape of You', '2017', 32, 1),
(3, 'Adele', 'UK', 'Hello', '2015', 34, 0),
(4, 'BTS', 'South Korea', 'Butter', '2021', 29, 0),
(5, 'Drake', 'Canada', "God's Plan", '2018', 36, 1);
-- INSERT INTO concert
INSERT INTO concert (concert_id, concert_name, theme, stadium_id, year) VALUES
(1, 'Reputation Tour', 'Pop', '1', '2018'),
(2, 'Divide Tour', 'Pop', '2', '2019'),
(3, '25 Tour', 'Pop', '3', '2016'),
(4, 'Love Yourself Tour', 'K-Pop', '4', '2018'),
(5, 'Scorpion Tour', 'Hip-Hop', '5', '2019');
-- INSERT INTO singer_in_concert
INSERT INTO singer_in_concert (concert_id, singer_id) VALUES
(1, '1'),
(2, '2'),
(3, '3'),
(4, '4'),
(5, '5');"""


if __name__ == "__main__":
create_and_load_database()
2 changes: 2 additions & 0 deletions packages/ai/src/providers/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ const createChatCompletion: ICreateChatCompletion = async (body) => {
.generateContent({ contents })
.catch((err: Error) => {
// TODO: Replace with instanceof checks when the Gemini SDK exports errors
console.log(err.message);
if (err.message.includes("[429 Too Many Requests]")) {
console.log("Attempting to retry");
retry(err);
}
throw err;
Expand Down
1 change: 1 addition & 0 deletions packages/cli/src/bin/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ program
console.log(buildErrorLog(`${code}: ${message}`));
process.exit(1);
}
process.exit(0);
});

const defaultWebUIPort = 1337;
Expand Down
33 changes: 32 additions & 1 deletion packages/scorer/src/provider/deterministic/sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,38 @@ test("sql-syntax works with markdown", async () => {
{
score: 0,
name: "sql-syntax",
message: "SQL is invalid",
message:
'Expected "#", "$", "(", "+", "-", "--", "/*", ";", "@", "@@", "ALTER", "CALL", "CREATE", "DELETE", "DESC", "DESCRIBE", "DROP", "GO", "GRANT", "INSERT", "LOCK", "RENAME", "REPLACE", "SELECT", "SET", "SHOW", "TRUNCATE", "UNLOCK", "UPDATE", "USE", "WITH", "return", [ \\t\\n\\r], [0-9], [A-Za-z_], or end of input but "`" found.',
},
]);
});

test("sql-syntax works with a correct query", async () => {
const query = `SELECT name, capacity
FROM stadium
WHERE stadium_id = (
SELECT stadium_id
FROM concert
WHERE year > 2013
GROUP BY stadium_id
ORDER BY COUNT(concert_id) DESC
LIMIT 1
);`;
expect(
await checkSqlSyntax({
sample: { id: "1", inputs: {} },
output: {
value: query,
},
config: {
type: "sql-syntax",
},
}),
).toStrictEqual([
{
score: 1,
name: "sql-syntax",
message: "",
},
]);
});
Expand Down
16 changes: 7 additions & 9 deletions packages/scorer/src/provider/deterministic/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ import { ScoringFn } from "../../interface/scorer";
export const syntaxName = "sql-syntax";
export const semanticName = "sql-semantic";

//TODO: make this config driven
const parserOpt = { database: "sqlite" };

export const checkSqlSyntax: ScoringFn = async ({ output }) => {
let isSQLQuery = false;
let inValidSQLMsg = "SQL is invalid";
let errorMsg = "SQL is invalid";
const parser = new Parser();
if (!output || !output.value) {
return [
Expand All @@ -21,16 +18,17 @@ export const checkSqlSyntax: ScoringFn = async ({ output }) => {
];
}
try {
parser.parse(output.value!, parserOpt);
parser.parse(output.value!);
isSQLQuery = true;
} catch (e) {
} catch (e: any) {
isSQLQuery = false;
errorMsg = e.message;
}
return [
{
score: isSQLQuery ? 1 : 0,
name: syntaxName,
message: isSQLQuery ? "" : inValidSQLMsg,
message: isSQLQuery ? "" : errorMsg,
},
];
};
Expand All @@ -48,8 +46,8 @@ export const checkSqlSemantic: ScoringFn = async ({ sample, output }) => {
];
}
try {
const parsedOutput = parser.parse(cleanQuery(output.value!), parserOpt);
const parsedExpected = parser.parse(cleanQuery(expected), parserOpt);
const parsedOutput = parser.parse(cleanQuery(output.value!));
const parsedExpected = parser.parse(cleanQuery(expected));
cleanColumns(parsedOutput.ast as Select);
cleanColumns(parsedExpected.ast as Select);
const isEquivalent =
Expand Down

0 comments on commit 2bc5465

Please sign in to comment.