From a77c7a8e39f33d81fa56d9ccfc3b718716abc5ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Wed, 25 Sep 2024 22:02:34 -0300 Subject: [PATCH 1/2] Added suport to anthropic prompt cache --- .../ai/anthropic/AnthropicChatModel.java | 11 ++++++- .../ai/anthropic/api/AnthropicApi.java | 30 +++++++++++++---- .../ai/anthropic/api/AnthropicCacheType.java | 19 +++++++++++ .../ai/anthropic/api/AnthropicApiIT.java | 23 +++++++++++++ .../ai/chat/messages/AbstractMessage.java | 32 +++++++++++++++++++ .../ai/chat/messages/UserMessage.java | 20 ++++++++++++ 6 files changed, 128 insertions(+), 7 deletions(-) create mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index c79b611dc6..890b405e68 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -35,7 +35,9 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.AnthropicCacheType; import org.springframework.ai.anthropic.metadata.AnthropicUsage; +import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -362,7 +364,14 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { if (message.getMessageType() == MessageType.USER) { - List contents = new ArrayList<>(List.of(new ContentBlock(message.getContent()))); + AbstractMessage abstractMessage = (AbstractMessage) message; + List contents; + if (abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contents = new ArrayList<>(List.of(new ContentBlock(message.getContent(), cacheType.cacheControl()))); + } else { + contents = new ArrayList<>(List.of(new ContentBlock(message.getContent()))); + } if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia() diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index ca5368ee4c..29751d26d5 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -71,6 +72,8 @@ public class AnthropicApi { public static final String BETA_MAX_TOKENS = "max-tokens-3-5-sonnet-2024-07-15"; + public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; @@ -267,6 +270,14 @@ public ChatCompletionRequest(String model, List messages, Stri public record Metadata(@JsonProperty("user_id") String userId) { } + /** + * @param type is the cache type supported by anthropic. + * Doc + */ + @JsonInclude(Include.NON_NULL) + public record CacheControl(String type) { + } + public static ChatCompletionRequestBuilder builder() { return new ChatCompletionRequestBuilder(); } @@ -433,7 +444,10 @@ public record ContentBlock( // @formatter:off // tool_result response only @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content + @JsonProperty("content") String content, + + // cache object + @JsonProperty("cache_control") CacheControl cacheControl ) { // @formatter:on @@ -442,25 +456,29 @@ public ContentBlock(String mediaType, String data) { } public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null); } public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null); + } + + public ContentBlock(String text, CacheControl cache) { + this(Type.TEXT, null, text, null, null, null, null, null, null, cache); } // Tool result public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content); + this(type, null, null, null, null, null, null, toolUseId, content, null); } public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null); } // Tool use input JSON delta streaming public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null); + this(type, null, null, null, id, name, input, null, null, null); } /** diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java new file mode 100644 index 0000000000..44617cdf18 --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -0,0 +1,19 @@ +package org.springframework.ai.anthropic.api; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; + +import java.util.function.Supplier; + +public enum AnthropicCacheType { + EPHEMERAL(() -> new CacheControl("ephemeral")); + + private Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + public CacheControl cacheControl() { + return this.value.get(); + } +} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index ceaebdfe6d..9ff5667337 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -26,8 +26,11 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; import reactor.core.publisher.Flux; /** @@ -38,6 +41,26 @@ public class AnthropicApiIT { AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); + + @Test + void chatWithPromptCache() { + AnthropicApi anthropicApiBeta = new AnthropicApi(AnthropicApi.DEFAULT_BASE_URL, + System.getenv("ANTHROPIC_API_KEY"), + AnthropicApi.DEFAULT_ANTHROPIC_VERSION, + RestClient.builder(), + WebClient.builder(), + RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER, + AnthropicApi.BETA_PROMPT_CACHING); + AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?", AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.USER); + + ResponseEntity response = anthropicApiBeta + .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), + List.of(chatCompletionMessage), null, 100, 0.8, false)); + + assertThat(response).isNotNull(); + } + @Test void chatCompletionEntity() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 05a89117c6..17dfe99327 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -41,11 +41,25 @@ public abstract class AbstractMessage implements Message { protected final String textContent; + protected String cache; + /** * Additional options for the message to influence the response, not a generative map. */ protected final Map metadata; + protected AbstractMessage(MessageType messageType, String textContent, Map metadata, String cache) { + Assert.notNull(messageType, "Message type must not be null"); + if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { + Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); + } + this.messageType = messageType; + this.textContent = textContent; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + protected AbstractMessage(MessageType messageType, String textContent, Map metadata) { Assert.notNull(messageType, "Message type must not be null"); if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { @@ -70,6 +84,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map metadata, String cache) { + Assert.notNull(resource, "Resource must not be null"); + try (InputStream inputStream = resource.getInputStream()) { + this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + this.messageType = messageType; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + @Override public String getContent() { return this.textContent; @@ -85,6 +113,10 @@ public MessageType getMessageType() { return this.messageType; } + public String getCache() { + return cache; + } + @Override public boolean equals(Object o) { if (this == o) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 53c3242572..26e3e9b8e8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -35,6 +35,10 @@ public class UserMessage extends AbstractMessage implements MediaContent { protected final List media; + public UserMessage(String textContent, String cache) { + this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache); + } + public UserMessage(String textContent) { this(MessageType.USER, textContent, new ArrayList<>(), Map.of()); } @@ -44,6 +48,11 @@ public UserMessage(Resource resource) { this.media = new ArrayList<>(); } + public UserMessage(Resource resource, String cache) { + super(MessageType.USER, resource, Map.of(), cache); + this.media = new ArrayList<>(); + } + public UserMessage(String textContent, List media) { this(MessageType.USER, textContent, media, Map.of()); } @@ -63,6 +72,13 @@ public UserMessage(MessageType messageType, String textContent, Collection(media); } + public UserMessage(MessageType messageType, String textContent, Collection media, + Map metadata, String cache) { + super(messageType, textContent, metadata, cache); + Assert.notNull(media, "media data must not be null"); + this.media = new ArrayList<>(media); + } + public List getMedia(String... dummy) { return this.media; } @@ -83,4 +99,8 @@ public String getContent() { return this.textContent; } + @Override + public String getCache() { + return super.getCache(); + } } From 0631de8be4e7c39ce671c342a6a9df37be530a8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Wed, 25 Sep 2024 22:09:25 -0300 Subject: [PATCH 2/2] Adjusted source code format --- .../ai/anthropic/AnthropicChatModel.java | 6 ++++-- .../ai/anthropic/api/AnthropicApi.java | 4 ++-- .../ai/anthropic/api/AnthropicCacheType.java | 2 ++ .../ai/anthropic/api/AnthropicApiIT.java | 17 ++++++----------- .../ai/chat/messages/UserMessage.java | 3 ++- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 890b405e68..cd933d3a4c 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -368,8 +368,10 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List contents; if (abstractMessage.getCache() != null) { AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); - contents = new ArrayList<>(List.of(new ContentBlock(message.getContent(), cacheType.cacheControl()))); - } else { + contents = new ArrayList<>( + List.of(new ContentBlock(message.getContent(), cacheType.cacheControl()))); + } + else { contents = new ArrayList<>(List.of(new ContentBlock(message.getContent()))); } if (message instanceof UserMessage userMessage) { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 29751d26d5..39fab4329d 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -271,8 +271,8 @@ public record Metadata(@JsonProperty("user_id") String userId) { } /** - * @param type is the cache type supported by anthropic. - * Doc + * @param type is the cache type supported by anthropic. Doc */ @JsonInclude(Include.NON_NULL) public record CacheControl(String type) { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java index 44617cdf18..06a756be42 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -5,6 +5,7 @@ import java.util.function.Supplier; public enum AnthropicCacheType { + EPHEMERAL(() -> new CacheControl("ephemeral")); private Supplier value; @@ -16,4 +17,5 @@ public enum AnthropicCacheType { public CacheControl cacheControl() { return this.value.get(); } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index 9ff5667337..af4e8ca12b 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -41,22 +41,17 @@ public class AnthropicApiIT { AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); - @Test void chatWithPromptCache() { AnthropicApi anthropicApiBeta = new AnthropicApi(AnthropicApi.DEFAULT_BASE_URL, - System.getenv("ANTHROPIC_API_KEY"), - AnthropicApi.DEFAULT_ANTHROPIC_VERSION, - RestClient.builder(), - WebClient.builder(), - RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER, - AnthropicApi.BETA_PROMPT_CACHING); - AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?", AnthropicCacheType.EPHEMERAL.cacheControl())), - Role.USER); + System.getenv("ANTHROPIC_API_KEY"), AnthropicApi.DEFAULT_ANTHROPIC_VERSION, RestClient.builder(), + WebClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER, AnthropicApi.BETA_PROMPT_CACHING); + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock("Tell me a Joke?", AnthropicCacheType.EPHEMERAL.cacheControl())), Role.USER); ResponseEntity response = anthropicApiBeta - .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), - List.of(chatCompletionMessage), null, 100, 0.8, false)); + .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), + List.of(chatCompletionMessage), null, 100, 0.8, false)); assertThat(response).isNotNull(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 26e3e9b8e8..3ed2092da0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -73,7 +73,7 @@ public UserMessage(MessageType messageType, String textContent, Collection media, - Map metadata, String cache) { + Map metadata, String cache) { super(messageType, textContent, metadata, cache); Assert.notNull(media, "media data must not be null"); this.media = new ArrayList<>(media); @@ -103,4 +103,5 @@ public String getContent() { public String getCache() { return super.getCache(); } + }