Skip to content

Commit

Permalink
[ML] Stream Inference API (elastic#113158) (elastic#113423)
Browse files Browse the repository at this point in the history
Create `POST _inference/<task>/<id>/_stream` and
`POST _inference/<id>/_stream` API.

REST Streaming API will reuse InferenceAction.
For now, all services and task types will return an
HTTP 405 status code and error message.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
prwhelan and elasticmachine authored Sep 24, 2024
1 parent 9a21ca6 commit cb42fd4
Show file tree
Hide file tree
Showing 24 changed files with 798 additions and 121 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/113158.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 113158
summary: Adds a new Inference API for streaming responses back to the user.
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,21 @@ default boolean isInClusterService() {
* @return {@link TransportVersion} specifying the version
*/
TransportVersion getMinimalSupportedVersion();

/**
* The set of tasks where this service provider supports using the streaming API.
* @return set of supported task types. Defaults to empty.
*/
default Set<TaskType> supportedStreamingTasks() {
return Set.of();
}

/**
* Checks the task type against the set of supported streaming tasks returned by {@link #supportedStreamingTasks()}.
* @param taskType the task that supports streaming
* @return true if the taskType is supported
*/
default boolean canStream(TaskType taskType) {
return supportedStreamingTasks().contains(taskType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
private final Map<String, Object> taskSettings;
private final InputType inputType;
private final TimeValue inferenceTimeout;
private final boolean stream;

public Request(
TaskType taskType,
Expand All @@ -100,7 +101,8 @@ public Request(
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue inferenceTimeout
TimeValue inferenceTimeout,
boolean stream
) {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
Expand All @@ -109,6 +111,7 @@ public Request(
this.taskSettings = taskSettings;
this.inputType = inputType;
this.inferenceTimeout = inferenceTimeout;
this.stream = stream;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -134,6 +137,9 @@ public Request(StreamInput in) throws IOException {
this.query = null;
this.inferenceTimeout = DEFAULT_TIMEOUT;
}

// streaming is not supported yet for transport traffic
this.stream = false;
}

public TaskType getTaskType() {
Expand Down Expand Up @@ -165,7 +171,7 @@ public TimeValue getInferenceTimeout() {
}

public boolean isStreaming() {
return false;
return stream;
}

@Override
Expand Down Expand Up @@ -261,6 +267,7 @@ public static class Builder {
private Map<String, Object> taskSettings = Map.of();
private String query;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;

private Builder() {}

Expand Down Expand Up @@ -303,8 +310,13 @@ private Builder setInferenceTimeout(String inferenceTimeout) {
return setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName()));
}

public Builder setStream(boolean stream) {
this.stream = stream;
return this;
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout);
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ protected InferenceAction.Request createTestInstance() {
randomList(1, 5, () -> randomAlphaOfLength(8)),
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
randomFrom(InputType.values()),
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
TimeValue.timeValueMillis(randomLongBetween(1, 2048)),
false
);
}

Expand Down Expand Up @@ -80,7 +81,8 @@ public void testValidation_TextEmbedding() {
List.of("input"),
null,
null,
null
null,
false
);
ActionRequestValidationException e = request.validate();
assertNull(e);
Expand All @@ -94,7 +96,8 @@ public void testValidation_Rerank() {
List.of("input"),
null,
null,
null
null,
false
);
ActionRequestValidationException e = request.validate();
assertNull(e);
Expand All @@ -108,7 +111,8 @@ public void testValidation_TextEmbedding_Null() {
null,
null,
null,
null
null,
false
);
ActionRequestValidationException inputNullError = inputNullRequest.validate();
assertNotNull(inputNullError);
Expand All @@ -123,7 +127,8 @@ public void testValidation_TextEmbedding_Empty() {
List.of(),
null,
null,
null
null,
false
);
ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate();
assertNotNull(inputEmptyError);
Expand All @@ -138,7 +143,8 @@ public void testValidation_Rerank_Null() {
List.of("input"),
null,
null,
null
null,
false
);
ActionRequestValidationException queryNullError = queryNullRequest.validate();
assertNotNull(queryNullError);
Expand All @@ -153,7 +159,8 @@ public void testValidation_Rerank_Empty() {
List.of("input"),
null,
null,
null
null,
false
);
ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate();
assertNotNull(queryEmptyError);
Expand Down Expand Up @@ -185,7 +192,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
instance.getInferenceTimeout()
instance.getInferenceTimeout(),
false
);
}
case 1 -> new InferenceAction.Request(
Expand All @@ -195,7 +203,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
instance.getInferenceTimeout()
instance.getInferenceTimeout(),
false
);
case 2 -> {
var changedInputs = new ArrayList<String>(instance.getInput());
Expand All @@ -207,7 +216,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
changedInputs,
instance.getTaskSettings(),
instance.getInputType(),
instance.getInferenceTimeout()
instance.getInferenceTimeout(),
false
);
}
case 3 -> {
Expand All @@ -225,7 +235,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
instance.getInput(),
taskSettings,
instance.getInputType(),
instance.getInferenceTimeout()
instance.getInferenceTimeout(),
false
);
}
case 4 -> {
Expand All @@ -237,7 +248,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
instance.getInput(),
instance.getTaskSettings(),
nextInputType,
instance.getInferenceTimeout()
instance.getInferenceTimeout(),
false
);
}
case 5 -> new InferenceAction.Request(
Expand All @@ -247,7 +259,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
instance.getInferenceTimeout()
instance.getInferenceTimeout(),
false
);
case 6 -> {
var newDuration = Duration.of(
Expand All @@ -262,7 +275,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis())
TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()),
false
);
}
default -> throw new UnsupportedOperationException();
Expand All @@ -279,7 +293,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
instance.getInput().subList(0, 1),
instance.getTaskSettings(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);
} else if (version.before(TransportVersions.V_8_13_0)) {
return new InferenceAction.Request(
Expand All @@ -289,7 +304,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
instance.getInput(),
instance.getTaskSettings(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);
} else if (version.before(TransportVersions.V_8_13_0)
&& (instance.getInputType() == InputType.UNSPECIFIED
Expand All @@ -302,7 +318,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
instance.getInput(),
instance.getTaskSettings(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);
} else if (version.before(TransportVersions.V_8_13_0)
&& (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) {
Expand All @@ -313,7 +330,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
instance.getInput(),
instance.getTaskSettings(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);
} else if (version.before(TransportVersions.V_8_14_0)) {
return new InferenceAction.Request(
Expand All @@ -323,7 +341,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);
}

Expand All @@ -339,7 +358,8 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio
List.of(),
Map.of(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
),
TransportVersions.V_8_13_0
);
Expand All @@ -353,7 +373,8 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn
List.of(),
Map.of(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);

InferenceAction.Request deserializedInstance = copyWriteable(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.entity.ContentType;
import org.apache.http.nio.ContentDecoder;
import org.apache.http.nio.IOControl;
import org.apache.http.nio.protocol.AbstractAsyncResponseConsumer;
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.atomic.AtomicReference;

class AsyncInferenceResponseConsumer extends AbstractAsyncResponseConsumer<HttpResponse> {
private final AtomicReference<HttpResponse> httpResponse = new AtomicReference<>();
private final Deque<ServerSentEvent> collector = new ArrayDeque<>();
private final ServerSentEventParser sseParser = new ServerSentEventParser();
private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096);

@Override
protected void onResponseReceived(HttpResponse httpResponse) {
this.httpResponse.set(httpResponse);
}

@Override
protected void onContentReceived(ContentDecoder contentDecoder, IOControl ioControl) throws IOException {
inputBuffer.consumeContent(contentDecoder);
}

@Override
protected void onEntityEnclosed(HttpEntity httpEntity, ContentType contentType) {
httpResponse.updateAndGet(response -> {
response.setEntity(httpEntity);
return response;
});
}

@Override
protected HttpResponse buildResult(HttpContext httpContext) {
var allBytes = new byte[inputBuffer.length()];
try {
inputBuffer.read(allBytes);
sseParser.parse(allBytes).forEach(collector::offer);
} catch (IOException e) {
failed(e);
}
return httpResponse.get();
}

@Override
protected void releaseResources() {}

Deque<ServerSentEvent> events() {
return collector;
}
}
Loading

0 comments on commit cb42fd4

Please sign in to comment.