From 46df062f9f2d740f7093a988a51e298ef903fbc0 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 22 Aug 2022 17:02:26 +0200 Subject: [PATCH 1/3] Add `RequestExt` and `RequestPartsExt` --- axum/CHANGELOG.md | 2 + axum/src/ext_traits/mod.rs | 30 +++++ axum/src/ext_traits/request.rs | 191 +++++++++++++++++++++++++++ axum/src/ext_traits/request_parts.rs | 98 ++++++++++++++ axum/src/lib.rs | 3 + 5 files changed, 324 insertions(+) create mode 100644 axum/src/ext_traits/mod.rs create mode 100644 axum/src/ext_traits/request.rs create mode 100644 axum/src/ext_traits/request_parts.rs diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index b1ce5a204d..82900aad2d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -342,6 +342,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `MethodRouter`, defaults to `()` - `FromRequest`, no default - `Handler`, no default +- **added:** Add `RequestExt` and `RequestPartsExt` which adds convenience + methods for running extractors to `http::Request` and `http::request::Parts` ## Middleware diff --git a/axum/src/ext_traits/mod.rs b/axum/src/ext_traits/mod.rs new file mode 100644 index 0000000000..131e99b159 --- /dev/null +++ b/axum/src/ext_traits/mod.rs @@ -0,0 +1,30 @@ +pub(crate) mod request; +pub(crate) mod request_parts; + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use async_trait::async_trait; + use axum_core::extract::{FromRef, FromRequestParts}; + use http::request::Parts; + + // some extractor that requires the state, such as `SignedCookieJar` + pub(crate) struct RequiresState(pub(crate) String); + + #[async_trait] + impl FromRequestParts for RequiresState + where + S: Send + Sync, + String: FromRef, + { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut Parts, + state: &S, + ) -> Result { + Ok(Self(String::from_ref(state))) + } + } +} diff --git a/axum/src/ext_traits/request.rs b/axum/src/ext_traits/request.rs new file mode 100644 index 0000000000..9ec79c79f3 --- /dev/null +++ b/axum/src/ext_traits/request.rs @@ -0,0 +1,191 @@ +use async_trait::async_trait; +use axum_core::extract::{FromRequest, FromRequestParts}; +use http::Request; + +mod sealed { + pub trait Sealed {} + impl Sealed for http::Request {} +} + +/// Extension trait that adds additional methods to [`Request`]. +#[async_trait] +pub trait RequestExt: sealed::Sealed + Sized { + /// Apply an extractor to this `Request`. + /// + /// This is just a convenience for `E::from_request(req, &())`. + /// + /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting + /// the body and don't want to consume the request. + async fn extract(self) -> Result + where + E: FromRequest<(), B, M>; + + /// Apply an extractor that requires some state to this `Request`. + /// + /// This is just a convenience for `E::from_request(req, state)`. + /// + /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not + /// extracting the body and don't want to consume the request. + async fn extract_with_state(self, state: &S) -> Result + where + E: FromRequest, + S: Send + Sync; + + /// Apply a parts extractor to this `Request`. + /// + /// This is just a convenience for `E::from_request_parts(parts, state)`. + async fn extract_parts(&mut self) -> Result + where + E: FromRequestParts<()>; + + /// Apply a parts extractor that requires some state to this `Request`. + /// + /// This is just a convenience for `E::from_request_parts(parts, state)`. + async fn extract_parts_with_state(&mut self, state: &S) -> Result + where + E: FromRequestParts, + S: Send + Sync; +} + +#[async_trait] +impl RequestExt for Request +where + B: Send, +{ + async fn extract(self) -> Result + where + E: FromRequest<(), B, M>, + { + self.extract_with_state(&()).await + } + + async fn extract_with_state(self, state: &S) -> Result + where + E: FromRequest, + S: Send + Sync, + { + E::from_request(self, state).await + } + + async fn extract_parts(&mut self) -> Result + where + E: FromRequestParts<()>, + { + self.extract_parts_with_state(&()).await + } + + async fn extract_parts_with_state(&mut self, state: &S) -> Result + where + E: FromRequestParts, + S: Send + Sync, + { + let mut req = Request::new(()); + *req.version_mut() = self.version(); + *req.method_mut() = self.method().clone(); + *req.uri_mut() = self.uri().clone(); + *req.headers_mut() = std::mem::take(self.headers_mut()); + *req.extensions_mut() = std::mem::take(self.extensions_mut()); + let (mut parts, _) = req.into_parts(); + + let result = E::from_request_parts(&mut parts, state).await; + + *self.version_mut() = parts.version; + *self.method_mut() = parts.method.clone(); + *self.uri_mut() = parts.uri.clone(); + *self.headers_mut() = std::mem::take(&mut parts.headers); + *self.extensions_mut() = std::mem::take(&mut parts.extensions); + + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ext_traits::tests::RequiresState, extract::State}; + use axum_core::extract::FromRef; + use http::Method; + use hyper::Body; + + #[tokio::test] + async fn extract_without_state() { + let req = Request::new(()); + + let method: Method = req.extract().await.unwrap(); + + assert_eq!(method, Method::GET); + } + + #[tokio::test] + async fn extract_body_without_state() { + let req = Request::new(Body::from("foobar")); + + let body: String = req.extract().await.unwrap(); + + assert_eq!(body, "foobar"); + } + + #[tokio::test] + async fn extract_with_state() { + let req = Request::new(()); + + let state = "state".to_owned(); + + let State(extracted_state): State = req.extract_with_state(&state).await.unwrap(); + + assert_eq!(extracted_state, state); + } + + #[tokio::test] + async fn extract_parts_without_state() { + let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); + + let method: Method = req.extract_parts().await.unwrap(); + + assert_eq!(method, Method::GET); + assert_eq!(req.headers()["x-foo"], "foo"); + } + + #[tokio::test] + async fn extract_parts_with_state() { + let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); + + let state = "state".to_owned(); + + let State(extracted_state): State = + req.extract_parts_with_state(&state).await.unwrap(); + + assert_eq!(extracted_state, state); + assert_eq!(req.headers()["x-foo"], "foo"); + } + + // this stuff just needs to compile + #[allow(dead_code)] + struct WorksForCustomExtractor { + method: Method, + from_state: String, + body: String, + } + + #[async_trait] + impl FromRequest for WorksForCustomExtractor + where + S: Send + Sync, + B: Send + 'static, + String: FromRef + FromRequest<(), B>, + { + type Rejection = >::Rejection; + + async fn from_request(mut req: Request, state: &S) -> Result { + let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap(); + let method = req.extract_parts().await.unwrap(); + let body = req.extract().await?; + + Ok(Self { + method, + from_state, + body, + }) + } + } +} diff --git a/axum/src/ext_traits/request_parts.rs b/axum/src/ext_traits/request_parts.rs new file mode 100644 index 0000000000..c243b6e0e3 --- /dev/null +++ b/axum/src/ext_traits/request_parts.rs @@ -0,0 +1,98 @@ +use async_trait::async_trait; +use axum_core::extract::FromRequestParts; +use http::request::Parts; + +mod sealed { + pub trait Sealed {} + impl Sealed for http::request::Parts {} +} + +/// Extension trait that adds additional methods to [`Parts`]. +#[async_trait] +pub trait RequestPartsExt: sealed::Sealed + Sized { + /// Apply an extractor to this `Parts`. + /// + /// This is just a convenience for `E::from_request_parts(parts, &())`. + async fn extract(&mut self) -> Result + where + E: FromRequestParts<()>; + + /// Apply an extractor that requires some state to this `Parts`. + /// + /// This is just a convenience for `E::from_request_parts(parts, state)`. + async fn extract_with_state(&mut self, state: &S) -> Result + where + E: FromRequestParts, + S: Send + Sync; +} + +#[async_trait] +impl RequestPartsExt for Parts { + async fn extract(&mut self) -> Result + where + E: FromRequestParts<()>, + { + self.extract_with_state(&()).await + } + + async fn extract_with_state(&mut self, state: &S) -> Result + where + E: FromRequestParts, + S: Send + Sync, + { + E::from_request_parts(self, state).await + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use super::*; + use crate::{ext_traits::tests::RequiresState, extract::State}; + use axum_core::extract::FromRef; + use http::{Method, Request}; + + #[tokio::test] + async fn extract_without_state() { + let (mut parts, _) = Request::new(()).into_parts(); + + let method: Method = parts.extract().await.unwrap(); + + assert_eq!(method, Method::GET); + } + + #[tokio::test] + async fn extract_with_state() { + let (mut parts, _) = Request::new(()).into_parts(); + + let state = "state".to_owned(); + + let State(extracted_state): State = parts.extract_with_state(&state).await.unwrap(); + + assert_eq!(extracted_state, state); + } + + // this stuff just needs to compile + #[allow(dead_code)] + struct WorksForCustomExtractor { + method: Method, + from_state: String, + } + + #[async_trait] + impl FromRequestParts for WorksForCustomExtractor + where + S: Send + Sync, + String: FromRef, + { + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let RequiresState(from_state) = parts.extract_with_state(state).await?; + let method = parts.extract().await?; + + Ok(Self { method, from_state }) + } + } +} diff --git a/axum/src/lib.rs b/axum/src/lib.rs index df51086828..c74e80bb26 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -433,6 +433,7 @@ #[macro_use] pub(crate) mod macros; +mod ext_traits; mod extension; #[cfg(feature = "form")] mod form; @@ -484,3 +485,5 @@ pub use axum_core::{BoxError, Error}; #[cfg(feature = "macros")] pub use axum_macros::debug_handler; + +pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; From c569750b6f17efe1b07cd05208749a8f1b57244e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 22 Aug 2022 17:12:57 +0200 Subject: [PATCH 2/3] don't double box futures --- axum/src/ext_traits/request.rs | 69 ++++++++++++++++------------ axum/src/ext_traits/request_parts.rs | 31 +++++++------ 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/axum/src/ext_traits/request.rs b/axum/src/ext_traits/request.rs index 9ec79c79f3..f4fad5b8be 100644 --- a/axum/src/ext_traits/request.rs +++ b/axum/src/ext_traits/request.rs @@ -1,5 +1,5 @@ -use async_trait::async_trait; use axum_core::extract::{FromRequest, FromRequestParts}; +use futures_util::future::BoxFuture; use http::Request; mod sealed { @@ -8,7 +8,6 @@ mod sealed { } /// Extension trait that adds additional methods to [`Request`]. -#[async_trait] pub trait RequestExt: sealed::Sealed + Sized { /// Apply an extractor to this `Request`. /// @@ -16,9 +15,10 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting /// the body and don't want to consume the request. - async fn extract(self) -> Result + fn extract(self) -> BoxFuture<'static, Result> where - E: FromRequest<(), B, M>; + E: FromRequest<(), B, M> + 'static, + M: 'static; /// Apply an extractor that requires some state to this `Request`. /// @@ -26,57 +26,63 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not /// extracting the body and don't want to consume the request. - async fn extract_with_state(self, state: &S) -> Result + fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> where - E: FromRequest, + E: FromRequest + 'static, S: Send + Sync; /// Apply a parts extractor to this `Request`. /// /// This is just a convenience for `E::from_request_parts(parts, state)`. - async fn extract_parts(&mut self) -> Result + fn extract_parts(&mut self) -> BoxFuture<'_, Result> where - E: FromRequestParts<()>; + E: FromRequestParts<()> + 'static; /// Apply a parts extractor that requires some state to this `Request`. /// /// This is just a convenience for `E::from_request_parts(parts, state)`. - async fn extract_parts_with_state(&mut self, state: &S) -> Result + fn extract_parts_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> where - E: FromRequestParts, + E: FromRequestParts + 'static, S: Send + Sync; } -#[async_trait] impl RequestExt for Request where - B: Send, + B: Send + 'static, { - async fn extract(self) -> Result + fn extract(self) -> BoxFuture<'static, Result> where - E: FromRequest<(), B, M>, + E: FromRequest<(), B, M> + 'static, + M: 'static, { - self.extract_with_state(&()).await + self.extract_with_state(&()) } - async fn extract_with_state(self, state: &S) -> Result + fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> where - E: FromRequest, + E: FromRequest + 'static, S: Send + Sync, { - E::from_request(self, state).await + E::from_request(self, state) } - async fn extract_parts(&mut self) -> Result + fn extract_parts(&mut self) -> BoxFuture<'_, Result> where - E: FromRequestParts<()>, + E: FromRequestParts<()> + 'static, { - self.extract_parts_with_state(&()).await + self.extract_parts_with_state(&()) } - async fn extract_parts_with_state(&mut self, state: &S) -> Result + fn extract_parts_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> where - E: FromRequestParts, + E: FromRequestParts + 'static, S: Send + Sync, { let mut req = Request::new(()); @@ -87,15 +93,17 @@ where *req.extensions_mut() = std::mem::take(self.extensions_mut()); let (mut parts, _) = req.into_parts(); - let result = E::from_request_parts(&mut parts, state).await; + Box::pin(async move { + let result = E::from_request_parts(&mut parts, state).await; - *self.version_mut() = parts.version; - *self.method_mut() = parts.method.clone(); - *self.uri_mut() = parts.uri.clone(); - *self.headers_mut() = std::mem::take(&mut parts.headers); - *self.extensions_mut() = std::mem::take(&mut parts.extensions); + *self.version_mut() = parts.version; + *self.method_mut() = parts.method.clone(); + *self.uri_mut() = parts.uri.clone(); + *self.headers_mut() = std::mem::take(&mut parts.headers); + *self.extensions_mut() = std::mem::take(&mut parts.extensions); - result + result + }) } } @@ -103,6 +111,7 @@ where mod tests { use super::*; use crate::{ext_traits::tests::RequiresState, extract::State}; + use async_trait::async_trait; use axum_core::extract::FromRef; use http::Method; use hyper::Body; diff --git a/axum/src/ext_traits/request_parts.rs b/axum/src/ext_traits/request_parts.rs index c243b6e0e3..c35f7c7445 100644 --- a/axum/src/ext_traits/request_parts.rs +++ b/axum/src/ext_traits/request_parts.rs @@ -1,5 +1,5 @@ -use async_trait::async_trait; use axum_core::extract::FromRequestParts; +use futures_util::future::BoxFuture; use http::request::Parts; mod sealed { @@ -8,39 +8,43 @@ mod sealed { } /// Extension trait that adds additional methods to [`Parts`]. -#[async_trait] pub trait RequestPartsExt: sealed::Sealed + Sized { /// Apply an extractor to this `Parts`. /// /// This is just a convenience for `E::from_request_parts(parts, &())`. - async fn extract(&mut self) -> Result + fn extract(&mut self) -> BoxFuture<'_, Result> where - E: FromRequestParts<()>; + E: FromRequestParts<()> + 'static; /// Apply an extractor that requires some state to this `Parts`. /// /// This is just a convenience for `E::from_request_parts(parts, state)`. - async fn extract_with_state(&mut self, state: &S) -> Result + fn extract_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> where - E: FromRequestParts, + E: FromRequestParts + 'static, S: Send + Sync; } -#[async_trait] impl RequestPartsExt for Parts { - async fn extract(&mut self) -> Result + fn extract(&mut self) -> BoxFuture<'_, Result> where - E: FromRequestParts<()>, + E: FromRequestParts<()> + 'static, { - self.extract_with_state(&()).await + self.extract_with_state(&()) } - async fn extract_with_state(&mut self, state: &S) -> Result + fn extract_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> where - E: FromRequestParts, + E: FromRequestParts + 'static, S: Send + Sync, { - E::from_request_parts(self, state).await + E::from_request_parts(self, state) } } @@ -50,6 +54,7 @@ mod tests { use super::*; use crate::{ext_traits::tests::RequiresState, extract::State}; + use async_trait::async_trait; use axum_core::extract::FromRef; use http::{Method, Request}; From c63d93e3859171f6f13bb32d24b3eadd3b4f96a2 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 22 Aug 2022 17:16:07 +0200 Subject: [PATCH 3/3] changelog pr link --- axum/CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 82900aad2d..06b3f37b1a 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -343,7 +343,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `FromRequest`, no default - `Handler`, no default - **added:** Add `RequestExt` and `RequestPartsExt` which adds convenience - methods for running extractors to `http::Request` and `http::request::Parts` + methods for running extractors to `http::Request` and `http::request::Parts` ([#1301]) ## Middleware @@ -374,6 +374,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1248]: https://github.com/tokio-rs/axum/pull/1248 [#1272]: https://github.com/tokio-rs/axum/pull/1272 +[#1301]: https://github.com/tokio-rs/axum/pull/1301 [#924]: https://github.com/tokio-rs/axum/pull/924 # 0.5.15 (9. August, 2022)