diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java index 5e762118d8d..e85ce6fe390 100644 --- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java +++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java @@ -228,10 +228,15 @@ public String getQueryString() { public void setQueryString(String queryString) { this.queryString = queryString; } + + @Override + public String getServerName() { + return null; + } } final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler { public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { throw new UnsupportedOperationException(method + " is not supported"); } -} \ No newline at end of file +} diff --git a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java index 1746546b3ab..6d0e8e3e93f 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java +++ b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2017 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ package org.springframework.security.web.firewall; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; /** *

@@ -59,10 +59,15 @@ * Rejects URLs that contain a URL encoded percent. See * {@link #setAllowUrlEncodedPercent(boolean)} * + *

  • + * Rejects hosts that are not allowed. See + * {@link #setAllowedHostnames(Collection)} + *
  • * * * @see DefaultHttpFirewall * @author Rob Winch + * @author Eddú Meléndez * @since 4.2.4 */ public class StrictHttpFirewall implements HttpFirewall { @@ -82,6 +87,8 @@ public class StrictHttpFirewall implements HttpFirewall { private Set decodedUrlBlacklist = new HashSet(); + private Collection allowedHostnames; + public StrictHttpFirewall() { urlBlacklistsAddAll(FORBIDDEN_SEMICOLON); urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH); @@ -230,6 +237,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) { } } + public void setAllowedHostnames(Collection allowedHostnames) { + if (allowedHostnames == null) { + throw new IllegalArgumentException("allowedHostnames cannot be null"); + } + this.allowedHostnames = allowedHostnames; + } + private void urlBlacklistsAddAll(Collection values) { this.encodedUrlBlacklist.addAll(values); this.decodedUrlBlacklist.addAll(values); @@ -243,6 +257,7 @@ private void urlBlacklistsRemoveAll(Collection values) { @Override public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException { rejectedBlacklistedUrls(request); + rejectedUntrustedHosts(request); if (!isNormalized(request)) { throw new RequestRejectedException("The request was rejected because the URL was not normalized."); @@ -272,6 +287,19 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) { } } + private void rejectedUntrustedHosts(HttpServletRequest request) { + String serverName = request.getServerName(); + if (serverName == null) { + return; + } + if (this.allowedHostnames == null) { + return; + } + if (!this.allowedHostnames.contains(serverName)) { + throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted."); + } + } + @Override public HttpServletResponse getFirewalledResponse(HttpServletResponse response) { return new FirewalledResponse(response); diff --git a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java index 1c12a66a1b3..46074d24e7d 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2017 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.security.web.firewall; +import java.util.Arrays; + import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; @@ -23,6 +25,7 @@ /** * @author Rob Winch + * @author Eddú Meléndez */ public class StrictHttpFirewallTests { public String[] unnormalizedPaths = { "/..", "/./path/", "/path/path/.", "/path/path//.", "./path/../path//.", @@ -373,4 +376,42 @@ public void getFirewalledRequestWhenAllowUrlEncodedSlashAndUppercaseEncodedPathT this.firewall.getFirewalledRequest(request); } + + @Test + public void getFirewalledRequestWhenTrustedDomainThenNoException() { + String host = "example.org"; + this.request.addHeader("Host", host); + this.firewall.setAllowedHostnames(Arrays.asList(host)); + + try { + this.firewall.getFirewalledRequest(this.request); + } catch (RequestRejectedException fail) { + fail("Host " + host + " was rejected"); + } + } + + @Test + public void getFirewalledRequestWhenUntrustedDomainThenException() { + String host = "example.org"; + this.request.addHeader("Host", host); + this.firewall.setAllowedHostnames(Arrays.asList("myexample.org")); + + try { + this.firewall.getFirewalledRequest(this.request); + fail("Host " + host + " was accepted"); + } catch (RequestRejectedException expected) { + } + } + + @Test + public void getFirewalledRequestWhenDefaultsThenNoException() { + String host = "example.org"; + this.request.addHeader("Host", host); + + try { + this.firewall.getFirewalledRequest(this.request); + } catch (RequestRejectedException fail) { + fail("Host " + host + " was rejected"); + } + } }