Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Model API #1350

Merged
merged 19 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MLModelGroup implements ToXContentObject {
@Setter
private String name;
private String description;
@Setter
private int latestVersion;
private List<String> backendRoles;
private User owner;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class MLUpdateModelAction extends ActionType<UpdateResponse> {
public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/update";

private MLUpdateModelAction() {
super(NAME, UpdateResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Data;
import lombok.Builder;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.Connector.createConnector;

@Data
public class MLUpdateModelInput implements ToXContentObject, Writeable {

public static final String MODEL_ID_FIELD = "model_id"; // mandatory
public static final String DESCRIPTION_FIELD = "description"; // optional
public static final String MODEL_VERSION_FIELD = "model_version"; // optional
public static final String MODEL_NAME_FIELD = "name"; // optional
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional
public static final String MODEL_CONFIG_FIELD = "model_config"; // optional
public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional

@Getter
private String modelId;
private String description;
private String version;
private String name;
private String modelGroupId;
private MLModelConfig modelConfig;
private String connectorId;

@Builder(toBuilder = true)
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) {
this.modelId = modelId;
b4sjoo marked this conversation as resolved.
Show resolved Hide resolved
this.description = description;
this.version = version;
this.name = name;
this.modelGroupId = modelGroupId;
this.modelConfig = modelConfig;
this.connectorId = connectorId;
}

public MLUpdateModelInput(StreamInput in) throws IOException {
this.modelId = in.readString();
this.description = in.readOptionalString();
this.version = in.readOptionalString();
this.name = in.readOptionalString();
this.modelGroupId = in.readOptionalString();
if (in.readBoolean()) {
modelConfig = new TextEmbeddingModelConfig(in);
}
this.connectorId = in.readOptionalString();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID_FIELD, modelId);
if (name != null) {
builder.field(MODEL_NAME_FIELD, name);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (version != null) {
builder.field(MODEL_VERSION_FIELD, version);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalString(description);
out.writeOptionalString(version);
out.writeOptionalString(name);
out.writeOptionalString(modelGroupId);
if (modelConfig != null) {
out.writeBoolean(true);
modelConfig.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
}

public static MLUpdateModelInput parse(XContentParser parser) throws IOException {
String modelId = null;
String description = null;
String version = null;
String name = null;
String modelGroupId = null;
MLModelConfig modelConfig = null;
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case DESCRIPTION_FIELD:
description = parser.text();
break;
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_VERSION_FIELD:
version = parser.text();
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
default:
parser.skipChildren();
break;
}
}
// Model ID can only be set through RestRequest. Model version can only be set automatically.
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLUpdateModelRequest extends ActionRequest {

MLUpdateModelInput updateModelInput;

@Builder
public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) {
this.updateModelInput = updateModelInput;
}

public MLUpdateModelRequest(StreamInput in) throws IOException {
super(in);
updateModelInput = new MLUpdateModelInput(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (updateModelInput == null) {
exception = addValidationError("Update Model Input can't be null", exception);
}

return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.updateModelInput.writeTo(out);
}

public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){
if (actionRequest instanceof MLUpdateModelRequest) {
return (MLUpdateModelRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUpdateModelRequest(in);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e);
}
}
}
Loading
Loading