diff --git a/docs/changelog/134960.yaml b/docs/changelog/134960.yaml new file mode 100644 index 0000000000000..6be94846d785e --- /dev/null +++ b/docs/changelog/134960.yaml @@ -0,0 +1,5 @@ +pr: 134960 +summary: Adding custom headers support openai text embeddings +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/inference_api_openai_embeddings_headers.csv b/server/src/main/resources/transport/definitions/referable/inference_api_openai_embeddings_headers.csv new file mode 100644 index 0000000000000..cec6381023b6b --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/inference_api_openai_embeddings_headers.csv @@ -0,0 +1 @@ +9169000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 6e7d51d3d3020..b1209b927d8a5 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -security_stats_endpoint,9168000 +inference_api_openai_embeddings_headers,9169000 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index fe0468455767b..d2b7dcc527aaa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -485,9 +485,8 @@ public static InferenceServiceConfiguration get() { configurationMap.put( HEADERS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)).setDescription( - "Custom headers to include in the requests to OpenAI." - ) + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION)) + .setDescription("Custom headers to include in the requests to OpenAI.") .setLabel("Custom Headers") .setRequired(false) .setSensitive(false) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettings.java new file mode 100644 index 0000000000000..c0fd0e50fa9e6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettings.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openai; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; +import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS; +import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; + +public abstract class OpenAiTaskSettings> implements TaskSettings { + private static final Settings EMPTY_SETTINGS = new Settings(null, null); + + private final Settings taskSettings; + + public OpenAiTaskSettings(Map map) { + this(fromMap(map)); + } + + public record Settings(@Nullable String user, @Nullable Map headers) {} + + public static Settings createSettings(String user, Map stringHeaders) { + if (user == null && stringHeaders == null) { + return EMPTY_SETTINGS; + } else { + return new Settings(user, stringHeaders); + } + } + + private static Settings fromMap(Map map) { + if (map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); + Map headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException); + var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return createSettings(user, stringHeaders); + } + + public OpenAiTaskSettings(@Nullable String user, @Nullable Map headers) { + this(new Settings(user, headers)); + } + + protected OpenAiTaskSettings(Settings taskSettings) { + this.taskSettings = Objects.requireNonNull(taskSettings); + } + + public String user() { + return taskSettings.user(); + } + + public Map headers() { + return taskSettings.headers(); + } + + @Override + public boolean isEmpty() { + return taskSettings.user() == null && (taskSettings.headers() == null || taskSettings.headers().isEmpty()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (taskSettings.user() != null) { + builder.field(USER, taskSettings.user()); + } + + if (taskSettings.headers() != null && taskSettings.headers().isEmpty() == false) { + builder.field(HEADERS, taskSettings.headers()); + } + + builder.endObject(); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenAiTaskSettings that = (OpenAiTaskSettings) o; + return Objects.equals(taskSettings, that.taskSettings); + } + + @Override + public int hashCode() { + return Objects.hash(taskSettings); + } + + @Override + public T updatedTaskSettings(Map newSettings) { + Settings updatedSettings = fromMap(new HashMap<>(newSettings)); + + var userToUse = updatedSettings.user() == null ? taskSettings.user() : updatedSettings.user(); + var headersToUse = updatedSettings.headers() == null ? taskSettings.headers() : updatedSettings.headers(); + return create(userToUse, headersToUse); + } + + protected abstract T create(@Nullable String user, @Nullable Map headers); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index 67457e2102ef3..08b582ce8d13e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -35,8 +35,7 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< return model; } - var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings); - return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + return new OpenAiChatCompletionModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings)); } public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { @@ -73,7 +72,7 @@ public OpenAiChatCompletionModel( taskType, service, OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context), - OpenAiChatCompletionTaskSettings.fromMap(taskSettings), + new OpenAiChatCompletionTaskSettings(taskSettings), DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java deleted file mode 100644 index 889c14ed6a3d9..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.openai.completion; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; - -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; -import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS; -import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; - -/** - * This class handles extracting OpenAI task settings from a request. The difference between this class and - * {@link OpenAiChatCompletionTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field - * is missing. This allows overriding persistent task settings. - * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse - * @param headers additional headers to include in the request to the OpenAI API - */ -public record OpenAiChatCompletionRequestTaskSettings(@Nullable String user, @Nullable Map headers) { - - public static final OpenAiChatCompletionRequestTaskSettings EMPTY_SETTINGS = new OpenAiChatCompletionRequestTaskSettings(null, null); - - /** - * Extracts the task settings from a map. All settings are considered optional and the absence of a setting - * does not throw an error. - * - * @param map the settings received from a request - * @return a {@link OpenAiChatCompletionRequestTaskSettings} - */ - public static OpenAiChatCompletionRequestTaskSettings fromMap(Map map) { - if (map.isEmpty()) { - return OpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS; - } - - ValidationException validationException = new ValidationException(); - - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - Map headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException); - var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new OpenAiChatCompletionRequestTaskSettings(user, stringHeaders); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java index 4744773478a76..40013263f440d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java @@ -9,100 +9,44 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettings; import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import java.util.Objects; import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; -import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS; -import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; -public class OpenAiChatCompletionTaskSettings implements TaskSettings { +public class OpenAiChatCompletionTaskSettings extends OpenAiTaskSettings { public static final String NAME = "openai_completion_task_settings"; - public static OpenAiChatCompletionTaskSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - var headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException); - var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new OpenAiChatCompletionTaskSettings(user, stringHeaders); + public OpenAiChatCompletionTaskSettings(Map map) { + super(map); } - private final String user; - @Nullable - private final Map headers; - public OpenAiChatCompletionTaskSettings(@Nullable String user, @Nullable Map headers) { - this.user = user; - this.headers = headers; + super(user, headers); } public OpenAiChatCompletionTaskSettings(StreamInput in) throws IOException { - this.user = in.readOptionalString(); + super(readTaskSettingsFromStream(in)); + } + + private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException { + var user = in.readOptionalString(); + + Map headers; if (in.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) { headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString); } else { headers = null; } - } - - @Override - public boolean isEmpty() { - return user == null && (headers == null || headers.isEmpty()); - } - - public static OpenAiChatCompletionTaskSettings of( - OpenAiChatCompletionTaskSettings originalSettings, - OpenAiChatCompletionRequestTaskSettings requestSettings - ) { - var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); - var headersToUse = requestSettings.headers() == null ? originalSettings.headers : requestSettings.headers(); - return new OpenAiChatCompletionTaskSettings(userToUse, headersToUse); - } - - public String user() { - return user; - } - public Map headers() { - return headers; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - if (user != null) { - builder.field(USER, user); - } - - if (headers != null && headers.isEmpty() == false) { - builder.field(HEADERS, headers); - } - - builder.endObject(); - - return builder; + return createSettings(user, headers); } @Override @@ -117,30 +61,14 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(user); + out.writeOptionalString(user()); if (out.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) { - out.writeOptionalMap(headers, StreamOutput::writeString, StreamOutput::writeString); + out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString); } } @Override - public boolean equals(Object object) { - if (this == object) return true; - if (object == null || getClass() != object.getClass()) return false; - OpenAiChatCompletionTaskSettings that = (OpenAiChatCompletionTaskSettings) object; - return Objects.equals(user, that.user) && Objects.equals(headers, that.headers); - } - - @Override - public int hashCode() { - return Objects.hash(user, headers); - } - - @Override - public TaskSettings updatedTaskSettings(Map newSettings) { - OpenAiChatCompletionRequestTaskSettings updatedSettings = OpenAiChatCompletionRequestTaskSettings.fromMap( - new HashMap<>(newSettings) - ); - return of(this, updatedSettings); + protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map headers) { + return new OpenAiChatCompletionTaskSettings(user, headers); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java index 6d47334da43ae..b071890814cee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java @@ -34,8 +34,7 @@ public static OpenAiEmbeddingsModel of(OpenAiEmbeddingsModel model, Mapsee the openai docs for more details */ -public class OpenAiEmbeddingsTaskSettings implements TaskSettings { +public class OpenAiEmbeddingsTaskSettings extends OpenAiTaskSettings { public static final String NAME = "openai_embeddings_task_settings"; - public static OpenAiEmbeddingsTaskSettings fromMap(Map map, ConfigurationParseContext context) { - ValidationException validationException = new ValidationException(); - - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } + // default for testing + static final TransportVersion INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS = TransportVersion.fromName( + "inference_api_openai_embeddings_headers" + ); - return new OpenAiEmbeddingsTaskSettings(user); + public OpenAiEmbeddingsTaskSettings(Map map) { + super(map); } - /** - * Creates a new {@link OpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones - * passed in via requestSettings if the fields are not null. - * @param originalSettings the original task settings from the inference entity configuration from storage - * @param requestSettings the task settings from the request - * @return a new {@link OpenAiEmbeddingsTaskSettings} - */ - public static OpenAiEmbeddingsTaskSettings of( - OpenAiEmbeddingsTaskSettings originalSettings, - OpenAiEmbeddingsRequestTaskSettings requestSettings - ) { - var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); - return new OpenAiEmbeddingsTaskSettings(userToUse); + public OpenAiEmbeddingsTaskSettings(@Nullable String user, @Nullable Map headers) { + super(user, headers); } - private final String user; - - public OpenAiEmbeddingsTaskSettings(@Nullable String user) { - this.user = user; + public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { + super(readTaskSettingsFromStream(in)); } - @Override - public boolean isEmpty() { - return user == null; - } + private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException { + String user; - public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { - this.user = in.readOptionalString(); + user = in.readOptionalString(); } else { var discard = in.readString(); - this.user = in.readOptionalString(); + user = in.readOptionalString(); } - } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (user != null) { - builder.field(USER, user); + Map headers; + + if (in.getTransportVersion().supports(INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS)) { + headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString); + } else { + headers = null; } - builder.endObject(); - return builder; - } - public String user() { - return user; + return createSettings(user, headers); } @Override @@ -109,29 +78,19 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { - out.writeOptionalString(user); + out.writeOptionalString(user()); } else { out.writeString("m"); // write any string - out.writeOptionalString(user); + out.writeOptionalString(user()); } - } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - OpenAiEmbeddingsTaskSettings that = (OpenAiEmbeddingsTaskSettings) o; - return Objects.equals(user, that.user); - } - - @Override - public int hashCode() { - return Objects.hash(user); + if (out.getTransportVersion().supports(INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS)) { + out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString); + } } @Override - public TaskSettings updatedTaskSettings(Map newSettings) { - OpenAiEmbeddingsRequestTaskSettings requestSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(newSettings)); - return of(this, requestSettings); + protected OpenAiEmbeddingsTaskSettings create(@Nullable String user, @Nullable Map headers) { + return new OpenAiEmbeddingsTaskSettings(user, headers); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequest.java index 356db6630cbdc..1413938d3684b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequest.java @@ -60,6 +60,12 @@ public HttpRequest createHttpRequest() { httpPost.setHeader(createOrgHeader(org)); } + if (model.getTaskSettings().headers() != null) { + for (var header : model.getTaskSettings().headers().entrySet()) { + httpPost.setHeader(header.getKey(), header.getValue()); + } + } + return new HttpRequest(httpPost, getInferenceEntityId()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index f5d1c71384d07..676dca2778141 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -87,7 +87,7 @@ import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap; -import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests.getOpenAiTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; @@ -140,7 +140,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOExc TaskType.TEXT_EMBEDDING, getRequestConfigMap( getServiceSettingsMap("model", "url", "org"), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ), modelVerificationListener @@ -174,7 +174,7 @@ public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() throws TaskType.COMPLETION, getRequestConfigMap( getServiceSettingsMap(model, url, organization), - getTaskSettingsMap(user), + getOpenAiTaskSettingsMap(user), getSecretSettingsMap(secret) ), modelVerificationListener @@ -197,7 +197,7 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti TaskType.SPARSE_EMBEDDING, getRequestConfigMap( getServiceSettingsMap("model", "url", "org"), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ), modelVerificationListener @@ -209,7 +209,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I try (var service = createOpenAiService()) { var config = getRequestConfigMap( getServiceSettingsMap("model", "url", "org"), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ); config.put("extra_key", "value"); @@ -234,7 +234,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var serviceSettings = getServiceSettingsMap("model", "url", "org"); serviceSettings.put("extra_key", "value"); - var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap("user"), getSecretSettingsMap("secret")); + var config = getRequestConfigMap(serviceSettings, getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret")); ActionListener modelVerificationListener = ActionListener.wrap((model) -> { fail("Expected exception, but got model: " + model); @@ -249,7 +249,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { try (var service = createOpenAiService()) { - var taskSettingsMap = getTaskSettingsMap("user"); + var taskSettingsMap = getOpenAiTaskSettingsMap("user"); taskSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), taskSettingsMap, getSecretSettingsMap("secret")); @@ -270,7 +270,11 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); - var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), getTaskSettingsMap("user"), secretSettingsMap); + var config = getRequestConfigMap( + getServiceSettingsMap("model", "url", "org"), + getOpenAiTaskSettingsMap("user"), + secretSettingsMap + ); ActionListener modelVerificationListener = ActionListener.wrap((model) -> { fail("Expected exception, but got model: " + model); @@ -299,7 +303,11 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlO service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")), + getRequestConfigMap( + getServiceSettingsMap("model", null, null), + getOpenAiTaskSettingsMap(null), + getSecretSettingsMap("secret") + ), modelVerificationListener ); } @@ -325,7 +333,7 @@ public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModelWithoutUse service.parseRequestConfig( "id", TaskType.COMPLETION, - getRequestConfigMap(getServiceSettingsMap(model, null, null), getTaskSettingsMap(null), getSecretSettingsMap(secret)), + getRequestConfigMap(getServiceSettingsMap(model, null, null), getOpenAiTaskSettingsMap(null), getSecretSettingsMap(secret)), modelVerificationListener ); } @@ -349,7 +357,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { TaskType.TEXT_EMBEDDING, getRequestConfigMap( getServiceSettingsMap("model", "url", "org"), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ), modelVerificationListener @@ -376,7 +384,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet TaskType.TEXT_EMBEDDING, getRequestConfigMap( getServiceSettingsMap("model", null, null), - getTaskSettingsMap(null), + getOpenAiTaskSettingsMap(null), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") ), @@ -402,7 +410,11 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")), + getRequestConfigMap( + getServiceSettingsMap("model", null, null), + getOpenAiTaskSettingsMap(null), + getSecretSettingsMap("secret") + ), modelVerificationListener ); } @@ -412,7 +424,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org", 100, null, false), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -438,7 +450,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org"), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -463,7 +475,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWi try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", null, null, null, null, true), - getTaskSettingsMap(null), + getOpenAiTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -489,7 +501,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWh try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", null, null, null, null, true), - getTaskSettingsMap(null), + getOpenAiTaskSettingsMap(null), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") ); @@ -517,7 +529,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWh try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", null, null, null, null, true), - getTaskSettingsMap(null), + getOpenAiTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -544,7 +556,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org", null, null, true), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ); persistedConfig.config().put("extra_key", "value"); @@ -575,7 +587,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org", null, null, true), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), secretSettingsMap ); @@ -601,7 +613,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org", null, null, true), - getTaskSettingsMap("user"), + getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret") ); persistedConfig.secrets().put("extra_key", "value"); @@ -630,7 +642,11 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); serviceSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user"), getSecretSettingsMap("secret")); + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + getOpenAiTaskSettingsMap("user"), + getSecretSettingsMap("secret") + ); var model = service.parsePersistedConfigWithSecrets( "id", @@ -652,7 +668,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createOpenAiService()) { - var taskSettingsMap = getTaskSettingsMap("user"); + var taskSettingsMap = getOpenAiTaskSettingsMap("user"); taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -683,7 +699,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOE try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org", null, null, true), - getTaskSettingsMap("user") + getOpenAiTaskSettingsMap("user") ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -701,7 +717,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOE public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getTaskSettingsMap("user")); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getOpenAiTaskSettingsMap("user")); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -719,7 +735,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUr try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", null, null, null, null, true), - getTaskSettingsMap(null) + getOpenAiTaskSettingsMap(null) ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -739,7 +755,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingS try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", null, null, null, null, true), - getTaskSettingsMap(null), + getOpenAiTaskSettingsMap(null), createRandomChunkingSettingsMap() ); @@ -761,7 +777,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingS try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", null, null, null, null, true), - getTaskSettingsMap(null) + getOpenAiTaskSettingsMap(null) ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -782,7 +798,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap("model", "url", "org", null, null, true), - getTaskSettingsMap("user") + getOpenAiTaskSettingsMap("user") ); persistedConfig.config().put("extra_key", "value"); @@ -804,7 +820,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); serviceSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user")); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getOpenAiTaskSettingsMap("user")); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -821,7 +837,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createOpenAiService()) { - var taskSettingsMap = getTaskSettingsMap("user"); + var taskSettingsMap = getOpenAiTaskSettingsMap("user"); taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap); @@ -1644,7 +1660,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": true, "type": "map", - "supported_task_types": ["completion", "chat_completion"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java new file mode 100644 index 0000000000000..9dd1c167b374c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.is; + +public abstract class OpenAiTaskSettingsTests> extends AbstractBWCWireSerializationTestCase { + + private enum HeadersDefinition { + NULL(null), + EMPTY(Map.of()), + DEFINED(Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15))); + + private final Map headers; + + HeadersDefinition(@Nullable Map headers) { + this.headers = headers; + } + } + + public T createRandom() { + var user = randomBoolean() ? null : randomAlphaOfLength(15); + var headers = randomFrom(HeadersDefinition.values()).headers; + + return create(user, headers); + } + + public void testIsEmpty() { + var bothNull = create(null, null); + assertTrue(bothNull.isEmpty()); + + var nullUserEmptyHeaders = create(null, Map.of()); + assertTrue(nullUserEmptyHeaders.isEmpty()); + + var nullHeaders = create("user", null); + assertFalse(nullHeaders.isEmpty()); + + var nullUser = create(null, Map.of("K", "v")); + assertFalse(nullUser.isEmpty()); + + var neitherNull = create("user", Map.of("K", "v")); + assertFalse(neitherNull.isEmpty()); + } + + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + + Map newSettingsMap = new HashMap<>(); + if (newSettings.user() != null) { + newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); + } + + if (newSettings.headers() != null) { + newSettingsMap.put(OpenAiServiceFields.HEADERS, newSettings.headers()); + } + + var updatedSettings = initialSettings.updatedTaskSettings(Collections.unmodifiableMap(newSettingsMap)); + + if (newSettings.user() == null) { + assertEquals(initialSettings.user(), updatedSettings.user()); + } else { + assertEquals(newSettings.user(), updatedSettings.user()); + } + + if (newSettings.headers() == null) { + assertEquals(initialSettings.headers(), updatedSettings.headers()); + } else { + assertEquals(newSettings.headers(), updatedSettings.headers()); + } + } + + public void testUpdatedTaskSettings_ApplyingEmptyHeaders() { + var user = "user"; + var initialSettingsNullHeaders = create(user, null); + Map newSettingsMap = Map.of(OpenAiServiceFields.HEADERS, Map.of()); + + var updatedSettings = initialSettingsNullHeaders.updatedTaskSettings(newSettingsMap); + assertThat(updatedSettings, is(create(user, Map.of()))); + + var initialSettingsDefinedHeaders = create(user, Map.of("key", "value")); + updatedSettings = initialSettingsDefinedHeaders.updatedTaskSettings(newSettingsMap); + assertThat(updatedSettings, is(create(user, Map.of()))); + } + + public void testUpdatedTaskSettings_KeepsOriginalValuesWithOverridesAreNull() { + var taskSettings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); + + assertThat(taskSettings.updatedTaskSettings(Map.of()), is(taskSettings)); + } + + public void testUpdatedTaskSettings_UsesOverriddenSettings() { + var taskSettings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); + + assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.USER, "user2")), is(create("user2", null))); + } + + public void testUpdatedTaskSettings_UsesOverriddenSettings_ForHeaders() { + var user = "user"; + var taskSettings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, user))); + + var headers = Map.of("key", "value"); + assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.HEADERS, headers)), is(create(user, headers))); + } + + public void testFromMap_WithUserAndHeaders() { + assertThat( + createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, Map.of("key", "value")))), + is(create("user", Map.of("key", "value"))) + ); + } + + public void testFromMap_UserIsEmptyString() { + var thrownException = expectThrows( + ValidationException.class, + () -> createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, ""))) + ); + + assertThat( + thrownException.getMessage(), + is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) + ); + } + + public void testFromMap_MissingUser_DoesNotThrowException() { + var taskSettings = createFromMap(new HashMap<>(Map.of())); + assertNull(taskSettings.user()); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + var settings = createFromMap(new HashMap<>(Map.of("key", "value"))); + assertNull(settings.user()); + assertNull(settings.headers()); + } + + public void testFromMap_ParsesCorrectly_WhenUserIsNull() { + var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", "value"))))); + + assertNull(settings.user()); + assertThat(settings.headers(), is(Map.of("key", "value"))); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() { + var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); + + assertThat(settings.user(), is("user")); + assertNull(settings.headers()); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersIsEmptyMap() { + var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, Map.of()))); + + assertThat(settings.user(), is("user")); + assertThat(settings.headers(), anEmptyMap()); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersMapOfNulls() { + var headersMap = new HashMap(); + headersMap.put("key1", null); + headersMap.put("key2", null); + var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, headersMap))); + + assertThat(settings.user(), is("user")); + assertThat(settings.headers(), anEmptyMap()); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersContainsAnInteger() { + var exception = expectThrows( + ValidationException.class, + () -> createFromMap( + new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", 1)))) + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [headers] has an entry that is not valid, " + + "[key => 1]. Value type of [1] is not one of [String].;" + ) + ); + } + + @Override + protected T mutateInstance(T instance) throws IOException { + var setNull = randomBoolean(); + var fieldToMutate = randomIntBetween(0, 1); + + return switch (fieldToMutate) { + case 0 -> create( + instance.user() == null ? randomAlphaOfLength(15) : (setNull ? null : instance.user() + "modified"), + instance.headers() + ); + case 1 -> { + if (instance.headers() == null) { + yield create(instance.user(), Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15))); + } else if (setNull) { + yield create(instance.user(), null); + } else { + var instanceHeaders = new HashMap<>(instance.headers() == null ? Map.of() : instance.headers()); + instanceHeaders.put(randomAlphaOfLength(15), randomAlphaOfLength(15)); + yield create(instance.user(), instanceHeaders); + } + } + default -> throw new IllegalStateException("Unexpected value: " + fieldToMutate); + }; + } + + protected abstract T create(@Nullable String user, @Nullable Map headers); + + protected abstract T createFromMap(Map map); + + public static Map getOpenAiTaskSettingsMap(@Nullable String user) { + var map = new HashMap(); + + if (user != null) { + map.put(OpenAiServiceFields.USER, user); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index a8343c4b75af7..6b63212a308a1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -44,9 +44,9 @@ import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests.getOpenAiTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap; import static org.hamcrest.Matchers.equalTo; @@ -348,7 +348,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var model = createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); - var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user"); + var overriddenTaskSettings = getOpenAiTaskSettingsMap("overridden_user"); var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); @@ -412,7 +412,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var model = createCompletionModel(getUrl(webServer), "org", "secret", "model", null); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); - var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap(null); + var overriddenTaskSettings = getOpenAiTaskSettingsMap(null); var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); @@ -475,7 +475,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var model = createCompletionModel(getUrl(webServer), null, "secret", "model", null); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); - var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user"); + var overriddenTaskSettings = getOpenAiTaskSettingsMap("overridden_user"); var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); @@ -544,7 +544,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( var model = createCompletionModel(getUrl(webServer), null, "secret", "model", null); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); - var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user"); + var overriddenTaskSettings = getOpenAiTaskSettingsMap("overridden_user"); var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index f92335e93fefd..1385780e2d6ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -17,7 +17,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests.getOpenAiTaskSettingsMap; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; @@ -25,7 +25,7 @@ public class OpenAiChatCompletionModelTests extends ESTestCase { public void testOverrideWith_OverridesUser() { var model = createCompletionModel("url", "org", "api_key", "model_name", null); - var requestTaskSettingsMap = getChatCompletionRequestTaskSettingsMap("user_override"); + var requestTaskSettingsMap = getOpenAiTaskSettingsMap("user_override"); var overriddenModel = OpenAiChatCompletionModel.of(model, requestTaskSettingsMap); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettingsTests.java deleted file mode 100644 index abd4f0f853353..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettingsTests.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.openai.completion; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; - -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class OpenAiChatCompletionRequestTaskSettingsTests extends ESTestCase { - - public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { - var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); - assertNull(settings.user()); - } - - public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { - var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "value"))); - assertNull(settings.user()); - } - - public void testFromMap_ParsesCorrectly() { - var settings = OpenAiChatCompletionRequestTaskSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", "value")))) - ); - - assertThat(settings.user(), is("user")); - assertThat(settings.headers(), is(Map.of("key", "value"))); - } - - public void testFromMap_ParsesCorrectly_WhenUserIsNull() { - var settings = OpenAiChatCompletionRequestTaskSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", "value")))) - ); - - assertNull(settings.user()); - assertThat(settings.headers(), is(Map.of("key", "value"))); - } - - public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() { - var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); - - assertThat(settings.user(), is("user")); - assertNull(settings.headers()); - } - - public static Map getChatCompletionRequestTaskSettingsMap(@Nullable String user) { - var map = new HashMap(); - - if (user != null) { - map.put(OpenAiServiceFields.USER, user); - } - - return map; - } - -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java index 305392f299da8..a1410ec2a32fa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java @@ -8,124 +8,15 @@ package org.elasticsearch.xpack.inference.services.openai.completion; import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; import java.util.Map; import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS; -import static org.hamcrest.Matchers.is; -public class OpenAiChatCompletionTaskSettingsTests extends AbstractBWCWireSerializationTestCase { - - public static OpenAiChatCompletionTaskSettings createRandomWithUser() { - return new OpenAiChatCompletionTaskSettings( - randomBoolean() ? null : randomAlphaOfLength(15), - randomBoolean() ? null : Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)) - ); - } - - public void testIsEmpty() { - var randomSettings = new OpenAiChatCompletionTaskSettings( - randomBoolean() ? null : "username", - randomBoolean() ? null : Map.of("key", "value") - ); - var stringRep = Strings.toString(randomSettings); - - assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); - } - - public void testUpdatedTaskSettings() { - var initialSettings = createRandomWithUser(); - var newSettings = createRandomWithUser(); - - Map newSettingsMap = new HashMap<>(); - if (newSettings.user() != null) { - newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); - } - - if (newSettings.headers() != null && newSettings.headers().isEmpty() == false) { - newSettingsMap.put(OpenAiServiceFields.HEADERS, newSettings.headers()); - } - - OpenAiChatCompletionTaskSettings updatedSettings = (OpenAiChatCompletionTaskSettings) initialSettings.updatedTaskSettings( - Collections.unmodifiableMap(newSettingsMap) - ); - - if (newSettings.user() == null) { - assertEquals(initialSettings.user(), updatedSettings.user()); - } else { - assertEquals(newSettings.user(), updatedSettings.user()); - } - - if (newSettings.headers() == null) { - assertEquals(initialSettings.headers(), updatedSettings.headers()); - } else { - assertEquals(newSettings.headers(), updatedSettings.headers()); - } - } - - public void testFromMap_WithUser() { - assertEquals( - new OpenAiChatCompletionTaskSettings("user", null), - OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))) - ); - } - - public void testFromMap_UserIsEmptyString() { - var thrownException = expectThrows( - ValidationException.class, - () -> OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, ""))) - ); - - assertThat( - thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) - ); - } - - public void testFromMap_MissingUser_DoesNotThrowException() { - var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of())); - assertNull(taskSettings.user()); - } - - public void testOf_KeepsOriginalValuesWithOverridesAreNull() { - var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); - - var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of( - taskSettings, - OpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS - ); - assertThat(overriddenTaskSettings, is(taskSettings)); - } - - public void testOf_UsesOverriddenSettings() { - var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); - - var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user2"))); - - var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(taskSettings, requestTaskSettings); - assertThat(overriddenTaskSettings, is(new OpenAiChatCompletionTaskSettings("user2", null))); - } - - public void testOf_UsesOverriddenSettings_ForHeaders() { - var user = "user"; - var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, user))); - - var headers = Map.of("key", "value"); - var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.HEADERS, headers)) - ); - - var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(taskSettings, requestTaskSettings); - assertThat(overriddenTaskSettings, is(new OpenAiChatCompletionTaskSettings(user, headers))); - } +public class OpenAiChatCompletionTaskSettingsTests extends OpenAiTaskSettingsTests { @Override protected Writeable.Reader instanceReader() { @@ -134,31 +25,7 @@ protected Writeable.Reader instanceReader() { @Override protected OpenAiChatCompletionTaskSettings createTestInstance() { - return createRandomWithUser(); - } - - @Override - protected OpenAiChatCompletionTaskSettings mutateInstance(OpenAiChatCompletionTaskSettings instance) throws IOException { - var setNull = randomBoolean(); - var fieldToMutate = randomIntBetween(0, 1); - return switch (fieldToMutate) { - case 0 -> new OpenAiChatCompletionTaskSettings( - instance.user() == null ? randomAlphaOfLength(15) : (setNull ? null : instance.user() + "modified"), - instance.headers() - ); - case 1 -> { - if (instance.headers() == null) { - yield new OpenAiChatCompletionTaskSettings(instance.user(), Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15))); - } else if (setNull) { - yield new OpenAiChatCompletionTaskSettings(instance.user(), null); - } else { - var instanceHeaders = new HashMap<>(instance.headers() == null ? Map.of() : instance.headers()); - instanceHeaders.put(randomAlphaOfLength(15), randomAlphaOfLength(15)); - yield new OpenAiChatCompletionTaskSettings(instance.user(), instanceHeaders); - } - } - default -> throw new IllegalStateException("Unexpected value: " + fieldToMutate); - }; + return createRandom(); } @Override @@ -170,6 +37,16 @@ protected OpenAiChatCompletionTaskSettings mutateInstanceForVersion( return instance; } - return new OpenAiChatCompletionTaskSettings(instance.user(), null); + return create(instance.user(), null); + } + + @Override + protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map headers) { + return new OpenAiChatCompletionTaskSettings(user, headers); + } + + @Override + protected OpenAiChatCompletionTaskSettings createFromMap(@Nullable Map map) { + return new OpenAiChatCompletionTaskSettings(map); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java index 0e9179792b92b..c57fb24a09d12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java @@ -61,7 +61,7 @@ public static OpenAiEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null), - new OpenAiEmbeddingsTaskSettings(user), + new OpenAiEmbeddingsTaskSettings(user, null), null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -80,7 +80,7 @@ public static OpenAiEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null), - new OpenAiEmbeddingsTaskSettings(user), + new OpenAiEmbeddingsTaskSettings(user, null), chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -98,7 +98,7 @@ public static OpenAiEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null), - new OpenAiEmbeddingsTaskSettings(user), + new OpenAiEmbeddingsTaskSettings(user, null), null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -117,7 +117,7 @@ public static OpenAiEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit, false, null), - new OpenAiEmbeddingsTaskSettings(user), + new OpenAiEmbeddingsTaskSettings(user, null), null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -137,7 +137,7 @@ public static OpenAiEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, false, null), - new OpenAiEmbeddingsTaskSettings(user), + new OpenAiEmbeddingsTaskSettings(user, null), null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -159,7 +159,7 @@ public static OpenAiEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, similarityMeasure, dimensions, tokenLimit, dimensionsSetByUser, null), - new OpenAiEmbeddingsTaskSettings(user), + new OpenAiEmbeddingsTaskSettings(user, null), null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java index a898620ce1d6c..426a8e2c6c318 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java @@ -7,107 +7,15 @@ package org.elasticsearch.xpack.inference.services.openai.embeddings; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; -import org.hamcrest.MatcherAssert; +import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; import java.util.Map; -import static org.hamcrest.Matchers.is; +import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings.INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS; -public class OpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { - - public static OpenAiEmbeddingsTaskSettings createRandomWithUser() { - return new OpenAiEmbeddingsTaskSettings(randomAlphaOfLength(15)); - } - - /** - * The created settings can have the user set to null. - */ - public static OpenAiEmbeddingsTaskSettings createRandom() { - var user = randomBoolean() ? randomAlphaOfLength(15) : null; - return new OpenAiEmbeddingsTaskSettings(user); - } - - public void testIsEmpty() { - var randomSettings = new OpenAiEmbeddingsTaskSettings(randomBoolean() ? null : "username"); - var stringRep = Strings.toString(randomSettings); - assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); - } - - public void testUpdatedTaskSettings() { - var initialSettings = createRandom(); - var newSettings = createRandom(); - Map newSettingsMap = new HashMap<>(); - if (newSettings.user() != null) { - newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); - } - OpenAiEmbeddingsTaskSettings updatedSettings = (OpenAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( - Collections.unmodifiableMap(newSettingsMap) - ); - if (newSettings.user() == null) { - assertEquals(initialSettings.user(), updatedSettings.user()); - } else { - assertEquals(newSettings.user(), updatedSettings.user()); - } - } - - public void testFromMap_WithUser() { - assertEquals( - new OpenAiEmbeddingsTaskSettings("user"), - OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")), ConfigurationParseContext.REQUEST) - ); - } - - public void testFromMap_UserIsEmptyString() { - var thrownException = expectThrows( - ValidationException.class, - () -> OpenAiEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.USER, "")), - ConfigurationParseContext.REQUEST - ) - ); - - MatcherAssert.assertThat( - thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) - ); - } - - public void testFromMap_MissingUser_DoesNotThrowException() { - var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.PERSISTENT); - assertNull(taskSettings.user()); - } - - public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { - var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")), - ConfigurationParseContext.PERSISTENT - ); - - var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS); - MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); - } - - public void testOverrideWith_UsesOverriddenSettings() { - var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")), - ConfigurationParseContext.PERSISTENT - ); - - var requestTaskSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user2"))); - - var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); - MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("user2"))); - } +public class OpenAiEmbeddingsTaskSettingsTests extends OpenAiTaskSettingsTests { @Override protected Writeable.Reader instanceReader() { @@ -116,21 +24,25 @@ protected Writeable.Reader instanceReader() { @Override protected OpenAiEmbeddingsTaskSettings createTestInstance() { - return createRandomWithUser(); + return createRandom(); } @Override - protected OpenAiEmbeddingsTaskSettings mutateInstance(OpenAiEmbeddingsTaskSettings instance) throws IOException { - return randomValueOtherThan(instance, OpenAiEmbeddingsTaskSettingsTests::createRandomWithUser); + protected OpenAiEmbeddingsTaskSettings create(String user, Map headers) { + return new OpenAiEmbeddingsTaskSettings(user, headers); } - public static Map getTaskSettingsMap(@Nullable String user) { - var map = new HashMap(); + @Override + protected OpenAiEmbeddingsTaskSettings createFromMap(Map map) { + return new OpenAiEmbeddingsTaskSettings(map); + } - if (user != null) { - map.put(OpenAiServiceFields.USER, user); + @Override + protected OpenAiEmbeddingsTaskSettings mutateInstanceForVersion(OpenAiEmbeddingsTaskSettings instance, TransportVersion version) { + if (version.supports(INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS)) { + return instance; } - return map; + return create(instance.user(), null); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequestTests.java index 6cc909efaf12a..ee4df9e35d9f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequestTests.java @@ -9,16 +9,24 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.io.IOException; +import java.net.URI; import java.net.URISyntaxException; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; @@ -28,17 +36,37 @@ import static org.hamcrest.Matchers.is; public class OpenAiEmbeddingsRequestTests extends ESTestCase { - public void testCreateRequest_WithUrlOrganizationUserDefined() throws URISyntaxException, IOException { - var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); + public void testCreateRequest_WithUrlOrganizationUser_AndCustomHeadersDefined() throws IOException { + + var headerKey = "key"; + var headerValue = "value"; + + var model = new OpenAiEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new OpenAiEmbeddingsServiceSettings("model", URI.create("www.elastic.co"), "org", null, null, null, false, null), + new OpenAiEmbeddingsTaskSettings("user", Map.of(headerKey, headerValue)), + null, + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + + var request = new OpenAiEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("abc"), new boolean[] { false }), + model + ); + var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("www.google.com")); + assertThat(httpPost.getURI().toString(), is("www.elastic.co")); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); + assertThat(httpPost.getLastHeader(headerKey).getValue(), is(headerValue)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap, aMapWithSize(3));