Skip to content

Commit

Permalink
Use message history instead of event payload for conversation handler (
Browse files Browse the repository at this point in the history
…#2047)

* Use message history instead of event payload for conversational route

* refactor e2e

* refactor gql requests

* fallback

* lint

* add test for retriever

* refactor that

* todo comments

* lint

* refactor that

* rename

* process history

* process history test

* more tests

* more tests
  • Loading branch information
sobolk authored Sep 26, 2024
1 parent d538ecc commit d0a90b1
Show file tree
Hide file tree
Showing 17 changed files with 1,311 additions and 342 deletions.
5 changes: 5 additions & 0 deletions .changeset/plenty-wombats-fry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@aws-amplify/ai-constructs': minor
---

Use message history instead of event payload for conversational route
9 changes: 8 additions & 1 deletion packages/ai-constructs/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,14 @@ type ConversationTurnEvent = {
authorization: string;
};
};
messages: Array<ConversationMessage>;
messages?: Array<ConversationMessage>;
messageHistoryQuery: {
getQueryName: string;
getQueryInputTypeName: string;
listQueryName: string;
listQueryInputTypeName: string;
listQueryLimit?: number;
};
toolsConfiguration?: {
dataTools?: Array<ToolDefinition & {
graphqlRequestInputDescriptor: {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { describe, it, mock } from 'node:test';
import assert from 'node:assert';
import { ConversationTurnEvent, ExecutableTool, ToolDefinition } from './types';
import {
ConversationMessage,
ConversationTurnEvent,
ExecutableTool,
ToolDefinition,
} from './types';
import { BedrockConverseAdapter } from './bedrock_converse_adapter';
import {
BedrockRuntimeClient,
Expand All @@ -13,22 +18,19 @@ import {
} from '@aws-sdk/client-bedrock-runtime';
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
import { randomBytes, randomUUID } from 'node:crypto';
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';

void describe('Bedrock converse adapter', () => {
const commonEvent: Readonly<ConversationTurnEvent> = {
conversationId: '',
currentMessageId: '',
graphqlApiEndpoint: '',
messages: [
{
role: 'user',
content: [
{
text: 'event message',
},
],
},
],
messageHistoryQuery: {
getQueryName: '',
getQueryInputTypeName: '',
listQueryName: '',
listQueryInputTypeName: '',
},
modelConfiguration: {
modelId: 'testModelId',
systemPrompt: 'testSystemPrompt',
Expand All @@ -46,6 +48,27 @@ void describe('Bedrock converse adapter', () => {
},
};

const messages: Array<ConversationMessage> = [
{
role: 'user',
content: [
{
text: 'event message',
},
],
},
];
const messageHistoryRetriever = new ConversationMessageHistoryRetriever(
commonEvent
);
const messageHistoryRetrieverMockGetEventMessages = mock.method(
messageHistoryRetriever,
'getMessageHistory',
() => {
return Promise.resolve(messages);
}
);

void it('calls bedrock to get conversation response', async () => {
const event: ConversationTurnEvent = {
...commonEvent,
Expand Down Expand Up @@ -78,7 +101,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand All @@ -90,7 +115,7 @@ void describe('Bedrock converse adapter', () => {
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput: ConverseCommandInput = {
messages: event.messages as Array<Message>,
messages: messages as Array<Message>,
modelId: event.modelConfiguration.modelId,
inferenceConfig: event.modelConfiguration.inferenceConfiguration,
system: [
Expand Down Expand Up @@ -211,7 +236,8 @@ void describe('Bedrock converse adapter', () => {
event,
[additionalTool],
bedrockClient,
eventToolsProvider
eventToolsProvider,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand Down Expand Up @@ -251,7 +277,7 @@ void describe('Bedrock converse adapter', () => {
const bedrockRequest1 = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput1: ConverseCommandInput = {
messages: event.messages as Array<Message>,
messages: messages as Array<Message>,
...expectedBedrockInputCommonProperties,
};
assert.deepStrictEqual(bedrockRequest1.input, expectedBedrockInput1);
Expand All @@ -264,7 +290,7 @@ void describe('Bedrock converse adapter', () => {
);
const expectedBedrockInput2: ConverseCommandInput = {
messages: [
...(event.messages as Array<Message>),
...(messages as Array<Message>),
additionalToolUseBedrockResponse.output?.message,
{
role: 'user',
Expand Down Expand Up @@ -447,7 +473,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[tool],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand Down Expand Up @@ -543,7 +571,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[tool],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand Down Expand Up @@ -645,7 +675,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[additionalTool],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(responseContent, [clientToolUseBlock]);
Expand Down Expand Up @@ -682,7 +714,7 @@ void describe('Bedrock converse adapter', () => {
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput: ConverseCommandInput = {
messages: event.messages as Array<Message>,
messages: messages as Array<Message>,
...expectedBedrockInputCommonProperties,
};
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
Expand All @@ -695,21 +727,27 @@ void describe('Bedrock converse adapter', () => {

const fakeImagePayload = randomBytes(32);

event.messages = [
{
role: 'user',
content: [
messageHistoryRetrieverMockGetEventMessages.mock.mockImplementationOnce(
() => {
return Promise.resolve([
{
image: {
format: 'png',
source: {
bytes: fakeImagePayload.toString('base64'),
id: '',
conversationId: '',
role: 'user',
content: [
{
image: {
format: 'png',
source: {
bytes: fakeImagePayload.toString('base64'),
},
},
},
},
],
},
],
},
];
]);
}
);

const bedrockClient = new BedrockRuntimeClient();
const bedrockResponse: ConverseCommandOutput = {
Expand All @@ -735,7 +773,13 @@ void describe('Bedrock converse adapter', () => {
Promise.resolve(bedrockResponse)
);

await new BedrockConverseAdapter(event, [], bedrockClient).askBedrock();
await new BedrockConverseAdapter(
event,
[],
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1);
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
ToolDefinition,
} from './types.js';
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';

/**
* This class is responsible for interacting with Bedrock Converse API
Expand All @@ -36,7 +37,10 @@ export class BedrockConverseAdapter {
private readonly bedrockClient: BedrockRuntimeClient = new BedrockRuntimeClient(
{ region: event.modelConfiguration.region }
),
eventToolsProvider = new ConversationTurnEventToolsProvider(event)
eventToolsProvider = new ConversationTurnEventToolsProvider(event),
private readonly messageHistoryRetriever = new ConversationMessageHistoryRetriever(
event
)
) {
this.executableTools = [
...eventToolsProvider.getEventTools(),
Expand Down Expand Up @@ -73,7 +77,8 @@ export class BedrockConverseAdapter {
const { modelId, systemPrompt, inferenceConfiguration } =
this.event.modelConfiguration;

const messages: Array<Message> = this.getEventMessagesAsBedrockMessages();
const messages: Array<Message> =
await this.getEventMessagesAsBedrockMessages();

let bedrockResponse: ConverseCommandOutput;
do {
Expand Down Expand Up @@ -124,9 +129,13 @@ export class BedrockConverseAdapter {
* 1. Makes a copy so that we don't mutate event.
* 2. Decodes Base64 encoded images.
*/
private getEventMessagesAsBedrockMessages = (): Array<Message> => {
private getEventMessagesAsBedrockMessages = async (): Promise<
Array<Message>
> => {
const messages: Array<Message> = [];
for (const message of this.event.messages) {
const eventMessages =
await this.messageHistoryRetriever.getMessageHistory();
for (const message of eventMessages) {
const messageContent: Array<ContentBlock> = [];
for (const contentElement of message.content) {
if (typeof contentElement.image?.source?.bytes === 'string') {
Expand Down
Loading

0 comments on commit d0a90b1

Please sign in to comment.