Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor connect network request extension #210

Merged
merged 6 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 79 additions & 50 deletions src/client/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Duration;
use std::{collections::HashMap, convert::TryInto, net::SocketAddr};
use std::{fmt, str};

use crate::util::client::{InnerRequest, NetworkScheme};
use crate::util::{
self, client::connect::HttpConnector, client::Builder, common::Exec, rt::TokioExecutor,
};
Expand All @@ -24,7 +25,7 @@ use std::task::{Context, Poll};
use tokio::time::Sleep;

use super::decoder::Accepts;
use super::request::{InnerRequest, Request, RequestBuilder};
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::Body;
use crate::connect::Connector;
Expand Down Expand Up @@ -204,8 +205,6 @@ impl ClientBuilder {
if config.auto_sys_proxy {
proxies.push(Proxy::system());
}
let proxies = Arc::new(proxies);

let proxies_maybe_http_auth = proxies.iter().any(|p| p.maybe_has_http_auth());

let mut connector = {
Expand Down Expand Up @@ -233,17 +232,7 @@ impl ClientBuilder {
http.set_connect_timeout(config.connect_timeout);

let tls = BoringTlsConnector::new(config.tls)?;
Connector::new_boring_tls(
http,
tls,
proxies,
config.local_address_ipv4,
config.local_address_ipv6,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
config.interface,
config.nodelay,
config.tls_info,
)
Connector::new_boring_tls(http, tls, config.nodelay, config.tls_info)
};

connector.set_timeout(config.connect_timeout);
Expand Down Expand Up @@ -272,6 +261,12 @@ impl ClientBuilder {
proxies_maybe_http_auth,
base_url: config.base_url.map(Arc::new),
http2_max_retry_count: config.http2_max_retry_count,

proxies,
local_addr_v4: config.local_address_ipv4,
local_addr_v6: config.local_address_ipv6,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
interface: config.interface,
}),
})
}
Expand Down Expand Up @@ -1303,15 +1298,14 @@ impl Client {
self.inner.proxy_auth(&uri, &mut headers);

let in_flight = {
let pool_key = self.inner.hyper.pool_key(&uri);
let req = InnerRequest::new()
let req = InnerRequest::<Body>::builder()
.network_scheme(self.inner.network_scheme(&uri, None))
.uri(uri)
.method(method.clone())
.version(version)
.headers(headers.clone())
.headers_order(self.inner.headers_order.as_deref())
.pool_key(pool_key)
.build(body);
.body(body);

ResponseFuture::Default(self.inner.hyper.request(req))
};
Expand Down Expand Up @@ -1411,33 +1405,31 @@ impl Client {
///
/// Returns the old proxies.
#[inline]
pub fn set_proxies(&mut self, proxies: impl Into<Cow<'static, [Proxy]>>) -> Vec<Proxy> {
let (inner, proxies) = self.apply_proxies(proxies);
inner.set_proxies(proxies)
pub fn set_proxies(&mut self, proxies: impl Into<Cow<'static, [Proxy]>>) {
self.apply_proxies(proxies, true);
}

/// Append the proxies to the client.
#[inline]
pub fn append_proxies(&mut self, proxies: impl Into<Cow<'static, [Proxy]>>) {
let (inner, proxies) = self.apply_proxies(proxies);
inner.append_proxies(proxies);
}

/// Private helper to handle setting or appending proxies.
fn apply_proxies(
&mut self,
proxies: impl Into<Cow<'static, [Proxy]>>,
) -> (&mut HyperClient, Cow<'static, [Proxy]>) {
let proxies = proxies.into();
let inner = self.inner_mut();
inner.proxies_maybe_http_auth = proxies.iter().any(|p| p.maybe_has_http_auth());
(&mut inner.hyper, proxies)
self.apply_proxies(proxies, false);
}

/// Unset the proxies for this client.
#[inline]
pub fn unset_proxies(&mut self) {
self.inner_mut().hyper.clear_proxies();
self.inner_mut().proxies.clear();
}

/// Private helper to handle setting or appending proxies.
fn apply_proxies(&mut self, proxies: impl Into<Cow<'static, [Proxy]>>, r#override: bool) {
let inner = self.inner_mut();
let proxies = proxies.into();
inner.proxies_maybe_http_auth = proxies.iter().any(|p| p.maybe_has_http_auth());
if r#override {
inner.proxies.clear();
}
inner.proxies.extend(proxies.into_owned());
}

/// Set that all sockets are bound to the configured address before connection.
Expand All @@ -1450,23 +1442,28 @@ impl Client {
where
T: Into<Option<IpAddr>>,
{
self.inner_mut().hyper.set_local_address(addr.into());
let inner = self.inner_mut();
match addr.into() {
Some(IpAddr::V4(a)) => inner.local_addr_v4 = Some(a),
Some(IpAddr::V6(a)) => inner.local_addr_v6 = Some(a),
_ => (),
}
}

/// Set that all sockets are bound to the configured IPv4 or IPv6 address
/// (depending on host's preferences) before connection.
#[inline]
pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
self.inner_mut()
.hyper
.set_local_addresses(addr_ipv4, addr_ipv6);
let inner = self.inner_mut();
inner.local_addr_v4 = Some(addr_ipv4);
inner.local_addr_v6 = Some(addr_ipv6);
}

/// Bind to an interface by `SO_BINDTODEVICE`.
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
#[inline]
pub fn set_interface(&mut self, interface: impl Into<std::borrow::Cow<'static, str>>) {
self.inner_mut().hyper.set_interface(interface.into());
self.inner_mut().interface = Some(interface.into());
}

/// Set the headers order for this client.
Expand Down Expand Up @@ -1676,6 +1673,12 @@ struct ClientRef {
proxies_maybe_http_auth: bool,
base_url: Option<Arc<Url>>,
http2_max_retry_count: usize,

proxies: Vec<Proxy>,
local_addr_v4: Option<Ipv4Addr>,
local_addr_v6: Option<Ipv6Addr>,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
interface: Option<std::borrow::Cow<'static, str>>,
}

impl_debug!(
Expand Down Expand Up @@ -1712,15 +1715,43 @@ impl ClientRef {
// Find the first proxy that matches the destination URI
// If a matching proxy provides an HTTP basic auth header, insert it into the headers
if let Some(header) = self
.hyper
.get_proxies()
.proxies
.iter()
.find(|proxy| proxy.maybe_has_http_auth() && proxy.is_match(dst))
.and_then(|proxy| proxy.http_basic_auth(dst))
{
headers.insert(PROXY_AUTHORIZATION, header);
}
}

#[inline]
fn network_scheme(&self, uri: &Uri, request_proxy: Option<Proxy>) -> NetworkScheme {
// If the request has no proxy, use the client's local addresses
#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
let mut builder = NetworkScheme::builder().iface((self.local_addr_v4, self.local_addr_v6));

// Use the client's interface if it's set
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
let mut builder = NetworkScheme::builder().iface((
self.interface.clone(),
self.local_addr_v4,
self.local_addr_v6,
));

// If the request has a proxy, use it
if let Some(proxy_scheme) = request_proxy.and_then(|p| p.intercept(uri)) {
builder = builder.proxy(proxy_scheme);
} else {
// Otherwise, use the client's proxies
for proxy in self.proxies.iter() {
if let Some(proxy_scheme) = proxy.intercept(uri) {
return builder.proxy(proxy_scheme).build();
}
}
}

builder.build()
}
}

pin_project! {
Expand Down Expand Up @@ -1821,15 +1852,14 @@ impl PendingRequest {
};

*self.as_mut().in_flight().get_mut() = {
let pool_key = self.client.hyper.pool_key(&uri);
let req = InnerRequest::new()
let req = InnerRequest::<Body>::builder()
.network_scheme(self.client.network_scheme(&uri, None))
.uri(uri)
.method(self.method.clone())
.version(self.version)
.headers(self.headers.clone())
.headers_order(self.client.headers_order.as_deref())
.pool_key(pool_key)
.build(body);
.body(body);
ResponseFuture::Default(self.client.hyper.request(req))
};

Expand Down Expand Up @@ -2068,15 +2098,14 @@ impl Future for PendingRequest {
self.client.proxy_auth(&uri, &mut headers);

*self.as_mut().in_flight().get_mut() = {
let pool_key = self.client.hyper.pool_key(&uri);
let req = InnerRequest::new()
let req = InnerRequest::<Body>::builder()
.network_scheme(self.client.network_scheme(&uri, None))
.uri(uri)
.method(self.method.clone())
.version(self.version)
.headers(headers.clone())
.headers_order(self.client.headers_order.as_deref())
.pool_key(pool_key)
.build(body);
.body(body);
std::mem::swap(self.as_mut().headers(), &mut headers);
ResponseFuture::Default(self.client.hyper.request(req))
};
Expand Down
112 changes: 1 addition & 111 deletions src/client/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use std::fmt;
use std::future::Future;
use std::time::Duration;

use http::header::CONTENT_LENGTH;
use http::{request::Parts, Request as HttpRequest, Uri, Version};
use http::{request::Parts, Request as HttpRequest, Version};
use serde::Serialize;
#[cfg(feature = "json")]
use serde_json;
Expand All @@ -14,11 +13,9 @@ use super::http::{Client, Pending};
#[cfg(feature = "multipart")]
use super::multipart;
use super::response::Response;
use super::HttpVersionPref;
#[cfg(feature = "cookies")]
use crate::cookie;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE, HOST};
use crate::util::ext::{ConnectExtension, PoolKeyExtension, VersionExtension};
use crate::{redirect, Method, Url};
#[cfg(feature = "cookies")]
use std::sync::Arc;
Expand Down Expand Up @@ -706,110 +703,3 @@ impl TryFrom<Request> for HttpRequest<Body> {
Ok(req)
}
}

/// A builder for constructing HTTP requests.
pub(crate) struct InnerRequest<'a> {
builder: http::request::Builder,
headers_order: Option<&'a [HeaderName]>,
}

impl<'a> InnerRequest<'a> {
/// Create a new `RequestBuilder` with required fields.
#[inline]
pub fn new() -> Self {
Self {
builder: hyper2::Request::builder(),
headers_order: None,
}
}

/// Set the method for the request.
#[inline]
pub fn method(mut self, method: Method) -> Self {
self.builder = self.builder.method(method);
self
}

/// Set the URI for the request.
#[inline]
pub fn uri(mut self, uri: Uri) -> Self {
self.builder = self.builder.uri(uri);
self
}

/// Set the version for the request.
#[inline]
pub fn version(mut self, version: Option<Version>) -> Self {
if let Some(version) = version {
match version {
Version::HTTP_11 | Version::HTTP_10 | Version::HTTP_09 => {
self.builder = self
.builder
.extension(ConnectExtension(VersionExtension(HttpVersionPref::Http1)));
}
Version::HTTP_2 => {
self.builder = self
.builder
.extension(ConnectExtension(VersionExtension(HttpVersionPref::Http2)));
}
_ => {}
};
}
self
}

/// Set the headers for the request.
#[inline]
pub fn headers(mut self, mut headers: HeaderMap) -> Self {
if let Some(h) = self.builder.headers_mut() {
std::mem::swap(h, &mut headers);
}
self
}

/// Set the headers order for the request.
#[inline]
pub fn headers_order(mut self, order: Option<&'a [HeaderName]>) -> Self {
self.headers_order = order;
self
}

/// Set an pool key extension for the request.
#[inline]
pub fn pool_key(mut self, pool_key: Option<PoolKeyExtension>) -> Self {
if let Some(pool_key) = pool_key {
self.builder = self.builder.extension(ConnectExtension(pool_key));
}
self
}

/// Build and return the constructed request.
pub fn build(mut self, body: Body) -> http::Request<Body> {
// Sort headers if headers_order is provided
if let Some(order) = self.headers_order {
let method = self.builder.method_ref().cloned();
let headers_mut = self.builder.headers_mut();
if let (Some(headers), Some(method)) = (headers_mut, method) {
{
// Add CONTENT_LENGTH header if required
if let Some(len) = http_body::Body::size_hint(&body).exact() {
let needs_content_length = len != 0
|| !matches!(
method,
Method::GET | Method::HEAD | Method::DELETE | Method::CONNECT
);
if needs_content_length {
headers
.entry(CONTENT_LENGTH)
.or_insert_with(|| HeaderValue::from(len));
}
}
// Sort headers
crate::util::sort_headers(headers, order);
}
}
}

self.builder.body(body).expect("valid request parts")
}
}
Loading