Skip to content

Commit

Permalink
make the input type + the output GAT unnameable in `connector_layer()
Browse files Browse the repository at this point in the history
  • Loading branch information
jlizen committed Dec 20, 2024
1 parent 3b6a3c8 commit f4a0412
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 81 deletions.
8 changes: 5 additions & 3 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use crate::async_impl::h3_client::connect::H3Connector;
#[cfg(feature = "http3")]
use crate::async_impl::h3_client::{H3Client, H3ResponseFuture};
use crate::connect::{
BoxedConnectorLayer, BoxedConnectorService, Conn, Connector, ConnectorBuilder,
sealed::{Conn, Unnameable},
BoxedConnectorLayer, BoxedConnectorService, Connector, ConnectorBuilder,
};
#[cfg(feature = "cookies")]
use crate::cookie;
Expand Down Expand Up @@ -1987,8 +1988,9 @@ impl ClientBuilder {
pub fn connector_layer<L>(mut self, layer: L) -> ClientBuilder
where
L: Layer<BoxedConnectorService> + Clone + Send + Sync + 'static,
L::Service: Service<Uri, Response = Conn, Error = BoxError> + Clone + Send + Sync + 'static,
<L::Service as Service<Uri>>::Future: Send + 'static,
L::Service:
Service<Unnameable, Response = Conn, Error = BoxError> + Clone + Send + Sync + 'static,
<L::Service as Service<Unnameable>>::Future: Send + 'static,
{
let layer = BoxCloneSyncServiceLayer::new(layer);

Expand Down
8 changes: 4 additions & 4 deletions src/blocking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use std::thread;
use std::time::Duration;

use http::header::HeaderValue;
use http::Uri;
use log::{error, trace};
use tokio::sync::{mpsc, oneshot};
use tower::Layer;
Expand All @@ -19,8 +18,8 @@ use tower::Service;
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::wait;
use crate::connect::sealed::{Conn, Unnameable};
use crate::connect::BoxedConnectorService;
use crate::connect::Conn;
use crate::dns::Resolve;
use crate::error::BoxError;
#[cfg(feature = "__tls")]
Expand Down Expand Up @@ -998,8 +997,9 @@ impl ClientBuilder {
pub fn connector_layer<L>(self, layer: L) -> ClientBuilder
where
L: Layer<BoxedConnectorService> + Clone + Send + Sync + 'static,
L::Service: Service<Uri, Response = Conn, Error = BoxError> + Clone + Send + Sync + 'static,
<L::Service as Service<Uri>>::Future: Send + 'static,
L::Service:
Service<Unnameable, Response = Conn, Error = BoxError> + Clone + Send + Sync + 'static,
<L::Service as Service<Unnameable>>::Future: Send + 'static,
{
self.with_inner(|inner| inner.connector_layer(layer))
}
Expand Down
152 changes: 78 additions & 74 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ use hyper_util::client::legacy::connect::{Connected, Connection};
use hyper_util::rt::TokioIo;
#[cfg(feature = "default-tls")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tower::util::BoxCloneSyncServiceLayer;
use pin_project_lite::pin_project;
use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer};
use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder};
use tower_service::Service;

use pin_project_lite::pin_project;
use std::future::Future;
use std::io::{self, IoSlice};
use std::net::IpAddr;
Expand All @@ -28,6 +28,7 @@ use self::rustls_tls_conn::RustlsTlsConn;
use crate::dns::DynResolver;
use crate::error::{cast_to_internal_error, BoxError};
use crate::proxy::{Proxy, ProxyScheme};
use sealed::{Conn, Unnameable};

pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector<DynResolver>;

Expand All @@ -37,7 +38,7 @@ pub(crate) enum Connector {
Simple(ConnectorService),
// at least one custom layer along with maybe an outer timeout layer
// from `builder.connect_timeout()`
WithLayers(BoxCloneSyncService<Uri, Conn, BoxError>),
WithLayers(BoxCloneSyncService<Unnameable, Conn, BoxError>),
}

impl Service<Uri> for Connector {
Expand All @@ -55,15 +56,15 @@ impl Service<Uri> for Connector {
fn call(&mut self, dst: Uri) -> Self::Future {
match self {
Connector::Simple(service) => service.call(dst),
Connector::WithLayers(service) => service.call(dst),
Connector::WithLayers(service) => service.call(Unnameable(dst)),
}
}
}

pub(crate) type BoxedConnectorService = BoxCloneSyncService<Uri, Conn, BoxError>;
pub(crate) type BoxedConnectorService = BoxCloneSyncService<Unnameable, Conn, BoxError>;

pub(crate) type BoxedConnectorLayer =
BoxCloneSyncServiceLayer<BoxedConnectorService, Uri, Conn, BoxError>;
BoxCloneSyncServiceLayer<BoxedConnectorService, Unnameable, Conn, BoxError>;

pub(crate) struct ConnectorBuilder {
inner: Inner,
Expand Down Expand Up @@ -103,8 +104,11 @@ where {

// otherwise we have user provided layers
// so we need type erasure all the way through

let mut service = BoxCloneSyncService::new(base_service);
// as well as mapping the unnameable type of the layers back to Uri for the inner service
let unnameable_service = ServiceBuilder::new()
.layer(MapRequestLayer::new(|request: Unnameable| request.0))
.service(base_service);
let mut service = BoxCloneSyncService::new(nameable_service);

for layer in layers {
service = ServiceBuilder::new().layer(layer).service(service);
Expand Down Expand Up @@ -749,87 +753,87 @@ impl<T: AsyncConn> AsyncConnWithInfo for T {}

type BoxConn = Box<dyn AsyncConnWithInfo>;

pin_project! {
/// Note: the `is_proxy` member means *is plain text HTTP proxy*.
/// This tells hyper whether the URI should be written in
/// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
/// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
// Currently Conn is public but has no implementation details exposed.
// We need this because we support pre-1.74 rust versions where `private-in-public`
// is a hard error rather than lint warning.
// Eventually we probably will want to expose some elements of the connection stream
// for layer handling on the tower backswing.
#[allow(missing_debug_implementations)]
pub struct Conn {
#[pin]
inner: BoxConn,
is_proxy: bool,
// Only needed for __tls, but #[cfg()] on fields breaks pin_project!
tls_info: bool,
pub(crate) mod sealed {
use super::*;
#[derive(Debug, Clone)]
pub struct Unnameable(pub(super) Uri);

pin_project! {
/// Note: the `is_proxy` member means *is plain text HTTP proxy*.
/// This tells hyper whether the URI should be written in
/// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
/// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
#[allow(missing_debug_implementations)]
pub struct Conn {
#[pin]
pub(super)inner: BoxConn,
pub(super) is_proxy: bool,
// Only needed for __tls, but #[cfg()] on fields breaks pin_project!
pub(super) tls_info: bool,
}
}
}

impl Connection for Conn {
fn connected(&self) -> Connected {
let connected = self.inner.connected().proxy(self.is_proxy);
#[cfg(feature = "__tls")]
if self.tls_info {
if let Some(tls_info) = self.inner.tls_info() {
connected.extra(tls_info)
impl Connection for Conn {
fn connected(&self) -> Connected {
let connected = self.inner.connected().proxy(self.is_proxy);
#[cfg(feature = "__tls")]
if self.tls_info {
if let Some(tls_info) = self.inner.tls_info() {
connected.extra(tls_info)
} else {
connected
}
} else {
connected
}
} else {
#[cfg(not(feature = "__tls"))]
connected
}
#[cfg(not(feature = "__tls"))]
connected
}
}

impl Read for Conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
Read::poll_read(this.inner, cx, buf)
impl Read for Conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
Read::poll_read(this.inner, cx, buf)
}
}
}

impl Write for Conn {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
Write::poll_write(this.inner, cx, buf)
}
impl Write for Conn {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
Write::poll_write(this.inner, cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
Write::poll_write_vectored(this.inner, cx, bufs)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
Write::poll_write_vectored(this.inner, cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
Write::poll_flush(this.inner, cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
Write::poll_flush(this.inner, cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
Write::poll_shutdown(this.inner, cx)
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
Write::poll_shutdown(this.inner, cx)
}
}
}

Expand Down

0 comments on commit f4a0412

Please sign in to comment.