Skip to content

Commit

Permalink
added feature to customize the OpenAIClientBuilder whilst retaining t…
Browse files Browse the repository at this point in the history
…he default auto-configuration

Signed-off-by: Manuel Andreo Garcia <manuel@magware.dev>
  • Loading branch information
magware-dev committed Jan 21, 2025
1 parent 1c41c6a commit dc3279b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,18 @@ public class AzureOpenAiAutoConfiguration {

private static final String APPLICATION_ID = "spring-ai";

@Bean
@ConditionalOnMissingBean
public OpenAIClientBuilderCustomizer openAIClientBuilderCustomizer() {
return clientBuilder -> {
};
}

@Bean
@ConditionalOnMissingBean // ({ OpenAIClient.class, TokenCredential.class })
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties) {
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties,
ObjectProvider<OpenAIClientBuilderCustomizer> customizers) {

if (StringUtils.hasText(connectionProperties.getApiKey())) {

Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
Expand All @@ -77,17 +86,21 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c
.map(entry -> new Header(entry.getKey(), entry.getValue()))
.collect(Collectors.toList());
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
.clientOptions(clientOptions);
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
return clientBuilder;
}

// Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is
// used as OpenAI model name.
if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) {
return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
.credential(new KeyCredential(connectionProperties.getOpenAiApiKey()))
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
return clientBuilder;
}

throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty");
Expand All @@ -97,14 +110,16 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c
@ConditionalOnMissingBean
@ConditionalOnBean(TokenCredential.class)
public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties,
TokenCredential tokenCredential) {
TokenCredential tokenCredential, ObjectProvider<OpenAIClientBuilderCustomizer> customizers) {

Assert.notNull(tokenCredential, "TokenCredential must not be null");
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");

return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
.credential(tokenCredential)
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
return clientBuilder;
}

@Bean
Expand Down Expand Up @@ -169,4 +184,9 @@ public AzureOpenAiAudioTranscriptionModel azureOpenAiAudioTranscriptionModel(Ope
return new AzureOpenAiAudioTranscriptionModel(openAIClient.buildClient(), audioProperties.getOptions());
}

private void applyOpenAIClientBuilderCustomizers(OpenAIClientBuilder clientBuilder,
ObjectProvider<OpenAIClientBuilderCustomizer> customizers) {
customizers.orderedStream().forEach(customizer -> customizer.customize(clientBuilder));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.springframework.ai.autoconfigure.azure.openai;

import com.azure.ai.openai.OpenAIClientBuilder;

/**
* Callback interface that can be implemented by beans wishing to customize the
* {@link OpenAIClientBuilder} whilst retaining the default auto-configuration.
*/
@FunctionalInterface
public interface OpenAIClientBuilderCustomizer {

/**
* Customize the {@link OpenAIClientBuilder}.
* @param clientBuilder the {@link OpenAIClientBuilder} to customize
*/
void customize(OpenAIClientBuilder clientBuilder);

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import com.azure.ai.openai.OpenAIClient;
Expand All @@ -33,6 +34,7 @@
import com.azure.core.http.HttpResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.autoconfigure.azure.openai.OpenAIClientBuilderCustomizer;
import reactor.core.publisher.Flux;

import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
Expand Down Expand Up @@ -228,4 +230,20 @@ void audioTranscriptionActivation() {
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
}

@Test
void openAIClientBuilderCustomizer() {
AtomicBoolean firstCustomizationApplied = new AtomicBoolean(false);
AtomicBoolean secondCustomizationApplied = new AtomicBoolean(false);
this.contextRunner
.withBean("first", OpenAIClientBuilderCustomizer.class,
() -> clientBuilder -> firstCustomizationApplied.set(true))
.withBean("second", OpenAIClientBuilderCustomizer.class,
() -> clientBuilder -> secondCustomizationApplied.set(true))
.run(context -> {
context.getBean(OpenAIClientBuilder.class);
assertThat(firstCustomizationApplied.get()).isTrue();
assertThat(secondCustomizationApplied.get()).isTrue();
});
}

}

0 comments on commit dc3279b

Please sign in to comment.