Skip to content

Commit

Permalink
Test ws subprotocols
Browse files Browse the repository at this point in the history
  • Loading branch information
coolreader18 committed Jan 6, 2025
1 parent b5ee864 commit 95f2535
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,10 +1090,11 @@ mod tests {
#[crate::test]
async fn integration_test() {
let addr = spawn_service(echo_app());
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
test_echo_app(socket).await;
let uri = format!("ws://{addr}/echo").try_into().unwrap();
let req = tungstenite::client::ClientRequestBuilder::new(uri)
.with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
test_echo_app(socket, response.headers()).await;
}

#[crate::test]
Expand Down Expand Up @@ -1121,21 +1122,22 @@ mod tests {
.extension(hyper::ext::Protocol::from_static("websocket"))
.uri("/echo")
.header("sec-websocket-version", "13")
.header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
.header("Host", "server.example.com")
.body(Body::empty())
.unwrap();

let response = send_request.send_request(req).await.unwrap();
let mut response = send_request.send_request(req).await.unwrap();
let status = response.status();
if status != 200 {
let body = response.into_body().collect().await.unwrap().to_bytes();
let body = std::str::from_utf8(&body).unwrap();
panic!("response status was {status}: {body}");
}
let upgraded = hyper::upgrade::on(response).await.unwrap();
let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
test_echo_app(socket).await;
test_echo_app(socket, response.headers()).await;
}

fn echo_app() -> Router {
Expand All @@ -1156,11 +1158,19 @@ mod tests {

Router::new().route(
"/echo",
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
any(|ws: WebSocketUpgrade| {
ready(ws.protocols(["echo2", "echo"]).on_upgrade(handle_socket))
}),
)
}

async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
mut socket: WebSocketStream<S>,
headers: &http::HeaderMap,
) {
assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");

let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
Expand Down

0 comments on commit 95f2535

Please sign in to comment.