Skip to content

Commit 5f21060

Browse files
committed
Add UpgradeRequestStrategy for WildFly/Undertow
Issue: SPR-11237
1 parent 5ee89a3 commit 5f21060

File tree

11 files changed

+342
-37
lines changed

11 files changed

+342
-37
lines changed

build.gradle

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,13 @@ project("spring-websocket") {
593593
exclude group: "javax.servlet", module: "javax.servlet"
594594
}
595595
optional("org.eclipse.jetty.websocket:websocket-client:${jettyVersion}")
596+
optional("io.undertow:undertow-core:1.0.0.Beta31")
597+
optional("io.undertow:undertow-servlet:1.0.0.Beta31") {
598+
exclude group: "org.jboss.spec.javax.servlet", module: "jboss-servlet-api_3.1_spec"
599+
}
600+
optional("io.undertow:undertow-websockets-jsr:1.0.0.Beta31") {
601+
exclude group: "org.jboss.spec.javax.websocket", module: "jboss-websocket-api_1.0_spec"
602+
}
596603
optional("com.fasterxml.jackson.core:jackson-databind:2.3.0")
597604
testCompile("org.apache.tomcat.embed:tomcat-embed-core:8.0.0-RC10")
598605
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")

spring-websocket/src/main/java/org/springframework/web/socket/server/standard/AbstractStandardUpgradeStrategy.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import javax.servlet.ServletContext;
2525
import javax.servlet.http.HttpServletRequest;
26+
import javax.servlet.http.HttpServletResponse;
2627
import javax.websocket.Endpoint;
2728
import javax.websocket.Extension;
2829
import javax.websocket.WebSocketContainer;
@@ -35,6 +36,7 @@
3536
import org.springframework.http.server.ServerHttpRequest;
3637
import org.springframework.http.server.ServerHttpResponse;
3738
import org.springframework.http.server.ServletServerHttpRequest;
39+
import org.springframework.http.server.ServletServerHttpResponse;
3840
import org.springframework.util.Assert;
3941
import org.springframework.web.socket.WebSocketExtension;
4042
import org.springframework.web.socket.WebSocketHandler;
@@ -85,6 +87,16 @@ protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer con
8587
return result;
8688
}
8789

90+
protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
91+
Assert.isTrue(response instanceof ServletServerHttpResponse);
92+
return ((ServletServerHttpResponse) response).getServletResponse();
93+
}
94+
95+
protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
96+
Assert.isTrue(request instanceof ServletServerHttpRequest);
97+
return ((ServletServerHttpRequest) request).getServletRequest();
98+
}
99+
88100
@Override
89101
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
90102
String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,

spring-websocket/src/main/java/org/springframework/web/socket/server/standard/GlassFishRequestUpgradeStrategy.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2013 the original author or authors.
2+
* Copyright 2002-2014 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -45,9 +45,6 @@
4545
import org.springframework.http.HttpHeaders;
4646
import org.springframework.http.server.ServerHttpRequest;
4747
import org.springframework.http.server.ServerHttpResponse;
48-
import org.springframework.http.server.ServletServerHttpRequest;
49-
import org.springframework.http.server.ServletServerHttpResponse;
50-
import org.springframework.util.Assert;
5148
import org.springframework.util.ReflectionUtils;
5249
import org.springframework.util.StringUtils;
5350
import org.springframework.web.socket.WebSocketExtension;
@@ -120,11 +117,8 @@ public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse respon
120117
String selectedProtocol, List<Extension> selectedExtensions,
121118
Endpoint endpoint) throws HandshakeFailureException {
122119

123-
Assert.isTrue(request instanceof ServletServerHttpRequest);
124-
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
125-
126-
Assert.isTrue(response instanceof ServletServerHttpResponse);
127-
HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
120+
HttpServletRequest servletRequest = getHttpServletRequest(request);
121+
HttpServletResponse servletResponse = getHttpServletResponse(response);
128122

129123
WebSocketApplication webSocketApplication = createTyrusEndpoint(endpoint, selectedProtocol, selectedExtensions);
130124

spring-websocket/src/main/java/org/springframework/web/socket/server/standard/TomcatRequestUpgradeStrategy.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2013 the original author or authors.
2+
* Copyright 2002-2014 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,9 +31,6 @@
3131

3232
import org.springframework.http.server.ServerHttpRequest;
3333
import org.springframework.http.server.ServerHttpResponse;
34-
import org.springframework.http.server.ServletServerHttpRequest;
35-
import org.springframework.http.server.ServletServerHttpResponse;
36-
import org.springframework.util.Assert;
3734
import org.springframework.web.socket.server.HandshakeFailureException;
3835

3936
/**
@@ -60,11 +57,8 @@ public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse respon
6057
String selectedProtocol, List<Extension> selectedExtensions,
6158
Endpoint endpoint) throws HandshakeFailureException {
6259

63-
Assert.isTrue(request instanceof ServletServerHttpRequest);
64-
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
65-
66-
Assert.isTrue(response instanceof ServletServerHttpResponse);
67-
HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
60+
HttpServletRequest servletRequest = getHttpServletRequest(request);
61+
HttpServletResponse servletResponse = getHttpServletResponse(response);
6862

6963
StringBuffer requestUrl = servletRequest.getRequestURL();
7064
String path = servletRequest.getRequestURI(); // shouldn't matter
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright 2002-2014 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.web.socket.server.standard;
18+
19+
import io.undertow.server.HttpServerExchange;
20+
import io.undertow.server.HttpUpgradeListener;
21+
import io.undertow.servlet.api.InstanceFactory;
22+
import io.undertow.servlet.api.InstanceHandle;
23+
import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
24+
import io.undertow.websockets.core.WebSocketChannel;
25+
import io.undertow.websockets.core.protocol.Handshake;
26+
import io.undertow.websockets.core.protocol.version07.Hybi07Handshake;
27+
import io.undertow.websockets.core.protocol.version08.Hybi08Handshake;
28+
import io.undertow.websockets.core.protocol.version13.Hybi13Handshake;
29+
import io.undertow.websockets.jsr.ConfiguredServerEndpoint;
30+
import io.undertow.websockets.jsr.EncodingFactory;
31+
import io.undertow.websockets.jsr.EndpointSessionHandler;
32+
import io.undertow.websockets.jsr.ServerWebSocketContainer;
33+
import io.undertow.websockets.jsr.handshake.HandshakeUtil;
34+
import org.springframework.http.server.ServerHttpRequest;
35+
import org.springframework.http.server.ServerHttpResponse;
36+
import org.springframework.web.socket.server.HandshakeFailureException;
37+
import org.xnio.StreamConnection;
38+
39+
import javax.servlet.http.HttpServletRequest;
40+
import javax.servlet.http.HttpServletResponse;
41+
import javax.websocket.Decoder;
42+
import javax.websocket.Encoder;
43+
import javax.websocket.Endpoint;
44+
import javax.websocket.Extension;
45+
import java.util.*;
46+
47+
48+
/**
49+
* A {@link org.springframework.web.socket.server.RequestUpgradeStrategy} for use
50+
* with WildFly and its underlying Undertow web server.
51+
*
52+
* @author Rossen Stoyanchev
53+
* @since 4.0.1
54+
*/
55+
public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
56+
57+
private final Handshake[] handshakes;
58+
59+
private final String[] supportedVersions;
60+
61+
62+
public UndertowRequestUpgradeStrategy() {
63+
this.handshakes = new Handshake[] { new Hybi13Handshake(), new Hybi08Handshake(), new Hybi07Handshake() };
64+
this.supportedVersions = initSupportedVersions(this.handshakes);
65+
}
66+
67+
private String[] initSupportedVersions(Handshake[] handshakes) {
68+
String[] versions = new String[handshakes.length];
69+
for (int i=0; i < versions.length; i++) {
70+
versions[i] = handshakes[i].getVersion().toHttpHeaderValue();
71+
}
72+
return versions;
73+
}
74+
75+
@Override
76+
public String[] getSupportedVersions() {
77+
return this.supportedVersions;
78+
}
79+
80+
@Override
81+
protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol,
82+
List<Extension> selectedExtensions, final Endpoint endpoint) throws HandshakeFailureException {
83+
84+
HttpServletRequest servletRequest = getHttpServletRequest(request);
85+
HttpServletResponse servletResponse = getHttpServletResponse(response);
86+
87+
final ServletWebSocketHttpExchange exchange = new ServletWebSocketHttpExchange(servletRequest, servletResponse);
88+
exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.<String, String>emptyMap());
89+
90+
ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest);
91+
final EndpointSessionHandler endpointSessionHandler = new EndpointSessionHandler(wsContainer);
92+
93+
final Handshake handshake = getHandshakeToUse(exchange);
94+
95+
final ConfiguredServerEndpoint configuredServerEndpoint = createConfiguredServerEndpoint(
96+
selectedProtocol, selectedExtensions, endpoint, servletRequest);
97+
98+
exchange.upgradeChannel(new HttpUpgradeListener() {
99+
@Override
100+
public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) {
101+
WebSocketChannel channel = handshake.createChannel(exchange, connection, exchange.getBufferPool());
102+
HandshakeUtil.setConfig(channel, configuredServerEndpoint);
103+
endpointSessionHandler.onConnect(exchange, channel);
104+
}
105+
});
106+
107+
handshake.handshake(exchange);
108+
}
109+
110+
private Handshake getHandshakeToUse(ServletWebSocketHttpExchange exchange) {
111+
for (Handshake handshake : this.handshakes) {
112+
if (handshake.matches(exchange)) {
113+
return handshake;
114+
}
115+
}
116+
// Should never occur
117+
throw new HandshakeFailureException("No matching Undertow Handshake found: " + exchange.getRequestHeaders());
118+
}
119+
120+
private ConfiguredServerEndpoint createConfiguredServerEndpoint(String selectedProtocol,
121+
List<Extension> selectedExtensions, Endpoint endpoint, HttpServletRequest servletRequest) {
122+
123+
String path = servletRequest.getRequestURI(); // shouldn't matter
124+
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration(path, endpoint);
125+
endpointRegistration.setSubprotocols(Arrays.asList(selectedProtocol));
126+
endpointRegistration.setExtensions(selectedExtensions);
127+
128+
return new ConfiguredServerEndpoint(endpointRegistration,
129+
new EndpointInstanceFactory(endpoint), null,
130+
new EncodingFactory(
131+
Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(),
132+
Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap(),
133+
Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(),
134+
Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap()));
135+
}
136+
137+
138+
private static class EndpointInstanceFactory implements InstanceFactory<Endpoint> {
139+
140+
private final Endpoint endpoint;
141+
142+
public EndpointInstanceFactory(Endpoint endpoint) {
143+
this.endpoint = endpoint;
144+
}
145+
146+
@Override
147+
public InstanceHandle<Endpoint> createInstance() throws InstantiationException {
148+
149+
return new InstanceHandle<Endpoint>() {
150+
@Override
151+
public Endpoint getInstance() {
152+
return endpoint;
153+
}
154+
@Override
155+
public void release() {
156+
}
157+
};
158+
}
159+
}
160+
161+
}

spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,17 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
6262

6363
protected Log logger = LogFactory.getLog(getClass());
6464

65+
private static final boolean glassFishWsPresent = ClassUtils.isPresent(
66+
"org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader());
67+
6568
private static final boolean jettyWsPresent = ClassUtils.isPresent(
6669
"org.eclipse.jetty.websocket.server.WebSocketServerFactory", DefaultHandshakeHandler.class.getClassLoader());
6770

6871
private static final boolean tomcatWsPresent = ClassUtils.isPresent(
6972
"org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader());
7073

71-
private static final boolean glassFishWsPresent = ClassUtils.isPresent(
72-
"org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader());
74+
private static final boolean undertowWsPresent = ClassUtils.isPresent(
75+
"io.undertow.websockets.jsr.ServerWebSocketContainer", DefaultHandshakeHandler.class.getClassLoader());
7376

7477

7578
private final RequestUpgradeStrategy requestUpgradeStrategy;
@@ -97,6 +100,9 @@ else if (tomcatWsPresent) {
97100
else if (glassFishWsPresent) {
98101
className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy";
99102
}
103+
else if (undertowWsPresent) {
104+
className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy";
105+
}
100106
else {
101107
throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
102108
}

spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright 2002-2014 the original author or authors.
23
*
34
* Licensed under the Apache License, Version 2.0 (the "License");
45
* you may not use this file except in compliance with the License.
@@ -29,6 +30,7 @@
2930
import org.springframework.util.concurrent.ListenableFuture;
3031
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
3132
import org.springframework.web.socket.client.WebSocketClient;
33+
import org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy;
3234
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
3335
import org.springframework.web.socket.server.RequestUpgradeStrategy;
3436
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
@@ -48,6 +50,7 @@ public abstract class AbstractWebSocketIntegrationTests {
4850
static {
4951
upgradeStrategyConfigTypes.put(JettyWebSocketTestServer.class, JettyUpgradeStrategyConfig.class);
5052
upgradeStrategyConfigTypes.put(TomcatWebSocketTestServer.class, TomcatUpgradeStrategyConfig.class);
53+
upgradeStrategyConfigTypes.put(UndertowTestServer.class, UndertowUpgradeStrategyConfig.class);
5154
}
5255

5356
@Parameter(0)
@@ -141,4 +144,13 @@ public RequestUpgradeStrategy requestUpgradeStrategy() {
141144
}
142145
}
143146

147+
@Configuration
148+
static class UndertowUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig {
149+
150+
@Bean
151+
public RequestUpgradeStrategy requestUpgradeStrategy() {
152+
return new UndertowRequestUpgradeStrategy();
153+
}
154+
}
155+
144156
}

0 commit comments

Comments
 (0)