Skip to content

Commit

Permalink
feat (ai/core): add onStepFinish callback to streamText (#3019)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Sep 16, 2024
1 parent 255457d commit 83da52c
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 412 deletions.
5 changes: 5 additions & 0 deletions .changeset/lemon-pandas-know.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): add onStepFinish callback to streamText
24 changes: 22 additions & 2 deletions content/docs/03-ai-sdk-core/05-generating-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,37 @@ const result = await streamText({
});
```

### `onStepFinish` callback

When using `streamText`, you can provide an `onStepFinish` callback that is triggered when a step is finished,
i.e. all text deltas, tool calls, and tool results for the step are available.
When you have multiple steps, the callback is triggered for each step.

```tsx highlight="7-9"
import { streamText } from 'ai';

const result = await streamText({
model: yourModel,
prompt: 'Invent a new holiday and describe its traditions.',
maxSteps: 5, // more than one step
onStepFinish({ text, toolCalls, toolResults, finishReason, usage }) {
// your own logic, e.g. for saving the chat history or recording usage
},
});
```

### `onFinish` callback

When using `streamText`, you can provide an `onFinish` callback that is triggered when the model finishes generating the response and all tool executions.
When using `streamText`, you can provide an `onFinish` callback that is triggered when all steps are finished.
It contains the text and tool calls from the last step, the combined usage of all steps, and an array of all steps.

```tsx highlight="6-8"
import { streamText } from 'ai';

const result = await streamText({
model: yourModel,
prompt: 'Invent a new holiday and describe its traditions.',
onFinish({ text, toolCalls, toolResults, finishReason, usage }) {
onFinish({ text, finishReason, usage, steps }) {
// your own logic, e.g. for saving the chat history or recording usage
},
});
Expand Down
105 changes: 105 additions & 0 deletions content/docs/07-reference/ai-sdk-core/02-stream-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,111 @@ To see `streamText` in action, check out [these examples](#examples).
},
],
},
{
name: 'onStepFinish',
type: '(result: onStepFinishResult) => Promise<void> | void',
isOptional: true,
description: 'Callback that is called when a step is finished.',
properties: [
{
type: 'onStepFinishResult',
parameters: [
{
name: 'finishReason',
type: '"stop" | "length" | "content-filter" | "tool-calls" | "error" | "other" | "unknown"',
description:
'The reason the model finished generating the text for the step.',
},
{
name: 'usage',
type: 'TokenUsage',
description: 'The token usage of the step.',
properties: [
{
type: 'TokenUsage',
parameters: [
{
name: 'promptTokens',
type: 'number',
description: 'The total number of tokens in the prompt.',
},
{
name: 'completionTokens',
type: 'number',
description:
'The total number of tokens in the completion.',
},
{
name: 'totalTokens',
type: 'number',
description: 'The total number of tokens generated.',
},
],
},
],
},
{
name: 'text',
type: 'string',
description: 'The full text that has been generated.',
},
{
name: 'toolCalls',
type: 'ToolCall[]',
description: 'The tool calls that have been executed.',
},
{
name: 'toolResults',
type: 'ToolResult[]',
description: 'The tool results that have been generated.',
},
{
name: 'warnings',
type: 'Warning[] | undefined',
description:
'Warnings from the model provider (e.g. unsupported settings).',
},
{
name: 'response',
type: 'Response',
optional: true,
description: 'Response metadata.',
properties: [
{
type: 'Response',
parameters: [
{
name: 'id',
type: 'string',
description:
'The response identifier. The AI SDK uses the ID from the provider response when available, and generates an ID otherwise.',
},
{
name: 'model',
type: 'string',
description:
'The model that was used to generate the response. The AI SDK uses the response model from the provider response when available, and the model from the function call otherwise.',
},
{
name: 'timestamp',
type: 'Date',
description:
'The timestamp of the response. The AI SDK uses the response timestamp from the provider response when available, and creates a timestamp otherwise.',
},
{
name: 'headers',
optional: true,
type: 'Record<string, string>',
description: 'Optional response headers.',
},
],
},
],
},
],
},
],
},
{
name: 'onFinish',
type: '(result: OnFinishResult) => Promise<void> | void',
Expand Down
1 change: 1 addition & 0 deletions examples/ai-core/src/stream-text/openai-on-finish-steps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async function main() {
prompt: 'What is the current weather in San Francisco?',
});

// consume the text stream
for await (const textPart of result.textStream) {
}
}
Expand Down
30 changes: 30 additions & 0 deletions examples/ai-core/src/stream-text/openai-on-step-finish.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { openai } from '@ai-sdk/openai';
import { streamText, tool } from 'ai';
import 'dotenv/config';
import { z } from 'zod';

async function main() {
const result = await streamText({
model: openai('gpt-4o'),
tools: {
weather: tool({
description: 'Get the weather in a location',
parameters: z.object({ location: z.string() }),
execute: async () => ({
temperature: 72 + Math.floor(Math.random() * 21) - 10,
}),
}),
},
maxSteps: 5,
onStepFinish(step) {
console.log(JSON.stringify(step, null, 2));
},
prompt: 'What is the current weather in San Francisco?',
});

// consume the text stream
for await (const textPart of result.textStream) {
}
}

main().catch(console.error);
112 changes: 106 additions & 6 deletions packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ exports[`multiple stream consumption > should support text stream, ai stream, fu
}
`;

exports[`options.maxSteps > 2 steps > onFinish should send correct information 1`] = `
exports[`options.maxSteps > 2 steps > callbacks > onFinish should send correct information 1`] = `
{
"experimental_providerMetadata": undefined,
"finishReason": "stop",
"logprobs": undefined,
"rawResponse": undefined,
"response": {
"headers": undefined,
Expand Down Expand Up @@ -154,7 +155,7 @@ exports[`options.maxSteps > 2 steps > onFinish should send correct information 1
}
`;

exports[`options.maxSteps > 2 steps > should contain all steps 1`] = `
exports[`options.maxSteps > 2 steps > callbacks > onStepFinish should send correct information 1`] = `
[
{
"finishReason": "tool-calls",
Expand Down Expand Up @@ -427,14 +428,81 @@ exports[`options.maxSteps > 2 steps > should record telemetry data for each step
]
`;

exports[`options.onFinish should send correct information 1`] = `
exports[`options.maxSteps > 2 steps > value promises > result.steps should contain all steps 1`] = `
[
{
"finishReason": "tool-calls",
"logprobs": undefined,
"rawResponse": undefined,
"response": {
"id": "id-0",
"modelId": "mock-model-id",
"timestamp": 1970-01-01T00:00:00.000Z,
},
"text": "",
"toolCalls": [
{
"args": {
"value": "value",
},
"toolCallId": "call-1",
"toolName": "tool1",
"type": "tool-call",
},
],
"toolResults": [
{
"args": {
"value": "value",
},
"result": "result1",
"toolCallId": "call-1",
"toolName": "tool1",
"type": "tool-result",
},
],
"usage": {
"completionTokens": 10,
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"finishReason": "stop",
"logprobs": undefined,
"rawResponse": {
"headers": {
"call": "2",
},
},
"response": {
"id": "id-1",
"modelId": "mock-model-id",
"timestamp": 1970-01-01T00:00:01.000Z,
},
"text": "Hello, world!",
"toolCalls": [],
"toolResults": [],
"usage": {
"completionTokens": 5,
"promptTokens": 1,
"totalTokens": 6,
},
"warnings": undefined,
},
]
`;

exports[`options.onFinish > options.onFinish should send correct information 1`] = `
{
"experimental_providerMetadata": {
"testProvider": {
"testKey": "testValue",
},
},
"finishReason": "stop",
"logprobs": undefined,
"rawResponse": {
"headers": {
"call": "2",
Expand Down Expand Up @@ -960,7 +1028,37 @@ exports[`result.fullStream > should use fallback response metadata when response
]
`;

exports[`result.toAIStream > should send tool call and tool result stream parts 1`] = `
exports[`result.toAIStream > should transform textStream through callbacks and data transformers 1`] = `
[
"0:"Hello"
",
"0:", "
",
"0:"world!"
",
"e:{"finishReason":"stop","usage":{"promptTokens":3,"completionTokens":10}}
",
"d:{"finishReason":"stop","usage":{"promptTokens":3,"completionTokens":10}}
",
]
`;

exports[`result.toDataStream > should create a data stream 1`] = `
[
"0:"Hello"
",
"0:", "
",
"0:"world!"
",
"e:{"finishReason":"stop","usage":{"promptTokens":3,"completionTokens":10}}
",
"d:{"finishReason":"stop","usage":{"promptTokens":3,"completionTokens":10}}
",
]
`;

exports[`result.toDataStream > should send tool call and tool result stream parts 1`] = `
[
"9:{"toolCallId":"call-1","toolName":"tool1","args":{"value":"value"}}
",
Expand All @@ -973,7 +1071,7 @@ exports[`result.toAIStream > should send tool call and tool result stream parts
]
`;

exports[`result.toAIStream > should send tool call, tool call stream start, tool call deltas, and tool result stream parts when tool call delta flag is enabled 1`] = `
exports[`result.toDataStream > should send tool call, tool call stream start, tool call deltas, and tool result stream parts when tool call delta flag is enabled 1`] = `
[
"b:{"toolCallId":"call-1","toolName":"tool1"}
",
Expand All @@ -992,8 +1090,10 @@ exports[`result.toAIStream > should send tool call, tool call stream start, tool
]
`;

exports[`result.toAIStream > should transform textStream through callbacks and data transformers 1`] = `
exports[`result.toDataStream > should support merging with existing stream data 1`] = `
[
"2:["stream-data-value"]
",
"0:"Hello"
",
"0:", "
Expand Down
Loading

0 comments on commit 83da52c

Please sign in to comment.