Skip to content

Commit

Permalink
add exception to pre-trained model
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 4, 2023
1 parent 74fbe08 commit 7d39173
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String IS_THIS_VERSION_CREATING_MODEL_GROUP = "is_this_version_creating_model_group";
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";
private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand All @@ -72,7 +72,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode accessMode;
private Boolean isThisVersionCreatingModelGroup;
private Boolean doesVersionCreateModelGroup;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
Expand All @@ -91,7 +91,7 @@ public MLRegisterModelInput(FunctionName functionName,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode accessMode,
Boolean isThisVersionCreatingModelGroup
Boolean doesVersionCreateModelGroup
) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -124,7 +124,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.accessMode = accessMode;
this.isThisVersionCreatingModelGroup = isThisVersionCreatingModelGroup;
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
}


Expand Down Expand Up @@ -155,7 +155,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.isThisVersionCreatingModelGroup = in.readOptionalBoolean();
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -201,7 +201,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isThisVersionCreatingModelGroup);
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}

@Override
Expand Down Expand Up @@ -249,8 +249,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (accessMode != null) {
builder.field(ACCESS_MODE_FIELD, accessMode);
}
if (isThisVersionCreatingModelGroup != null) {
builder.field(IS_THIS_VERSION_CREATING_MODEL_GROUP, isThisVersionCreatingModelGroup);
if (doesVersionCreateModelGroup != null) {
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
}
builder.endObject();
return builder;
Expand All @@ -270,7 +270,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
List<String> backendRoles = new ArrayList<>();
Boolean addAllBackendRoles = null;
AccessMode accessMode = null;
Boolean isThisVersionCreatingModelGroup = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -322,15 +322,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case IS_THIS_VERSION_CREATING_MODEL_GROUP:
isThisVersionCreatingModelGroup = parser.booleanValue();
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, isThisVersionCreatingModelGroup);
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
Expand All @@ -349,7 +349,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
List<String> backendRoles = new ArrayList<>();
AccessMode accessMode = null;
Boolean addAllBackendRoles = null;
Boolean isThisVersionCreatingModelGroup = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -408,14 +408,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case IS_THIS_VERSION_CREATING_MODEL_GROUP:
isThisVersionCreatingModelGroup = parser.booleanValue();
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, isThisVersionCreatingModelGroup);
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
}
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
} catch (Exception e) {
log.error("Failed to parse ml model group" + modelGroup.getId(), e);
listener.onFailure(e);
}
} else {
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,47 @@ private void checkUserAccess(
User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper
.validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
if (isModelNameAlreadyExisting) {
if (access) {
doRegister(registerModelInput, listener);
return;
}
// if the user does not have access, we need to check three more conditions before throwing exception.
// if we are checking the access based on the name provided in the input, we let user know the name is already used by a
// model group they do not have access to.
if (isModelNameAlreadyExisting) {
// This case handles when user is using the same pre-trained model already registered by another user on the cluster.
// The only way here is for the user to first create model group and use its ID in the request
if (registerModelInput.getModelGroupId() != null
&& (registerModelInput.getUrl() == null
&& registerModelInput.getFunctionName() != FunctionName.REMOTE
&& registerModelInput.getConnectorId() == null)) {
listener
.onFailure(
new IllegalArgumentException(
"The name \""
"Without a model group ID, the system will use the model name {"
+ registerModelInput.getModelName()
+ "\" you provided is already being used by another model group \""
+ "} to create a new model group. However, this name is taken by another group {"
+ registerModelInput.getModelGroupId()
+ "\" to which you do not have access. Please provide a different name."
+ "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."
)
);
} else
} else {
listener
.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
} else {
doRegister(registerModelInput, listener);
.onFailure(
new IllegalArgumentException(
"The name {"
+ registerModelInput.getModelName()
+ "} you provided is unavailable because it is used by another model group {"
+ registerModelInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
}
return;
}
// if user does not have access to the model group ID provided in the input, we let user know they do not have access to the
// specified model group
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
}, listener::onFailure));
}

Expand Down Expand Up @@ -233,7 +256,7 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
registerModelInput.setModelGroupId(modelGroupId);
registerModelInput.setIsThisVersionCreatingModelGroup(true);
registerModelInput.setDoesVersionCreateModelGroup(true);
registerModel(registerModelInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
Expand Down
16 changes: 7 additions & 9 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -723,19 +723,15 @@ private void registerModel(
handleException(functionName, taskId, e);
deleteFileQuietly(file);
// remove model doc as failed to upload model
deleteModel(
modelId,
registerModelInput.getModelGroupId(),
registerModelInput.getIsThisVersionCreatingModelGroup()
);
deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getDoesVersionCreateModelGroup());
semaphore.release();
deleteFileQuietly(mlEngine.getRegisterModelPath(modelId));
}));
}
}, e -> {
log.error("Failed to index chunk file", e);
deleteFileQuietly(mlEngine.getRegisterModelPath(modelId));
deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getIsThisVersionCreatingModelGroup());
deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getDoesVersionCreateModelGroup());
handleException(functionName, taskId, e);
})
);
Expand Down Expand Up @@ -802,7 +798,7 @@ private void updateModelRegisterStateAsDone(
}, e -> {
log.error("Failed to update model", e);
handleException(functionName, taskId, e);
deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getIsThisVersionCreatingModelGroup());
deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getDoesVersionCreateModelGroup());
}));
}

Expand All @@ -815,7 +811,7 @@ private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput
client.execute(MLDeployModelAction.INSTANCE, request, listener);
}

private void deleteModel(String modelId, String modelGroupID, Boolean isThisVersionCreatingModelGroup) {
private void deleteModel(String modelId, String modelGroupID, Boolean doesVersionCreateModelGroup) {
DeleteRequest deleteRequest = new DeleteRequest();
deleteRequest.index(ML_MODEL_INDEX).id(modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.delete(deleteRequest);
Expand All @@ -824,7 +820,9 @@ private void deleteModel(String modelId, String modelGroupID, Boolean isThisVers
.setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN)
.setAbortOnVersionConflict(false);
client.execute(DeleteByQueryAction.INSTANCE, deleteChunksRequest);
if (isThisVersionCreatingModelGroup) {
// This checks if model group is created when registering the version and deletes the model group since the version registration had
// failed
if (doesVersionCreateModelGroup) {
deleteModelGroup(modelGroupID);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ public void test_NoAccessWhenModelNameAlreadyExists() throws IOException {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"The name \"Test Model\" you provided is already being used by another model group \"model_group_ID\" to which you do not have access. Please provide a different name.",
"The name {Test Model} you provided is unavailable because it is used by another model group {model_group_ID} to which you do not have access. Please provide a different name.",
argumentCaptor.getValue().getMessage()
);
}
Expand Down

0 comments on commit 7d39173

Please sign in to comment.