Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (ai/core): forward abort signal to tools #3277

Merged
merged 8 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/olive-roses-bathe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): forward abort signal to tools
29 changes: 29 additions & 0 deletions content/docs/03-ai-sdk-core/15-tools-and-tool-calling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,35 @@ const result = await generateText({
});
```

## Abort Signals

The abort signals from `generateText` and `streamText` are forwarded to the tool execution.
You can access them in the second parameter of the `execute` function and e.g. abort long-running computations or forward them to fetch calls inside tools.

```ts highlight="12,15"
import { z } from 'zod';
import { generateText, tool } from 'ai';

const result = await generateText({
model: yourModel,
tools: {
weather: tool({
description: 'Get the weather in a location',
parameters: z.object({
location: z.string().describe('The location to get the weather for'),
}),
execute: async ({ location }, { abortSignal }) => {
return fetch(
`https://api.weatherapi.com/v1/current.json?q=${location}`,
{ signal: abortSignal }, // forward the abort signal
);
},
}),
},
prompt: 'What is the weather in San Francisco?',
});
```

## Prompt Engineering with Tools

When you create prompts that include tools, getting good results can be tricky as the number and complexity of your tools increases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ To see `generateText` in action, check out [these examples](#examples).
{
name: 'execute',
isOptional: true,
type: 'async (parameters) => any',
type: 'async (parameters: T, options: { abortSignal: AbortSignal }) => JSONValue',
description:
'An async function that is called with the arguments from the tool call and produces a result. If not provided, the tool will not be executed automatically.',
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ To see `streamText` in action, check out [these examples](#examples).
{
name: 'execute',
isOptional: true,
type: 'async (parameters) => any',
type: 'async (parameters: T, options: { abortSignal: AbortSignal }) => JSONValue',
description:
'An async function that is called with the arguments from the tool call and produces a result. If not provided, the tool will not be executed automatically.',
},
Expand Down
2 changes: 1 addition & 1 deletion content/docs/07-reference/01-ai-sdk-core/20-tool.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export const weatherTool = tool({
{
name: 'execute',
isOptional: true,
type: 'async (parameters) => any',
type: 'async (parameters: T, options: { abortSignal: AbortSignal }) => JSONValue',
description:
'An async function that is called with the arguments from the tool call and produces a result. If not provided, the tool will not be executed automatically.',
},
Expand Down
41 changes: 41 additions & 0 deletions packages/ai/core/generate-text/generate-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,47 @@ describe('options.providerMetadata', () => {
});
});

describe('options.abortSignal', () => {
it('should forward abort signal to tool execution', async () => {
const abortController = new AbortController();
const toolExecuteMock = vi.fn().mockResolvedValue('tool result');

const generateTextPromise = generateText({
model: new MockLanguageModelV1({
doGenerate: async () => ({
...dummyResponseValues,
toolCalls: [
{
toolCallType: 'function',
toolCallId: 'call-1',
toolName: 'tool1',
args: `{ "value": "value" }`,
},
],
}),
}),
tools: {
tool1: {
parameters: z.object({ value: z.string() }),
execute: toolExecuteMock,
},
},
prompt: 'test-input',
abortSignal: abortController.signal,
});

// Abort the operation
abortController.abort();

await generateTextPromise;

expect(toolExecuteMock).toHaveBeenCalledWith(
{ value: 'value' },
{ abortSignal: abortController.signal },
);
});
});

describe('telemetry', () => {
let tracer: MockTracer;

Expand Down
5 changes: 4 additions & 1 deletion packages/ai/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ functionality that can be fully encapsulated in the provider.
tools,
tracer,
telemetry,
abortSignal,
});

// token usage:
Expand Down Expand Up @@ -537,11 +538,13 @@ async function executeTools<TOOLS extends Record<string, CoreTool>>({
tools,
tracer,
telemetry,
abortSignal,
}: {
toolCalls: ToToolCallArray<TOOLS>;
tools: TOOLS;
tracer: Tracer;
telemetry: TelemetrySettings | undefined;
abortSignal: AbortSignal | undefined;
}): Promise<ToToolResultArray<TOOLS>> {
const toolResults = await Promise.all(
toolCalls.map(async toolCall => {
Expand Down Expand Up @@ -569,7 +572,7 @@ async function executeTools<TOOLS extends Record<string, CoreTool>>({
}),
tracer,
fn: async span => {
const result = await tool.execute!(toolCall.args);
const result = await tool.execute!(toolCall.args, { abortSignal });

try {
span.setAttributes(
Expand Down
4 changes: 3 additions & 1 deletion packages/ai/core/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
toolCallStreaming,
tracer,
telemetry,
abortSignal,
}: {
tools: TOOLS | undefined;
generatorStream: ReadableStream<LanguageModelV1StreamPart>;
toolCallStreaming: boolean;
tracer: Tracer;
telemetry: TelemetrySettings | undefined;
abortSignal: AbortSignal | undefined;
}): ReadableStream<SingleRequestTextStreamPart<TOOLS>> {
let canClose = false;
const outstandingToolCalls = new Set<string>();
Expand Down Expand Up @@ -195,7 +197,7 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
}),
tracer,
fn: async span =>
tool.execute!(toolCall.args).then(
tool.execute!(toolCall.args, { abortSignal }).then(
(result: any) => {
toolResultsStreamController!.enqueue({
...toolCall,
Expand Down
46 changes: 46 additions & 0 deletions packages/ai/core/generate-text/stream-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,52 @@ describe('options.providerMetadata', () => {
});
});

describe('options.abortSignal', () => {
it('should forward abort signal to tool execution during streaming', async () => {
const abortController = new AbortController();
const toolExecuteMock = vi.fn().mockResolvedValue('tool result');

const result = await streamText({
model: new MockLanguageModelV1({
doStream: async () => ({
stream: convertArrayToReadableStream([
{
type: 'tool-call',
toolCallType: 'function',
toolCallId: 'call-1',
toolName: 'tool1',
args: `{ "value": "value" }`,
},
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 10, completionTokens: 20 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
}),
}),
tools: {
tool1: {
parameters: z.object({ value: z.string() }),
execute: toolExecuteMock,
},
},
prompt: 'test-input',
abortSignal: abortController.signal,
});

convertAsyncIterableToArray(result.fullStream);

abortController.abort();

expect(toolExecuteMock).toHaveBeenCalledWith(
{ value: 'value' },
{ abortSignal: abortController.signal },
);
});
});

describe('telemetry', () => {
let tracer: MockTracer;

Expand Down
1 change: 1 addition & 0 deletions packages/ai/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ need to be added separately.
toolCallStreaming,
tracer,
telemetry,
abortSignal,
}),
warnings,
rawResponse,
Expand Down
10 changes: 9 additions & 1 deletion packages/ai/core/tool/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@ Use descriptions to make the input understandable for the language model.
/**
An async function that is called with the arguments from the tool call and produces a result.
If not provided, the tool will not be executed automatically.

@args is the input of the tool call.
@options.abortSignal is a signal that can be used to abort the tool call.
*/
execute?: (args: inferParameters<PARAMETERS>) => PromiseLike<RESULT>;
execute?: (
args: inferParameters<PARAMETERS>,
options: {
abortSignal?: AbortSignal;
},
) => PromiseLike<RESULT>;
}

/**
Expand Down
Loading