diff --git a/src/lib.rs b/src/lib.rs index bd9bd60e..d9d7014c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,11 @@ extern crate tower_add_origin; -pub use tower_add_origin::AddOrigin; +pub mod add_origin { + pub use ::tower_add_origin::{ + AddOrigin, + Builder, + BuilderError, + }; +} + +pub use add_origin::AddOrigin; diff --git a/tower-add-origin/src/lib.rs b/tower-add-origin/src/lib.rs index 8206d8a3..bc98d5f6 100644 --- a/tower-add-origin/src/lib.rs +++ b/tower-add-origin/src/lib.rs @@ -3,17 +3,32 @@ extern crate http; extern crate tower; use futures::Poll; -use http::Request; -use http::uri::{Authority, Scheme}; +use http::{Request, HttpTryFrom}; +use http::uri::{self, Authority, Scheme, Uri}; use tower::Service; /// Wraps an HTTP service, injecting authority and scheme on every request. +#[derive(Debug)] pub struct AddOrigin { inner: T, scheme: Scheme, authority: Authority, } +/// Configure an `AddOrigin` instance +#[derive(Debug, Default)] +pub struct Builder { + uri: Option, +} + +/// Errors that can happen when building an `AddOrigin`. +#[derive(Debug)] +pub struct BuilderError { + _p: (), +} + +// ===== impl AddOrigin ====== + impl AddOrigin { /// Create a new `AddOrigin` pub fn new(inner: T, scheme: Scheme, authority: Authority) -> Self { @@ -83,3 +98,57 @@ where T: Service>, self.inner.call(request) } } + +// ===== impl Builder ====== + +impl Builder { + /// Return a new, default builder + pub fn new() -> Self { + Builder::default() + } + + /// Set the URI to use as the origin for all requests. + pub fn uri(&mut self, uri: T) -> &mut Self + where Uri: HttpTryFrom, + { + self.uri = Uri::try_from(uri) + .map(Some) + .unwrap_or(None); + + self + } + + pub fn build(&mut self, inner: T) -> Result, BuilderError> { + // Create the error just in case. It is a zero sized type anyway right + // now. + let err = BuilderError { _p: () }; + + let uri = match self.uri.take() { + Some(uri) => uri, + None => return Err(err), + }; + + let parts = uri::Parts::from(uri); + + // Get the scheme + let scheme = match parts.scheme { + Some(scheme) => scheme, + None => return Err(err), + }; + + // Get the authority + let authority = match parts.authority { + Some(authority) => authority, + None => return Err(err), + }; + + // Ensure that the path is unsued + match parts.path_and_query { + None => {} + Some(ref path) if path == "/" => {} + _ => return Err(err), + } + + Ok(AddOrigin::new(inner, scheme, authority)) + } +} diff --git a/tower-add-origin/tests/add_origin.rs b/tower-add-origin/tests/add_origin.rs index 971f54c8..f8439e0e 100644 --- a/tower-add-origin/tests/add_origin.rs +++ b/tower-add-origin/tests/add_origin.rs @@ -38,3 +38,19 @@ fn adds_origin_to_requests() { send_response.respond(response); } + +#[test] +fn does_not_build_with_relative_uri() { + let _ = Builder::new() + .uri("/") + .build(()) + .unwrap_err(); +} + +#[test] +fn does_not_build_with_path() { + let _ = Builder::new() + .uri("http://www.example.com/foo") + .build(()) + .unwrap_err(); +}