Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions .github/workflows/pr-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ jobs:
distribution: 'temurin'
cache: 'maven'

- name: Run tests
- name: Build spring-ai-model module
run: |
./mvnw --batch-mode test
./mvnw clean install -pl spring-ai-model -DskipTests

- name: Build spring-ai-openai module
run: |
./mvnw clean install -pl models/spring-ai-openai -DskipTests

- name: Run OpenAI Image Model integration tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
./mvnw -pl models/spring-ai-openai -Pintegration-tests -Dit.test=OpenAiImageModelIT,OpenAiImageModelNoOpApiKeysIT,OpenAiImageModelObservationIT,OpenAiImageModelStreamingIT,OpenAiImageModelWithImageResponseMetadataTests,OpenAiImageApiBuilderTests,OpenAiImageApiIT,OpenAiImageApiStreamingIT verify

- name: Run OpenAI Image AutoConfiguration tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
./mvnw -pl auto-configurations/models/spring-ai-autoconfigure-model-openai -Pintegration-tests -Dit.test=OpenAiImageAutoConfigurationIT verify

- name: Run OpenAI Image Properties tests
run: |
./mvnw test -pl auto-configurations/models/spring-ai-autoconfigure-model-openai -Dtest=OpenAiPropertiesTests#imageOptionsTest,OpenAiPropertiesTests#imageGptImageOptionsTest
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.springframework.ai.model.SpringAIModelProperties;
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -38,6 +37,7 @@
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;

import static org.springframework.ai.model.openai.autoconfigure.OpenAIAutoConfigurationUtil.resolveConnectionProperties;

Expand All @@ -53,31 +53,38 @@
*/
@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class,
SpringAiRetryAutoConfiguration.class })
@ConditionalOnClass(OpenAiApi.class)
@ConditionalOnClass(OpenAiImageApi.class)
@ConditionalOnProperty(name = SpringAIModelProperties.IMAGE_MODEL, havingValue = SpringAIModels.OPENAI,
matchIfMissing = true)
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiImageProperties.class })
public class OpenAiImageAutoConfiguration {

@Bean
@ConditionalOnMissingBean
public OpenAiImageModel openAiImageModel(OpenAiConnectionProperties commonProperties,
public OpenAiImageApi openAiImageApi(OpenAiConnectionProperties commonProperties,
OpenAiImageProperties imageProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ImageModelObservationConvention> observationConvention) {
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ResponseErrorHandler responseErrorHandler) {

OpenAIAutoConfigurationUtil.ResolvedConnectionProperties resolved = resolveConnectionProperties(
commonProperties, imageProperties, "image");

var openAiImageApi = OpenAiImageApi.builder()
return OpenAiImageApi.builder()
.baseUrl(resolved.baseUrl())
.apiKey(new SimpleApiKey(resolved.apiKey()))
.headers(resolved.headers())
.imagesPath(imageProperties.getImagesPath())
.restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
.webClientBuilder(webClientBuilderProvider.getIfAvailable(WebClient::builder))
.responseErrorHandler(responseErrorHandler)
.build();
}

@Bean
@ConditionalOnMissingBean
public OpenAiImageModel openAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageProperties imageProperties,
RetryTemplate retryTemplate, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ImageModelObservationConvention> observationConvention) {

var imageModel = new OpenAiImageModel(openAiImageApi, imageProperties.getOptions(), retryTemplate,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.model.openai.autoconfigure;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for {@link OpenAiImageAutoConfiguration}.
*
* @author Alexandros Pappas
* @since 1.1.0
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiImageAutoConfigurationIT {

private static final Log logger = LogFactory.getLog(OpenAiImageAutoConfigurationIT.class);

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
.withConfiguration(
AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class,
WebClientAutoConfiguration.class, ToolCallingAutoConfiguration.class));

@Test
void imageModelAutoConfigured() {
this.contextRunner.withConfiguration(AutoConfigurations.of(OpenAiImageAutoConfiguration.class)).run(context -> {
assertThat(context.getBeansOfType(OpenAiImageModel.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiImageApi.class)).isNotEmpty();
});
}

@Test
void generateImage() {
this.contextRunner
.withPropertyValues("spring.ai.openai.image.options.model=dall-e-2",
"spring.ai.openai.image.options.response-format=b64_json")
.withConfiguration(AutoConfigurations.of(OpenAiImageAutoConfiguration.class))
.run(context -> {
OpenAiImageModel imageModel = context.getBean(OpenAiImageModel.class);
ImagePrompt prompt = new ImagePrompt("A simple red circle");
ImageResponse response = imageModel.call(prompt);

assertThat(response).isNotNull();
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResult().getOutput().getB64Json()).isNotEmpty();

logger.info("Generated image with base64 length: "
+ response.getResult().getOutput().getB64Json().length());
});
}

@Test
void imageModelDisabled() {
this.contextRunner.withPropertyValues("spring.ai.model.image=none")
.withConfiguration(AutoConfigurations.of(OpenAiImageAutoConfiguration.class))
.run(context -> assertThat(context.getBeansOfType(OpenAiImageModel.class)).isEmpty());
}

@Test
void imageModelExplicitlyEnabled() {
this.contextRunner.withPropertyValues("spring.ai.model.image=openai")
.withConfiguration(AutoConfigurations.of(OpenAiImageAutoConfiguration.class))
.run(context -> assertThat(context.getBeansOfType(OpenAiImageModel.class)).isNotEmpty());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,46 @@ public void imageOptionsTest() {
});
}

@Test
public void imageGptImageOptionsTest() {
this.contextRunner.withPropertyValues(
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",

"spring.ai.openai.image.options.model=gpt-image-1",
"spring.ai.openai.image.options.quality=high",
"spring.ai.openai.image.options.size=1024x1024",
"spring.ai.openai.image.options.background=transparent",
"spring.ai.openai.image.options.moderation=low",
"spring.ai.openai.image.options.output_compression=85",
"spring.ai.openai.image.options.output_format=png",
"spring.ai.openai.image.options.partial_images=2",
"spring.ai.openai.image.options.stream=true",
"spring.ai.openai.image.options.user=userXYZ"
)
// @formatter:on
.withConfiguration(AutoConfigurations.of(OpenAiImageAutoConfiguration.class))
.run(context -> {
var imageProperties = context.getBean(OpenAiImageProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");

assertThat(imageProperties.getOptions().getModel()).isEqualTo("gpt-image-1");
assertThat(imageProperties.getOptions().getQuality()).isEqualTo("high");
assertThat(imageProperties.getOptions().getSize()).isEqualTo("1024x1024");
assertThat(imageProperties.getOptions().getBackground()).isEqualTo("transparent");
assertThat(imageProperties.getOptions().getModeration()).isEqualTo("low");
assertThat(imageProperties.getOptions().getOutputCompression()).isEqualTo(85);
assertThat(imageProperties.getOptions().getOutputFormat()).isEqualTo("png");
assertThat(imageProperties.getOptions().getPartialImages()).isEqualTo(2);
assertThat(imageProperties.getOptions().getStream()).isTrue();
assertThat(imageProperties.getOptions().getUser()).isEqualTo("userXYZ");
});
}

@Test
void embeddingActivation() {

Expand Down
5 changes: 5 additions & 0 deletions models/spring-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageGeneration;
Expand All @@ -29,6 +30,7 @@
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.image.StreamingImageModel;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
Expand All @@ -43,16 +45,26 @@
import org.springframework.util.Assert;

/**
* OpenAiImageModel is a class that implements the ImageModel interface. It provides a
* client for calling the OpenAI image generation API.
* OpenAiImageModel is a class that implements the ImageModel and StreamingImageModel
* interfaces. It provides a client for calling the OpenAI image generation API with both
* synchronous and streaming capabilities.
*
* <p>
* Streaming image generation is supported for GPT-Image models (gpt-image-1,
* gpt-image-1-mini) and allows receiving partial images as they are generated. DALL-E
* models do not support streaming.
* </p>
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Hyunjoon Choi
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 0.8.0
* @see ImageModel
* @see StreamingImageModel
*/
public class OpenAiImageModel implements ImageModel {
public class OpenAiImageModel implements ImageModel, StreamingImageModel {

private static final Logger logger = LoggerFactory.getLogger(OpenAiImageModel.class);

Expand Down Expand Up @@ -205,6 +217,51 @@ private ImagePrompt buildRequestImagePrompt(ImagePrompt imagePrompt) {
return new ImagePrompt(imagePrompt.getInstructions(), requestOptions);
}

@Override
public Flux<ImageResponse> stream(ImagePrompt imagePrompt) {
// Before moving any further, build the final request ImagePrompt,
// merging runtime and default options.
ImagePrompt requestImagePrompt = buildRequestImagePrompt(imagePrompt);

OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(requestImagePrompt);

// Validate that streaming is only used with GPT-Image models
String model = imageRequest.model();
if (model != null && !model.startsWith("gpt-image-")) {
return Flux.error(new IllegalArgumentException(
"Streaming is only supported for GPT-Image models (gpt-image-1, gpt-image-1-mini). "
+ "Current model: " + model));
}

// Ensure stream is set to true
if (imageRequest.stream() == null || !imageRequest.stream()) {
imageRequest = new OpenAiImageApi.OpenAiImageRequest(imageRequest.prompt(), imageRequest.model(),
imageRequest.n(), imageRequest.quality(), imageRequest.responseFormat(), imageRequest.size(),
imageRequest.style(), imageRequest.user(), imageRequest.background(), imageRequest.moderation(),
imageRequest.outputCompression(), imageRequest.outputFormat(), imageRequest.partialImages(), true);
}

var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.build();

OpenAiImageApi.OpenAiImageRequest finalImageRequest = imageRequest;

// Stream the image generation events
Flux<OpenAiImageApi.OpenAiImageStreamEvent> eventStream = this.openAiImageApi.streamImage(finalImageRequest);

// Convert streaming events to ImageResponse
return eventStream.map(event -> {
Image image = new Image(null, event.b64Json());
OpenAiImageGenerationMetadata metadata = new OpenAiImageGenerationMetadata(null);
ImageGeneration generation = new ImageGeneration(image, metadata);
ImageResponseMetadata responseMetadata = event.createdAt() != null
? new ImageResponseMetadata(event.createdAt()) : new ImageResponseMetadata(null);
return new ImageResponse(List.of(generation), responseMetadata);
});
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand Down
Loading