Skip to content

Commit

Permalink
Add more UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Feb 4, 2024
1 parent d1d7761 commit 6a99720
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ public static SdkHttpFullRequest buildSdkRequest(
} else {
requestBody = RequestBody.empty();
}
if (requestBody.optionalContentLength().isEmpty()) {
log.error("Content length is empty. Aborting request to remote model");
throw new IllegalArgumentException("Content length is empty. Aborting request to remote model");
if (SdkHttpMethod.POST == method && "0".equals(requestBody.optionalContentLength().get().toString())) {
log.error("Content length is 0. Aborting request to remote model");
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public static void validateIp(String hostName) throws UnknownHostException {
InetAddress[] addresses = InetAddress.getAllByName(hostName);
if (hasPrivateIpAddress(addresses)) {
log.error("Remote inference host name has private ip address: " + hostName);
throw new IllegalArgumentException(hostName);
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.mockito.ArgumentMatchers.any;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
Expand All @@ -21,6 +21,7 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
Expand All @@ -41,7 +42,6 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableList;
Expand All @@ -62,9 +62,6 @@ public class AwsConnectorExecutorTest {

ThreadContext threadContext;

@Mock
ScriptService scriptService;

@Mock
ActionListener<MLTaskResponse> actionListener;

Expand Down Expand Up @@ -127,7 +124,10 @@ public void executePredict_RemoteInferenceInput_invalidIp() {

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
Mockito.verify(actionListener, times(1)).onFailure(any(MLException.class));
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof MLException;
assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage());
}

@Test
Expand All @@ -136,7 +136,7 @@ public void executePredict_RemoteInferenceInput_illegalIpAddress() {
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.url("http://localhost/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
Expand All @@ -161,6 +161,10 @@ public void executePredict_RemoteInferenceInput_illegalIpAddress() {

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
assertEquals("Remote inference host name has private ip address: localhost", exceptionCaptor.getValue().getMessage());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;

import org.junit.Before;
Expand All @@ -20,6 +21,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.HttpConnector;
Expand All @@ -30,6 +32,13 @@

import com.google.common.collect.ImmutableMap;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.spy;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;

public class HttpJsonConnectorExecutorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -62,6 +71,39 @@ public void invokeRemoteModel_WrongHttpMethod() {
executor.invokeRemoteModel(null, null, null, null, null, actionListener);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Mockito.verify(actionListener, Mockito.times(1)).onFailure(captor.capture());
assertEquals("unsupported http method", captor.getValue().getMessage());
}

@Test
public void invokeRemoteModel_invalidIpAddress() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://127.0.0.1/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor
.invokeRemoteModel(
createMLInput(),
new HashMap<>(),
"{\"input\": \"hello world\"}",
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
actionListener
);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Mockito.verify(actionListener, Mockito.times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof IllegalArgumentException;
assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -90,6 +132,10 @@ public void invokeRemoteModel_Empty_payload() {
new WrappedCountDownLatch(0, new CountDownLatch(1)),
actionListener
);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Mockito.verify(actionListener, Mockito.times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof IllegalArgumentException;
assertEquals("Content length is 0. Aborting request to remote model", captor.getValue().getMessage());
}

@Test
Expand Down

0 comments on commit 6a99720

Please sign in to comment.