Skip to content

Commit

Permalink
model registry fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 2, 2023
1 parent 6e0d949 commit 144357a
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 283 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";

public static final String IS_THIS_VERSION_CREATING_MODEL_GROUP = "is_this_version_creating_model_group";
private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand All @@ -72,6 +72,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode accessMode;
private Boolean isThisVersionCreatingModelGroup;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
Expand All @@ -89,7 +90,8 @@ public MLRegisterModelInput(FunctionName functionName,
String connectorId,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode accessMode
AccessMode accessMode,
Boolean isThisVersionCreatingModelGroup
) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -122,6 +124,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.accessMode = accessMode;
this.isThisVersionCreatingModelGroup = isThisVersionCreatingModelGroup;
}


Expand Down Expand Up @@ -152,6 +155,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.isThisVersionCreatingModelGroup = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -197,6 +201,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isThisVersionCreatingModelGroup);
}

@Override
Expand Down Expand Up @@ -244,6 +249,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (accessMode != null) {
builder.field(ACCESS_MODE_FIELD, accessMode);
}
if (isThisVersionCreatingModelGroup != null) {
builder.field(IS_THIS_VERSION_CREATING_MODEL_GROUP, isThisVersionCreatingModelGroup);
}
builder.endObject();
return builder;
}
Expand All @@ -262,6 +270,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
List<String> backendRoles = new ArrayList<>();
Boolean addAllBackendRoles = null;
AccessMode accessMode = null;
Boolean isThisVersionCreatingModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -313,12 +322,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case IS_THIS_VERSION_CREATING_MODEL_GROUP:
isThisVersionCreatingModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, isThisVersionCreatingModelGroup);
}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
Expand All @@ -337,6 +349,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
List<String> backendRoles = new ArrayList<>();
AccessMode accessMode = null;
Boolean addAllBackendRoles = null;
Boolean isThisVersionCreatingModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -395,11 +408,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case IS_THIS_VERSION_CREATING_MODEL_GROUP:
isThisVersionCreatingModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, isThisVersionCreatingModelGroup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
Expand Down Expand Up @@ -90,39 +89,39 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
MLUpdateModelGroupInput updateModelGroupInput = updateModelGroupRequest.getUpdateModelGroupInput();
String modelGroupId = updateModelGroupInput.getModelGroupID();
User user = RestActionUtils.getUserContext(client);
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
}
} else {
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
listener.onFailure(e);
}
}));
} catch (Exception e) {
logException("Failed to Update model group", e, log);
listener.onFailure(e);
}
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
updateModelGroup(modelGroupId, new HashMap<>(), updateModelGroupInput, listener, user);
} else {
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
listener.onFailure(e);
}
}));
} catch (Exception e) {
logException("Failed to Update model group", e, log);
listener.onFailure(e);
}

}

private void updateModelGroup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@
package org.opensearch.ml.action.models;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
Expand All @@ -34,8 +30,6 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
Expand All @@ -48,7 +42,6 @@
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -109,6 +102,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString();
}
MLModel mlModel = MLModel.parse(parser, algorithmName);
MLModelState mlModelState = mlModel.getModelState();

modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
Expand All @@ -117,37 +111,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
.onFailure(
new MLValidationException("User doesn't have privilege to perform this operation on this model")
);
} else if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING) | mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
actionListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
)
);
} else {
MLModelState mlModelState = mlModel.getModelState();
if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
actionListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
)
);
} else if (StringUtils.isNotEmpty(mlModel.getModelGroupId())) {
searchModel(mlModel.getModelGroupId(), ActionListener.wrap(response -> {
boolean isLastModelOfGroup = false;
if (response != null
&& response.getHits() != null
&& response.getHits().getTotalHits() != null
&& response.getHits().getTotalHits().value == 1) {
isLastModelOfGroup = true;
}
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, actionListener);
}, e -> {
log.error("Failed to Search Model index " + modelId, e);
actionListener.onFailure(e);
}));
} else {
deleteModel(modelId, mlModel.getModelGroupId(), false, actionListener);
}
deleteModel(modelId, actionListener);
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
Expand All @@ -167,18 +143,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
}

private void searchModel(String modelGroupId, ActionListener<SearchResponse> listener) {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLModel.MODEL_GROUP_ID_FIELD, modelGroupId));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);
client.search(searchRequest, ActionListener.wrap(response -> { listener.onResponse(response); }, e -> {
log.error("Failed to search Model index", e);
listener.onFailure(e);
}));
}

@VisibleForTesting
void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener<DeleteResponse> actionListener) {
DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX);
Expand Down Expand Up @@ -217,19 +181,11 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
}

private void deleteModel(
String modelId,
String modelGroupId,
boolean isLastModelOfGroup,
ActionListener<DeleteResponse> actionListener
) {
private void deleteModel(String modelId, ActionListener<DeleteResponse> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (isLastModelOfGroup) {
deleteModelGroup(modelGroupId);
}
deleteModelChunks(modelId, deleteResponse, actionListener);
}

Expand All @@ -243,19 +199,4 @@ public void onFailure(Exception e) {
}
});
}

private void deleteModelGroup(String modelGroupId) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Model Group for modelGroupId:{}", modelGroupId);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Model Group with Id:{} " + modelGroupId, e);
}
});
}
}
Loading

0 comments on commit 144357a

Please sign in to comment.