Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client): strip path from Uri before calling Connector #2109

Merged
merged 1 commit into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ serde_derive = "1.0"
serde_json = "1.0"
tokio = { version = "0.2.2", features = ["fs", "macros", "io-std", "rt-util", "sync", "time", "test-util"] }
tokio-test = "0.2"
tower-util = "0.3"
url = "1.0"

[features]
Expand Down
36 changes: 21 additions & 15 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

use std::fmt;
use std::mem;
use std::sync::Arc;
use std::time::Duration;

use futures_channel::oneshot;
Expand Down Expand Up @@ -230,14 +229,13 @@ where
other => return ResponseFuture::error_version(other),
};

let domain = match extract_domain(req.uri_mut(), is_http_connect) {
let pool_key = match extract_domain(req.uri_mut(), is_http_connect) {
Ok(s) => s,
Err(err) => {
return ResponseFuture::new(Box::new(future::err(err)));
}
};

let pool_key = Arc::new(domain);
ResponseFuture::new(Box::new(self.retryably_send_request(req, pool_key)))
}

Expand Down Expand Up @@ -281,7 +279,7 @@ where
mut req: Request<B>,
pool_key: PoolKey,
) -> impl Future<Output = Result<Response<Body>, ClientError<B>>> + Unpin {
let conn = self.connection_for(req.uri().clone(), pool_key);
let conn = self.connection_for(pool_key);

let set_host = self.config.set_host;
let executor = self.conn_builder.exec.clone();
Expand Down Expand Up @@ -377,7 +375,6 @@ where

fn connection_for(
&self,
uri: Uri,
pool_key: PoolKey,
) -> impl Future<Output = Result<Pooled<PoolClient<B>>, ClientError<B>>> {
// This actually races 2 different futures to try to get a ready
Expand All @@ -394,7 +391,7 @@ where
// connection future is spawned into the runtime to complete,
// and then be inserted into the pool as an idle connection.
let checkout = self.pool.checkout(pool_key.clone());
let connect = self.connect_to(uri, pool_key);
let connect = self.connect_to(pool_key);

let executor = self.conn_builder.exec.clone();
// The order of the `select` is depended on below...
Expand Down Expand Up @@ -455,7 +452,6 @@ where

fn connect_to(
&self,
uri: Uri,
pool_key: PoolKey,
) -> impl Lazy<Output = crate::Result<Pooled<PoolClient<B>>>> + Unpin {
let executor = self.conn_builder.exec.clone();
Expand All @@ -464,7 +460,7 @@ where
let ver = self.config.ver;
let is_ver_h2 = ver == Ver::Http2;
let connector = self.connector.clone();
let dst = uri;
let dst = domain_as_uri(pool_key.clone());
hyper_lazy(move || {
// Try to take a "connecting lock".
//
Expand Down Expand Up @@ -794,22 +790,22 @@ fn authority_form(uri: &mut Uri) {
};
}

fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<String> {
fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<PoolKey> {
let uri_clone = uri.clone();
match (uri_clone.scheme(), uri_clone.authority()) {
(Some(scheme), Some(auth)) => Ok(format!("{}://{}", scheme, auth)),
(Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())),
(None, Some(auth)) if is_http_connect => {
let scheme = match auth.port_u16() {
Some(443) => {
set_scheme(uri, Scheme::HTTPS);
"https"
Scheme::HTTPS
}
_ => {
set_scheme(uri, Scheme::HTTP);
"http"
Scheme::HTTP
}
};
Ok(format!("{}://{}", scheme, auth))
Ok((scheme, auth.clone()))
}
_ => {
debug!("Client requires absolute-form URIs, received: {:?}", uri);
Expand All @@ -818,6 +814,15 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<String>
}
}

fn domain_as_uri((scheme, auth): PoolKey) -> Uri {
http::uri::Builder::new()
.scheme(scheme)
.authority(auth)
.path_and_query("/")
.build()
.expect("domain is valid Uri")
}

fn set_scheme(uri: &mut Uri, scheme: Scheme) {
debug_assert!(
uri.scheme().is_none(),
Expand Down Expand Up @@ -1126,7 +1131,8 @@ mod unit_tests {
#[test]
fn test_extract_domain_connect_no_port() {
let mut uri = "hyper.rs".parse().unwrap();
let domain = extract_domain(&mut uri, true).expect("extract domain");
assert_eq!(domain, "http://hyper.rs");
let (scheme, host) = extract_domain(&mut uri, true).expect("extract domain");
assert_eq!(scheme, *"http");
assert_eq!(host, "hyper.rs");
}
}
23 changes: 13 additions & 10 deletions src/client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(super) enum Reservation<T> {
}

/// Simple type alias in case the key type needs to be adjusted.
pub(super) type Key = Arc<String>;
pub(super) type Key = (http::uri::Scheme, http::uri::Authority); //Arc<String>;

struct PoolInner<T> {
// A flag that a connection is being established, and the connection
Expand Down Expand Up @@ -755,7 +755,6 @@ impl<T> WeakOpt<T> {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;

Expand Down Expand Up @@ -787,6 +786,10 @@ mod tests {
}
}

fn host_key(s: &str) -> Key {
(http::uri::Scheme::HTTP, s.parse().expect("host key"))
}

fn pool_no_timer<T>() -> Pool<T> {
pool_max_idle_no_timer(::std::usize::MAX)
}
Expand All @@ -807,7 +810,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_smoke() {
let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");
let pooled = pool.pooled(c(key.clone()), Uniq(41));

drop(pooled);
Expand Down Expand Up @@ -839,7 +842,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_returns_none_if_expired() {
let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");
let pooled = pool.pooled(c(key.clone()), Uniq(41));

drop(pooled);
Expand All @@ -854,7 +857,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_removes_expired() {
let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");

pool.pooled(c(key.clone()), Uniq(41));
pool.pooled(c(key.clone()), Uniq(5));
Expand All @@ -876,7 +879,7 @@ mod tests {
#[test]
fn test_pool_max_idle_per_host() {
let pool = pool_max_idle_no_timer(2);
let key = Arc::new("foo".to_string());
let key = host_key("foo");

pool.pooled(c(key.clone()), Uniq(41));
pool.pooled(c(key.clone()), Uniq(5));
Expand Down Expand Up @@ -904,7 +907,7 @@ mod tests {
&Exec::Default,
);

let key = Arc::new("foo".to_string());
let key = host_key("foo");

pool.pooled(c(key.clone()), Uniq(41));
pool.pooled(c(key.clone()), Uniq(5));
Expand All @@ -929,7 +932,7 @@ mod tests {
use futures_util::FutureExt;

let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");
let pooled = pool.pooled(c(key.clone()), Uniq(41));

let checkout = join(pool.checkout(key), async {
Expand All @@ -948,7 +951,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_drop_cleans_up_waiters() {
let pool = pool_no_timer::<Uniq<i32>>();
let key = Arc::new("localhost:12345".to_string());
let key = host_key("foo");

let mut checkout1 = pool.checkout(key.clone());
let mut checkout2 = pool.checkout(key.clone());
Expand Down Expand Up @@ -993,7 +996,7 @@ mod tests {
#[test]
fn pooled_drop_if_closed_doesnt_reinsert() {
let pool = pool_no_timer();
let key = Arc::new("localhost:12345".to_string());
let key = host_key("foo");
pool.pooled(
c(key.clone()),
CanClose {
Expand Down
33 changes: 24 additions & 9 deletions src/client/tests.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
// FIXME: re-implement tests with `async/await`
/*
#![cfg(feature = "runtime")]
use std::io;

use futures_util::future;
use tokio::net::TcpStream;

use futures::{Async, Future, Stream};
use futures::future::poll_fn;
use futures::sync::oneshot;
use tokio::runtime::current_thread::Runtime;
use super::Client;

use crate::mock::MockConnector;
use super::*;
#[tokio::test]
async fn client_connect_uri_argument() {
let connector = tower_util::service_fn(|dst: http::Uri| {
assert_eq!(dst.scheme(), Some(&http::uri::Scheme::HTTP));
assert_eq!(dst.host(), Some("example.local"));
assert_eq!(dst.port(), None);
assert_eq!(dst.path(), "/", "path should be removed");

future::err::<TcpStream, _>(io::Error::new(io::ErrorKind::Other, "expect me"))
});

let client = Client::builder().build::<_, crate::Body>(connector);
let _ = client
.get("http://example.local/and/a/path".parse().unwrap())
.await
.expect_err("response should fail");
}

/*
// FIXME: re-implement tests with `async/await`
#[test]
fn retryable_request() {
let _ = pretty_env_logger::try_init();
Expand Down