From 308d6a3116c2b90f74b95f266ed0ed885fb4b970 Mon Sep 17 00:00:00 2001 From: Yusuke Sasaki Date: Sun, 2 Sep 2018 02:56:59 +0900 Subject: [PATCH] tweak rejection APIs * add combinator: OrReject and OrRejectWith * make Fixed and reject() deprecated --- src/endpoint/fixed.rs | 8 +++- src/endpoint/mod.rs | 30 +++++++++++++- src/endpoint/or_reject.rs | 87 +++++++++++++++++++++++++++++++++++++++ src/endpoint/reject.rs | 14 ++++++- src/endpoints/header.rs | 3 +- tests/endpoint/and.rs | 27 +----------- tests/endpoint/or.rs | 6 +-- tests/endpoint/recover.rs | 2 +- tests/endpoints/header.rs | 7 ++-- 9 files changed, 144 insertions(+), 40 deletions(-) create mode 100644 src/endpoint/or_reject.rs diff --git a/src/endpoint/fixed.rs b/src/endpoint/fixed.rs index aff1017b7..d297e3ef1 100644 --- a/src/endpoint/fixed.rs +++ b/src/endpoint/fixed.rs @@ -1,3 +1,5 @@ +#![allow(deprecated)] + use std::pin::PinMut; use futures_core::future::{Future, TryFuture}; @@ -8,7 +10,11 @@ use pin_utils::unsafe_unpinned; use crate::endpoint::{Context, Endpoint, EndpointError, EndpointResult}; use crate::error::Error; -#[allow(missing_docs)] +#[doc(hidden)] +#[deprecated( + since = "0.12.0-alpha.3", + note = "This struct is going to remove before releasing 0.12.0." +)] #[derive(Debug, Copy, Clone)] pub struct Fixed { pub(super) endpoint: E, diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 2ab6b2a4f..bd319f1eb 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -10,6 +10,7 @@ mod fixed; mod lazy; mod map; mod or; +mod or_reject; mod recover; mod reject; mod then; @@ -24,13 +25,18 @@ pub use self::error::{EndpointError, EndpointResult}; pub use self::and::And; pub use self::and_then::AndThen; pub use self::boxed::{Boxed, BoxedLocal}; +#[allow(deprecated)] +#[doc(hidden)] pub use self::fixed::Fixed; pub use self::map::Map; pub use self::or::Or; +pub use self::or_reject::{OrReject, OrRejectWith}; pub use self::recover::Recover; pub use self::then::Then; pub use self::lazy::{lazy, Lazy}; +#[allow(deprecated)] +#[doc(hidden)] pub use self::reject::{reject, Reject}; pub use self::unit::{unit, Unit}; pub use self::value::{value, Value}; @@ -173,6 +179,23 @@ pub trait EndpointExt<'a>: Endpoint<'a> + Sized { (AndThen { endpoint: self, f }).output::<(::Ok,)>() } + /// Creates an endpoint which returns the error value returned from + /// `Endpoint::apply()` as the return value from the associated `Future`. + fn or_reject(self) -> OrReject { + (OrReject { endpoint: self }).output::() + } + + /// Creates an endpoint which converts the error value returned from + /// `Endpoint::apply()` to the specified type and returns it as + /// the return value from the associated `Future`. + fn or_reject_with(self, f: F) -> OrRejectWith + where + F: Fn(EndpointError, &mut Context<'_>) -> R + 'a, + R: Into + 'a, + { + (OrRejectWith { endpoint: self, f }).output::() + } + #[allow(missing_docs)] fn recover(self, f: F) -> Recover where @@ -182,7 +205,12 @@ pub trait EndpointExt<'a>: Endpoint<'a> + Sized { (Recover { endpoint: self, f }).output::<(self::recover::Recovered,)>() } - #[allow(missing_docs)] + #[doc(hidden)] + #[deprecated( + since = "0.12.0-alpha.3", + note = "this method is going to remove before releasing 0.12.0." + )] + #[allow(deprecated)] fn fixed(self) -> Fixed { Fixed { endpoint: self } } diff --git a/src/endpoint/or_reject.rs b/src/endpoint/or_reject.rs new file mode 100644 index 000000000..6c35166e8 --- /dev/null +++ b/src/endpoint/or_reject.rs @@ -0,0 +1,87 @@ +use std::pin::PinMut; + +use futures_core::future::{Future, TryFuture}; +use futures_core::task; +use futures_core::task::Poll; +use pin_utils::unsafe_unpinned; + +use crate::endpoint::{Context, Endpoint, EndpointError, EndpointResult}; +use crate::error::Error; + +#[allow(missing_docs)] +#[derive(Debug, Copy, Clone)] +pub struct OrReject { + pub(super) endpoint: E, +} + +impl<'a, E: Endpoint<'a>> Endpoint<'a> for OrReject { + type Output = E::Output; + type Future = OrRejectFuture; + + fn apply(&'a self, ecx: &mut Context<'_>) -> EndpointResult { + match self.endpoint.apply(ecx) { + Ok(future) => Ok(OrRejectFuture { inner: Ok(future) }), + Err(err) => { + while let Some(..) = ecx.next_segment() {} + Ok(OrRejectFuture { + inner: Err(Some(err.into())), + }) + } + } + } +} + +#[derive(Debug)] +pub struct OrRejectFuture { + inner: Result>, +} + +impl OrRejectFuture { + unsafe_unpinned!(inner: Result>); +} + +impl Future for OrRejectFuture +where + F: TryFuture, +{ + type Output = Result; + + fn poll(mut self: PinMut<'_, Self>, cx: &mut task::Context<'_>) -> Poll { + match self.inner() { + Ok(ref mut f) => unsafe { PinMut::new_unchecked(f).try_poll(cx) }, + Err(ref mut err) => Poll::Ready(Err(err.take().unwrap())), + } + } +} + +// ==== OrRejectWith ==== + +#[allow(missing_docs)] +#[derive(Debug, Copy, Clone)] +pub struct OrRejectWith { + pub(super) endpoint: E, + pub(super) f: F, +} + +impl<'a, E, F, R> Endpoint<'a> for OrRejectWith +where + E: Endpoint<'a>, + F: Fn(EndpointError, &mut Context<'_>) -> R + 'a, + R: Into + 'a, +{ + type Output = E::Output; + type Future = OrRejectFuture; + + fn apply(&'a self, ecx: &mut Context<'_>) -> EndpointResult { + match self.endpoint.apply(ecx) { + Ok(future) => Ok(OrRejectFuture { inner: Ok(future) }), + Err(err) => { + while let Some(..) = ecx.next_segment() {} + let err = (self.f)(err, ecx).into(); + Ok(OrRejectFuture { + inner: Err(Some(err)), + }) + } + } + } +} diff --git a/src/endpoint/reject.rs b/src/endpoint/reject.rs index 7de8d1861..88cf82c3f 100644 --- a/src/endpoint/reject.rs +++ b/src/endpoint/reject.rs @@ -1,3 +1,5 @@ +#![allow(deprecated)] + use std::marker::PhantomData; use std::pin::PinMut; @@ -7,7 +9,11 @@ use crate::endpoint::{Context, Endpoint, EndpointExt, EndpointResult}; use crate::error::Error; use crate::input::Input; -/// Creates an endpoint which always rejects the request with the specified error. +#[doc(hidden)] +#[deprecated( + since = "0.12.0-alpha.3", + note = "This endpoint is going to remove before releasing 0.12.0." +)] pub fn reject(f: F) -> Reject where F: Fn(PinMut<'_, Input>) -> E, @@ -19,7 +25,11 @@ where }).output::<()>() } -#[allow(missing_docs)] +#[doc(hidden)] +#[deprecated( + since = "0.12.0-alpha.3", + note = "This endpoint is going to remove before releasing 0.12.0." +)] #[derive(Debug)] pub struct Reject { f: F, diff --git a/src/endpoints/header.rs b/src/endpoints/header.rs index 89adde646..d64337ab5 100644 --- a/src/endpoints/header.rs +++ b/src/endpoints/header.rs @@ -216,11 +216,10 @@ where /// ``` /// # use finchers::endpoint::EndpointExt; /// # use finchers::endpoints::header; -/// use finchers::endpoint::reject; /// use finchers::error; /// /// let endpoint = header::matches("origin", "www.example.com") -/// .or(reject(|_| error::bad_request("The value of Origin is invalid"))); +/// .or_reject_with(|_, _| error::bad_request("invalid header value")); /// # drop(endpoint); /// ``` pub fn matches(name: K, value: V) -> Matches diff --git a/tests/endpoint/and.rs b/tests/endpoint/and.rs index 26d7e06e0..b12d97c27 100644 --- a/tests/endpoint/and.rs +++ b/tests/endpoint/and.rs @@ -1,9 +1,6 @@ -use failure::format_err; -use finchers::endpoint::{reject, unit, value, EndpointExt}; -use finchers::error::bad_request; +use finchers::endpoint::{unit, value, EndpointExt}; use finchers::local; -use http::StatusCode; use matches::assert_matches; #[test] @@ -13,28 +10,6 @@ fn test_and_all_ok() { assert_matches!(local::get("/").apply(&endpoint), Ok(("Hello", "world"))); } -#[test] -fn test_and_with_err_1() { - let endpoint = value("Hello").and(reject(|_| bad_request(format_err!(""))).output::<()>()); - - assert_matches!( - local::get("/").apply(&endpoint), - Err(ref e) if e.status_code() == StatusCode::BAD_REQUEST - ); -} - -#[test] -fn test_and_with_err_2() { - let endpoint = reject(|_| bad_request(format_err!(""))) - .output::<()>() - .and(value("Hello")); - - assert_matches!( - local::get("/").apply(&endpoint), - Err(ref e) if e.status_code() == StatusCode::BAD_REQUEST - ); -} - #[test] fn test_and_flatten() { let endpoint = value("Hello") diff --git a/tests/endpoint/or.rs b/tests/endpoint/or.rs index e0ee98ec4..2cbe87b1f 100644 --- a/tests/endpoint/or.rs +++ b/tests/endpoint/or.rs @@ -1,5 +1,5 @@ use failure::format_err; -use finchers::endpoint::{reject, value, EndpointExt}; +use finchers::endpoint::{value, EndpointExt}; use finchers::endpoints::path::path; use finchers::error::bad_request; use finchers::local; @@ -28,10 +28,10 @@ fn test_or_choose_longer_segments() { } #[test] -fn test_or_with_rejection_path() { +fn test_or_with_rejection() { let endpoint = path("foo") .or(path("bar")) - .or(reject(|_| bad_request(format_err!("custom rejection")))); + .or_reject_with(|_err, _cx| bad_request(format_err!("custom rejection"))); assert_matches!(local::get("/foo").apply(&endpoint), Ok(..)); diff --git a/tests/endpoint/recover.rs b/tests/endpoint/recover.rs index 705c6d8f6..e09e524bb 100644 --- a/tests/endpoint/recover.rs +++ b/tests/endpoint/recover.rs @@ -9,7 +9,7 @@ fn test_recover() { let endpoint = method::get(path::path("posts").and(path::param::())) .map(|id: u32| format!("param={}", id)); - let recovered = endpoint.fixed().recover(|err| { + let recovered = endpoint.or_reject().recover(|err| { if err.is::() { ready(Ok(Response::builder() .status(err.status_code()) diff --git a/tests/endpoints/header.rs b/tests/endpoints/header.rs index 6ef463b09..3f05974fa 100644 --- a/tests/endpoints/header.rs +++ b/tests/endpoints/header.rs @@ -1,4 +1,4 @@ -use finchers::endpoint::{reject, EndpointExt}; +use finchers::endpoint::EndpointExt; use finchers::endpoints::header; use finchers::error; use finchers::local; @@ -71,9 +71,8 @@ fn test_header_optional() { #[test] fn test_header_matches_with_rejection() { - let endpoint = header::matches("origin", "www.example.com").or(reject(|_| { - error::bad_request("The value of Origin is invalid") - })); + let endpoint = header::matches("origin", "www.example.com") + .or_reject_with(|_, _| error::bad_request("The value of Origin is invalid")); assert_matches!( local::get("/")