From 2daa3d4760ed73c50ff6d7e96117f69a4a03efef Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 6 Oct 2023 13:48:11 -0700 Subject: [PATCH] register new versions to a model group based on the name provided (#1452) Signed-off-by: Bhavana Ramaram --- .../TransportRegisterModelAction.java | 70 +++++++++- .../TransportRegisterModelMetaAction.java | 72 +++++++--- .../TransportRegisterModelActionTests.java | 124 +++++++++++++++++- ...TransportRegisterModelMetaActionTests.java | 91 ++++++++++++- 4 files changed, 334 insertions(+), 23 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 4556073d7d..a1d1d54258 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.regex.Pattern; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.ActionRequest; @@ -136,17 +137,76 @@ public TransportRegisterModelAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - User user = RestActionUtils.getUserContext(client); MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) { + mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); + registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided); + checkUserAccess(registerModelInput, listener, true); + } else { + doRegister(registerModelInput, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + })); + } else { + checkUserAccess(registerModelInput, listener, false); + } + } + + private void checkUserAccess( + MLRegisterModelInput registerModelInput, + ActionListener listener, + Boolean isModelNameAlreadyExisting + ) { + User user = RestActionUtils.getUserContext(client); modelAccessControlHelper .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - log.error("You don't have permissions to perform this operation on this model."); - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } else { + if (access) { doRegister(registerModelInput, listener); + return; + } + // if the user does not have access, we need to check three more conditions before throwing exception. + // if we are checking the access based on the name provided in the input, we let user know the name is already used by a + // model group they do not have access to. + if (isModelNameAlreadyExisting) { + // This case handles when user is using the same pre-trained model already registered by another user on the cluster. + // The only way here is for the user to first create model group and use its ID in the request + if (registerModelInput.getUrl() == null + && registerModelInput.getFunctionName() != FunctionName.REMOTE + && registerModelInput.getConnectorId() == null) { + listener + .onFailure( + new IllegalArgumentException( + "Without a model group ID, the system will use the model name {" + + registerModelInput.getModelName() + + "} to create a new model group. However, this name is taken by another group with id {" + + registerModelInput.getModelGroupId() + + "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request." + ) + ); + } else { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + registerModelInput.getModelName() + + "} you provided is unavailable because it is used by another model group with id {" + + registerModelInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } + return; } + // if user does not have access to the model group ID provided in the input, we let user know they do not have access to the + // specified model group + listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); }, listener::onFailure)); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index 01d8abb96c..da350ef039 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -63,25 +63,52 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); + mlUploadInput.setModelGroupId(modelGroupIdOfTheNameProvided); + checkUserAccess(mlUploadInput, listener, true); + } else { + createModelGroup(mlUploadInput, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + })); + } else { + checkUserAccess(mlUploadInput, listener, false); + } + } + private void checkUserAccess( + MLRegisterModelMetaInput mlUploadInput, + ActionListener listener, + Boolean isModelNameAlreadyExisting + ) { + + User user = RestActionUtils.getUserContext(client); modelAccessControlHelper.validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { + if (access) { + createModelGroup(mlUploadInput, listener); + return; + } + if (isModelNameAlreadyExisting) { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + mlUploadInput.getName() + + "} you provided is unavailable because it is used by another model group with id {" + + mlUploadInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } else { log.error("You don't have permissions to perform this operation on this model."); listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } else { - if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) { - MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput); - mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { - mlUploadInput.setModelGroupId(modelGroupId); - registerModelMeta(mlUploadInput, listener); - }, e -> { - logException("Failed to create Model Group", e, log); - listener.onFailure(e); - })); - } else { - registerModelMeta(mlUploadInput, listener); - } } }, e -> { logException("Failed to validate model access", e, log); @@ -89,6 +116,21 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) { + MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput); + mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { + mlUploadInput.setModelGroupId(modelGroupId); + registerModelMeta(mlUploadInput, listener); + }, e -> { + logException("Failed to create Model Group", e, log); + listener.onFailure(e); + })); + } else { + registerModelMeta(mlUploadInput, listener); + } + } + private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterModelMetaInput mlUploadInput) { return MLRegisterModelGroupInput .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 1a8384f45f..ac1f09dea1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -18,9 +18,11 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import java.io.IOException; import java.util.List; import java.util.Map; +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -30,6 +32,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -61,6 +64,9 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -144,7 +150,7 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { private ConnectorAccessControlHelper connectorAccessControlHelper; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); settings = Settings .builder() @@ -199,6 +205,13 @@ public void setup() { return null; }).when(mlTaskDispatcher).dispatch(any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(clusterService.localNode()).thenReturn(node2); when(node2.getId()).thenReturn("node2Id"); @@ -461,6 +474,97 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi ); } + public void test_ModelNameAlreadyExists() throws IOException { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId2"); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", null), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .modelName("huggingface/sentence-transformers/all-MiniLM-L12-v2") + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .version("1") + .build(); + + transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Without a model group ID, the system will use the model name {huggingface/sentence-transformers/all-MiniLM-L12-v2} to create a new model group. However, this name is taken by another group with id {model_group_ID} you can't access. To register this pre-trained model, create a new model group and use its ID in your request.", + argumentCaptor.getValue().getMessage() + + ); + } + + public void test_FailureWhenSearchingModelGroupName() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Runtime exception")); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Runtime exception", argumentCaptor.getValue().getMessage()); + } + + public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "The name {Test Model} you provided is unavailable because it is used by another model group with id {model_group_ID} to which you do not have access. Please provide a different name.", + argumentCaptor.getValue().getMessage() + ); + } + private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) { MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -485,4 +589,22 @@ private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) { return new MLRegisterModelRequest(registerModelInput); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"Test Model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 26b2f3f091..f7eb64c8eb 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -7,13 +7,18 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; @@ -30,6 +35,9 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -67,7 +75,7 @@ public class TransportRegisterModelMetaActionTests extends OpenSearchTestCase { private ModelAccessControlHelper modelAccessControlHelper; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -93,6 +101,13 @@ public void setup() { return null; }).when(mlModelManager).registerModelMeta(any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -169,10 +184,64 @@ public void test_ValidationFailedException() { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } + public void testDoExecute_ModelNameAlreadyExists() throws IOException { + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "The name {Test Model} you provided is unavailable because it is used by another model group with id {model_group_ID} to which you do not have access. Please provide a different name.", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_FailureWhenSearchingModelGroupName() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Runtime exception")); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Runtime exception", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelMetaRequest prepareRequest(String modelGroupID) { MLRegisterModelMetaInput input = MLRegisterModelMetaInput .builder() - .name("Model Name") + .name("Test Model") .modelGroupId(modelGroupID) .description("Custom Model Test") .modelFormat(MLModelFormat.TORCH_SCRIPT) @@ -195,4 +264,22 @@ private MLRegisterModelMetaRequest prepareRequest(String modelGroupID) { return new MLRegisterModelMetaRequest(input); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"Test Model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + }