diff --git a/.changeset/plenty-wombats-fry.md b/.changeset/plenty-wombats-fry.md new file mode 100644 index 0000000000..9908e3e671 --- /dev/null +++ b/.changeset/plenty-wombats-fry.md @@ -0,0 +1,5 @@ +--- +'@aws-amplify/ai-constructs': minor +--- + +Use message history instead of event payload for conversational route diff --git a/packages/ai-constructs/API.md b/packages/ai-constructs/API.md index 95ab27f8e1..8141429128 100644 --- a/packages/ai-constructs/API.md +++ b/packages/ai-constructs/API.md @@ -91,7 +91,14 @@ type ConversationTurnEvent = { authorization: string; }; }; - messages: Array; + messages?: Array; + messageHistoryQuery: { + getQueryName: string; + getQueryInputTypeName: string; + listQueryName: string; + listQueryInputTypeName: string; + listQueryLimit?: number; + }; toolsConfiguration?: { dataTools?: Array { const commonEvent: Readonly = { conversationId: '', currentMessageId: '', graphqlApiEndpoint: '', - messages: [ - { - role: 'user', - content: [ - { - text: 'event message', - }, - ], - }, - ], + messageHistoryQuery: { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }, modelConfiguration: { modelId: 'testModelId', systemPrompt: 'testSystemPrompt', @@ -46,6 +48,27 @@ void describe('Bedrock converse adapter', () => { }, }; + const messages: Array = [ + { + role: 'user', + content: [ + { + text: 'event message', + }, + ], + }, + ]; + const messageHistoryRetriever = new ConversationMessageHistoryRetriever( + commonEvent + ); + const messageHistoryRetrieverMockGetEventMessages = mock.method( + messageHistoryRetriever, + 'getMessageHistory', + () => { + return Promise.resolve(messages); + } + ); + void it('calls bedrock to get conversation response', async () => { const event: ConversationTurnEvent = { ...commonEvent, @@ -78,7 +101,9 @@ void describe('Bedrock converse adapter', () => { const responseContent = await new BedrockConverseAdapter( event, [], - bedrockClient + bedrockClient, + undefined, + messageHistoryRetriever ).askBedrock(); assert.deepStrictEqual( @@ -90,7 +115,7 @@ void describe('Bedrock converse adapter', () => { const bedrockRequest = bedrockClientSendMock.mock.calls[0] .arguments[0] as unknown as ConverseCommand; const expectedBedrockInput: ConverseCommandInput = { - messages: event.messages as Array, + messages: messages as Array, modelId: event.modelConfiguration.modelId, inferenceConfig: event.modelConfiguration.inferenceConfiguration, system: [ @@ -211,7 +236,8 @@ void describe('Bedrock converse adapter', () => { event, [additionalTool], bedrockClient, - eventToolsProvider + eventToolsProvider, + messageHistoryRetriever ).askBedrock(); assert.deepStrictEqual( @@ -251,7 +277,7 @@ void describe('Bedrock converse adapter', () => { const bedrockRequest1 = bedrockClientSendMock.mock.calls[0] .arguments[0] as unknown as ConverseCommand; const expectedBedrockInput1: ConverseCommandInput = { - messages: event.messages as Array, + messages: messages as Array, ...expectedBedrockInputCommonProperties, }; assert.deepStrictEqual(bedrockRequest1.input, expectedBedrockInput1); @@ -264,7 +290,7 @@ void describe('Bedrock converse adapter', () => { ); const expectedBedrockInput2: ConverseCommandInput = { messages: [ - ...(event.messages as Array), + ...(messages as Array), additionalToolUseBedrockResponse.output?.message, { role: 'user', @@ -447,7 +473,9 @@ void describe('Bedrock converse adapter', () => { const responseContent = await new BedrockConverseAdapter( event, [tool], - bedrockClient + bedrockClient, + undefined, + messageHistoryRetriever ).askBedrock(); assert.deepStrictEqual( @@ -543,7 +571,9 @@ void describe('Bedrock converse adapter', () => { const responseContent = await new BedrockConverseAdapter( event, [tool], - bedrockClient + bedrockClient, + undefined, + messageHistoryRetriever ).askBedrock(); assert.deepStrictEqual( @@ -645,7 +675,9 @@ void describe('Bedrock converse adapter', () => { const responseContent = await new BedrockConverseAdapter( event, [additionalTool], - bedrockClient + bedrockClient, + undefined, + messageHistoryRetriever ).askBedrock(); assert.deepStrictEqual(responseContent, [clientToolUseBlock]); @@ -682,7 +714,7 @@ void describe('Bedrock converse adapter', () => { const bedrockRequest = bedrockClientSendMock.mock.calls[0] .arguments[0] as unknown as ConverseCommand; const expectedBedrockInput: ConverseCommandInput = { - messages: event.messages as Array, + messages: messages as Array, ...expectedBedrockInputCommonProperties, }; assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput); @@ -695,21 +727,27 @@ void describe('Bedrock converse adapter', () => { const fakeImagePayload = randomBytes(32); - event.messages = [ - { - role: 'user', - content: [ + messageHistoryRetrieverMockGetEventMessages.mock.mockImplementationOnce( + () => { + return Promise.resolve([ { - image: { - format: 'png', - source: { - bytes: fakeImagePayload.toString('base64'), + id: '', + conversationId: '', + role: 'user', + content: [ + { + image: { + format: 'png', + source: { + bytes: fakeImagePayload.toString('base64'), + }, + }, }, - }, + ], }, - ], - }, - ]; + ]); + } + ); const bedrockClient = new BedrockRuntimeClient(); const bedrockResponse: ConverseCommandOutput = { @@ -735,7 +773,13 @@ void describe('Bedrock converse adapter', () => { Promise.resolve(bedrockResponse) ); - await new BedrockConverseAdapter(event, [], bedrockClient).askBedrock(); + await new BedrockConverseAdapter( + event, + [], + bedrockClient, + undefined, + messageHistoryRetriever + ).askBedrock(); assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1); const bedrockRequest = bedrockClientSendMock.mock.calls[0] diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index 7c7a572387..5ab89d09f9 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -14,6 +14,7 @@ import { ToolDefinition, } from './types.js'; import { ConversationTurnEventToolsProvider } from './event-tools-provider'; +import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever'; /** * This class is responsible for interacting with Bedrock Converse API @@ -36,7 +37,10 @@ export class BedrockConverseAdapter { private readonly bedrockClient: BedrockRuntimeClient = new BedrockRuntimeClient( { region: event.modelConfiguration.region } ), - eventToolsProvider = new ConversationTurnEventToolsProvider(event) + eventToolsProvider = new ConversationTurnEventToolsProvider(event), + private readonly messageHistoryRetriever = new ConversationMessageHistoryRetriever( + event + ) ) { this.executableTools = [ ...eventToolsProvider.getEventTools(), @@ -73,7 +77,8 @@ export class BedrockConverseAdapter { const { modelId, systemPrompt, inferenceConfiguration } = this.event.modelConfiguration; - const messages: Array = this.getEventMessagesAsBedrockMessages(); + const messages: Array = + await this.getEventMessagesAsBedrockMessages(); let bedrockResponse: ConverseCommandOutput; do { @@ -124,9 +129,13 @@ export class BedrockConverseAdapter { * 1. Makes a copy so that we don't mutate event. * 2. Decodes Base64 encoded images. */ - private getEventMessagesAsBedrockMessages = (): Array => { + private getEventMessagesAsBedrockMessages = async (): Promise< + Array + > => { const messages: Array = []; - for (const message of this.event.messages) { + const eventMessages = + await this.messageHistoryRetriever.getMessageHistory(); + for (const message of eventMessages) { const messageContent: Array = []; for (const contentElement of message.content) { if (typeof contentElement.image?.source?.bytes === 'string') { diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.test.ts new file mode 100644 index 0000000000..c53963f3f9 --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.test.ts @@ -0,0 +1,413 @@ +import { describe, it, mock } from 'node:test'; +import assert from 'node:assert'; +import { MutationResponseInput } from './conversation_turn_response_sender'; +import { ConversationMessage, ConversationTurnEvent } from './types'; +import { + GraphqlRequest, + GraphqlRequestExecutor, +} from './graphql_request_executor'; +import { + ConversationHistoryMessageItem, + ConversationMessageHistoryRetriever, + GetQueryOutput, + ListQueryOutput, +} from './conversation_message_history_retriever'; + +type TestCase = { + name: string; + mockListResponseMessages: Array; + mockGetCurrentMessage?: ConversationHistoryMessageItem; + expectedMessages: Array; +}; + +void describe('Conversation message history retriever', () => { + const event: ConversationTurnEvent = { + conversationId: 'testConversationId', + currentMessageId: 'testCurrentMessageId', + graphqlApiEndpoint: '', + messageHistoryQuery: { + getQueryName: 'testGetQueryName', + getQueryInputTypeName: 'testGetQueryInputTypeName', + listQueryName: 'testListQueryName', + listQueryInputTypeName: 'testListQueryInputTypeName', + }, + modelConfiguration: { modelId: '', systemPrompt: '' }, + request: { headers: { authorization: '' } }, + responseMutation: { + name: '', + inputTypeName: '', + selectionSet: '', + }, + }; + + const testCases: Array = [ + { + name: 'Retrieves message history that includes current message', + mockListResponseMessages: [ + { + id: 'someNonCurrentMessageId1', + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'message1', + }, + ], + }, + { + id: 'someNonCurrentMessageId2', + associatedUserMessageId: 'someNonCurrentMessageId1', + conversationId: event.conversationId, + role: 'assistant', + content: [ + { + text: 'message2', + }, + ], + }, + { + id: event.currentMessageId, + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'message3', + }, + ], + }, + ], + expectedMessages: [ + { + role: 'user', + content: [ + { + text: 'message1', + }, + ], + }, + { + role: 'assistant', + content: [ + { + text: 'message2', + }, + ], + }, + { + role: 'user', + content: [ + { + text: 'message3', + }, + ], + }, + ], + }, + { + name: 'Retrieves message history that does not include current message with fallback to get it directly', + mockListResponseMessages: [ + { + id: 'someNonCurrentMessageId1', + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'message1', + }, + ], + }, + { + id: 'someNonCurrentMessageId2', + associatedUserMessageId: 'someNonCurrentMessageId1', + conversationId: event.conversationId, + role: 'assistant', + content: [ + { + text: 'message2', + }, + ], + }, + ], + mockGetCurrentMessage: { + id: event.currentMessageId, + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'message3', + }, + ], + }, + expectedMessages: [ + { + role: 'user', + content: [ + { + text: 'message1', + }, + ], + }, + { + role: 'assistant', + content: [ + { + text: 'message2', + }, + ], + }, + { + role: 'user', + content: [ + { + text: 'message3', + }, + ], + }, + ], + }, + { + name: 'Re-orders delayed assistant responses', + mockListResponseMessages: [ + // Simulate that two first messages were sent without waiting for assistant response + { + id: 'userMessage1', + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'userMessage1', + }, + ], + }, + { + id: 'userMessage2', + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'userMessage2', + }, + ], + }, + // also simulate that responses came back out of order + { + id: 'assistantResponse2', + associatedUserMessageId: 'userMessage2', + conversationId: event.conversationId, + role: 'assistant', + content: [ + { + text: 'assistantResponse2', + }, + ], + }, + { + id: 'assistantResponse1', + associatedUserMessageId: 'userMessage1', + conversationId: event.conversationId, + role: 'assistant', + content: [ + { + text: 'assistantResponse1', + }, + ], + }, + { + id: event.currentMessageId, + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'currentUserMessage', + }, + ], + }, + ], + expectedMessages: [ + { + role: 'user', + content: [ + { + text: 'userMessage1', + }, + ], + }, + { + role: 'assistant', + content: [ + { + text: 'assistantResponse1', + }, + ], + }, + { + role: 'user', + content: [ + { + text: 'userMessage2', + }, + ], + }, + { + role: 'assistant', + content: [ + { + text: 'assistantResponse2', + }, + ], + }, + { + role: 'user', + content: [ + { + text: 'currentUserMessage', + }, + ], + }, + ], + }, + { + name: 'Skips user message that does not have response yet', + mockListResponseMessages: [ + // Simulate that two first messages were sent without waiting for assistant response + // and none was responded to yet. + { + id: 'userMessage1', + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'userMessage1', + }, + ], + }, + { + id: 'userMessage2', + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'userMessage2', + }, + ], + }, + { + id: event.currentMessageId, + conversationId: event.conversationId, + role: 'user', + content: [ + { + text: 'currentUserMessage', + }, + ], + }, + ], + expectedMessages: [ + { + role: 'user', + content: [ + { + text: 'currentUserMessage', + }, + ], + }, + ], + }, + { + name: 'Injects aiContext', + mockListResponseMessages: [ + { + id: event.currentMessageId, + conversationId: event.conversationId, + role: 'user', + aiContext: { some: { ai: 'context' } }, + content: [ + { + text: 'currentUserMessage', + }, + ], + }, + ], + expectedMessages: [ + { + role: 'user', + content: [ + { + text: 'currentUserMessage', + }, + { + text: '{"some":{"ai":"context"}}', + }, + ], + }, + ], + }, + ]; + + for (const testCase of testCases) { + void it(testCase.name, async () => { + const graphqlRequestExecutor = new GraphqlRequestExecutor('', ''); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + (request: GraphqlRequest) => { + if (request.query.match(/ListMessages/)) { + const mockListResponse: ListQueryOutput = { + data: { + [event.messageHistoryQuery.listQueryName]: { + // clone array + items: [...testCase.mockListResponseMessages], + }, + }, + }; + return Promise.resolve(mockListResponse); + } + if ( + request.query.match(/GetMessage/) && + testCase.mockGetCurrentMessage + ) { + const mockGetResponse: GetQueryOutput = { + data: { + [event.messageHistoryQuery.getQueryName]: + testCase.mockGetCurrentMessage, + }, + }; + return Promise.resolve(mockGetResponse); + } + throw new Error('The query is not mocked'); + } + ); + + const retriever = new ConversationMessageHistoryRetriever( + event, + graphqlRequestExecutor + ); + const messages = await retriever.getMessageHistory(); + + assert.strictEqual( + executeGraphqlMock.mock.calls.length, + testCase.mockGetCurrentMessage ? 2 : 1 + ); + const listRequest = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + assert.match(listRequest.query, /ListMessages/); + assert.deepStrictEqual(listRequest.variables, { + filter: { + conversationId: { + eq: 'testConversationId', + }, + }, + limit: 1000, + }); + if (testCase.mockGetCurrentMessage) { + const getRequest = executeGraphqlMock.mock.calls[1] + .arguments[0] as GraphqlRequest; + assert.match(getRequest.query, /GetMessage/); + assert.deepStrictEqual(getRequest.variables, { + id: event.currentMessageId, + }); + } + assert.deepStrictEqual(messages, testCase.expectedMessages); + }); + } +}); diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts b/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts new file mode 100644 index 0000000000..223f372cf0 --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts @@ -0,0 +1,239 @@ +import { ConversationMessage, ConversationTurnEvent } from './types'; +import { GraphqlRequestExecutor } from './graphql_request_executor'; + +export type ConversationHistoryMessageItem = ConversationMessage & { + id: string; + conversationId: string; + associatedUserMessageId?: string; + aiContext?: unknown; +}; + +export type GetQueryInput = { + id: string; +}; + +export type GetQueryOutput = { + data: Record; +}; + +export type ListQueryInput = { + filter: { + conversationId: { + eq: string; + }; + }; + limit: number; +}; + +export type ListQueryOutput = { + data: Record< + string, + { + items: Array; + } + >; +}; + +/** + * These are all properties we have to pull. + * Unfortunately, GQL doesn't support wildcards. + * https://github.com/graphql/graphql-spec/issues/127 + */ +const messageItemSelectionSet = ` + id + conversationId + associatedUserMessageId + aiContext + role + content { + text + document { + source { + bytes + } + format + name + } + image { + format + source { + bytes + } + } + toolResult { + content { + document { + format + name + source { + bytes + } + } + image { + format + source { + bytes + } + } + json + text + } + status + toolUseId + } + toolUse { + input + name + toolUseId + } + } +`; + +/** + * This class is responsible for retrieving message history that belongs to conversation turn event. + * It queries AppSync to list messages that belong to conversation. + * Additionally, it looks up a current message in case it's missing in the list due to eventual consistency. + */ +export class ConversationMessageHistoryRetriever { + /** + * Creates conversation message history retriever. + */ + constructor( + private readonly event: ConversationTurnEvent, + private readonly graphqlRequestExecutor = new GraphqlRequestExecutor( + event.graphqlApiEndpoint, + event.request.headers.authorization + ) + ) {} + + getMessageHistory = async (): Promise> => { + if (this.event.messages?.length) { + // This is for backwards compatibility and should be removed with messages property. + return this.event.messages; + } + const messages = await this.listMessages(); + + let currentMessage = messages.find( + (m) => m.id === this.event.currentMessageId + ); + + // This is a fallback in case current message is not available in the message list. + // I.e. in a situation when freshly written message is not yet visible in + // eventually consistent reads. + if (!currentMessage) { + currentMessage = await this.getCurrentMessage(); + messages.push(currentMessage); + } + + // Index assistant messages by corresponding user message. + const assistantMessageByUserMessageId: Map< + string, + ConversationHistoryMessageItem + > = new Map(); + messages.forEach((message) => { + if (message.role === 'assistant' && message.associatedUserMessageId) { + assistantMessageByUserMessageId.set( + message.associatedUserMessageId, + message + ); + } + }); + + // Reconcile history and inject aiContext + return messages.reduce((acc, current) => { + // Bedrock expects that message history is user->assistant->user->assistant->... and so on. + // The chronological order doesn't assure this ordering if there were any concurrent messages sent. + // Therefore, conversation is ordered by user's messages only and corresponding assistant messages are inserted + // into right place regardless of their createdAt value. + // This algorithm assumes that GQL query returns messages sorted by createdAt. + if (current.role === 'assistant') { + // Initially, skip assistant messages, these might be out of chronological order. + return acc; + } + if ( + current.role === 'user' && + !assistantMessageByUserMessageId.has(current.id) && + current.id !== this.event.currentMessageId + ) { + // Skip user messages that didn't get answer from assistant yet. + // These might be still "in-flight", i.e. assistant is still working on them in separate invocation. + // Except current message, we want to process that one. + return acc; + } + const aiContext = current.aiContext; + const content = aiContext + ? [...current.content, { text: JSON.stringify(aiContext) }] + : current.content; + + acc.push({ role: current.role, content }); + + // Find and insert corresponding assistant message. + const correspondingAssistantMessage = assistantMessageByUserMessageId.get( + current.id + ); + if (correspondingAssistantMessage) { + acc.push({ + role: correspondingAssistantMessage.role, + content: correspondingAssistantMessage.content, + }); + } + return acc; + }, [] as Array); + }; + + private getCurrentMessage = + async (): Promise => { + const query = ` + query GetMessage($id: ${this.event.messageHistoryQuery.getQueryInputTypeName}!) { + ${this.event.messageHistoryQuery.getQueryName}(id: $id) { + ${messageItemSelectionSet} + } + } + `; + const variables: GetQueryInput = { + id: this.event.currentMessageId, + }; + + const response = await this.graphqlRequestExecutor.executeGraphql< + GetQueryInput, + GetQueryOutput + >({ + query, + variables, + }); + + return response.data[this.event.messageHistoryQuery.getQueryName]; + }; + + private listMessages = async (): Promise< + Array + > => { + const query = ` + query ListMessages($filter: ${this.event.messageHistoryQuery.listQueryInputTypeName}!, $limit: Int) { + ${this.event.messageHistoryQuery.listQueryName}(filter: $filter, limit: $limit) { + items { + ${messageItemSelectionSet} + } + } + } + `; + const variables: ListQueryInput = { + filter: { + conversationId: { + eq: this.event.conversationId, + }, + }, + limit: this.event.messageHistoryQuery.listQueryLimit ?? 1000, + }; + + const response = await this.graphqlRequestExecutor.executeGraphql< + ListQueryInput, + ListQueryOutput + >({ + query, + variables, + }); + + return response.data[this.event.messageHistoryQuery.listQueryName].items; + }; +} diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts index e9a0664750..8231b8daef 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts @@ -12,6 +12,12 @@ void describe('Conversation turn executor', () => { currentMessageId: 'testCurrentMessageId', graphqlApiEndpoint: '', messages: [], + messageHistoryQuery: { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }, modelConfiguration: { modelId: '', systemPrompt: '' }, request: { headers: { authorization: '' } }, responseMutation: { diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts index c7201ca1b5..8680592760 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts @@ -1,9 +1,15 @@ import { describe, it, mock } from 'node:test'; import assert from 'node:assert'; -import { text } from 'node:stream/consumers'; -import { ConversationTurnResponseSender } from './conversation_turn_response_sender'; +import { + ConversationTurnResponseSender, + MutationResponseInput, +} from './conversation_turn_response_sender'; import { ConversationTurnEvent } from './types'; import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; +import { + GraphqlRequest, + GraphqlRequestExecutor, +} from './graphql_request_executor'; void describe('Conversation turn response sender', () => { const event: ConversationTurnEvent = { @@ -11,6 +17,12 @@ void describe('Conversation turn response sender', () => { currentMessageId: 'testCurrentMessageId', graphqlApiEndpoint: 'http://fake.endpoint/', messages: [], + messageHistoryQuery: { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }, modelConfiguration: { modelId: '', systemPrompt: '' }, request: { headers: { authorization: 'testToken' } }, responseMutation: { @@ -21,13 +33,18 @@ void describe('Conversation turn response sender', () => { }; void it('sends response back to appsync', async () => { - const fetchMock = mock.fn( - fetch, - (): Promise => + const graphqlRequestExecutor = new GraphqlRequestExecutor('', ''); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => // Mock successful Appsync response - Promise.resolve(new Response('{}', { status: 200 })) + Promise.resolve() + ); + const sender = new ConversationTurnResponseSender( + event, + graphqlRequestExecutor ); - const sender = new ConversationTurnResponseSender(event, fetchMock); const response: Array = [ { text: 'block1', @@ -36,20 +53,10 @@ void describe('Conversation turn response sender', () => { ]; await sender.sendResponse(response); - assert.strictEqual(fetchMock.mock.calls.length, 1); - const request: Request = fetchMock.mock.calls[0].arguments[0] as Request; - assert.strictEqual(request.url, event.graphqlApiEndpoint); - assert.strictEqual(request.method, 'POST'); - assert.strictEqual( - request.headers.get('Content-Type'), - 'application/graphql' - ); - assert.strictEqual( - request.headers.get('Authorization'), - event.request.headers.authorization - ); - assert.ok(request.body); - assert.deepStrictEqual(JSON.parse(await text(request.body)), { + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + assert.deepStrictEqual(request, { query: '\n' + ' mutation PublishModelResponse($input: testResponseMutationInputTypeName!) {\n' + @@ -73,73 +80,19 @@ void describe('Conversation turn response sender', () => { }); }); - void it('throws if response is not 2xx', async () => { - const fetchMock = mock.fn( - fetch, - (): Promise => - // Mock successful Appsync response - Promise.resolve( - new Response('Body with error', { - status: 400, - headers: { testHeaderKey: 'testHeaderValue' }, - }) - ) - ); - const sender = new ConversationTurnResponseSender(event, fetchMock); - const response: Array = []; - await assert.rejects( - () => sender.sendResponse(response), - (error: Error) => { - assert.strictEqual( - error.message, - // eslint-disable-next-line spellcheck/spell-checker - 'Assistant response mutation request was not successful, response headers={"content-type":"text/plain;charset=UTF-8","testheaderkey":"testHeaderValue"}, body=Body with error' - ); - return true; - } - ); - }); - - void it('throws if graphql returns errors', async () => { - const fetchMock = mock.fn( - fetch, - (): Promise => - // Mock successful Appsync response - Promise.resolve( - new Response( - JSON.stringify({ - errors: ['Some GQL error'], - }), - { - status: 200, - headers: { testHeaderKey: 'testHeaderValue' }, - } - ) - ) - ); - const sender = new ConversationTurnResponseSender(event, fetchMock); - const response: Array = []; - await assert.rejects( - () => sender.sendResponse(response), - (error: Error) => { - assert.strictEqual( - error.message, - // eslint-disable-next-line spellcheck/spell-checker - 'Assistant response mutation request was not successful, response headers={"content-type":"text/plain;charset=UTF-8","testheaderkey":"testHeaderValue"}, body={"errors":["Some GQL error"]}' - ); - return true; - } - ); - }); - void it('serializes tool use input to JSON', async () => { - const fetchMock = mock.fn( - fetch, - (): Promise => + const graphqlRequestExecutor = new GraphqlRequestExecutor('', ''); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => // Mock successful Appsync response - Promise.resolve(new Response('{}', { status: 200 })) + Promise.resolve() + ); + const sender = new ConversationTurnResponseSender( + event, + graphqlRequestExecutor ); - const sender = new ConversationTurnResponseSender(event, fetchMock); const toolUseBlock: ContentBlock.ToolUseMember = { toolUse: { name: 'testTool', @@ -152,10 +105,10 @@ void describe('Conversation turn response sender', () => { const response: Array = [toolUseBlock]; await sender.sendResponse(response); - assert.strictEqual(fetchMock.mock.calls.length, 1); - const request: Request = fetchMock.mock.calls[0].arguments[0] as Request; - assert.ok(request.body); - assert.deepStrictEqual(JSON.parse(await text(request.body)), { + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + assert.deepStrictEqual(request, { query: '\n' + ' mutation PublishModelResponse($input: testResponseMutationInputTypeName!) {\n' + diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts index 9fe979aac8..d3bb9accdc 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts @@ -1,7 +1,8 @@ import { ConversationTurnEvent } from './types.js'; import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; +import { GraphqlRequestExecutor } from './graphql_request_executor'; -type MutationResponseInput = { +export type MutationResponseInput = { input: { conversationId: string; content: ContentBlock[]; @@ -19,30 +20,21 @@ export class ConversationTurnResponseSender { */ constructor( private readonly event: ConversationTurnEvent, - private readonly _fetch = fetch + private readonly graphqlRequestExecutor = new GraphqlRequestExecutor( + event.graphqlApiEndpoint, + event.request.headers.authorization + ) ) {} sendResponse = async (message: ContentBlock[]) => { - const request = this.createMutationRequest(message); - const res = await this._fetch(request); - const responseHeaders: Record = {}; - res.headers.forEach((value, key) => (responseHeaders[key] = value)); - if (!res.ok) { - const body = await res.text(); - throw new Error( - `Assistant response mutation request was not successful, response headers=${JSON.stringify( - responseHeaders - )}, body=${body}` - ); - } - const body = await res.json(); - if (body && typeof body === 'object' && 'errors' in body) { - throw new Error( - `Assistant response mutation request was not successful, response headers=${JSON.stringify( - responseHeaders - )}, body=${JSON.stringify(body)}` - ); - } + const { query, variables } = this.createMutationRequest(message); + await this.graphqlRequestExecutor.executeGraphql< + MutationResponseInput, + void + >({ + query, + variables, + }); }; private createMutationRequest = (content: ContentBlock[]) => { @@ -70,13 +62,6 @@ export class ConversationTurnResponseSender { associatedUserMessageId: this.event.currentMessageId, }, }; - return new Request(this.event.graphqlApiEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/graphql', - Authorization: this.event.request.headers.authorization, - }, - body: JSON.stringify({ query, variables }), - }); + return { query, variables }; }; } diff --git a/packages/ai-constructs/src/conversation/runtime/event-tools-provider/event_tools_provider.test.ts b/packages/ai-constructs/src/conversation/runtime/event-tools-provider/event_tools_provider.test.ts index 12bea0403e..8b3db55ea7 100644 --- a/packages/ai-constructs/src/conversation/runtime/event-tools-provider/event_tools_provider.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/event-tools-provider/event_tools_provider.test.ts @@ -12,6 +12,12 @@ void describe('events tool provider', () => { currentMessageId: '', graphqlApiEndpoint: '', messages: [], + messageHistoryQuery: { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }, modelConfiguration: { modelId: '', systemPrompt: '' }, request: { headers: { authorization: '' } }, responseMutation: { @@ -62,6 +68,12 @@ void describe('events tool provider', () => { currentMessageId: '', graphqlApiEndpoint: '', messages: [], + messageHistoryQuery: { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }, modelConfiguration: { modelId: '', systemPrompt: '' }, request: { headers: { authorization: '' } }, responseMutation: { diff --git a/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.test.ts b/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.test.ts index 9989429ef5..6d556e4a2e 100644 --- a/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.test.ts @@ -1,14 +1,20 @@ import { describe, it, mock } from 'node:test'; import assert from 'node:assert'; -import { text } from 'node:stream/consumers'; import { GraphQlTool } from './graphql_tool'; +import { + GraphqlRequest, + GraphqlRequestExecutor, +} from '../graphql_request_executor'; +import { DocumentType } from '@smithy/types'; void describe('GraphQl tool', () => { const graphQlEndpoint = 'http://test.endpoint/'; const query = 'testQuery'; const accessToken = 'testAccessToken'; - const createGraphQlTool = (fetchMock: typeof fetch): GraphQlTool => { + const createGraphQlTool = ( + graphqlRequestExecutor: GraphqlRequestExecutor + ): GraphQlTool => { return new GraphQlTool( 'testName', 'testDescription', @@ -16,7 +22,7 @@ void describe('GraphQl tool', () => { graphQlEndpoint, query, accessToken, - fetchMock + graphqlRequestExecutor ); }; @@ -24,28 +30,21 @@ void describe('GraphQl tool', () => { const testResponse = { test: 'response', }; - const fetchMock = mock.fn( - fetch, - (): Promise => + const graphqlRequestExecutor = new GraphqlRequestExecutor('', ''); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => // Mock successful Appsync response - Promise.resolve( - new Response(JSON.stringify(testResponse), { status: 200 }) - ) + Promise.resolve(testResponse) ); - const tool = createGraphQlTool(fetchMock); + const tool = createGraphQlTool(graphqlRequestExecutor); const toolResult = await tool.execute({ test: 'input' }); - assert.strictEqual(fetchMock.mock.calls.length, 1); - const request: Request = fetchMock.mock.calls[0].arguments[0] as Request; - assert.strictEqual(request.url, graphQlEndpoint); - assert.strictEqual(request.method, 'POST'); - assert.strictEqual( - request.headers.get('Content-Type'), - 'application/graphql' - ); - assert.strictEqual(request.headers.get('Authorization'), accessToken); - assert.ok(request.body); - assert.deepStrictEqual(JSON.parse(await text(request.body)), { + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + assert.deepStrictEqual(request, { query: 'testQuery', variables: { test: 'input', @@ -57,61 +56,4 @@ void describe('GraphQl tool', () => { }, }); }); - - void it('throws if response is not 2xx', async () => { - const fetchMock = mock.fn( - fetch, - (): Promise => - // Mock successful Appsync response - Promise.resolve( - new Response('Body with error', { - status: 400, - headers: { testHeaderKey: 'testHeaderValue' }, - }) - ) - ); - const tool = createGraphQlTool(fetchMock); - await assert.rejects( - () => tool.execute({ test: 'input' }), - (error: Error) => { - assert.strictEqual( - error.message, - // eslint-disable-next-line spellcheck/spell-checker - 'GraphQl tool \'testName\' failed, response headers={"content-type":"text/plain;charset=UTF-8","testheaderkey":"testHeaderValue"}, body=Body with error' - ); - return true; - } - ); - }); - - void it('throws if graphql returns errors', async () => { - const fetchMock = mock.fn( - fetch, - (): Promise => - // Mock successful Appsync response - Promise.resolve( - new Response( - JSON.stringify({ - errors: ['Some GQL error'], - }), - { - status: 200, - headers: { testHeaderKey: 'testHeaderValue' }, - } - ) - ) - ); - const tool = createGraphQlTool(fetchMock); - await assert.rejects( - () => tool.execute({ test: 'input' }), - (error: Error) => { - assert.strictEqual( - error.message, - // eslint-disable-next-line spellcheck/spell-checker - 'GraphQl tool \'testName\' failed, response headers={"content-type":"text/plain;charset=UTF-8","testheaderkey":"testHeaderValue"}, body={"errors":["Some GQL error"]}' - ); - return true; - } - ); - }); }); diff --git a/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.ts b/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.ts index a6f9cce949..6173e3ac63 100644 --- a/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.ts +++ b/packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.ts @@ -4,6 +4,7 @@ import type { ToolResultContentBlock, } from '@aws-sdk/client-bedrock-runtime'; import { DocumentType } from '@smithy/types'; +import { GraphqlRequestExecutor } from '../graphql_request_executor'; /** * A tool that use GraphQl queries. @@ -16,10 +17,13 @@ export class GraphQlTool implements ExecutableTool { public name: string, public description: string, public inputSchema: ToolInputSchema, - private readonly graphQlEndpoint: string, + readonly graphQlEndpoint: string, private readonly query: string, - private readonly accessToken: string, - private readonly _fetch = fetch + readonly accessToken: string, + private readonly graphqlRequestExecutor = new GraphqlRequestExecutor( + graphQlEndpoint, + accessToken + ) ) {} execute = async ( @@ -29,37 +33,13 @@ export class GraphQlTool implements ExecutableTool { throw Error(`GraphQl tool '${this.name}' requires input to execute.`); } - const options: RequestInit = { - method: 'POST', - headers: { - 'Content-Type': 'application/graphql', - Authorization: this.accessToken, - }, - body: JSON.stringify({ query: this.query, variables: input }), - }; - - const req = new Request(this.graphQlEndpoint, options); - const res = await this._fetch(req); - - const responseHeaders: Record = {}; - res.headers.forEach((value, key) => (responseHeaders[key] = value)); - if (!res.ok) { - const body = await res.text(); - throw new Error( - `GraphQl tool '${this.name}' failed, response headers=${JSON.stringify( - responseHeaders - )}, body=${body}` - ); - } - const body = await res.json(); - if (body && typeof body === 'object' && 'errors' in body) { - throw new Error( - `GraphQl tool '${this.name}' failed, response headers=${JSON.stringify( - responseHeaders - )}, body=${JSON.stringify(body)}` - ); - } - - return { json: body as DocumentType }; + const response = await this.graphqlRequestExecutor.executeGraphql< + DocumentType, + DocumentType + >({ + query: this.query, + variables: input, + }); + return { json: response as DocumentType }; }; } diff --git a/packages/ai-constructs/src/conversation/runtime/graphql_request_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/graphql_request_executor.test.ts new file mode 100644 index 0000000000..17820e1aff --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/graphql_request_executor.test.ts @@ -0,0 +1,119 @@ +import { describe, it, mock } from 'node:test'; +import assert from 'node:assert'; +import { text } from 'node:stream/consumers'; +import { GraphqlRequestExecutor } from './graphql_request_executor'; + +void describe('Graphql executor test', () => { + const graphqlEndpoint = 'http://fake.endpoint/'; + const accessToken = 'testToken'; + + void it('sends request to appsync', async () => { + const fetchMock = mock.fn( + fetch, + (): Promise => + // Mock successful Appsync response + Promise.resolve(new Response('{}', { status: 200 })) + ); + const executor = new GraphqlRequestExecutor( + graphqlEndpoint, + accessToken, + fetchMock + ); + const query = 'testQuery'; + const variables = { + testVariableKey: 'testVariableValue', + }; + await executor.executeGraphql({ + query, + variables, + }); + + assert.strictEqual(fetchMock.mock.calls.length, 1); + const request: Request = fetchMock.mock.calls[0].arguments[0] as Request; + assert.strictEqual(request.url, graphqlEndpoint); + assert.strictEqual(request.method, 'POST'); + assert.strictEqual( + request.headers.get('Content-Type'), + 'application/graphql' + ); + assert.strictEqual(request.headers.get('Authorization'), accessToken); + assert.ok(request.body); + assert.deepStrictEqual(JSON.parse(await text(request.body)), { + query: 'testQuery', + variables: { testVariableKey: 'testVariableValue' }, + }); + }); + + void it('throws if response is not 2xx', async () => { + const fetchMock = mock.fn( + fetch, + (): Promise => + // Mock successful Appsync response + Promise.resolve( + new Response('Body with error', { + status: 400, + headers: { testHeaderKey: 'testHeaderValue' }, + }) + ) + ); + const executor = new GraphqlRequestExecutor( + graphqlEndpoint, + accessToken, + fetchMock + ); + const query = 'testQuery'; + const variables = { + testVariableKey: 'testVariableValue', + }; + await assert.rejects( + () => executor.executeGraphql({ query, variables }), + (error: Error) => { + assert.strictEqual( + error.message, + // eslint-disable-next-line spellcheck/spell-checker + 'GraphQL request failed, response headers={"content-type":"text/plain;charset=UTF-8","testheaderkey":"testHeaderValue"}, body=Body with error' + ); + return true; + } + ); + }); + + void it('throws if graphql returns errors', async () => { + const fetchMock = mock.fn( + fetch, + (): Promise => + // Mock successful Appsync response + Promise.resolve( + new Response( + JSON.stringify({ + errors: ['Some GQL error'], + }), + { + status: 200, + headers: { testHeaderKey: 'testHeaderValue' }, + } + ) + ) + ); + const executor = new GraphqlRequestExecutor( + graphqlEndpoint, + accessToken, + fetchMock + ); + const query = 'testQuery'; + const variables = { + testVariableKey: 'testVariableValue', + }; + await assert.rejects( + () => executor.executeGraphql({ query, variables }), + (error: Error) => { + assert.strictEqual( + error.message, + // eslint-disable-next-line spellcheck/spell-checker + 'GraphQL request failed, response headers={"content-type":"text/plain;charset=UTF-8","testheaderkey":"testHeaderValue"}, body={"errors":["Some GQL error"]}' + ); + return true; + } + ); + }); +}); diff --git a/packages/ai-constructs/src/conversation/runtime/graphql_request_executor.ts b/packages/ai-constructs/src/conversation/runtime/graphql_request_executor.ts new file mode 100644 index 0000000000..17d759deae --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/graphql_request_executor.ts @@ -0,0 +1,57 @@ +export type GraphqlRequest = { + query: string; + variables: TVariables; +}; + +/** + * This class is responsible for executing GraphQL requests. + * Serializing query and it's inputs, adding authorization headers, + * inspecting response for errors and de-serializing output. + */ +export class GraphqlRequestExecutor { + /** + * Creates GraphQL request executor. + */ + constructor( + private readonly graphQlEndpoint: string, + private readonly accessToken: string, + private readonly _fetch = fetch + ) {} + + executeGraphql = async ( + request: GraphqlRequest + ): Promise => { + const httpRequest = new Request(this.graphQlEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/graphql', + Authorization: this.accessToken, + }, + body: JSON.stringify({ + query: request.query, + variables: request.variables, + }), + }); + + const res = await this._fetch(httpRequest); + const responseHeaders: Record = {}; + res.headers.forEach((value, key) => (responseHeaders[key] = value)); + if (!res.ok) { + const body = await res.text(); + throw new Error( + `GraphQL request failed, response headers=${JSON.stringify( + responseHeaders + )}, body=${body}` + ); + } + const body = await res.json(); + if (body && typeof body === 'object' && 'errors' in body) { + throw new Error( + `GraphQL request failed, response headers=${JSON.stringify( + responseHeaders + )}, body=${JSON.stringify(body)}` + ); + } + return body as TReturn; + }; +} diff --git a/packages/ai-constructs/src/conversation/runtime/types.ts b/packages/ai-constructs/src/conversation/runtime/types.ts index 3cab0c9925..5794323fb5 100644 --- a/packages/ai-constructs/src/conversation/runtime/types.ts +++ b/packages/ai-constructs/src/conversation/runtime/types.ts @@ -57,7 +57,17 @@ export type ConversationTurnEvent = { authorization: string; }; }; - messages: Array; + /** + * @deprecated This field is going to be removed in upcoming releases. + */ + messages?: Array; + messageHistoryQuery: { + getQueryName: string; + getQueryInputTypeName: string; + listQueryName: string; + listQueryInputTypeName: string; + listQueryLimit?: number; + }; toolsConfiguration?: { dataTools?: Array< ToolDefinition & { diff --git a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts index 9c1c52b5a7..1ab7ab2175 100644 --- a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts +++ b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts @@ -7,7 +7,10 @@ import { AmplifyClient } from '@aws-sdk/client-amplify'; import { BackendIdentifier } from '@aws-amplify/plugin-types'; import { InvokeCommand, LambdaClient } from '@aws-sdk/client-lambda'; import { DeployedResourcesFinder } from '../find_deployed_resource.js'; -import { ConversationTurnEvent } from '@aws-amplify/ai-constructs/conversation/runtime'; +import { + ConversationMessage, + ConversationTurnEvent, +} from '@aws-amplify/ai-constructs/conversation/runtime'; import { randomUUID } from 'crypto'; import { generateClientConfig } from '@aws-amplify/client-config'; import { AmplifyAuthCredentialsFactory } from '../amplify_auth_credentials_factory.js'; @@ -48,6 +51,12 @@ type ConversationTurnAppSyncResponse = { content: string; }; +type CreateConversationMessageChatInput = ConversationMessage & { + conversationId: string; + id: string; + associatedUserMessageId?: string; +}; + const commonEventProperties = { responseMutation: { name: 'createConversationMessageAssistantResponse', @@ -56,12 +65,17 @@ const commonEventProperties = { 'id', 'conversationId', 'content', - 'sender', 'owner', 'createdAt', 'updatedAt', ].join('\n'), }, + messageHistoryQuery: { + getQueryName: 'getConversationMessageChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessageChats', + listQueryInputTypeName: 'ModelConversationMessageChatFilterInput', + }, modelConfiguration: { modelId: bedrockModelId, systemPrompt: 'You are helpful bot.', @@ -189,7 +203,28 @@ class ConversationHandlerTestProject extends TestProjectBase { backendId, authenticatedUserCredentials.accessToken, clientConfig.data.url, - apolloClient + apolloClient, + // Does not use message history lookup. + // This case should be removed when event.messages field is removed. + false + ); + + await this.assertDefaultConversationHandlerCanExecuteTurn( + backendId, + authenticatedUserCredentials.accessToken, + clientConfig.data.url, + apolloClient, + true + ); + + await this.assertDefaultConversationHandlerCanExecuteTurn( + backendId, + authenticatedUserCredentials.accessToken, + clientConfig.data.url, + apolloClient, + true, + // Simulate eventual consistency + true ); await this.assertCustomConversationHandlerCanExecuteTurn( @@ -217,7 +252,16 @@ class ConversationHandlerTestProject extends TestProjectBase { backendId, authenticatedUserCredentials.accessToken, clientConfig.data.url, - apolloClient + apolloClient, + false + ); + + await this.assertDefaultConversationHandlerCanExecuteTurnWithImage( + backendId, + authenticatedUserCredentials.accessToken, + clientConfig.data.url, + apolloClient, + true ); } @@ -225,7 +269,9 @@ class ConversationHandlerTestProject extends TestProjectBase { backendId: BackendIdentifier, accessToken: string, graphqlApiEndpoint: string, - apolloClient: ApolloClient + apolloClient: ApolloClient, + useMessageHistory: boolean, + withoutMessageAvailableInTheMessageList = false ): Promise => { const defaultConversationHandlerFunction = ( await this.resourceFinder.findByBackendIdentifier( @@ -235,26 +281,51 @@ class ConversationHandlerTestProject extends TestProjectBase { ) )[0]; - // send event - const event: ConversationTurnEvent = { + const message: CreateConversationMessageChatInput = { + id: randomUUID().toString(), conversationId: randomUUID().toString(), - currentMessageId: randomUUID().toString(), - graphqlApiEndpoint: graphqlApiEndpoint, - messages: [ + role: 'user', + content: [ { - role: 'user', - content: [ - { - text: 'What is the value of PI?', - }, - ], + text: 'What is the value of PI?', }, ], + }; + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, request: { headers: { authorization: accessToken }, }, ...commonEventProperties, }; + + if (useMessageHistory) { + if (withoutMessageAvailableInTheMessageList) { + // This tricks conversation handler to think that message is not available in the list. + // I.e. it simulates eventually consistency read at list operation where item is not yet visible. + // In this case handler should fall back to lookup by current message id. + message.conversationId = randomUUID().toString(); + } + await this.insertMessage(apolloClient, message); + } else { + event.messageHistoryQuery = { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }; + event.messages = [ + { + role: message.role, + content: message.content, + }, + ]; + } + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, @@ -267,7 +338,8 @@ class ConversationHandlerTestProject extends TestProjectBase { backendId: BackendIdentifier, accessToken: string, graphqlApiEndpoint: string, - apolloClient: ApolloClient + apolloClient: ApolloClient, + useMessageHistory: boolean ): Promise => { const defaultConversationHandlerFunction = ( await this.resourceFinder.findByBackendIdentifier( @@ -291,32 +363,49 @@ class ConversationHandlerTestProject extends TestProjectBase { const imageSource = await fs.readFile(imagePath, 'base64'); - // send event - const event: ConversationTurnEvent = { + const message: CreateConversationMessageChatInput = { + id: randomUUID().toString(), conversationId: randomUUID().toString(), - currentMessageId: randomUUID().toString(), - graphqlApiEndpoint: graphqlApiEndpoint, - messages: [ + role: 'user', + content: [ { - role: 'user', - content: [ - { - text: 'What is on the attached image?', - }, - { - image: { - format: 'png', - source: { bytes: imageSource }, - }, - }, - ], + text: 'What is on the attached image?', + }, + { + image: { + format: 'png', + source: { bytes: imageSource }, + }, }, ], + }; + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, request: { headers: { authorization: accessToken }, }, ...commonEventProperties, }; + if (useMessageHistory) { + await this.insertMessage(apolloClient, message); + } else { + event.messageHistoryQuery = { + getQueryName: '', + getQueryInputTypeName: '', + listQueryName: '', + listQueryInputTypeName: '', + }; + event.messages = [ + { + role: message.role, + content: message.content, + }, + ]; + } const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, @@ -341,21 +430,23 @@ class ConversationHandlerTestProject extends TestProjectBase { ) )[0]; - // send event - const event: ConversationTurnEvent = { + const message: CreateConversationMessageChatInput = { conversationId: randomUUID().toString(), - currentMessageId: randomUUID().toString(), - graphqlApiEndpoint: graphqlApiEndpoint, - messages: [ + id: randomUUID().toString(), + role: 'user', + content: [ { - role: 'user', - content: [ - { - text: 'What is the temperature in Seattle?', - }, - ], + text: 'What is the temperature in Seattle?', }, ], + }; + await this.insertMessage(apolloClient, message); + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, request: { headers: { authorization: accessToken }, }, @@ -414,21 +505,23 @@ class ConversationHandlerTestProject extends TestProjectBase { ) )[0]; - // send event - const event: ConversationTurnEvent = { + const message: CreateConversationMessageChatInput = { conversationId: randomUUID().toString(), - currentMessageId: randomUUID().toString(), - graphqlApiEndpoint: graphqlApiEndpoint, - messages: [ + id: randomUUID().toString(), + role: 'user', + content: [ { - role: 'user', - content: [ - { - text: 'What is the temperature in Seattle?', - }, - ], + text: 'What is the temperature in Seattle?', }, ], + }; + await this.insertMessage(apolloClient, message); + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, request: { headers: { authorization: accessToken }, }, @@ -482,21 +575,23 @@ class ConversationHandlerTestProject extends TestProjectBase { ) )[0]; - // send event - const event: ConversationTurnEvent = { + const message: CreateConversationMessageChatInput = { conversationId: randomUUID().toString(), - currentMessageId: randomUUID().toString(), - graphqlApiEndpoint: graphqlApiEndpoint, - messages: [ + id: randomUUID().toString(), + role: 'user', + content: [ { - role: 'user', - content: [ - { - text: 'What is the temperature in Seattle?', - }, - ], + text: 'What is the temperature in Seattle?', }, ], + }; + await this.insertMessage(apolloClient, message); + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, request: { headers: { authorization: accessToken }, }, @@ -535,10 +630,9 @@ class ConversationHandlerTestProject extends TestProjectBase { }>({ query: gql` query ListMessages { - listConversationMessageAssistantResponses { + listConversationMessageAssistantResponses(limit: 1000) { items { conversationId - sender id updatedAt createdAt @@ -557,4 +651,22 @@ class ConversationHandlerTestProject extends TestProjectBase { assert.ok(response); return response; }; + + private insertMessage = async ( + apolloClient: ApolloClient, + message: CreateConversationMessageChatInput + ): Promise => { + await apolloClient.mutate({ + mutation: gql` + mutation InsertMessage($input: CreateConversationMessageChatInput!) { + createConversationMessageChat(input: $input) { + id + } + } + `, + variables: { + input: message, + }, + }); + }; } diff --git a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts index 2ef65c955b..f698574310 100644 --- a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts +++ b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts @@ -19,13 +19,89 @@ const schema = a.schema({ ) ), - // This schema mocks expected model where conversation responses are supposed to be recorded. + // These schemas below mock models normally generated by conversational routes. + MockConversationParticipantRole: a.enum(['user', 'assistant']), + + MockDocumentBlockSource: a.customType({ + bytes: a.string(), + }), + + MockDocumentBlock: a.customType({ + format: a.string().required(), + name: a.string().required(), + source: a.ref('MockDocumentBlockSource').required(), + }), + + MockImageBlockSource: a.customType({ + bytes: a.string(), + }), + + MockImageBlock: a.customType({ + format: a.string().required(), + source: a.ref('MockImageBlockSource').required(), + }), + + MockToolResultContentBlock: a.customType({ + document: a.ref('MockDocumentBlock'), + image: a.ref('MockImageBlock'), + json: a.json(), + text: a.string(), + }), + + MockToolResultBlock: a.customType({ + toolUseId: a.string().required(), + status: a.string(), + content: a.ref('MockToolResultContentBlock').array().required(), + }), + + MockToolUseBlock: a.customType({ + toolUseId: a.string().required(), + name: a.string().required(), + input: a.json().required(), + }), + + MockContentBlock: a.customType({ + text: a.string(), + document: a.ref('MockDocumentBlock'), + image: a.ref('MockImageBlock'), + toolResult: a.ref('MockToolResultBlock'), + toolUse: a.ref('MockToolUseBlock'), + }), + + MockToolInputSchema: a.customType({ + json: a.json(), + }), + + MockToolSpecification: a.customType({ + name: a.string().required(), + description: a.string(), + inputSchema: a.ref('MockToolInputSchema').required(), + }), + + MockTool: a.customType({ + toolSpec: a.ref('MockToolSpecification'), + }), + + MockToolConfiguration: a.customType({ + tools: a.ref('MockTool').array(), + }), + ConversationMessageAssistantResponse: a .model({ conversationId: a.id(), associatedUserMessageId: a.id(), content: a.string(), - sender: a.enum(['user', 'assistant']), + }) + .authorization((allow) => [allow.authenticated(), allow.owner()]), + + ConversationMessageChat: a + .model({ + conversationId: a.id(), + associatedUserMessageId: a.id(), + role: a.ref('MockConversationParticipantRole'), + content: a.ref('MockContentBlock').array(), + aiContext: a.json(), + toolConfiguration: a.ref('MockToolConfiguration'), }) .authorization((allow) => [allow.authenticated(), allow.owner()]), });