diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index e79a09c5b2..7c733d5a7e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -52,7 +52,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String ACCESS_MODE_FIELD = "access_mode"; public static final String BACKEND_ROLES_FIELD = "backend_roles"; public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; - + public static final String IS_THIS_VERSION_CREATING_MODEL_GROUP = "is_this_version_creating_model_group"; private FunctionName functionName; private String modelName; private String modelGroupId; @@ -72,6 +72,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private List backendRoles; private Boolean addAllBackendRoles; private AccessMode accessMode; + private Boolean isThisVersionCreatingModelGroup; @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, @@ -89,7 +90,8 @@ public MLRegisterModelInput(FunctionName functionName, String connectorId, List backendRoles, Boolean addAllBackendRoles, - AccessMode accessMode + AccessMode accessMode, + Boolean isThisVersionCreatingModelGroup ) { if (functionName == null) { this.functionName = FunctionName.TEXT_EMBEDDING; @@ -122,6 +124,7 @@ public MLRegisterModelInput(FunctionName functionName, this.backendRoles = backendRoles; this.addAllBackendRoles = addAllBackendRoles; this.accessMode = accessMode; + this.isThisVersionCreatingModelGroup = isThisVersionCreatingModelGroup; } @@ -152,6 +155,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { if (in.readBoolean()) { this.accessMode = in.readEnum(AccessMode.class); } + this.isThisVersionCreatingModelGroup = in.readOptionalBoolean(); } @Override @@ -197,6 +201,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalBoolean(isThisVersionCreatingModelGroup); } @Override @@ -244,6 +249,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (accessMode != null) { builder.field(ACCESS_MODE_FIELD, accessMode); } + if (isThisVersionCreatingModelGroup != null) { + builder.field(IS_THIS_VERSION_CREATING_MODEL_GROUP, isThisVersionCreatingModelGroup); + } builder.endObject(); return builder; } @@ -262,6 +270,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName List backendRoles = new ArrayList<>(); Boolean addAllBackendRoles = null; AccessMode accessMode = null; + Boolean isThisVersionCreatingModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -313,12 +322,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case IS_THIS_VERSION_CREATING_MODEL_GROUP: + isThisVersionCreatingModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, isThisVersionCreatingModelGroup); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -337,6 +349,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo List backendRoles = new ArrayList<>(); AccessMode accessMode = null; Boolean addAllBackendRoles = null; + Boolean isThisVersionCreatingModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -395,11 +408,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case IS_THIS_VERSION_CREATING_MODEL_GROUP: + isThisVersionCreatingModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, isThisVersionCreatingModelGroup); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 5d53c4dea9..8f2966a815 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,7 +9,6 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; -import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; @@ -90,39 +89,39 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (modelGroup.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { + if (modelGroup.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); - updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); } - } else { - listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); + updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user); } - }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); - } else { - logException("Failed to get model group", e, log); - listener.onFailure(e); - } - })); - } catch (Exception e) { - logException("Failed to Update model group", e, log); - listener.onFailure(e); - } - } else { - validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); - updateModelGroup(modelGroupId, new HashMap<>(), updateModelGroupInput, listener, user); + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + } else { + logException("Failed to get model group", e, log); + listener.onFailure(e); + } + })); + } catch (Exception e) { + logException("Failed to Update model group", e, log); + listener.onFailure(e); } + } private void updateModelGroup( diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 2948b255d1..c71b5434c0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -6,14 +6,12 @@ package org.opensearch.ml.action.models; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; -import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; @@ -21,8 +19,6 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -34,8 +30,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; @@ -48,7 +42,6 @@ import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -109,6 +102,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { @@ -117,37 +111,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - boolean isLastModelOfGroup = false; - if (response != null - && response.getHits() != null - && response.getHits().getTotalHits() != null - && response.getHits().getTotalHits().value == 1) { - isLastModelOfGroup = true; - } - deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, actionListener); - }, e -> { - log.error("Failed to Search Model index " + modelId, e); - actionListener.onFailure(e); - })); - } else { - deleteModel(modelId, mlModel.getModelGroupId(), false, actionListener); - } + deleteModel(modelId, actionListener); } }, e -> { log.error("Failed to validate Access for Model Id " + modelId, e); @@ -167,18 +143,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(MLModel.MODEL_GROUP_ID_FIELD, modelGroupId)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> { listener.onResponse(response); }, e -> { - log.error("Failed to search Model index", e); - listener.onFailure(e); - })); - } - @VisibleForTesting void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener actionListener) { DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX); @@ -217,19 +181,11 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); } - private void deleteModel( - String modelId, - String modelGroupId, - boolean isLastModelOfGroup, - ActionListener actionListener - ) { + private void deleteModel(String modelId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId); client.delete(deleteRequest, new ActionListener() { @Override public void onResponse(DeleteResponse deleteResponse) { - if (isLastModelOfGroup) { - deleteModelGroup(modelGroupId); - } deleteModelChunks(modelId, deleteResponse, actionListener); } @@ -243,19 +199,4 @@ public void onFailure(Exception e) { } }); } - - private void deleteModelGroup(String modelGroupId) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId); - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - log.debug("Completed Delete Model Group for modelGroupId:{}", modelGroupId); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to delete ML Model Group with Id:{} " + modelGroupId, e); - } - }); - } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 23006f5464..13a935c1a3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -136,14 +136,51 @@ public TransportRegisterModelAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - User user = RestActionUtils.getUserContext(client); MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (registerModelInput.getModelGroupId() == null) { + mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); + registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided); + checkUserAccess(registerModelInput, listener, true); + } else { + checkUserAccess(registerModelInput, listener, false); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + })); + } else { + checkUserAccess(registerModelInput, listener, false); + } + } + + private void checkUserAccess( + MLRegisterModelInput registerModelInput, + ActionListener listener, + Boolean isModelNameAlreadyExisting + ) { + User user = RestActionUtils.getUserContext(client); modelAccessControlHelper .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { if (!access) { - log.error("You don't have permissions to perform this operation on this model."); - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); + if (isModelNameAlreadyExisting) { + listener + .onFailure( + new IllegalArgumentException( + "The name \"" + + registerModelInput.getModelName() + + "\" you provided is already being used by another model group \"" + + registerModelInput.getModelGroupId() + + "\" to which you do not have access. Please provide a different name." + ) + ); + } else + listener + .onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); } else { doRegister(registerModelInput, listener); } @@ -196,6 +233,7 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { registerModelInput.setModelGroupId(modelGroupId); + registerModelInput.setIsThisVersionCreatingModelGroup(true); registerModel(registerModelInput, listener); }, e -> { logException("Failed to create Model Group", e, log); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index efc78edf20..aa7a16c873 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -145,15 +145,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us AccessMode modelAccessMode = input.getModelAccessMode(); Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles(); if (modelAccessMode == null) { - if (modelAccessMode == null) { - if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { - throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); - } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { - input.setModelAccessMode(AccessMode.RESTRICTED); - modelAccessMode = AccessMode.RESTRICTED; - } else { - input.setModelAccessMode(AccessMode.PRIVATE); - } + if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); + } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { + input.setModelAccessMode(AccessMode.RESTRICTED); + modelAccessMode = AccessMode.RESTRICTED; + } else { + input.setModelAccessMode(AccessMode.PRIVATE); } } if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode) @@ -183,20 +181,29 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us } public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); - - client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onResponse(null); - } else { - log.error("Failed to search model group index", e); - listener.onFailure(e); - } - })); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onResponse(null); + } else { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } + }), () -> context.restore()) + ); + } catch (Exception e) { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } } private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 090b82443f..9fc7a00a4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -325,10 +325,6 @@ public void registerMLRemoteModel( String modelGroupId = mlRegisterModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - indexRemoteModel(mlRegisterModelInput, mlTask, "1", listener); - } - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { if (getModelGroupResponse.isExists()) { Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); @@ -396,9 +392,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - uploadModel(registerModelInput, mlTask, "1"); - } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { @@ -730,7 +723,11 @@ private void registerModel( handleException(functionName, taskId, e); deleteFileQuietly(file); // remove model doc as failed to upload model - deleteModel(modelId); + deleteModel( + modelId, + registerModelInput.getModelGroupId(), + registerModelInput.getIsThisVersionCreatingModelGroup() + ); semaphore.release(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); })); @@ -738,7 +735,7 @@ private void registerModel( }, e -> { log.error("Failed to index chunk file", e); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); - deleteModel(modelId); + deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getIsThisVersionCreatingModelGroup()); handleException(functionName, taskId, e); }) ); @@ -805,7 +802,7 @@ private void updateModelRegisterStateAsDone( }, e -> { log.error("Failed to update model", e); handleException(functionName, taskId, e); - deleteModel(modelId); + deleteModel(modelId, registerModelInput.getModelGroupId(), registerModelInput.getIsThisVersionCreatingModelGroup()); })); } @@ -818,7 +815,7 @@ private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput client.execute(MLDeployModelAction.INSTANCE, request, listener); } - private void deleteModel(String modelId) { + private void deleteModel(String modelId, String modelGroupID, Boolean isThisVersionCreatingModelGroup) { DeleteRequest deleteRequest = new DeleteRequest(); deleteRequest.index(ML_MODEL_INDEX).id(modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.delete(deleteRequest); @@ -827,6 +824,15 @@ private void deleteModel(String modelId) { .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) .setAbortOnVersionConflict(false); client.execute(DeleteByQueryAction.INSTANCE, deleteChunksRequest); + if (isThisVersionCreatingModelGroup) { + deleteModelGroup(modelGroupID); + } + } + + private void deleteModelGroup(String modelGroupID) { + DeleteRequest deleteModelGroupRequest = new DeleteRequest(); + deleteModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupID).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteModelGroupRequest); } private void handleException(FunctionName functionName, String taskId, Exception e) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 1a67977291..a99488d0e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -38,6 +38,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; @@ -367,6 +368,20 @@ public void test_FailedToGetModelGroupException() { assertEquals("Failed to get model group", argumentCaptor.getValue().getMessage()); } + public void test_ModelGroupIndexNotFoundException() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IndexNotFoundException("Fail to find model group")); + return null; + }).when(client).get(any(), any()); + + MLUpdateModelGroupRequest actionRequest = prepareRequest(null, AccessMode.RESTRICTED, null); + transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find model group", argumentCaptor.getValue().getMessage()); + } + public void test_FailedToUpdatetModelGroupException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -414,15 +429,16 @@ public void test_ModelGroupNameNotUnique() throws IOException { } public void test_ExceptionSecurityDisabledCluster() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule - .expectMessage( - "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster." - ); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, true); transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.", + argumentCaptor.getValue().getMessage() + ); } private MLUpdateModelGroupRequest prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index c0f431e64f..57051643a6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -6,9 +6,7 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -22,7 +20,6 @@ import java.util.ArrayList; import java.util.Arrays; -import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; @@ -35,7 +32,6 @@ import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -50,15 +46,11 @@ import org.opensearch.index.get.GetResult; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.ScrollableHitSource; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; -import org.opensearch.ml.utils.TestHelper; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -184,115 +176,6 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { verify(actionListener).onResponse(deleteResponse); } - public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - - SearchResponse searchResponse = createModelGroupSearchResponse(1); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); - - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); - } - - public void test_Success_ModelGroupIDNotNull_NotLastModelOfGroup() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - - SearchResponse searchResponse = createModelGroupSearchResponse(2); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); - - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelGroupId("modelGroupID") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); - } - - public void test_Failure_FailedToSearchLastModel() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new Exception("Failed to search Model index")); - return null; - }).when(client).search(any(), isA(ActionListener.class)); - - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to search Model index", argumentCaptor.getValue().getMessage()); - } - public void test_UserHasNoAccessException() throws IOException { GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); doAnswer(invocation -> { @@ -517,20 +400,4 @@ public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID GetResponse getResponse = new GetResponse(getResult); return getResponse; } - - private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); - String modelContent = "{\n" - + " \"created_time\": 1684981986069,\n" - + " \"access\": \"public\",\n" - + " \"latest_version\": 0,\n" - + " \"last_updated_time\": 1684981986069,\n" - + " \"name\": \"model_group_IT\",\n" - + " \"description\": \"This is an example description\"\n" - + " }"; - SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); - SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); - return searchResponse; - } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 1a8384f45f..32c280a75b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -18,9 +18,11 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import java.io.IOException; import java.util.List; import java.util.Map; +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -30,6 +32,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -61,6 +64,9 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -144,7 +150,7 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { private ConnectorAccessControlHelper connectorAccessControlHelper; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); settings = Settings .builder() @@ -199,6 +205,13 @@ public void setup() { return null; }).when(mlTaskDispatcher).dispatch(any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(clusterService.localNode()).thenReturn(node2); when(node2.getId()).thenReturn("node2Id"); @@ -461,6 +474,66 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi ); } + public void test_ModelNameAlreadyExists() throws IOException { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId2"); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", null), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void test_FailureWhenSearchingModelGroupName() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Runtime exception")); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Runtime exception", argumentCaptor.getValue().getMessage()); + } + + public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "The name \"Test Model\" you provided is already being used by another model group \"model_group_ID\" to which you do not have access. Please provide a different name.", + argumentCaptor.getValue().getMessage() + ); + } + private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) { MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -485,4 +558,22 @@ private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) { return new MLRegisterModelRequest(registerModelInput); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"Test Model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + }