Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.bedrock.converse.api.BedrockCacheOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -81,6 +82,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private BedrockCacheOptions cacheOptions;

public static Builder builder() {
return new Builder();
}
Expand All @@ -101,6 +105,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) {
.toolNames(new HashSet<>(fromOptions.getToolNames()))
.toolContext(new HashMap<>(fromOptions.getToolContext()))
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.cacheOptions(fromOptions.getCacheOptions())
.build();
}

Expand Down Expand Up @@ -237,6 +242,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@JsonIgnore
public BedrockCacheOptions getCacheOptions() {
return this.cacheOptions;
}

@JsonIgnore
public void setCacheOptions(BedrockCacheOptions cacheOptions) {
this.cacheOptions = cacheOptions;
}

@Override
@SuppressWarnings("unchecked")
public BedrockChatOptions copy() {
Expand All @@ -259,14 +274,15 @@ public boolean equals(Object o) {
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK)
&& Objects.equals(this.topP, that.topP) && Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled);
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.cacheOptions, that.cacheOptions);
}

@Override
public int hashCode() {
return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty,
this.requestParameters, this.stopSequences, this.temperature, this.topK, this.topP, this.toolCallbacks,
this.toolNames, this.toolContext, this.internalToolExecutionEnabled);
this.toolNames, this.toolContext, this.internalToolExecutionEnabled, this.cacheOptions);
}

public static final class Builder {
Expand Down Expand Up @@ -356,6 +372,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
return this;
}

public Builder cacheOptions(BedrockCacheOptions cacheOptions) {
this.options.setCacheOptions(cacheOptions);
return this;
}

public BedrockChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -46,6 +47,7 @@
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
Expand Down Expand Up @@ -74,6 +76,8 @@
import software.amazon.awssdk.services.bedrockruntime.model.VideoFormat;
import software.amazon.awssdk.services.bedrockruntime.model.VideoSource;

import org.springframework.ai.bedrock.converse.api.BedrockCacheOptions;
import org.springframework.ai.bedrock.converse.api.BedrockCacheStrategy;
import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
import org.springframework.ai.bedrock.converse.api.URLValidator;
Expand Down Expand Up @@ -316,6 +320,8 @@ else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOp
.internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null
? runtimeOptions.getInternalToolExecutionEnabled()
: this.defaultOptions.getInternalToolExecutionEnabled())
.cacheOptions(runtimeOptions.getCacheOptions() != null ? runtimeOptions.getCacheOptions()
: this.defaultOptions.getCacheOptions())
.build();
}

Expand All @@ -326,93 +332,183 @@ else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOp

ConverseRequest createRequest(Prompt prompt) {

List<Message> instructionMessages = prompt.getInstructions()
BedrockChatOptions updatedRuntimeOptions = prompt.getOptions().copy();

// Get cache options to determine strategy
BedrockCacheOptions cacheOptions = updatedRuntimeOptions.getCacheOptions();
boolean shouldCacheConversationHistory = cacheOptions != null
&& cacheOptions.getStrategy() == BedrockCacheStrategy.CONVERSATION_HISTORY;

// Get all non-system messages
List<org.springframework.ai.chat.messages.Message> allNonSystemMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>();
if (message instanceof UserMessage userMessage) {
contents.add(ContentBlock.fromText(userMessage.getText()));

if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia()
.stream()
.map(this::mapMediaToContentBlock)
.toList();
contents.addAll(mediaContent);
}
}
return Message.builder().content(contents).role(ConversationRole.USER).build();
.toList();

// Find the last user message index for CONVERSATION_HISTORY caching
int lastUserMessageIndex = -1;
if (shouldCacheConversationHistory) {
for (int i = allNonSystemMessages.size() - 1; i >= 0; i--) {
if (allNonSystemMessages.get(i).getMessageType() == MessageType.USER) {
lastUserMessageIndex = i;
break;
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getText())) {
contentBlocks.add(ContentBlock.fromText(message.getText()));
}
if (logger.isDebugEnabled()) {
logger.debug("CONVERSATION_HISTORY caching: lastUserMessageIndex={}, totalMessages={}",
lastUserMessageIndex, allNonSystemMessages.size());
}
}

// Build instruction messages with potential caching
List<Message> instructionMessages = new ArrayList<>();
for (int i = 0; i < allNonSystemMessages.size(); i++) {
org.springframework.ai.chat.messages.Message message = allNonSystemMessages.get(i);

// Determine if this message should have a cache point
// For CONVERSATION_HISTORY: cache point goes on the last user message
boolean shouldApplyCachePoint = shouldCacheConversationHistory && i == lastUserMessageIndex;

if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>();
if (message instanceof UserMessage) {
var userMessage = (UserMessage) message;
contents.add(ContentBlock.fromText(userMessage.getText()));

if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia()
.stream()
.map(this::mapMediaToContentBlock)
.toList();
contents.addAll(mediaContent);
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
}

var argumentsDocument = ConverseApiUtils
.convertObjectToDocument(ModelOptionsUtils.jsonToMap(toolCall.arguments()));
// Apply cache point if this is the last user message
if (shouldApplyCachePoint) {
CachePointBlock cachePoint = CachePointBlock.builder().type("default").build();
contents.add(ContentBlock.fromCachePoint(cachePoint));
logger.debug("Applied cache point on last user message (conversation history caching)");
}

instructionMessages.add(Message.builder().content(contents).role(ConversationRole.USER).build());
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getText())) {
contentBlocks.add(ContentBlock.fromText(message.getText()));
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

contentBlocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder()
.toolUseId(toolCall.id())
.name(toolCall.name())
.input(argumentsDocument)
.build()));
var argumentsDocument = ConverseApiUtils
.convertObjectToDocument(ModelOptionsUtils.jsonToMap(toolCall.arguments()));

contentBlocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder()
.toolUseId(toolCall.id())
.name(toolCall.name())
.input(argumentsDocument)
.build()));

}
}
return Message.builder().content(contentBlocks).role(ConversationRole.ASSISTANT).build();
}
else if (message.getMessageType() == MessageType.TOOL) {
List<ContentBlock> contentBlocks = ((ToolResponseMessage) message).getResponses()
.stream()
.map(toolResponse -> {

instructionMessages
.add(Message.builder().content(contentBlocks).role(ConversationRole.ASSISTANT).build());
}
else if (message.getMessageType() == MessageType.TOOL) {
List<ContentBlock> contentBlocks = new ArrayList<>(
((ToolResponseMessage) message).getResponses().stream().map(toolResponse -> {
ToolResultBlock toolResultBlock = ToolResultBlock.builder()
.toolUseId(toolResponse.id())
.content(ToolResultContentBlock.builder().text(toolResponse.responseData()).build())
.build();
return ContentBlock.fromToolResult(toolResultBlock);
})
.toList();
return Message.builder().content(contentBlocks).role(ConversationRole.USER).build();
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
})
.toList();
}).toList());

instructionMessages.add(Message.builder().content(contentBlocks).role(ConversationRole.USER).build());
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
}

// Determine if system message caching should be applied
boolean shouldCacheSystem = cacheOptions != null
&& (cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_ONLY
|| cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_AND_TOOLS);

if (logger.isDebugEnabled() && cacheOptions != null) {
logger.debug("Cache strategy: {}, shouldCacheSystem: {}", cacheOptions.getStrategy(), shouldCacheSystem);
}

List<SystemContentBlock> systemMessages = prompt.getInstructions()
// Build system messages with optional caching on last message
List<org.springframework.ai.chat.messages.Message> systemMessageList = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build())
.toList();

BedrockChatOptions updatedRuntimeOptions = prompt.getOptions().copy();
List<SystemContentBlock> systemMessages = new ArrayList<>();
for (int i = 0; i < systemMessageList.size(); i++) {
org.springframework.ai.chat.messages.Message sysMessage = systemMessageList.get(i);

// Add the text content block
SystemContentBlock textBlock = SystemContentBlock.builder().text(sysMessage.getText()).build();
systemMessages.add(textBlock);

// Apply cache point marker after last system message if caching is enabled
// SystemContentBlock is a UNION type - text and cachePoint must be separate
// blocks
boolean isLastSystem = (i == systemMessageList.size() - 1);
if (isLastSystem && shouldCacheSystem) {
CachePointBlock cachePoint = CachePointBlock.builder().type("default").build();
SystemContentBlock cachePointBlock = SystemContentBlock.builder().cachePoint(cachePoint).build();
systemMessages.add(cachePointBlock);
logger.debug("Applied cache point after system message");
}
}

ToolConfiguration toolConfiguration = null;

// Add the tool definitions to the request's tools parameter.
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(updatedRuntimeOptions);

// Determine if tool caching should be applied
boolean shouldCacheTools = cacheOptions != null
&& (cacheOptions.getStrategy() == BedrockCacheStrategy.TOOLS_ONLY
|| cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_AND_TOOLS);

if (!CollectionUtils.isEmpty(toolDefinitions)) {
List<Tool> bedrockTools = toolDefinitions.stream().map(toolDefinition -> {
List<Tool> bedrockTools = new ArrayList<>();

for (int i = 0; i < toolDefinitions.size(); i++) {
ToolDefinition toolDefinition = toolDefinitions.get(i);
var description = toolDefinition.description();
var name = toolDefinition.name();
String inputSchema = toolDefinition.inputSchema();
return Tool.builder()

// Create tool specification
Tool tool = Tool.builder()
.toolSpec(ToolSpecification.builder()
.name(name)
.description(description)
.inputSchema(ToolInputSchema.fromJson(
ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema))))
.build())
.build();
}).toList();
bedrockTools.add(tool);

// Apply cache point marker after last tool if caching is enabled
// Tool is a UNION type - toolSpec and cachePoint must be separate objects
boolean isLastTool = (i == toolDefinitions.size() - 1);
if (isLastTool && shouldCacheTools) {
CachePointBlock cachePoint = CachePointBlock.builder().type("default").build();
Tool cachePointTool = Tool.builder().cachePoint(cachePoint).build();
bedrockTools.add(cachePointTool);
logger.debug("Applied cache point after tool definitions");
}
}

toolConfiguration = ToolConfiguration.builder().tools(bedrockTools).build();
}
Expand Down Expand Up @@ -635,12 +731,23 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv

ConverseMetrics metrics = response.metrics();

var chatResponseMetaData = ChatResponseMetadata.builder()
var metadataBuilder = ChatResponseMetadata.builder()
.id(response.responseMetadata() != null ? response.responseMetadata().requestId() : "Unknown")
.usage(usage)
.build();
.usage(usage);

// Add cache metrics if available
Map<String, Object> additionalMetadata = new HashMap<>();
if (response.usage().cacheReadInputTokens() != null) {
additionalMetadata.put("cacheReadInputTokens", response.usage().cacheReadInputTokens());
}
if (response.usage().cacheWriteInputTokens() != null) {
additionalMetadata.put("cacheWriteInputTokens", response.usage().cacheWriteInputTokens());
}
if (!additionalMetadata.isEmpty()) {
metadataBuilder.metadata(additionalMetadata);
}

return new ChatResponse(allGenerations, chatResponseMetaData);
return new ChatResponse(allGenerations, metadataBuilder.build());
}

/**
Expand Down
Loading