Skip to content

Commit

Permalink
Feature/os assistant (#1655)
Browse files Browse the repository at this point in the history
* Merge update model API and model level throttling/quota (#1624)

* Update Model API (#1350)

* Update Model API POC

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Using GetRequest to get model

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Finalize model update API

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix compile

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix compileTest

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Add Unit Test Cases for Update Model API

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Tune back test coverage thereshold

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Add more unit tests on Update model API

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Add unit test for TransportUpdateModelAction class

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix a test error

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Change exception thrown to failure response

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Move the function judgement to the outter block

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Check if model is undeployed before update model

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Add more unit test for update model API

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix unit test due to blocking java 11 CI workflow

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Enabling auto bumping model version during registering to a new model group and address reviewers' other concern

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Autobump new model groups' latest version when register to a new model

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Change the REST API method from POST to PUT

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Change the update REST API endpoint

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

---------

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix java compile when merging

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix model/connector update API to address security concern (#1595)

* Fix model/connector update API to address appsec concern

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix compile and build failure

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Improve unit test coverage

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix spotless

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Merge update connector feature flag to remote inference feature flag

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix compile

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix exception status

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Keep fixing exception status

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Spotless fix

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Add UT on parsing exception

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

---------

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* return parsing exception 400 for parsing errors (#1603)

add more ut in restupdateconnector

Signed-off-by: Xun Zhang <xunzh@amazon.com>
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* throttling and quota feature on single node cluster

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix java compile when merging

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Enabling in-place update on multi-node

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Fix confidential rotation in update internal connector

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

---------

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
Signed-off-by: Xun Zhang <xunzh@amazon.com>
Co-authored-by: Xun Zhang <xunzh@amazon.com>
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* merge conflict

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Change rate limiter token capacity setting (#1635)

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* support step size for embedding model which outputs less embeddings (#1586)

* support step size for embedding model which outputs less embeddings

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* tune parameter name

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fine tune processed doc to always respect step size

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* validate step size (#1587)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* backport to 2.11 (#1639)

Signed-off-by: xinyual <xinyual@amazon.com>
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

---------

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
Signed-off-by: Xun Zhang <xunzh@amazon.com>
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: xinyual <xinyual@amazon.com>
Co-authored-by: Xun Zhang <xunzh@amazon.com>
Co-authored-by: Yaliang Wu <ylwu@amazon.com>
Co-authored-by: xinyual <74362153+xinyual@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 17, 2023
1 parent 17181b7 commit d90ae06
Show file tree
Hide file tree
Showing 39 changed files with 3,150 additions and 91 deletions.
11 changes: 10 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
Expand Down Expand Up @@ -196,6 +196,15 @@ public class CommonValue {
+ MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\""
+ ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n"
+ " \""
+ MLModel.QUOTA_FLAG_FIELD
+ "\" : {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.RATE_LIMIT_NUMBER_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.RATE_LIMIT_UNIT_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
Expand Down
57 changes: 55 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;

import java.io.IOException;
import java.sql.Time;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
Expand All @@ -50,9 +52,14 @@ public class MLModel implements ToXContentObject {
public static final String MODEL_FORMAT_FIELD = "model_format";
public static final String MODEL_STATE_FIELD = "model_state";
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
//SHA256 hash value of model content.
// SHA256 hash value of model content.
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";

// Model level quota and throttling control
public static final String QUOTA_FLAG_FIELD = "quota_flag";
public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number";
public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit";

public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
Expand Down Expand Up @@ -92,6 +99,9 @@ public class MLModel implements ToXContentObject {
private Long modelContentSizeInBytes;
private String modelContentHash;
private MLModelConfig modelConfig;
private Boolean quotaFlag;
private String rateLimitNumber;
private TimeUnit rateLimitUnit;
private Instant createdTime;
private Instant lastUpdateTime;
private Instant lastRegisteredTime;
Expand Down Expand Up @@ -126,6 +136,9 @@ public MLModel(String name,
MLModelState modelState,
Long modelContentSizeInBytes,
String modelContentHash,
Boolean quotaFlag,
String rateLimitNumber,
TimeUnit rateLimitUnit,
MLModelConfig modelConfig,
Instant createdTime,
Instant lastUpdateTime,
Expand All @@ -152,6 +165,9 @@ public MLModel(String name,
this.modelState = modelState;
this.modelContentSizeInBytes = modelContentSizeInBytes;
this.modelContentHash = modelContentHash;
this.quotaFlag = quotaFlag;
this.rateLimitNumber = rateLimitNumber;
this.rateLimitUnit = rateLimitUnit;
this.modelConfig = modelConfig;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
Expand Down Expand Up @@ -197,6 +213,11 @@ public MLModel(StreamInput input) throws IOException{
modelConfig = new TextEmbeddingModelConfig(input);
}
}
quotaFlag = input.readOptionalBoolean();
rateLimitNumber = input.readOptionalString();
if (input.readBoolean()) {
rateLimitUnit = input.readEnum(TimeUnit.class);
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
lastRegisteredTime = input.readOptionalInstant();
Expand Down Expand Up @@ -250,6 +271,14 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(quotaFlag);
out.writeOptionalString(rateLimitNumber);
if (rateLimitUnit != null) {
out.writeBoolean(true);
out.writeEnum(rateLimitUnit);
} else {
out.writeBoolean(false);
}
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalInstant(lastRegisteredTime);
Expand Down Expand Up @@ -312,6 +341,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
if (quotaFlag != null) {
builder.field(QUOTA_FLAG_FIELD, quotaFlag);
}
if (rateLimitNumber != null) {
builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber);
}
if (rateLimitUnit != null) {
builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
Expand Down Expand Up @@ -371,12 +409,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
String oldContent = null;
User user = null;

String description = null;;
String description = null;
MLModelFormat modelFormat = null;
MLModelState modelState = null;
Long modelContentSizeInBytes = null;
String modelContentHash = null;
MLModelConfig modelConfig = null;
Boolean quotaFlag = null;
String rateLimitNumber = null;
TimeUnit rateLimitUnit = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
Instant lastUploadedTime = null;
Expand Down Expand Up @@ -461,6 +502,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
break;
case QUOTA_FLAG_FIELD:
quotaFlag = parser.booleanValue();
break;
case RATE_LIMIT_NUMBER_FIELD:
rateLimitNumber = parser.text();
break;
case RATE_LIMIT_UNIT_FIELD:
rateLimitUnit = TimeUnit.valueOf(parser.text());
break;
case PLANNING_WORKER_NODE_COUNT_FIELD:
planningWorkerNodeCount = parser.intValue();
break;
Expand Down Expand Up @@ -524,6 +574,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.modelContentSizeInBytes(modelContentSizeInBytes)
.modelContentHash(modelContentHash)
.modelConfig(modelConfig)
.quotaFlag(quotaFlag)
.rateLimitNumber(rateLimitNumber)
.rateLimitUnit(rateLimitUnit)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.ActionType;

public class MLInPlaceUpdateModelAction extends ActionType<MLInPlaceUpdateModelNodesResponse> {
public static final MLInPlaceUpdateModelAction INSTANCE = new MLInPlaceUpdateModelAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/in_place_update";

private MLInPlaceUpdateModelAction() { super(NAME, MLInPlaceUpdateModelNodesResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.support.nodes.BaseNodeRequest;
import java.io.IOException;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLInPlaceUpdateModelNodeRequest extends BaseNodeRequest {
@Getter
private MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest;

public MLInPlaceUpdateModelNodeRequest(StreamInput in) throws IOException {
super(in);
this.mlInPlaceUpdateModelNodesRequest = new MLInPlaceUpdateModelNodesRequest(in);
}

public MLInPlaceUpdateModelNodeRequest(MLInPlaceUpdateModelNodesRequest request) {
this.mlInPlaceUpdateModelNodesRequest = request;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
mlInPlaceUpdateModelNodesRequest.writeTo(out);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

@Getter
@Log4j2
public class MLInPlaceUpdateModelNodeResponse extends BaseNodeResponse implements ToXContentFragment {
private Map<String, String> modelUpdateStatus;

public MLInPlaceUpdateModelNodeResponse(DiscoveryNode node, Map<String, String> modelUpdateStatus) {
super(node);
this.modelUpdateStatus = modelUpdateStatus;
}

public MLInPlaceUpdateModelNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelUpdateStatus = in.readMap(StreamInput::readString, StreamInput::readString);
}
}

public static MLInPlaceUpdateModelNodeResponse readStats(StreamInput in) throws IOException {
return new MLInPlaceUpdateModelNodeResponse(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);

if (!isEmpty()) {
out.writeBoolean(true);
out.writeMap(modelUpdateStatus, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("stats");
if (modelUpdateStatus != null && modelUpdateStatus.size() > 0) {
for (Map.Entry<String, String> stat : modelUpdateStatus.entrySet()) {
builder.field(stat.getKey(), stat.getValue());
}
}
builder.endObject();
return builder;
}

public boolean isEmpty() {
return modelUpdateStatus == null || modelUpdateStatus.size() == 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Getter;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import java.io.IOException;

public class MLInPlaceUpdateModelNodesRequest extends BaseNodesRequest<MLInPlaceUpdateModelNodesRequest> {

@Getter
private String modelId;
@Getter
private boolean updatePredictorFlag;

public MLInPlaceUpdateModelNodesRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.updatePredictorFlag = in.readBoolean();
}

public MLInPlaceUpdateModelNodesRequest(String[] nodeIds, String modelId, boolean updatePredictorFlag) {
super(nodeIds);
this.modelId = modelId;
this.updatePredictorFlag = updatePredictorFlag;
}

public MLInPlaceUpdateModelNodesRequest(DiscoveryNode[] nodeIds, String modelId, boolean updatePredictorFlag) {
super(nodeIds);
this.modelId = modelId;
this.updatePredictorFlag = updatePredictorFlag;
}

public MLInPlaceUpdateModelNodesRequest(DiscoveryNode... nodes) {
super(nodes);
this.updatePredictorFlag = false;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeBoolean(updatePredictorFlag);
}
}
Loading

0 comments on commit d90ae06

Please sign in to comment.