diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java
index 5e762118d8d..e85ce6fe390 100644
--- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java
+++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java
@@ -228,10 +228,15 @@ public String getQueryString() {
public void setQueryString(String queryString) {
this.queryString = queryString;
}
+
+ @Override
+ public String getServerName() {
+ return null;
+ }
}
final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
throw new UnsupportedOperationException(method + " is not supported");
}
-}
\ No newline at end of file
+}
diff --git a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java
index 1746546b3ab..6d0e8e3e93f 100644
--- a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java
+++ b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2012-2017 the original author or authors.
+ * Copyright 2012-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,14 +16,14 @@
package org.springframework.security.web.firewall;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
/**
*
@@ -59,10 +59,15 @@
* Rejects URLs that contain a URL encoded percent. See
* {@link #setAllowUrlEncodedPercent(boolean)}
*
+ *
+ * Rejects hosts that are not allowed. See
+ * {@link #setAllowedHostnames(Collection)}
+ *
*
*
* @see DefaultHttpFirewall
* @author Rob Winch
+ * @author Eddú Meléndez
* @since 4.2.4
*/
public class StrictHttpFirewall implements HttpFirewall {
@@ -82,6 +87,8 @@ public class StrictHttpFirewall implements HttpFirewall {
private Set decodedUrlBlacklist = new HashSet();
+ private Collection allowedHostnames;
+
public StrictHttpFirewall() {
urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
@@ -230,6 +237,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
}
}
+ public void setAllowedHostnames(Collection allowedHostnames) {
+ if (allowedHostnames == null) {
+ throw new IllegalArgumentException("allowedHostnames cannot be null");
+ }
+ this.allowedHostnames = allowedHostnames;
+ }
+
private void urlBlacklistsAddAll(Collection values) {
this.encodedUrlBlacklist.addAll(values);
this.decodedUrlBlacklist.addAll(values);
@@ -243,6 +257,7 @@ private void urlBlacklistsRemoveAll(Collection values) {
@Override
public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
rejectedBlacklistedUrls(request);
+ rejectedUntrustedHosts(request);
if (!isNormalized(request)) {
throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
@@ -272,6 +287,19 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
}
}
+ private void rejectedUntrustedHosts(HttpServletRequest request) {
+ String serverName = request.getServerName();
+ if (serverName == null) {
+ return;
+ }
+ if (this.allowedHostnames == null) {
+ return;
+ }
+ if (!this.allowedHostnames.contains(serverName)) {
+ throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted.");
+ }
+ }
+
@Override
public HttpServletResponse getFirewalledResponse(HttpServletResponse response) {
return new FirewalledResponse(response);
diff --git a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java
index 1c12a66a1b3..46074d24e7d 100644
--- a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java
+++ b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2012-2017 the original author or authors.
+ * Copyright 2012-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
package org.springframework.security.web.firewall;
+import java.util.Arrays;
+
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
@@ -23,6 +25,7 @@
/**
* @author Rob Winch
+ * @author Eddú Meléndez
*/
public class StrictHttpFirewallTests {
public String[] unnormalizedPaths = { "/..", "/./path/", "/path/path/.", "/path/path//.", "./path/../path//.",
@@ -373,4 +376,42 @@ public void getFirewalledRequestWhenAllowUrlEncodedSlashAndUppercaseEncodedPathT
this.firewall.getFirewalledRequest(request);
}
+
+ @Test
+ public void getFirewalledRequestWhenTrustedDomainThenNoException() {
+ String host = "example.org";
+ this.request.addHeader("Host", host);
+ this.firewall.setAllowedHostnames(Arrays.asList(host));
+
+ try {
+ this.firewall.getFirewalledRequest(this.request);
+ } catch (RequestRejectedException fail) {
+ fail("Host " + host + " was rejected");
+ }
+ }
+
+ @Test
+ public void getFirewalledRequestWhenUntrustedDomainThenException() {
+ String host = "example.org";
+ this.request.addHeader("Host", host);
+ this.firewall.setAllowedHostnames(Arrays.asList("myexample.org"));
+
+ try {
+ this.firewall.getFirewalledRequest(this.request);
+ fail("Host " + host + " was accepted");
+ } catch (RequestRejectedException expected) {
+ }
+ }
+
+ @Test
+ public void getFirewalledRequestWhenDefaultsThenNoException() {
+ String host = "example.org";
+ this.request.addHeader("Host", host);
+
+ try {
+ this.firewall.getFirewalledRequest(this.request);
+ } catch (RequestRejectedException fail) {
+ fail("Host " + host + " was rejected");
+ }
+ }
}