diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 79a7be4363..8e9a3e503f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -264,9 +264,9 @@ public static SdkHttpFullRequest buildSdkRequest( } else { requestBody = RequestBody.empty(); } - if (requestBody.optionalContentLength().isEmpty()) { - log.error("Content length is empty. Aborting request to remote model"); - throw new IllegalArgumentException("Content length is empty. Aborting request to remote model"); + if (SdkHttpMethod.POST == method && "0".equals(requestBody.optionalContentLength().get().toString())) { + log.error("Content length is 0. Aborting request to remote model"); + throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } SdkHttpFullRequest.Builder builder = SdkHttpFullRequest .builder() diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index 991d1c2d02..4f679a214b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -35,7 +35,7 @@ public static void validateIp(String hostName) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); if (hasPrivateIpAddress(addresses)) { log.error("Remote inference host name has private ip address: " + hostName); - throw new IllegalArgumentException(hostName); + throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index ae8feda445..fdbe2b2dd9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,7 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.mockito.ArgumentMatchers.any; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; @@ -21,6 +21,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -41,7 +42,6 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; @@ -62,9 +62,6 @@ public class AwsConnectorExecutorTest { ThreadContext threadContext; - @Mock - ScriptService scriptService; - @Mock ActionListener actionListener; @@ -127,7 +124,10 @@ public void executePredict_RemoteInferenceInput_invalidIp() { MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); - Mockito.verify(actionListener, times(1)).onFailure(any(MLException.class)); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof MLException; + assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage()); } @Test @@ -136,7 +136,7 @@ public void executePredict_RemoteInferenceInput_illegalIpAddress() { .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") + .url("http://localhost/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap @@ -161,6 +161,10 @@ public void executePredict_RemoteInferenceInput_illegalIpAddress() { MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof IllegalArgumentException; + assertEquals("Remote inference host name has private ip address: localhost", exceptionCaptor.getValue().getMessage()); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 4d582fb902..f42be9b00a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import org.junit.Before; @@ -20,6 +21,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; @@ -30,6 +32,13 @@ import com.google.common.collect.ImmutableMap; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.spy; +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; + public class HttpJsonConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -62,6 +71,39 @@ public void invokeRemoteModel_WrongHttpMethod() { executor.invokeRemoteModel(null, null, null, null, null, actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, Mockito.times(1)).onFailure(captor.capture()); + assertEquals("unsupported http method", captor.getValue().getMessage()); + } + + @Test + public void invokeRemoteModel_invalidIpAddress() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://127.0.0.1/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + "{\"input\": \"hello world\"}", + new HashMap<>(), + new WrappedCountDownLatch(0, new CountDownLatch(1)), + actionListener + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); + Mockito.verify(actionListener, Mockito.times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof IllegalArgumentException; + assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage()); } @Test @@ -90,6 +132,10 @@ public void invokeRemoteModel_Empty_payload() { new WrappedCountDownLatch(0, new CountDownLatch(1)), actionListener ); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); + Mockito.verify(actionListener, Mockito.times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof IllegalArgumentException; + assertEquals("Content length is 0. Aborting request to remote model", captor.getValue().getMessage()); } @Test