Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[ML] Add queue_capacity setting to start deployment API (#79369) #79374

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
public static final ParseField WAIT_FOR = new ParseField("wait_for");
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;

public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);

Expand All @@ -70,7 +69,6 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
}

public static Request parseRequest(String modelId, XContentParser parser) {
Expand All @@ -89,7 +87,6 @@ public static Request parseRequest(String modelId, XContentParser parser) {
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private int modelThreads = 1;
private int inferenceThreads = 1;
private int queueCapacity = 1024;

private Request() {}

Expand All @@ -104,7 +101,6 @@ public Request(StreamInput in) throws IOException {
waitForState = in.readEnum(AllocationStatus.State.class);
modelThreads = in.readVInt();
inferenceThreads = in.readVInt();
queueCapacity = in.readVInt();
}

public final void setModelId(String modelId) {
Expand Down Expand Up @@ -148,14 +144,6 @@ public void setInferenceThreads(int inferenceThreads) {
this.inferenceThreads = inferenceThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}

public void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand All @@ -164,7 +152,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(waitForState);
out.writeVInt(modelThreads);
out.writeVInt(inferenceThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -175,7 +162,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(WAIT_FOR.getPreferredName(), waitForState);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}
Expand All @@ -197,15 +183,12 @@ public ActionRequestValidationException validate() {
if (inferenceThreads < 1) {
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
}
if (queueCapacity < 1 || queueCapacity > 10000) {
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be in [1, 10000]");
}
return validationException.validationErrors().isEmpty() ? null : validationException;
}

@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
}

@Override
Expand All @@ -221,8 +204,7 @@ public boolean equals(Object obj) {
&& Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState)
&& modelThreads == other.modelThreads
&& inferenceThreads == other.inferenceThreads
&& queueCapacity == other.queueCapacity;
&& inferenceThreads == other.inferenceThreads;
}

@Override
Expand All @@ -244,20 +226,16 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");

private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
}

public static TaskParams fromXContent(XContentParser parser) {
Expand All @@ -275,9 +253,8 @@ public static TaskParams fromXContent(XContentParser parser) {
private final long modelBytes;
private final int inferenceThreads;
private final int modelThreads;
private final int queueCapacity;

public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes;
if (modelBytes < 0) {
Expand All @@ -291,18 +268,13 @@ public TaskParams(String modelId, long modelBytes, int inferenceThreads, int mod
if (modelThreads < 1) {
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
}
this.queueCapacity = queueCapacity;
if (queueCapacity < 1 || queueCapacity > 10000) {
throw new IllegalArgumentException(QUEUE_CAPACITY + " must be in [1, 10000]");
}
}

public TaskParams(StreamInput in) throws IOException {
this.modelId = in.readString();
this.modelBytes = in.readVLong();
this.inferenceThreads = in.readVInt();
this.modelThreads = in.readVInt();
this.queueCapacity = in.readVInt();
}

public String getModelId() {
Expand All @@ -324,7 +296,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(modelBytes);
out.writeVInt(inferenceThreads);
out.writeVInt(modelThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -334,14 +305,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
}

@Override
Expand All @@ -353,8 +323,7 @@ public boolean equals(Object o) {
return Objects.equals(modelId, other.modelId)
&& modelBytes == other.modelBytes
&& inferenceThreads == other.inferenceThreads
&& modelThreads == other.modelThreads
&& queueCapacity == other.queueCapacity;
&& modelThreads == other.modelThreads;
}

@Override
Expand All @@ -373,10 +342,6 @@ public int getInferenceThreads() {
public int getModelThreads() {
return modelThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}
}

public interface TaskMatcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire

@Override
protected Request createTestInstance() {
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
return new Request(
new StartTrainedModelDeploymentAction.TaskParams(
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8)
)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.io.IOException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -54,9 +53,6 @@ public static Request createRandom() {
if (randomBoolean()) {
request.setModelThreads(randomIntBetween(1, 8));
}
if (randomBoolean()) {
request.setQueueCapacity(randomIntBetween(1, 10000));
}
return request;
}

Expand Down Expand Up @@ -99,43 +95,4 @@ public void testValidate_GivenModelThreadsIsNegative() {
assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
}

public void testValidate_GivenQueueCapacityIsZero() {
Request request = createRandom();
request.setQueueCapacity(0);

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
}

public void testValidate_GivenQueueCapacityIsNegative() {
Request request = createRandom();
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
}

public void testValidate_GivenQueueCapacityIsGreaterThan10000() {
Request request = createRandom();
request.setQueueCapacity(randomIntBetween(10001, Integer.MAX_VALUE));

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
}

public void testDefaults() {
Request request = new Request(randomAlphaOfLength(10));
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
assertThat(request.getInferenceThreads(), equalTo(1));
assertThat(request.getModelThreads(), equalTo(1));
assertThat(request.getQueueCapacity(), equalTo(1024));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
randomIntBetween(1, 8)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;

import java.io.IOException;
import java.util.List;
Expand All @@ -32,7 +31,9 @@
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {

public static TrainedModelAllocation randomInstance() {
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
);
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
for (String node : nodes) {
if (randomBoolean()) {
Expand Down Expand Up @@ -248,7 +249,7 @@ private static DiscoveryNode buildNode() {
}

private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
}

private static void assertUnchanged(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
Expand All @@ -34,7 +35,6 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
Expand Down Expand Up @@ -161,8 +161,7 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
trainedModelConfig.getModelId(),
modelBytes,
request.getInferenceThreads(),
request.getModelThreads(),
request.getQueueCapacity()
request.getModelThreads()
);
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
PersistentTasksCustomMetadata.TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ public void onFailure(Exception e) {

@Override
protected void doRun() throws Exception {
logger.info("Request [{}] running", requestId);
final String requestIdStr = String.valueOf(requestId);
try {
// The request builder expect a list of inputs which are then batched.
Expand Down Expand Up @@ -393,11 +392,7 @@ class ProcessContext {
this.task = Objects.requireNonNull(task);
resultProcessor = new PyTorchResultProcessor(task.getModelId());
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
this.executorService = new ProcessWorkerExecutorService(
threadPool.getThreadContext(),
"pytorch_inference",
task.getParams().getQueueCapacity()
);
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
}

PyTorchResultProcessor getResultProcessor() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;

Expand Down Expand Up @@ -60,7 +59,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
));
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
}

return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ private static DiscoveryNode buildOldNode(String name, boolean isML, long native
}

private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1);
}

private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String
modelId,
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
randomIntBetween(1, 8)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ private void withSearchingLoadFailure(String modelId) {
}

private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1);
}

private TrainedModelAllocationNodeService createService() {
Expand Down
Loading