Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

push down url validation to avoid security check false alarm #2685

Merged
merged 14 commits into from
Oct 9, 2023
1 change: 1 addition & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ the backend workers convert "Bytearray to utf-8 string" when the Content-Type of
* `limit_max_image_pixels` : Default value is true (Use default [PIL.Image.MAX_IMAGE_PIXELS](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.MAX_IMAGE_PIXELS)). If this is set to "false", set PIL.Image.MAX_IMAGE_PIXELS = None in backend default vision handler for large image payload.
* `allowed_urls` : Comma separated regex of allowed source URL(s) from where models can be registered. Default: `file://.*|http(s)?://.*` (all URLs and local file system)
e.g. : To allow base URLs `https://s3.amazonaws.com/` and `https://torchserve.pytorch.org/` use the following regex string `allowed_urls=https://s3.amazonaws.com/.*,https://torchserve.pytorch.org/.*`
* For security reason, `use_env_allowed_urls=true` is required in config.properties to read `allowed_urls` from environment variable.
* `workflow_store` : Path of workflow store directory. Defaults to model store directory.
* `disable_system_metrics` : Disable collection of system metrics when set to "true". Default value is "false".

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.file.FileAlreadyExistsException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.archive.utils.InvalidArchiveURLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -19,20 +23,42 @@ public final class HttpUtils {
private HttpUtils() {}

/** Copy model from S3 url to local model store */
public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3SseKmsEnabled)
throws IOException {
// for a simple GET, we have no body so supply the precomputed 'empty' hash
Map<String, String> headers;
if (s3SseKmsEnabled) {
String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID");
String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY");
String regionName = System.getenv("AWS_DEFAULT_REGION");
if (!regionName.isEmpty() && !awsAccessKey.isEmpty() && !awsSecretKey.isEmpty()) {
public static boolean copyURLToFile(
List<String> allowedUrls,
String url,
File modelLocation,
boolean s3SseKmsEnabled,
String archiveName)
throws FileAlreadyExistsException, IOException, InvalidArchiveURLException {
if (ArchiveUtils.validateURL(allowedUrls, url)) {
if (modelLocation.exists()) {
throw new FileAlreadyExistsException(archiveName);
}

if (archiveName.contains("/") || archiveName.contains("\\")) {
throw new IOException(
"Security alert slash or backslash appear in archiveName:" + archiveName);
}

Copy link
Member

@msaroufim msaroufim Oct 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's some more validaiton we can do

  • Ensure the url is using the https protocol and not http or ftp
  • Potentially check that the IP is not a litteral IP and instead a domain?
  • Limit redirections
  • ..%2F is used for path traveral attacks
  • Optionally some additional checks on user-agent and content-type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these validation is covered by function validateURL

// for a simple GET, we have no body so supply the precomputed 'empty' hash
Map<String, String> headers;
if (s3SseKmsEnabled) {
String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID");
String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY");
String regionName = System.getenv("AWS_DEFAULT_REGION");
if (regionName.isEmpty() || awsAccessKey.isEmpty() || awsSecretKey.isEmpty()) {
throw new IOException(
"Miss environment variables "
+ "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION");
}

HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();

Check failure

Code scanning / CodeQL

Server-side request forgery

Potential server-side request forgery due to a [user-provided value](1).
headers = new HashMap<>();
headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256);

AWS4SignerForAuthorizationHeader signer =
new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName);
new AWS4SignerForAuthorizationHeader(
connection.getURL(), "GET", "s3", regionName);
String authorization =
signer.computeSignature(
headers,
Expand All @@ -44,7 +70,7 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3
// place the computed signature into a formatted 'Authorization' header
// and call S3
headers.put("Authorization", authorization);
HttpURLConnection connection = createHttpConnection(endpointUrl, "GET", headers);
setHttpConnection(connection, "GET", headers);
try {
FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation);
} finally {
Expand All @@ -53,28 +79,23 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3
}
}
} else {
throw new IOException(
"Miss environment variables "
+ "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION");
URL endpointUrl = new URL(url);
FileUtils.copyURLToFile(endpointUrl, modelLocation);

Check failure

Code scanning / CodeQL

Server-side request forgery

Potential server-side request forgery due to a [user-provided value](1).

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression

This path depends on a [user-provided value](1).
}
} else {
FileUtils.copyURLToFile(endpointUrl, modelLocation);
}
return false;
}

public static HttpURLConnection createHttpConnection(
URL endpointUrl, String httpMethod, Map<String, String> headers) throws IOException {

HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection();
public static void setHttpConnection(
HttpURLConnection connection, String httpMethod, Map<String, String> headers)
throws IOException {
connection.setRequestMethod(httpMethod);

if (headers != null) {
for (String headerKey : headers.keySet()) {
connection.setRequestProperty(headerKey, headers.get(headerKey));
}
}

return connection;
}

public static String urlEncode(String url, boolean keepPathSlash)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,14 @@ public static boolean downloadArchive(
boolean s3SseKmsEnabled)
throws FileAlreadyExistsException, FileNotFoundException, DownloadArchiveException,
InvalidArchiveURLException {
if (validateURL(allowedUrls, url)) {
if (location.exists()) {
throw new FileAlreadyExistsException(archiveName);
}
try {
HttpUtils.copyURLToFile(new URL(url), location, s3SseKmsEnabled);
} catch (IOException e) {
FileUtils.deleteQuietly(location);
throw new DownloadArchiveException("Failed to download archive from: " + url, e);
}
try {
return HttpUtils.copyURLToFile(
allowedUrls, url, location, s3SseKmsEnabled, archiveName);
} catch (InvalidArchiveURLException | FileAlreadyExistsException e) {
throw e;
} catch (IOException e) {
FileUtils.deleteQuietly(location);
throw new DownloadArchiveException("Failed to download archive from: " + url, e);
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public final class ConfigManager {

// Configuration default values
private static final String DEFAULT_TS_ALLOWED_URLS = "file://.*|http(s)?://.*";
private static final String USE_ENV_ALLOWED_URLS = "use_env_allowed_urls";

// Variables which are local
public static final String MODEL_METRICS_LOGGER = "MODEL_METRICS";
Expand Down Expand Up @@ -277,6 +278,14 @@ private void setSystemVars() {
Class<ConfigManager> configClass = ConfigManager.class;
Field[] fields = configClass.getDeclaredFields();
for (Field f : fields) {
// For security, disable TS_ALLOWED_URLS in env.
if ("TS_ALLOWED_URLS".equals(f.getName())
&& !"true"
.equals(
prop.getProperty(USE_ENV_ALLOWED_URLS, "false")
.toLowerCase())) {
continue;
}
if (f.getName().startsWith("TS_")) {
String val = System.getenv(f.getName());
if (val != null) {
Expand Down