diff --git a/src/client.rs b/src/client.rs index ae1cfb479..16a239354 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,7 +16,7 @@ use serde_json; use serde_urlencoded; use ::body::{self, Body}; -use ::redirect::{RedirectPolicy, check_redirect}; +use ::redirect::{self, RedirectPolicy, check_redirect}; use ::response::Response; static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -36,7 +36,7 @@ pub struct Client { impl Client { /// Constructs a new `Client`. pub fn new() -> ::Result { - let mut client = try!(new_hyper_client()); + let mut client = try_!(new_hyper_client()); client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone); Ok(Client { inner: Arc::new(ClientRef { @@ -133,7 +133,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> { ::hyper::client::Pool::with_connector( Default::default(), ::hyper::net::HttpsConnector::new( - try!(NativeTlsClient::new() + try_!(NativeTlsClient::new() .map_err(|e| ::hyper::Error::Ssl(Box::new(e))))) ) )) @@ -198,7 +198,7 @@ impl RequestBuilder { /// .send(); /// ``` pub fn form(mut self, form: &T) -> RequestBuilder { - let body = serde_urlencoded::to_string(form).map_err(::Error::from); + let body = serde_urlencoded::to_string(form).map_err(::error::from); self.headers.set(ContentType::form_url_encoded()); self.body = Some(body.map(|b| b.into())); self @@ -242,10 +242,10 @@ impl RequestBuilder { } let client = self.client; let mut method = self.method; - let mut url = try!(self.url); + let mut url = try_!(self.url); let mut headers = self.headers; let mut body = match self.body { - Some(b) => Some(try!(b)), + Some(b) => Some(try_!(b)), None => None, }; @@ -263,7 +263,7 @@ impl RequestBuilder { req = req.body(body); } - try!(req.send()) + try_!(req.send(), &url) }; let should_redirect = match res.status { @@ -304,12 +304,20 @@ impl RequestBuilder { Ok(loc) => { headers.set(Referer(url.to_string())); urls.push(url); - if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? { - loc - } else { - debug!("redirect_policy disallowed redirection to '{}'", loc); - - return Ok(::response::new(res, client.auto_ungzip.load(Ordering::Relaxed))); + let action = check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls); + match action { + redirect::Action::Follow => loc, + redirect::Action::Stop => { + debug!("redirect_policy disallowed redirection to '{}'", loc); + + return Ok(::response::new(res, client.auto_ungzip.load(Ordering::Relaxed))); + }, + redirect::Action::LoopDetected => { + return Err(::error::loop_detected(res.url.clone())); + }, + redirect::Action::TooManyRedirects => { + return Err(::error::too_many_redirects(res.url.clone())); + } } }, Err(e) => { diff --git a/src/error.rs b/src/error.rs index ced68f13e..46a20073f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,87 +1,191 @@ use std::error::Error as StdError; use std::fmt; +use ::Url; + /// The Errors that may occur when processing a `Request`. #[derive(Debug)] -pub enum Error { - /// An HTTP error from the `hyper` crate. - Http(::hyper::Error), - /// An error trying to serialize a value. - /// - /// This may be serializing a value that is illegal in JSON or - /// form-url-encoded bodies. - Serialize(Box), - /// A request tried to redirect too many times. - TooManyRedirects, - /// An infinite redirect loop was detected. - RedirectLoop, - #[doc(hidden)] - __DontMatchMe, +pub struct Error { + kind: Kind, + url: Option, +} + +/// A `Result` alias where the `Err` case is `reqwest::Error`. +pub type Result = ::std::result::Result; + +impl Error { + /// Returns a possible URL related to this error. + #[inline] + pub fn url(&self) -> Option<&Url> { + self.url.as_ref() + } + + /// Returns true if the error is related to HTTP. + #[inline] + pub fn is_http(&self) -> bool { + match self.kind { + Kind::Http(_) => true, + _ => false, + } + } + + /// Returns true if the error is serialization related. + #[inline] + pub fn is_serialization(&self) -> bool { + match self.kind { + Kind::Json(_) | + Kind::UrlEncoded(_) => true, + _ => false, + } + } + + /// Returns true if the error is from a `RedirectPolicy`. + #[inline] + pub fn is_redirect(&self) -> bool { + match self.kind { + Kind::TooManyRedirects | + Kind::RedirectLoop => true, + _ => false, + } + } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::Http(ref e) => fmt::Display::fmt(e, f), - Error::Serialize(ref e) => fmt::Display::fmt(e, f), - Error::TooManyRedirects => f.pad("Too many redirects"), - Error::RedirectLoop => f.pad("Infinite redirect loop"), - Error::__DontMatchMe => unreachable!() + if let Some(ref url) = self.url { + try!(fmt::Display::fmt(url, f)); + try!(f.write_str(": ")); + } + match self.kind { + Kind::Http(ref e) => fmt::Display::fmt(e, f), + Kind::UrlEncoded(ref e) => fmt::Display::fmt(e, f), + Kind::Json(ref e) => fmt::Display::fmt(e, f), + Kind::TooManyRedirects => f.write_str("Too many redirects"), + Kind::RedirectLoop => f.write_str("Infinite redirect loop"), } } } impl StdError for Error { fn description(&self) -> &str { - match *self { - Error::Http(ref e) => e.description(), - Error::Serialize(ref e) => e.description(), - Error::TooManyRedirects => "Too many redirects", - Error::RedirectLoop => "Infinite redirect loop", - Error::__DontMatchMe => unreachable!() + match self.kind { + Kind::Http(ref e) => e.description(), + Kind::UrlEncoded(ref e) => e.description(), + Kind::Json(ref e) => e.description(), + Kind::TooManyRedirects => "Too many redirects", + Kind::RedirectLoop => "Infinite redirect loop", } } fn cause(&self) -> Option<&StdError> { - match *self { - Error::Http(ref e) => Some(e), - Error::Serialize(ref e) => Some(&**e), - Error::TooManyRedirects | - Error::RedirectLoop => None, - Error::__DontMatchMe => unreachable!() + match self.kind { + Kind::Http(ref e) => Some(e), + Kind::UrlEncoded(ref e) => Some(e), + Kind::Json(ref e) => Some(e), + Kind::TooManyRedirects | + Kind::RedirectLoop => None, } } } -fn _assert_types() { - fn _assert_send() { +// pub(crate) + +#[derive(Debug)] +pub enum Kind { + Http(::hyper::Error), + UrlEncoded(::serde_urlencoded::ser::Error), + Json(::serde_json::Error), + TooManyRedirects, + RedirectLoop, +} + + +impl From<::hyper::Error> for Kind { + #[inline] + fn from(err: ::hyper::Error) -> Kind { + Kind::Http(err) + } +} + +impl From<::url::ParseError> for Kind { + #[inline] + fn from(err: ::url::ParseError) -> Kind { + Kind::Http(::hyper::Error::Uri(err)) } - _assert_send::(); } -impl From<::hyper::Error> for Error { - fn from(err: ::hyper::Error) -> Error { - Error::Http(err) +impl From<::serde_urlencoded::ser::Error> for Kind { + #[inline] + fn from(err: ::serde_urlencoded::ser::Error) -> Kind { + Kind::UrlEncoded(err) } } -impl From<::url::ParseError> for Error { - fn from(err: ::url::ParseError) -> Error { - Error::Http(::hyper::Error::Uri(err)) +impl From<::serde_json::Error> for Kind { + #[inline] + fn from(err: ::serde_json::Error) -> Kind { + Kind::Json(err) } } -impl From<::serde_urlencoded::ser::Error> for Error { - fn from(err: ::serde_urlencoded::ser::Error) -> Error { - Error::Serialize(Box::new(err)) +pub struct InternalFrom(pub T, pub Option); + +impl From> for Error { + #[inline] + fn from(other: InternalFrom) -> Error { + other.0 } } -impl From<::serde_json::Error> for Error { - fn from(err: ::serde_json::Error) -> Error { - Error::Serialize(Box::new(err)) +impl From> for Error +where T: Into { + #[inline] + fn from(other: InternalFrom) -> Error { + Error { + kind: other.0.into(), + url: other.1, + } } } -/// A `Result` alias where the `Err` case is `reqwest::Error`. -pub type Result = ::std::result::Result; +#[inline] +pub fn from(err: T) -> Error +where T: Into { + InternalFrom(err, None).into() +} + +#[inline] +pub fn loop_detected(url: Url) -> Error { + Error { + kind: Kind::RedirectLoop, + url: Some(url), + } +} + +#[inline] +pub fn too_many_redirects(url: Url) -> Error { + Error { + kind: Kind::TooManyRedirects, + url: Some(url), + } +} + +#[macro_export] +macro_rules! try_ { + ($e:expr) => ( + match $e { + Ok(v) => v, + Err(err) => { + return Err(::Error::from(::error::InternalFrom(err, None))); + } + } + ); + ($e:expr, $url:expr) => ( + match $e { + Ok(v) => v, + Err(err) => { + return Err(::Error::from(::error::InternalFrom(err, Some($url.clone())))); + } + } + ) +} diff --git a/src/lib.rs b/src/lib.rs index c6a13da8d..7ecbe99de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,9 +124,9 @@ pub use self::body::Body; pub use self::redirect::RedirectPolicy; pub use self::response::Response; +#[macro_use] mod error; mod body; mod client; -mod error; mod redirect; mod response; @@ -161,4 +161,8 @@ fn _assert_impls() { assert_send::(); assert_send::(); + + + assert_send::(); + assert_sync::(); } diff --git a/src/redirect.rs b/src/redirect.rs index d714ecd8c..919d0dcc0 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -11,10 +11,22 @@ pub struct RedirectPolicy { inner: Policy, } +#[derive(Debug)] +pub struct RedirectAttempt<'a> { + next: &'a Url, + previous: &'a [Url], +} + +/// An action to perform when a redirect status code is found. +#[derive(Debug)] +pub struct RedirectAction { + inner: Action, +} + impl RedirectPolicy { /// Create a RedirectPolicy with a maximum number of redirects. /// - /// A `Error::TooManyRedirects` will be returned if the max is reached. + /// An `Error` will be returned if the max is reached. pub fn limited(max: usize) -> RedirectPolicy { RedirectPolicy { inner: Policy::Limit(max), @@ -36,45 +48,48 @@ impl RedirectPolicy { /// chain, but the custom variant does not do that for you automatically. /// The custom policy should have some way of handling those. /// - /// There are variants on `::Error` for both cases that can be used as - /// return values. + /// Information on the next request and previous requests can be found + /// on the `RedirectAttempt` argument passed to the closure. + /// + /// Actions can be conveniently created from methods on the + /// `RedirectAttempt`. /// /// # Example /// /// ```no_run /// # use reqwest::RedirectPolicy; /// # let mut client = reqwest::Client::new().unwrap(); - /// client.redirect(RedirectPolicy::custom(|next, previous| { - /// if previous.len() > 5 { - /// Err(reqwest::Error::TooManyRedirects) - /// } else if next.host_str() == Some("example.domain") { + /// client.redirect(RedirectPolicy::custom(|attempt| { + /// if attempt.previous().len() > 5 { + /// attempt.too_many_redirects() + /// } else if attempt.url().host_str() == Some("example.domain") { /// // prevent redirects to 'example.domain' - /// Ok(false) + /// attempt.stop() /// } else { - /// Ok(true) + /// attempt.follow() /// } /// })); /// ``` pub fn custom(policy: T) -> RedirectPolicy - where T: Fn(&Url, &[Url]) -> ::Result + Send + Sync + 'static { + where T: Fn(RedirectAttempt) -> RedirectAction + Send + Sync + 'static { RedirectPolicy { inner: Policy::Custom(Box::new(policy)), } } - fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result { + fn redirect(&self, attempt: RedirectAttempt) -> RedirectAction { match self.inner { - Policy::Custom(ref custom) => custom(next, previous), + Policy::Custom(ref custom) => custom(attempt), Policy::Limit(max) => { - if previous.len() == max { - Err(::Error::TooManyRedirects) - } else if previous.contains(next) { - Err(::Error::RedirectLoop) + if attempt.previous.len() == max { + attempt.too_many_redirects() + } else if attempt.previous.contains(attempt.next) { + attempt.loop_detected() } else { - Ok(true) + attempt.follow() } }, - Policy::None => Ok(false), + Policy::None => attempt.stop(), } } } @@ -85,8 +100,53 @@ impl Default for RedirectPolicy { } } +impl<'a> RedirectAttempt<'a> { + /// Get the next URL to redirect to. + pub fn url(&self) -> &Url { + self.next + } + + /// Get the list of previous URLs that have already been requested in this chain. + pub fn previous(&self) -> &[Url] { + self.previous + } + /// Returns an action meaning reqwest should follow the next URL. + pub fn follow(self) -> RedirectAction { + RedirectAction { + inner: Action::Follow, + } + } + + /// Returns an action meaning reqwest should not follow the next URL. + /// + /// The 30x response will be returned as the `Ok` result. + pub fn stop(self) -> RedirectAction { + RedirectAction { + inner: Action::Stop, + } + } + + /// Returns an action meaning there was a loop of redirects found. + /// + /// An `Error` will be returned for the result of the sent request. + pub fn loop_detected(self) -> RedirectAction { + RedirectAction { + inner: Action::LoopDetected, + } + } + + /// Returns an action meaning there was a loop of redirects found. + /// + /// An `Error` will be returned for the result of the sent request. + pub fn too_many_redirects(self) -> RedirectAction { + RedirectAction { + inner: Action::TooManyRedirects, + } + } +} + enum Policy { - Custom(Box ::Result + Send + Sync + 'static>), + Custom(Box RedirectAction + Send + Sync + 'static>), Limit(usize), None, } @@ -101,8 +161,22 @@ impl fmt::Debug for Policy { } } -pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> ::Result { - policy.redirect(next, previous) +// pub(crate) + +#[derive(Debug, PartialEq)] +pub enum Action { + Follow, + Stop, + LoopDetected, + TooManyRedirects, +} + +#[inline] +pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> Action { + policy.redirect(RedirectAttempt { + next: next, + previous: previous, + }).inner } /* @@ -132,32 +206,26 @@ fn test_redirect_policy_limit() { .collect::>(); - match policy.redirect(&next, &previous) { - Ok(true) => {}, - other => panic!("expected Ok(true), got: {:?}", other) - } + assert_eq!(check_redirect(&policy, &next, &previous), Action::Follow); previous.push(Url::parse("http://a.b.d/e/33").unwrap()); - match policy.redirect(&next, &previous) { - Err(::Error::TooManyRedirects) => {}, - other => panic!("expected TooManyRedirects, got: {:?}", other) - } + assert_eq!(check_redirect(&policy, &next, &previous), Action::TooManyRedirects); } #[test] fn test_redirect_policy_custom() { - let policy = RedirectPolicy::custom(|next, _previous| { - if next.host_str() == Some("foo") { - Ok(false) + let policy = RedirectPolicy::custom(|attempt| { + if attempt.url().host_str() == Some("foo") { + attempt.stop() } else { - Ok(true) + attempt.follow() } }); let next = Url::parse("http://bar/baz").unwrap(); - assert_eq!(policy.redirect(&next, &[]).unwrap(), true); + assert_eq!(check_redirect(&policy, &next, &[]), Action::Follow); let next = Url::parse("http://foo/baz").unwrap(); - assert_eq!(policy.redirect(&next, &[]).unwrap(), false); + assert_eq!(check_redirect(&policy, &next, &[]), Action::Stop); } diff --git a/src/response.rs b/src/response.rs index a488c52bf..465fe564c 100644 --- a/src/response.rs +++ b/src/response.rs @@ -89,7 +89,7 @@ impl Response { /// Try and deserialize the response body as JSON. #[inline] pub fn json(&mut self) -> ::Result { - serde_json::from_reader(self).map_err(::Error::from) + serde_json::from_reader(self).map_err(::error::from) } } diff --git a/tests/client.rs b/tests/client.rs index a36a3ed83..08e34556a 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -207,10 +207,7 @@ fn test_redirect_policy_can_return_errors() { }; let err = reqwest::get(&format!("http://{}/loop", server.addr())).unwrap_err(); - match err { - reqwest::Error::RedirectLoop => (), - e => panic!("wrong error received: {:?}", e), - } + assert!(err.is_redirect()); } #[test]