From 6c95dc6e4776f05dcf63dc096de2c228abbc2243 Mon Sep 17 00:00:00 2001
From: Jared Wiltshire <jazdw@users.noreply.github.com>
Date: Mon, 3 Feb 2025 16:45:47 -0700
Subject: [PATCH 1/2] Fix HTTP/2 CONNECT WebSocket upgrades (RFC 8441)

Closes gh-34362

Signed-off-by: Jared Wiltshire <jazdw@users.noreply.github.com>
---
 .../support/HandshakeWebSocketService.java    | 24 +++++++-------
 .../web/socket/WebSocketHttpHeaders.java      |  2 +-
 .../support/AbstractHandshakeHandler.java     | 32 +++++++++++--------
 3 files changed, 32 insertions(+), 26 deletions(-)

diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java
index c54f38d9bc5b..d7577a02517e 100644
--- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java
+++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java
@@ -205,23 +205,25 @@ public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler han
 		HttpMethod method = request.getMethod();
 		HttpHeaders headers = request.getHeaders();
 
-		if (HttpMethod.GET != method && CONNECT_METHOD != method) {
+		if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) {
 			return Mono.error(new MethodNotAllowedException(
 					request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD)));
 		}
 
-		if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
-			return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
-		}
+		if (HttpMethod.GET == method) {
+			if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
+				return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
+			}
 
-		List<String> connectionValue = headers.getConnection();
-		if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
-			return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
-		}
+			List<String> connectionValue = headers.getConnection();
+			if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
+				return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
+			}
 
-		String key = headers.getFirst(SEC_WEBSOCKET_KEY);
-		if (key == null) {
-			return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
+			String key = headers.getFirst(SEC_WEBSOCKET_KEY);
+			if (key == null) {
+				return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
+			}
 		}
 
 		String protocol = selectProtocol(headers, handler);
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java
index fa4c9037b83c..1a4fac7f881e 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java
@@ -151,7 +151,7 @@ public void setSecWebSocketProtocol(List<String> secWebSocketProtocols) {
 	}
 
 	/**
-	 * Returns the value of the {@code Sec-WebSocket-Key} header.
+	 * Returns the value of the {@code Sec-WebSocket-Protocol} header.
 	 * @return the value of the header
 	 */
 	public List<String> getSecWebSocketProtocol() {
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java
index acde43c3cc59..fce20644c16f 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java
@@ -215,7 +215,7 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
 		}
 		try {
 			HttpMethod httpMethod = request.getMethod();
-			if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) {
+			if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) {
 				response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
 				response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
 				if (logger.isErrorEnabled()) {
@@ -223,13 +223,15 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
 				}
 				return false;
 			}
-			if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
-				handleInvalidUpgradeHeader(request, response);
-				return false;
-			}
-			if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
-				handleInvalidConnectHeader(request, response);
-				return false;
+			if (HttpMethod.GET == httpMethod) {
+				if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
+					handleInvalidUpgradeHeader(request, response);
+					return false;
+				}
+				if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
+					handleInvalidConnectHeader(request, response);
+					return false;
+				}
 			}
 			if (!isWebSocketVersionSupported(headers)) {
 				handleWebSocketVersionNotSupported(request, response);
@@ -239,13 +241,15 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
 				response.setStatusCode(HttpStatus.FORBIDDEN);
 				return false;
 			}
-			String wsKey = headers.getSecWebSocketKey();
-			if (wsKey == null) {
-				if (logger.isErrorEnabled()) {
-					logger.error("Missing \"Sec-WebSocket-Key\" header");
+			if (HttpMethod.GET == httpMethod) {
+				String wsKey = headers.getSecWebSocketKey();
+				if (wsKey == null) {
+					if (logger.isErrorEnabled()) {
+						logger.error("Missing \"Sec-WebSocket-Key\" header");
+					}
+					response.setStatusCode(HttpStatus.BAD_REQUEST);
+					return false;
 				}
-				response.setStatusCode(HttpStatus.BAD_REQUEST);
-				return false;
 			}
 		}
 		catch (IOException ex) {

From 695937718420752b6eb7ae156ef3ca3306573b53 Mon Sep 17 00:00:00 2001
From: Jared Wiltshire <jazdw@users.noreply.github.com>
Date: Wed, 5 Feb 2025 16:04:09 -0700
Subject: [PATCH 2/2] Only attempt RFC8441 upgrade if we know the strategy
 supports it

Signed-off-by: Jared Wiltshire <jazdw@users.noreply.github.com>
---
 .../socket/server/RequestUpgradeStrategy.java | 22 +++++++++
 .../jetty/JettyRequestUpgradeStrategy.java    |  7 +++
 .../support/AbstractHandshakeHandler.java     | 48 +++++++++++--------
 3 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java
index 288193de2efa..be2e308f7ccb 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java
@@ -17,9 +17,12 @@
 package org.springframework.web.socket.server;
 
 import java.security.Principal;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
+import org.springframework.http.HttpMethod;
 import org.springframework.http.server.ServerHttpRequest;
 import org.springframework.http.server.ServerHttpResponse;
 import org.springframework.lang.Nullable;
@@ -35,6 +38,25 @@
  */
 public interface RequestUpgradeStrategy {
 
+	enum OpeningHandshake {
+		RFC6455(HttpMethod.GET),
+		RFC8441(HttpMethod.valueOf("CONNECT"));
+
+		private final HttpMethod method;
+
+		OpeningHandshake(HttpMethod method) {
+			this.method = method;
+		}
+
+		public HttpMethod getMethod() {
+			return method;
+		}
+	}
+
+	default Set<OpeningHandshake> getSupportedOpeningHandshake() {
+		return EnumSet.of(OpeningHandshake.RFC6455);
+	}
+
 	/**
 	 * Return the supported WebSocket protocol versions.
 	 */
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java
index 2ed4542e3111..04dbd16f7f70 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java
@@ -19,8 +19,10 @@
 import java.lang.reflect.UndeclaredThrowableException;
 import java.security.Principal;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Consumer;
 
 import jakarta.servlet.ServletContext;
@@ -59,6 +61,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv
 	private Consumer<Configurable> webSocketConfigurer;
 
 
+	@Override
+	public Set<OpeningHandshake> getSupportedOpeningHandshake() {
+		return EnumSet.of(OpeningHandshake.RFC6455, OpeningHandshake.RFC8441);
+	}
+
 	@Override
 	public String[] getSupportedVersions() {
 		return SUPPORTED_VERSIONS;
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java
index fce20644c16f..babfe33995a2 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java
@@ -26,6 +26,7 @@
 import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -48,6 +49,7 @@
 import org.springframework.web.socket.server.HandshakeFailureException;
 import org.springframework.web.socket.server.HandshakeHandler;
 import org.springframework.web.socket.server.RequestUpgradeStrategy;
+import org.springframework.web.socket.server.RequestUpgradeStrategy.OpeningHandshake;
 import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
 import org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy;
 import org.springframework.web.socket.server.standard.StandardWebSocketUpgradeStrategy;
@@ -78,9 +80,6 @@
  */
 public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {
 
-	// For WebSocket upgrades in HTTP/2 (see RFC 8441)
-	private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT");
-
 	private static final boolean tomcatWsPresent;
 
 	private static final boolean jettyWsPresent;
@@ -215,15 +214,16 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
 		}
 		try {
 			HttpMethod httpMethod = request.getMethod();
-			if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) {
-				response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
-				response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
-				if (logger.isErrorEnabled()) {
-					logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
+			Set<OpeningHandshake> supportedHandshakes = requestUpgradeStrategy.getSupportedOpeningHandshake();
+			OpeningHandshake handshake = null;
+			for (OpeningHandshake h : supportedHandshakes) {
+				if (h.getMethod().equals(httpMethod)) {
+					handshake = h;
+					break;
 				}
-				return false;
 			}
-			if (HttpMethod.GET == httpMethod) {
+
+			if (handshake == OpeningHandshake.RFC6455) {
 				if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
 					handleInvalidUpgradeHeader(request, response);
 					return false;
@@ -232,16 +232,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
 					handleInvalidConnectHeader(request, response);
 					return false;
 				}
-			}
-			if (!isWebSocketVersionSupported(headers)) {
-				handleWebSocketVersionNotSupported(request, response);
-				return false;
-			}
-			if (!isValidOrigin(request)) {
-				response.setStatusCode(HttpStatus.FORBIDDEN);
-				return false;
-			}
-			if (HttpMethod.GET == httpMethod) {
 				String wsKey = headers.getSecWebSocketKey();
 				if (wsKey == null) {
 					if (logger.isErrorEnabled()) {
@@ -250,6 +240,24 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
 					response.setStatusCode(HttpStatus.BAD_REQUEST);
 					return false;
 				}
+			} else if (handshake == null) {
+				response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
+				Set<HttpMethod> methods = supportedHandshakes.stream()
+						.map(OpeningHandshake::getMethod)
+						.collect(Collectors.toUnmodifiableSet());
+				response.getHeaders().setAllow(methods);
+				if (logger.isErrorEnabled()) {
+					logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
+				}
+				return false;
+			}
+			if (!isWebSocketVersionSupported(headers)) {
+				handleWebSocketVersionNotSupported(request, response);
+				return false;
+			}
+			if (!isValidOrigin(request)) {
+				response.setStatusCode(HttpStatus.FORBIDDEN);
+				return false;
 			}
 		}
 		catch (IOException ex) {