diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index b6e932bccd..8e9968a38c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -13,6 +13,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.opensearch.client.Client; import org.opensearch.common.util.TokenBucket; @@ -70,7 +72,6 @@ public void invokeRemoteModel( ) { try { SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); - MLHttpClientFactory.validateIp(request.getUri().getHost()); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(signRequest(request)) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 8e9a3e503f..ff7844e213 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.net.URI; +import java.net.UnknownHostException; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; @@ -23,6 +24,8 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; @@ -38,6 +41,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import com.jayway.jsonpath.JsonPath; @@ -255,8 +259,18 @@ public static SdkHttpFullRequest buildSdkRequest( Map parameters, String payload, SdkHttpMethod method - ) { + ) throws UnknownHostException { String endpoint = connector.getPredictEndpoint(parameters); + Pattern pattern = Pattern.compile("(?:(\\w+)://)?((\\w+\\.)*\\w+)(?::(\\w+))?"); + Matcher matcher = pattern.matcher(endpoint); + if (matcher.find()) { + String protocol = matcher.group(1); + String host = matcher.group(2); + String port = matcher.group(4); + MLHttpClientFactory.validate(protocol, host, port); + } else { + throw new IllegalArgumentException("Invalid endpoint: " + endpoint); + } String charset = parameters.getOrDefault("charset", "UTF-8"); RequestBody requestBody; if (payload != null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index fa6d95a66d..b56a365213 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -78,11 +78,9 @@ public void invokeRemoteModel( case "POST": log.debug("original payload to remote model: " + payload); request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); - MLHttpClientFactory.validateIp(request.getUri().getHost()); break; case "GET": request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET); - MLHttpClientFactory.validateIp(request.getUri().getHost()); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index 4f679a214b..f8e2f1385d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -12,8 +12,11 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.Arrays; +import java.util.Locale; +import java.util.Optional; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.math.NumberUtils; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; @@ -31,7 +34,31 @@ public static SdkAsyncHttpClient getAsyncHttpClient() { } } - public static void validateIp(String hostName) throws UnknownHostException { + public static void validate(String protocol, String host, String port) throws UnknownHostException { + if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equals(protocol)) { + log.error("Remote inference protocol is not http or https: " + protocol); + throw new IllegalArgumentException("Protocol is not http or https: " + protocol); + } + String portStr = Optional.ofNullable(port).orElseGet(() -> { + if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { + return "80"; + } else { + return "443"; + } + }); + if (!NumberUtils.isDigits(portStr)) { + log.error("Remote inference port is not a valid number: " + portStr); + throw new IllegalArgumentException("Port is not a valid number: " + portStr); + } + int portNum = Integer.parseInt(portStr); + if (portNum < 0 || portNum > 65536) { + log.error("Remote inference port out of range: " + port); + throw new IllegalArgumentException("Port out of range: " + port); + } + validateIp(host); + } + + private static void validateIp(String hostName) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); if (hasPrivateIpAddress(addresses)) { log.error("Remote inference host name has private ip address: " + hostName); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index fdbe2b2dd9..889e20328c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -130,6 +130,43 @@ public void executePredict_RemoteInferenceInput_invalidIp() { assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage()); } + @Test + public void executePredict_RemoteInferenceInput_unformattedIp() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://0177.1/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof MLException; + assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage()); + } + @Test public void executePredict_RemoteInferenceInput_illegalIpAddress() { ConnectorAction predictAction = ConnectorAction diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index 28748b159f..96acc018a6 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -6,17 +6,27 @@ package org.opensearch.ml.engine.httpclient; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static software.amazon.awssdk.http.SdkHttpMethod.POST; import java.net.URI; +import java.net.UnknownHostException; import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Map; +import com.google.common.collect.ImmutableMap; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; public class MLHttpClientFactoryTests { @@ -31,120 +41,229 @@ public void test_getSdkAsyncHttpClient_success() { } @Test - public void test_validateIp_validIp_noException() { - RequestBody requestBody = RequestBody.fromString("hello world", Charset.defaultCharset()); - SdkHttpFullRequest request = SdkHttpFullRequest + public void test_validateIp_validIp_noException() throws UnknownHostException { + ConnectorAction predictAction = ConnectorAction .builder() - .method(POST) - .uri(URI.create("https://api.openai.com")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://api.openai.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + assertNotNull(request); } @Test - public void test_validateIp_rarePrivateIp_throwException() { - RequestBody requestBody = RequestBody.fromString("hello world", Charset.defaultCharset()); + public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException { try { - SdkHttpFullRequest request = SdkHttpFullRequest + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://0254.020.00.01/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("http://0177.1/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); - } catch (NullPointerException e) { + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } catch (IllegalArgumentException e) { assertNotNull(e); } try { - SdkHttpFullRequest request = SdkHttpFullRequest + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://172.1048577/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("http://172.1048577/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); - } catch (NullPointerException e) { + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } catch (IllegalArgumentException e) { assertNotNull(e); } try { - SdkHttpFullRequest request = SdkHttpFullRequest + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://2886729729/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("http://2886729729/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); - } catch (NullPointerException e) { + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } catch (IllegalArgumentException e) { assertNotNull(e); } try { - SdkHttpFullRequest request = SdkHttpFullRequest + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://192.11010049/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("http://192.11010049/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); - } catch (NullPointerException e) { + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } catch (IllegalArgumentException e) { assertNotNull(e); } try { - SdkHttpFullRequest request = SdkHttpFullRequest + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://3232300545/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("http://3232300545/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); - } catch (NullPointerException e) { + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } catch (IllegalArgumentException e) { assertNotNull(e); } try { - SdkHttpFullRequest request = SdkHttpFullRequest + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://0:0:0:0:0:ffff:127.0.0.1/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("http://0:0:0:0:0:ffff:127.0.0.1/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); - } catch (NullPointerException e) { + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } catch (IllegalArgumentException e) { assertNotNull(e); } } @Test - public void test_validateSchemaAndPort_success() { - RequestBody requestBody = RequestBody.fromString("hello world", Charset.defaultCharset()); - SdkHttpFullRequest request = SdkHttpFullRequest + public void test_validateSchemaAndPort_success() throws UnknownHostException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://api.openai.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector .builder() - .method(POST) - .uri(URI.create("https://api.openai.com:8080/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) .build(); + SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); assertNotNull(request); } @Test - public void test_validateSchemaAndPort_notAllowedSchema_throwException() { + public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws UnknownHostException { expectedException.expect(IllegalArgumentException.class); - RequestBody requestBody = RequestBody.fromString("hello world", Charset.defaultCharset()); - SdkHttpFullRequest request = SdkHttpFullRequest + expectedException.expectMessage("Protocol is not http or https: ftp"); + ConnectorAction predictAction = ConnectorAction .builder() - .method(POST) - .uri(URI.create("ftp://api.openai.com:8080/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("ftp://api.openai.com:8080/v1/completions") + .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - assertNotNull(request); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + assertNull(request); } @Test - public void test_validateSchemaAndPort_portNotInRange_throwException() { - RequestBody requestBody = RequestBody.fromString("hello world", Charset.defaultCharset()); - SdkHttpFullRequest request = SdkHttpFullRequest + public void test_validateSchemaAndPort_portNotInRange_throwException() throws UnknownHostException { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Port out of range: 65537"); + ConnectorAction predictAction = ConnectorAction .builder() - .method(POST) - .uri(URI.create("https://api.openai.com:65537/v1/completions")) - .contentStreamProvider(requestBody.contentStreamProvider()) + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com:65537/v1/completions") + .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - assertNotNull(request); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + } + + @Test + public void test_validateSchemaAndPort_portNotANumber_throwException() throws UnknownHostException { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Port is not a valid number: abc"); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com:abc/v1/completions") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); } }