Skip to content

Commit 68ecb92

Browse files
committed
Allow "ws" and "wss" for isValidCorsOrigin checks
Issue: SPR-12956
1 parent 222f699 commit 68ecb92

File tree

7 files changed

+117
-148
lines changed

7 files changed

+117
-148
lines changed

spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,29 @@ public static UriComponentsBuilder fromHttpRequest(HttpRequest request) {
317317
}
318318

319319

320+
/**
321+
* Create an instance by parsing the "origin" header of an HTTP request.
322+
*/
323+
public static UriComponentsBuilder fromOriginHeader(String origin) {
324+
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
325+
if (StringUtils.hasText(origin)) {
326+
int schemaIdx = origin.indexOf("://");
327+
String schema = (schemaIdx != -1 ? origin.substring(0, schemaIdx) : "http");
328+
builder.scheme(schema);
329+
String hostString = (schemaIdx != -1 ? origin.substring(schemaIdx + 3) : origin);
330+
if (hostString.contains(":")) {
331+
String[] hostAndPort = StringUtils.split(hostString, ":");
332+
builder.host(hostAndPort[0]);
333+
builder.port(Integer.parseInt(hostAndPort[1]));
334+
}
335+
else {
336+
builder.host(hostString);
337+
}
338+
}
339+
return builder;
340+
}
341+
342+
320343
// build methods
321344

322345
/**

spring-web/src/main/java/org/springframework/web/util/WebUtils.java

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import java.util.StringTokenizer;
2525
import java.util.TreeMap;
26+
2627
import javax.servlet.ServletContext;
2728
import javax.servlet.ServletRequest;
2829
import javax.servlet.ServletRequestWrapper;
@@ -38,6 +39,7 @@
3839

3940
import org.springframework.http.HttpRequest;
4041
import org.springframework.util.Assert;
42+
import org.springframework.util.CollectionUtils;
4143
import org.springframework.util.LinkedMultiValueMap;
4244
import org.springframework.util.MultiValueMap;
4345
import org.springframework.util.StringUtils;
@@ -790,21 +792,10 @@ public static boolean isValidOrigin(HttpRequest request, Collection<String> allo
790792
if (origin == null || allowedOrigins.contains("*")) {
791793
return true;
792794
}
793-
else if (allowedOrigins.isEmpty()) {
794-
UriComponents originComponents;
795-
try {
796-
originComponents = UriComponentsBuilder.fromHttpUrl(origin).build();
797-
}
798-
catch (IllegalArgumentException ex) {
799-
if (logger.isWarnEnabled()) {
800-
logger.warn("Failed to parse Origin header value [" + origin + "]");
801-
}
802-
return false;
803-
}
804-
UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build();
805-
int originPort = getPort(originComponents);
806-
int requestPort = getPort(requestComponents);
807-
return (originComponents.getHost().equals(requestComponents.getHost()) && originPort == requestPort);
795+
else if (CollectionUtils.isEmpty(allowedOrigins)) {
796+
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
797+
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
798+
return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl));
808799
}
809800
else {
810801
return allowedOrigins.contains(origin);
@@ -814,10 +805,10 @@ else if (allowedOrigins.isEmpty()) {
814805
private static int getPort(UriComponents component) {
815806
int port = component.getPort();
816807
if (port == -1) {
817-
if ("http".equals(component.getScheme())) {
808+
if ("http".equals(component.getScheme()) || "ws".equals(component.getScheme())) {
818809
port = 80;
819810
}
820-
else if ("https".equals(component.getScheme())) {
811+
else if ("https".equals(component.getScheme()) || "wss".equals(component.getScheme())) {
821812
port = 443;
822813
}
823814
}

spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.web.util;
1818

19-
import java.util.ArrayList;
2019
import java.util.Arrays;
20+
import java.util.Collections;
2121
import java.util.HashMap;
2222
import java.util.List;
2323
import java.util.Map;
@@ -106,60 +106,45 @@ public void parseMatrixVariablesString() {
106106
}
107107

108108
@Test
109-
public void isValidOrigin() {
110-
List<String> allowedOrigins = new ArrayList<>();
109+
public void isValidOriginSuccess() {
110+
111+
List<String> allowed = Collections.emptyList();
112+
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed));
113+
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com:80", allowed));
114+
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com", allowed));
115+
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com:443", allowed));
116+
assertTrue(checkOrigin("mydomain1.com", 123, "http://mydomain1.com:123", allowed));
117+
assertTrue(checkOrigin("mydomain1.com", -1, "ws://mydomain1.com", allowed));
118+
assertTrue(checkOrigin("mydomain1.com", 443, "wss://mydomain1.com", allowed));
119+
120+
allowed = Collections.singletonList("*");
121+
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
122+
123+
allowed = Collections.singletonList("http://mydomain1.com");
124+
assertTrue(checkOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed));
125+
}
126+
127+
@Test
128+
public void isValidOriginFailure() {
129+
130+
List<String> allowed = Collections.emptyList();
131+
assertFalse(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
132+
assertFalse(checkOrigin("mydomain1.com", -1, "https://mydomain1.com", allowed));
133+
assertFalse(checkOrigin("mydomain1.com", -1, "invalid-origin", allowed));
134+
135+
allowed = Collections.singletonList("http://mydomain1.com");
136+
assertFalse(checkOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed));
137+
}
138+
139+
private boolean checkOrigin(String serverName, int port, String originHeader, List<String> allowed) {
111140
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
112141
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
113-
114-
servletRequest.setServerName("mydomain1.com");
115-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
116-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
117-
118-
servletRequest.setServerName("mydomain1.com");
119-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80");
120-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
121-
122-
servletRequest.setServerName("mydomain1.com");
123-
servletRequest.setServerPort(443);
124-
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
125-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
126-
127-
servletRequest.setServerName("mydomain1.com");
128-
servletRequest.setServerPort(443);
129-
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443");
130-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
131-
132-
servletRequest.setServerName("mydomain1.com");
133-
servletRequest.setServerPort(123);
134-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123");
135-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
136-
137-
servletRequest.setServerName("mydomain1.com");
138-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
139-
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
140-
141-
servletRequest.setServerName("mydomain1.com");
142-
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
143-
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
144-
145-
servletRequest.setServerName("invalid-origin");
146-
request.getHeaders().set(HttpHeaders.ORIGIN, "invalid-origin");
147-
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
148-
149-
allowedOrigins = Arrays.asList("*");
150-
servletRequest.setServerName("mydomain1.com");
151-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
152-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
153-
154-
allowedOrigins = Arrays.asList("http://mydomain1.com");
155-
servletRequest.setServerName("mydomain2.com");
156-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
157-
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
158-
159-
allowedOrigins = Arrays.asList("http://mydomain1.com");
160-
servletRequest.setServerName("mydomain2.com");
161-
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com");
162-
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
142+
servletRequest.setServerName(serverName);
143+
if (port != -1) {
144+
servletRequest.setServerPort(port);
145+
}
146+
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
147+
return WebUtils.isValidOrigin(request, allowed);
163148
}
164149

165150
}

spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,18 @@ public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
6565
}
6666

6767
/**
68-
* Configure allowed {@code Origin} header values. This check is mostly designed for
69-
* browser clients. There is nothing preventing other types of client to modify the
70-
* {@code Origin} header value.
68+
* Configure allowed {@code Origin} header values. This check is mostly
69+
* designed for browsers. There is nothing preventing other types of client
70+
* to modify the {@code Origin} header value.
7171
*
72-
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
73-
* (means that all origins are allowed).
72+
* <p>Each provided allowed origin must have a scheme, and optionally a port
73+
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
74+
* string may also be "*" in which case all origins are allowed.
7475
*
7576
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
7677
*/
7778
public void setAllowedOrigins(Collection<String> allowedOrigins) {
7879
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
79-
for (String allowedOrigin : allowedOrigins) {
80-
Assert.isTrue(allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
81-
allowedOrigin.startsWith("https://"), "Invalid allowed origin provided: \"" +
82-
allowedOrigin + "\". It must start with \"http://\", \"https://\" or be \"*\"");
83-
}
8480
this.allowedOrigins.clear();
8581
this.allowedOrigins.addAll(allowedOrigins);
8682
}
@@ -93,6 +89,7 @@ public Collection<String> getAllowedOrigins() {
9389
return Collections.unmodifiableList(this.allowedOrigins);
9490
}
9591

92+
9693
@Override
9794
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
9895
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -276,31 +276,25 @@ public boolean isWebSocketEnabled() {
276276
}
277277

278278
/**
279-
* Configure allowed {@code Origin} header values. This check is mostly designed for
280-
* browser clients. There is nothing preventing other types of client to modify the
281-
* {@code Origin} header value.
279+
* Configure allowed {@code Origin} header values. This check is mostly
280+
* designed for browsers. There is nothing preventing other types of client
281+
* to modify the {@code Origin} header value.
282282
*
283-
* <p>When SockJS is enabled and origins are restricted, transport types that do not
284-
* allow to check request origin (JSONP and Iframe based transports) are disabled.
285-
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
283+
* <p>When SockJS is enabled and origins are restricted, transport types
284+
* that do not allow to check request origin (JSONP and Iframe based
285+
* transports) are disabled. As a consequence, IE 6 to 9 are not supported
286+
* when origins are restricted.
286287
*
287-
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
288-
* (means that all origins are allowed).
288+
* <p>Each provided allowed origin must have a scheme, and optionally a port
289+
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
290+
* string may also be "*" in which case all origins are allowed.
289291
*
290292
* @since 4.1.2
291293
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
292294
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
293295
*/
294296
public void setAllowedOrigins(List<String> allowedOrigins) {
295297
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
296-
for (String allowedOrigin : allowedOrigins) {
297-
Assert.isTrue(
298-
allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
299-
allowedOrigin.startsWith("https://"),
300-
"Invalid allowed origin provided: \"" +
301-
allowedOrigin +
302-
"\". It must start with \"http://\", \"https://\" or be \"*\"");
303-
}
304298
this.allowedOrigins.clear();
305299
this.allowedOrigins.addAll(allowedOrigins);
306300
}
@@ -451,7 +445,9 @@ protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
451445
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
452446
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
453447

454-
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException {
448+
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
449+
HttpMethod... httpMethods) throws IOException {
450+
455451
String origin = request.getHeaders().getOrigin();
456452

457453
if (origin == null) {
@@ -514,7 +510,8 @@ public void handle(ServerHttpRequest request, ServerHttpResponse response) throw
514510
addNoCacheHeaders(response);
515511
if (checkOrigin(request, response)) {
516512
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
517-
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
513+
String content = String.format(INFO_CONTENT, random.nextInt(),
514+
isSessionCookieNeeded(), isWebSocketEnabled());
518515
response.getBody().write(content.getBytes());
519516
}
520517

spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.web.socket.server.support;
1818

1919
import java.util.Arrays;
20+
import java.util.Collections;
2021
import java.util.HashMap;
22+
import java.util.List;
2123
import java.util.Map;
2224
import java.util.Set;
2325
import java.util.concurrent.ConcurrentSkipListSet;
@@ -39,31 +41,17 @@
3941
public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
4042

4143
@Test(expected = IllegalArgumentException.class)
42-
public void nullAllowedOriginList() {
44+
public void invalidInput() {
4345
new OriginHandshakeInterceptor(null);
4446
}
4547

46-
@Test(expected = IllegalArgumentException.class)
47-
public void invalidAllowedOrigin() {
48-
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
49-
}
50-
51-
@Test
52-
public void emtpyAllowedOriginList() {
53-
new OriginHandshakeInterceptor(Arrays.asList());
54-
}
55-
56-
@Test
57-
public void validAllowedOrigins() {
58-
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
59-
}
60-
6148
@Test
6249
public void originValueMatch() throws Exception {
6350
Map<String, Object> attributes = new HashMap<String, Object>();
6451
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
6552
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
66-
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
53+
List<String> allowed = Collections.singletonList("http://mydomain1.com");
54+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
6755
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
6856
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
6957
}
@@ -73,7 +61,8 @@ public void originValueNoMatch() throws Exception {
7361
Map<String, Object> attributes = new HashMap<String, Object>();
7462
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
7563
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
76-
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
64+
List<String> allowed = Collections.singletonList("http://mydomain2.com");
65+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
7766
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
7867
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
7968
}
@@ -83,7 +72,8 @@ public void originListMatch() throws Exception {
8372
Map<String, Object> attributes = new HashMap<String, Object>();
8473
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
8574
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
86-
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
75+
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
76+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
8777
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
8878
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
8979
}
@@ -93,7 +83,8 @@ public void originListNoMatch() throws Exception {
9383
Map<String, Object> attributes = new HashMap<String, Object>();
9484
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
9585
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
96-
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
86+
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
87+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
9788
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
9889
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
9990
}
@@ -117,7 +108,7 @@ public void originMatchAll() throws Exception {
117108
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
118109
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
119110
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
120-
interceptor.setAllowedOrigins(Arrays.asList("*"));
111+
interceptor.setAllowedOrigins(Collections.singletonList("*"));
121112
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
122113
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
123114
}
@@ -128,7 +119,7 @@ public void sameOriginMatch() throws Exception {
128119
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
129120
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
130121
this.servletRequest.setServerName("mydomain2.com");
131-
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
122+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
132123
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
133124
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
134125
}
@@ -139,7 +130,7 @@ public void sameOriginNoMatch() throws Exception {
139130
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
140131
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
141132
this.servletRequest.setServerName("mydomain2.com");
142-
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
133+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
143134
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
144135
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
145136
}

0 commit comments

Comments
 (0)