Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CORS throwing 500 upon encountering a malformed URL #33682 #33688

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;
import org.springframework.web.util.InvalidUrlException;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

Expand All @@ -30,27 +31,35 @@
* <a href="https://www.w3.org/TR/cors/">CORS W3C recommendation</a>.
*
* @author Sebastien Deleuze
* @author Igor Durbek
* @since 4.2
*/
public abstract class CorsUtils {

/**
* Returns {@code true} if the request is a valid CORS one by checking {@code Origin}
* header presence and ensuring that origins are different.
* Returns {@code IsCorsRequestResult.IS_CORS_REQUEST} if the request is a valid CORS one by checking {@code Origin}
* header presence and ensuring that origins are different. Returns {@code IsCorsRequestResult.IS_NOT_CORS_REQUEST}
* otherwise, or {@code IsCorsRequestResult.MALFORMED_ORIGIN} if the origin url is malformed.
*/
public static boolean isCorsRequest(HttpServletRequest request) {
public static IsCorsRequestResult isCorsRequest(HttpServletRequest request) {
IgorOffline marked this conversation as resolved.
Show resolved Hide resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to introduce a breaking public API change for this bugfix? If so, it's probably worth documenting in the release notes.

String origin = request.getHeader(HttpHeaders.ORIGIN);
if (origin == null) {
return false;
return IsCorsRequestResult.IS_NOT_CORS_REQUEST;
}
try {
UriComponentsBuilder.fromUriString(origin);
}
catch (InvalidUrlException ex) {
return IsCorsRequestResult.MALFORMED_ORIGIN;
}
UriComponents originUrl = UriComponentsBuilder.fromUriString(origin).build();
String scheme = request.getScheme();
String host = request.getServerName();
int port = request.getServerPort();
return !(ObjectUtils.nullSafeEquals(scheme, originUrl.getScheme()) &&
boolean isCorsRequest = !(ObjectUtils.nullSafeEquals(scheme, originUrl.getScheme()) &&
ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));

return isCorsRequest ? IsCorsRequestResult.IS_CORS_REQUEST : IsCorsRequestResult.IS_NOT_CORS_REQUEST;
}

private static int getPort(@Nullable String scheme, int port) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@ public boolean processRequest(@Nullable CorsConfiguration config, HttpServletReq
response.addHeader(HttpHeaders.VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS);
}

if (!CorsUtils.isCorsRequest(request)) {
IsCorsRequestResult isCorsRequestResult = CorsUtils.isCorsRequest(request);
IgorOffline marked this conversation as resolved.
Show resolved Hide resolved
if (isCorsRequestResult == IsCorsRequestResult.IS_NOT_CORS_REQUEST) {
return true;
}
else if (isCorsRequestResult == IsCorsRequestResult.MALFORMED_ORIGIN) {
rejectRequest(new ServletServerHttpResponse(response));
return false;
}

if (response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\"");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.cors;

/**
* Used to enumerate the CORS request result.
*
* @author Igor Durbek
* @since 6.2
*/
public enum IsCorsRequestResult {

/**
* Is CORS request.
*/
IS_CORS_REQUEST,

/**
* Is not a CORS request.
*/
IS_NOT_CORS_REQUEST,

/**
* Invalid origin - reject the request.
* See test for an example of a malformed origin request.
*/
MALFORMED_ORIGIN
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.cors.IsCorsRequestResult;
import org.springframework.web.util.InvalidUrlException;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

Expand All @@ -31,17 +33,28 @@
* <a href="https://www.w3.org/TR/cors/">CORS W3C recommendation</a>.
*
* @author Sebastien Deleuze
* @author Igor Durbek
* @since 5.0
*/
public abstract class CorsUtils {

/**
* Returns {@code true} if the request is a valid CORS one by checking {@code Origin}
* header presence and ensuring that origins are different via {@link #isSameOrigin}.
* Returns {@code IsCorsRequestResult.IS_CORS_REQUEST} if the request is a valid CORS one by checking {@code Origin}
* header presence and ensuring that origins are different via {@link #isSameOrigin}. Returns
* {@code IsCorsRequestResult.IS_NOT_CORS_REQUEST} otherwise, or {@code IsCorsRequestResult.MALFORMED_ORIGIN}
* in case the origin url is malformed.
*/
@SuppressWarnings("deprecation")
public static boolean isCorsRequest(ServerHttpRequest request) {
return request.getHeaders().containsKey(HttpHeaders.ORIGIN) && !isSameOrigin(request);
public static IsCorsRequestResult isCorsRequest(ServerHttpRequest request) {
boolean containsOrigin = request.getHeaders().containsKey(HttpHeaders.ORIGIN);
try {
boolean hasDifferentOrigin = !isSameOrigin(request);
boolean originOk = containsOrigin && hasDifferentOrigin;
return originOk ? IsCorsRequestResult.IS_CORS_REQUEST : IsCorsRequestResult.IS_NOT_CORS_REQUEST;
}
catch (InvalidUrlException ex) {
return IsCorsRequestResult.MALFORMED_ORIGIN;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.IsCorsRequestResult;
import org.springframework.web.server.ServerWebExchange;

/**
Expand All @@ -43,6 +44,7 @@
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
* @author Igor Durbek
* @since 5.0
*/
public class DefaultCorsProcessor implements CorsProcessor {
Expand Down Expand Up @@ -83,9 +85,14 @@ public boolean process(@Nullable CorsConfiguration config, ServerWebExchange exc
}
}

if (!CorsUtils.isCorsRequest(request)) {
IsCorsRequestResult isCorsRequestResult = CorsUtils.isCorsRequest(request);
if (isCorsRequestResult == IsCorsRequestResult.IS_NOT_CORS_REQUEST) {
return true;
}
else if (isCorsRequestResult == IsCorsRequestResult.MALFORMED_ORIGIN) {
rejectRequest(response);
return false;
}

if (responseHeaders.getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\"");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ class CorsUtilsTests {
void isCorsRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader(HttpHeaders.ORIGIN, "https://domain.com");
assertThat(CorsUtils.isCorsRequest(request)).isTrue();
assertThat(CorsUtils.isCorsRequest(request)).isEqualTo(IsCorsRequestResult.IS_CORS_REQUEST);
}

@Test
void isNotCorsRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
assertThat(CorsUtils.isCorsRequest(request)).isFalse();
assertThat(CorsUtils.isCorsRequest(request)).isEqualTo(IsCorsRequestResult.IS_NOT_CORS_REQUEST);
}

@Test
void isMalformedUrlCorsRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader(HttpHeaders.ORIGIN, "http://*@:;");
assertThat(CorsUtils.isCorsRequest(request)).isEqualTo(IsCorsRequestResult.MALFORMED_ORIGIN);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DefaultCorsProcessorTests {
void setup() {
this.request = new MockHttpServletRequest();
this.request.setRequestURI("/test.html");
this.request.setServerName("domain1.example");
this.request.setServerName("domain1.example.com");
this.conf = new CorsConfiguration();
this.response = new MockHttpServletResponse();
this.response.setStatus(HttpServletResponse.SC_OK);
Expand All @@ -73,7 +73,7 @@ void requestWithoutOriginHeader() throws Exception {
@Test
void sameOriginRequest() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain1.example");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain1.example.com");

this.processor.processRequest(this.conf, this.request, this.response);
assertThat(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isFalse();
Expand All @@ -82,6 +82,18 @@ void sameOriginRequest() throws Exception {
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
}

@Test
void invalidOriginRequest() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain1.example");

this.processor.processRequest(this.conf, this.request, this.response);
assertThat(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isFalse();
assertThat(this.response.getHeaders(HttpHeaders.VARY)).contains(HttpHeaders.ORIGIN,
HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
}

@Test
void actualRequestWithOriginHeader() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.cors.IsCorsRequestResult;
import org.springframework.web.server.adapter.ForwardedHeaderTransformer;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
import org.springframework.web.testfixture.server.MockServerWebExchange;
Expand All @@ -39,13 +40,19 @@ class CorsUtilsTests {
@Test
void isCorsRequest() {
ServerHttpRequest request = get("http://domain.example/").header(HttpHeaders.ORIGIN, "https://domain.com").build();
assertThat(CorsUtils.isCorsRequest(request)).isTrue();
assertThat(CorsUtils.isCorsRequest(request)).isEqualTo(IsCorsRequestResult.IS_CORS_REQUEST);
}

@Test
void isNotCorsRequest() {
ServerHttpRequest request = get("/").build();
assertThat(CorsUtils.isCorsRequest(request)).isFalse();
assertThat(CorsUtils.isCorsRequest(request)).isEqualTo(IsCorsRequestResult.IS_NOT_CORS_REQUEST);
}

@Test
void isMalformedOriginCorsRequest() {
ServerHttpRequest request = get("http://example.com").header(HttpHeaders.ORIGIN, "http://*@:;").build();
assertThat(CorsUtils.isCorsRequest(request)).isEqualTo(IsCorsRequestResult.MALFORMED_ORIGIN);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.stream.Stream;

import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Named;

import org.springframework.beans.DirectFieldAccessor;
Expand Down Expand Up @@ -92,11 +91,9 @@ static Stream<Named<TestRequestMappingInfoHandlerMapping>> pathPatternsArguments
return Stream.of(named("PathPatternParser", mapping1), named("AntPathMatcher", mapping2));
}


private final MockHttpServletRequest request = new MockHttpServletRequest();


@BeforeEach
void setup() {
this.request.setMethod("GET");
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain.com/");
Expand Down Expand Up @@ -135,6 +132,7 @@ void noAnnotationWithPreflightRequest(TestRequestMappingInfoHandlerMapping mappi

@PathPatternsParameterizedTest // SPR-12931
void noAnnotationWithOrigin(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/no");
HandlerExecutionChain chain = mapping.getHandler(request);
Expand All @@ -152,6 +150,7 @@ void noAnnotationPostWithOrigin(TestRequestMappingInfoHandlerMapping mapping) th

@PathPatternsParameterizedTest
void defaultAnnotation(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/default");
HandlerExecutionChain chain = mapping.getHandler(request);
Expand All @@ -167,6 +166,7 @@ void defaultAnnotation(TestRequestMappingInfoHandlerMapping mapping) throws Exce

@PathPatternsParameterizedTest
void customized(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/customized");
HandlerExecutionChain chain = mapping.getHandler(request);
Expand Down Expand Up @@ -250,6 +250,7 @@ void allowCredentialsWithWildcardOrigin(TestRequestMappingInfoHandlerMapping map

@PathPatternsParameterizedTest
void classLevel(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new ClassLevelController());

this.request.setRequestURI("/foo");
Expand Down Expand Up @@ -280,6 +281,7 @@ void classLevel(TestRequestMappingInfoHandlerMapping mapping) throws Exception {

@PathPatternsParameterizedTest // SPR-13468
void classLevelComposedAnnotation(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new ClassLevelMappingWithComposedAnnotation());

this.request.setRequestURI("/foo");
Expand All @@ -293,6 +295,7 @@ void classLevelComposedAnnotation(TestRequestMappingInfoHandlerMapping mapping)

@PathPatternsParameterizedTest // SPR-13468
void methodLevelComposedAnnotation(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelMappingWithComposedAnnotation());

this.request.setRequestURI("/foo");
Expand All @@ -306,6 +309,7 @@ void methodLevelComposedAnnotation(TestRequestMappingInfoHandlerMapping mapping)

@PathPatternsParameterizedTest
void preFlightRequest(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelController());
this.request.setMethod("OPTIONS");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
Expand All @@ -323,6 +327,7 @@ void preFlightRequest(TestRequestMappingInfoHandlerMapping mapping) throws Excep

@PathPatternsParameterizedTest
void ambiguousHeaderPreFlightRequest(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelController());
this.request.setMethod("OPTIONS");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
Expand All @@ -342,6 +347,7 @@ void ambiguousHeaderPreFlightRequest(TestRequestMappingInfoHandlerMapping mappin

@PathPatternsParameterizedTest
void ambiguousProducesPreFlightRequest(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MethodLevelController());
this.request.setMethod("OPTIONS");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
Expand All @@ -367,6 +373,7 @@ void preFlightRequestWithoutRequestMethodHeader(TestRequestMappingInfoHandlerMap

@PathPatternsParameterizedTest
void maxAgeWithDefaultOrigin(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
setup();
mapping.registerHandler(new MaxAgeWithDefaultOriginController());

this.request.setRequestURI("/classAge");
Expand Down