diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 9c27ea4e33b0..c3a639a7b3ad 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -16,17 +16,6 @@ package org.springframework.web.socket.sockjs.transport; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ScheduledFuture; - import org.springframework.context.Lifecycle; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; @@ -46,6 +35,17 @@ import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.support.AbstractSockJsService; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledFuture; + /** * A basic implementation of {@link org.springframework.web.socket.sockjs.SockJsService} * with support for SPI-based transport handling and session management. @@ -322,7 +322,8 @@ else if (transportType.supportsCors()) { @Override protected boolean validateRequest(String serverId, String sessionId, String transport) { - if (!getAllowedOrigins().contains("*") && !TransportType.fromValue(transport).supportsOrigin()) { + TransportType transportType = TransportType.fromValue(transport); + if (!getAllowedOrigins().contains("*") && (transportType == null || !transportType.supportsOrigin())) { if (logger.isWarnEnabled()) { logger.warn("Origin check has been enabled, but transport " + transport + " does not support it"); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 6b07be824116..239e107a699e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -16,19 +16,10 @@ package org.springframework.web.socket.sockjs.transport.handler; -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; - import org.springframework.http.HttpHeaders; import org.springframework.scheduling.TaskScheduler; import org.springframework.web.socket.AbstractHttpRequestTests; @@ -43,6 +34,24 @@ import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.reset; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoMoreInteractions; + /** * Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}. * @@ -186,6 +195,18 @@ public void handleTransportRequestXhrSameOrigin() throws Exception { assertEquals(200, this.servletResponse.getStatus()); } + @Test // SPR-13545 + public void handleInvalidTransportType() throws Exception { + String sockJsPath = sessionUrlPrefix + "invalid"; + setRequest("POST", sockJsPrefix + sockJsPath); + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); + this.servletRequest.setServerName("mydomain2.com"); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertEquals(404, this.servletResponse.getStatus()); + } + @Test public void handleTransportRequestXhrOptions() throws Exception { String sockJsPath = sessionUrlPrefix + "xhr";