Skip to content

Commit

Permalink
Streamline ImageOptions
Browse files Browse the repository at this point in the history
* Add style to option abstraction
* Use abstraction in Observations directly instead of dedicated implementation
* Clean-up the merge of runtime and default image options in OpenAI and Stability AI

Related to spring-projects#1148

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and tzolov committed Aug 7, 2024
1 parent 80007d4 commit 17ba1fc
Show file tree
Hide file tree
Showing 17 changed files with 193 additions and 381 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* The configuration information for a image generation request.
*
* @author Benoit Moussaud
* @author Thomas Vitale
* @since 1.0.0 M1
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand Down Expand Up @@ -88,10 +89,15 @@ public class AzureOpenAiImageOptions implements ImageOptions {
@JsonProperty("user")
private String user;

@Override
public Integer getN() {
return n;
}

public void setN(Integer n) {
this.n = n;
}

@Override
public String getModel() {
return model;
Expand All @@ -101,10 +107,7 @@ public void setModel(String model) {
this.model = model;
}

public void setN(Integer n) {
this.n = n;
}

@Override
public Integer getWidth() {
return width;
}
Expand All @@ -114,6 +117,7 @@ public void setWidth(Integer width) {
this.size = this.width + "x" + this.height;
}

@Override
public Integer getHeight() {
return height;
}
Expand All @@ -123,6 +127,7 @@ public void setHeight(Integer height) {
this.size = this.width + "x" + this.height;
}

@Override
public String getResponseFormat() {
return responseFormat;
}
Expand Down Expand Up @@ -158,6 +163,7 @@ public void setQuality(String quality) {
this.quality = quality;
}

@Override
public String getStyle() {
return style;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.image.observation.ImageModelRequestOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.AiOperationMetadata;
import org.springframework.ai.observation.conventions.AiOperationType;
Expand All @@ -40,7 +39,6 @@
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.List;

Expand Down Expand Up @@ -128,12 +126,13 @@ public OpenAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageOptions option

@Override
public ImageResponse call(ImagePrompt imagePrompt) {
OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(imagePrompt);
OpenAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);

var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt)
.operationMetadata(buildOperationMetadata())
.requestOptions(buildRequestOptions(imageRequest))
.requestOptions(requestImageOptions)
.build();

return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
Expand All @@ -151,23 +150,14 @@ public ImageResponse call(ImagePrompt imagePrompt) {
});
}

private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt) {
private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt,
OpenAiImageOptions requestImageOptions) {
String instructions = imagePrompt.getInstructions().get(0).getText();

OpenAiImageApi.OpenAiImageRequest imageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions,
OpenAiImageApi.DEFAULT_IMAGE_MODEL);

if (this.defaultOptions != null) {
imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest,
OpenAiImageApi.OpenAiImageRequest.class);
}

if (imagePrompt.getOptions() != null) {
imageRequest = ModelOptionsUtils.merge(toOpenAiImageOptions(imagePrompt.getOptions()), imageRequest,
OpenAiImageApi.OpenAiImageRequest.class);
}

return imageRequest;
return ModelOptionsUtils.merge(requestImageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class);
}

private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
Expand All @@ -188,44 +178,27 @@ private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageR
}

/**
* Convert the {@link ImageOptions} into {@link OpenAiImageOptions}.
* @param runtimeImageOptions the image options to use.
* @return the converted {@link OpenAiImageOptions}.
* Merge runtime and default {@link ImageOptions} to compute the final options to use
* in the request.
*/
private OpenAiImageOptions toOpenAiImageOptions(ImageOptions runtimeImageOptions) {
OpenAiImageOptions.Builder openAiImageOptionsBuilder = OpenAiImageOptions.builder();
if (runtimeImageOptions != null) {
private OpenAiImageOptions mergeOptions(ImageOptions runtimeOptions, OpenAiImageOptions defaultOptions) {
if (runtimeOptions == null) {
return defaultOptions;
}

return OpenAiImageOptions.builder()
// Handle portable image options
if (runtimeImageOptions.getN() != null) {
openAiImageOptionsBuilder.withN(runtimeImageOptions.getN());
}
if (runtimeImageOptions.getModel() != null) {
openAiImageOptionsBuilder.withModel(runtimeImageOptions.getModel());
}
if (runtimeImageOptions.getResponseFormat() != null) {
openAiImageOptionsBuilder.withResponseFormat(runtimeImageOptions.getResponseFormat());
}
if (runtimeImageOptions.getWidth() != null) {
openAiImageOptionsBuilder.withWidth(runtimeImageOptions.getWidth());
}
if (runtimeImageOptions.getHeight() != null) {
openAiImageOptionsBuilder.withHeight(runtimeImageOptions.getHeight());
}
.withModel(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel()))
.withN(ModelOptionsUtils.mergeOption(runtimeOptions.getN(), defaultOptions.getN()))
.withResponseFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getResponseFormat(),
defaultOptions.getResponseFormat()))
.withWidth(ModelOptionsUtils.mergeOption(runtimeOptions.getWidth(), defaultOptions.getWidth()))
.withHeight(ModelOptionsUtils.mergeOption(runtimeOptions.getHeight(), defaultOptions.getHeight()))
.withStyle(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle()))
// Handle OpenAI specific image options
if (runtimeImageOptions instanceof OpenAiImageOptions) {
OpenAiImageOptions runtimeOpenAiImageOptions = (OpenAiImageOptions) runtimeImageOptions;
if (runtimeOpenAiImageOptions.getQuality() != null) {
openAiImageOptionsBuilder.withQuality(runtimeOpenAiImageOptions.getQuality());
}
if (runtimeOpenAiImageOptions.getStyle() != null) {
openAiImageOptionsBuilder.withStyle(runtimeOpenAiImageOptions.getStyle());
}
if (runtimeOpenAiImageOptions.getUser() != null) {
openAiImageOptionsBuilder.withUser(runtimeOpenAiImageOptions.getUser());
}
}
}
return openAiImageOptionsBuilder.build();
.withQuality(defaultOptions.getQuality())
.withUser(defaultOptions.getUser())
.build();
}

private AiOperationMetadata buildOperationMetadata() {
Expand All @@ -235,17 +208,6 @@ private AiOperationMetadata buildOperationMetadata() {
.build();
}

private ImageModelRequestOptions buildRequestOptions(OpenAiImageApi.OpenAiImageRequest request) {
return ImageModelRequestOptions.builder()
.model(StringUtils.hasText(request.model()) ? request.model() : "unknown")
.n(request.n())
.width(request.size() != null ? Integer.parseInt(request.size().split("x")[0]) : null)
.height(request.size() != null ? Integer.parseInt(request.size().split("x")[1]) : null)
.responseFormat(request.responseFormat())
.style(request.style())
.build();
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ public void setHeight(Integer height) {
this.size = this.width + "x" + this.height;
}

@Override
public String getStyle() {
return this.style;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,17 @@ public Integer getHeight() {
return this.height;
}

@Override
public String getResponseFormat() {
return null;
}

public void setHeight(Integer height) {
this.height = height;
this.size = this.width + "x" + this.height;
}

@Override
public String getResponseFormat() {
return null;
}

@Override
public String getStyle() {
return this.style;
}
Expand All @@ -190,18 +191,17 @@ public void setUser(String user) {
this.user = user;
}

public void setSize(String size) {
this.size = size;
}

public String getSize() {

if (this.size != null) {
return this.size;
}
return (this.width != null && this.height != null) ? this.width + "x" + this.height : null;
}

public void setSize(String size) {
this.size = size;
}

@Override
public boolean equals(Object o) {
if (this == o)
Expand Down
Loading

0 comments on commit 17ba1fc

Please sign in to comment.