diff --git a/docs/configuration.md b/docs/configuration.md index 78339af20e..df1399d27d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -294,6 +294,7 @@ the backend workers convert "Bytearray to utf-8 string" when the Content-Type of * `limit_max_image_pixels` : Default value is true (Use default [PIL.Image.MAX_IMAGE_PIXELS](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.MAX_IMAGE_PIXELS)). If this is set to "false", set PIL.Image.MAX_IMAGE_PIXELS = None in backend default vision handler for large image payload. * `allowed_urls` : Comma separated regex of allowed source URL(s) from where models can be registered. Default: `file://.*|http(s)?://.*` (all URLs and local file system) e.g. : To allow base URLs `https://s3.amazonaws.com/` and `https://torchserve.pytorch.org/` use the following regex string `allowed_urls=https://s3.amazonaws.com/.*,https://torchserve.pytorch.org/.*` + * For security reason, `use_env_allowed_urls=true` is required in config.properties to read `allowed_urls` from environment variable. * `workflow_store` : Path of workflow store directory. Defaults to model store directory. * `disable_system_metrics` : Disable collection of system metrics when set to "true". Default value is "false". diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 8c03f78875..5c8c910db5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -6,9 +6,13 @@ import java.net.HttpURLConnection; import java.net.URL; import java.net.URLEncoder; +import java.nio.file.FileAlreadyExistsException; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; +import org.pytorch.serve.archive.utils.ArchiveUtils; +import org.pytorch.serve.archive.utils.InvalidArchiveURLException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,20 +23,42 @@ public final class HttpUtils { private HttpUtils() {} /** Copy model from S3 url to local model store */ - public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3SseKmsEnabled) - throws IOException { - // for a simple GET, we have no body so supply the precomputed 'empty' hash - Map headers; - if (s3SseKmsEnabled) { - String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID"); - String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY"); - String regionName = System.getenv("AWS_DEFAULT_REGION"); - if (!regionName.isEmpty() && !awsAccessKey.isEmpty() && !awsSecretKey.isEmpty()) { + public static boolean copyURLToFile( + List allowedUrls, + String url, + File modelLocation, + boolean s3SseKmsEnabled, + String archiveName) + throws FileAlreadyExistsException, IOException, InvalidArchiveURLException { + if (ArchiveUtils.validateURL(allowedUrls, url)) { + if (modelLocation.exists()) { + throw new FileAlreadyExistsException(archiveName); + } + + if (archiveName.contains("/") || archiveName.contains("\\")) { + throw new IOException( + "Security alert slash or backslash appear in archiveName:" + archiveName); + } + + // for a simple GET, we have no body so supply the precomputed 'empty' hash + Map headers; + if (s3SseKmsEnabled) { + String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID"); + String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY"); + String regionName = System.getenv("AWS_DEFAULT_REGION"); + if (regionName.isEmpty() || awsAccessKey.isEmpty() || awsSecretKey.isEmpty()) { + throw new IOException( + "Miss environment variables " + + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION"); + } + + HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection(); headers = new HashMap<>(); headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); AWS4SignerForAuthorizationHeader signer = - new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName); + new AWS4SignerForAuthorizationHeader( + connection.getURL(), "GET", "s3", regionName); String authorization = signer.computeSignature( headers, @@ -44,7 +70,7 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3 // place the computed signature into a formatted 'Authorization' header // and call S3 headers.put("Authorization", authorization); - HttpURLConnection connection = createHttpConnection(endpointUrl, "GET", headers); + setHttpConnection(connection, "GET", headers); try { FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); } finally { @@ -53,19 +79,16 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3 } } } else { - throw new IOException( - "Miss environment variables " - + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION"); + URL endpointUrl = new URL(url); + FileUtils.copyURLToFile(endpointUrl, modelLocation); } - } else { - FileUtils.copyURLToFile(endpointUrl, modelLocation); } + return false; } - public static HttpURLConnection createHttpConnection( - URL endpointUrl, String httpMethod, Map headers) throws IOException { - - HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); + public static void setHttpConnection( + HttpURLConnection connection, String httpMethod, Map headers) + throws IOException { connection.setRequestMethod(httpMethod); if (headers != null) { @@ -73,8 +96,6 @@ public static HttpURLConnection createHttpConnection( connection.setRequestProperty(headerKey, headers.get(headerKey)); } } - - return connection; } public static String urlEncode(String url, boolean keepPathSlash) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 370763a6ae..948f233b69 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -110,18 +110,14 @@ public static boolean downloadArchive( boolean s3SseKmsEnabled) throws FileAlreadyExistsException, FileNotFoundException, DownloadArchiveException, InvalidArchiveURLException { - if (validateURL(allowedUrls, url)) { - if (location.exists()) { - throw new FileAlreadyExistsException(archiveName); - } - try { - HttpUtils.copyURLToFile(new URL(url), location, s3SseKmsEnabled); - } catch (IOException e) { - FileUtils.deleteQuietly(location); - throw new DownloadArchiveException("Failed to download archive from: " + url, e); - } + try { + return HttpUtils.copyURLToFile( + allowedUrls, url, location, s3SseKmsEnabled, archiveName); + } catch (InvalidArchiveURLException | FileAlreadyExistsException e) { + throw e; + } catch (IOException e) { + FileUtils.deleteQuietly(location); + throw new DownloadArchiveException("Failed to download archive from: " + url, e); } - - return true; } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 294abd9fbe..198e8ae6fe 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -114,6 +114,7 @@ public final class ConfigManager { // Configuration default values private static final String DEFAULT_TS_ALLOWED_URLS = "file://.*|http(s)?://.*"; + private static final String USE_ENV_ALLOWED_URLS = "use_env_allowed_urls"; // Variables which are local public static final String MODEL_METRICS_LOGGER = "MODEL_METRICS"; @@ -277,6 +278,14 @@ private void setSystemVars() { Class configClass = ConfigManager.class; Field[] fields = configClass.getDeclaredFields(); for (Field f : fields) { + // For security, disable TS_ALLOWED_URLS in env. + if ("TS_ALLOWED_URLS".equals(f.getName()) + && !"true" + .equals( + prop.getProperty(USE_ENV_ALLOWED_URLS, "false") + .toLowerCase())) { + continue; + } if (f.getName().startsWith("TS_")) { String val = System.getenv(f.getName()); if (val != null) {