Skip to content

Commit

Permalink
Make ConnectInfo work with ListenerExt::tap_io (#3059)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte authored Dec 2, 2024
1 parent 6d30c57 commit 9fab45a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
11 changes: 11 additions & 0 deletions axum/src/extract/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ const _: () = {
*stream.remote_addr()
}
}

impl<'a, L, F> Connected<serve::IncomingStream<'a, serve::TapIo<L, F>>> for L::Addr
where
L: serve::Listener,
L::Addr: Clone + Sync + 'static,
F: FnMut(&mut L::Io) + Send + 'static,
{
fn connect_info(stream: serve::IncomingStream<'a, serve::TapIo<L, F>>) -> Self {
stream.remote_addr().clone()
}
}
};

impl Connected<SocketAddr> for SocketAddr {
Expand Down
38 changes: 38 additions & 0 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ mod tests {
extract::connect_info::Connected,
handler::{Handler, HandlerWithoutStateExt},
routing::get,
serve::ListenerExt,
Router,
};

Expand All @@ -452,14 +453,29 @@ mod tests {

let addr = "0.0.0.0:0";

let tcp_nodelay_listener = || async {
TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
if let Err(err) = tcp_stream.set_nodelay(true) {
eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
})
};

// router
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(tcp_nodelay_listener().await, router.clone())
.await
.unwrap();
serve(UnixListener::bind("").unwrap(), router.clone());

serve(
TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(),
);
serve(
tcp_nodelay_listener().await,
router.clone().into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
Expand All @@ -471,19 +487,30 @@ mod tests {
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
);

// method router
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(tcp_nodelay_listener().await, get(handler));
serve(UnixListener::bind("").unwrap(), get(handler));

serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(),
);
serve(
tcp_nodelay_listener().await,
get(handler).into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
Expand All @@ -493,6 +520,10 @@ mod tests {
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
Expand All @@ -503,24 +534,31 @@ mod tests {
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
);
serve(tcp_nodelay_listener().await, handler.into_service());
serve(UnixListener::bind("").unwrap(), handler.into_service());

serve(
TcpListener::bind(addr).await.unwrap(),
handler.with_state(()),
);
serve(tcp_nodelay_listener().await, handler.with_state(()));
serve(UnixListener::bind("").unwrap(), handler.with_state(()));

serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(),
);
serve(tcp_nodelay_listener().await, handler.into_make_service());
serve(UnixListener::bind("").unwrap(), handler.into_make_service());

serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
Expand Down
2 changes: 1 addition & 1 deletion axum/src/serve/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<L: Listener> ListenerExt for L {}
/// Return type of [`ListenerExt::tap_io`].
///
/// See that method for details.
pub struct TapIo<L: Listener, F> {
pub struct TapIo<L, F> {
listener: L,
tap_fn: F,
}
Expand Down

0 comments on commit 9fab45a

Please sign in to comment.