Skip to content

Commit

Permalink
feat: add support for tool calls in assistant output (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatmitra91 authored May 2, 2024
1 parent 2b03d24 commit 93e12e0
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 39 deletions.
9 changes: 9 additions & 0 deletions .changeset/seven-oranges-decide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
"@empiricalrun/types": minor
"@empiricalrun/core": minor
"@empiricalrun/ai": minor
"web": minor
"@empiricalrun/cli": minor
---

feat: add support for assistant tool calls
12 changes: 9 additions & 3 deletions apps/web/components/json-as-tab.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ export function JsonAsTab({
onSampleRemove,
onEditorContentUpdate,
onClickRunAll,
showExpandOption = true,
readonlyContent = false,
scrollableContent = true,
}: {
storeKey: string;
data: { [key: string]: any };
Expand All @@ -35,6 +38,9 @@ export function JsonAsTab({
onSampleRemove?: () => void;
onEditorContentUpdate?: (key: string, value: string) => void;
onClickRunAll?: () => void;
showExpandOption?: boolean;
readonlyContent?: boolean;
scrollableContent?: boolean;
}) {
const tabs = useMemo(
() => defaultTabs || Object.keys(data),
Expand Down Expand Up @@ -70,7 +76,7 @@ export function JsonAsTab({
<>
<div className="flex flex-row space-x-2 justify-end">
<>
{activeTabValue && (
{activeTabValue && showExpandOption && (
<Sheet>
<SheetTrigger asChild>
<Button
Expand Down Expand Up @@ -169,8 +175,8 @@ export function JsonAsTab({
: JSON.stringify(value, null, 2)
}
language="text"
readOnly={false}
scrollable
readOnly={readonlyContent}
scrollable={scrollableContent}
onChange={(value) => onEditorContentUpdate?.(key, value!)}
/>
</TabsContent>
Expand Down
60 changes: 35 additions & 25 deletions apps/web/components/sample-output-card.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import SampleCompletionError from "./sample-completion-error";
import { JsonAsTab } from "./json-as-tab";
import { RunSampleOutputMetric } from "./run-response-metadata";
import { Scores } from "./scores";
import { ToolCalls } from "./tool-calls-view";

type Diff = {
type: string;
Expand Down Expand Up @@ -56,12 +57,13 @@ export default function SampleOutputCard({

const showCompareAgainst = useMemo(
() =>
baseSample?.expected?.value ||
comparisonSamples?.some(
(comparisonSample, index) =>
comparisonSample?.output &&
comparisonResults?.[index]?.id !== baseResult?.id,
),
baseSample?.output.value &&
(baseSample?.expected?.value ||
comparisonSamples?.some(
(comparisonSample, index) =>
comparisonSample?.output &&
comparisonResults?.[index]?.id !== baseResult?.id,
)),
[
baseResult?.id,
baseSample?.expected?.value,
Expand Down Expand Up @@ -139,12 +141,12 @@ export default function SampleOutputCard({
<CardTitle className="flex flex-row space-x-2 items-center">
<Scores scores={baseSample?.scores || []} />
<div className="flex flex-row space-x-2 justify-end items-start self-baseline">
<DropdownMenu>
<DropdownMenuTrigger>
<DotsVerticalIcon />
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
{showCompareAgainst ? (
{showCompareAgainst && (
<DropdownMenu>
<DropdownMenuTrigger>
<DotsVerticalIcon />
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<>
<DropdownMenuLabel className="text-xs">
Compare against
Expand Down Expand Up @@ -200,9 +202,9 @@ export default function SampleOutputCard({
);
})}
</>
) : null}
</DropdownMenuContent>
</DropdownMenu>
</DropdownMenuContent>
</DropdownMenu>
)}
</div>
</CardTitle>
)}
Expand All @@ -216,9 +218,6 @@ export default function SampleOutputCard({
ref={containerWrapper}
>
<section className="flex flex-col">
{showOutput && hasMetadata && (
<p className=" text-sm font-medium mb-2">Output</p>
)}
{diffView.enabled && baseSample && (
<DiffEditor
original={baseSample?.output.value || ""}
Expand All @@ -236,11 +235,16 @@ export default function SampleOutputCard({
/>
)}
{showOutput && (
<CodeViewer
value={baseSample?.output.value || ""}
language="json"
readOnly
/>
<>
{!baseSample?.output.tool_calls && (
<CodeViewer
value={baseSample?.output.value || ""}
language="json"
readOnly
/>
)}
<ToolCalls toolCalls={baseSample?.output.tool_calls} />
</>
)}
</section>
{showOutput && (
Expand All @@ -250,7 +254,11 @@ export default function SampleOutputCard({
value={baseSample?.output?.tokens_used}
hideSeparator
/>
<RunSampleOutputMetric title="Latency" value={latency} />
<RunSampleOutputMetric
title="Latency"
value={latency}
hideSeparator={!baseSample?.output?.tokens_used}
/>
<RunSampleOutputMetric
title="Finish reason"
value={baseSample?.output?.finish_reason}
Expand All @@ -259,11 +267,13 @@ export default function SampleOutputCard({
)}
{!diffView.enabled && hasMetadata && (
<section className="flex flex-col h-[200px] mt-2">
<p className=" text-sm font-medium mt-2">Metadata</p>
<section className="relative flex flex-col flex-1">
<JsonAsTab
storeKey={baseResult?.id!}
data={baseSample?.output.metadata!}
showExpandOption={false}
scrollableContent={false}
readonlyContent
/>
</section>
</section>
Expand Down
53 changes: 53 additions & 0 deletions apps/web/components/tool-calls-view.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { ChatCompletionMessageToolCall } from "@empiricalrun/types";
import CodeViewer from "./ui/code-viewer";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs";

export function ToolCalls({
toolCalls,
}: {
toolCalls: ChatCompletionMessageToolCall[] | undefined;
}) {
if (!toolCalls) {
return null;
}
if (toolCalls.length === 0) {
return null;
}
const tabs = toolCalls.filter((t) => t.type === "function");
return (
<>
<Tabs defaultValue={tabs[0]?.id} className="h-full">
<TabsList className=" rounded-sm w-full overflow-x-scroll justify-start no-scrollbar">
<p className=" font-semibold text-sm mr-2 text-white">Tool Calls</p>
{tabs.map((tab) => (
<TabsTrigger
key={tab.id}
value={tab.id}
className="text-xs rounded-sm"
>
{tab.function.name}
</TabsTrigger>
))}
</TabsList>
{tabs.map((tab) => (
<TabsContent
key={tab.id}
value={tab.id}
// 2.25rem as the height of the tabs is h-9 by default. change this if tab height changes
className="h-[calc(100%-3rem)]"
>
<CodeViewer
value={JSON.stringify(
JSON.parse(tab.function.arguments),
null,
2,
)}
language="text"
readOnly
/>
</TabsContent>
))}
</Tabs>
</>
);
}
2 changes: 1 addition & 1 deletion apps/web/components/ui/tabs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const TabsContent = React.forwardRef<
<TabsPrimitive.Content
ref={ref}
className={cn(
"mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
"ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
className,
)}
{...props}
Expand Down
7 changes: 2 additions & 5 deletions packages/ai/src/providers/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,9 @@ const runAssistant: ICreateAndRunAssistantThread = async (body) => {
if (eventData.event === "thread.run.requires_action") {
const { tool_calls } = eventData.data.required_action
?.submit_tool_outputs || {
tool_calls: [],
tool_calls: undefined,
};
const toolSummary = tool_calls.map((tc) => {
return `${tc.function.name} with args ${tc.function.arguments}`;
});
asstRunResp.content = `Attempting to make tool call: ${toolSummary.join(", ") || ""}`;
asstRunResp.content = "";
asstRunResp.tool_calls = tool_calls;
}

Expand Down
5 changes: 1 addition & 4 deletions packages/core/src/executors/run/transformers/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,18 @@ export const assistantExecutor: Transformer = async function (
});

const hasCitation = message.citations && message.citations.length > 0;
const hasToolCall = message.tool_calls && message.tool_calls.length > 0;
let metadata: any = {};
if (hasCitation) {
metadata.citations = message.citations;
}
if (hasToolCall) {
metadata.tool_calls = message.tool_calls;
}

return {
output: {
value: message.content,
metadata,
tokens_used: message.usage?.total_tokens,
latency: message.latency,
tool_calls: message.tool_calls,
},
};
} catch (e: any) {
Expand Down
5 changes: 4 additions & 1 deletion packages/types/src/ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ export interface ICreateChatCompletion {
(body: IChatCompletionCreateParams): Promise<IChatCompletion>;
}

export interface ChatCompletionMessageToolCall
extends OpenAI.ChatCompletionMessageToolCall {}

export interface Citation {
file_id?: string;
quote?: string;
Expand All @@ -23,7 +26,7 @@ export interface Citation {
export interface IAssistantRunResponse {
content: string;
citations: Citation[];
tool_calls?: any[];
tool_calls?: ChatCompletionMessageToolCall[];
usage?: OpenAI.CompletionUsage;
latency?: number;
}
Expand Down
3 changes: 3 additions & 0 deletions packages/types/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { ChatCompletionMessageToolCall } from "./ai";

export * from "./ai";

export enum RoleType {
Expand Down Expand Up @@ -193,6 +195,7 @@ export type RunOutput = {
finish_reason?: string;
tokens_used?: number;
latency?: number;
tool_calls?: ChatCompletionMessageToolCall[];
};

export type RunSampleOutput = {
Expand Down

0 comments on commit 93e12e0

Please sign in to comment.