Skip to content

Commit

Permalink
Add Azure OpenAI Chat and Embedding Options
Browse files Browse the repository at this point in the history
 - Add AzureOpenAiChatOptions
   - Add default options field to AzureOpenAiChatClient.
   - Impl runtime (e.g. prompt) and default options on call.
   - Add options field to the AzureOpenAiChatProperties.
 - Add AzureOpenAiEmbeddingOptions
   - Add default options field to AzureOpenAiEmbeddingClient.
   - Impmlement runtime and default option merging on embedding request.
   - Add options field to AzureOpenAiEmbeddingProperties.
 - Add Unit and ITs.
 - Split the azure-openai.adoc into ./clients/azure-openai-chat.adoc and ./embeddings/azure-openai-embeddings.adoc.
   - Provide detailed explanation how to use the chat and embedding clients manually or via the auto-configuration.
  • Loading branch information
tzolov committed Jan 30, 2024
1 parent 5b4784f commit 4b88b5c
Show file tree
Hide file tree
Showing 17 changed files with 1,136 additions and 212 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,23 +27,22 @@
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.core.util.IterableStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import reactor.core.publisher.Flux;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

/**
Expand All @@ -59,104 +58,39 @@
*/
public class AzureOpenAiChatClient implements ChatClient, StreamingChatClient {

/**
* The sampling temperature to use that controls the apparent creativity of generated
* completions. Higher values will make output more random while lower values will
* make results more focused and deterministic. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
private Double temperature = 0.7;
private static final String DEFAULT_MODEL = "gpt-35-turbo";

/**
* An alternative to sampling with temperature called nucleus sampling. This value
* causes the model to consider the results of tokens with the provided probability
* mass. As an example, a value of 0.15 will cause only the tokens comprising the top
* 15% of probability mass to be considered. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
private Double topP;
private static final Float DEFAULT_TEMPERATURE = 0.7f;

private final Logger logger = LoggerFactory.getLogger(getClass());

/**
* Creates an instance of ChatCompletionsOptions class.
* The configuration information for a chat completions request.
*/
private String model = "gpt-35-turbo";
private AzureOpenAiChatOptions defaultOptions;

/**
* The maximum number of tokens to generate.
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
*/
private Integer maxTokens;

private final Logger logger = LoggerFactory.getLogger(getClass());

private final OpenAIClient openAIClient;

public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
this.openAIClient = microsoftOpenAiClient;
this.defaultOptions = AzureOpenAiChatOptions.builder()
.withModel(DEFAULT_MODEL)
.withTemperature(DEFAULT_TEMPERATURE)
.build();
}

public String getModel() {
return this.model;
}

public AzureOpenAiChatClient withModel(String model) {
this.model = model;
return this;
}

public Double getTemperature() {
return this.temperature;
}

public AzureOpenAiChatClient withTemperature(Double temperature) {
this.temperature = temperature;
return this;
}

public Double getTopP() {
return topP;
}

public AzureOpenAiChatClient withTopP(Double topP) {
this.topP = topP;
public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.defaultOptions = defaultOptions;
return this;
}

public Integer getMaxTokens() {
return maxTokens;
}

public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
return this;
}

@Override
public String call(String text) {

ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text);

ChatCompletionsOptions options = new ChatCompletionsOptions(List.of(azureChatMessage));
options.setTemperature(this.getTemperature());
options.setModel(this.getModel());

logger.trace("Azure Chat Message: {}", azureChatMessage);

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options);
logger.trace("Azure ChatCompletions: {}", chatCompletions);

StringBuilder stringBuilder = new StringBuilder();

for (ChatChoice choice : chatCompletions.getChoices()) {
ChatResponseMessage message = choice.getMessage();
if (message != null && message.getContent() != null) {
stringBuilder.append(message.getContent());
}
}

return stringBuilder.toString();
public AzureOpenAiChatOptions getDefaultOptions() {
return defaultOptions;
}

@Override
Expand All @@ -167,7 +101,7 @@ public ChatResponse call(Prompt prompt) {

logger.trace("Azure ChatCompletionsOptions: {}", options);

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options);
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);

logger.trace("Azure ChatCompletions: {}", chatCompletions);

Expand All @@ -178,6 +112,7 @@ public ChatResponse call(Prompt prompt) {
.toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations,
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
}
Expand All @@ -189,7 +124,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
options.setStream(true);

IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
.getChatCompletionsStream(this.getModel(), options);
.getChatCompletionsStream(options.getModel(), options);

return Flux.fromStream(chatCompletionsStream.stream()
// Note: the first chat completions can be ignored when using Azure OpenAI
Expand All @@ -205,7 +140,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}));
}

private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
/**
* Test access.
*/
ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {

List<ChatRequestMessage> azureMessages = prompt.getInstructions()
.stream()
Expand All @@ -214,10 +152,27 @@ private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {

ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);

options.setTemperature(this.getTemperature());
options.setModel(this.getModel());
options.setTopP(this.getTopP());
options.setMaxTokens(this.getMaxTokens());
if (this.defaultOptions != null) {
// JSON merge doesn't due to Azure OpenAI service bug:
// https://github.com/Azure/azure-sdk-for-java/issues/38183
// options = ModelOptionsUtils.merge(options, this.defaultOptions,
// ChatCompletionsOptions.class);
options = merge(options, this.defaultOptions);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof AzureOpenAiChatOptions runtimeOptions) {
// JSON merge doesn't due to Azure OpenAI service bug:
// https://github.com/Azure/azure-sdk-for-java/issues/38183
// options = ModelOptionsUtils.merge(runtimeOptions, options,
// ChatCompletionsOptions.class);
options = merge(runtimeOptions, options);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:"
+ prompt.getOptions().getClass().getSimpleName());
}
}

return options;
}
Expand Down Expand Up @@ -256,4 +211,121 @@ private <T> List<T> nullSafeList(List<T> list) {
return list != null ? list : Collections.emptyList();
}

// JSON merge doesn't due to Azure OpenAI service bug:
// https://github.com/Azure/azure-sdk-for-java/issues/38183
private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureOpenAiChatOptions springAiOptions) {

if (springAiOptions == null) {
return azureOptions;
}

ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
mergedAzureOptions.setStream(azureOptions.isStream());

mergedAzureOptions.setMaxTokens(azureOptions.getMaxTokens());
if (mergedAzureOptions.getMaxTokens() == null) {
mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
}

mergedAzureOptions.setLogitBias(azureOptions.getLogitBias());
if (mergedAzureOptions.getLogitBias() == null) {
mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
}

mergedAzureOptions.setStop(azureOptions.getStop());
if (mergedAzureOptions.getStop() == null) {
mergedAzureOptions.setStop(springAiOptions.getStop());
}

mergedAzureOptions.setTemperature(azureOptions.getTemperature());
if (mergedAzureOptions.getTemperature() == null && springAiOptions.getTemperature() != null) {
mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
}

mergedAzureOptions.setTopP(azureOptions.getTopP());
if (mergedAzureOptions.getTopP() == null && springAiOptions.getTopP() != null) {
mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
}

mergedAzureOptions.setFrequencyPenalty(azureOptions.getFrequencyPenalty());
if (mergedAzureOptions.getFrequencyPenalty() == null && springAiOptions.getFrequencyPenalty() != null) {
mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
}

mergedAzureOptions.setPresencePenalty(azureOptions.getPresencePenalty());
if (mergedAzureOptions.getPresencePenalty() == null && springAiOptions.getPresencePenalty() != null) {
mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
}

mergedAzureOptions.setN(azureOptions.getN());
if (mergedAzureOptions.getN() == null) {
mergedAzureOptions.setN(springAiOptions.getN());
}

mergedAzureOptions.setUser(azureOptions.getUser());
if (mergedAzureOptions.getUser() == null) {
mergedAzureOptions.setUser(springAiOptions.getUser());
}

mergedAzureOptions.setModel(azureOptions.getModel());
if (mergedAzureOptions.getModel() == null) {
mergedAzureOptions.setModel(springAiOptions.getModel());
}

return mergedAzureOptions;
}

// JSON merge doesn't due to Azure OpenAI service bug:
// https://github.com/Azure/azure-sdk-for-java/issues/38183
private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, ChatCompletionsOptions azureOptions) {
if (springAiOptions == null) {
return azureOptions;
}

ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
mergedAzureOptions.setStream(azureOptions.isStream());

if (springAiOptions.getMaxTokens() != null) {
mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
}

if (springAiOptions.getLogitBias() != null) {
mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
}

if (springAiOptions.getStop() != null) {
mergedAzureOptions.setStop(springAiOptions.getStop());
}

if (springAiOptions.getTemperature() != null && springAiOptions.getTemperature() != null) {
mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
}

if (springAiOptions.getTopP() != null && springAiOptions.getTopP() != null) {
mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
}

if (springAiOptions.getFrequencyPenalty() != null && springAiOptions.getFrequencyPenalty() != null) {
mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
}

if (springAiOptions.getPresencePenalty() != null && springAiOptions.getPresencePenalty() != null) {
mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
}

if (springAiOptions.getN() != null) {
mergedAzureOptions.setN(springAiOptions.getN());
}

if (springAiOptions.getUser() != null) {
mergedAzureOptions.setUser(springAiOptions.getUser());
}

if (springAiOptions.getModel() != null) {
mergedAzureOptions.setModel(springAiOptions.getModel());
}

return mergedAzureOptions;
}

}
Loading

0 comments on commit 4b88b5c

Please sign in to comment.