Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RFC 8441 WebSocket upgrade with HTTP/2 CONNECT #34362

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,25 @@ public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler han
HttpMethod method = request.getMethod();
HttpHeaders headers = request.getHeaders();

if (HttpMethod.GET != method && CONNECT_METHOD != method) {
if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) {
return Mono.error(new MethodNotAllowedException(
request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD)));
}

if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
}
if (HttpMethod.GET == method) {
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
}

List<String> connectionValue = headers.getConnection();
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
}
List<String> connectionValue = headers.getConnection();
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
}

String key = headers.getFirst(SEC_WEBSOCKET_KEY);
if (key == null) {
return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
String key = headers.getFirst(SEC_WEBSOCKET_KEY);
if (key == null) {
return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
}
}

String protocol = selectProtocol(headers, handler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public void setSecWebSocketProtocol(List<String> secWebSocketProtocols) {
}

/**
* Returns the value of the {@code Sec-WebSocket-Key} header.
* Returns the value of the {@code Sec-WebSocket-Protocol} header.
* @return the value of the header
*/
public List<String> getSecWebSocketProtocol() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package org.springframework.web.socket.server;

import java.security.Principal;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.http.HttpMethod;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
Expand All @@ -35,6 +38,25 @@
*/
public interface RequestUpgradeStrategy {

enum OpeningHandshake {
RFC6455(HttpMethod.GET),
RFC8441(HttpMethod.valueOf("CONNECT"));

private final HttpMethod method;

OpeningHandshake(HttpMethod method) {
this.method = method;
}

public HttpMethod getMethod() {
return method;
}
}

default Set<OpeningHandshake> getSupportedOpeningHandshake() {
return EnumSet.of(OpeningHandshake.RFC6455);
}

/**
* Return the supported WebSocket protocol versions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import java.lang.reflect.UndeclaredThrowableException;
import java.security.Principal;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import jakarta.servlet.ServletContext;
Expand Down Expand Up @@ -59,6 +61,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv
private Consumer<Configurable> webSocketConfigurer;


@Override
public Set<OpeningHandshake> getSupportedOpeningHandshake() {
return EnumSet.of(OpeningHandshake.RFC6455, OpeningHandshake.RFC8441);
}

@Override
public String[] getSupportedVersions() {
return SUPPORTED_VERSIONS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -48,6 +49,7 @@
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.RequestUpgradeStrategy.OpeningHandshake;
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
import org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy;
import org.springframework.web.socket.server.standard.StandardWebSocketUpgradeStrategy;
Expand Down Expand Up @@ -78,9 +80,6 @@
*/
public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {

// For WebSocket upgrades in HTTP/2 (see RFC 8441)
private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT");

private static final boolean tomcatWsPresent;

private static final boolean jettyWsPresent;
Expand Down Expand Up @@ -215,22 +214,43 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
}
try {
HttpMethod httpMethod = request.getMethod();
if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) {
Set<OpeningHandshake> supportedHandshakes = requestUpgradeStrategy.getSupportedOpeningHandshake();
OpeningHandshake handshake = null;
for (OpeningHandshake h : supportedHandshakes) {
if (h.getMethod().equals(httpMethod)) {
handshake = h;
break;
}
}

if (handshake == OpeningHandshake.RFC6455) {
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
handleInvalidUpgradeHeader(request, response);
return false;
}
if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
handleInvalidConnectHeader(request, response);
return false;
}
String wsKey = headers.getSecWebSocketKey();
if (wsKey == null) {
if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
} else if (handshake == null) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
Set<HttpMethod> methods = supportedHandshakes.stream()
.map(OpeningHandshake::getMethod)
.collect(Collectors.toUnmodifiableSet());
response.getHeaders().setAllow(methods);
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
}
return false;
}
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
handleInvalidUpgradeHeader(request, response);
return false;
}
if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
handleInvalidConnectHeader(request, response);
return false;
}
if (!isWebSocketVersionSupported(headers)) {
handleWebSocketVersionNotSupported(request, response);
return false;
Expand All @@ -239,14 +259,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
String wsKey = headers.getSecWebSocketKey();
if (wsKey == null) {
if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
}
catch (IOException ex) {
throw new HandshakeFailureException(
Expand Down