Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5412548
Adding custom headers support openai text embeddings
jonathan-buttner Sep 17, 2025
579eb18
Update docs/changelog/134960.yaml
jonathan-buttner Sep 17, 2025
8277bb7
Adding headers to the service api result
jonathan-buttner Sep 17, 2025
a9eda90
[CI] Auto commit changes from spotless
Sep 17, 2025
f52efab
Merge branch 'main' of github.com:elastic/elasticsearch into ml-opena…
jonathan-buttner Sep 18, 2025
2894254
Merge branch 'ml-openai-headers-embedding' of github.com:jonathan-but…
jonathan-buttner Sep 18, 2025
fc3457b
Merge branch 'main' into ml-openai-headers-embedding
jonathan-buttner Sep 18, 2025
4890f2d
Addressing feedback
jonathan-buttner Sep 22, 2025
2428b29
Merge branch 'main' of github.com:elastic/elasticsearch into ml-opena…
jonathan-buttner Sep 22, 2025
6f9940a
Adding transport version change
jonathan-buttner Sep 22, 2025
a8087ac
[CI] Auto commit changes from spotless
Sep 22, 2025
d79dc7d
Cleaning up helpers
jonathan-buttner Sep 22, 2025
bd63b49
Merge branch 'ml-openai-headers-embedding' of github.com:jonathan-but…
jonathan-buttner Sep 22, 2025
41bea77
[CI] Auto commit changes from spotless
Sep 22, 2025
881d97a
Merge branch 'main' of github.com:elastic/elasticsearch into ml-opena…
jonathan-buttner Sep 22, 2025
a930d0c
Fixing transport version
jonathan-buttner Sep 22, 2025
a29f5c4
Merge branch 'ml-openai-headers-embedding' of github.com:jonathan-but…
jonathan-buttner Sep 22, 2025
df8d8fe
Merge branch 'main' into ml-openai-headers-embedding
jonathan-buttner Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/134960.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 134960
summary: Adding custom headers support openai text embeddings
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9169000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
security_stats_endpoint,9168000
inference_api_openai_embeddings_headers,9169000
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,8 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
HEADERS,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)).setDescription(
"Custom headers to include in the requests to OpenAI."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
.setDescription("Custom headers to include in the requests to OpenAI.")
.setLabel("Custom Headers")
.setRequired(false)
.setSensitive(false)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.openai;

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;

public abstract class OpenAiTaskSettings<T extends OpenAiTaskSettings<T>> implements TaskSettings {
private static final Settings EMPTY_SETTINGS = new Settings(null, null);

private final Settings taskSettings;

public OpenAiTaskSettings(Map<String, Object> map) {
this(fromMap(map));
}

public record Settings(@Nullable String user, @Nullable Map<String, String> headers) {}

public static Settings createSettings(String user, Map<String, String> stringHeaders) {
if (user == null && stringHeaders == null) {
return EMPTY_SETTINGS;
} else {
return new Settings(user, stringHeaders);
}
}

private static Settings fromMap(Map<String, Object> map) {
if (map.isEmpty()) {
return EMPTY_SETTINGS;
}

ValidationException validationException = new ValidationException();

String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
Map<String, Object> headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return createSettings(user, stringHeaders);
}

public OpenAiTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
this(new Settings(user, headers));
}

protected OpenAiTaskSettings(Settings taskSettings) {
this.taskSettings = Objects.requireNonNull(taskSettings);
}

public String user() {
return taskSettings.user();
}

public Map<String, String> headers() {
return taskSettings.headers();
}

@Override
public boolean isEmpty() {
return taskSettings.user() == null && (taskSettings.headers() == null || taskSettings.headers().isEmpty());
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

if (taskSettings.user() != null) {
builder.field(USER, taskSettings.user());
}

if (taskSettings.headers() != null && taskSettings.headers().isEmpty() == false) {
builder.field(HEADERS, taskSettings.headers());
}

builder.endObject();

return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OpenAiTaskSettings<?> that = (OpenAiTaskSettings<?>) o;
return Objects.equals(taskSettings, that.taskSettings);
}

@Override
public int hashCode() {
return Objects.hash(taskSettings);
}

@Override
public T updatedTaskSettings(Map<String, Object> newSettings) {
Settings updatedSettings = fromMap(new HashMap<>(newSettings));

var userToUse = updatedSettings.user() == null ? taskSettings.user() : updatedSettings.user();
var headersToUse = updatedSettings.headers() == null ? taskSettings.headers() : updatedSettings.headers();
return create(userToUse, headersToUse);
}

protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers);

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map<
return model;
}

var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings);
return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
return new OpenAiChatCompletionModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings));
}

public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) {
Expand Down Expand Up @@ -73,7 +72,7 @@ public OpenAiChatCompletionModel(
taskType,
service,
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
new OpenAiChatCompletionTaskSettings(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,100 +9,44 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettings;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;

public class OpenAiChatCompletionTaskSettings implements TaskSettings {
public class OpenAiChatCompletionTaskSettings extends OpenAiTaskSettings<OpenAiChatCompletionTaskSettings> {

public static final String NAME = "openai_completion_task_settings";

public static OpenAiChatCompletionTaskSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
var headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new OpenAiChatCompletionTaskSettings(user, stringHeaders);
public OpenAiChatCompletionTaskSettings(Map<String, Object> map) {
super(map);
}

private final String user;
@Nullable
private final Map<String, String> headers;

public OpenAiChatCompletionTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
this.user = user;
this.headers = headers;
super(user, headers);
}

public OpenAiChatCompletionTaskSettings(StreamInput in) throws IOException {
this.user = in.readOptionalString();
super(readTaskSettingsFromStream(in));
}

private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException {
var user = in.readOptionalString();

Map<String, String> headers;

if (in.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString);
} else {
headers = null;
}
}

@Override
public boolean isEmpty() {
return user == null && (headers == null || headers.isEmpty());
}

public static OpenAiChatCompletionTaskSettings of(
OpenAiChatCompletionTaskSettings originalSettings,
OpenAiChatCompletionRequestTaskSettings requestSettings
) {
var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
var headersToUse = requestSettings.headers() == null ? originalSettings.headers : requestSettings.headers();
return new OpenAiChatCompletionTaskSettings(userToUse, headersToUse);
}

public String user() {
return user;
}

public Map<String, String> headers() {
return headers;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

if (user != null) {
builder.field(USER, user);
}

if (headers != null && headers.isEmpty() == false) {
builder.field(HEADERS, headers);
}

builder.endObject();

return builder;
return createSettings(user, headers);
}

@Override
Expand All @@ -117,30 +61,14 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(user);
out.writeOptionalString(user());
if (out.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
out.writeOptionalMap(headers, StreamOutput::writeString, StreamOutput::writeString);
out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString);
}
}

@Override
public boolean equals(Object object) {
if (this == object) return true;
if (object == null || getClass() != object.getClass()) return false;
OpenAiChatCompletionTaskSettings that = (OpenAiChatCompletionTaskSettings) object;
return Objects.equals(user, that.user) && Objects.equals(headers, that.headers);
}

@Override
public int hashCode() {
return Objects.hash(user, headers);
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
OpenAiChatCompletionRequestTaskSettings updatedSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(
new HashMap<>(newSettings)
);
return of(this, updatedSettings);
protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map<String, String> headers) {
return new OpenAiChatCompletionTaskSettings(user, headers);
}
}
Loading