diff --git a/Cargo.lock b/Cargo.lock index a576f2e9..89d60d24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,27 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bitflags" version = "2.6.0" @@ -48,12 +63,66 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "der" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "env_logger" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -74,6 +143,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "hoot" version = "0.2.0" @@ -104,6 +179,12 @@ version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "idna" version = "0.5.0" @@ -114,6 +195,17 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "is-terminal" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + [[package]] name = "itoa" version = "1.0.11" @@ -126,30 +218,112 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "log" version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "once_cell" version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl" +version = "0.10.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "pem-rfc7468" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d159833a9105500e0398934e205e0773f0b27529557134ecfc51c27646adac" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "proc-macro2" version = "1.0.86" @@ -168,6 +342,35 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "regex" +version = "1.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + [[package]] name = "ring" version = "0.17.8" @@ -183,6 +386,19 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustls" version = "0.23.11" @@ -299,6 +515,27 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys", +] + +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "1.0.61" @@ -367,10 +604,14 @@ version = "3.0.0-beta1" dependencies = [ "base64", "cc", + "der", + "env_logger", "hoot", "http", "log", + "native-tls", "once_cell", + "regex", "rustls", "rustls-native-certs", "rustls-pemfile", @@ -389,12 +630,27 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "winapi-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +dependencies = [ + "windows-sys", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 207646a1..37d00c31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,8 +22,9 @@ rust-version = "1.63" features = ["rustls"] [features] -default = ["rustls", "native-roots"] +default = ["rustls", "native-roots", "native-tls"] rustls = ["dep:rustls", "tls"] +native-tls = ["dep:native-tls", "dep:der", "tls"] tls = ["dep:rustls-pemfile", "dep:rustls-pki-types"] native-roots = ["dep:rustls-native-certs"] @@ -36,15 +37,21 @@ thiserror = "1.0.61" once_cell = "1.19.0" # These are used regardless of TLS implementation. -rustls-pemfile = { version = "2.1.2", optional = true, default-features = false, features = ["std"] } +rustls-pemfile = { version = "2.1.2", optional = true, default-features = false, features = ["std"] } rustls-pki-types = { version = "1.7.0", optional = true, default-features = false, features = ["std"] } rustls-native-certs = { version = "0.7.1", optional = true, default-features = false } # ring has a higher chance of compiling cleanly without additional developer environment -rustls = { version = "0.23.11", optional = true, default-features = false, features = ["ring", "logging", "std", "tls12"] } +rustls = { version = "0.23.11", optional = true, default-features = false, features = ["ring", "logging", "std", "tls12"] } +native-tls = { version = "0.2.12", optional = true, default-features = false } +# Needed for MSRV 1.63 (0.7.0 require 1.65) +der = { version = "=0.6.1", optional = true, default-features = false, features = ["pem", "std"] } [build-dependencies] -# Needed for MSRV 1.63 +# Needed for MSRV 1.63 (1.0.106 requires 1.67) cc = "=1.0.105" [dev-dependencies] +# Needed for MSRV 1.63 (0.11 requires 1.71) +env_logger = "=0.10.2" +regex = "=1.9.6" diff --git a/src/agent.rs b/src/agent.rs index db200c92..ddb52a0a 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -11,7 +11,7 @@ use crate::proxy::Proxy; use crate::recv::RecvBody; use crate::resolver::{DefaultResolver, Resolver}; use crate::time::Instant; -use crate::transport::{Buffers, ConnectionDetails, Connector, DefaultConnector}; +use crate::transport::{ConnectionDetails, Connector, DefaultConnector, NoBuffers}; use crate::unit::{Event, Input, Unit}; use crate::{Body, Error}; @@ -137,8 +137,8 @@ impl Default for AgentConfig { max_redirects: 10, redirect_auth_headers: RedirectAuthHeaders::Never, user_agent: "ureq".to_string(), // TODO(martin): add version - input_buffer_size: 512 * 1024, - output_buffer_size: 512 * 1024, + input_buffer_size: 128 * 1024, + output_buffer_size: 128 * 1024, #[cfg(all(feature = "tls"))] tls_config: TlsConfig::with_native_roots(), @@ -195,14 +195,15 @@ impl Agent { let mut addr = None; let mut connection: Option = None; let mut response; + let mut no_buffers = NoBuffers; loop { // The buffer is owned by the connection. Before we have an open connection, // there are no buffers (and the code below should not need it). let buffers = connection .as_mut() - .map(|c| c.borrow_buffers(unit.need_input_as_tmp())) - .unwrap_or(Buffers::empty()); + .map(|c| c.buffers()) + .unwrap_or(&mut no_buffers); match unit.poll_event(current_time(), buffers)? { Event::Reset { must_close } => { @@ -244,7 +245,8 @@ impl Agent { let connection = connection.as_mut().expect("connection for AwaitInput"); match connection.await_input(timeout) { - Ok(Buffers { input, .. }) => { + Ok(_) => { + let input = connection.buffers().input(); unit.handle_input(current_time(), Input::Input { input }, &mut [])? } @@ -264,7 +266,8 @@ impl Agent { Event::AwaitInput { timeout } => { let connection = connection.as_mut().expect("connection for AwaitInput"); - let Buffers { input, output } = connection.await_input(timeout)?; + connection.await_input(timeout)?; + let (input, output) = connection.buffers().input_and_output(); let input_used = unit.handle_input(current_time(), Input::Input { input }, output)?; diff --git a/src/error.rs b/src/error.rs index 8881d130..41996ed8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,14 @@ pub enum Error { #[cfg(feature = "rustls")] #[error("rustls: {0}")] Rustls(#[from] rustls::Error), + + #[cfg(feature = "native-tls")] + #[error("native-tls: {0}")] + NativeTls(#[from] native_tls::Error), + + #[cfg(feature = "native-tls")] + #[error("der: {0}")] + Der(#[from] der::Error), } impl Error { diff --git a/src/lib.rs b/src/lib.rs index c3dc50c0..63c373e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,7 @@ mod test { #[test] fn simple_get() { + env_logger::init(); let mut response = get("https://www.lookback.com/").call().unwrap(); println!("{:#?}", response); let mut body = String::new(); diff --git a/src/pool.rs b/src/pool.rs index b34e2f10..246cae89 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -30,15 +30,15 @@ pub(crate) struct Connection { } impl Connection { - pub fn borrow_buffers(&mut self, input_as_tmp: bool) -> Buffers { - self.conn.borrow_buffers(input_as_tmp) + pub fn buffers(&mut self) -> &mut dyn Buffers { + self.conn.buffers() } pub fn transmit_output(&mut self, amount: usize, timeout: Duration) -> Result<(), Error> { self.conn.transmit_output(amount, timeout) } - pub fn await_input(&mut self, timeout: Duration) -> Result { + pub fn await_input(&mut self, timeout: Duration) -> Result<(), Error> { self.conn.await_input(timeout) } diff --git a/src/recv.rs b/src/recv.rs index 3dcab109..b5c07b7b 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -3,7 +3,6 @@ use std::io::{self, Read}; use crate::pool::Connection; use crate::time::Instant; -use crate::transport::Buffers; use crate::unit::{Event, Input, Unit}; use crate::Error; @@ -32,8 +31,7 @@ impl RecvBody { None => return Ok(0), }; - let buffers = connection.borrow_buffers(false); - let event = self.unit.poll_event((self.current_time)(), buffers)?; + let event = self.unit.poll_event((self.current_time)())?; let timeout = match event { Event::AwaitInput { timeout } => timeout, @@ -50,7 +48,8 @@ impl RecvBody { _ => unreachable!("expected event AwaitInput"), }; - let Buffers { input, .. } = connection.await_input(timeout)?; + connection.await_input(timeout)?; + let input = connection.buffers().input(); let max = input.len().min(buf.len()); let input = &input[..max]; @@ -61,8 +60,7 @@ impl RecvBody { connection.consume_input(input_used); - let buffers = connection.borrow_buffers(false); - let event = self.unit.poll_event((self.current_time)(), buffers)?; + let event = self.unit.poll_event((self.current_time)())?; let output_used = match event { Event::ResponseBody { amount } => amount, diff --git a/src/request.rs b/src/request.rs index 30ec8cff..7d7b86e4 100644 --- a/src/request.rs +++ b/src/request.rs @@ -30,7 +30,7 @@ impl RequestBuilder { /// # Examples /// /// ``` - /// let res = ureq::get("https://host.test/my-path") + /// let res = ureq::get("https://httpbin.org/get") /// .header("Accept", "text/html") /// .header("X-Custom-Foo", "bar") /// .call()?; diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 35155684..d7352063 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -10,6 +10,11 @@ mod rustls; #[cfg(feature = "rustls")] pub use self::rustls::RustlsConnector; +#[cfg(feature = "native-tls")] +mod native_tls; +#[cfg(feature = "native-tls")] +pub use self::native_tls::NativeTlsConnector; + #[derive(Debug, Clone)] pub struct TlsConfig { pub client_cert: Option<(Vec>, Arc>)>, diff --git a/src/tls/native_tls.rs b/src/tls/native_tls.rs new file mode 100644 index 00000000..3a42377d --- /dev/null +++ b/src/tls/native_tls.rs @@ -0,0 +1,199 @@ +use std::convert::TryFrom; +use std::fmt; +use std::io::{Read, Write}; +use std::sync::Arc; +use std::time::Duration; + +use crate::{transport::*, Error}; +use der::pem::LineEnding; +use der::Document; +use http::uri::Scheme; +use native_tls::{Certificate, HandshakeError, Identity, TlsConnector, TlsStream}; +use once_cell::sync::OnceCell; + +use super::TlsConfig; + +#[derive(Default)] +pub struct NativeTlsConnector { + connector: OnceCell>, +} + +impl Connector for NativeTlsConnector { + fn connect( + &self, + details: &ConnectionDetails, + chained: Option>, + ) -> Result>, Error> { + let transport = match chained { + Some(v) => v, + None => panic!("RustlConnector requires a chained transport"), + }; + + // Only add TLS if we are connecting via HTTPS and the transport isn't TLS + // already, otherwise use chained transport as is. + if details.uri.scheme() != Some(&Scheme::HTTPS) || transport.is_tls() { + trace!("NativeTlsConnector skip"); + return Ok(Some(transport)); + } + + let tls_config = &details.config.tls_config; + + // Initialize the connector on first run. + let connector_ref = match self.connector.get() { + Some(v) => v, + None => { + // This is unlikely to be racy, but if it is, doesn't matter much. + let c = build_connector(tls_config)?; + // Maybe someone else set it first. Weird, but ok. + let _ = self.connector.set(c); + self.connector.get().unwrap() + } + }; + let connector = connector_ref.clone(); // cheap clone due to Arc + + let domain = details + .uri + .authority() + .expect("uri authority for tls") + .host() + .to_string(); + + let adapter = TransportAdapter::new(transport); + let stream = LazyStream::Unstarted(Some((connector, domain, adapter))); + + let buffers = LazyBuffers::new( + details.config.input_buffer_size, + details.config.output_buffer_size, + ); + + let transport = Box::new(NativeTlsTransport { buffers, stream }); + + Ok(Some(transport)) + } +} + +fn build_connector(tls_config: &TlsConfig) -> Result, Error> { + let mut builder = TlsConnector::builder(); + + if tls_config.disable_verification { + builder.danger_accept_invalid_certs(true); + builder.danger_accept_invalid_hostnames(true); + } else { + for cert in &tls_config.root_certs { + let c = match Certificate::from_der(cert.der()) { + Ok(v) => v, + Err(e) => { + // Invalid/expired/broken root certs are expected + // in a native root store. + trace!("Skip invalid root cert: {}", e); + continue; + } + }; + builder.add_root_certificate(c); + } + } + + if let Some((certs, key)) = &tls_config.client_cert { + let certs_pem = certs + .iter() + .map(|c| pemify(c.der(), "CERTIFICATE")) + .collect::>()?; + + let key_pem = pemify(key.der(), "PRIVATE KEY")?; + + let identity = Identity::from_pkcs8(certs_pem.as_bytes(), key_pem.as_bytes())?; + builder.identity(identity); + } + + builder.use_sni(tls_config.use_sni); + + let conn = builder.build()?; + + Ok(Arc::new(conn)) +} + +fn pemify(der: &[u8], label: &'static str) -> Result { + let doc = Document::try_from(der)?; + let pem = doc.to_pem(label, LineEnding::LF)?; + Ok(pem) +} + +struct NativeTlsTransport { + buffers: LazyBuffers, + stream: LazyStream, +} + +impl Transport for NativeTlsTransport { + fn buffers(&mut self) -> &mut dyn Buffers { + &mut self.buffers + } + + fn transmit_output(&mut self, amount: usize, timeout: Duration) -> Result<(), Error> { + let stream = self.stream.handshaken()?; + stream.get_mut().timeout = timeout; + + let output = &self.buffers.output()[..amount]; + stream.write_all(output)?; + + Ok(()) + } + + fn await_input(&mut self, timeout: Duration) -> Result<(), Error> { + if self.buffers.can_use_input() { + return Ok(()); + } + + let stream = self.stream.handshaken()?; + stream.get_mut().timeout = timeout; + + let input = self.buffers.input_mut(); + let amount = stream.read(input)?; + self.buffers.add_filled(amount); + + Ok(()) + } + + fn consume_input(&mut self, amount: usize) { + self.buffers.consume(amount); + } + + fn is_tls(&self) -> bool { + true + } +} + +/// Helper to delay the handshake until we are starting IO. +/// This normalizes native-tls to behave like rustls. +enum LazyStream { + Unstarted(Option<(Arc, String, TransportAdapter)>), + Started(TlsStream), +} + +impl LazyStream { + fn handshaken(&mut self) -> Result<&mut TlsStream, Error> { + match self { + LazyStream::Unstarted(v) => { + let (conn, domain, adapter) = v.take().unwrap(); + let stream = conn.connect(&domain, adapter).map_err(|e| match e { + HandshakeError::Failure(e) => e, + HandshakeError::WouldBlock(_) => unreachable!(), + })?; + *self = LazyStream::Started(stream); + // Next time we hit the other match arm + return self.handshaken(); + } + LazyStream::Started(v) => Ok(v), + } + } +} +impl fmt::Debug for NativeTlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NativeTlsConnector").finish() + } +} + +impl fmt::Debug for NativeTlsTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NativeTlsTransport").finish() + } +} diff --git a/src/tls/rustls.rs b/src/tls/rustls.rs index 5e9e46f4..92912475 100644 --- a/src/tls/rustls.rs +++ b/src/tls/rustls.rs @@ -37,9 +37,10 @@ impl Connector for RustlsConnector { None => panic!("RustlConnector requires a chained transport"), }; - // Only add TLS if we are connecting via HTTPS, otherwise - // use chained transport as is. - if details.uri.scheme() != Some(&Scheme::HTTPS) { + // Only add TLS if we are connecting via HTTPS and the transport isn't TLS + // already, otherwise use chained transport as is. + if details.uri.scheme() != Some(&Scheme::HTTPS) || transport.is_tls() { + trace!("RustlsConnector skip"); return Ok(Some(transport)); } @@ -100,7 +101,7 @@ fn build_config(tls_config: &TlsConfig) -> Arc { builder.with_root_certificates(root_store) }; - let config = if let Some((certs, key)) = &tls_config.client_cert { + let mut config = if let Some((certs, key)) = &tls_config.client_cert { let cert_chain = certs .iter() .map(|c| CertificateDer::from(c.der()).into_owned()); @@ -119,6 +120,8 @@ fn build_config(tls_config: &TlsConfig) -> Arc { builder.with_no_client_auth() }; + config.enable_sni = tls_config.use_sni; + Arc::new(config) } @@ -128,38 +131,39 @@ struct RustlsTransport { } impl Transport for RustlsTransport { - fn borrow_buffers(&mut self, input_as_tmp: bool) -> Buffers { - self.buffers.borrow_mut(input_as_tmp) + fn buffers(&mut self) -> &mut dyn Buffers { + &mut self.buffers } fn transmit_output(&mut self, amount: usize, timeout: Duration) -> Result<(), Error> { - let buffers = self.buffers.borrow_mut(false); - self.stream.sock.timeout = timeout; - self.stream.write_all(&buffers.output[..amount])?; + self.stream.get_mut().timeout = timeout; + + let output = &self.buffers.output()[..amount]; + self.stream.write_all(output)?; + Ok(()) } - fn await_input(&mut self, timeout: Duration) -> Result { - if self.buffers.unconsumed() > 0 { - return Ok(self.buffers.borrow_mut(false)); + fn await_input(&mut self, timeout: Duration) -> Result<(), Error> { + if self.buffers.can_use_input() { + return Ok(()); } - // Ensure we get the entire input buffer to write to. - self.buffers.assert_and_clear_input_filled(); - - // Read more - self.stream.sock.timeout = timeout; - let buffers = self.buffers.borrow_mut(false); - let amount = self.stream.read(buffers.input)?; + self.stream.get_mut().timeout = timeout; - // Cap the input - self.buffers.set_input_filled(amount); + let input = self.buffers.input_mut(); + let amount = self.stream.read(input)?; + self.buffers.add_filled(amount); - Ok(self.buffers.borrow_mut(false)) + Ok(()) } fn consume_input(&mut self, amount: usize) { - self.buffers.consume_input(amount) + self.buffers.consume(amount); + } + + fn is_tls(&self) -> bool { + true } } diff --git a/src/transport/buf.rs b/src/transport/buf.rs new file mode 100644 index 00000000..f5842fbb --- /dev/null +++ b/src/transport/buf.rs @@ -0,0 +1,168 @@ +pub trait Buffers { + fn output(&self) -> &[u8]; + fn output_mut(&mut self) -> &mut [u8]; + fn input(&self) -> &[u8]; + fn input_mut(&mut self) -> &mut [u8]; + fn input_and_output(&mut self) -> (&[u8], &mut [u8]); + fn tmp_and_output(&mut self) -> (&mut [u8], &mut [u8]); + fn add_filled(&mut self, amount: usize); + fn consume(&mut self, amount: usize); + fn can_use_input(&self) -> bool; +} + +pub struct LazyBuffers { + input_size: usize, + output_size: usize, + + input: Vec, + output: Vec, + + filled: usize, + consumed: usize, + progress: bool, +} + +impl LazyBuffers { + pub fn empty() -> Self { + Self::new(0, 0) + } + + pub fn new(input_size: usize, output_size: usize) -> Self { + assert!(input_size > 0); + assert!(output_size > 0); + + LazyBuffers { + input_size, + output_size, + + // Vectors don't allocate until they get a size. + input: vec![], + output: vec![], + + filled: 0, + consumed: 0, + + progress: false, + } + } + + fn ensure_allocation(&mut self) { + if self.output.len() < self.output_size { + self.output.resize(self.output_size, 0); + } + if self.input.len() < self.input_size { + self.input.resize(self.input_size, 0); + } + } +} + +impl Buffers for LazyBuffers { + fn output(&self) -> &[u8] { + &self.output + } + + fn output_mut(&mut self) -> &mut [u8] { + self.ensure_allocation(); + &mut self.output + } + + fn input(&self) -> &[u8] { + &self.input[self.consumed..self.filled] + } + + fn input_mut(&mut self) -> &mut [u8] { + self.ensure_allocation(); + + // Move if needed + if self.consumed >= self.input_size / 2 { + self.input.copy_within(self.consumed..self.filled, 0); + self.consumed = 0; + self.filled -= self.consumed; + } + + &mut self.input[self.filled..] + } + + fn input_and_output(&mut self) -> (&[u8], &mut [u8]) { + self.ensure_allocation(); + (&self.input[self.consumed..self.filled], &mut self.output) + } + + fn tmp_and_output(&mut self) -> (&mut [u8], &mut [u8]) { + self.ensure_allocation(); + const MIN_TMP_SIZE: usize = 10 * 1024; + + let tmp_available = self.input.len() - self.filled; + + if tmp_available < MIN_TMP_SIZE { + // The tmp space is used for reading the request body from the + // Body as a Read. There's an outside chance there isn't any space + // left in the input buffer if we have done Await100 and the peer + // started sending a ton of data before we asked for it. + // It's a pathological situation that we don't need to make work well. + let needed = MIN_TMP_SIZE - tmp_available; + self.input.resize(self.input.len() + needed, 0); + } + + (&mut self.input[self.filled..], &mut self.output) + } + + fn add_filled(&mut self, amount: usize) { + self.filled += amount; + } + + fn consume(&mut self, amount: usize) { + self.progress = amount > 0; + + self.consumed += amount; + + if self.consumed == self.filled { + self.consumed = 0; + self.filled = 0; + } + } + + fn can_use_input(&self) -> bool { + self.consumed < self.filled && self.progress + } +} + +pub struct NoBuffers; + +impl Buffers for NoBuffers { + fn output(&self) -> &[u8] { + &[] + } + + fn output_mut(&mut self) -> &mut [u8] { + &mut [] + } + + fn input(&self) -> &[u8] { + &[] + } + + fn input_mut(&mut self) -> &mut [u8] { + &mut [] + } + + fn input_and_output(&mut self) -> (&[u8], &mut [u8]) { + (&[], &mut []) + } + + fn tmp_and_output(&mut self) -> (&mut [u8], &mut [u8]) { + (&mut [], &mut []) + } + + fn add_filled(&mut self, _amount: usize) { + unreachable!() + } + + fn consume(&mut self, _amount: usize) { + unreachable!() + } + + fn can_use_input(&self) -> bool { + unreachable!() + } +} diff --git a/src/transport/io.rs b/src/transport/io.rs index 392910ee..7bd0a276 100644 --- a/src/transport/io.rs +++ b/src/transport/io.rs @@ -9,6 +9,7 @@ pub struct TransportAdapter { pub timeout: Duration, pub transport: Box, } + impl TransportAdapter { pub(crate) fn new(transport: Box) -> Self { Self { @@ -20,13 +21,13 @@ impl TransportAdapter { impl io::Read for TransportAdapter { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let buffers = self - .transport + self.transport .await_input(self.timeout) .map_err(|e| e.into_io())?; + let input = self.transport.buffers().input(); - let max = buf.len().min(buffers.input.len()); - buf[..max].copy_from_slice(&buffers.input[..max]); + let max = buf.len().min(input.len()); + buf[..max].copy_from_slice(&input[..max]); self.transport.consume_input(max); Ok(max) @@ -35,10 +36,10 @@ impl io::Read for TransportAdapter { impl io::Write for TransportAdapter { fn write(&mut self, buf: &[u8]) -> io::Result { - let buffers = self.transport.borrow_buffers(false); + let output = self.transport.buffers().output_mut(); - let max = buf.len().min(buffers.output.len()); - buffers.output[..max].copy_from_slice(&buf[..max]); + let max = buf.len().min(output.len()); + output[..max].copy_from_slice(&buf[..max]); self.transport .transmit_output(max, self.timeout) .map_err(|e| e.into_io())?; diff --git a/src/transport/lazybuf.rs b/src/transport/lazybuf.rs deleted file mode 100644 index 4c2e60ec..00000000 --- a/src/transport/lazybuf.rs +++ /dev/null @@ -1,129 +0,0 @@ -use super::Buffers; - -pub struct LazyBuffers { - input_size: usize, - output_size: usize, - input: Vec, - output: Vec, - - // We have two modes. One where input is filled with some incoming data, - // and one where we can use it freely. These are represented by - // Some/None in this Option respectively. - input_filled: Option, - - // If we have input_filled: Some(value), this is the amount of that value - // we have consumed. - input_consumed: usize, -} - -impl LazyBuffers { - pub fn new(input_size: usize, output_size: usize) -> Self { - assert!(input_size > 0); - assert!(output_size > 0); - - LazyBuffers { - input_size, - output_size, - // Vectors don't allocate until they get a size. - input: vec![], - output: vec![], - - input_filled: None, - input_consumed: 0, - } - } - - /// Borrow the buffers. - /// - /// This allocates first time it's used. - /// - /// The input buffer might be scaled to what's left unconsumed if we are in "fill mode". - pub fn borrow_mut(&mut self, input_as_tmp: bool) -> Buffers<'_> { - if self.input.is_empty() { - self.input.resize(self.input_size, 0); - } - if self.output.is_empty() { - self.output.resize(self.output_size, 0); - } - - if input_as_tmp && self.input_filled.is_some() { - panic!("input used as tmp when filled"); - } - - // Unput is scaled to whatever is unconsumed. - let input = if let Some(filled) = self.input_filled { - &mut self.input[self.input_consumed..filled] - } else { - &mut self.input[..] - }; - - Buffers { - input, - output: &mut self.output, - } - } - - /// Query how much input is unconsumed. - pub fn unconsumed(&self) -> usize { - if let Some(filled) = self.input_filled { - filled - .checked_sub(self.input_consumed) - // This is an error condition. Something in the buffer handling - // has consumed more than is possible. - .expect("consumed is greater than filled") - } else { - 0 - } - } - - /// Switch mode to "filled input" by setting how much of the input was filled. - /// - /// There cannot be a previous set_input_filled that hasn't been entirely consumed. - pub fn set_input_filled(&mut self, input_filled: usize) { - // Assert there isn't unconsumed input. - self.assert_and_clear_input_filled(); - - self.input_filled = Some(input_filled); - } - - /// Switch mode to "free input" by unsetting the filled value. This checks the - /// entire input was consumed. - pub fn assert_and_clear_input_filled(&mut self) { - let unconsumed = self.unconsumed(); - - if unconsumed > 0 { - // This is a hard error. It indicates a state bug higher up in ureq. Ignoring - // it would be a security risk because we would silently discard input sent - // by the remote server potentially opening for request smuggling - // attacks etc. - panic!("input contains {} unconsumed bytes", unconsumed); - } - - self.input_filled = None; - self.input_consumed = 0; - } - - /// Mark some input as consumed. - /// - /// This ensure we are in the correct "fill mode" and that there are bytes left to consume. - pub fn consume_input(&mut self, amount: usize) { - // This indicates the order of calls is not correct. We must - // first set_input_fileld(), then consume_input() - assert!( - self.input_filled.is_some(), - "consume without a filled buffer" - ); - - // This indicates some state bug where the caller tries to consume - // more than is filled. - if amount > self.unconsumed() { - panic!( - "consume more than unconsumed {} > {}", - amount, - self.unconsumed() - ); - } - - self.input_consumed += amount; - } -} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index cc1781a1..86017d54 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -6,6 +6,7 @@ use http::Uri; use crate::proxy::Proxy; use crate::resolver::Resolver; +use crate::tls::NativeTlsConnector; use crate::{AgentConfig, Error}; #[cfg(feature = "rustls")] @@ -13,8 +14,8 @@ use crate::tls::RustlsConnector; use self::tcp::TcpConnector; -mod lazybuf; -pub use lazybuf::LazyBuffers; +mod buf; +pub use buf::{Buffers, LazyBuffers, NoBuffers}; mod tcp; @@ -48,26 +49,21 @@ pub struct ConnectionDetails<'a> { } pub trait Transport: Debug + Send + Sync { - fn borrow_buffers(&mut self, input_as_tmp: bool) -> Buffers; + fn buffers(&mut self) -> &mut dyn Buffers; fn transmit_output(&mut self, amount: usize, timeout: Duration) -> Result<(), Error>; - fn await_input(&mut self, timeout: Duration) -> Result; + fn await_input(&mut self, timeout: Duration) -> Result<(), Error>; fn consume_input(&mut self, amount: usize); -} - -pub struct Buffers<'a> { - pub input: &'a mut [u8], - pub output: &'a mut [u8], -} - -impl Buffers<'_> { - pub(crate) fn empty() -> Buffers<'static> { - Buffers { - input: &mut [], - output: &mut [], - } + fn is_tls(&self) -> bool { + false } } +// fn buffer_output(&mut self) -> &mut [u8]; +// fn buffer_tmp_and_output(&mut self) -> (&mut [u8], &mut [u8]); +// fn transmit_output(&mut self, amount: usize, timeout: Duration) -> Result<(), Error>; +// fn await_input(&mut self, timeout: Duration) -> Result<&[u8], Error>; +// fn consume_input(&mut self, amount: usize); + #[derive(Debug)] pub struct ChainedConnector { chain: Vec>, @@ -108,6 +104,8 @@ impl DefaultConnector { TcpConnector.boxed(), #[cfg(feature = "rustls")] RustlsConnector::default().boxed(), + #[cfg(feature = "native-tls")] + NativeTlsConnector::default().boxed(), ]); DefaultConnector { chain } diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 3d6b9d1b..bf89d385 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -72,8 +72,8 @@ fn maybe_update_timeout( } impl Transport for TcpTransport { - fn borrow_buffers(&mut self, input_as_tmp: bool) -> Buffers { - self.buffers.borrow_mut(input_as_tmp) + fn buffers(&mut self) -> &mut dyn Buffers { + &mut self.buffers } fn transmit_output(&mut self, amount: usize, timeout: Duration) -> Result<(), Error> { @@ -84,17 +84,15 @@ impl Transport for TcpTransport { TcpStream::set_write_timeout, )?; - let buffers = self.buffers.borrow_mut(false); - let output = &buffers.output[..amount]; + let output = &self.buffers.output()[..amount]; self.stream.write_all(output).normalize_would_block()?; Ok(()) } - fn await_input(&mut self, timeout: Duration) -> Result { - // There might be input left from the previous await_input. - if self.buffers.unconsumed() > 0 { - return Ok(self.buffers.borrow_mut(false)); + fn await_input(&mut self, timeout: Duration) -> Result<(), Error> { + if self.buffers.can_use_input() { + return Ok(()); } // Proceed to fill the buffers from the TcpStream @@ -105,20 +103,15 @@ impl Transport for TcpTransport { TcpStream::set_read_timeout, )?; - // Ensure we get the entire input buffer to write to. - self.buffers.assert_and_clear_input_filled(); + let input = self.buffers.input_mut(); + let amount = self.stream.read(input)?; + self.buffers.add_filled(amount); - let buffers = self.buffers.borrow_mut(false); - let amount = self.stream.read(buffers.input).normalize_would_block()?; - - // Cap the input buffer. - self.buffers.set_input_filled(amount); - - Ok(self.buffers.borrow_mut(false)) + Ok(()) } fn consume_input(&mut self, amount: usize) { - self.buffers.consume_input(amount); + self.buffers.consume(amount); } } diff --git a/src/unit.rs b/src/unit.rs index b52916ef..6be3149e 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -85,7 +85,7 @@ impl<'b> Unit> { }) } - pub fn poll_event(&mut self, now: Instant, buffers: Buffers) -> Result { + pub fn poll_event(&mut self, now: Instant, buffers: &mut dyn Buffers) -> Result { // Queued events go first. if let Some(queued) = self.queued_event.pop_front() { return Ok(queued); @@ -110,19 +110,17 @@ impl<'b> Unit> { // since self.state is not borrowed. fn poll_event_static( &mut self, - buffers: Buffers, + buffers: &mut dyn Buffers, timeout: Duration, ) -> Result>, Error> { - let Buffers { input, output } = buffers; - Ok(match &mut self.state { State::Begin(_) => Some(Event::Reset { must_close: false }), // State::Resolve (see below) // State::OpenConnection (see below) - State::SendRequest(flow) => Some(send_request(flow, output, timeout)?), + State::SendRequest(flow) => Some(send_request(flow, buffers.output_mut(), timeout)?), - State::SendBody(flow) => Some(send_body(flow, input, output, timeout, &mut self.body)?), + State::SendBody(flow) => Some(send_body(flow, buffers, timeout, &mut self.body)?), State::Await100(_) => Some(Event::Await100 { timeout }), @@ -339,17 +337,11 @@ impl<'b> Unit> { redirect_count: self.redirect_count, } } - - // When we are doing SendBody, we user Buffers::input as a temporary scratch space - // for reading from the Body<'a> (as a reader) to write the output. - pub fn need_input_as_tmp(&self) -> bool { - matches!(self.state, State::SendBody(_)) - } } // Unit<()> is for receiving the body. We have let go of the input body. impl Unit<()> { - pub fn poll_event(&mut self, now: Instant, _buffers: Buffers) -> Result { + pub fn poll_event(&mut self, now: Instant) -> Result { // Queued events go first. if let Some(queued) = self.queued_event.pop_front() { return Ok(queued); @@ -459,12 +451,13 @@ fn send_request( fn send_body( flow: &mut Flow, - input: &mut [u8], - output: &mut [u8], + buffers: &mut dyn Buffers, timeout: Duration, body: &mut Body, ) -> Result, Error> { - let input_len = input.len(); + let (tmp, output) = buffers.tmp_and_output(); + + let input_len = tmp.len(); let overhead = flow.calculate_output_overhead(output.len())?; assert!(input_len > overhead); @@ -481,10 +474,10 @@ fn send_body( output_used } else { - let input = &mut input[..max_input]; - let n = body.read(input)?; + let tmp = &mut tmp[..max_input]; + let n = body.read(tmp)?; - let (input_used, output_used) = flow.write(&input[..n], output)?; + let (input_used, output_used) = flow.write(&tmp[..n], output)?; // Since output is "a bit" larger than the input (compensate for chunk ovherhead), // the entire input we read from the body should also be shipped to the output.