Skip to content

Commit

Permalink
fix (ai/core): only append assistant response at the end when there i…
Browse files Browse the repository at this point in the history
…s a final user message (#4623)
  • Loading branch information
lgrammel authored Jan 31, 2025
1 parent d99c7fb commit ca89615
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .changeset/six-jokes-join.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

fix (ai/core): only append assistant response at the end when there is a final user message
171 changes: 159 additions & 12 deletions packages/ai/core/prompt/convert-to-core-messages.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ describe('assistant message', () => {
{
role: 'assistant',
content: [
{
type: 'text',
text: 'Let me calculate that for you.',
},
{
type: 'tool-call',
toolCallId: 'call1',
Expand All @@ -249,10 +253,6 @@ describe('assistant message', () => {
},
],
},
{
role: 'assistant',
content: 'Let me calculate that for you.',
},
]);
});

Expand Down Expand Up @@ -290,6 +290,10 @@ describe('assistant message', () => {
{
role: 'assistant',
content: [
{
type: 'text',
text: 'Let me calculate that for you.',
},
{
type: 'tool-call',
toolCallId: 'call1',
Expand All @@ -310,10 +314,6 @@ describe('assistant message', () => {
},
],
},
{
role: 'assistant',
content: 'Let me calculate that for you.',
},
]);
});

Expand Down Expand Up @@ -400,6 +400,10 @@ describe('assistant message', () => {
{
role: 'assistant',
content: [
{
type: 'text',
text: 'response',
},
{
type: 'tool-call',
toolCallId: 'call-1',
Expand Down Expand Up @@ -475,10 +479,6 @@ describe('assistant message', () => {
},
],
},
{
role: 'assistant',
content: 'response',
},
]);
});
});
Expand Down Expand Up @@ -525,6 +525,153 @@ describe('multiple messages', () => {
},
]);
});

it('should handle conversation with multiple tool invocations and user message at the end', () => {
const tools = {
screenshot: tool({
parameters: z.object({ value: z.string() }),
execute: async () => 'imgbase64',
}),
};

const result = convertToCoreMessages(
[
{
role: 'assistant',
content: 'response',
toolInvocations: [
{
state: 'result',
toolCallId: 'call-1',
toolName: 'screenshot',
args: { value: 'value-1' },
result: 'result-1',
step: 0,
},
{
state: 'result',
toolCallId: 'call-2',
toolName: 'screenshot',
args: { value: 'value-2' },
result: 'result-2',
step: 1,
},

{
state: 'result',
toolCallId: 'call-3',
toolName: 'screenshot',
args: { value: 'value-3' },
result: 'result-3',
step: 1,
},
{
state: 'result',
toolCallId: 'call-4',
toolName: 'screenshot',
args: { value: 'value-4' },
result: 'result-4',
step: 2,
},
],
},
{
role: 'user',
content: 'Thanks!',
},
],
{ tools }, // separate tools to ensure that types are inferred correctly
);

expect(result).toEqual([
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'screenshot',
args: { value: 'value-1' },
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call-1',
toolName: 'screenshot',
result: 'result-1',
},
],
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'call-2',
toolName: 'screenshot',
args: { value: 'value-2' },
},
{
type: 'tool-call',
toolCallId: 'call-3',
toolName: 'screenshot',
args: { value: 'value-3' },
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call-2',
toolName: 'screenshot',
result: 'result-2',
},
{
type: 'tool-result',
toolCallId: 'call-3',
toolName: 'screenshot',
result: 'result-3',
},
],
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'call-4',
toolName: 'screenshot',
args: { value: 'value-4' },
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call-4',
toolName: 'screenshot',
result: 'result-4',
},
],
},
{
role: 'assistant',
content: 'response',
},
{
role: 'user',
content: 'Thanks!',
},
]);
});
});

describe('error handling', () => {
Expand Down
9 changes: 7 additions & 2 deletions packages/ai/core/prompt/convert-to-core-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ export function convertToCoreMessages<TOOLS extends ToolSet = never>(
const tools = options?.tools ?? ({} as TOOLS);
const coreMessages: CoreMessage[] = [];

for (const message of messages) {
for (let i = 0; i < messages.length; i++) {
const message = messages[i];
const isLastMessage = i === messages.length - 1;
const { role, content, toolInvocations, experimental_attachments } =
message;

Expand Down Expand Up @@ -64,6 +66,9 @@ export function convertToCoreMessages<TOOLS extends ToolSet = never>(
coreMessages.push({
role: 'assistant',
content: [
...(isLastMessage && content && i === 0
? [{ type: 'text' as const, text: content }]
: []),
...stepInvocations.map(
({ toolCallId, toolName, args }): ToolCallPart => ({
type: 'tool-call' as const,
Expand Down Expand Up @@ -110,7 +115,7 @@ export function convertToCoreMessages<TOOLS extends ToolSet = never>(
});
}

if (content) {
if (content && !isLastMessage) {
coreMessages.push({ role: 'assistant', content });
}

Expand Down

0 comments on commit ca89615

Please sign in to comment.