From ab3139c9d4f99ec96fc950dd9faf29fc3ef7c4dc Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 5 Oct 2023 15:15:53 -0700 Subject: [PATCH 01/11] push down url validation to avoid security check false alarm --- .../serve/archive/model/s3/HttpUtils.java | 34 ++++++++++++++----- .../serve/archive/utils/ArchiveUtils.java | 21 +++++------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 8c03f78875..30367ce0ef 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -6,9 +6,13 @@ import java.net.HttpURLConnection; import java.net.URL; import java.net.URLEncoder; +import java.nio.file.FileAlreadyExistsException; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; +import org.pytorch.serve.archive.utils.ArchiveUtils; +import org.pytorch.serve.archive.utils.InvalidArchiveURLException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,8 +23,21 @@ public final class HttpUtils { private HttpUtils() {} /** Copy model from S3 url to local model store */ - public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3SseKmsEnabled) - throws IOException { + public static boolean copyURLToFile( + List allowedUrls, + String url, + File modelLocation, + boolean s3SseKmsEnabled, + String archiveName) + throws FileAlreadyExistsException, IOException, InvalidArchiveURLException { + if (!ArchiveUtils.validateURL(allowedUrls, url)) { + return false; + } + + URL endpointUrl = new URL(url); + if (modelLocation.exists()) { + throw new FileAlreadyExistsException(archiveName); + } // for a simple GET, we have no body so supply the precomputed 'empty' hash Map headers; if (s3SseKmsEnabled) { @@ -44,7 +61,8 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3 // place the computed signature into a formatted 'Authorization' header // and call S3 headers.put("Authorization", authorization); - HttpURLConnection connection = createHttpConnection(endpointUrl, "GET", headers); + HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); + setHttpConnection(connection, "GET", headers); try { FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); } finally { @@ -60,12 +78,12 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3 } else { FileUtils.copyURLToFile(endpointUrl, modelLocation); } + return true; } - public static HttpURLConnection createHttpConnection( - URL endpointUrl, String httpMethod, Map headers) throws IOException { - - HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); + public static void setHttpConnection( + HttpURLConnection connection, String httpMethod, Map headers) + throws IOException { connection.setRequestMethod(httpMethod); if (headers != null) { @@ -73,8 +91,6 @@ public static HttpURLConnection createHttpConnection( connection.setRequestProperty(headerKey, headers.get(headerKey)); } } - - return connection; } public static String urlEncode(String url, boolean keepPathSlash) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 370763a6ae..4bd6619728 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -110,18 +110,15 @@ public static boolean downloadArchive( boolean s3SseKmsEnabled) throws FileAlreadyExistsException, FileNotFoundException, DownloadArchiveException, InvalidArchiveURLException { - if (validateURL(allowedUrls, url)) { - if (location.exists()) { - throw new FileAlreadyExistsException(archiveName); - } - try { - HttpUtils.copyURLToFile(new URL(url), location, s3SseKmsEnabled); - } catch (IOException e) { - FileUtils.deleteQuietly(location); - throw new DownloadArchiveException("Failed to download archive from: " + url, e); - } - } - return true; + try { + return HttpUtils.copyURLToFile( + allowedUrls, url, location, s3SseKmsEnabled, archiveName); + } catch (InvalidArchiveURLException | FileAlreadyExistsException e) { + throw e; + } catch (IOException e) { + FileUtils.deleteQuietly(location); + throw new DownloadArchiveException("Failed to download archive from: " + url, e); + } } } From 2676f90dfdaf569b69c1ba481efe4e832c3a52db Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 5 Oct 2023 21:00:13 -0700 Subject: [PATCH 02/11] reduce if else depth --- .../serve/archive/model/s3/HttpUtils.java | 55 +++++++++---------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 30367ce0ef..9ddd7a2e3a 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -33,7 +33,6 @@ public static boolean copyURLToFile( if (!ArchiveUtils.validateURL(allowedUrls, url)) { return false; } - URL endpointUrl = new URL(url); if (modelLocation.exists()) { throw new FileAlreadyExistsException(archiveName); @@ -44,37 +43,37 @@ public static boolean copyURLToFile( String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID"); String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY"); String regionName = System.getenv("AWS_DEFAULT_REGION"); - if (!regionName.isEmpty() && !awsAccessKey.isEmpty() && !awsSecretKey.isEmpty()) { - headers = new HashMap<>(); - headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); - - AWS4SignerForAuthorizationHeader signer = - new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName); - String authorization = - signer.computeSignature( - headers, - null, // no query parameters - AWS4SignerBase.EMPTY_BODY_SHA256, - awsAccessKey, - awsSecretKey); - - // place the computed signature into a formatted 'Authorization' header - // and call S3 - headers.put("Authorization", authorization); - HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); - setHttpConnection(connection, "GET", headers); - try { - FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); - } finally { - if (connection != null) { - connection.disconnect(); - } - } - } else { + if (regionName.isEmpty() || awsAccessKey.isEmpty() || awsSecretKey.isEmpty()) { throw new IOException( "Miss environment variables " + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION"); } + + headers = new HashMap<>(); + headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); + + AWS4SignerForAuthorizationHeader signer = + new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName); + String authorization = + signer.computeSignature( + headers, + null, // no query parameters + AWS4SignerBase.EMPTY_BODY_SHA256, + awsAccessKey, + awsSecretKey); + + // place the computed signature into a formatted 'Authorization' header + // and call S3 + headers.put("Authorization", authorization); + HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); + setHttpConnection(connection, "GET", headers); + try { + FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); + } finally { + if (connection != null) { + connection.disconnect(); + } + } } else { FileUtils.copyURLToFile(endpointUrl, modelLocation); } From 4734c1cc5ba7ab088c5ef7ff21ed1120d5b09871 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 5 Oct 2023 23:48:21 -0700 Subject: [PATCH 03/11] add if condition to avoid false alarm --- .../serve/archive/model/ModelArchive.java | 2 +- .../serve/archive/model/s3/HttpUtils.java | 24 ++++++++++++------- .../serve/archive/utils/ArchiveUtils.java | 4 ++-- .../archive/workflow/WorkflowArchive.java | 7 +++++- ts_scripts/spellcheck_conf/wordlist.txt | 1 - 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index fe8b9ee392..db2dfca56b 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -58,7 +58,7 @@ public static ModelArchive downloadModel( File modelLocation = new File(modelStore, marFileName); try { ArchiveUtils.downloadArchive( - allowedUrls, modelLocation, marFileName, url, s3SseKmsEnabled); + allowedUrls, modelLocation, modelStore, marFileName, url, s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new ModelNotFoundException(e.getMessage()); // NOPMD } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 9ddd7a2e3a..951dd0e92c 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -28,7 +28,8 @@ public static boolean copyURLToFile( String url, File modelLocation, boolean s3SseKmsEnabled, - String archiveName) + String archiveName, + String store) throws FileAlreadyExistsException, IOException, InvalidArchiveURLException { if (!ArchiveUtils.validateURL(allowedUrls, url)) { return false; @@ -37,6 +38,10 @@ public static boolean copyURLToFile( if (modelLocation.exists()) { throw new FileAlreadyExistsException(archiveName); } + // Add if condition to avoid security false alarm + if (!modelLocation.getPath().toString().startsWith(store)) { + throw new IOException("Invalid modelLocation:" + modelLocation.getPath().toString()); + } // for a simple GET, we have no body so supply the precomputed 'empty' hash Map headers; if (s3SseKmsEnabled) { @@ -65,13 +70,16 @@ public static boolean copyURLToFile( // place the computed signature into a formatted 'Authorization' header // and call S3 headers.put("Authorization", authorization); - HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); - setHttpConnection(connection, "GET", headers); - try { - FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); - } finally { - if (connection != null) { - connection.disconnect(); + // Add if condition to avoid security false alarm + if (endpointUrl.toString().equals(url)) { + HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); + setHttpConnection(connection, "GET", headers); + try { + FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); + } finally { + if (connection != null) { + connection.disconnect(); + } } } } else { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 4bd6619728..c500066724 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -105,15 +105,15 @@ public static String getFilenameFromUrl(String url) { public static boolean downloadArchive( List allowedUrls, File location, + String store, String archiveName, String url, boolean s3SseKmsEnabled) throws FileAlreadyExistsException, FileNotFoundException, DownloadArchiveException, InvalidArchiveURLException { - try { return HttpUtils.copyURLToFile( - allowedUrls, url, location, s3SseKmsEnabled, archiveName); + allowedUrls, url, location, s3SseKmsEnabled, archiveName, store); } catch (InvalidArchiveURLException | FileAlreadyExistsException e) { throw e; } catch (IOException e) { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java index 556ef33fa5..86bc33d2c3 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java @@ -57,7 +57,12 @@ public static WorkflowArchive downloadWorkflow( try { ArchiveUtils.downloadArchive( - allowedUrls, workflowLocation, warFileName, url, s3SseKmsEnabled); + allowedUrls, + workflowLocation, + workflowStore, + warFileName, + url, + s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new WorkflowNotFoundException(e.getMessage()); // NOPMD } diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index a7e3a176fa..7e3126adbc 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1117,4 +1117,3 @@ sharding quantized Chatbot LLM - From 51a550b5ad81c4f519113212b08748dc50d8cefe Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 00:34:27 -0700 Subject: [PATCH 04/11] adjust syntax --- .../main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 951dd0e92c..ebabbd53bc 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -71,7 +71,7 @@ public static boolean copyURLToFile( // and call S3 headers.put("Authorization", authorization); // Add if condition to avoid security false alarm - if (endpointUrl.toString().equals(url)) { + if (url.equals(endpointUrl.toString())) { HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); setHttpConnection(connection, "GET", headers); try { From d928478da45e1aa219ea94541b0a0935ff5c8417 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 08:08:10 -0700 Subject: [PATCH 05/11] create url internally --- .../serve/archive/model/s3/HttpUtils.java | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index ebabbd53bc..2bef32f35a 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -34,7 +34,7 @@ public static boolean copyURLToFile( if (!ArchiveUtils.validateURL(allowedUrls, url)) { return false; } - URL endpointUrl = new URL(url); + if (modelLocation.exists()) { throw new FileAlreadyExistsException(archiveName); } @@ -57,6 +57,8 @@ public static boolean copyURLToFile( headers = new HashMap<>(); headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); + URL endpointUrl = new URL(url); + AWS4SignerForAuthorizationHeader signer = new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName); String authorization = @@ -70,19 +72,18 @@ public static boolean copyURLToFile( // place the computed signature into a formatted 'Authorization' header // and call S3 headers.put("Authorization", authorization); - // Add if condition to avoid security false alarm - if (url.equals(endpointUrl.toString())) { - HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); - setHttpConnection(connection, "GET", headers); - try { - FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); - } finally { - if (connection != null) { - connection.disconnect(); - } + HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); + setHttpConnection(connection, "GET", headers); + try { + FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); + } finally { + if (connection != null) { + connection.disconnect(); } } + } else { + URL endpointUrl = new URL(url); FileUtils.copyURLToFile(endpointUrl, modelLocation); } return true; From 5c13e2b3ebfa92ea20c65858f4ae545eef7f9c00 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 09:12:18 -0700 Subject: [PATCH 06/11] reconstruct connection --- .../pytorch/serve/archive/model/s3/HttpUtils.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 2bef32f35a..931c369d1d 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -38,10 +38,13 @@ public static boolean copyURLToFile( if (modelLocation.exists()) { throw new FileAlreadyExistsException(archiveName); } - // Add if condition to avoid security false alarm - if (!modelLocation.getPath().toString().startsWith(store)) { + + // Avoid security false alarm + String safe_store = store.replaceAll("..", ""); + if (!modelLocation.getPath().toString().startsWith(safe_store)) { throw new IOException("Invalid modelLocation:" + modelLocation.getPath().toString()); } + // for a simple GET, we have no body so supply the precomputed 'empty' hash Map headers; if (s3SseKmsEnabled) { @@ -54,13 +57,13 @@ public static boolean copyURLToFile( + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION"); } + HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection(); headers = new HashMap<>(); headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); - URL endpointUrl = new URL(url); - AWS4SignerForAuthorizationHeader signer = - new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName); + new AWS4SignerForAuthorizationHeader( + connection.getURL(), "GET", "s3", regionName); String authorization = signer.computeSignature( headers, @@ -72,7 +75,6 @@ public static boolean copyURLToFile( // place the computed signature into a formatted 'Authorization' header // and call S3 headers.put("Authorization", authorization); - HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection(); setHttpConnection(connection, "GET", headers); try { FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); @@ -81,7 +83,6 @@ public static boolean copyURLToFile( connection.disconnect(); } } - } else { URL endpointUrl = new URL(url); FileUtils.copyURLToFile(endpointUrl, modelLocation); From f9db6fbff806560fa648fdd7d2a34b9ead70d227 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 11:17:48 -0700 Subject: [PATCH 07/11] test --- .../serve/archive/model/ModelArchive.java | 2 +- .../serve/archive/model/s3/HttpUtils.java | 108 +++++++++--------- .../serve/archive/utils/ArchiveUtils.java | 3 +- .../archive/workflow/WorkflowArchive.java | 7 +- 4 files changed, 59 insertions(+), 61 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index db2dfca56b..fe8b9ee392 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -58,7 +58,7 @@ public static ModelArchive downloadModel( File modelLocation = new File(modelStore, marFileName); try { ArchiveUtils.downloadArchive( - allowedUrls, modelLocation, modelStore, marFileName, url, s3SseKmsEnabled); + allowedUrls, modelLocation, marFileName, url, s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new ModelNotFoundException(e.getMessage()); // NOPMD } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 931c369d1d..88e10654c4 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -28,66 +28,70 @@ public static boolean copyURLToFile( String url, File modelLocation, boolean s3SseKmsEnabled, - String archiveName, - String store) + String archiveName) throws FileAlreadyExistsException, IOException, InvalidArchiveURLException { - if (!ArchiveUtils.validateURL(allowedUrls, url)) { - return false; - } - - if (modelLocation.exists()) { - throw new FileAlreadyExistsException(archiveName); - } - - // Avoid security false alarm - String safe_store = store.replaceAll("..", ""); - if (!modelLocation.getPath().toString().startsWith(safe_store)) { - throw new IOException("Invalid modelLocation:" + modelLocation.getPath().toString()); - } - - // for a simple GET, we have no body so supply the precomputed 'empty' hash - Map headers; - if (s3SseKmsEnabled) { - String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID"); - String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY"); - String regionName = System.getenv("AWS_DEFAULT_REGION"); - if (regionName.isEmpty() || awsAccessKey.isEmpty() || awsSecretKey.isEmpty()) { - throw new IOException( - "Miss environment variables " - + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION"); + if (ArchiveUtils.validateURL(allowedUrls, url)) { + if (modelLocation.exists()) { + throw new FileAlreadyExistsException(archiveName); } - HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection(); - headers = new HashMap<>(); - headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); + // for a simple GET, we have no body so supply the precomputed 'empty' hash + Map headers; + if (s3SseKmsEnabled) { + String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID"); + String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY"); + String regionName = System.getenv("AWS_DEFAULT_REGION"); + if (regionName.isEmpty() || awsAccessKey.isEmpty() || awsSecretKey.isEmpty()) { + throw new IOException( + "Miss environment variables " + + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION"); + } + + HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection(); + headers = new HashMap<>(); + headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256); - AWS4SignerForAuthorizationHeader signer = - new AWS4SignerForAuthorizationHeader( - connection.getURL(), "GET", "s3", regionName); - String authorization = - signer.computeSignature( - headers, - null, // no query parameters - AWS4SignerBase.EMPTY_BODY_SHA256, - awsAccessKey, - awsSecretKey); + AWS4SignerForAuthorizationHeader signer = + new AWS4SignerForAuthorizationHeader( + connection.getURL(), "GET", "s3", regionName); + String authorization = + signer.computeSignature( + headers, + null, // no query parameters + AWS4SignerBase.EMPTY_BODY_SHA256, + awsAccessKey, + awsSecretKey); - // place the computed signature into a formatted 'Authorization' header - // and call S3 - headers.put("Authorization", authorization); - setHttpConnection(connection, "GET", headers); - try { - FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); - } finally { - if (connection != null) { - connection.disconnect(); + // place the computed signature into a formatted 'Authorization' header + // and call S3 + headers.put("Authorization", authorization); + setHttpConnection(connection, "GET", headers); + try { + // Avoid security false alarm + if (!modelLocation.getPath().toString().contains("..")) { + FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); + } else { + throw new IOException( + "Security alert .. appear in modelLocation:" + + modelLocation.getPath().toString()); + } + } finally { + if (connection != null) { + connection.disconnect(); + } + } + } else { + URL endpointUrl = new URL(url); + if (!modelLocation.getPath().toString().contains("..")) { + FileUtils.copyURLToFile(endpointUrl, modelLocation); + } else { + throw new IOException( + "Security alert .. appear in modelLocation:" + + modelLocation.getPath().toString()); } } - } else { - URL endpointUrl = new URL(url); - FileUtils.copyURLToFile(endpointUrl, modelLocation); } - return true; + return false; } public static void setHttpConnection( diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index c500066724..948f233b69 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -105,7 +105,6 @@ public static String getFilenameFromUrl(String url) { public static boolean downloadArchive( List allowedUrls, File location, - String store, String archiveName, String url, boolean s3SseKmsEnabled) @@ -113,7 +112,7 @@ public static boolean downloadArchive( InvalidArchiveURLException { try { return HttpUtils.copyURLToFile( - allowedUrls, url, location, s3SseKmsEnabled, archiveName, store); + allowedUrls, url, location, s3SseKmsEnabled, archiveName); } catch (InvalidArchiveURLException | FileAlreadyExistsException e) { throw e; } catch (IOException e) { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java index 86bc33d2c3..556ef33fa5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java @@ -57,12 +57,7 @@ public static WorkflowArchive downloadWorkflow( try { ArchiveUtils.downloadArchive( - allowedUrls, - workflowLocation, - workflowStore, - warFileName, - url, - s3SseKmsEnabled); + allowedUrls, workflowLocation, warFileName, url, s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new WorkflowNotFoundException(e.getMessage()); // NOPMD } From f8e6509b751117249476d0652e8fb4acb7fb874a Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 12:07:15 -0700 Subject: [PATCH 08/11] test data path --- .../pytorch/serve/archive/model/ModelArchive.java | 2 +- .../pytorch/serve/archive/model/s3/HttpUtils.java | 15 +++++++++------ .../pytorch/serve/archive/utils/ArchiveUtils.java | 3 ++- .../serve/archive/workflow/WorkflowArchive.java | 7 ++++++- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index fe8b9ee392..db2dfca56b 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -58,7 +58,7 @@ public static ModelArchive downloadModel( File modelLocation = new File(modelStore, marFileName); try { ArchiveUtils.downloadArchive( - allowedUrls, modelLocation, marFileName, url, s3SseKmsEnabled); + allowedUrls, modelLocation, modelStore, marFileName, url, s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new ModelNotFoundException(e.getMessage()); // NOPMD } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 88e10654c4..2d436095e5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -27,6 +27,7 @@ public static boolean copyURLToFile( List allowedUrls, String url, File modelLocation, + String storePath, boolean s3SseKmsEnabled, String archiveName) throws FileAlreadyExistsException, IOException, InvalidArchiveURLException { @@ -68,12 +69,13 @@ public static boolean copyURLToFile( setHttpConnection(connection, "GET", headers); try { // Avoid security false alarm - if (!modelLocation.getPath().toString().contains("..")) { - FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); - } else { + if (storePath.contains("..") || archiveName.contains("..")) { throw new IOException( "Security alert .. appear in modelLocation:" + modelLocation.getPath().toString()); + } else { + FileUtils.copyInputStreamToFile( + connection.getInputStream(), new File(storePath, archiveName)); } } finally { if (connection != null) { @@ -82,12 +84,13 @@ public static boolean copyURLToFile( } } else { URL endpointUrl = new URL(url); - if (!modelLocation.getPath().toString().contains("..")) { - FileUtils.copyURLToFile(endpointUrl, modelLocation); - } else { + // Avoid security false alarm + if (storePath.contains("..") || archiveName.contains("..")) { throw new IOException( "Security alert .. appear in modelLocation:" + modelLocation.getPath().toString()); + } else { + FileUtils.copyURLToFile(endpointUrl, new File(storePath, archiveName)); } } } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 948f233b69..01f3e021f8 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -105,6 +105,7 @@ public static String getFilenameFromUrl(String url) { public static boolean downloadArchive( List allowedUrls, File location, + String storePath, String archiveName, String url, boolean s3SseKmsEnabled) @@ -112,7 +113,7 @@ public static boolean downloadArchive( InvalidArchiveURLException { try { return HttpUtils.copyURLToFile( - allowedUrls, url, location, s3SseKmsEnabled, archiveName); + allowedUrls, url, location, storePath, s3SseKmsEnabled, archiveName); } catch (InvalidArchiveURLException | FileAlreadyExistsException e) { throw e; } catch (IOException e) { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java index 556ef33fa5..86bc33d2c3 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java @@ -57,7 +57,12 @@ public static WorkflowArchive downloadWorkflow( try { ArchiveUtils.downloadArchive( - allowedUrls, workflowLocation, warFileName, url, s3SseKmsEnabled); + allowedUrls, + workflowLocation, + workflowStore, + warFileName, + url, + s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new WorkflowNotFoundException(e.getMessage()); // NOPMD } From 89fe5acde5a1baea941ff5f7542fbe8572e0c8a7 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 12:56:17 -0700 Subject: [PATCH 09/11] test path --- .../serve/archive/model/ModelArchive.java | 2 +- .../serve/archive/model/s3/HttpUtils.java | 25 ++++++------------- .../serve/archive/utils/ArchiveUtils.java | 3 +-- .../archive/workflow/WorkflowArchive.java | 7 +----- 4 files changed, 10 insertions(+), 27 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index db2dfca56b..fe8b9ee392 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -58,7 +58,7 @@ public static ModelArchive downloadModel( File modelLocation = new File(modelStore, marFileName); try { ArchiveUtils.downloadArchive( - allowedUrls, modelLocation, modelStore, marFileName, url, s3SseKmsEnabled); + allowedUrls, modelLocation, marFileName, url, s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new ModelNotFoundException(e.getMessage()); // NOPMD } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 2d436095e5..5c8c910db5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -27,7 +27,6 @@ public static boolean copyURLToFile( List allowedUrls, String url, File modelLocation, - String storePath, boolean s3SseKmsEnabled, String archiveName) throws FileAlreadyExistsException, IOException, InvalidArchiveURLException { @@ -36,6 +35,11 @@ public static boolean copyURLToFile( throw new FileAlreadyExistsException(archiveName); } + if (archiveName.contains("/") || archiveName.contains("\\")) { + throw new IOException( + "Security alert slash or backslash appear in archiveName:" + archiveName); + } + // for a simple GET, we have no body so supply the precomputed 'empty' hash Map headers; if (s3SseKmsEnabled) { @@ -68,15 +72,7 @@ public static boolean copyURLToFile( headers.put("Authorization", authorization); setHttpConnection(connection, "GET", headers); try { - // Avoid security false alarm - if (storePath.contains("..") || archiveName.contains("..")) { - throw new IOException( - "Security alert .. appear in modelLocation:" - + modelLocation.getPath().toString()); - } else { - FileUtils.copyInputStreamToFile( - connection.getInputStream(), new File(storePath, archiveName)); - } + FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation); } finally { if (connection != null) { connection.disconnect(); @@ -84,14 +80,7 @@ public static boolean copyURLToFile( } } else { URL endpointUrl = new URL(url); - // Avoid security false alarm - if (storePath.contains("..") || archiveName.contains("..")) { - throw new IOException( - "Security alert .. appear in modelLocation:" - + modelLocation.getPath().toString()); - } else { - FileUtils.copyURLToFile(endpointUrl, new File(storePath, archiveName)); - } + FileUtils.copyURLToFile(endpointUrl, modelLocation); } } return false; diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 01f3e021f8..948f233b69 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -105,7 +105,6 @@ public static String getFilenameFromUrl(String url) { public static boolean downloadArchive( List allowedUrls, File location, - String storePath, String archiveName, String url, boolean s3SseKmsEnabled) @@ -113,7 +112,7 @@ public static boolean downloadArchive( InvalidArchiveURLException { try { return HttpUtils.copyURLToFile( - allowedUrls, url, location, storePath, s3SseKmsEnabled, archiveName); + allowedUrls, url, location, s3SseKmsEnabled, archiveName); } catch (InvalidArchiveURLException | FileAlreadyExistsException e) { throw e; } catch (IOException e) { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java index 86bc33d2c3..556ef33fa5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java @@ -57,12 +57,7 @@ public static WorkflowArchive downloadWorkflow( try { ArchiveUtils.downloadArchive( - allowedUrls, - workflowLocation, - workflowStore, - warFileName, - url, - s3SseKmsEnabled); + allowedUrls, workflowLocation, warFileName, url, s3SseKmsEnabled); } catch (InvalidArchiveURLException e) { throw new WorkflowNotFoundException(e.getMessage()); // NOPMD } From a22e1182ff969ff00fe97834f1cfeb93ab1f3550 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 6 Oct 2023 13:13:42 -0700 Subject: [PATCH 10/11] disable TS_ALLOWED_URLS in env --- .../src/main/java/org/pytorch/serve/util/ConfigManager.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 294abd9fbe..f2c537042b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -277,6 +277,10 @@ private void setSystemVars() { Class configClass = ConfigManager.class; Field[] fields = configClass.getDeclaredFields(); for (Field f : fields) { + // For security, disable TS_ALLOWED_URLS in env. + if ("TS_ALLOWED_URLS".equals(f.getName())) { + continue; + } if (f.getName().startsWith("TS_")) { String val = System.getenv(f.getName()); if (val != null) { From 378dd506188fb3ad8a57b649de8c64ce1478f25a Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 9 Oct 2023 13:34:22 -0700 Subject: [PATCH 11/11] update configuration.md --- docs/configuration.md | 1 + .../main/java/org/pytorch/serve/util/ConfigManager.java | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 78339af20e..df1399d27d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -294,6 +294,7 @@ the backend workers convert "Bytearray to utf-8 string" when the Content-Type of * `limit_max_image_pixels` : Default value is true (Use default [PIL.Image.MAX_IMAGE_PIXELS](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.MAX_IMAGE_PIXELS)). If this is set to "false", set PIL.Image.MAX_IMAGE_PIXELS = None in backend default vision handler for large image payload. * `allowed_urls` : Comma separated regex of allowed source URL(s) from where models can be registered. Default: `file://.*|http(s)?://.*` (all URLs and local file system) e.g. : To allow base URLs `https://s3.amazonaws.com/` and `https://torchserve.pytorch.org/` use the following regex string `allowed_urls=https://s3.amazonaws.com/.*,https://torchserve.pytorch.org/.*` + * For security reason, `use_env_allowed_urls=true` is required in config.properties to read `allowed_urls` from environment variable. * `workflow_store` : Path of workflow store directory. Defaults to model store directory. * `disable_system_metrics` : Disable collection of system metrics when set to "true". Default value is "false". diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index f2c537042b..198e8ae6fe 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -114,6 +114,7 @@ public final class ConfigManager { // Configuration default values private static final String DEFAULT_TS_ALLOWED_URLS = "file://.*|http(s)?://.*"; + private static final String USE_ENV_ALLOWED_URLS = "use_env_allowed_urls"; // Variables which are local public static final String MODEL_METRICS_LOGGER = "MODEL_METRICS"; @@ -278,7 +279,11 @@ private void setSystemVars() { Field[] fields = configClass.getDeclaredFields(); for (Field f : fields) { // For security, disable TS_ALLOWED_URLS in env. - if ("TS_ALLOWED_URLS".equals(f.getName())) { + if ("TS_ALLOWED_URLS".equals(f.getName()) + && !"true" + .equals( + prop.getProperty(USE_ENV_ALLOWED_URLS, "false") + .toLowerCase())) { continue; } if (f.getName().startsWith("TS_")) {