From 4b88b5cf7bd56a5551a8bb92d086d885c9260e72 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 30 Jan 2024 16:28:13 +0100 Subject: [PATCH] Add Azure OpenAI Chat and Embedding Options - Add AzureOpenAiChatOptions - Add default options field to AzureOpenAiChatClient. - Impl runtime (e.g. prompt) and default options on call. - Add options field to the AzureOpenAiChatProperties. - Add AzureOpenAiEmbeddingOptions - Add default options field to AzureOpenAiEmbeddingClient. - Impmlement runtime and default option merging on embedding request. - Add options field to AzureOpenAiEmbeddingProperties. - Add Unit and ITs. - Split the azure-openai.adoc into ./clients/azure-openai-chat.adoc and ./embeddings/azure-openai-embeddings.adoc. - Provide detailed explanation how to use the chat and embedding clients manually or via the auto-configuration. --- .../azure/openai/AzureOpenAiChatClient.java | 256 +++++++++------ .../azure/openai/AzureOpenAiChatOptions.java | 292 ++++++++++++++++++ .../openai/AzureOpenAiEmbeddingClient.java | 76 +++-- .../openai/AzureOpenAiEmbeddingOptions.java | 86 ++++++ .../AzureChatCompletionsOptionsTests.java | 55 ++++ .../openai/AzureEmbeddingsOptionsTests.java | 58 ++++ .../azure/openai/AzureOpenAiChatClientIT.java | 4 +- .../openai/AzureOpenAiEmbeddingClientIT.java | 2 +- .../ai/model/ModelOptionsUtils.java | 6 +- .../src/main/antora/modules/ROOT/nav.adoc | 3 +- .../pages/api/clients/azure-openai-chat.adoc | 176 +++++++++++ .../embeddings/azure-openai-embeddings.adoc | 167 ++++++++++ .../openai/AzureOpenAiAutoConfiguration.java | 7 +- .../openai/AzureOpenAiChatProperties.java | 69 +---- .../AzureOpenAiEmbeddingProperties.java | 18 +- .../openai/OpenAiAutoConfiguration.java | 5 +- ...eOpenAiAutoConfigurationPropertyTests.java | 68 ++-- 17 files changed, 1136 insertions(+), 212 deletions(-) create mode 100644 models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java create mode 100644 models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java create mode 100644 models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java create mode 100644 models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/azure-openai-chat.adoc create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java index 1dc941c775..84d5504d71 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,12 +27,10 @@ import com.azure.ai.openai.models.ChatRequestMessage; import com.azure.ai.openai.models.ChatRequestSystemMessage; import com.azure.ai.openai.models.ChatRequestUserMessage; -import com.azure.ai.openai.models.ChatResponseMessage; import com.azure.ai.openai.models.ContentFilterResultsForPrompt; import com.azure.core.util.IterableStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; @@ -40,10 +38,11 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.messages.Message; import org.springframework.util.Assert; /** @@ -59,104 +58,39 @@ */ public class AzureOpenAiChatClient implements ChatClient, StreamingChatClient { - /** - * The sampling temperature to use that controls the apparent creativity of generated - * completions. Higher values will make output more random while lower values will - * make results more focused and deterministic. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. - */ - private Double temperature = 0.7; + private static final String DEFAULT_MODEL = "gpt-35-turbo"; - /** - * An alternative to sampling with temperature called nucleus sampling. This value - * causes the model to consider the results of tokens with the provided probability - * mass. As an example, a value of 0.15 will cause only the tokens comprising the top - * 15% of probability mass to be considered. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. - */ - private Double topP; + private static final Float DEFAULT_TEMPERATURE = 0.7f; + + private final Logger logger = LoggerFactory.getLogger(getClass()); /** - * Creates an instance of ChatCompletionsOptions class. + * The configuration information for a chat completions request. */ - private String model = "gpt-35-turbo"; + private AzureOpenAiChatOptions defaultOptions; /** - * The maximum number of tokens to generate. + * The {@link OpenAIClient} used to interact with the Azure OpenAI service. */ - private Integer maxTokens; - - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final OpenAIClient openAIClient; public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) { Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); this.openAIClient = microsoftOpenAiClient; + this.defaultOptions = AzureOpenAiChatOptions.builder() + .withModel(DEFAULT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE) + .build(); } - public String getModel() { - return this.model; - } - - public AzureOpenAiChatClient withModel(String model) { - this.model = model; - return this; - } - - public Double getTemperature() { - return this.temperature; - } - - public AzureOpenAiChatClient withTemperature(Double temperature) { - this.temperature = temperature; - return this; - } - - public Double getTopP() { - return topP; - } - - public AzureOpenAiChatClient withTopP(Double topP) { - this.topP = topP; + public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) { + Assert.notNull(defaultOptions, "DefaultOptions must not be null"); + this.defaultOptions = defaultOptions; return this; } - public Integer getMaxTokens() { - return maxTokens; - } - - public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - @Override - public String call(String text) { - - ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text); - - ChatCompletionsOptions options = new ChatCompletionsOptions(List.of(azureChatMessage)); - options.setTemperature(this.getTemperature()); - options.setModel(this.getModel()); - - logger.trace("Azure Chat Message: {}", azureChatMessage); - - ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options); - logger.trace("Azure ChatCompletions: {}", chatCompletions); - - StringBuilder stringBuilder = new StringBuilder(); - - for (ChatChoice choice : chatCompletions.getChoices()) { - ChatResponseMessage message = choice.getMessage(); - if (message != null && message.getContent() != null) { - stringBuilder.append(message.getContent()); - } - } - - return stringBuilder.toString(); + public AzureOpenAiChatOptions getDefaultOptions() { + return defaultOptions; } @Override @@ -167,7 +101,7 @@ public ChatResponse call(Prompt prompt) { logger.trace("Azure ChatCompletionsOptions: {}", options); - ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options); + ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); logger.trace("Azure ChatCompletions: {}", chatCompletions); @@ -178,6 +112,7 @@ public ChatResponse call(Prompt prompt) { .toList(); PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + return new ChatResponse(generations, AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata)); } @@ -189,7 +124,7 @@ public Flux stream(Prompt prompt) { options.setStream(true); IterableStream chatCompletionsStream = this.openAIClient - .getChatCompletionsStream(this.getModel(), options); + .getChatCompletionsStream(options.getModel(), options); return Flux.fromStream(chatCompletionsStream.stream() // Note: the first chat completions can be ignored when using Azure OpenAI @@ -205,7 +140,10 @@ public Flux stream(Prompt prompt) { })); } - private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { + /** + * Test access. + */ + ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { List azureMessages = prompt.getInstructions() .stream() @@ -214,10 +152,27 @@ private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); - options.setTemperature(this.getTemperature()); - options.setModel(this.getModel()); - options.setTopP(this.getTopP()); - options.setMaxTokens(this.getMaxTokens()); + if (this.defaultOptions != null) { + // JSON merge doesn't due to Azure OpenAI service bug: + // https://github.com/Azure/azure-sdk-for-java/issues/38183 + // options = ModelOptionsUtils.merge(options, this.defaultOptions, + // ChatCompletionsOptions.class); + options = merge(options, this.defaultOptions); + } + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof AzureOpenAiChatOptions runtimeOptions) { + // JSON merge doesn't due to Azure OpenAI service bug: + // https://github.com/Azure/azure-sdk-for-java/issues/38183 + // options = ModelOptionsUtils.merge(runtimeOptions, options, + // ChatCompletionsOptions.class); + options = merge(runtimeOptions, options); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:" + + prompt.getOptions().getClass().getSimpleName()); + } + } return options; } @@ -256,4 +211,121 @@ private List nullSafeList(List list) { return list != null ? list : Collections.emptyList(); } + // JSON merge doesn't due to Azure OpenAI service bug: + // https://github.com/Azure/azure-sdk-for-java/issues/38183 + private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureOpenAiChatOptions springAiOptions) { + + if (springAiOptions == null) { + return azureOptions; + } + + ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages()); + mergedAzureOptions.setStream(azureOptions.isStream()); + + mergedAzureOptions.setMaxTokens(azureOptions.getMaxTokens()); + if (mergedAzureOptions.getMaxTokens() == null) { + mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens()); + } + + mergedAzureOptions.setLogitBias(azureOptions.getLogitBias()); + if (mergedAzureOptions.getLogitBias() == null) { + mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias()); + } + + mergedAzureOptions.setStop(azureOptions.getStop()); + if (mergedAzureOptions.getStop() == null) { + mergedAzureOptions.setStop(springAiOptions.getStop()); + } + + mergedAzureOptions.setTemperature(azureOptions.getTemperature()); + if (mergedAzureOptions.getTemperature() == null && springAiOptions.getTemperature() != null) { + mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue()); + } + + mergedAzureOptions.setTopP(azureOptions.getTopP()); + if (mergedAzureOptions.getTopP() == null && springAiOptions.getTopP() != null) { + mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue()); + } + + mergedAzureOptions.setFrequencyPenalty(azureOptions.getFrequencyPenalty()); + if (mergedAzureOptions.getFrequencyPenalty() == null && springAiOptions.getFrequencyPenalty() != null) { + mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue()); + } + + mergedAzureOptions.setPresencePenalty(azureOptions.getPresencePenalty()); + if (mergedAzureOptions.getPresencePenalty() == null && springAiOptions.getPresencePenalty() != null) { + mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue()); + } + + mergedAzureOptions.setN(azureOptions.getN()); + if (mergedAzureOptions.getN() == null) { + mergedAzureOptions.setN(springAiOptions.getN()); + } + + mergedAzureOptions.setUser(azureOptions.getUser()); + if (mergedAzureOptions.getUser() == null) { + mergedAzureOptions.setUser(springAiOptions.getUser()); + } + + mergedAzureOptions.setModel(azureOptions.getModel()); + if (mergedAzureOptions.getModel() == null) { + mergedAzureOptions.setModel(springAiOptions.getModel()); + } + + return mergedAzureOptions; + } + + // JSON merge doesn't due to Azure OpenAI service bug: + // https://github.com/Azure/azure-sdk-for-java/issues/38183 + private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, ChatCompletionsOptions azureOptions) { + if (springAiOptions == null) { + return azureOptions; + } + + ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages()); + mergedAzureOptions.setStream(azureOptions.isStream()); + + if (springAiOptions.getMaxTokens() != null) { + mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens()); + } + + if (springAiOptions.getLogitBias() != null) { + mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias()); + } + + if (springAiOptions.getStop() != null) { + mergedAzureOptions.setStop(springAiOptions.getStop()); + } + + if (springAiOptions.getTemperature() != null && springAiOptions.getTemperature() != null) { + mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue()); + } + + if (springAiOptions.getTopP() != null && springAiOptions.getTopP() != null) { + mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue()); + } + + if (springAiOptions.getFrequencyPenalty() != null && springAiOptions.getFrequencyPenalty() != null) { + mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue()); + } + + if (springAiOptions.getPresencePenalty() != null && springAiOptions.getPresencePenalty() != null) { + mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue()); + } + + if (springAiOptions.getN() != null) { + mergedAzureOptions.setN(springAiOptions.getN()); + } + + if (springAiOptions.getUser() != null) { + mergedAzureOptions.setUser(springAiOptions.getUser()); + } + + if (springAiOptions.getModel() != null) { + mergedAzureOptions.setModel(springAiOptions.getModel()); + } + + return mergedAzureOptions; + } + } \ No newline at end of file diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java new file mode 100644 index 0000000000..4b1878def8 --- /dev/null +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -0,0 +1,292 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.chat.ChatOptions; + +/** + * The configuration information for a chat completions request. Completions support a + * wide variety of tasks and generate text that continues from or "completes" provided + * prompt data. + * + * @author Christian Tzolov + */ +@JsonInclude(Include.NON_NULL) +public class AzureOpenAiChatOptions implements ChatOptions { + + /** + * The maximum number of tokens to generate. + */ + @JsonProperty(value = "max_tokens") + private Integer maxTokens; + + /** + * The sampling temperature to use that controls the apparent creativity of generated + * completions. Higher values will make output more random while lower values will + * make results more focused and deterministic. It is not recommended to modify + * temperature and top_p for the same completions request as the interaction of these + * two settings is difficult to predict. + */ + @JsonProperty(value = "temperature") + private Float temperature; + + /** + * An alternative to sampling with temperature called nucleus sampling. This value + * causes the model to consider the results of tokens with the provided probability + * mass. As an example, a value of 0.15 will cause only the tokens comprising the top + * 15% of probability mass to be considered. It is not recommended to modify + * temperature and top_p for the same completions request as the interaction of these + * two settings is difficult to predict. + */ + @JsonProperty(value = "top_p") + private Float topP; + + /** + * A map between GPT token IDs and bias scores that influences the probability of + * specific tokens appearing in a completions response. Token IDs are computed via + * external tokenizer tools, while bias scores reside in the range of -100 to 100 with + * minimum and maximum values corresponding to a full ban or exclusive selection of a + * token, respectively. The exact behavior of a given bias score varies by model. + */ + @JsonProperty(value = "logit_bias") + private Map logitBias; + + /** + * An identifier for the caller or end user of the operation. This may be used for + * tracking or rate-limiting purposes. + */ + @JsonProperty(value = "user") + private String user; + + /** + * The number of chat completions choices that should be generated for a chat + * completions response. Because this setting can generate many completions, it may + * quickly consume your token quota. Use carefully and ensure reasonable settings for + * max_tokens and stop. + */ + @JsonProperty(value = "n") + private Integer n; + + /** + * A collection of textual sequences that will end completions generation. + */ + @JsonProperty(value = "stop") + private List stop; + + /** + * A value that influences the probability of generated tokens appearing based on + * their existing presence in generated text. Positive values will make tokens less + * likely to appear when they already exist and increase the model's likelihood to + * output new topics. + */ + @JsonProperty(value = "presence_penalty") + private Double presencePenalty; + + /** + * A value that influences the probability of generated tokens appearing based on + * their cumulative frequency in generated text. Positive values will make tokens less + * likely to appear as their frequency increases and decrease the likelihood of the + * model repeating the same statements verbatim. + */ + @JsonProperty(value = "frequency_penalty") + private Double frequencyPenalty; + + /** + * The model name to provide as part of this completions request. Not applicable to + * Azure OpenAI, where deployment information should be included in the Azure resource + * URI that's connected to. + */ + @JsonProperty(value = "model") + private String model; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected AzureOpenAiChatOptions options; + + public Builder() { + this.options = new AzureOpenAiChatOptions(); + } + + public Builder(AzureOpenAiChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty.doubleValue(); + return this; + } + + public Builder withLogitBias(Map logitBias) { + this.options.logitBias = logitBias; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + this.options.presencePenalty = presencePenalty.doubleValue(); + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.options.topP = topP; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public AzureOpenAiChatOptions build() { + return this.options; + } + + } + + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Map getLogitBias() { + return this.logitBias; + } + + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + @Override + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + @Override + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @Override + @JsonIgnore + public void setTopK(Integer topK) { + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + +} diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java index 6e217f5e5a..76bfc77f19 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java @@ -18,6 +18,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient { @@ -26,52 +27,69 @@ public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient { private final OpenAIClient azureOpenAiClient; - private final String model; + private AzureOpenAiEmbeddingOptions defaultOptions = AzureOpenAiEmbeddingOptions.builder() + .withModel("text-embedding-ada-002") + .build(); private final MetadataMode metadataMode; public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) { - this(azureOpenAiClient, "text-embedding-ada-002"); + this(azureOpenAiClient, MetadataMode.EMBED); } - public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model) { - this(azureOpenAiClient, model, MetadataMode.EMBED); - } - - public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model, MetadataMode metadataMode) { + public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) { Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); - Assert.notNull(model, "Model must not be null"); Assert.notNull(metadataMode, "Metadata mode must not be null"); this.azureOpenAiClient = azureOpenAiClient; - this.model = model; this.metadataMode = metadataMode; } @Override public List embed(Document document) { logger.debug("Retrieving embeddings"); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, - new EmbeddingsOptions(List.of(document.getFormattedContent(this.metadataMode)))); - logger.debug("Embeddings retrieved"); - return extractEmbeddingsList(embeddings); - } - private List extractEmbeddingsList(Embeddings embeddings) { - return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).flatMap(List::stream).toList(); + EmbeddingResponse response = this + .call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null)); + logger.debug("Embeddings retrieved"); + return response.getResults().stream().map(embedding -> embedding.getOutput()).flatMap(List::stream).toList(); } @Override - public EmbeddingResponse call(EmbeddingRequest request) { + public EmbeddingResponse call(EmbeddingRequest embeddingRequest) { logger.debug("Retrieving embeddings"); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, - new EmbeddingsOptions(request.getInstructions())); + + EmbeddingsOptions azureOptions = new EmbeddingsOptions(embeddingRequest.getInstructions()); + if (this.defaultOptions != null) { + azureOptions = ModelOptionsUtils.merge(azureOptions, this.defaultOptions, EmbeddingsOptions.class); + } + if (embeddingRequest.getOptions() != null) { + azureOptions = ModelOptionsUtils.merge(embeddingRequest.getOptions(), azureOptions, + EmbeddingsOptions.class); + } + Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions); + logger.debug("Embeddings retrieved"); return generateEmbeddingResponse(embeddings); } + /** + * Test access + */ + EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) { + var azureOptions = new EmbeddingsOptions(embeddingRequest.getInstructions()); + if (this.defaultOptions != null) { + azureOptions = ModelOptionsUtils.merge(azureOptions, this.defaultOptions, EmbeddingsOptions.class); + } + if (embeddingRequest.getOptions() != null) { + azureOptions = ModelOptionsUtils.merge(embeddingRequest.getOptions(), azureOptions, + EmbeddingsOptions.class); + } + return azureOptions; + } + private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) { List data = generateEmbeddingList(embeddings.getData()); - EmbeddingResponseMetadata metadata = generateMetadata(this.model, embeddings.getUsage()); + EmbeddingResponseMetadata metadata = generateMetadata(embeddings.getUsage()); return new EmbeddingResponse(data, metadata); } @@ -86,12 +104,26 @@ private List generateEmbeddingList(List nativeData) { return data; } - private EmbeddingResponseMetadata generateMetadata(String model, EmbeddingsUsage embeddingsUsage) { + private EmbeddingResponseMetadata generateMetadata(EmbeddingsUsage embeddingsUsage) { EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); - metadata.put("model", model); + // metadata.put("model", model); metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens()); metadata.put("total-tokens", embeddingsUsage.getTotalTokens()); return metadata; } + public AzureOpenAiEmbeddingOptions getDefaultOptions() { + return this.defaultOptions; + } + + public void setDefaultOptions(AzureOpenAiEmbeddingOptions defaultOptions) { + Assert.notNull(defaultOptions, "Default options must not be null"); + this.defaultOptions = defaultOptions; + } + + public AzureOpenAiEmbeddingClient withDefaultOptions(AzureOpenAiEmbeddingOptions options) { + this.defaultOptions = options; + return this; + } + } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java new file mode 100644 index 0000000000..3d89c40ae0 --- /dev/null +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java @@ -0,0 +1,86 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * The configuration information for the embedding requests. + * + * @author Christian Tzolov + * @since 0.8.0 + */ +public class AzureOpenAiEmbeddingOptions implements EmbeddingOptions { + + /** + * An identifier for the caller or end user of the operation. This may be used for + * tracking or rate-limiting purposes. + */ + @JsonProperty(value = "user") + private String user; + + /** + * The model name to provide as part of this embeddings request. Not applicable to + * Azure OpenAI, where deployment information should be included in the Azure resource + * URI that's connected to. + */ + @JsonProperty(value = "model") + private String model; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final AzureOpenAiEmbeddingOptions options = new AzureOpenAiEmbeddingOptions(); + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public AzureOpenAiEmbeddingOptions build() { + return this.options; + } + + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + +} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java new file mode 100644 index 0000000000..0d8a5509b9 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import com.azure.ai.openai.OpenAIClient; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.ai.chat.prompt.Prompt; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class AzureChatCompletionsOptionsTests { + + @Test + public void createRequestWithChatOptions() { + + OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + var client = new AzureOpenAiChatClient(mockClient).withDefaultOptions( + AzureOpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build()); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content")); + + assertThat(requestOptions.getMessages()).hasSize(1); + + assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL"); + assertThat(requestOptions.getTemperature()).isEqualTo(66.6f); + + requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content", + AzureOpenAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build())); + + assertThat(requestOptions.getMessages()).hasSize(1); + + assertThat(requestOptions.getModel()).isEqualTo("PROMPT_MODEL"); + assertThat(requestOptions.getTemperature()).isEqualTo(99.9f); + } + +} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java new file mode 100644 index 0000000000..0aaa0abcd4 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import java.util.List; + +import com.azure.ai.openai.OpenAIClient; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.ai.embedding.EmbeddingRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +public class AzureEmbeddingsOptionsTests { + + @Test + public void createRequestWithChatOptions() { + + OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + var client = new AzureOpenAiEmbeddingClient(mockClient).withDefaultOptions( + AzureOpenAiEmbeddingOptions.builder().withModel("DEFAULT_MODEL").withUser("USER_TEST").build()); + + var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"), null)); + + assertThat(requestOptions.getInput()).hasSize(1); + + assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL"); + assertThat(requestOptions.getUser()).isEqualTo("USER_TEST"); + + requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"), + AzureOpenAiEmbeddingOptions.builder().withModel("PROMPT_MODEL").withUser("PROMPT_USER").build())); + + assertThat(requestOptions.getInput()).hasSize(1); + + assertThat(requestOptions.getModel()).isEqualTo("PROMPT_MODEL"); + assertThat(requestOptions.getUser()).isEqualTo("PROMPT_USER"); + } + +} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index a65df0c955..0de353e601 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -180,7 +180,9 @@ public OpenAIClient openAIClient() { @Bean public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) { - return new AzureOpenAiChatClient(openAIClient).withModel("gpt-35-turbo").withMaxTokens(200); + return new AzureOpenAiChatClient(openAIClient).withDefaultOptions( + AzureOpenAiChatOptions.builder().withModel("gpt-35-turbo").withMaxTokens(200).build()); + } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java index f23e2e6abb..00f301f72c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java @@ -60,7 +60,7 @@ public OpenAIClient openAIClient() { @Bean public AzureOpenAiEmbeddingClient azureEmbeddingClient(OpenAIClient openAIClient) { - return new AzureOpenAiEmbeddingClient(openAIClient, "text-embedding-ada-002"); + return new AzureOpenAiEmbeddingClient(openAIClient); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 96ddc1e3fb..87f4e509a0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -107,7 +107,11 @@ public static Map objectToMap(Object source) { try { String json = OBJECT_MAPPER.writeValueAsString(source); return OBJECT_MAPPER.readValue(json, new TypeReference>() { - }); + }) + .entrySet() + .stream() + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); } catch (JsonProcessingException e) { throw new RuntimeException(e); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index aee3ddbbea..b9cec69e60 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -4,7 +4,7 @@ * xref:api/index.adoc[] ** xref:api/chatclient.adoc[] *** xref:api/clients/openai.adoc[] -*** xref:api/clients/azure-openai.adoc[] +*** xref:api/clients/azure-openai-chat.adoc[] *** xref:api/clients/bedrock.adoc[] *** xref:api/clients/huggingface.adoc[] *** xref:api/clients/ollama.adoc[] @@ -13,6 +13,7 @@ ** xref:api/etl-pipeline.adoc[] ** xref:api/embeddings.adoc[] *** xref:api/embeddings/onnx.adoc[] +*** xref:api/embeddings/azure-openai-embeddings.adoc[] ** xref:api/vectordbs.adoc[] *** xref:api/vectordbs/azure.adoc[] *** xref:api/vectordbs/chroma.adoc[] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/azure-openai-chat.adoc new file mode 100644 index 0000000000..41a8c47767 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/azure-openai-chat.adoc @@ -0,0 +1,176 @@ += Azure OpenAI Chat + +Azure's OpenAI offering, powered by ChatGPT, extends beyond traditional OpenAI capabilities, delivering AI-driven text generation with enhanced functionality. Azure offers additional AI safety and responsible AI features, as highlighted in their recent update https://techcommunity.microsoft.com/t5/ai-azure-ai-services-blog/announcing-new-ai-safety-amp-responsible-ai-features-in-azure/ba-p/3983686[here]. + +Azure offers Java developers the opportunity to leverage AI's full potential by integrating it with an array of Azure services, which includes AI-related resources such as Vector Stores on Azure. + +== Getting Started + +Obtain your Azure OpenAI `endpoint` and `api-key` from the Azure OpenAI Service section on the link:https://portal.azure.com[Azure Portal]. + +=== Configure the Azure OpenAI Chat Client Manually + +Add the `spring-ai-azure-openai` dependency to your project's Maven `pom.xml` file: +[source, xml] +---- + + org.springframework.ai + spring-ai-azure-openai + 0.8.0-SNAPSHOT + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-azure-openai:0.8.0-SNAPSHOT' +} +---- + +NOTE: The `spring-ai-azure-openai` dependency also provide the access to the `AzureOpenAiChatClient`. For more information about the `AzureOpenAiChatClient` refer to the link:../clients/azure-openai-chat.html[Azure OpenAI Chat] section. + +Next, create an `AzureOpenAiChatClient` instance and use it to generate text responses: + +[source,java] +---- +var openAIClient = OpenAIClientBuilder() + .credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .buildClient(); + +var chatClient = new AzureOpenAiChatClient(openAIClient).withDefaultOptions( + AzureOpenAiChatOptions.builder() + .withModel("gpt-35-turbo") + .withTemperature(0.4) + .withMaxTokens(200) + .build()); + +ChatResponse response = chatClient.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux response = chatClient.stream( + new Prompt("Generate the names of 5 famous pirates.")); + +---- + +NOTE: the `gpt-35-turbo` is actually the `Deployment Name` as presented in the Azure AI Portal. + +The `AzureOpenAiChatOptions` provides the configuration information for the chat requests. +The `AzureOpenAiChatOptions` offers a builder to create the options. + +At start time use the `AzureOpenAiChatClient#withDefaultOptions()` to configure the default options used for all char requests. +Furthermore, at runtime, you can override the default options by passing a `AzureOpenAiChatOptions` instance with your to the `Prompt` request. + +For example to override the default model name for a specific request: + +[source,java] +---- +ChatResponse response = chatClient.call( + new Prompt( + "Generate the names of 5 famous pirates.", + AzureOpenAiChatOptions.builder().withModel("gpt-4-32k").build() + )); +---- + +=== Spring Boot Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Azure OpenAI Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-azure-openai-spring-boot-starter + 0.8.0-SNAPSHOT + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-azure-openai-spring-boot-starter:0.8.0-SNAPSHOT' +} +---- + +Spring AI defines a configuration property named `spring.ai.azure.openai.api-key` that you should set to the value of the `API Key` obtained from Azure. +There is also a configuration property named `spring.ai.azure.openai.endpoint` that you should set to the endpoint URL obtained when provisioning your model in Azure. + +Exporting environment variables is one way to set these configuration properties: + +[source,shell] +---- +export SPRING_AI_AZURE_OPENAI_API_KEY= +export SPRING_AI_AZURE_OPENAI_ENDPOINT= +---- + +The `spring.ai.azure.openai.chat.options.*` properties are used to configure the default options used for all chat requests. + +==== Sample Code + +This will create a `ChatClient` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `ChatClient` implementation. + +[source,application.properties] +---- +spring.ai.azure.openai.api-key=YOUR_API_KEY +spring.ai.azure.openai.endpoint=YOUR_ENDPOINT +spring.ai.azure.openai.chat.options.model=gpt-35-turbo +spring.ai.azure.openai.chat.options.temperature=0.7 +---- + +[source,java] +---- +@RestController +public class ChatController { + + private final ChatClient chatClient; + + @Autowired + public ChatController(ChatClient chatClient) { + this.chatClient = chatClient; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatClient.generate(message)); + } +} +---- + +== Azure OpenAI Chat Properties + +The prefix `spring.ai.azure.openai` is the property prefix to configure the connection to Azure OpenAI. + +[cols="3,5,3"] +|==== +| Property | Description | Default + +| spring.ai.azure.openai.api-key | The Key from Azure AI OpenAI `Keys and Endpoint` section under `Resource Management` | - +| spring.ai.azure.openai.endpoint | The endpoint from the Azure AI OpenAI `Keys and Endpoint` section under `Resource Management` | - +|==== + + +The prefix `spring.ai.azure.openai.chat` is the property prefix that configures the `ChatClient` implementation for Azure OpenAI. + +[cols="3,5,3"] +|==== +| Property | Description | Default + +| spring.ai.azure.openai.chat.options.model | * The model name to provide as part of this completions request. Not applicable to Azure OpenAI, where deployment information should be included in the Azure resource URI that's connected to. + | gpt-35-turbo +| spring.ai.azure.openai.chat.options.maxTokens | The maximum number of tokens to generate. | - +| spring.ai.azure.openai.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.7 +| spring.ai.azure.openai.chat.options.topP | An alternative to sampling with temperature called nucleus sampling. This value causes the model to consider the results of tokens with the provided probability mass. | - +| spring.ai.azure.openai.chat.options.logitBias | A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions response. Token IDs are computed via external tokenizer tools, while bias scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection of a token, respectively. The exact behavior of a given bias score varies by model. | - +| spring.ai.azure.openai.chat.options.user | An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. | - +| spring.ai.azure.openai.chat.options.n | The number of chat completions choices that should be generated for a chat completions response. | - +| spring.ai.azure.openai.chat.options.stop | A collection of textual sequences that will end completions generation. | - +| spring.ai.azure.openai.chat.options.presencePenalty | A value that influences the probability of generated tokens appearing based on their existing presence in generated text. Positive values will make tokens less likely to appear when they already exist and increase the model's likelihood to output new topics. | - +| spring.ai.azure.openai.chat.options.frequencyPenalty | A value that influences the probability of generated tokens appearing based on their cumulative frequency in generated text. Positive values will make tokens less likely to appear as their frequency increases and decrease the likelihood of the model repeating the same statements verbatim. | - +|==== diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc new file mode 100644 index 0000000000..626757dbca --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc @@ -0,0 +1,167 @@ += Azure OpenAI Embeddings + +Azure's OpenAI extends the OpenAI capabilities, offering safe text generation and Embeddings computation models for various task: + +- Similarity embeddings are good at capturing semantic similarity between two or more pieces of text. +- Text search embeddings help measure whether long documents are relevant to a short query. +- Code search embeddings are useful for embedding code snippets and embedding natural language search queries. + +The Azure OpenAI embeddings rely on `cosine similarity` to compute similarity between documents and a query. + +== Getting Started + +Obtain your Azure OpenAI `endpoint` and `api-key` from the Azure OpenAI Service section on the link:https://portal.azure.com[Azure Portal]. + + +=== Configure the Azure OpenAI Embedding Client Manually + +Add the `spring-ai-azure-openai` dependency to your project's Maven `pom.xml` file: +[source, xml] +---- + + org.springframework.ai + spring-ai-azure-openai + 0.8.0-SNAPSHOT + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-azure-openai:0.8.0-SNAPSHOT' +} +---- + +NOTE: The `spring-ai-azure-openai` dependency also provide the access to the `AzureOpenAiEmbeddingClient`. For more information about the `AzureOpenAiChatClient` refer to the link:../embeddings/azure-openai-embeddings.html[Azure OpenAI Embeddings] section. + +Next, create an `AzureOpenAiEmbeddingClient` instance and use it to compute the similarity between two input texts: + +[source,java] +---- +var openAIClient = OpenAIClientBuilder() + .credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .buildClient(); + +var embeddingClient = new AzureOpenAiEmbeddingClient(openAIClient) + .withDefaultOptions(AzureOpenAiEmbeddingOptions.builder() + .withModel("text-embedding-ada-002") + .withUser("user-6") + .build()); + +EmbeddingResponse embeddingResponse = embeddingClient + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); +---- + +NOTE: the `text-embedding-ada-002` is actually the `Deployment Name` as presented in the Azure AI Portal. + +The `AzureOpenAiEmbeddingOptions` provides the configuration information for the embedding requests. +The `AzureOpenAiEmbeddingOptions` offers a builder to create the options. + +At start time use the `AzureOpenAiEmbeddingClient#withDefaultOptions()` to configure the default options used for all embedding requests. +Furthermore you can override the default options, at runtime, by passing a `AzureOpenAiEmbeddingOptions` instance with your to the `EmbeddingRequest` request. + +For example to override the default model name for a specific request: + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingClient.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + AzureOpenAiEmbeddingOptions.builder() + .withModel("Different-Embedding-Model-Deployment-Name") + .build())); +---- + +=== Spring Boot Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Azure OpenAI Embedding Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-azure-openai-spring-boot-starter + 0.8.0-SNAPSHOT + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-azure-openai-spring-boot-starter:0.8.0-SNAPSHOT' +} +---- + +Spring AI defines a configuration property named `spring.ai.azure.openai.api-key` that you should set to the value of the `API Key` obtained from Azure. +There is also a configuration property named `spring.ai.azure.openai.endpoint` that you should set to the endpoint URL obtained when provisioning your model in Azure. + +Exporting environment variables is one way to set these configuration properties: + +[source,shell] +---- +export SPRING_AI_AZURE_OPENAI_API_KEY= +export SPRING_AI_AZURE_OPENAI_ENDPOINT= +---- + +The `spring.ai.azure.openai.embedding.options.*` properties are used to configure the default options used for all embedding requests. + +==== Sample Code + +This will create a `EmbeddingClient` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `EmbeddingClient` implementation. + +[source,application.properties] +---- +spring.ai.azure.openai.api-key=YOUR_API_KEY +spring.ai.azure.openai.endpoint=YOUR_ENDPOINT +spring.ai.azure.openai.embedding.options.model=text-embedding-ada-002 +---- + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingClient embeddingClient; + + @Autowired + public ChatController(EmbeddingClient embeddingClient) { + this.embeddingClient = embeddingClient; + } + + @GetMapping("/ai/embedding") + public Map embed(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + EmbeddingResponse embeddingResponse = this.embeddingClient.embedForResponse(List.of(message)); + return Map.of("embedding", embeddingResponse); + } +} +---- + + +== Azure OpenAI Embedding Properties + +The prefix `spring.ai.azure.openai` is the property prefix to configure the connection to Azure OpenAI. + +[cols="3,5,3"] +|==== +| Property | Description | Default + +| spring.ai.azure.openai.api-key | The Key from Azure AI OpenAI `Keys and Endpoint` section under `Resource Management` | - +| spring.ai.azure.openai.endpoint | The endpoint from the Azure AI OpenAI `Keys and Endpoint` section under `Resource Management` | - +|==== + + +The prefix `spring.ai.azure.openai.embeddings` is the property prefix that configures the `EmbeddingClient` implementation for Azure OpenAI + +[cols="3,5,3"] +|==== +| Property | Description | Default + +| spring.ai.azure.openai.embedding.options.model | This is the value of the 'Deployment Name' as presented in the Azure AI Portal | text-embedding-ada-002 +| spring.ai.azure.openai.embedding.options.user | An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. | - +|==== diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index f2c2979021..67766126c1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -52,10 +52,7 @@ public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient, AzureOpenAiChatProperties chatProperties) { AzureOpenAiChatClient azureOpenAiChatClient = new AzureOpenAiChatClient(openAIClient) - .withModel(chatProperties.getModel()) - .withTemperature(chatProperties.getTemperature()) - .withMaxTokens(chatProperties.getMaxTokens()) - .withTopP(chatProperties.getTopP()); + .withDefaultOptions(chatProperties.getOptions()); return azureOpenAiChatClient; } @@ -63,7 +60,7 @@ public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient, @Bean public AzureOpenAiEmbeddingClient azureOpenAiEmbeddingClient(OpenAIClient openAIClient, AzureOpenAiEmbeddingProperties embeddingProperties) { - return new AzureOpenAiEmbeddingClient(openAIClient, embeddingProperties.getModel()); + return new AzureOpenAiEmbeddingClient(openAIClient).withDefaultOptions(embeddingProperties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java index fadcafe1b7..310f3f0fa8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,72 +16,31 @@ package org.springframework.ai.autoconfigure.azure.openai; +import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; @ConfigurationProperties(AzureOpenAiChatProperties.CONFIG_PREFIX) public class AzureOpenAiChatProperties { public static final String CONFIG_PREFIX = "spring.ai.azure.openai.chat"; - /** - * The sampling temperature to use that controls the apparent creativity of generated - * completions. Higher values will make output more random while lower values will - * make results more focused and deterministic. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. - */ - private Double temperature = 0.7; + public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; - /** - * An alternative to sampling with temperature called nucleus sampling. This value - * causes the generative to consider the results of tokens with the provided - * probability mass. As an example, a value of 0.15 will cause only the tokens - * comprising the top 15% of probability mass to be considered. It is not recommended - * to modify temperature and top_p for the same completions request as the interaction - * of these two settings is difficult to predict. - */ - private Double topP; + private static final Double DEFAULT_TEMPERATURE = 0.7; - /** - * Creates an instance of ChatCompletionsOptions class. - */ - private String model = "gpt-35-turbo"; + @NestedConfigurationProperty + private AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); - /** - * The maximum number of tokens to generate. - */ - private Integer maxTokens; - - public Double getTemperature() { - return temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - public String getModel() { - return model; - } - - public void setModel(String model) { - this.model = model; - } - - public Double getTopP() { - return topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - public Integer getMaxTokens() { - return maxTokens; + public AzureOpenAiChatOptions getOptions() { + return this.options; } - public void setMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; + public void setOptions(AzureOpenAiChatOptions options) { + this.options = options; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java index c7ddee8537..2edba84ee1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java @@ -16,24 +16,26 @@ package org.springframework.ai.autoconfigure.azure.openai; +import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.util.Assert; @ConfigurationProperties(AzureOpenAiEmbeddingProperties.CONFIG_PREFIX) public class AzureOpenAiEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.azure.openai.embedding"; - /** - * The text embedding generative to use for the embedding client. - */ - private String model = "text-embedding-ada-002"; + private AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder() + .withModel("text-embedding-ada-002") + .build(); - public String getModel() { - return model; + public AzureOpenAiEmbeddingOptions getOptions() { + return options; } - public void setModel(String model) { - this.model = model; + public void setOptions(AzureOpenAiEmbeddingOptions options) { + Assert.notNull(options, "Options must not be null"); + this.options = options; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 7c28e39ea6..f488fbefce 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,9 @@ @EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class, OpenAiEmbeddingProperties.class }) @ImportRuntimeHints(NativeHints.class) +/** + * @author Christian Tzolov + */ public class OpenAiAutoConfiguration { @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java index 5a77074b8e..c439b03212 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,48 +34,66 @@ public class AzureOpenAiAutoConfigurationPropertyTests { @Test - public void chatPropertiesTest() { + public void embeddingPropertiesTest() { - new ApplicationContextRunner().withPropertyValues( - // @formatter:off - "spring.ai.azure.openai.api-key=TEST_API_KEY", - "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", - "spring.ai.azure.openai.chat.model=MODEL_XYZ", - "spring.ai.azure.openai.chat.temperature=0.55", - "spring.ai.azure.openai.chat.topP=0.56", - "spring.ai.azure.openai.chat.maxTokens=123") - // @formatter:on + new ApplicationContextRunner() + .withPropertyValues("spring.ai.azure.openai.api-key=TEST_API_KEY", + "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", + "spring.ai.azure.openai.embedding.options.model=MODEL_XYZ") .withConfiguration(AutoConfigurations.of(AzureOpenAiAutoConfiguration.class)) .run(context -> { - var chatProperties = context.getBean(AzureOpenAiChatProperties.class); + var chatProperties = context.getBean(AzureOpenAiEmbeddingProperties.class); var connectionProperties = context.getBean(AzureOpenAiConnectionProperties.class); assertThat(connectionProperties.getApiKey()).isEqualTo("TEST_API_KEY"); assertThat(connectionProperties.getEndpoint()).isEqualTo("TEST_ENDPOINT"); - assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); - - assertThat(chatProperties.getTemperature()).isEqualTo(0.55); - assertThat(chatProperties.getTopP()).isEqualTo(0.56); - assertThat(chatProperties.getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); } @Test - public void embeddingPropertiesTest() { + public void chatPropertiesTest() { - new ApplicationContextRunner() - .withPropertyValues("spring.ai.azure.openai.api-key=TEST_API_KEY", - "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", "spring.ai.azure.openai.embedding.model=MODEL_XYZ") + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.azure.openai.api-key=API_KEY", + "spring.ai.azure.openai.endpoint=ENDPOINT", + + "spring.ai.azure.openai.chat.options.model=MODEL_XYZ", + "spring.ai.azure.openai.chat.options.frequencyPenalty=-1.5", + "spring.ai.azure.openai.chat.options.logitBias.myTokenId=-5", + "spring.ai.azure.openai.chat.options.maxTokens=123", + "spring.ai.azure.openai.chat.options.n=10", + "spring.ai.azure.openai.chat.options.presencePenalty=0", + "spring.ai.azure.openai.chat.options.stop=boza,koza", + "spring.ai.azure.openai.chat.options.temperature=0.55", + "spring.ai.azure.openai.chat.options.topP=0.56", + "spring.ai.azure.openai.chat.options.user=userXYZ" + ) + // @formatter:on .withConfiguration(AutoConfigurations.of(AzureOpenAiAutoConfiguration.class)) .run(context -> { - var chatProperties = context.getBean(AzureOpenAiEmbeddingProperties.class); + var chatProperties = context.getBean(AzureOpenAiChatProperties.class); var connectionProperties = context.getBean(AzureOpenAiConnectionProperties.class); + var embeddingProperties = context.getBean(AzureOpenAiEmbeddingProperties.class); - assertThat(connectionProperties.getApiKey()).isEqualTo("TEST_API_KEY"); - assertThat(connectionProperties.getEndpoint()).isEqualTo("TEST_ENDPOINT"); + assertThat(connectionProperties.getEndpoint()).isEqualTo("ENDPOINT"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); + assertThat(chatProperties.getOptions().getLogitBias().get("myTokenId")).isEqualTo(-5); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getN()).isEqualTo(10); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f); - assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getUser()).isEqualTo("userXYZ"); }); }