diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 243ff20ec4..7990a62183 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -1090,10 +1090,11 @@ mod tests { #[crate::test] async fn integration_test() { let addr = spawn_service(echo_app()); - let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo")) - .await - .unwrap(); - test_echo_app(socket).await; + let uri = format!("ws://{addr}/echo").try_into().unwrap(); + let req = tungstenite::client::ClientRequestBuilder::new(uri) + .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO); + let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap(); + test_echo_app(socket, response.headers()).await; } #[crate::test] @@ -1121,21 +1122,22 @@ mod tests { .extension(hyper::ext::Protocol::from_static("websocket")) .uri("/echo") .header("sec-websocket-version", "13") + .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO) .header("Host", "server.example.com") .body(Body::empty()) .unwrap(); - let response = send_request.send_request(req).await.unwrap(); + let mut response = send_request.send_request(req).await.unwrap(); let status = response.status(); if status != 200 { let body = response.into_body().collect().await.unwrap().to_bytes(); let body = std::str::from_utf8(&body).unwrap(); panic!("response status was {status}: {body}"); } - let upgraded = hyper::upgrade::on(response).await.unwrap(); + let upgraded = hyper::upgrade::on(&mut response).await.unwrap(); let upgraded = TokioIo::new(upgraded); let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await; - test_echo_app(socket).await; + test_echo_app(socket, response.headers()).await; } fn echo_app() -> Router { @@ -1156,11 +1158,19 @@ mod tests { Router::new().route( "/echo", - any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))), + any(|ws: WebSocketUpgrade| { + ready(ws.protocols(["echo2", "echo"]).on_upgrade(handle_socket)) + }), ) } - async fn test_echo_app(mut socket: WebSocketStream) { + const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo"; + async fn test_echo_app( + mut socket: WebSocketStream, + headers: &http::HeaderMap, + ) { + assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo"); + let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar")); socket.send(input.clone()).await.unwrap(); let output = socket.next().await.unwrap().unwrap();