Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: openai assistants #174

Merged
merged 21 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .changeset/itchy-jobs-breathe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@empiricalrun/scorer": minor
"@empiricalrun/types": minor
"@empiricalrun/core": minor
"@empiricalrun/cli": minor
"@empiricalrun/ai": minor
"web": minor
---

feat: add support for openai assistants
29 changes: 26 additions & 3 deletions apps/web/components/json-as-tab.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useMemo } from "react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useSyncedTabs } from "../hooks/useSyncedTab";
import {
Sheet,
Expand Down Expand Up @@ -40,14 +40,32 @@ export function JsonAsTab({
() => defaultTabs || Object.keys(data),
[data, defaultTabs],
);
const { activeTab, onChangeTab } = useSyncedTabs(tabs, storeKey);
const { activeTab: remoteActiveTab, onChangeTab: remoteOnChangeTab } =
useSyncedTabs(tabs, storeKey);
const [activeTab, setActiveTab] = useState<string | undefined>();
const activeTabValue = useMemo(() => {
if (activeTab && data) {
return data[activeTab];
}
return undefined;
}, [activeTab, data]);

useEffect(() => {
if (remoteActiveTab && data[remoteActiveTab]) {
setActiveTab(remoteActiveTab);
} else if (!activeTab) {
setActiveTab(Object.keys(data)[0]);
}
}, [remoteActiveTab]);

const onChangeTab = useCallback(
(tab: string) => {
setActiveTab(tab);
remoteOnChangeTab(tab);
},
[remoteOnChangeTab],
);

return (
<>
<div className="flex flex-row space-x-2 justify-end">
Expand Down Expand Up @@ -117,7 +135,12 @@ export function JsonAsTab({
</>
</div>
{tabs.length > 0 && (
<Tabs value={activeTab} className="h-full" onValueChange={onChangeTab}>
<Tabs
value={activeTab}
defaultValue={activeTab}
className="h-full"
onValueChange={onChangeTab}
>
<TabsList className=" rounded-sm w-full overflow-x-scroll justify-start no-scrollbar">
{tabs.map((name) => (
<TabsTrigger
Expand Down
15 changes: 7 additions & 8 deletions apps/web/components/sample-output-card.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import { DotsVerticalIcon } from "@radix-ui/react-icons";
import EmptySampleCompletion from "./empty-sample-completion";
import { RunResult } from "../types";
import SampleCompletionError from "./sample-completion-error";
import { Separator } from "./ui/separator";
import { JsonAsTab } from "./json-as-tab";
import { RunSampleOutputMetric } from "./run-response-metadata";
import { Scores } from "./scores";
Expand Down Expand Up @@ -122,6 +121,10 @@ export default function SampleOutputCard({
: 0,
[baseSample],
);
const hasMetadata = useMemo(
() => !!Object.keys(baseSample?.output.metadata || {}).length,
[baseSample?.output.metadata],
);
return (
<Card
className={`flex flex-col flex-1 ${
Expand Down Expand Up @@ -213,7 +216,7 @@ export default function SampleOutputCard({
ref={containerWrapper}
>
<section className="flex flex-col">
{showOutput && baseSample?.output.metadata && (
{showOutput && hasMetadata && (
<p className=" text-sm font-medium mb-2">Output</p>
)}
{diffView.enabled && baseSample && (
Expand Down Expand Up @@ -241,7 +244,7 @@ export default function SampleOutputCard({
)}
</section>
{showOutput && (
<div className="flex gap-2 items-center px-2 mt-2">
<div className="flex gap-2 items-center mt-2">
<RunSampleOutputMetric
title="Total tokens"
value={baseSample?.output?.tokens_used}
Expand All @@ -254,12 +257,8 @@ export default function SampleOutputCard({
/>
</div>
)}
{!diffView.enabled && baseSample?.output.metadata && (
{!diffView.enabled && hasMetadata && (
<section className="flex flex-col h-[200px] mt-2">
<Separator
orientation="horizontal"
className="w-[60%] self-center"
/>
<p className=" text-sm font-medium mt-2">Metadata</p>
<section className="relative flex flex-col flex-1">
<JsonAsTab
Expand Down
26 changes: 26 additions & 0 deletions examples/assistants/empiricalrc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"runs": [
{
"type": "assistant",
"assistant_id": "asst_xyK085ub8c30tG3iwkaW9Moh",
"prompt": "{{ question }}",
"parameters": {
"temperature": 0.1
}
}
],
"dataset": {
"path": "https://docs.google.com/spreadsheets/d/1U8fBQ9TxtR5pUS0Bg0n1xcU_6EXkhGUOj1CS-TAAHmU/edit#gid=0"
},
"scorers": [
{
"type": "llm-critic",
"criteria": "{{ success criteria }}",
"name": "success-criteria"
},
{
"type": "py-script",
"path": "has-citations.py"
}
]
}
14 changes: 14 additions & 0 deletions examples/assistants/has-citations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def evaluate(output, inputs):
should_have_citation = inputs.get("should have citations") == "yes"

if should_have_citation:
citations = output.get("metadata", {}).get("citations")
return [
{
"score": 1 if citations else 0,
"message": "" if citations else "No citations found",
"name": "has-citations",
}
]
else:
return []
6 changes: 3 additions & 3 deletions examples/chatbot/empiricalrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"prompt": "{{user_message}}",
"scorers": [
{
"type": "llm-criteria",
"type": "llm-critic",
"criteria": "Never call yourself an AI, language model, or any variant of the term",
"name": "self-referencing"
}
Expand All @@ -23,12 +23,12 @@
"prompt": "You are Sarah, a political scientist. Respond to the user with your best answer. Make sure to respond to them with their name.\n\n{{user_name}}: {{user_message}}",
"scorers": [
{
"type": "llm-criteria",
"type": "llm-critic",
"criteria": "Never call yourself an AI, language model, or any variant of the term",
"name": "self-referencing"
},
{
"type": "llm-criteria",
"type": "llm-critic",
"criteria": "Mention the user's name {{user_name}}",
"name": "personal"
}
Expand Down
2 changes: 1 addition & 1 deletion packages/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"@google/generative-ai": "^0.7.1",
"@mistralai/mistralai": "^0.1.3",
"anthropic": "^0.0.0",
"openai": "^4.29.0",
"openai": "^4.38.5",
"promise-retry": "^2.0.1"
}
}
2 changes: 2 additions & 0 deletions packages/ai/src/error/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export enum AIErrorEnum {
MISSING_PARAMETERS = "AI202",
// failed completions
FAILED_CHAT_COMPLETION = "AI301",
// unsupported response type like image / audio
UNSUPPORTED_COMPLETION_TYPE = "AI302",
// rate limiting
RATE_LIMITED = "AI401",
}
Expand Down
31 changes: 29 additions & 2 deletions packages/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import {
IModel,
IChatCompletions,
ICreateChatCompletion,
ICreateAndRunAssistantThread,
} from "@empiricalrun/types";
import { chatProvider } from "./providers";
import { assistantProvider, chatProvider } from "./providers";
import { OpenAIProvider } from "./providers/openai";
import { AIError, AIErrorEnum } from "./error";
export * from "./utils";
Expand All @@ -28,7 +29,7 @@ class ChatCompletions implements IChatCompletions {
if (err instanceof AIError) {
throw err;
} else {
const message = `Failed chat completion for ${this.provider} provider ${body.model} model`;
const message = `Failed chat completion for ${this.provider} model: ${body.model}`;
throw new AIError(AIErrorEnum.UNKNOWN, message);
}
}
Expand All @@ -42,6 +43,30 @@ class Chat implements IChat {
}
}

class Assistant {
constructor(private provider: string) {}
runAssistant: ICreateAndRunAssistantThread = async (body) => {
const provider = assistantProvider.get(this.provider);
if (!provider) {
throw new AIError(
AIErrorEnum.INCORRECT_PARAMETERS,
` ${this.provider} ai provider is not supported`,
);
}
try {
const run = await provider(body);
return run;
} catch (err) {
if (err instanceof AIError) {
throw err;
} else {
const message = `Failed assistant run for ${this.provider} assistant: ${body.assistant_id}`;
throw new AIError(AIErrorEnum.UNKNOWN, message);
}
}
};
}

class Models {
constructor() {}
// get the list of supported models by empiricalrun
Expand All @@ -52,9 +77,11 @@ class Models {

export class EmpiricalAI implements AI {
chat;
assistant;
models;
constructor(private provider: string = OpenAIProvider.name) {
this.chat = new Chat(this.provider);
this.assistant = new Assistant(this.provider);
this.models = new Models();
}
}
9 changes: 8 additions & 1 deletion packages/ai/src/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { ICreateChatCompletion } from "@empiricalrun/types";
import {
ICreateAndRunAssistantThread,
ICreateChatCompletion,
} from "@empiricalrun/types";
import { OpenAIProvider } from "./openai";
import { MistralAIProvider } from "./mistral";
import { GoogleAIProvider } from "./google";
Expand All @@ -14,3 +17,7 @@ export const chatProvider = new Map<string, ICreateChatCompletion>([
[FireworksAIProvider.name, FireworksAIProvider.chat],
[AzureOpenAIProvider.name, AzureOpenAIProvider.chat],
]);

export const assistantProvider = new Map<string, ICreateAndRunAssistantThread>([
[OpenAIProvider.name, OpenAIProvider.assistant!],
]);
Loading