diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 44cff22ac..74d749ec1 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -115,6 +115,8 @@ struct Config { #[cfg(feature = "__tls")] max_tls_version: Option, #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] tls: TlsBackend, http_version_pref: HttpVersionPref, http09_responses: bool, @@ -198,6 +200,8 @@ impl ClientBuilder { #[cfg(feature = "__tls")] max_tls_version: None, #[cfg(feature = "__tls")] + tls_info: false, + #[cfg(feature = "__tls")] tls: TlsBackend::default(), http_version_pref: HttpVersionPref::All, http09_responses: false, @@ -408,6 +412,7 @@ impl ClientBuilder { user_agent(&config.headers), config.local_address, config.nodelay, + config.tls_info, )? } #[cfg(feature = "native-tls")] @@ -418,6 +423,7 @@ impl ClientBuilder { user_agent(&config.headers), config.local_address, config.nodelay, + config.tls_info, ), #[cfg(feature = "__rustls")] TlsBackend::BuiltRustls(conn) => { @@ -442,6 +448,7 @@ impl ClientBuilder { user_agent(&config.headers), config.local_address, config.nodelay, + config.tls_info, ) } #[cfg(feature = "__rustls")] @@ -586,6 +593,7 @@ impl ClientBuilder { user_agent(&config.headers), config.local_address, config.nodelay, + config.tls_info, ) } #[cfg(any(feature = "native-tls", feature = "__rustls",))] @@ -1483,6 +1491,26 @@ impl ClientBuilder { self } + /// Add TLS information as `TlsInfo` extension to responses. + /// + /// # Optional + /// + /// This requires the optional `default-tls`, `native-tls`, or `rustls-tls(-...)` + /// feature to be enabled. + #[cfg(feature = "__tls")] + #[cfg_attr( + docsrs, + doc(cfg(any( + feature = "default-tls", + feature = "native-tls", + feature = "rustls-tls" + ))) + )] + pub fn tls_info(mut self, tls_info: bool) -> ClientBuilder { + self.config.tls_info = tls_info; + self + } + /// Enables the [trust-dns](trust_dns_resolver) async resolver instead of a default threadpool using `getaddrinfo`. /// /// If the `trust-dns` feature is turned on, the default option is enabled. @@ -1987,6 +2015,8 @@ impl Config { } f.field("tls_sni", &self.tls_sni); + + f.field("tls_info", &self.tls_info); } #[cfg(all(feature = "native-tls-crate", feature = "__rustls"))] diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 67e280f8a..d57f3a031 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -738,6 +738,25 @@ impl ClientBuilder { self.with_inner(move |inner| inner.use_rustls_tls()) } + /// Add TLS information as `TlsInfo` extension to responses. + /// + /// # Optional + /// + /// This requires the optional `default-tls`, `native-tls`, or `rustls-tls(-...)` + /// feature to be enabled. + #[cfg(feature = "__tls")] + #[cfg_attr( + docsrs, + doc(cfg(any( + feature = "default-tls", + feature = "native-tls", + feature = "rustls-tls" + ))) + )] + pub fn tls_info(self, tls_info: bool) -> ClientBuilder { + self.with_inner(|inner| inner.tls_info(tls_info)) + } + /// Use a preconfigured TLS backend. /// /// If the passed `Any` argument is not a TLS backend that reqwest diff --git a/src/connect.rs b/src/connect.rs index b80ffc767..c171dd18d 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -36,6 +36,8 @@ pub(crate) struct Connector { #[cfg(feature = "__tls")] nodelay: bool, #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] user_agent: Option, } @@ -82,13 +84,14 @@ impl Connector { user_agent: Option, local_addr: T, nodelay: bool, + tls_info: bool, ) -> crate::Result where T: Into>, { let tls = tls.build().map_err(crate::error::builder)?; Ok(Self::from_built_default_tls( - http, tls, proxies, user_agent, local_addr, nodelay, + http, tls, proxies, user_agent, local_addr, nodelay, tls_info, )) } @@ -100,6 +103,7 @@ impl Connector { user_agent: Option, local_addr: T, nodelay: bool, + tls_info: bool, ) -> Connector where T: Into>, @@ -113,6 +117,7 @@ impl Connector { verbose: verbose::OFF, timeout: None, nodelay, + tls_info, user_agent, } } @@ -125,6 +130,7 @@ impl Connector { user_agent: Option, local_addr: T, nodelay: bool, + tls_info: bool, ) -> Connector where T: Into>, @@ -151,6 +157,7 @@ impl Connector { verbose: verbose::OFF, timeout: None, nodelay, + tls_info, user_agent, } } @@ -188,6 +195,7 @@ impl Connector { return Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: io }), is_proxy: false, + tls_info: self.tls_info, }); } } @@ -208,6 +216,7 @@ impl Connector { return Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: io }), is_proxy: false, + tls_info: false, }); } } @@ -218,6 +227,7 @@ impl Connector { socks::connect(proxy, dst, dns).await.map(|tcp| Conn { inner: self.verbose.wrap(tcp), is_proxy: false, + tls_info: false, }) } @@ -229,6 +239,7 @@ impl Connector { Ok(Conn { inner: self.verbose.wrap(io), is_proxy, + tls_info: false, }) } #[cfg(feature = "default-tls")] @@ -253,11 +264,13 @@ impl Connector { Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: stream }), is_proxy, + tls_info: self.tls_info, }) } else { Ok(Conn { inner: self.verbose.wrap(io), is_proxy, + tls_info: false, }) } } @@ -283,11 +296,13 @@ impl Connector { Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: stream }), is_proxy, + tls_info: self.tls_info, }) } else { Ok(Conn { inner: self.verbose.wrap(io), is_proxy, + tls_info: false, }) } } @@ -337,6 +352,7 @@ impl Connector { return Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: io }), is_proxy: false, + tls_info: false, }); } } @@ -369,6 +385,7 @@ impl Connector { return Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: io }), is_proxy: false, + tls_info: false, }); } } @@ -444,6 +461,105 @@ impl Service for Connector { } } +#[cfg(feature = "__tls")] +trait TlsInfoFactory { + fn tls_info(&self) -> Option; +} + +#[cfg(feature = "__tls")] +impl TlsInfoFactory for tokio::net::TcpStream { + fn tls_info(&self) -> Option { + None + } +} + +#[cfg(feature = "default-tls")] +impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { + fn tls_info(&self) -> Option { + match self { + hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_tls::MaybeHttpsStream::Http(_) => None, + } + } +} + +#[cfg(feature = "default-tls")] +impl TlsInfoFactory for hyper_tls::TlsStream> { + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .peer_certificate() + .ok() + .flatten() + .and_then(|c| c.to_der().ok()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "default-tls")] +impl TlsInfoFactory for tokio_native_tls::TlsStream { + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .peer_certificate() + .ok() + .flatten() + .and_then(|c| c.to_der().ok()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "__rustls")] +impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { + fn tls_info(&self) -> Option { + match self { + hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_rustls::MaybeHttpsStream::Http(_) => None, + } + } +} + +#[cfg(feature = "__rustls")] +impl TlsInfoFactory for tokio_rustls::TlsStream { + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .1 + .peer_certificates() + .and_then(|certs| certs.first()) + .map(|c| c.0.clone()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "__rustls")] +impl TlsInfoFactory + for tokio_rustls::client::TlsStream> +{ + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .1 + .peer_certificates() + .and_then(|certs| certs.first()) + .map(|c| c.0.clone()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "__rustls")] +impl TlsInfoFactory for tokio_rustls::client::TlsStream { + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .1 + .peer_certificates() + .and_then(|certs| certs.first()) + .map(|c| c.0.clone()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + pub(crate) trait AsyncConn: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static { @@ -451,7 +567,17 @@ pub(crate) trait AsyncConn: impl AsyncConn for T {} -type BoxConn = Box; +#[cfg(feature = "__tls")] +trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} +#[cfg(not(feature = "__tls"))] +trait AsyncConnWithInfo: AsyncConn {} + +#[cfg(feature = "__tls")] +impl AsyncConnWithInfo for T {} +#[cfg(not(feature = "__tls"))] +impl AsyncConnWithInfo for T {} + +type BoxConn = Box; pin_project! { /// Note: the `is_proxy` member means *is plain text HTTP proxy*. @@ -462,12 +588,26 @@ pin_project! { #[pin] inner: BoxConn, is_proxy: bool, + // Only needed for __tls, but #[cfg()] on fields breaks pin_project! + tls_info: bool, } } impl Connection for Conn { fn connected(&self) -> Connected { - self.inner.connected().proxy(self.is_proxy) + let connected = self.inner.connected().proxy(self.is_proxy); + #[cfg(feature = "__tls")] + if self.tls_info { + if let Some(tls_info) = self.inner.tls_info() { + connected.extra(tls_info) + } else { + connected + } + } else { + connected + } + #[cfg(not(feature = "__tls"))] + connected } } @@ -595,6 +735,7 @@ fn tunnel_eof() -> BoxError { #[cfg(feature = "default-tls")] mod native_tls_conn { + use super::TlsInfoFactory; use hyper::client::connect::{Connected, Connection}; use pin_project_lite::pin_project; use std::{ @@ -682,10 +823,23 @@ mod native_tls_conn { AsyncWrite::poll_shutdown(this.inner, cx) } } + + impl TlsInfoFactory for NativeTlsConn { + fn tls_info(&self) -> Option { + self.inner.tls_info() + } + } + + impl TlsInfoFactory for NativeTlsConn> { + fn tls_info(&self) -> Option { + self.inner.tls_info() + } + } } #[cfg(feature = "__rustls")] mod rustls_tls_conn { + use super::TlsInfoFactory; use hyper::client::connect::{Connected, Connection}; use pin_project_lite::pin_project; use std::{ @@ -762,6 +916,18 @@ mod rustls_tls_conn { AsyncWrite::poll_shutdown(this.inner, cx) } } + + impl TlsInfoFactory for RustlsTlsConn { + fn tls_info(&self) -> Option { + self.inner.tls_info() + } + } + + impl TlsInfoFactory for RustlsTlsConn> { + fn tls_info(&self) -> Option { + self.inner.tls_info() + } + } } #[cfg(feature = "socks")] @@ -844,7 +1010,7 @@ mod verbose { pub(super) struct Wrapper(pub(super) bool); impl Wrapper { - pub(super) fn wrap(&self, conn: T) -> super::BoxConn { + pub(super) fn wrap(&self, conn: T) -> super::BoxConn { if self.0 && log::log_enabled!(log::Level::Trace) { Box::new(Verbose { // truncate is fine @@ -939,6 +1105,13 @@ mod verbose { } } + #[cfg(feature = "__tls")] + impl super::TlsInfoFactory for Verbose { + fn tls_info(&self) -> Option { + self.inner.tls_info() + } + } + struct Escape<'a>(&'a [u8]); impl fmt::Debug for Escape<'_> { diff --git a/src/tls.rs b/src/tls.rs index 07f3d4543..e873939ab 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -463,6 +463,26 @@ impl ServerCertVerifier for NoVerifier { } } +/// Hyper extension carrying extra TLS layer information. +/// Made available to clients on responses when `tls_info` is set. +#[derive(Clone)] +pub struct TlsInfo { + pub(crate) peer_certificate: Option>, +} + +impl TlsInfo { + /// Get the DER encoded leaf certificate of the peer. + pub fn peer_certificate(&self) -> Option<&[u8]> { + self.peer_certificate.as_ref().map(|der| &der[..]) + } +} + +impl std::fmt::Debug for TlsInfo { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("TlsInfo").finish() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/blocking.rs b/tests/blocking.rs index f5f10175f..fa6c8d01c 100644 --- a/tests/blocking.rs +++ b/tests/blocking.rs @@ -362,3 +362,25 @@ fn blocking_update_json_content_type_if_set_manually() { assert_eq!("application/json", req.headers().get(CONTENT_TYPE).unwrap()); } + +#[test] +fn test_response_no_tls_info_for_http() { + let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); + + let url = format!("http://{}/text", server.addr()); + + let client = reqwest::blocking::Client::builder() + .tls_info(true) + .build() + .unwrap(); + + let res = client.get(&url).send().unwrap(); + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::OK); + assert_eq!(res.content_length(), Some(5)); + let tls_info = res.extensions().get::(); + assert_eq!(tls_info.is_none(), true); + + let body = res.text().unwrap(); + assert_eq!(b"Hello", body.as_bytes()); +} diff --git a/tests/client.rs b/tests/client.rs index cacca0429..e77cc6a4a 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -408,3 +408,33 @@ fn update_json_content_type_if_set_manually() { assert_eq!("application/json", req.headers().get(CONTENT_TYPE).unwrap()); } + +#[cfg(all(feature = "__tls", not(feature = "rustls-tls-manual-roots")))] +#[tokio::test] +async fn test_tls_info() { + let resp = reqwest::Client::builder() + .tls_info(true) + .build() + .expect("client builder") + .get("https://google.com") + .send() + .await + .expect("response"); + let tls_info = resp.extensions().get::(); + assert!(tls_info.is_some()); + let tls_info = tls_info.unwrap(); + let peer_certificate = tls_info.peer_certificate(); + assert!(peer_certificate.is_some()); + let der = peer_certificate.unwrap(); + assert_eq!(der[0], 0x30); // ASN.1 SEQUENCE + + let resp = reqwest::Client::builder() + .build() + .expect("client builder") + .get("https://google.com") + .send() + .await + .expect("response"); + let tls_info = resp.extensions().get::(); + assert!(tls_info.is_none()); +}