Skip to content

Commit

Permalink
feat (ai/core): forward abort signal to tools (#3277)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Oct 16, 2024
1 parent 5817da3 commit a23da5b
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .changeset/olive-roses-bathe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): forward abort signal to tools
29 changes: 29 additions & 0 deletions content/docs/03-ai-sdk-core/15-tools-and-tool-calling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,35 @@ const result = await generateText({
});
```

## Abort Signals

The abort signals from `generateText` and `streamText` are forwarded to the tool execution.
You can access them in the second parameter of the `execute` function and e.g. abort long-running computations or forward them to fetch calls inside tools.

```ts highlight="12,15"
import { z } from 'zod';
import { generateText, tool } from 'ai';

const result = await generateText({
model: yourModel,
tools: {
weather: tool({
description: 'Get the weather in a location',
parameters: z.object({
location: z.string().describe('The location to get the weather for'),
}),
execute: async ({ location }, { abortSignal }) => {
return fetch(
`https://api.weatherapi.com/v1/current.json?q=${location}`,
{ signal: abortSignal }, // forward the abort signal
);
},
}),
},
prompt: 'What is the weather in San Francisco?',
});
```

## Prompt Engineering with Tools

When you create prompts that include tools, getting good results can be tricky as the number and complexity of your tools increases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ To see `generateText` in action, check out [these examples](#examples).
{
name: 'execute',
isOptional: true,
type: 'async (parameters) => any',
type: 'async (parameters: T, options: { abortSignal: AbortSignal }) => JSONValue',
description:
'An async function that is called with the arguments from the tool call and produces a result. If not provided, the tool will not be executed automatically.',
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ To see `streamText` in action, check out [these examples](#examples).
{
name: 'execute',
isOptional: true,
type: 'async (parameters) => any',
type: 'async (parameters: T, options: { abortSignal: AbortSignal }) => JSONValue',
description:
'An async function that is called with the arguments from the tool call and produces a result. If not provided, the tool will not be executed automatically.',
},
Expand Down
2 changes: 1 addition & 1 deletion content/docs/07-reference/01-ai-sdk-core/20-tool.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export const weatherTool = tool({
{
name: 'execute',
isOptional: true,
type: 'async (parameters) => any',
type: 'async (parameters: T, options: { abortSignal: AbortSignal }) => JSONValue',
description:
'An async function that is called with the arguments from the tool call and produces a result. If not provided, the tool will not be executed automatically.',
},
Expand Down
41 changes: 41 additions & 0 deletions packages/ai/core/generate-text/generate-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,47 @@ describe('options.providerMetadata', () => {
});
});

describe('options.abortSignal', () => {
it('should forward abort signal to tool execution', async () => {
const abortController = new AbortController();
const toolExecuteMock = vi.fn().mockResolvedValue('tool result');

const generateTextPromise = generateText({
model: new MockLanguageModelV1({
doGenerate: async () => ({
...dummyResponseValues,
toolCalls: [
{
toolCallType: 'function',
toolCallId: 'call-1',
toolName: 'tool1',
args: `{ "value": "value" }`,
},
],
}),
}),
tools: {
tool1: {
parameters: z.object({ value: z.string() }),
execute: toolExecuteMock,
},
},
prompt: 'test-input',
abortSignal: abortController.signal,
});

// Abort the operation
abortController.abort();

await generateTextPromise;

expect(toolExecuteMock).toHaveBeenCalledWith(
{ value: 'value' },
{ abortSignal: abortController.signal },
);
});
});

describe('telemetry', () => {
let tracer: MockTracer;

Expand Down
5 changes: 4 additions & 1 deletion packages/ai/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ functionality that can be fully encapsulated in the provider.
tools,
tracer,
telemetry,
abortSignal,
});

// token usage:
Expand Down Expand Up @@ -537,11 +538,13 @@ async function executeTools<TOOLS extends Record<string, CoreTool>>({
tools,
tracer,
telemetry,
abortSignal,
}: {
toolCalls: ToToolCallArray<TOOLS>;
tools: TOOLS;
tracer: Tracer;
telemetry: TelemetrySettings | undefined;
abortSignal: AbortSignal | undefined;
}): Promise<ToToolResultArray<TOOLS>> {
const toolResults = await Promise.all(
toolCalls.map(async toolCall => {
Expand Down Expand Up @@ -569,7 +572,7 @@ async function executeTools<TOOLS extends Record<string, CoreTool>>({
}),
tracer,
fn: async span => {
const result = await tool.execute!(toolCall.args);
const result = await tool.execute!(toolCall.args, { abortSignal });

try {
span.setAttributes(
Expand Down
4 changes: 3 additions & 1 deletion packages/ai/core/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
toolCallStreaming,
tracer,
telemetry,
abortSignal,
}: {
tools: TOOLS | undefined;
generatorStream: ReadableStream<LanguageModelV1StreamPart>;
toolCallStreaming: boolean;
tracer: Tracer;
telemetry: TelemetrySettings | undefined;
abortSignal: AbortSignal | undefined;
}): ReadableStream<SingleRequestTextStreamPart<TOOLS>> {
let canClose = false;
const outstandingToolCalls = new Set<string>();
Expand Down Expand Up @@ -195,7 +197,7 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
}),
tracer,
fn: async span =>
tool.execute!(toolCall.args).then(
tool.execute!(toolCall.args, { abortSignal }).then(
(result: any) => {
toolResultsStreamController!.enqueue({
...toolCall,
Expand Down
46 changes: 46 additions & 0 deletions packages/ai/core/generate-text/stream-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,52 @@ describe('options.providerMetadata', () => {
});
});

describe('options.abortSignal', () => {
it('should forward abort signal to tool execution during streaming', async () => {
const abortController = new AbortController();
const toolExecuteMock = vi.fn().mockResolvedValue('tool result');

const result = await streamText({
model: new MockLanguageModelV1({
doStream: async () => ({
stream: convertArrayToReadableStream([
{
type: 'tool-call',
toolCallType: 'function',
toolCallId: 'call-1',
toolName: 'tool1',
args: `{ "value": "value" }`,
},
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 10, completionTokens: 20 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
}),
}),
tools: {
tool1: {
parameters: z.object({ value: z.string() }),
execute: toolExecuteMock,
},
},
prompt: 'test-input',
abortSignal: abortController.signal,
});

convertAsyncIterableToArray(result.fullStream);

abortController.abort();

expect(toolExecuteMock).toHaveBeenCalledWith(
{ value: 'value' },
{ abortSignal: abortController.signal },
);
});
});

describe('telemetry', () => {
let tracer: MockTracer;

Expand Down
1 change: 1 addition & 0 deletions packages/ai/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ need to be added separately.
toolCallStreaming,
tracer,
telemetry,
abortSignal,
}),
warnings,
rawResponse,
Expand Down
10 changes: 9 additions & 1 deletion packages/ai/core/tool/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@ Use descriptions to make the input understandable for the language model.
/**
An async function that is called with the arguments from the tool call and produces a result.
If not provided, the tool will not be executed automatically.
@args is the input of the tool call.
@options.abortSignal is a signal that can be used to abort the tool call.
*/
execute?: (args: inferParameters<PARAMETERS>) => PromiseLike<RESULT>;
execute?: (
args: inferParameters<PARAMETERS>,
options: {
abortSignal?: AbortSignal;
},
) => PromiseLike<RESULT>;
}

/**
Expand Down

0 comments on commit a23da5b

Please sign in to comment.