Skip to content

Commit e4df714

Browse files
consoullgordonhart
andauthoredSep 18, 2024··
Add API to download judge votes (#53)
* Add API to download judge votes Adds an endpoint to download a judges votes under `/project/{project_slug}/judge/{judge_id}/download/votes`. The exported CSV contains the prompt, two responses and the judges choice. * Connect to frontend * Disable download button when no votes are present * Improve enabled/disabled pill display * Add test for judge vote download endpoint * Add second vote to test and catch bug where responses were not ordered --------- Co-authored-by: Gordon Hart <gordon.hart2@gmail.com>
1 parent 9e54ae3 commit e4df714

File tree

6 files changed

+100
-13
lines changed

6 files changed

+100
-13
lines changed
 

‎autoarena/api/router.py

+12
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,18 @@ def delete_judge(project_slug: str, judge_id: int, background_tasks: BackgroundT
188188
except NotFoundError:
189189
pass
190190

191+
@r.get("/project/{project_slug}/judge/{judge_id}/download/votes")
192+
async def download_judge_votes_csv(project_slug: str, judge_id: int) -> StreamingResponse:
193+
columns = ["prompt", "model_a", "model_b", "response_a", "response_b", "winner"]
194+
df_response = JudgeService.get_df_vote(project_slug, judge_id)
195+
# TODO: handle case where no votes exist, not a big problem for now as UI buttons are disabled for 0-vote judges
196+
judge_name = df_response.iloc[0].judge
197+
stream = StringIO()
198+
df_response[columns].to_csv(stream, index=False)
199+
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
200+
response.headers["Content-Disposition"] = f'attachment; filename="{judge_name}.csv"'
201+
return response
202+
191203
@r.put("/project/{project_slug}/elo/reseed-scores")
192204
def reseed_elo_scores(project_slug: str) -> None:
193205
EloService.reseed_scores(project_slug)

‎autoarena/service/judge.py

+27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from loguru import logger
2+
import pandas as pd
23

34
from autoarena.api import api
45
from autoarena.judge.factory import verify_judge_type_environment
@@ -44,6 +45,32 @@ def get_all(project_slug: str) -> list[api.Judge]:
4445
df["judge_type"] = df["judge_type"].apply(lambda j: j if j in judge_types else api.JudgeType.UNRECOGNIZED.value)
4546
return [api.Judge(**r) for _, r in df.iterrows()]
4647

48+
@staticmethod
49+
def get_df_vote(project_slug: str, judge_id: int) -> pd.DataFrame:
50+
with ProjectService.connect(project_slug) as conn:
51+
df_vote = conn.execute(
52+
"""
53+
SELECT
54+
j.name as judge,
55+
ra.prompt as prompt,
56+
ma.name as model_a,
57+
mb.name as model_b,
58+
ra.response as response_a,
59+
rb.response as response_b,
60+
h2h.winner as winner
61+
FROM judge j
62+
JOIN head_to_head h2h ON j.id = h2h.judge_id
63+
JOIN response ra ON ra.id = h2h.response_a_id
64+
JOIN response rb ON rb.id = h2h.response_b_id
65+
JOIN model ma ON ra.model_id = ma.id
66+
JOIN model mb ON rb.model_id = mb.id
67+
WHERE j.id = $judge_id
68+
ORDER BY h2h.id
69+
""",
70+
dict(judge_id=judge_id),
71+
).df()
72+
return df_vote
73+
4774
@staticmethod
4875
def create(project_slug: str, request: api.CreateJudgeRequest) -> api.Judge:
4976
with ProjectService.connect(project_slug) as conn:

‎tests/integration/api/test_judges.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from io import StringIO
2+
3+
import pandas as pd
14
import pytest
25
from fastapi.testclient import TestClient
36

@@ -51,6 +54,32 @@ def test__judges__can_access__unrecognized__failed(project_client: TestClient) -
5154
assert not project_client.get("/judge/unrecognized/can-access").json()
5255

5356

57+
def test__judges__download_votes_csv(project_client: TestClient, model_id: int, model_b_id: int) -> None:
58+
h2h = project_client.put("/head-to-heads", json=dict(model_a_id=model_id, model_b_id=model_b_id)).json()
59+
judge_request = dict(response_a_id=h2h[0]["response_a_id"], response_b_id=h2h[0]["response_b_id"], winner="A")
60+
assert project_client.post("/head-to-head/vote", json=judge_request).json() is None
61+
judge_request = dict(response_a_id=h2h[1]["response_b_id"], response_b_id=h2h[1]["response_a_id"], winner="-")
62+
assert project_client.post("/head-to-head/vote", json=judge_request).json() is None
63+
judges = project_client.get("/judges").json()
64+
assert len(judges) == 1
65+
assert judges[0]["n_votes"] == 2
66+
human_judge_id = judges[0]["id"]
67+
response = project_client.get(f"/judge/{human_judge_id}/download/votes")
68+
assert response.status_code == 200
69+
models = project_client.get("/models").json()
70+
assert len(models) == 2
71+
model_a, model_b = [m for m in models if m["id"] == model_id][0], [m for m in models if m["id"] == model_b_id][0]
72+
df_vote_expected = pd.DataFrame(
73+
[
74+
(h2h[0]["prompt"], model_a["name"], model_b["name"], h2h[0]["response_a"], h2h[0]["response_b"], "A"),
75+
(h2h[1]["prompt"], model_b["name"], model_a["name"], h2h[1]["response_b"], h2h[1]["response_a"], "-"),
76+
],
77+
columns=["prompt", "model_a", "model_b", "response_a", "response_b", "winner"],
78+
)
79+
df_vote = pd.read_csv(StringIO(response.text))
80+
assert df_vote.equals(df_vote_expected)
81+
82+
5483
def test__judges__delete(project_client: TestClient, judge_id: int) -> None:
5584
assert project_client.delete(f"/judge/{judge_id}").json() is None
5685
assert len(project_client.get("/judges").json()) == 1 # only default judge is left

‎ui/src/components/Judges/JudgeAccordionItem.tsx

+28-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import { useState } from 'react';
2-
import { Accordion, Button, Checkbox, Collapse, Group, Pill, Stack, Text, Tooltip } from '@mantine/core';
2+
import { Accordion, Anchor, Button, Checkbox, Collapse, Group, Pill, Stack, Text, Tooltip } from '@mantine/core';
33
import { Link } from 'react-router-dom';
44
import { useDisclosure } from '@mantine/hooks';
5-
import { IconGavel } from '@tabler/icons-react';
5+
import { IconDownload, IconGavel, IconPrompt } from '@tabler/icons-react';
66
import { Judge } from '../../hooks/useJudges.ts';
77
import { useUrlState } from '../../hooks/useUrlState.ts';
88
import { useUpdateJudge } from '../../hooks/useUpdateJudge.ts';
99
import { MarkdownContent } from '../MarkdownContent.tsx';
1010
import { pluralize } from '../../lib/string.ts';
11+
import { API_ROUTES } from '../../lib/routes.ts';
1112
import { judgeTypeIconComponent, judgeTypeToHumanReadableName } from './types.ts';
1213
import { DeleteJudgeButton } from './DeleteJudgeButton.tsx';
1314
import { CanAccessJudgeStatusIndicator } from './CanAccessJudgeStatusIndicator.tsx';
@@ -30,6 +31,15 @@ export function JudgeAccordionItem({ judge }: Props) {
3031
setIsEnabled(prev => !prev);
3132
}
3233

34+
const canDownload = judge.n_votes > 0;
35+
const downloadUrl = canDownload ? API_ROUTES.downloadJudgeVotesCsv(projectSlug, judge.id) : undefined;
36+
const DownloadVotesComponent = (
37+
<Anchor href={downloadUrl} target="_blank">
38+
<Button variant="light" color="teal" size="xs" leftSection={<IconDownload size={20} />} disabled={!canDownload}>
39+
Download Votes CSV
40+
</Button>
41+
</Anchor>
42+
);
3343
const IconComponent = judgeTypeIconComponent(judge_type);
3444
return (
3545
<Accordion.Item key={id} value={`${judge_type}-${id}`}>
@@ -53,10 +63,15 @@ export function JudgeAccordionItem({ judge }: Props) {
5363
</Text>
5464
</Stack>
5565
<Group>
56-
{isEnabled && (
57-
<Pill bg="ice.0" c="gray.8">
66+
<Text c="dimmed" size="xs" fs="italic">
67+
{pluralize(judge.n_votes, 'vote')}
68+
</Text>
69+
{isEnabled ? (
70+
<Pill bg="ice.0" c="ice.9">
5871
Enabled
5972
</Pill>
73+
) : (
74+
<Pill c="gray">Disabled</Pill>
6075
)}
6176
</Group>
6277
</Group>
@@ -73,12 +88,16 @@ export function JudgeAccordionItem({ judge }: Props) {
7388
onChange={() => handleToggleEnabled()}
7489
/>
7590
<Group>
76-
<Text c="dimmed" size="xs" fs="italic">
77-
{pluralize(judge.n_votes, 'vote')} submitted
78-
</Text>
79-
<Button variant="light" color="gray" size="xs" onClick={toggleShowSystemPrompt}>
91+
<Button
92+
variant="light"
93+
color="gray"
94+
size="xs"
95+
leftSection={<IconPrompt size={20} />}
96+
onClick={toggleShowSystemPrompt}
97+
>
8098
{showSystemPrompt ? 'Hide' : 'Show'} System Prompt
8199
</Button>
100+
{DownloadVotesComponent}
82101
<Tooltip label="Judge must be enabled" disabled={judge.enabled}>
83102
<Button
84103
variant="light"
@@ -111,9 +130,7 @@ export function JudgeAccordionItem({ judge }: Props) {
111130
</Link>{' '}
112131
tab to provide ratings on head-to-head matchups between models.
113132
</Text>
114-
<Text c="dimmed" size="xs" fs="italic">
115-
{pluralize(judge.n_votes, 'vote')} submitted
116-
</Text>
133+
{DownloadVotesComponent}
117134
</Group>
118135
)}
119136
<Collapse in={showSystemPrompt} fz="sm">

‎ui/src/components/Leaderboard/ExpandedModelDetails.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ export function ExpandedModelDetails({ model }: Props) {
6464
</Link>
6565
<Anchor href={API_ROUTES.downloadModelResponsesCsv(projectSlug, model.id)} target="_blank">
6666
<Button color="teal" variant="light" size="xs" leftSection={<IconDownload size={20} />}>
67-
Download Response CSV
67+
Download Responses CSV
6868
</Button>
6969
</Anchor>
7070
<Anchor href={API_ROUTES.downloadModelHeadToHeadsCsv(projectSlug, model.id)} target="_blank">
7171
<Button color="teal" variant="light" size="xs" leftSection={<IconDownload size={20} />}>
72-
Download Head-to-Head CSV
72+
Download Head-to-Heads CSV
7373
</Button>
7474
</Anchor>
7575
<Button

‎ui/src/lib/routes.ts

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ export const API_ROUTES = {
6060
getDefaultSystemPrompt: (projectSlug: string) => `${getProjectApiUrl(projectSlug)}/judge/default-system-prompt`,
6161
createJudge: (projectSlug: string) => `${getProjectApiUrl(projectSlug)}/judge`,
6262
updateJudge: (projectSlug: string, judgeId: number) => `${getProjectApiUrl(projectSlug)}/judge/${judgeId}`,
63+
downloadJudgeVotesCsv: (projectSlug: string, judgeId: number) =>
64+
`${getProjectApiUrl(projectSlug)}/judge/${judgeId}/download/votes`,
6365
checkCanAccess: (projectSlug: string, judgeType: JudgeType) =>
6466
`${getProjectApiUrl(projectSlug)}/judge/${judgeType}/can-access`,
6567
deleteJudge: (projectSlug: string, judgeId: number) => `${getProjectApiUrl(projectSlug)}/judge/${judgeId}`,

0 commit comments

Comments
 (0)
Please sign in to comment.