From 6c95dc6e4776f05dcf63dc096de2c228abbc2243 Mon Sep 17 00:00:00 2001 From: Jared Wiltshire <jazdw@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:45:47 -0700 Subject: [PATCH 1/2] Fix HTTP/2 CONNECT WebSocket upgrades (RFC 8441) Closes gh-34362 Signed-off-by: Jared Wiltshire <jazdw@users.noreply.github.com> --- .../support/HandshakeWebSocketService.java | 24 +++++++------- .../web/socket/WebSocketHttpHeaders.java | 2 +- .../support/AbstractHandshakeHandler.java | 32 +++++++++++-------- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index c54f38d9bc5b..d7577a02517e 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -205,23 +205,25 @@ public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler han HttpMethod method = request.getMethod(); HttpHeaders headers = request.getHeaders(); - if (HttpMethod.GET != method && CONNECT_METHOD != method) { + if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) { return Mono.error(new MethodNotAllowedException( request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD))); } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); - } + if (HttpMethod.GET == method) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); + } - List<String> connectionValue = headers.getConnection(); - if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { - return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); - } + List<String> connectionValue = headers.getConnection(); + if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { + return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); + } - String key = headers.getFirst(SEC_WEBSOCKET_KEY); - if (key == null) { - return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); + String key = headers.getFirst(SEC_WEBSOCKET_KEY); + if (key == null) { + return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); + } } String protocol = selectProtocol(headers, handler); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java index fa4c9037b83c..1a4fac7f881e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java @@ -151,7 +151,7 @@ public void setSecWebSocketProtocol(List<String> secWebSocketProtocols) { } /** - * Returns the value of the {@code Sec-WebSocket-Key} header. + * Returns the value of the {@code Sec-WebSocket-Protocol} header. * @return the value of the header */ public List<String> getSecWebSocketProtocol() { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java index acde43c3cc59..fce20644c16f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java @@ -215,7 +215,7 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r } try { HttpMethod httpMethod = request.getMethod(); - if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) { + if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) { response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD)); if (logger.isErrorEnabled()) { @@ -223,13 +223,15 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r } return false; } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - handleInvalidUpgradeHeader(request, response); - return false; - } - if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { - handleInvalidConnectHeader(request, response); - return false; + if (HttpMethod.GET == httpMethod) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + handleInvalidUpgradeHeader(request, response); + return false; + } + if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { + handleInvalidConnectHeader(request, response); + return false; + } } if (!isWebSocketVersionSupported(headers)) { handleWebSocketVersionNotSupported(request, response); @@ -239,13 +241,15 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r response.setStatusCode(HttpStatus.FORBIDDEN); return false; } - String wsKey = headers.getSecWebSocketKey(); - if (wsKey == null) { - if (logger.isErrorEnabled()) { - logger.error("Missing \"Sec-WebSocket-Key\" header"); + if (HttpMethod.GET == httpMethod) { + String wsKey = headers.getSecWebSocketKey(); + if (wsKey == null) { + if (logger.isErrorEnabled()) { + logger.error("Missing \"Sec-WebSocket-Key\" header"); + } + response.setStatusCode(HttpStatus.BAD_REQUEST); + return false; } - response.setStatusCode(HttpStatus.BAD_REQUEST); - return false; } } catch (IOException ex) { From 695937718420752b6eb7ae156ef3ca3306573b53 Mon Sep 17 00:00:00 2001 From: Jared Wiltshire <jazdw@users.noreply.github.com> Date: Wed, 5 Feb 2025 16:04:09 -0700 Subject: [PATCH 2/2] Only attempt RFC8441 upgrade if we know the strategy supports it Signed-off-by: Jared Wiltshire <jazdw@users.noreply.github.com> --- .../socket/server/RequestUpgradeStrategy.java | 22 +++++++++ .../jetty/JettyRequestUpgradeStrategy.java | 7 +++ .../support/AbstractHandshakeHandler.java | 48 +++++++++++-------- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index 288193de2efa..be2e308f7ccb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -17,9 +17,12 @@ package org.springframework.web.socket.server; import java.security.Principal; +import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Set; +import org.springframework.http.HttpMethod; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.lang.Nullable; @@ -35,6 +38,25 @@ */ public interface RequestUpgradeStrategy { + enum OpeningHandshake { + RFC6455(HttpMethod.GET), + RFC8441(HttpMethod.valueOf("CONNECT")); + + private final HttpMethod method; + + OpeningHandshake(HttpMethod method) { + this.method = method; + } + + public HttpMethod getMethod() { + return method; + } + } + + default Set<OpeningHandshake> getSupportedOpeningHandshake() { + return EnumSet.of(OpeningHandshake.RFC6455); + } + /** * Return the supported WebSocket protocol versions. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java index 2ed4542e3111..04dbd16f7f70 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java @@ -19,8 +19,10 @@ import java.lang.reflect.UndeclaredThrowableException; import java.security.Principal; import java.util.Collections; +import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Consumer; import jakarta.servlet.ServletContext; @@ -59,6 +61,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv private Consumer<Configurable> webSocketConfigurer; + @Override + public Set<OpeningHandshake> getSupportedOpeningHandshake() { + return EnumSet.of(OpeningHandshake.RFC6455, OpeningHandshake.RFC8441); + } + @Override public String[] getSupportedVersions() { return SUPPORTED_VERSIONS; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java index fce20644c16f..babfe33995a2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java @@ -26,6 +26,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -48,6 +49,7 @@ import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.RequestUpgradeStrategy.OpeningHandshake; import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy; import org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy; import org.springframework.web.socket.server.standard.StandardWebSocketUpgradeStrategy; @@ -78,9 +80,6 @@ */ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle { - // For WebSocket upgrades in HTTP/2 (see RFC 8441) - private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT"); - private static final boolean tomcatWsPresent; private static final boolean jettyWsPresent; @@ -215,15 +214,16 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r } try { HttpMethod httpMethod = request.getMethod(); - if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) { - response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); - response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD)); - if (logger.isErrorEnabled()) { - logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod); + Set<OpeningHandshake> supportedHandshakes = requestUpgradeStrategy.getSupportedOpeningHandshake(); + OpeningHandshake handshake = null; + for (OpeningHandshake h : supportedHandshakes) { + if (h.getMethod().equals(httpMethod)) { + handshake = h; + break; } - return false; } - if (HttpMethod.GET == httpMethod) { + + if (handshake == OpeningHandshake.RFC6455) { if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { handleInvalidUpgradeHeader(request, response); return false; @@ -232,16 +232,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r handleInvalidConnectHeader(request, response); return false; } - } - if (!isWebSocketVersionSupported(headers)) { - handleWebSocketVersionNotSupported(request, response); - return false; - } - if (!isValidOrigin(request)) { - response.setStatusCode(HttpStatus.FORBIDDEN); - return false; - } - if (HttpMethod.GET == httpMethod) { String wsKey = headers.getSecWebSocketKey(); if (wsKey == null) { if (logger.isErrorEnabled()) { @@ -250,6 +240,24 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r response.setStatusCode(HttpStatus.BAD_REQUEST); return false; } + } else if (handshake == null) { + response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); + Set<HttpMethod> methods = supportedHandshakes.stream() + .map(OpeningHandshake::getMethod) + .collect(Collectors.toUnmodifiableSet()); + response.getHeaders().setAllow(methods); + if (logger.isErrorEnabled()) { + logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod); + } + return false; + } + if (!isWebSocketVersionSupported(headers)) { + handleWebSocketVersionNotSupported(request, response); + return false; + } + if (!isValidOrigin(request)) { + response.setStatusCode(HttpStatus.FORBIDDEN); + return false; } } catch (IOException ex) {