forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Model API (opensearch-project#1350)
* Update Model API POC Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Using GetRequest to get model Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Finalize model update API Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix compile Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix compileTest Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Add Unit Test Cases for Update Model API Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Tune back test coverage thereshold Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Add more unit tests on Update model API Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Add unit test for TransportUpdateModelAction class Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix a test error Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Change exception thrown to failure response Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Move the function judgement to the outter block Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Check if model is undeployed before update model Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Add more unit test for update model API Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix unit test due to blocking java 11 CI workflow Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Enabling auto bumping model version during registering to a new model group and address reviewers' other concern Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Autobump new model groups' latest version when register to a new model Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Change the REST API method from POST to PUT Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Change the update REST API endpoint Signed-off-by: Sicheng Song <sicheng.song@outlook.com> --------- Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
- Loading branch information
Showing
14 changed files
with
2,007 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class MLUpdateModelAction extends ActionType<UpdateResponse> { | ||
public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/models/update"; | ||
|
||
private MLUpdateModelAction() { | ||
super(NAME, UpdateResponse::new); | ||
} | ||
} |
155 changes: 155 additions & 0 deletions
155
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.Data; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.common.io.stream.Writeable; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.model.MLModelConfig; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.connector.Connector.createConnector; | ||
|
||
@Data | ||
public class MLUpdateModelInput implements ToXContentObject, Writeable { | ||
|
||
public static final String MODEL_ID_FIELD = "model_id"; // mandatory | ||
public static final String DESCRIPTION_FIELD = "description"; // optional | ||
public static final String MODEL_VERSION_FIELD = "model_version"; // optional | ||
public static final String MODEL_NAME_FIELD = "name"; // optional | ||
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional | ||
public static final String MODEL_CONFIG_FIELD = "model_config"; // optional | ||
public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional | ||
|
||
@Getter | ||
private String modelId; | ||
private String description; | ||
private String version; | ||
private String name; | ||
private String modelGroupId; | ||
private MLModelConfig modelConfig; | ||
private String connectorId; | ||
|
||
@Builder(toBuilder = true) | ||
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) { | ||
this.modelId = modelId; | ||
this.description = description; | ||
this.version = version; | ||
this.name = name; | ||
this.modelGroupId = modelGroupId; | ||
this.modelConfig = modelConfig; | ||
this.connectorId = connectorId; | ||
} | ||
|
||
public MLUpdateModelInput(StreamInput in) throws IOException { | ||
this.modelId = in.readString(); | ||
this.description = in.readOptionalString(); | ||
this.version = in.readOptionalString(); | ||
this.name = in.readOptionalString(); | ||
this.modelGroupId = in.readOptionalString(); | ||
if (in.readBoolean()) { | ||
modelConfig = new TextEmbeddingModelConfig(in); | ||
} | ||
this.connectorId = in.readOptionalString(); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(MODEL_ID_FIELD, modelId); | ||
if (name != null) { | ||
builder.field(MODEL_NAME_FIELD, name); | ||
} | ||
if (description != null) { | ||
builder.field(DESCRIPTION_FIELD, description); | ||
} | ||
if (version != null) { | ||
builder.field(MODEL_VERSION_FIELD, version); | ||
} | ||
if (modelGroupId != null) { | ||
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); | ||
} | ||
if (modelConfig != null) { | ||
builder.field(MODEL_CONFIG_FIELD, modelConfig); | ||
} | ||
if (connectorId != null) { | ||
builder.field(CONNECTOR_ID_FIELD, connectorId); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(modelId); | ||
out.writeOptionalString(description); | ||
out.writeOptionalString(version); | ||
out.writeOptionalString(name); | ||
out.writeOptionalString(modelGroupId); | ||
if (modelConfig != null) { | ||
out.writeBoolean(true); | ||
modelConfig.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
out.writeOptionalString(connectorId); | ||
} | ||
|
||
public static MLUpdateModelInput parse(XContentParser parser) throws IOException { | ||
String modelId = null; | ||
String description = null; | ||
String version = null; | ||
String name = null; | ||
String modelGroupId = null; | ||
MLModelConfig modelConfig = null; | ||
String connectorId = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
switch (fieldName) { | ||
case MODEL_ID_FIELD: | ||
modelId = parser.text(); | ||
break; | ||
case DESCRIPTION_FIELD: | ||
description = parser.text(); | ||
break; | ||
case MODEL_NAME_FIELD: | ||
name = parser.text(); | ||
break; | ||
case MODEL_VERSION_FIELD: | ||
version = parser.text(); | ||
break; | ||
case MODEL_GROUP_ID_FIELD: | ||
modelGroupId = parser.text(); | ||
break; | ||
case MODEL_CONFIG_FIELD: | ||
modelConfig = TextEmbeddingModelConfig.parse(parser); | ||
break; | ||
case CONNECTOR_ID_FIELD: | ||
connectorId = parser.text(); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
// Model ID can only be set through RestRequest. Model version can only be set automatically. | ||
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId); | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import lombok.experimental.FieldDefaults; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@ToString | ||
public class MLUpdateModelRequest extends ActionRequest { | ||
|
||
MLUpdateModelInput updateModelInput; | ||
|
||
@Builder | ||
public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) { | ||
this.updateModelInput = updateModelInput; | ||
} | ||
|
||
public MLUpdateModelRequest(StreamInput in) throws IOException { | ||
super(in); | ||
updateModelInput = new MLUpdateModelInput(in); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (updateModelInput == null) { | ||
exception = addValidationError("Update Model Input can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
this.updateModelInput.writeTo(out); | ||
} | ||
|
||
public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ | ||
if (actionRequest instanceof MLUpdateModelRequest) { | ||
return (MLUpdateModelRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLUpdateModelRequest(in); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); | ||
} | ||
} | ||
} |
Oops, something went wrong.