Skip to content

Commit

Permalink
adding more test coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 4, 2023
1 parent eb90be2 commit 3160692
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,9 @@

package org.opensearch.ml.action.register;

import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import lombok.extern.log4j.Log4j2;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -63,10 +54,17 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;

import lombok.extern.log4j.Log4j2;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

@Log4j2
public class TransportRegisterModelAction extends HandledTransportAction<ActionRequest, MLRegisterModelResponse> {
Expand Down Expand Up @@ -176,10 +174,9 @@ private void checkUserAccess(
if (isModelNameAlreadyExisting) {
// This case handles when user is using the same pre-trained model already registered by another user on the cluster.
// The only way here is for the user to first create model group and use its ID in the request
if (registerModelInput.getModelGroupId() != null
&& (registerModelInput.getUrl() == null
if (registerModelInput.getUrl() == null
&& registerModelInput.getFunctionName() != FunctionName.REMOTE
&& registerModelInput.getConnectorId() == null)) {
&& registerModelInput.getConnectorId() == null) {
listener
.onFailure(
new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,7 @@

package org.opensearch.ml.action.register;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import com.google.common.collect.ImmutableList;
import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -72,7 +56,22 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

public class TransportRegisterModelActionTests extends OpenSearchTestCase {
@Rule
Expand Down Expand Up @@ -495,6 +494,36 @@ public void test_ModelNameAlreadyExists() throws IOException {
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException {
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(false);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
.modelName("huggingface/sentence-transformers/all-MiniLM-L12-v2")
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.version("1")
.build();

transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Without a model group ID, the system will use the model name {huggingface/sentence-transformers/all-MiniLM-L12-v2} to create a new model group. However, this name is taken by another group {model_group_ID} you can't access. To register this pre-trained model, create a new model group and use its ID in your request.",
argumentCaptor.getValue().getMessage()
);
}

public void test_FailureWhenSearchingModelGroupName() throws IOException {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
Expand Down

0 comments on commit 3160692

Please sign in to comment.