diff --git a/.changeset/violet-horses-accept.md b/.changeset/violet-horses-accept.md new file mode 100644 index 000000000000..a3f52f3ea3c8 --- /dev/null +++ b/.changeset/violet-horses-accept.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat (rsc): add streamUI onFinish callback diff --git a/content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx b/content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx index a69dfbb26076..caaee98cfc7d 100644 --- a/content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx +++ b/content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx @@ -333,6 +333,77 @@ A helper function to create a streamable UI from LLM providers. This function is }, ], }, + { + name: 'onFinish', + type: '(result: OnFinishResult) => void', + isOptional: true, + description: + 'Callback that is called when the LLM response and all request tool executions (for tools that have a `generate` function) are finished.', + properties: [ + { + type: 'OnFinishResult', + parameters: [ + { + name: 'usage', + type: 'TokenUsage', + description: 'The token usage of the generated text.', + properties: [ + { + type: 'TokenUsage', + parameters: [ + { + name: 'promptTokens', + type: 'number', + description: 'The total number of tokens in the prompt.', + }, + { + name: 'completionTokens', + type: 'number', + description: + 'The total number of tokens in the completion.', + }, + { + name: 'totalTokens', + type: 'number', + description: 'The total number of tokens generated.', + }, + ], + }, + ], + }, + { + name: 'value', + type: 'ReactNode', + description: 'The final ui node that was generated.', + }, + { + name: 'warnings', + type: 'Warning[] | undefined', + description: + 'Warnings from the model provider (e.g. unsupported settings).', + }, + { + name: 'rawResponse', + type: 'RawResponse', + description: 'Optional raw response data.', + properties: [ + { + type: 'RawResponse', + parameters: [ + { + name: 'header', + optional: true, + type: 'Record', + description: 'Response headers.', + }, + ], + }, + ], + }, + ], + }, + ], + }, ]} /> diff --git a/content/examples/01-next-app/05-interface/03-token-usage.mdx b/content/examples/01-next-app/05-interface/03-token-usage.mdx new file mode 100644 index 000000000000..4d20762055c4 --- /dev/null +++ b/content/examples/01-next-app/05-interface/03-token-usage.mdx @@ -0,0 +1,156 @@ +--- +title: Recording Token Usage +description: Examples of how to record token usage when streaming user interfaces. +--- + +# Recording Token Usage + +When you're streaming structured data with [`streamUI`](/docs/reference/ai-sdk-rsc/stream-ui), +you may want to record the token usage for billing purposes. + +## `onFinish` Callback + +You can use the `onFinish` callback to record token usage. +It is called when the stream is finished. + +```tsx filename='app/page.tsx' +'use client'; + +import { useState } from 'react'; +import { ClientMessage } from './actions'; +import { useActions, useUIState } from 'ai/rsc'; +import { generateId } from 'ai'; + +// Force the page to be dynamic and allow streaming responses up to 30 seconds +export const dynamic = 'force-dynamic'; +export const maxDuration = 30; + +export default function Home() { + const [input, setInput] = useState(''); + const [conversation, setConversation] = useUIState(); + const { continueConversation } = useActions(); + + return ( +
+
+ {conversation.map((message: ClientMessage) => ( +
+ {message.role}: {message.display} +
+ ))} +
+ +
+ { + setInput(event.target.value); + }} + /> + +
+
+ ); +} +``` + +## Server + +```tsx filename='app/actions.tsx' highlight={"57-63"} +'use server'; + +import { createAI, getMutableAIState, streamUI } from 'ai/rsc'; +import { openai } from '@ai-sdk/openai'; +import { ReactNode } from 'react'; +import { z } from 'zod'; +import { generateId } from 'ai'; + +export interface ServerMessage { + role: 'user' | 'assistant'; + content: string; +} + +export interface ClientMessage { + id: string; + role: 'user' | 'assistant'; + display: ReactNode; +} + +export async function continueConversation( + input: string, +): Promise { + 'use server'; + + const history = getMutableAIState(); + + const result = await streamUI({ + model: openai('gpt-3.5-turbo'), + messages: [...history.get(), { role: 'user', content: input }], + text: ({ content, done }) => { + if (done) { + history.done((messages: ServerMessage[]) => [ + ...messages, + { role: 'assistant', content }, + ]); + } + + return
{content}
; + }, + tools: { + deploy: { + description: 'Deploy repository to vercel', + parameters: z.object({ + repositoryName: z + .string() + .describe('The name of the repository, example: vercel/ai-chatbot'), + }), + generate: async function* ({ repositoryName }) { + yield
Cloning repository {repositoryName}...
; // [!code highlight:5] + await new Promise(resolve => setTimeout(resolve, 3000)); + yield
Building repository {repositoryName}...
; + await new Promise(resolve => setTimeout(resolve, 2000)); + return
{repositoryName} deployed!
; + }, + }, + }, + onFinish: ({ usage }) => { + const { promptTokens, completionTokens, totalTokens } = usage; + // your own logic, e.g. for saving the chat history or recording usage + console.log('Prompt tokens:', promptTokens); + console.log('Completion tokens:', completionTokens); + console.log('Total tokens:', totalTokens); + }, + }); + + return { + id: generateId(), + role: 'assistant', + display: result.value, + }; +} + +export const AI = createAI({ + actions: { + continueConversation, + }, + initialAIState: [], + initialUIState: [], +}); +``` diff --git a/examples/next-openai/app/stream-ui/actions.tsx b/examples/next-openai/app/stream-ui/actions.tsx new file mode 100644 index 000000000000..c46dca1e2ee2 --- /dev/null +++ b/examples/next-openai/app/stream-ui/actions.tsx @@ -0,0 +1,119 @@ +import { openai } from '@ai-sdk/openai'; +import { CoreMessage, generateId } from 'ai'; +import { + createAI, + createStreamableValue, + getMutableAIState as $getMutableAIState, + streamUI, +} from 'ai/rsc'; +import { Message, BotMessage } from './message'; +import { z } from 'zod'; + +type AIProviderNoActions = ReturnType>; +// typed wrapper *without* actions defined to avoid circular dependencies +const getMutableAIState = $getMutableAIState; + +// mock function to fetch weather data +const fetchWeatherData = async (location: string) => { + await new Promise(resolve => setTimeout(resolve, 1000)); + return { temperature: '72°F' }; +}; + +export async function submitUserMessage(content: string) { + 'use server'; + + const aiState = getMutableAIState(); + + aiState.update({ + ...aiState.get(), + messages: [ + ...aiState.get().messages, + { id: generateId(), role: 'user', content }, + ], + }); + + let textStream: undefined | ReturnType>; + let textNode: React.ReactNode; + + const result = await streamUI({ + model: openai('gpt-4-turbo'), + initial: Working on that..., + system: 'You are a weather assistant.', + messages: aiState + .get() + .messages.map(({ role, content }) => ({ role, content } as CoreMessage)), + + text: ({ content, done, delta }) => { + if (!textStream) { + textStream = createStreamableValue(''); + textNode = ; + } + + if (done) { + textStream.done(); + aiState.update({ + ...aiState.get(), + messages: [ + ...aiState.get().messages, + { id: generateId(), role: 'assistant', content }, + ], + }); + } else { + textStream.append(delta); + } + + return textNode; + }, + tools: { + get_current_weather: { + description: 'Get the current weather', + parameters: z.object({ + location: z.string(), + }), + generate: async function* ({ location }) { + yield ( + Loading weather for {location} + ); + const { temperature } = await fetchWeatherData(location); + return ( + + + The temperature in {location} is{' '} + {temperature} + + + ); + }, + }, + }, + onFinish: event => { + // your own logic, e.g. for saving the chat history or recording usage + console.log(`[onFinish]: ${JSON.stringify(event, null, 2)}`); + }, + }); + + return { + id: generateId(), + display: result.value, + }; +} + +export type ClientMessage = CoreMessage & { + id: string; +}; + +export type AIState = { + chatId: string; + messages: ClientMessage[]; +}; + +export type UIState = { + id: string; + display: React.ReactNode; +}[]; + +export const AI = createAI({ + actions: { submitUserMessage }, + initialUIState: [] as UIState, + initialAIState: { chatId: generateId(), messages: [] } as AIState, +}); diff --git a/examples/next-openai/app/stream-ui/layout.tsx b/examples/next-openai/app/stream-ui/layout.tsx new file mode 100644 index 000000000000..8c0e44c57225 --- /dev/null +++ b/examples/next-openai/app/stream-ui/layout.tsx @@ -0,0 +1,5 @@ +import { AI } from './actions'; + +export default function Layout({ children }: { children: React.ReactNode }) { + return {children}; +} diff --git a/examples/next-openai/app/stream-ui/message.tsx b/examples/next-openai/app/stream-ui/message.tsx new file mode 100644 index 000000000000..4553a6b66753 --- /dev/null +++ b/examples/next-openai/app/stream-ui/message.tsx @@ -0,0 +1,25 @@ +'use client'; + +import { StreamableValue, useStreamableValue } from 'ai/rsc'; + +export function BotMessage({ textStream }: { textStream: StreamableValue }) { + const [text] = useStreamableValue(textStream); + return {text}; +} + +export function Message({ + role, + children, +}: { + role: string; + children: React.ReactNode; +}) { + return ( +
+
+
{role}
+
+ {children} +
+ ); +} diff --git a/examples/next-openai/app/stream-ui/page.tsx b/examples/next-openai/app/stream-ui/page.tsx new file mode 100644 index 000000000000..364b408a2dd0 --- /dev/null +++ b/examples/next-openai/app/stream-ui/page.tsx @@ -0,0 +1,61 @@ +'use client'; + +import { Fragment, useState } from 'react'; +import type { AI } from './actions'; +import { useActions } from 'ai/rsc'; + +import { useAIState, useUIState } from 'ai/rsc'; +import { generateId } from 'ai'; +import { Message } from './message'; + +export default function Home() { + const [input, setInput] = useState(''); + const [messages, setMessages] = useUIState(); + const { submitUserMessage } = useActions(); + + const handleSubmission = async () => { + setMessages(currentMessages => [ + ...currentMessages, + { + id: generateId(), + display: {input}, + }, + ]); + + const response = await submitUserMessage(input); + setMessages(currentMessages => [...currentMessages, response]); + setInput(''); + }; + + return ( +
+
+ setInput(event.target.value)} + placeholder="Ask a question" + onKeyDown={event => { + if (event.key === 'Enter') { + handleSubmission(); + } + }} + /> + +
+ +
+
+ {messages.map(message => ( + {message.display} + ))} +
+
+
+ ); +} diff --git a/packages/core/rsc/stream-ui/__snapshots__/stream-ui.ui.test.tsx.snap b/packages/core/rsc/stream-ui/__snapshots__/stream-ui.ui.test.tsx.snap new file mode 100644 index 000000000000..d67250518908 --- /dev/null +++ b/packages/core/rsc/stream-ui/__snapshots__/stream-ui.ui.test.tsx.snap @@ -0,0 +1,131 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`rsc - streamUI() > should emit React Nodes with async render function 1`] = ` +{ + "children": {}, + "props": { + "s": { + "curr": undefined, + "next": { + "curr":
+ Weather +
, + }, + "type": Symbol(ui.streamable.value), + }, + }, + "type": "InternalStreamableUIClient", +} +`; + +exports[`rsc - streamUI() > should emit React Nodes with async streamUI function 1`] = ` +{ + "children": {}, + "props": { + "s": { + "curr": undefined, + "next": { + "curr":
+ Weather +
, + }, + "type": Symbol(ui.streamable.value), + }, + }, + "type": "InternalStreamableUIClient", +} +`; + +exports[`rsc - streamUI() > should emit React Nodes with generator render function 1`] = ` +{ + "children": {}, + "props": { + "s": { + "curr": undefined, + "next": { + "curr":
+ Loading... +
, + "next": { + "curr":
+ Weather +
, + }, + }, + "type": Symbol(ui.streamable.value), + }, + }, + "type": "InternalStreamableUIClient", +} +`; + +exports[`rsc - streamUI() > should emit React Nodes with generator streamUI function 1`] = ` +{ + "children": {}, + "props": { + "s": { + "curr": undefined, + "next": { + "curr":
+ Loading... +
, + "next": { + "curr":
+ Weather +
, + }, + }, + "type": Symbol(ui.streamable.value), + }, + }, + "type": "InternalStreamableUIClient", +} +`; + +exports[`rsc - streamUI() > should emit React Nodes with sync render function 1`] = ` +{ + "children": {}, + "props": { + "s": { + "curr": undefined, + "next": { + "curr":
+ Weather +
, + }, + "type": Symbol(ui.streamable.value), + }, + }, + "type": "InternalStreamableUIClient", +} +`; + +exports[`rsc - streamUI() > should emit React Nodes with sync streamUI function 1`] = ` +{ + "children": {}, + "props": { + "s": { + "curr": undefined, + "next": { + "curr":
+ Weather +
, + }, + "type": Symbol(ui.streamable.value), + }, + }, + "type": "InternalStreamableUIClient", +} +`; + +exports[`rsc - streamUI() onFinish callback > should contain final React node 1`] = ` + +`; diff --git a/packages/core/rsc/stream-ui/stream-ui.tsx b/packages/core/rsc/stream-ui/stream-ui.tsx index 8cca5cef6623..d358125e9fd9 100644 --- a/packages/core/rsc/stream-ui/stream-ui.tsx +++ b/packages/core/rsc/stream-ui/stream-ui.tsx @@ -13,10 +13,14 @@ import { getValidatedPrompt } from '../../core/prompt/get-validated-prompt'; import { prepareCallSettings } from '../../core/prompt/prepare-call-settings'; import { prepareToolsAndToolChoice } from '../../core/prompt/prepare-tools-and-tool-choice'; import { Prompt } from '../../core/prompt/prompt'; -import { CoreToolChoice } from '../../core/types'; +import { CallWarning, CoreToolChoice, FinishReason } from '../../core/types'; import { retryWithExponentialBackoff } from '../../core/util/retry-with-exponential-backoff'; import { createStreamableUI } from '../streamable'; import { createResolvablePromise } from '../utils'; +import { + TokenUsage, + calculateTokenUsage, +} from '../../core/generate-text/token-usage'; type Streamable = ReactNode | Promise; @@ -84,6 +88,7 @@ export async function streamUI< abortSignal, initial, text, + onFinish, ...settings }: CallSettings & Prompt & { @@ -100,12 +105,42 @@ export async function streamUI< }; /** -The tool choice strategy. Default: 'auto'. + * The tool choice strategy. Default: 'auto'. */ toolChoice?: CoreToolChoice; text?: RenderText; initial?: ReactNode; + /** + * Callback that is called when the LLM response and the final object validation are finished. + */ + onFinish?: (event: { + /** + * The reason why the generation finished. + */ + finishReason: FinishReason; + /** + * The token usage of the generated response. + */ + usage: TokenUsage; + /** + * The final ui node that was generated. + */ + value: ReactNode; + /** + * Warnings from the model provider (e.g. unsupported settings) + */ + warnings?: CallWarning[]; + /** + * Optional raw response data. + */ + rawResponse?: { + /** + * Response headers. + */ + headers?: Record; + }; + }) => Promise | void; }): Promise { // TODO: Remove these errors after the experimental phase. if (typeof model === 'string') { @@ -311,7 +346,13 @@ The tool choice strategy. Default: 'auto'. } case 'finish': { - // Nothing to do here. + onFinish?.({ + finishReason: value.finishReason, + usage: calculateTokenUsage(value.usage), + value: ui.value, + warnings: result.warnings, + rawResponse: result.rawResponse, + }); } } } diff --git a/packages/core/rsc/stream-ui/stream-ui.ui.test.tsx b/packages/core/rsc/stream-ui/stream-ui.ui.test.tsx new file mode 100644 index 000000000000..255b2f772c7d --- /dev/null +++ b/packages/core/rsc/stream-ui/stream-ui.ui.test.tsx @@ -0,0 +1,217 @@ +import { convertArrayToReadableStream } from '@ai-sdk/provider-utils/test'; +import assert from 'node:assert'; +import { z } from 'zod'; +import { MockLanguageModelV1 } from '../../core/test/mock-language-model-v1'; +import { + openaiChatCompletionChunks, + openaiFunctionCallChunks, +} from '../../tests/snapshots/openai-chat'; +import { + DEFAULT_TEST_URL, + createMockServer, +} from '../../tests/utils/mock-server'; +import { streamUI } from './stream-ui'; + +const FUNCTION_CALL_TEST_URL = DEFAULT_TEST_URL + 'mock-func-call'; + +const server = createMockServer([ + { + url: DEFAULT_TEST_URL, + chunks: openaiChatCompletionChunks, + formatChunk: chunk => `data: ${JSON.stringify(chunk)}\n\n`, + suffix: 'data: [DONE]', + }, + { + url: FUNCTION_CALL_TEST_URL, + chunks: openaiFunctionCallChunks, + formatChunk: chunk => `data: ${JSON.stringify(chunk)}\n\n`, + suffix: 'data: [DONE]', + }, +]); + +beforeAll(() => { + server.listen(); +}); + +afterEach(() => { + server.resetHandlers(); +}); + +afterAll(() => { + server.close(); +}); + +async function recursiveResolve(val: any): Promise { + if (val && typeof val === 'object' && typeof val.then === 'function') { + return await recursiveResolve(await val); + } + + if (Array.isArray(val)) { + return await Promise.all(val.map(recursiveResolve)); + } + + if (val && typeof val === 'object') { + const result: any = {}; + for (const key in val) { + result[key] = await recursiveResolve(val[key]); + } + return result; + } + + return val; +} + +async function simulateFlightServerRender(node: React.ReactNode) { + async function traverse(node: React.ReactNode): Promise { + if (!node || typeof node !== 'object' || !('props' in node)) return {}; // only traverse React elements + + // Let's only do one level of promise resolution here. As it's only for testing purposes. + const props = await recursiveResolve({ ...node.props } || {}); + + const { type } = node; + const { children, ...otherProps } = props; + const typeName = typeof type === 'function' ? type.name : String(type); + + return { + type: typeName, + props: otherProps, + children: + typeof children === 'string' + ? children + : Array.isArray(children) + ? children.map(traverse) + : await traverse(children), + }; + } + + return traverse(node); +} + +const mockToolModel = new MockLanguageModelV1({ + doStream: async () => { + return { + stream: convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallType: 'function', + toolCallId: 'call-1', + toolName: 'get_current_weather', + args: `{}`, + }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: 'prompt', rawSettings: {} }, + }; + }, +}); + +describe('rsc - streamUI()', () => { + it('should emit React Nodes with sync streamUI function', async () => { + const ui = await streamUI({ + model: mockToolModel, + messages: [], + tools: { + get_current_weather: { + description: 'Get the current weather', + parameters: z.object({}), + generate: () => { + return
Weather
; + }, + }, + }, + }); + + const rendered = await simulateFlightServerRender(ui.value); + expect(rendered).toMatchSnapshot(); + }); + + it('should emit React Nodes with async streamUI function', async () => { + const ui = await streamUI({ + model: mockToolModel, + messages: [], + tools: { + get_current_weather: { + description: 'Get the current weather', + parameters: z.object({}), + generate: async () => { + await new Promise(resolve => setTimeout(resolve, 100)); + return
Weather
; + }, + }, + }, + }); + + const rendered = await simulateFlightServerRender(ui.value); + expect(rendered).toMatchSnapshot(); + }); + + it('should emit React Nodes with generator streamUI function', async () => { + const ui = await streamUI({ + model: mockToolModel, + messages: [], + tools: { + get_current_weather: { + description: 'Get the current weather', + parameters: z.object({}), + generate: async function* () { + yield
Loading...
; + await new Promise(resolve => setTimeout(resolve, 100)); + return
Weather
; + }, + }, + }, + }); + + const rendered = await simulateFlightServerRender(ui.value); + expect(rendered).toMatchSnapshot(); + }); +}); + +describe('rsc - streamUI() onFinish callback', () => { + let result: Parameters< + Required[0]>['onFinish'] + >[0]; + + beforeEach(async () => { + const ui = await streamUI({ + model: mockToolModel, + messages: [], + tools: { + get_current_weather: { + description: 'Get the current weather', + parameters: z.object({}), + generate: () => { + return 'Weather'; + }, + }, + }, + onFinish: event => { + result = event; + }, + }); + + // consume stream + await simulateFlightServerRender(ui.value); + }); + + it('should contain token usage', () => { + assert.deepStrictEqual(result.usage, { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }); + }); + + it('should contain finish reason', async () => { + assert.strictEqual(result.finishReason, 'stop'); + }); + + it('should contain final React node', async () => { + expect(result.value).toMatchSnapshot(); + }); +});