Skip to content

Commit

Permalink
Change SSRF code to make it correct for return error stattus
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Feb 4, 2024
1 parent 9ca0330 commit 726dd25
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
Expand Down Expand Up @@ -70,7 +72,6 @@ public void invokeRemoteModel(
) {
try {
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
MLHttpClientFactory.validateIp(request.getUri().getHost());
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
.builder()
.request(signRequest(request))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

import java.io.IOException;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
Expand All @@ -38,6 +41,7 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import com.jayway.jsonpath.JsonPath;
Expand Down Expand Up @@ -255,8 +259,18 @@ public static SdkHttpFullRequest buildSdkRequest(
Map<String, String> parameters,
String payload,
SdkHttpMethod method
) {
) throws UnknownHostException {
String endpoint = connector.getPredictEndpoint(parameters);
Pattern pattern = Pattern.compile("(?:(\\w+)://)?((\\w+\\.)*\\w+)(?::(\\w+))?");
Matcher matcher = pattern.matcher(endpoint);
if (matcher.find()) {
String protocol = matcher.group(1);
String host = matcher.group(2);
String port = matcher.group(4);
MLHttpClientFactory.validate(protocol, host, port);
} else {
throw new IllegalArgumentException("Invalid endpoint: " + endpoint);
}
String charset = parameters.getOrDefault("charset", "UTF-8");
RequestBody requestBody;
if (payload != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,9 @@ public void invokeRemoteModel(
case "POST":
log.debug("original payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
MLHttpClientFactory.validateIp(request.getUri().getHost());
break;
case "GET":
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET);
MLHttpClientFactory.validateIp(request.getUri().getHost());
break;
default:
throw new IllegalArgumentException("unsupported http method");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.Locale;
import java.util.Optional;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.math.NumberUtils;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;

Expand All @@ -31,7 +34,31 @@ public static SdkAsyncHttpClient getAsyncHttpClient() {
}
}

public static void validateIp(String hostName) throws UnknownHostException {
public static void validate(String protocol, String host, String port) throws UnknownHostException {
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equals(protocol)) {
log.error("Remote inference protocol is not http or https: " + protocol);
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
}
String portStr = Optional.ofNullable(port).orElseGet(() -> {
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
return "80";
} else {
return "443";
}
});
if (!NumberUtils.isDigits(portStr)) {
log.error("Remote inference port is not a valid number: " + portStr);
throw new IllegalArgumentException("Port is not a valid number: " + portStr);
}
int portNum = Integer.parseInt(portStr);
if (portNum < 0 || portNum > 65536) {
log.error("Remote inference port out of range: " + port);
throw new IllegalArgumentException("Port out of range: " + port);
}
validateIp(host);
}

private static void validateIp(String hostName) throws UnknownHostException {
InetAddress[] addresses = InetAddress.getAllByName(hostName);
if (hasPrivateIpAddress(addresses)) {
log.error("Remote inference host name has private ip address: " + hostName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,43 @@ public void executePredict_RemoteInferenceInput_invalidIp() {
assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage());
}

@Test
public void executePredict_RemoteInferenceInput_unformattedIp() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://0177.1/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof MLException;
assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage());
}

@Test
public void executePredict_RemoteInferenceInput_illegalIpAddress() {
ConnectorAction predictAction = ConnectorAction
Expand Down
Loading

0 comments on commit 726dd25

Please sign in to comment.