Skip to content

Commit

Permalink
Change Azure embedding and chat options 'model' property to 'deployme…
Browse files Browse the repository at this point in the history
…nt-name'

* Azure uses 'deployment-name' when provisioning models and is what needs to
  be passed in to the client, not the model name.
  This is a difference with the OpenAI API that doesn't have a deployment-name
  This change aligns the terminology used with Azure so that there is
  less confusion when setting configuration property values

Fixes #10
  • Loading branch information
markpollack committed Mar 12, 2024
1 parent 777b79e commit 1e98f82
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public class AzureOpenAiChatClient
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
implements ChatClient, StreamingChatClient {

private static final String DEFAULT_MODEL = "gpt-35-turbo";
private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";

private static final Float DEFAULT_TEMPERATURE = 0.7f;

Expand All @@ -93,7 +93,10 @@ public class AzureOpenAiChatClient

public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
this(microsoftOpenAiClient,
AzureOpenAiChatOptions.builder().withModel(DEFAULT_MODEL).withTemperature(DEFAULT_TEMPERATURE).build());
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.build());
}

public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
Expand Down Expand Up @@ -131,12 +134,7 @@ public ChatResponse call(Prompt prompt) {
options.setStream(false);

logger.trace("Azure ChatCompletionsOptions: {}", options);

ChatCompletions chatCompletions = this.callWithFunctionSupport(options);

// ChatCompletions chatCompletions =
// this.openAIClient.getChatCompletions(options.getModel(), options);

logger.trace("Azure ChatCompletions: {}", chatCompletions);

List<Generation> generations = chatCompletions.getChoices()
Expand Down Expand Up @@ -323,7 +321,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureO
mergedAzureOptions.setUser(azureOptions.getUser() != null ? azureOptions.getUser() : springAiOptions.getUser());

mergedAzureOptions
.setModel(azureOptions.getModel() != null ? azureOptions.getModel() : springAiOptions.getModel());
.setModel(azureOptions.getModel() != null ? azureOptions.getModel() : springAiOptions.getDeploymentName());

return mergedAzureOptions;
}
Expand Down Expand Up @@ -376,8 +374,8 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, Cha
mergedAzureOptions.setUser(springAiOptions.getUser());
}

if (springAiOptions.getModel() != null) {
mergedAzureOptions.setModel(springAiOptions.getModel());
if (springAiOptions.getDeploymentName() != null) {
mergedAzureOptions.setModel(springAiOptions.getDeploymentName());
}

return mergedAzureOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,11 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
private Double frequencyPenalty;

/**
* The model name to provide as part of this completions request. Not applicable to
* Azure OpenAI, where deployment information should be included in the Azure resource
* URI that's connected to.
* The deployment name as defined in Azure Open AI Studio when creating a deployment
* backed by an Azure OpenAI base model.
*/
@JsonProperty(value = "model")
private String model;
@JsonProperty(value = "deployment_name")
private String deploymentName;

/**
* OpenAI Tool Function Callbacks to register with the ChatClient. For Prompt Options
Expand Down Expand Up @@ -169,8 +168,8 @@ public Builder(AzureOpenAiChatOptions options) {
this.options = options;
}

public Builder withModel(String model) {
this.options.model = model;
public Builder withDeploymentName(String deploymentName) {
this.options.deploymentName = deploymentName;
return this;
}

Expand Down Expand Up @@ -298,12 +297,12 @@ public void setFrequencyPenalty(Double frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}

public String getModel() {
return this.model;
public String getDeploymentName() {
return this.deploymentName;
}

public void setModel(String model) {
this.model = model;
public void setDeploymentName(String deploymentName) {
this.deploymentName = deploymentName;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {

Expand All @@ -53,7 +54,7 @@ public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) {

public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
this(azureOpenAiClient, metadataMode,
AzureOpenAiEmbeddingOptions.builder().withModel("text-embedding-ada-002").build());
AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build());
}

public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
Expand Down Expand Up @@ -93,7 +94,7 @@ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) {
var azureOptions = new EmbeddingsOptions(embeddingRequest.getInstructions());
if (this.defaultOptions != null) {
azureOptions.setModel(this.defaultOptions.getModel());
azureOptions.setModel(this.defaultOptions.getDeploymentName());
azureOptions.setUser(this.defaultOptions.getUser());
}
if (embeddingRequest.getOptions() != null && !EmbeddingOptions.EMPTY.equals(embeddingRequest.getOptions())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ public class AzureOpenAiEmbeddingOptions implements EmbeddingOptions {
private String user;

/**
* The model name to provide as part of this embeddings request. Not applicable to
* Azure OpenAI, where deployment information should be included in the Azure resource
* URI that's connected to.
* The deployment name as defined in Azure Open AI Studio when creating a deployment
* backed by an Azure OpenAI base model. If using Azure OpenAI library to communicate
* with OpenAI (not Azure OpenAI) then this value will be used as the name of the
* model. The json serialization of this field is 'model'.
*/
@JsonProperty(value = "model")
private String model;
private String deploymentName;

public static Builder builder() {
return new Builder();
Expand All @@ -55,8 +56,8 @@ public Builder withUser(String user) {
return this;
}

public Builder withModel(String model) {
this.options.setModel(model);
public Builder withDeploymentName(String model) {
this.options.setDeploymentName(model);
return this;
}

Expand All @@ -74,12 +75,12 @@ public void setUser(String user) {
this.user = user;
}

public String getModel() {
return this.model;
public String getDeploymentName() {
return this.deploymentName;
}

public void setModel(String model) {
this.model = model;
public void setDeploymentName(String deploymentName) {
this.deploymentName = deploymentName;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
var client = new AzureOpenAiChatClient(mockClient,
AzureOpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
AzureOpenAiChatOptions.builder().withDeploymentName("DEFAULT_MODEL").withTemperature(66.6f).build());

var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));

Expand All @@ -43,7 +43,7 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getTemperature()).isEqualTo(66.6f);

requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content",
AzureOpenAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()));
AzureOpenAiChatOptions.builder().withDeploymentName("PROMPT_MODEL").withTemperature(99.9f).build()));

assertThat(requestOptions.getMessages()).hasSize(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
var client = new AzureOpenAiEmbeddingClient(mockClient, MetadataMode.EMBED,
AzureOpenAiEmbeddingOptions.builder().withModel("DEFAULT_MODEL").withUser("USER_TEST").build());
AzureOpenAiEmbeddingOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
.withUser("USER_TEST")
.build());

var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"), null));

Expand All @@ -47,7 +50,10 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getUser()).isEqualTo("USER_TEST");

requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"),
AzureOpenAiEmbeddingOptions.builder().withModel("PROMPT_MODEL").withUser("PROMPT_USER").build()));
AzureOpenAiEmbeddingOptions.builder()
.withDeploymentName("PROMPT_MODEL")
.withUser("PROMPT_USER")
.build()));

assertThat(requestOptions.getInput()).hasSize(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public OpenAIClient openAIClient() {
@Bean
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
return new AzureOpenAiChatClient(openAIClient,
AzureOpenAiChatOptions.builder().withModel("gpt-35-turbo").withMaxTokens(200).build());
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build());

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withModel("gpt-4-0125-preview")
.withDeploymentName("gpt-4-0125-preview")
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
Expand Down Expand Up @@ -88,8 +88,10 @@ public OpenAIClient openAIClient() {
@Bean
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
return new AzureOpenAiChatClient(openAIClient,
AzureOpenAiChatOptions.builder().withModel("gpt-35-turbo-0613").withMaxTokens(500).build());

AzureOpenAiChatOptions.builder()
.withDeploymentName("gpt-4-0125-preview")
.withMaxTokens(500)
.build());
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The prefix `spring.ai.azure.openai.chat` is the property prefix that configures
| Property | Description | Default

| spring.ai.azure.openai.chat.enabled | Enable Azure OpenAI chat client. | true
| spring.ai.azure.openai.chat.options.model | * In use with Azure, this actually refers to the "Deployment Name" of your model, which you can find at https://oai.azure.com/portal. It's important to note that within an Azure OpenAI deployment, the "Deployment Name" is distinct from the model itself. The confusion around these terms stems from the intention to make the Azure OpenAI client library compatible with the original OpenAI endpoint. The deployment structures offered by Azure OpenAI and Sam Altman's OpenAI differ significantly. To clarify this distinction, we plan to rename this attribute to `deployment-name` in future updates.
| spring.ai.azure.openai.chat.options.deployment-name | * In use with Azure, this refers to the "Deployment Name" of your model, which you can find at https://oai.azure.com/portal. It's important to note that within an Azure OpenAI deployment, the "Deployment Name" is distinct from the model itself. The confusion around these terms stems from the intention to make the Azure OpenAI client library compatible with the original OpenAI endpoint. The deployment structures offered by Azure OpenAI and Sam Altman's OpenAI differ significantly.
Deployments model name to provide as part of this completions request.
| gpt-35-turbo
| spring.ai.azure.openai.chat.options.maxTokens | The maximum number of tokens to generate. | -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ The prefix `spring.ai.azure.openai.embeddings` is the property prefix that confi

| spring.ai.azure.openai.embedding.enabled | Enable Azure OpenAI embedding client. | true
| spring.ai.azure.openai.embedding.metadata-mode | Document content extraction mode | EMBED
| spring.ai.azure.openai.embedding.options.model | This is the value of the 'Deployment Name' as presented in the Azure AI Portal | text-embedding-ada-002
| spring.ai.azure.openai.embedding.options.deployment-name | This is the value of the 'Deployment Name' as presented in the Azure AI Portal | text-embedding-ada-002
| spring.ai.azure.openai.embedding.options.user | An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. | -
|====

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class AzureOpenAiChatProperties {

public static final String CONFIG_PREFIX = "spring.ai.azure.openai.chat";

public static final String DEFAULT_CHAT_MODEL = "gpt-35-turbo";
public static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";

private static final Double DEFAULT_TEMPERATURE = 0.7;

Expand All @@ -35,7 +35,7 @@ public class AzureOpenAiChatProperties {

@NestedConfigurationProperty
private AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder()
.withModel(DEFAULT_CHAT_MODEL)
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE.floatValue())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class AzureOpenAiEmbeddingProperties {

@NestedConfigurationProperty
private AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder()
.withModel("text-embedding-ada-002")
.withDeploymentName("text-embedding-ada-002")
.build();

private MetadataMode metadataMode = MetadataMode.EMBED;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void embeddingPropertiesTest() {
assertThat(connectionProperties.getApiKey()).isEqualTo("TEST_API_KEY");
assertThat(connectionProperties.getEndpoint()).isEqualTo("TEST_ENDPOINT");

assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getDeploymentName()).isEqualTo("MODEL_XYZ");
});
}

Expand Down Expand Up @@ -80,9 +80,9 @@ public void chatPropertiesTest() {
assertThat(connectionProperties.getEndpoint()).isEqualTo("ENDPOINT");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002");
assertThat(embeddingProperties.getOptions().getDeploymentName()).isEqualTo("text-embedding-ada-002");

assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getDeploymentName()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f);
assertThat(chatProperties.getOptions().getLogitBias().get("myTokenId")).isEqualTo(-5);
assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123);
Expand Down

0 comments on commit 1e98f82

Please sign in to comment.