diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index b1ce5a204d..06b3f37b1a 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -342,6 +342,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `MethodRouter`, defaults to `()` - `FromRequest`, no default - `Handler`, no default +- **added:** Add `RequestExt` and `RequestPartsExt` which adds convenience + methods for running extractors to `http::Request` and `http::request::Parts` ([#1301]) ## Middleware @@ -372,6 +374,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1248]: https://github.com/tokio-rs/axum/pull/1248 [#1272]: https://github.com/tokio-rs/axum/pull/1272 +[#1301]: https://github.com/tokio-rs/axum/pull/1301 [#924]: https://github.com/tokio-rs/axum/pull/924 # 0.5.15 (9. August, 2022) diff --git a/axum/src/ext_traits/mod.rs b/axum/src/ext_traits/mod.rs new file mode 100644 index 0000000000..131e99b159 --- /dev/null +++ b/axum/src/ext_traits/mod.rs @@ -0,0 +1,30 @@ +pub(crate) mod request; +pub(crate) mod request_parts; + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use async_trait::async_trait; + use axum_core::extract::{FromRef, FromRequestParts}; + use http::request::Parts; + + // some extractor that requires the state, such as `SignedCookieJar` + pub(crate) struct RequiresState(pub(crate) String); + + #[async_trait] + impl FromRequestParts for RequiresState + where + S: Send + Sync, + String: FromRef, + { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut Parts, + state: &S, + ) -> Result { + Ok(Self(String::from_ref(state))) + } + } +} diff --git a/axum/src/ext_traits/request.rs b/axum/src/ext_traits/request.rs new file mode 100644 index 0000000000..f4fad5b8be --- /dev/null +++ b/axum/src/ext_traits/request.rs @@ -0,0 +1,200 @@ +use axum_core::extract::{FromRequest, FromRequestParts}; +use futures_util::future::BoxFuture; +use http::Request; + +mod sealed { + pub trait Sealed {} + impl Sealed for http::Request {} +} + +/// Extension trait that adds additional methods to [`Request`]. +pub trait RequestExt: sealed::Sealed + Sized { + /// Apply an extractor to this `Request`. + /// + /// This is just a convenience for `E::from_request(req, &())`. + /// + /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting + /// the body and don't want to consume the request. + fn extract(self) -> BoxFuture<'static, Result> + where + E: FromRequest<(), B, M> + 'static, + M: 'static; + + /// Apply an extractor that requires some state to this `Request`. + /// + /// This is just a convenience for `E::from_request(req, state)`. + /// + /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not + /// extracting the body and don't want to consume the request. + fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> + where + E: FromRequest + 'static, + S: Send + Sync; + + /// Apply a parts extractor to this `Request`. + /// + /// This is just a convenience for `E::from_request_parts(parts, state)`. + fn extract_parts(&mut self) -> BoxFuture<'_, Result> + where + E: FromRequestParts<()> + 'static; + + /// Apply a parts extractor that requires some state to this `Request`. + /// + /// This is just a convenience for `E::from_request_parts(parts, state)`. + fn extract_parts_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> + where + E: FromRequestParts + 'static, + S: Send + Sync; +} + +impl RequestExt for Request +where + B: Send + 'static, +{ + fn extract(self) -> BoxFuture<'static, Result> + where + E: FromRequest<(), B, M> + 'static, + M: 'static, + { + self.extract_with_state(&()) + } + + fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> + where + E: FromRequest + 'static, + S: Send + Sync, + { + E::from_request(self, state) + } + + fn extract_parts(&mut self) -> BoxFuture<'_, Result> + where + E: FromRequestParts<()> + 'static, + { + self.extract_parts_with_state(&()) + } + + fn extract_parts_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> + where + E: FromRequestParts + 'static, + S: Send + Sync, + { + let mut req = Request::new(()); + *req.version_mut() = self.version(); + *req.method_mut() = self.method().clone(); + *req.uri_mut() = self.uri().clone(); + *req.headers_mut() = std::mem::take(self.headers_mut()); + *req.extensions_mut() = std::mem::take(self.extensions_mut()); + let (mut parts, _) = req.into_parts(); + + Box::pin(async move { + let result = E::from_request_parts(&mut parts, state).await; + + *self.version_mut() = parts.version; + *self.method_mut() = parts.method.clone(); + *self.uri_mut() = parts.uri.clone(); + *self.headers_mut() = std::mem::take(&mut parts.headers); + *self.extensions_mut() = std::mem::take(&mut parts.extensions); + + result + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ext_traits::tests::RequiresState, extract::State}; + use async_trait::async_trait; + use axum_core::extract::FromRef; + use http::Method; + use hyper::Body; + + #[tokio::test] + async fn extract_without_state() { + let req = Request::new(()); + + let method: Method = req.extract().await.unwrap(); + + assert_eq!(method, Method::GET); + } + + #[tokio::test] + async fn extract_body_without_state() { + let req = Request::new(Body::from("foobar")); + + let body: String = req.extract().await.unwrap(); + + assert_eq!(body, "foobar"); + } + + #[tokio::test] + async fn extract_with_state() { + let req = Request::new(()); + + let state = "state".to_owned(); + + let State(extracted_state): State = req.extract_with_state(&state).await.unwrap(); + + assert_eq!(extracted_state, state); + } + + #[tokio::test] + async fn extract_parts_without_state() { + let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); + + let method: Method = req.extract_parts().await.unwrap(); + + assert_eq!(method, Method::GET); + assert_eq!(req.headers()["x-foo"], "foo"); + } + + #[tokio::test] + async fn extract_parts_with_state() { + let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); + + let state = "state".to_owned(); + + let State(extracted_state): State = + req.extract_parts_with_state(&state).await.unwrap(); + + assert_eq!(extracted_state, state); + assert_eq!(req.headers()["x-foo"], "foo"); + } + + // this stuff just needs to compile + #[allow(dead_code)] + struct WorksForCustomExtractor { + method: Method, + from_state: String, + body: String, + } + + #[async_trait] + impl FromRequest for WorksForCustomExtractor + where + S: Send + Sync, + B: Send + 'static, + String: FromRef + FromRequest<(), B>, + { + type Rejection = >::Rejection; + + async fn from_request(mut req: Request, state: &S) -> Result { + let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap(); + let method = req.extract_parts().await.unwrap(); + let body = req.extract().await?; + + Ok(Self { + method, + from_state, + body, + }) + } + } +} diff --git a/axum/src/ext_traits/request_parts.rs b/axum/src/ext_traits/request_parts.rs new file mode 100644 index 0000000000..c35f7c7445 --- /dev/null +++ b/axum/src/ext_traits/request_parts.rs @@ -0,0 +1,103 @@ +use axum_core::extract::FromRequestParts; +use futures_util::future::BoxFuture; +use http::request::Parts; + +mod sealed { + pub trait Sealed {} + impl Sealed for http::request::Parts {} +} + +/// Extension trait that adds additional methods to [`Parts`]. +pub trait RequestPartsExt: sealed::Sealed + Sized { + /// Apply an extractor to this `Parts`. + /// + /// This is just a convenience for `E::from_request_parts(parts, &())`. + fn extract(&mut self) -> BoxFuture<'_, Result> + where + E: FromRequestParts<()> + 'static; + + /// Apply an extractor that requires some state to this `Parts`. + /// + /// This is just a convenience for `E::from_request_parts(parts, state)`. + fn extract_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> + where + E: FromRequestParts + 'static, + S: Send + Sync; +} + +impl RequestPartsExt for Parts { + fn extract(&mut self) -> BoxFuture<'_, Result> + where + E: FromRequestParts<()> + 'static, + { + self.extract_with_state(&()) + } + + fn extract_with_state<'a, E, S>( + &'a mut self, + state: &'a S, + ) -> BoxFuture<'a, Result> + where + E: FromRequestParts + 'static, + S: Send + Sync, + { + E::from_request_parts(self, state) + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use super::*; + use crate::{ext_traits::tests::RequiresState, extract::State}; + use async_trait::async_trait; + use axum_core::extract::FromRef; + use http::{Method, Request}; + + #[tokio::test] + async fn extract_without_state() { + let (mut parts, _) = Request::new(()).into_parts(); + + let method: Method = parts.extract().await.unwrap(); + + assert_eq!(method, Method::GET); + } + + #[tokio::test] + async fn extract_with_state() { + let (mut parts, _) = Request::new(()).into_parts(); + + let state = "state".to_owned(); + + let State(extracted_state): State = parts.extract_with_state(&state).await.unwrap(); + + assert_eq!(extracted_state, state); + } + + // this stuff just needs to compile + #[allow(dead_code)] + struct WorksForCustomExtractor { + method: Method, + from_state: String, + } + + #[async_trait] + impl FromRequestParts for WorksForCustomExtractor + where + S: Send + Sync, + String: FromRef, + { + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let RequiresState(from_state) = parts.extract_with_state(state).await?; + let method = parts.extract().await?; + + Ok(Self { method, from_state }) + } + } +} diff --git a/axum/src/lib.rs b/axum/src/lib.rs index df51086828..c74e80bb26 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -433,6 +433,7 @@ #[macro_use] pub(crate) mod macros; +mod ext_traits; mod extension; #[cfg(feature = "form")] mod form; @@ -484,3 +485,5 @@ pub use axum_core::{BoxError, Error}; #[cfg(feature = "macros")] pub use axum_macros::debug_handler; + +pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt};