diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index 62b1472c6b..7fe07dcfbb 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,10 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which - represents the state ([#1155]) +- **breaking:** `FromRequest` has been reworked and `RequestParts` has been + removed. See axum's changelog for more details ([#1272]) +- **added:** Added new `FromRequestParts` trait. See axum's changelog for more + details ([#1272]) +- **breaking:** `BodyAlreadyExtracted` has been removed ([#1272]) [#1155]: https://github.com/tokio-rs/axum/pull/1155 +[#1272]: https://github.com/tokio-rs/axum/pull/1272 # 0.2.7 (10. July, 2022) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 9c867321b0..a7c0412274 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -4,11 +4,10 @@ //! //! [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html -use self::rejection::*; use crate::response::IntoResponse; use async_trait::async_trait; -use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; -use std::{convert::Infallible, sync::Arc}; +use http::{request::Parts, Request}; +use std::convert::Infallible; pub mod rejection; @@ -18,9 +17,44 @@ mod tuple; pub use self::from_ref::FromRef; +mod private { + #[derive(Debug, Clone, Copy)] + pub enum ViaParts {} + + #[derive(Debug, Clone, Copy)] + pub enum ViaRequest {} +} + +/// Types that can be created from request parts. +/// +/// Extractors that implement `FromRequestParts` cannot consume the request body and can thus be +/// run in any order for handlers. +/// +/// If your extractor needs to consume the request body then you should implement [`FromRequest`] +/// and not [`FromRequestParts`]. +/// +/// See [`axum::extract`] for more general docs about extraxtors. +/// +/// [`axum::extract`]: https://docs.rs/axum/0.6/axum/extract/index.html +#[async_trait] +pub trait FromRequestParts: Sized { + /// If the extractor fails it'll use this "rejection" type. A rejection is + /// a kind of error that can be converted into a response. + type Rejection: IntoResponse; + + /// Perform the extraction. + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result; +} + /// Types that can be created from requests. /// -/// See [`axum::extract`] for more details. +/// Extractors that implement `FromRequest` can consume the request body and can thus only be run +/// once for handlers. +/// +/// If your extractor doesn't need to consume the request body then you should implement +/// [`FromRequestParts`] and not [`FromRequest`]. +/// +/// See [`axum::extract`] for more general docs about extraxtors. /// /// # What is the `B` type parameter? /// @@ -39,7 +73,8 @@ pub use self::from_ref::FromRef; /// ```rust /// use axum::{ /// async_trait, -/// extract::{FromRequest, RequestParts}, +/// extract::FromRequest, +/// http::Request, /// }; /// /// struct MyExtractor; @@ -48,12 +83,12 @@ pub use self::from_ref::FromRef; /// impl FromRequest for MyExtractor /// where /// // these bounds are required by `async_trait` -/// B: Send, +/// B: Send + 'static, /// S: Send + Sync, /// { /// type Rejection = http::StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: Request, state: &S) -> Result { /// // ... /// # unimplemented!() /// } @@ -63,245 +98,72 @@ pub use self::from_ref::FromRef; /// This ensures your extractor is as flexible as possible. /// /// [`http::Request`]: http::Request -/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html +/// [`axum::extract`]: https://docs.rs/axum/0.6/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut RequestParts) -> Result; + async fn from_request(req: Request, state: &S) -> Result; } -/// The type used with [`FromRequest`] to extract data from requests. -/// -/// Has several convenience methods for getting owned parts of the request. -#[derive(Debug)] -pub struct RequestParts { - pub(crate) state: Arc, - method: Method, - uri: Uri, - version: Version, - headers: HeaderMap, - extensions: Extensions, - body: Option, -} +#[async_trait] +impl FromRequest for T +where + B: Send + 'static, + S: Send + Sync, + T: FromRequestParts, +{ + type Rejection = >::Rejection; -impl RequestParts<(), B> { - /// Create a new `RequestParts` without any state. - /// - /// You generally shouldn't need to construct this type yourself, unless - /// using extractors outside of axum for example to implement a - /// [`tower::Service`]. - /// - /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - pub fn new(req: Request) -> Self { - Self::with_state((), req) + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, _) = req.into_parts(); + Self::from_request_parts(&mut parts, state).await } } -impl RequestParts { - /// Create a new `RequestParts` with the given state. - /// - /// You generally shouldn't need to construct this type yourself, unless - /// using extractors outside of axum for example to implement a - /// [`tower::Service`]. - /// - /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - pub fn with_state(state: S, req: Request) -> Self { - Self::with_state_arc(Arc::new(state), req) - } - - /// Create a new `RequestParts` with the given [`Arc`]'ed state. - /// - /// You generally shouldn't need to construct this type yourself, unless - /// using extractors outside of axum for example to implement a - /// [`tower::Service`]. - /// - /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - pub fn with_state_arc(state: Arc, req: Request) -> Self { - let ( - http::request::Parts { - method, - uri, - version, - headers, - extensions, - .. - }, - body, - ) = req.into_parts(); - - RequestParts { - state, - method, - uri, - version, - headers, - extensions, - body: Some(body), - } - } - - /// Apply an extractor to this `RequestParts`. - /// - /// `req.extract::()` is equivalent to `Extractor::from_request(req)`. - /// This function simply exists as a convenience. - /// - /// # Example - /// - /// ``` - /// # struct MyExtractor {} - /// - /// use std::convert::Infallible; - /// - /// use async_trait::async_trait; - /// use axum::extract::{FromRequest, RequestParts}; - /// use http::{Method, Uri}; - /// - /// #[async_trait] - /// impl FromRequest for MyExtractor - /// where - /// B: Send, - /// S: Send + Sync, - /// { - /// type Rejection = Infallible; - /// - /// async fn from_request(req: &mut RequestParts) -> Result { - /// let method = req.extract::().await?; - /// let path = req.extract::().await?.path().to_owned(); - /// - /// todo!() - /// } - /// } - /// ``` - pub async fn extract(&mut self) -> Result - where - E: FromRequest, - { - E::from_request(self).await - } - - /// Convert this `RequestParts` back into a [`Request`]. - /// - /// Fails if The request body has been extracted, that is [`take_body`] has - /// been called. - /// - /// [`take_body`]: RequestParts::take_body - pub fn try_into_request(self) -> Result, BodyAlreadyExtracted> { - let Self { - state: _, - method, - uri, - version, - headers, - extensions, - mut body, - } = self; - - let mut req = if let Some(body) = body.take() { - Request::new(body) - } else { - return Err(BodyAlreadyExtracted); - }; - - *req.method_mut() = method; - *req.uri_mut() = uri; - *req.version_mut() = version; - *req.headers_mut() = headers; - *req.extensions_mut() = extensions; - - Ok(req) - } - - /// Gets a reference to the request method. - pub fn method(&self) -> &Method { - &self.method - } - - /// Gets a mutable reference to the request method. - pub fn method_mut(&mut self) -> &mut Method { - &mut self.method - } - - /// Gets a reference to the request URI. - pub fn uri(&self) -> &Uri { - &self.uri - } - - /// Gets a mutable reference to the request URI. - pub fn uri_mut(&mut self) -> &mut Uri { - &mut self.uri - } - - /// Get the request HTTP version. - pub fn version(&self) -> Version { - self.version - } - - /// Gets a mutable reference to the request HTTP version. - pub fn version_mut(&mut self) -> &mut Version { - &mut self.version - } - - /// Gets a reference to the request headers. - pub fn headers(&self) -> &HeaderMap { - &self.headers - } - - /// Gets a mutable reference to the request headers. - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers - } - - /// Gets a reference to the request extensions. - pub fn extensions(&self) -> &Extensions { - &self.extensions - } - - /// Gets a mutable reference to the request extensions. - pub fn extensions_mut(&mut self) -> &mut Extensions { - &mut self.extensions - } - - /// Gets a reference to the request body. - /// - /// Returns `None` if the body has been taken by another extractor. - pub fn body(&self) -> Option<&B> { - self.body.as_ref() - } +#[async_trait] +impl FromRequestParts for Option +where + T: FromRequestParts, + S: Send + Sync, +{ + type Rejection = Infallible; - /// Gets a mutable reference to the request body. - /// - /// Returns `None` if the body has been taken by another extractor. - // this returns `&mut Option` rather than `Option<&mut B>` such that users can use it to set the body. - pub fn body_mut(&mut self) -> &mut Option { - &mut self.body + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result, Self::Rejection> { + Ok(T::from_request_parts(parts, state).await.ok()) } +} - /// Takes the body out of the request, leaving a `None` in its place. - pub fn take_body(&mut self) -> Option { - self.body.take() - } +#[async_trait] +impl FromRequest for Option +where + T: FromRequest, + B: Send + 'static, + S: Send + Sync, +{ + type Rejection = Infallible; - /// Get a reference to the state. - pub fn state(&self) -> &S { - &self.state + async fn from_request(req: Request, state: &S) -> Result, Self::Rejection> { + Ok(T::from_request(req, state).await.ok()) } } #[async_trait] -impl FromRequest for Option +impl FromRequestParts for Result where - T: FromRequest, - B: Send, + T: FromRequestParts, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { - Ok(T::from_request(req).await.ok()) + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + Ok(T::from_request_parts(parts, state).await) } } @@ -309,12 +171,12 @@ where impl FromRequest for Result where T: FromRequest, - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - Ok(T::from_request(req).await) + async fn from_request(req: Request, state: &S) -> Result { + Ok(T::from_request(req, state).await) } } diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs index e6f53b8224..8afe112a62 100644 --- a/axum-core/src/extract/rejection.rs +++ b/axum-core/src/extract/rejection.rs @@ -1,35 +1,6 @@ //! Rejection response types. -use crate::{ - response::{IntoResponse, Response}, - BoxError, -}; -use http::StatusCode; -use std::fmt; - -/// Rejection type used if you try and extract the request body more than -/// once. -#[derive(Debug, Default)] -#[non_exhaustive] -pub struct BodyAlreadyExtracted; - -impl BodyAlreadyExtracted { - const BODY: &'static str = "Cannot have two request body extractors for a single handler"; -} - -impl IntoResponse for BodyAlreadyExtracted { - fn into_response(self) -> Response { - (StatusCode::INTERNAL_SERVER_ERROR, Self::BODY).into_response() - } -} - -impl fmt::Display for BodyAlreadyExtracted { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", Self::BODY) - } -} - -impl std::error::Error for BodyAlreadyExtracted {} +use crate::BoxError; composite_rejection! { /// Rejection type for extractors that buffer the request body. Used if the @@ -85,7 +56,6 @@ composite_rejection! { /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor /// can fail. pub enum BytesRejection { - BodyAlreadyExtracted, FailedToBufferBody, } } @@ -95,7 +65,6 @@ composite_rejection! { /// /// Contains one variant for each way the [`String`] extractor can fail. pub enum StringRejection { - BodyAlreadyExtracted, FailedToBufferBody, InvalidUtf8, } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index ed3ff2022f..245657cf3d 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,9 +1,9 @@ -use super::{rejection::*, FromRequest, RequestParts}; +use super::{rejection::*, FromRequest, FromRequestParts}; use crate::BoxError; use async_trait::async_trait; use bytes::Bytes; -use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; -use std::{convert::Infallible, sync::Arc}; +use http::{request::Parts, HeaderMap, Method, Request, Uri, Version}; +use std::convert::Infallible; #[async_trait] impl FromRequest for Request @@ -11,62 +11,46 @@ where B: Send, S: Send + Sync, { - type Rejection = BodyAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - let req = std::mem::replace( - req, - RequestParts { - state: Arc::clone(&req.state), - method: req.method.clone(), - version: req.version, - uri: req.uri.clone(), - headers: HeaderMap::new(), - extensions: Extensions::default(), - body: None, - }, - ); - - req.try_into_request() + type Rejection = Infallible; + + async fn from_request(req: Request, _: &S) -> Result { + Ok(req) } } #[async_trait] -impl FromRequest for Method +impl FromRequestParts for Method where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.method().clone()) + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + Ok(parts.method.clone()) } } #[async_trait] -impl FromRequest for Uri +impl FromRequestParts for Uri where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.uri().clone()) + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + Ok(parts.uri.clone()) } } #[async_trait] -impl FromRequest for Version +impl FromRequestParts for Version where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.version()) + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + Ok(parts.version) } } @@ -76,30 +60,29 @@ where /// /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] -impl FromRequest for HeaderMap +impl FromRequestParts for HeaderMap where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.headers().clone()) + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + Ok(parts.headers.clone()) } } #[async_trait] impl FromRequest for Bytes where - B: http_body::Body + Send, + B: http_body::Body + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = BytesRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; + async fn from_request(req: Request, _: &S) -> Result { + let body = req.into_body(); let bytes = crate::body::to_bytes(body) .await @@ -112,15 +95,15 @@ where #[async_trait] impl FromRequest for String where - B: http_body::Body + Send, + B: http_body::Body + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = StringRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; + async fn from_request(req: Request, _: &S) -> Result { + let body = req.into_body(); let bytes = crate::body::to_bytes(body) .await @@ -134,40 +117,14 @@ where } #[async_trait] -impl FromRequest for http::request::Parts +impl FromRequest for Parts where - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let method = unwrap_infallible(Method::from_request(req).await); - let uri = unwrap_infallible(Uri::from_request(req).await); - let version = unwrap_infallible(Version::from_request(req).await); - let headers = unwrap_infallible(HeaderMap::from_request(req).await); - let extensions = std::mem::take(req.extensions_mut()); - - let mut temp_request = Request::new(()); - *temp_request.method_mut() = method; - *temp_request.uri_mut() = uri; - *temp_request.version_mut() = version; - *temp_request.headers_mut() = headers; - *temp_request.extensions_mut() = extensions; - - let (parts, _) = temp_request.into_parts(); - - Ok(parts) + async fn from_request(req: Request, _: &S) -> Result { + Ok(req.into_parts().0) } } - -fn unwrap_infallible(result: Result) -> T { - match result { - Ok(value) => value, - Err(err) => match err {}, - } -} - -pub(crate) fn take_body(req: &mut RequestParts) -> Result { - req.take_body().ok_or(BodyAlreadyExtracted) -} diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 3a5938e2c2..03427b2bbb 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -1,41 +1,143 @@ -use super::{FromRequest, RequestParts}; +use super::{FromRequest, FromRequestParts}; use crate::response::{IntoResponse, Response}; use async_trait::async_trait; +use http::request::{Parts, Request}; use std::convert::Infallible; #[async_trait] -impl FromRequest for () +impl FromRequestParts for () where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { + async fn from_request_parts(_: &mut Parts, _: &S) -> Result<(), Self::Rejection> { Ok(()) } } macro_rules! impl_from_request { - () => {}; + ( + [$($ty:ident),*], $last:ident + ) => { + #[async_trait] + #[allow(non_snake_case, unused_mut, unused_variables)] + impl FromRequestParts for ($($ty,)* $last,) + where + $( $ty: FromRequestParts + Send, )* + $last: FromRequestParts + Send, + S: Send + Sync, + { + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + $( + let $ty = $ty::from_request_parts(parts, state) + .await + .map_err(|err| err.into_response())?; + )* + let $last = $last::from_request_parts(parts, state) + .await + .map_err(|err| err.into_response())?; + + Ok(($($ty,)* $last,)) + } + } - ( $($ty:ident),* $(,)? ) => { + // This impl must not be generic over M, otherwise it would conflict with the blanket + // implementation of `FromRequest` for `T: FromRequestParts`. #[async_trait] - #[allow(non_snake_case)] - impl FromRequest for ($($ty,)*) + #[allow(non_snake_case, unused_mut, unused_variables)] + impl FromRequest for ($($ty,)* $last,) where - $( $ty: FromRequest + Send, )* - B: Send, + $( $ty: FromRequestParts + Send, )* + $last: FromRequest + Send, + B: Send + 'static, S: Send + Sync, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { - $( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* - Ok(($($ty,)*)) + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, body) = req.into_parts(); + + $( + let $ty = $ty::from_request_parts(&mut parts, state).await.map_err(|err| err.into_response())?; + )* + + let req = Request::from_parts(parts, body); + + let $last = $last::from_request(req, state).await.map_err(|err| err.into_response())?; + + Ok(($($ty,)* $last,)) } } }; } -all_the_tuples!(impl_from_request); +impl_from_request!([], T1); +impl_from_request!([T1], T2); +impl_from_request!([T1, T2], T3); +impl_from_request!([T1, T2, T3], T4); +impl_from_request!([T1, T2, T3, T4], T5); +impl_from_request!([T1, T2, T3, T4, T5], T6); +impl_from_request!([T1, T2, T3, T4, T5, T6], T7); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7], T8); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8], T9); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); +impl_from_request!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], + T14 +); +impl_from_request!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], + T15 +); +impl_from_request!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], + T16 +); + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use http::Method; + + use crate::extract::{FromRequest, FromRequestParts}; + + fn assert_from_request() + where + T: FromRequest<(), http_body::Full, M>, + { + } + + fn assert_from_request_parts>() {} + + #[test] + fn unit() { + assert_from_request_parts::<()>(); + assert_from_request::<_, ()>(); + } + + #[test] + fn tuple_of_one() { + assert_from_request_parts::<(Method,)>(); + assert_from_request::<_, (Method,)>(); + assert_from_request::<_, (Bytes,)>(); + } + + #[test] + fn tuple_of_two() { + assert_from_request_parts::<((), ())>(); + assert_from_request::<_, ((), ())>(); + assert_from_request::<_, (Method, Bytes)>(); + } + + #[test] + fn nested_tuple() { + assert_from_request_parts::<(((Method,),),)>(); + assert_from_request::<_, ((((Bytes,),),),)>(); + } +} diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index d6d1c8ec62..2e3af8b8ea 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -4,15 +4,50 @@ //! //! ``` //! use axum_extra::either::Either3; -//! use axum::{body::Bytes, Json}; +//! use axum::{ +//! body::Bytes, +//! Router, +//! async_trait, +//! routing::get, +//! extract::FromRequestParts, +//! }; +//! +//! // extractors for checking permissions +//! struct AdminPermissions {} +//! +//! #[async_trait] +//! impl FromRequestParts for AdminPermissions +//! where +//! S: Send + Sync, +//! { +//! // check for admin permissions... +//! # type Rejection = (); +//! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result { +//! # todo!() +//! # } +//! } +//! +//! struct User {} +//! +//! #[async_trait] +//! impl FromRequestParts for User +//! where +//! S: Send + Sync, +//! { +//! // check for a logged in user... +//! # type Rejection = (); +//! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result { +//! # todo!() +//! # } +//! } //! //! async fn handler( -//! body: Either3, String, Bytes>, +//! body: Either3, //! ) { //! match body { -//! Either3::E1(json) => { /* ... */ } -//! Either3::E2(string) => { /* ... */ } -//! Either3::E3(bytes) => { /* ... */ } +//! Either3::E1(admin) => { /* ... */ } +//! Either3::E2(user) => { /* ... */ } +//! Either3::E3(guest) => { /* ... */ } //! } //! } //! # @@ -60,9 +95,10 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::FromRequestParts, response::{IntoResponse, Response}, }; +use http::request::Parts; /// Combines two extractors or responses into a single type. /// @@ -190,23 +226,22 @@ macro_rules! impl_traits_for_either { $last:ident $(,)? ) => { #[async_trait] - impl FromRequest for $either<$($ident),*, $last> + impl FromRequestParts for $either<$($ident),*, $last> where - $($ident: FromRequest),*, - $last: FromRequest, - B: Send, + $($ident: FromRequestParts),*, + $last: FromRequestParts, S: Send + Sync, { type Rejection = $last::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { $( - if let Ok(value) = req.extract().await { + if let Ok(value) = FromRequestParts::from_request_parts(parts, state).await { return Ok(Self::$ident(value)); } )* - req.extract().await.map(Self::$last) + FromRequestParts::from_request_parts(parts, state).await.map(Self::$last) } } diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 64519bf419..548a256245 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -1,7 +1,8 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{Extension, FromRequest, FromRequestParts}, }; +use http::{request::Parts, Request}; use std::ops::{Deref, DerefMut}; /// Cache results of other extractors. @@ -20,24 +21,23 @@ use std::ops::{Deref, DerefMut}; /// use axum_extra::extract::Cached; /// use axum::{ /// async_trait, -/// extract::{FromRequest, RequestParts}, +/// extract::FromRequestParts, /// body::BoxBody, /// response::{IntoResponse, Response}, -/// http::StatusCode, +/// http::{StatusCode, request::Parts}, /// }; /// /// #[derive(Clone)] /// struct Session { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for Session +/// impl FromRequestParts for Session /// where -/// B: Send, /// S: Send + Sync, /// { /// type Rejection = (StatusCode, String); /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // load session... /// # unimplemented!() /// } @@ -46,19 +46,18 @@ use std::ops::{Deref, DerefMut}; /// struct CurrentUser { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for CurrentUser +/// impl FromRequestParts for CurrentUser /// where -/// B: Send, /// S: Send + Sync, /// { /// type Rejection = Response; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // loading a `CurrentUser` requires first loading the `Session` /// // /// // by using `Cached` we avoid extracting the session more than /// // once, in case other extractors for the same request also loads the session -/// let session: Session = Cached::::from_request(req) +/// let session: Session = Cached::::from_request_parts(parts, state) /// .await /// .map_err(|err| err.into_response())? /// .0; @@ -92,18 +91,40 @@ struct CachedEntry(T); #[async_trait] impl FromRequest for Cached where - B: Send, + B: Send + 'static, S: Send + Sync, - T: FromRequest + Clone + Send + Sync + 'static, + T: FromRequestParts + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { - match Extension::>::from_request(req).await { + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, _) = req.into_parts(); + + match Extension::>::from_request_parts(&mut parts, state).await { + Ok(Extension(CachedEntry(value))) => Ok(Self(value)), + Err(_) => { + let value = T::from_request_parts(&mut parts, state).await?; + parts.extensions.insert(CachedEntry(value.clone())); + Ok(Self(value)) + } + } + } +} + +#[async_trait] +impl FromRequestParts for Cached +where + S: Send + Sync, + T: FromRequestParts + Clone + Send + Sync + 'static, +{ + type Rejection = T::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + match Extension::>::from_request_parts(parts, state).await { Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Err(_) => { - let value = T::from_request(req).await?; - req.extensions_mut().insert(CachedEntry(value.clone())); + let value = T::from_request_parts(parts, state).await?; + parts.extensions.insert(CachedEntry(value.clone())); Ok(Self(value)) } } @@ -127,7 +148,8 @@ impl DerefMut for Cached { #[cfg(test)] mod tests { use super::*; - use axum::http::Request; + use axum::{extract::FromRequestParts, http::Request}; + use http::request::Parts; use std::{ convert::Infallible, sync::atomic::{AtomicU32, Ordering}, @@ -142,25 +164,33 @@ mod tests { struct Extractor(Instant); #[async_trait] - impl FromRequest for Extractor + impl FromRequestParts for Extractor where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request_parts( + _parts: &mut Parts, + _state: &S, + ) -> Result { COUNTER.fetch_add(1, Ordering::SeqCst); Ok(Self(Instant::now())) } } - let mut req = RequestParts::new(Request::new(())); + let (mut parts, _) = Request::new(()).into_parts(); - let first = Cached::::from_request(&mut req).await.unwrap().0; + let first = Cached::::from_request_parts(&mut parts, &()) + .await + .unwrap() + .0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); - let second = Cached::::from_request(&mut req).await.unwrap().0; + let second = Cached::::from_request_parts(&mut parts, &()) + .await + .unwrap() + .0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); assert_eq!(first, second); diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 3edbe68f35..20d015d1b9 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -4,11 +4,12 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::FromRequestParts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use http::{ header::{COOKIE, SET_COOKIE}, + request::Parts, HeaderMap, }; use std::convert::Infallible; @@ -88,15 +89,14 @@ pub struct CookieJar { } #[async_trait] -impl FromRequest for CookieJar +impl FromRequestParts for CookieJar where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - Ok(Self::from_headers(req.headers())) + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + Ok(Self::from_headers(&parts.headers)) } } @@ -115,7 +115,9 @@ impl CookieJar { /// The cookies in `headers` will be added to the jar. /// /// This is inteded to be used in middleware and other places where it might be difficult to - /// run extractors. Normally you should create `CookieJar`s through [`FromRequest`]. + /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`]. + /// + /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn from_headers(headers: &HeaderMap) -> Self { let mut jar = cookie::CookieJar::new(); for cookie in cookies_from_request(headers) { @@ -127,10 +129,12 @@ impl CookieJar { /// Create a new empty `CookieJar`. /// /// This is inteded to be used in middleware and other places where it might be difficult to - /// run extractors. Normally you should create `CookieJar`s through [`FromRequest`]. + /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`]. /// /// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for /// CookieJar`. + /// + /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn new() -> Self { Self::default() } diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 7a88380c10..0b08fdc9e3 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -1,11 +1,11 @@ use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ async_trait, - extract::{FromRef, FromRequest, RequestParts}, + extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::PrivateJar; -use http::HeaderMap; +use http::{request::Parts, HeaderMap}; use std::{convert::Infallible, fmt, marker::PhantomData}; /// Extractor that grabs private cookies from the request and manages the jar. @@ -87,22 +87,21 @@ impl fmt::Debug for PrivateCookieJar { } #[async_trait] -impl FromRequest for PrivateCookieJar +impl FromRequestParts for PrivateCookieJar where - B: Send, S: Send + Sync, K: FromRef + Into, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let k = K::from_ref(req.state()); + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let k = K::from_ref(state); let key = k.into(); let PrivateCookieJar { jar, key, _marker: _, - } = PrivateCookieJar::from_headers(req.headers(), key); + } = PrivateCookieJar::from_headers(&parts.headers, key); Ok(PrivateCookieJar { jar, key, @@ -117,7 +116,9 @@ impl PrivateCookieJar { /// The valid cookies in `headers` will be added to the jar. /// /// This is inteded to be used in middleware and other where places it might be difficult to - /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequest`]. + /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`]. + /// + /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { let mut jar = cookie::CookieJar::new(); let mut private_jar = jar.private_mut(&key); @@ -137,7 +138,9 @@ impl PrivateCookieJar { /// Create a new empty `PrivateCookieJarIter`. /// /// This is inteded to be used in middleware and other places where it might be difficult to - /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequest`]. + /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`]. + /// + /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn new(key: Key) -> Self { Self { jar: Default::default(), diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 05ffcc926d..ca0aa4ca19 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -1,12 +1,12 @@ use super::{cookies_from_request, set_cookies}; use axum::{ async_trait, - extract::{FromRef, FromRequest, RequestParts}, + extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::SignedJar; use cookie::{Cookie, Key}; -use http::HeaderMap; +use http::{request::Parts, HeaderMap}; use std::{convert::Infallible, fmt, marker::PhantomData}; /// Extractor that grabs signed cookies from the request and manages the jar. @@ -105,22 +105,21 @@ impl fmt::Debug for SignedCookieJar { } #[async_trait] -impl FromRequest for SignedCookieJar +impl FromRequestParts for SignedCookieJar where - B: Send, S: Send + Sync, K: FromRef + Into, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let k = K::from_ref(req.state()); + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let k = K::from_ref(state); let key = k.into(); let SignedCookieJar { jar, key, _marker: _, - } = SignedCookieJar::from_headers(req.headers(), key); + } = SignedCookieJar::from_headers(&parts.headers, key); Ok(SignedCookieJar { jar, key, @@ -135,7 +134,9 @@ impl SignedCookieJar { /// The valid cookies in `headers` will be added to the jar. /// /// This is inteded to be used in middleware and other places where it might be difficult to - /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequest`]. + /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`]. + /// + /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { let mut jar = cookie::CookieJar::new(); let mut signed_jar = jar.signed_mut(&key); @@ -155,7 +156,9 @@ impl SignedCookieJar { /// Create a new empty `SignedCookieJar`. /// /// This is inteded to be used in middleware and other places where it might be difficult to - /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequest`]. + /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`]. + /// + /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn new(key: Key) -> Self { Self { jar: Default::default(), diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index 08c36755fd..254c7cce41 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -3,12 +3,12 @@ use axum::{ body::HttpBody, extract::{ rejection::{FailedToDeserializeQueryString, FormRejection, InvalidFormContentType}, - FromRequest, RequestParts, + FromRequest, }, BoxError, }; use bytes::Bytes; -use http::{header, Method}; +use http::{header, HeaderMap, Method, Request}; use serde::de::DeserializeOwned; use std::ops::Deref; @@ -58,25 +58,25 @@ impl Deref for Form { impl FromRequest for Form where T: DeserializeOwned, - B: HttpBody + Send, + B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, state: &S) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; Ok(Form(value)) } else { - if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) { + if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) { return Err(InvalidFormContentType::default().into()); } - let bytes = Bytes::from_request(req).await?; + let bytes = Bytes::from_request(req, state).await?; let value = serde_html_form::from_bytes(&bytes) .map_err(FailedToDeserializeQueryString::__private_new)?; @@ -86,8 +86,8 @@ where } // this is duplicated in `axum/src/extract/mod.rs` -fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { - let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { +fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { content_type } else { return false; diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index feae007e43..4a8d6f8676 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -2,9 +2,10 @@ use axum::{ async_trait, extract::{ rejection::{FailedToDeserializeQueryString, QueryRejection}, - FromRequest, RequestParts, + FromRequestParts, }, }; +use http::request::Parts; use serde::de::DeserializeOwned; use std::ops::Deref; @@ -58,16 +59,15 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequestParts for Query where T: DeserializeOwned, - B: Send, S: Send + Sync, { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let query = req.uri().query().unwrap_or_default(); + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let query = parts.uri.query().unwrap_or_default(); let value = serde_html_form::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; Ok(Query(value)) diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index e9abc40886..f3a0f04e87 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -1,6 +1,8 @@ use axum::async_trait; -use axum::extract::{FromRequest, RequestParts}; +use axum::extract::{FromRequest, FromRequestParts}; use axum::response::IntoResponse; +use http::request::Parts; +use http::Request; use std::fmt::Debug; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; @@ -109,23 +111,40 @@ impl DerefMut for WithRejection { #[async_trait] impl FromRequest for WithRejection where - B: Send, + B: Send + 'static, S: Send + Sync, E: FromRequest, R: From + IntoResponse, { type Rejection = R; - async fn from_request(req: &mut RequestParts) -> Result { - let extractor = req.extract::().await?; + async fn from_request(req: Request, state: &S) -> Result { + let extractor = E::from_request(req, state).await?; + Ok(WithRejection(extractor, PhantomData)) + } +} + +#[async_trait] +impl FromRequestParts for WithRejection +where + S: Send + Sync, + E: FromRequestParts, + R: From + IntoResponse, +{ + type Rejection = R; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let extractor = E::from_request_parts(parts, state).await?; Ok(WithRejection(extractor, PhantomData)) } } #[cfg(test)] mod tests { + use axum::extract::FromRequestParts; use axum::http::Request; use axum::response::Response; + use http::request::Parts; use super::*; @@ -135,14 +154,16 @@ mod tests { struct TestRejection; #[async_trait] - impl FromRequest for TestExtractor + impl FromRequestParts for TestExtractor where - B: Send, S: Send + Sync, { type Rejection = (); - async fn from_request(_: &mut RequestParts) -> Result { + async fn from_request_parts( + _parts: &mut Parts, + _state: &S, + ) -> Result { Err(()) } } @@ -159,12 +180,14 @@ mod tests { } } - let mut req = RequestParts::new(Request::new(())); - - let result = req - .extract::>() - .await; + let req = Request::new(()); + let result = WithRejection::::from_request(req, &()).await; + assert!(matches!(result, Err(TestRejection))); - assert!(matches!(result, Err(TestRejection))) + let (mut parts, _) = Request::new(()).into_parts(); + let result = + WithRejection::::from_request_parts(&mut parts, &()) + .await; + assert!(matches!(result, Err(TestRejection))); } } diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index ef12f896e9..75b4d2cf1d 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -1,7 +1,7 @@ //! Additional handler utilities. use axum::{ - extract::{FromRequest, RequestParts}, + extract::FromRequest, handler::Handler, response::{IntoResponse, Response}, }; @@ -26,8 +26,8 @@ pub trait HandlerCallWithExtractors: Sized { /// Call the handler with the extracted inputs. fn call( self, - state: Arc, extractors: T, + state: Arc, ) -> >::Future; /// Conver this `HandlerCallWithExtractors` into [`Handler`]. @@ -51,7 +51,7 @@ pub trait HandlerCallWithExtractors: Sized { /// Router, /// async_trait, /// routing::get, - /// extract::FromRequest, + /// extract::FromRequestParts, /// }; /// /// // handlers for varying levels of access @@ -71,14 +71,13 @@ pub trait HandlerCallWithExtractors: Sized { /// struct AdminPermissions {} /// /// #[async_trait] - /// impl FromRequest for AdminPermissions + /// impl FromRequestParts for AdminPermissions /// where - /// B: Send, /// S: Send + Sync, /// { /// // check for admin permissions... /// # type Rejection = (); - /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { + /// # async fn from_request_parts(parts: &mut http::request::Parts, state: &S) -> Result { /// # todo!() /// # } /// } @@ -86,14 +85,13 @@ pub trait HandlerCallWithExtractors: Sized { /// struct User {} /// /// #[async_trait] - /// impl FromRequest for User + /// impl FromRequestParts for User /// where - /// B: Send, /// S: Send + Sync, /// { /// // check for a logged in user... /// # type Rejection = (); - /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { + /// # async fn from_request_parts(parts: &mut http::request::Parts, state: &S) -> Result { /// # todo!() /// # } /// } @@ -121,27 +119,27 @@ pub trait HandlerCallWithExtractors: Sized { } macro_rules! impl_handler_call_with { - ( $($ty:ident),* $(,)? ) => { - #[allow(non_snake_case)] - impl HandlerCallWithExtractors<($($ty,)*), S, B> for F - where - F: FnOnce($($ty,)*) -> Fut, - Fut: Future + Send + 'static, - Fut::Output: IntoResponse, - { - // this puts `futures_util` in our public API but thats fine in axum-extra - type Future = Map Response>; + ( $($ty:ident),* $(,)? ) => { + #[allow(non_snake_case)] + impl HandlerCallWithExtractors<($($ty,)*), S, B> for F + where + F: FnOnce($($ty,)*) -> Fut, + Fut: Future + Send + 'static, + Fut::Output: IntoResponse, + { + // this puts `futures_util` in our public API but thats fine in axum-extra + type Future = Map Response>; - fn call( - self, - _state: Arc, - ($($ty,)*): ($($ty,)*), - ) -> >::Future { - self($($ty,)*).map(IntoResponse::into_response) - } - } - }; -} + fn call( + self, + ($($ty,)*): ($($ty,)*), + _state: Arc, + ) -> >::Future { + self($($ty,)*).map(IntoResponse::into_response) + } + } + }; + } impl_handler_call_with!(); impl_handler_call_with!(T1); @@ -180,11 +178,10 @@ where { type Future = BoxFuture<'static, Response>; - fn call(self, state: Arc, req: http::Request) -> Self::Future { + fn call(self, req: http::Request, state: Arc) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state_arc(Arc::clone(&state), req); - match req.extract::().await { - Ok(t) => self.handler.call(state, t).await, + match T::from_request(req, &state).await { + Ok(t) => self.handler.call(t, state).await, Err(rejection) => rejection.into_response(), } }) diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index fb307ccf7a..cf470ea5cc 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -1,13 +1,12 @@ use super::HandlerCallWithExtractors; use crate::either::Either; use axum::{ - extract::{FromRequest, RequestParts}, + extract::{FromRequest, FromRequestParts}, handler::Handler, http::Request, response::{IntoResponse, Response}, }; use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map}; -use http::StatusCode; use std::{future::Future, marker::PhantomData, sync::Arc}; /// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another @@ -37,30 +36,30 @@ where fn call( self, - state: Arc, extractors: Either, + state: Arc, ) -> , S, B>>::Future { match extractors { Either::E1(lt) => self .lhs - .call(state, lt) + .call(lt, state) .map(IntoResponse::into_response as _) .left_future(), Either::E2(rt) => self .rhs - .call(state, rt) + .call(rt, state) .map(IntoResponse::into_response as _) .right_future(), } } } -impl Handler<(Lt, Rt), S, B> for Or +impl Handler<(M, Lt, Rt), S, B> for Or where L: HandlerCallWithExtractors + Clone + Send + 'static, R: HandlerCallWithExtractors + Clone + Send + 'static, - Lt: FromRequest + Send + 'static, - Rt: FromRequest + Send + 'static, + Lt: FromRequestParts + Send + 'static, + Rt: FromRequest + Send + 'static, Lt::Rejection: Send, Rt::Rejection: Send, B: Send + 'static, @@ -69,19 +68,20 @@ where // this puts `futures_util` in our public API but thats fine in axum-extra type Future = BoxFuture<'static, Response>; - fn call(self, state: Arc, req: Request) -> Self::Future { + fn call(self, req: Request, state: Arc) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state_arc(Arc::clone(&state), req); + let (mut parts, body) = req.into_parts(); - if let Ok(lt) = req.extract::().await { - return self.lhs.call(state, lt).await; + if let Ok(lt) = Lt::from_request_parts(&mut parts, &state).await { + return self.lhs.call(lt, state).await; } - if let Ok(rt) = req.extract::().await { - return self.rhs.call(state, rt).await; - } + let req = Request::from_parts(parts, body); - StatusCode::NOT_FOUND.into_response() + match Rt::from_request(req, &state).await { + Ok(rt) => self.rhs.call(rt, state).await, + Err(rejection) => rejection.into_response(), + } }) } } diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 46ddc35ea8..83336311ca 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -3,15 +3,17 @@ use axum::{ async_trait, body::{HttpBody, StreamBody}, - extract::{rejection::BodyAlreadyExtracted, FromRequest, RequestParts}, + extract::FromRequest, response::{IntoResponse, Response}, BoxError, }; use bytes::{BufMut, Bytes, BytesMut}; use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt}; +use http::Request; use pin_project_lite::pin_project; use serde::{de::DeserializeOwned, Serialize}; use std::{ + convert::Infallible, io::{self, Write}, marker::PhantomData, pin::Pin, @@ -106,14 +108,14 @@ where T: DeserializeOwned, S: Send + Sync, { - type Rejection = BodyAlreadyExtracted; + type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, _state: &S) -> Result { // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // so we can call `AsyncRead::lines` and then convert it back to a `Stream` - - let body = req.take_body().ok_or_else(BodyAlreadyExtracted::default)?; - let body = BodyStream { body }; + let body = BodyStream { + body: req.into_body(), + }; let stream = body .map_ok(Into::into) diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index 906ab5f33b..ddb70122e6 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -3,12 +3,12 @@ use axum::{ async_trait, body::{Bytes, HttpBody}, - extract::{rejection::BytesRejection, FromRequest, RequestParts}, + extract::{rejection::BytesRejection, FromRequest}, response::{IntoResponse, Response}, BoxError, }; use bytes::BytesMut; -use http::StatusCode; +use http::{Request, StatusCode}; use prost::Message; use std::ops::{Deref, DerefMut}; @@ -100,15 +100,15 @@ pub struct ProtoBuf(pub T); impl FromRequest for ProtoBuf where T: Message + Default, - B: HttpBody + Send, + B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = ProtoBufRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let mut bytes = Bytes::from_request(req).await?; + async fn from_request(req: Request, state: &S) -> Result { + let mut bytes = Bytes::from_request(req, state).await?; match T::decode(&mut bytes) { Ok(value) => Ok(ProtoBuf(value)), diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index d64e1ff986..ad348292a0 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -24,7 +24,7 @@ pub use self::resource::Resource; pub use axum_macros::TypedPath; #[cfg(feature = "typed-routing")] -pub use self::typed::{FirstElementIs, TypedPath}; +pub use self::typed::{SecondElementIs, TypedPath}; #[cfg(feature = "spa")] pub use self::spa::SpaRouter; @@ -41,7 +41,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_get(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `DELETE` route to the router. @@ -54,7 +54,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_delete(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `HEAD` route to the router. @@ -67,7 +67,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_head(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `OPTIONS` route to the router. @@ -80,7 +80,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_options(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `PATCH` route to the router. @@ -93,7 +93,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_patch(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `POST` route to the router. @@ -106,7 +106,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_post(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `PUT` route to the router. @@ -119,7 +119,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_put(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `TRACE` route to the router. @@ -132,7 +132,7 @@ pub trait RouterExt: sealed::Sealed { fn typed_trace(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath; /// Add another route to the router with an additional "trailing slash redirect" route. @@ -184,7 +184,7 @@ where fn typed_get(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::get(handler)) @@ -194,7 +194,7 @@ where fn typed_delete(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::delete(handler)) @@ -204,7 +204,7 @@ where fn typed_head(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::head(handler)) @@ -214,7 +214,7 @@ where fn typed_options(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::options(handler)) @@ -224,7 +224,7 @@ where fn typed_patch(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::patch(handler)) @@ -234,7 +234,7 @@ where fn typed_post(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::post(handler)) @@ -244,7 +244,7 @@ where fn typed_put(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::put(handler)) @@ -254,7 +254,7 @@ where fn typed_trace(self, handler: H) -> Self where H: Handler, - T: FirstElementIs

+ 'static, + T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::trace(handler)) diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index 159472a063..aeb9936c0b 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -231,10 +231,10 @@ pub trait TypedPath: std::fmt::Display { } } -/// Utility trait used with [`RouterExt`] to ensure the first element of a tuple type is a +/// Utility trait used with [`RouterExt`] to ensure the second element of a tuple type is a /// given type. /// -/// If you see it in type errors its most likely because the first argument to your handler doesn't +/// If you see it in type errors its most likely because the second argument to your handler doesn't /// implement [`TypedPath`]. /// /// You normally shouldn't have to use this trait directly. @@ -242,56 +242,56 @@ pub trait TypedPath: std::fmt::Display { /// It is sealed such that it cannot be implemented outside this crate. /// /// [`RouterExt`]: super::RouterExt -pub trait FirstElementIs

: Sealed {} +pub trait SecondElementIs

: Sealed {} -macro_rules! impl_first_element_is { +macro_rules! impl_second_element_is { ( $($ty:ident),* $(,)? ) => { - impl FirstElementIs

for (P, $($ty,)*) + impl SecondElementIs

for (M, P, $($ty,)*) where P: TypedPath {} - impl Sealed for (P, $($ty,)*) + impl Sealed for (M, P, $($ty,)*) where P: TypedPath {} - impl FirstElementIs

for (Option

, $($ty,)*) + impl SecondElementIs

for (M, Option

, $($ty,)*) where P: TypedPath {} - impl Sealed for (Option

, $($ty,)*) + impl Sealed for (M, Option

, $($ty,)*) where P: TypedPath {} - impl FirstElementIs

for (Result, $($ty,)*) + impl SecondElementIs

for (M, Result, $($ty,)*) where P: TypedPath {} - impl Sealed for (Result, $($ty,)*) + impl Sealed for (M, Result, $($ty,)*) where P: TypedPath {} }; } -impl_first_element_is!(); -impl_first_element_is!(T1); -impl_first_element_is!(T1, T2); -impl_first_element_is!(T1, T2, T3); -impl_first_element_is!(T1, T2, T3, T4); -impl_first_element_is!(T1, T2, T3, T4, T5); -impl_first_element_is!(T1, T2, T3, T4, T5, T6); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); -impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); +impl_second_element_is!(); +impl_second_element_is!(T1); +impl_second_element_is!(T1, T2); +impl_second_element_is!(T1, T2, T3); +impl_second_element_is!(T1, T2, T3, T4); +impl_second_element_is!(T1, T2, T3, T4, T5); +impl_second_element_is!(T1, T2, T3, T4, T5, T6); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); +impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 1f71e638af..619c5bd7c6 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -10,9 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **change:** axum-macro's MSRV is now 1.60 ([#1239]) - **added:** Support using a different rejection for `#[derive(FromRequest)]` with `#[from_request(rejection(MyRejection))]` ([#1256]) +- **breaking:** `#[derive(FromRequest)]` will no longer generate a rejection + enum but instead generate `type Rejection = axum::response::Response`. Use the + new `#[from_request(rejection(MyRejection))]` attribute to change this. + The `rejection_derive` attribute has also been removed ([#1272]) [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1256]: https://github.com/tokio-rs/axum/pull/1256 +[#1272]: https://github.com/tokio-rs/axum/pull/1272 # 0.2.3 (27. June, 2022) diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 0fe8ad6af4..7c261b500e 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,14 +1,11 @@ -use std::collections::HashSet; - use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; +use std::collections::HashSet; use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type}; pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream { let check_extractor_count = check_extractor_count(&item_fn); - let check_request_last_extractor = check_request_last_extractor(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); - let check_multiple_body_extractors = check_multiple_body_extractors(&item_fn); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); // If the function is generic, we can't reliably check its inputs or whether the future it @@ -39,9 +36,7 @@ pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream { quote! { #item_fn #check_extractor_count - #check_request_last_extractor #check_path_extractor - #check_multiple_body_extractors #check_output_impls_into_response #check_inputs_and_future_send } @@ -135,22 +130,6 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator Option { - let request_extractor_ident = - extractor_idents(item_fn).find(|(_, _, ident)| *ident == "Request"); - - if let Some((idx, fn_arg, _)) = request_extractor_ident { - if idx != item_fn.sig.inputs.len() - 1 { - return Some( - syn::Error::new_spanned(fn_arg, "`Request` extractor should always be last") - .to_compile_error(), - ); - } - } - - None -} - fn check_path_extractor(item_fn: &ItemFn) -> TokenStream { let path_extractors = extractor_idents(item_fn) .filter(|(_, _, ident)| *ident == "Path") @@ -174,30 +153,14 @@ fn check_path_extractor(item_fn: &ItemFn) -> TokenStream { } } -fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream { - let body_extractors = extractor_idents(item_fn) - .filter(|(_, _, ident)| { - *ident == "String" - || *ident == "Bytes" - || *ident == "Json" - || *ident == "RawBody" - || *ident == "BodyStream" - || *ident == "Multipart" - || *ident == "Request" - }) - .collect::>(); - - if body_extractors.len() > 1 { - body_extractors - .into_iter() - .map(|(_, arg, _)| { - syn::Error::new_spanned(arg, "Only one body extractor can be applied") - .to_compile_error() - }) - .collect() +fn is_self_pat_type(typed: &syn::PatType) -> bool { + let ident = if let syn::Pat::Ident(ident) = &*typed.pat { + &ident.ident } else { - quote! {} - } + return false; + }; + + ident == "self" } fn check_inputs_impls_from_request( @@ -205,6 +168,11 @@ fn check_inputs_impls_from_request( body_ty: &Type, state_ty: Type, ) -> TokenStream { + let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg { + FnArg::Receiver(_) => true, + FnArg::Typed(typed) => is_self_pat_type(typed), + }); + item_fn .sig .inputs @@ -227,21 +195,53 @@ fn check_inputs_impls_from_request( FnArg::Typed(typed) => { let ty = &typed.ty; let span = ty.span(); - (span, ty.clone()) + + if is_self_pat_type(typed) { + (span, syn::parse_quote!(Self)) + } else { + (span, ty.clone()) + } } }; - let name = format_ident!( - "__axum_macros_check_{}_{}_from_request", + let check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_check", item_fn.sig.ident, - idx + idx, + span = span, ); + + let call_check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_call_check", + item_fn.sig.ident, + idx, + span = span, + ); + + let call_check_fn_body = if takes_self { + quote_spanned! {span=> + Self::#check_fn(); + } + } else { + quote_spanned! {span=> + #check_fn(); + } + }; + quote_spanned! {span=> #[allow(warnings)] - fn #name() + fn #check_fn() where - #ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send, + #ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send, {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + fn #call_check_fn() + { + #call_check_fn_body + } } }) .collect::() @@ -380,11 +380,11 @@ fn check_future_send(item_fn: &ItemFn) -> TokenStream { } fn self_receiver(item_fn: &ItemFn) -> Option { - let takes_self = item_fn - .sig - .inputs - .iter() - .any(|arg| matches!(arg, syn::FnArg::Receiver(_))); + let takes_self = item_fn.sig.inputs.iter().any(|arg| match arg { + FnArg::Receiver(_) => true, + FnArg::Typed(typed) => is_self_pat_type(typed), + }); + if takes_self { return Some(quote! { Self:: }); } diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index e1e3d457c7..596d23d362 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -1,10 +1,8 @@ use self::attr::{ parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr, - RejectionDeriveOptOuts, }; -use heck::ToUpperCamelCase; use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, quote, quote_spanned}; +use quote::{quote, quote_spanned}; use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token}; mod attr; @@ -18,7 +16,7 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result { generics, fields, semi_token: _, - vis, + vis: _, struct_token: _, } = item; @@ -34,32 +32,15 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result { generic_ident, ) } - FromRequestContainerAttr::RejectionDerive(_, opt_outs) => { - error_on_generic_ident(generic_ident)?; - - impl_struct_by_extracting_each_field(ident, fields, vis, opt_outs, None) - } FromRequestContainerAttr::Rejection(rejection) => { error_on_generic_ident(generic_ident)?; - impl_struct_by_extracting_each_field( - ident, - fields, - vis, - RejectionDeriveOptOuts::default(), - Some(rejection), - ) + impl_struct_by_extracting_each_field(ident, fields, Some(rejection)) } FromRequestContainerAttr::None => { error_on_generic_ident(generic_ident)?; - impl_struct_by_extracting_each_field( - ident, - fields, - vis, - RejectionDeriveOptOuts::default(), - None, - ) + impl_struct_by_extracting_each_field(ident, fields, None) } } } @@ -88,12 +69,6 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result { FromRequestContainerAttr::Via { path, rejection } => { impl_enum_by_extracting_all_at_once(ident, variants, path, rejection) } - FromRequestContainerAttr::RejectionDerive(rejection_derive, _) => { - Err(syn::Error::new_spanned( - rejection_derive, - "cannot use `rejection_derive` on enums", - )) - } FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned( rejection, "cannot use `rejection` without `via`", @@ -197,22 +172,16 @@ fn error_on_generic_ident(generic_ident: Option) -> syn::Result<()> { fn impl_struct_by_extracting_each_field( ident: syn::Ident, fields: syn::Fields, - vis: syn::Visibility, - rejection_derive_opt_outs: RejectionDeriveOptOuts, rejection: Option, ) -> syn::Result { let extract_fields = extract_fields(&fields, &rejection)?; - let (rejection_ident, rejection) = if let Some(rejection) = rejection { - let rejection_ident = syn::parse_quote!(#rejection); - (rejection_ident, None) + let rejection_ident = if let Some(rejection) = rejection { + quote!(#rejection) } else if has_no_fields(&fields) { - (syn::parse_quote!(::std::convert::Infallible), None) + quote!(::std::convert::Infallible) } else { - let rejection_ident = rejection_ident(&ident); - let rejection = - extract_each_field_rejection(&ident, &fields, &vis, rejection_derive_opt_outs)?; - (rejection_ident, Some(rejection)) + quote!(::axum::response::Response) }; Ok(quote! { @@ -228,15 +197,14 @@ fn impl_struct_by_extracting_each_field( type Rejection = #rejection_ident; async fn from_request( - req: &mut ::axum::extract::RequestParts, + mut req: axum::http::Request, + state: &S, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* }) } } - - #rejection }) } @@ -248,11 +216,6 @@ fn has_no_fields(fields: &syn::Fields) -> bool { } } -fn rejection_ident(ident: &syn::Ident) -> syn::Type { - let ident = format_ident!("{}Rejection", ident); - syn::parse_quote!(#ident) -} - fn extract_fields( fields: &syn::Fields, rejection: &Option, @@ -261,6 +224,8 @@ fn extract_fields( .iter() .enumerate() .map(|(index, field)| { + let is_last = fields.len() - 1 == index; + let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; let member = if let Some(ident) = &field.ident { @@ -286,40 +251,79 @@ fn extract_fields( } }; - let rejection_variant_name = rejection_variant_name(field)?; - if peel_option(&field.ty).is_some() { - Ok(quote_spanned! {ty_span=> - #member: { - ::axum::extract::FromRequest::from_request(req) - .await - .ok() - .map(#into_inner) - }, - }) + if is_last { + Ok(quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequest::from_request(req, state) + .await + .ok() + .map(#into_inner) + }, + }) + } else { + Ok(quote_spanned! {ty_span=> + #member: { + let (mut parts, body) = req.into_parts(); + let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state) + .await + .ok() + .map(#into_inner); + req = ::axum::http::Request::from_parts(parts, body); + value + }, + }) + } } else if peel_result_ok(&field.ty).is_some() { - Ok(quote_spanned! {ty_span=> - #member: { - ::axum::extract::FromRequest::from_request(req) - .await - .map(#into_inner) - }, - }) + if is_last { + Ok(quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequest::from_request(req, state) + .await + .map(#into_inner) + }, + }) + } else { + Ok(quote_spanned! {ty_span=> + #member: { + let (mut parts, body) = req.into_parts(); + let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state) + .await + .map(#into_inner); + req = ::axum::http::Request::from_parts(parts, body); + value + }, + }) + } } else { let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { - quote! { Self::Rejection::#rejection_variant_name } + quote! { ::axum::response::IntoResponse::into_response } }; - Ok(quote_spanned! {ty_span=> - #member: { - ::axum::extract::FromRequest::from_request(req) - .await - .map(#into_inner) - .map_err(#map_err)? - }, - }) + if is_last { + Ok(quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequest::from_request(req, state) + .await + .map(#into_inner) + .map_err(#map_err)? + }, + }) + } else { + Ok(quote_spanned! {ty_span=> + #member: { + let (mut parts, body) = req.into_parts(); + let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state) + .await + .map(#into_inner) + .map_err(#map_err)?; + req = ::axum::http::Request::from_parts(parts, body); + value + }, + }) + } } }) .collect() @@ -387,199 +391,6 @@ fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> { } } -fn extract_each_field_rejection( - ident: &syn::Ident, - fields: &syn::Fields, - vis: &syn::Visibility, - rejection_derive_opt_outs: RejectionDeriveOptOuts, -) -> syn::Result { - let rejection_ident = rejection_ident(ident); - - let variants = fields - .iter() - .map(|field| { - let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; - - let field_ty = &field.ty; - let ty_span = field_ty.span(); - - let variant_name = rejection_variant_name(field)?; - - let extractor_ty = if let Some((_, path)) = via { - if let Some(inner) = peel_option(field_ty) { - quote_spanned! {ty_span=> - ::std::option::Option<#path<#inner>> - } - } else if let Some(inner) = peel_result_ok(field_ty) { - quote_spanned! {ty_span=> - ::std::result::Result<#path<#inner>, TypedHeaderRejection> - } - } else { - quote_spanned! {ty_span=> #path<#field_ty> } - } - } else { - quote_spanned! {ty_span=> #field_ty } - }; - - Ok(quote_spanned! {ty_span=> - #[allow(non_camel_case_types)] - #variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection), - }) - }) - .collect::>>()?; - - let impl_into_response = { - let arms = fields - .iter() - .map(|field| { - let variant_name = rejection_variant_name(field)?; - Ok(quote! { - Self::#variant_name(inner) => inner.into_response(), - }) - }) - .collect::>>()?; - - quote! { - #[automatically_derived] - impl ::axum::response::IntoResponse for #rejection_ident { - fn into_response(self) -> ::axum::response::Response { - match self { - #(#arms)* - } - } - } - } - }; - - let impl_display = if rejection_derive_opt_outs.derive_display() { - let arms = fields - .iter() - .map(|field| { - let variant_name = rejection_variant_name(field)?; - Ok(quote! { - Self::#variant_name(inner) => inner.fmt(f), - }) - }) - .collect::>>()?; - - Some(quote! { - #[automatically_derived] - impl ::std::fmt::Display for #rejection_ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - match self { - #(#arms)* - } - } - } - }) - } else { - None - }; - - let impl_error = if rejection_derive_opt_outs.derive_error() { - let arms = fields - .iter() - .map(|field| { - let variant_name = rejection_variant_name(field)?; - Ok(quote! { - Self::#variant_name(inner) => Some(inner), - }) - }) - .collect::>>()?; - - Some(quote! { - #[automatically_derived] - impl ::std::error::Error for #rejection_ident { - fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> { - match self { - #(#arms)* - } - } - } - }) - } else { - None - }; - - let impl_debug = rejection_derive_opt_outs.derive_debug().then(|| { - quote! { #[derive(Debug)] } - }); - - Ok(quote! { - #impl_debug - #vis enum #rejection_ident { - #(#variants)* - } - - #impl_into_response - #impl_display - #impl_error - }) -} - -fn rejection_variant_name(field: &syn::Field) -> syn::Result { - fn rejection_variant_name_for_type(out: &mut String, ty: &syn::Type) -> syn::Result<()> { - if let syn::Type::Path(type_path) = ty { - let segment = type_path - .path - .segments - .last() - .ok_or_else(|| syn::Error::new_spanned(ty, "Empty type path"))?; - - out.push_str(&segment.ident.to_string()); - - match &segment.arguments { - syn::PathArguments::AngleBracketed(args) => { - let ty = if args.args.len() == 1 { - args.args.last().unwrap() - } else if args.args.len() == 2 { - if segment.ident == "Result" { - args.args.first().unwrap() - } else { - return Err(syn::Error::new_spanned( - segment, - "Only `Result` is supported with two generics type paramters", - )); - } - } else { - return Err(syn::Error::new_spanned( - &args.args, - "Expected exactly one or two type paramters", - )); - }; - - if let syn::GenericArgument::Type(ty) = ty { - rejection_variant_name_for_type(out, ty) - } else { - Err(syn::Error::new_spanned(ty, "Expected type path")) - } - } - syn::PathArguments::Parenthesized(args) => { - Err(syn::Error::new_spanned(args, "Unsupported")) - } - syn::PathArguments::None => Ok(()), - } - } else { - Err(syn::Error::new_spanned(ty, "Expected type path")) - } - } - - if let Some(ident) = &field.ident { - Ok(format_ident!("{}", ident.to_string().to_upper_camel_case())) - } else { - let mut out = String::new(); - rejection_variant_name_for_type(&mut out, &field.ty)?; - - let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; - if let Some((_, path)) = via { - let via_ident = &path.segments.last().unwrap().ident; - Ok(format_ident!("{}{}", via_ident, out)) - } else { - Ok(format_ident!("{}", out)) - } - } -} - fn impl_struct_by_extracting_all_at_once( ident: syn::Ident, fields: syn::Fields, @@ -606,12 +417,16 @@ fn impl_struct_by_extracting_all_at_once( let path_span = path.span(); - let associated_rejection_type = if let Some(rejection) = &rejection { - quote! { #rejection } + let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { + let rejection = quote! { #rejection }; + let map_err = quote! { ::std::convert::From::from }; + (rejection, map_err) } else { - quote! { - <#path as ::axum::extract::FromRequest>::Rejection - } + let rejection = quote! { + ::axum::response::Response + }; + let map_err = quote! { ::axum::response::IntoResponse::into_response }; + (rejection, map_err) }; let rejection_bound = rejection.as_ref().map(|rejection| { @@ -658,18 +473,19 @@ fn impl_struct_by_extracting_all_at_once( where #path<#via_type_generics>: ::axum::extract::FromRequest, #rejection_bound - B: ::std::marker::Send, + B: ::std::marker::Send + 'static, S: ::std::marker::Send + ::std::marker::Sync, { type Rejection = #associated_rejection_type; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: ::axum::http::Request, + state: &S ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::from_request(req, state) .await .map(|#path(value)| #value_to_self) - .map_err(::std::convert::From::from) + .map_err(#map_err) } } }) @@ -707,12 +523,16 @@ fn impl_enum_by_extracting_all_at_once( } } - let associated_rejection_type = if let Some(rejection) = rejection { - quote! { #rejection } + let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { + let rejection = quote! { #rejection }; + let map_err = quote! { ::std::convert::From::from }; + (rejection, map_err) } else { - quote! { - <#path as ::axum::extract::FromRequest>::Rejection - } + let rejection = quote! { + ::axum::response::Response + }; + let map_err = quote! { ::axum::response::IntoResponse::into_response }; + (rejection, map_err) }; let path_span = path.span(); @@ -730,12 +550,13 @@ fn impl_enum_by_extracting_all_at_once( type Rejection = #associated_rejection_type; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: ::axum::http::Request, + state: &S ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::from_request(req, state) .await .map(|#path(inner)| inner) - .map_err(::std::convert::From::from) + .map_err(#map_err) } } }) diff --git a/axum-macros/src/from_request/attr.rs b/axum-macros/src/from_request/attr.rs index 46b2e518ee..9fc1080251 100644 --- a/axum-macros/src/from_request/attr.rs +++ b/axum-macros/src/from_request/attr.rs @@ -16,13 +16,11 @@ pub(crate) enum FromRequestContainerAttr { rejection: Option, }, Rejection(syn::Path), - RejectionDerive(kw::rejection_derive, RejectionDeriveOptOuts), None, } pub(crate) mod kw { syn::custom_keyword!(via); - syn::custom_keyword!(rejection_derive); syn::custom_keyword!(rejection); syn::custom_keyword!(Display); syn::custom_keyword!(Debug); @@ -55,7 +53,6 @@ pub(crate) fn parse_container_attrs( let attrs = parse_attrs::(attrs)?; let mut out_via = None; - let mut out_rejection_derive = None; let mut out_rejection = None; // we track the index of the attribute to know which comes last @@ -69,16 +66,6 @@ pub(crate) fn parse_container_attrs( out_via = Some((idx, via, path)); } } - ContainerAttr::RejectionDerive { - rejection_derive, - opt_outs, - } => { - if out_rejection_derive.is_some() { - return Err(double_attr_error("rejection_derive", rejection_derive)); - } else { - out_rejection_derive = Some((idx, rejection_derive, opt_outs)); - } - } ContainerAttr::Rejection { rejection, path } => { if out_rejection.is_some() { return Err(double_attr_error("rejection", rejection)); @@ -89,55 +76,20 @@ pub(crate) fn parse_container_attrs( } } - match (out_via, out_rejection_derive, out_rejection) { - (Some((via_idx, via, _)), Some((rejection_derive_idx, rejection_derive, _)), _) => { - if via_idx > rejection_derive_idx { - Err(syn::Error::new_spanned( - via, - "cannot use both `rejection_derive` and `via`", - )) - } else { - Err(syn::Error::new_spanned( - rejection_derive, - "cannot use both `via` and `rejection_derive`", - )) - } - } - - ( - _, - Some((rejection_derive_idx, rejection_derive, _)), - Some((rejection_idx, rejection, _)), - ) => { - if rejection_idx > rejection_derive_idx { - Err(syn::Error::new_spanned( - rejection, - "cannot use both `rejection_derive` and `rejection`", - )) - } else { - Err(syn::Error::new_spanned( - rejection_derive, - "cannot use both `rejection` and `rejection_derive`", - )) - } - } - - (Some((_, _, path)), None, None) => Ok(FromRequestContainerAttr::Via { + match (out_via, out_rejection) { + (Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via { path, rejection: None, }), - (Some((_, _, path)), None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via { + + (Some((_, _, path)), Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via { path, rejection: Some(rejection), }), - (None, Some((_, rejection_derive, opt_outs)), _) => Ok( - FromRequestContainerAttr::RejectionDerive(rejection_derive, opt_outs), - ), - - (None, None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)), + (None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)), - (None, None, None) => Ok(FromRequestContainerAttr::None), + (None, None) => Ok(FromRequestContainerAttr::None), } } @@ -172,10 +124,6 @@ enum ContainerAttr { rejection: kw::rejection, path: syn::Path, }, - RejectionDerive { - rejection_derive: kw::rejection_derive, - opt_outs: RejectionDeriveOptOuts, - }, } impl Parse for ContainerAttr { @@ -186,14 +134,6 @@ impl Parse for ContainerAttr { let content; syn::parenthesized!(content in input); content.parse().map(|path| Self::Via { via, path }) - } else if lh.peek(kw::rejection_derive) { - let rejection_derive = input.parse::()?; - let content; - syn::parenthesized!(content in input); - content.parse().map(|opt_outs| Self::RejectionDerive { - rejection_derive, - opt_outs, - }) } else if lh.peek(kw::rejection) { let rejection = input.parse::()?; let content; @@ -224,82 +164,3 @@ impl Parse for FieldAttr { } } } - -#[derive(Default)] -pub(crate) struct RejectionDeriveOptOuts { - debug: Option, - display: Option, - error: Option, -} - -impl RejectionDeriveOptOuts { - pub(crate) fn derive_debug(&self) -> bool { - self.debug.is_none() - } - - pub(crate) fn derive_display(&self) -> bool { - self.display.is_none() - } - - pub(crate) fn derive_error(&self) -> bool { - self.error.is_none() - } -} - -impl Parse for RejectionDeriveOptOuts { - fn parse(input: ParseStream) -> syn::Result { - fn parse_opt_out(out: &mut Option, ident: &str, input: ParseStream) -> syn::Result<()> - where - T: Parse, - { - if out.is_some() { - Err(input.error(format!("`{}` opt out specified more than once", ident))) - } else { - *out = Some(input.parse()?); - Ok(()) - } - } - - let mut debug = None::; - let mut display = None::; - let mut error = None::; - - while !input.is_empty() { - input.parse::()?; - - let lh = input.lookahead1(); - if lh.peek(kw::Debug) { - parse_opt_out(&mut debug, "Debug", input)?; - } else if lh.peek(kw::Display) { - parse_opt_out(&mut display, "Display", input)?; - } else if lh.peek(kw::Error) { - parse_opt_out(&mut error, "Error", input)?; - } else { - return Err(lh.error()); - } - - input.parse::().ok(); - } - - if error.is_none() { - match (debug, display) { - (Some(debug), Some(_)) => { - return Err(syn::Error::new_spanned(debug, "opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`")); - } - (Some(debug), None) => { - return Err(syn::Error::new_spanned(debug, "opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`")); - } - (None, Some(display)) => { - return Err(syn::Error::new_spanned(display, "opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`")); - } - (None, None) => {} - } - } - - Ok(Self { - debug, - display, - error, - }) - } -} diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 0242ddca66..47442724bb 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -86,6 +86,20 @@ mod typed_path; /// /// This requires that each field is an extractor (i.e. implements [`FromRequest`]). /// +/// ```compile_fail +/// use axum_macros::FromRequest; +/// use axum::body::Bytes; +/// +/// #[derive(FromRequest)] +/// struct MyExtractor { +/// // only the last field can implement `FromRequest` +/// // other fields must only implement `FromRequestParts` +/// bytes: Bytes, +/// string: String, +/// } +/// ``` +/// Note that only the last field can consume the request body. Therefore this doesn't compile: +/// /// ## Extracting via another extractor /// /// You can use `#[from_request(via(...))]` to extract a field via another extractor, meaning the @@ -157,95 +171,15 @@ mod typed_path; /// /// ## The rejection /// -/// A rejection enum is also generated. It has a variant for each field: -/// -/// ``` -/// use axum_macros::FromRequest; -/// use axum::{ -/// extract::{Extension, TypedHeader}, -/// headers::ContentType, -/// body::Bytes, -/// }; -/// -/// #[derive(FromRequest)] -/// struct MyExtractor { -/// #[from_request(via(Extension))] -/// state: State, -/// #[from_request(via(TypedHeader))] -/// content_type: ContentType, -/// request_body: Bytes, -/// } -/// -/// // also generates -/// // -/// // #[derive(Debug)] -/// // enum MyExtractorRejection { -/// // State(ExtensionRejection), -/// // ContentType(TypedHeaderRejection), -/// // RequestBody(BytesRejection), -/// // } -/// // -/// // impl axum::response::IntoResponse for MyExtractor { ... } -/// // -/// // impl std::fmt::Display for MyExtractor { ... } -/// // -/// // impl std::error::Error for MyExtractor { ... } -/// -/// #[derive(Clone)] -/// struct State { -/// // ... -/// } -/// ``` -/// -/// The rejection's `std::error::Error::source` implementation returns the inner rejection. This -/// can be used to access source errors for example to customize rejection responses. Note this -/// means the inner rejection types must themselves implement `std::error::Error`. All extractors -/// in axum does this. -/// -/// You can opt out of this using `#[from_request(rejection_derive(...))]`: -/// -/// ``` -/// use axum_macros::FromRequest; -/// use axum::{ -/// extract::{FromRequest, RequestParts}, -/// http::StatusCode, -/// headers::ContentType, -/// body::Bytes, -/// async_trait, -/// }; -/// -/// #[derive(FromRequest)] -/// #[from_request(rejection_derive(!Display, !Error))] -/// struct MyExtractor { -/// other: OtherExtractor, -/// } -/// -/// struct OtherExtractor; -/// -/// #[async_trait] -/// impl FromRequest for OtherExtractor -/// where -/// B: Send, -/// S: Send + Sync, -/// { -/// // this rejection doesn't implement `Display` and `Error` -/// type Rejection = (StatusCode, String); -/// -/// async fn from_request(_req: &mut RequestParts) -> Result { -/// // ... -/// # unimplemented!() -/// } -/// } -/// ``` -/// -/// You can also use your own rejection type with `#[from_request(rejection(YourType))]`: +/// By default [`axum::response::Response`] will be used as the rejection. You can also use your own +/// rejection type with `#[from_request(rejection(YourType))]`: /// /// ``` /// use axum_macros::FromRequest; /// use axum::{ /// extract::{ /// rejection::{ExtensionRejection, StringRejection}, -/// FromRequest, RequestParts, +/// FromRequest, /// }, /// Extension, /// response::{Response, IntoResponse}, @@ -414,6 +348,7 @@ mod typed_path; /// ``` /// /// [`FromRequest`]: https://docs.rs/axum/latest/axum/extract/trait.FromRequest.html +/// [`axum::response::Response`]: https://docs.rs/axum/0.6/axum/response/type.Response.html /// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html #[proc_macro_derive(FromRequest, attributes(from_request))] pub fn derive_from_request(item: TokenStream) -> TokenStream { diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 9df4702881..efbf733cc0 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -127,15 +127,17 @@ fn expand_named_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequestParts for #ident where - B: Send, S: Send + Sync, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { - ::axum::extract::Path::from_request(req) + async fn from_request_parts( + parts: &mut ::axum::http::request::Parts, + state: &S, + ) -> ::std::result::Result { + ::axum::extract::Path::from_request_parts(parts, state) .await .map(|path| path.0) #map_err_rejection @@ -230,15 +232,17 @@ fn expand_unnamed_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequestParts for #ident where - B: Send, S: Send + Sync, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { - ::axum::extract::Path::from_request(req) + async fn from_request_parts( + parts: &mut ::axum::http::request::Parts, + state: &S, + ) -> ::std::result::Result { + ::axum::extract::Path::from_request_parts(parts, state) .await .map(|path| path.0) #map_err_rejection @@ -312,15 +316,17 @@ fn expand_unit_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequestParts for #ident where - B: Send, S: Send + Sync, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { - if req.uri().path() == ::PATH { + async fn from_request_parts( + parts: &mut ::axum::http::request::Parts, + _state: &S, + ) -> ::std::result::Result { + if parts.uri.path() == ::PATH { Ok(Self) } else { #create_rejection @@ -390,7 +396,7 @@ enum Segment { fn path_rejection() -> TokenStream { quote! { - <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection + <::axum::extract::Path as ::axum::extract::FromRequestParts>::Rejection } } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 265258419e..8d46455ca2 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,17 +1,22 @@ -error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied +error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool` + | ^^^^ the trait `FromRequestParts<()>` is not implemented for `bool` | - = help: the following other types implement trait `FromRequest`: - <() as FromRequest> - <(T1, T2) as FromRequest> - <(T1, T2, T3) as FromRequest> - <(T1, T2, T3, T4) as FromRequest> - <(T1, T2, T3, T4, T5) as FromRequest> - <(T1, T2, T3, T4, T5, T6) as FromRequest> - <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> - <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> - and 34 others - = help: see issue #48214 + = help: the following other types implement trait `FromRequestParts`: + <() as FromRequestParts> + <(T1, T2) as FromRequestParts> + <(T1, T2, T3) as FromRequestParts> + <(T1, T2, T3, T4) as FromRequestParts> + <(T1, T2, T3, T4, T5) as FromRequestParts> + <(T1, T2, T3, T4, T5, T6) as FromRequestParts> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> + and 26 others + = note: required because of the requirements on the impl of `FromRequest<(), Body, axum_core::extract::private::ViaParts>` for `bool` +note: required by a bound in `__axum_macros_check_handler_0_from_request_check` + --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 + | +4 | async fn handler(foo: bool) {} + | ^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check` diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index 168a1c8177..d20426e22f 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -1,6 +1,7 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::FromRequest, + http::Request, }; use axum_macros::debug_handler; @@ -9,12 +10,12 @@ struct A; #[async_trait] impl FromRequest for A where - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr index 3d80dffbca..1e1a9ec384 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_mut.rs:24:22 + --> tests/debug_handler/fail/extract_self_mut.rs:25:22 | -24 | async fn handler(&mut self) {} +25 | async fn handler(&mut self) {} | ^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index 4090265cd6..77940e2996 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -1,6 +1,7 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::FromRequest, + http::Request, }; use axum_macros::debug_handler; @@ -9,12 +10,12 @@ struct A; #[async_trait] impl FromRequest for A where - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr index 82d9a89ff5..79f9d190f5 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_ref.rs:24:22 + --> tests/debug_handler/fail/extract_self_ref.rs:25:22 | -24 | async fn handler(&self) {} +25 | async fn handler(&self) {} | ^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/multiple_body_extractors.rs b/axum-macros/tests/debug_handler/fail/multiple_body_extractors.rs deleted file mode 100644 index 875c75407e..0000000000 --- a/axum-macros/tests/debug_handler/fail/multiple_body_extractors.rs +++ /dev/null @@ -1,7 +0,0 @@ -use axum_macros::debug_handler; -use axum::body::Bytes; - -#[debug_handler] -async fn handler(_: String, _: Bytes) {} - -fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/multiple_body_extractors.stderr b/axum-macros/tests/debug_handler/fail/multiple_body_extractors.stderr deleted file mode 100644 index 098f3675d0..0000000000 --- a/axum-macros/tests/debug_handler/fail/multiple_body_extractors.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: Only one body extractor can be applied - --> tests/debug_handler/fail/multiple_body_extractors.rs:5:18 - | -5 | async fn handler(_: String, _: Bytes) {} - | ^^^^^^^^^ - -error: Only one body extractor can be applied - --> tests/debug_handler/fail/multiple_body_extractors.rs:5:29 - | -5 | async fn handler(_: String, _: Bytes) {} - | ^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/request_not_last.rs b/axum-macros/tests/debug_handler/fail/request_not_last.rs deleted file mode 100644 index 153d35ef3f..0000000000 --- a/axum-macros/tests/debug_handler/fail/request_not_last.rs +++ /dev/null @@ -1,7 +0,0 @@ -use axum::{body::Body, extract::Extension, http::Request}; -use axum_macros::debug_handler; - -#[debug_handler] -async fn handler(_: Request, _: Extension) {} - -fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/request_not_last.stderr b/axum-macros/tests/debug_handler/fail/request_not_last.stderr deleted file mode 100644 index a3482e6486..0000000000 --- a/axum-macros/tests/debug_handler/fail/request_not_last.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: `Request` extractor should always be last - --> tests/debug_handler/fail/request_not_last.rs:5:18 - | -5 | async fn handler(_: Request, _: Extension) {} - | ^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr index c3ca7e1e80..89a5ed55ad 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr +++ b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr @@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied (Response<()>, T1, T2, R) (Response<()>, T1, T2, T3, R) (Response<()>, T1, T2, T3, T4, R) - and 123 others + and 122 others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/wrong_return_type.rs:4:23 | diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 4941f59638..c4be8b52af 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -1,8 +1,4 @@ -use axum::{ - async_trait, - extract::{FromRequest, RequestParts}, - response::IntoResponse, -}; +use axum::{async_trait, extract::FromRequest, http::Request, response::IntoResponse}; use axum_macros::debug_handler; fn main() {} @@ -122,12 +118,12 @@ impl A { #[async_trait] impl FromRequest for A where - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index a926eb7f44..e7bf81ce6c 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -1,6 +1,7 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::FromRequest, + http::Request, }; use axum_macros::debug_handler; @@ -9,12 +10,25 @@ struct A; #[async_trait] impl FromRequest for A where - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: Request, _state: &S) -> Result { + unimplemented!() + } +} + +#[async_trait] +impl FromRequest for Box +where + B: Send + 'static, + S: Send + Sync, +{ + type Rejection = (); + + async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } @@ -22,6 +36,9 @@ where impl A { #[debug_handler] async fn handler(self) {} + + #[debug_handler] + async fn handler_with_qualified_self(self: Box) {} } fn main() {} diff --git a/axum-macros/tests/debug_handler/pass/set_state.rs b/axum-macros/tests/debug_handler/pass/set_state.rs index 12afaf1059..5c84dbd25b 100644 --- a/axum-macros/tests/debug_handler/pass/set_state.rs +++ b/axum-macros/tests/debug_handler/pass/set_state.rs @@ -1,6 +1,7 @@ use axum_macros::debug_handler; -use axum::extract::{FromRef, FromRequest, RequestParts}; +use axum::extract::{FromRef, FromRequest}; use axum::async_trait; +use axum::http::Request; #[debug_handler(state = AppState)] async fn handler(_: A) {} @@ -13,13 +14,13 @@ struct A; #[async_trait] impl FromRequest for A where - B: Send, + B: Send + 'static, S: Send + Sync, AppState: FromRef, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs b/axum-macros/tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs deleted file mode 100644 index a8cb0818aa..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest)] -#[from_request(rejection_derive(!Debug, !Display))] -struct Extractor { - body: String, -} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_debug_and_display_without_error.stderr b/axum-macros/tests/from_request/fail/derive_opt_out_debug_and_display_without_error.stderr deleted file mode 100644 index 656c8a546a..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_debug_and_display_without_error.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]` - --> tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs:4:34 - | -4 | #[from_request(rejection_derive(!Debug, !Display))] - | ^^^^^ diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_debug_without_error.rs b/axum-macros/tests/from_request/fail/derive_opt_out_debug_without_error.rs deleted file mode 100644 index dbc0aed82f..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_debug_without_error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest)] -#[from_request(rejection_derive(!Debug))] -struct Extractor { - body: String, -} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_debug_without_error.stderr b/axum-macros/tests/from_request/fail/derive_opt_out_debug_without_error.stderr deleted file mode 100644 index 1d8c287568..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_debug_without_error.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]` - --> tests/from_request/fail/derive_opt_out_debug_without_error.rs:4:34 - | -4 | #[from_request(rejection_derive(!Debug))] - | ^^^^^ diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_display_without_error.rs b/axum-macros/tests/from_request/fail/derive_opt_out_display_without_error.rs deleted file mode 100644 index 2f4787685b..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_display_without_error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest)] -#[from_request(rejection_derive(!Display))] -struct Extractor { - body: String, -} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_display_without_error.stderr b/axum-macros/tests/from_request/fail/derive_opt_out_display_without_error.stderr deleted file mode 100644 index 0db0375958..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_display_without_error.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]` - --> tests/from_request/fail/derive_opt_out_display_without_error.rs:4:34 - | -4 | #[from_request(rejection_derive(!Display))] - | ^^^^^^^ diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_duplicate.rs b/axum-macros/tests/from_request/fail/derive_opt_out_duplicate.rs deleted file mode 100644 index 0a0b2eb895..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_duplicate.rs +++ /dev/null @@ -1,9 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest)] -#[from_request(rejection_derive(!Error, !Error))] -struct Extractor { - body: String, -} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/derive_opt_out_duplicate.stderr b/axum-macros/tests/from_request/fail/derive_opt_out_duplicate.stderr deleted file mode 100644 index 7ae523d4b8..0000000000 --- a/axum-macros/tests/from_request/fail/derive_opt_out_duplicate.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: `Error` opt out specified more than once - --> tests/from_request/fail/derive_opt_out_duplicate.rs:4:42 - | -4 | #[from_request(rejection_derive(!Error, !Error))] - | ^^^^^ diff --git a/axum-macros/tests/from_request/fail/enum_rejection_derive.rs b/axum-macros/tests/from_request/fail/enum_rejection_derive.rs deleted file mode 100644 index f343d54483..0000000000 --- a/axum-macros/tests/from_request/fail/enum_rejection_derive.rs +++ /dev/null @@ -1,7 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest, Clone)] -#[from_request(rejection_derive(!Error))] -enum Extractor {} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/enum_rejection_derive.stderr b/axum-macros/tests/from_request/fail/enum_rejection_derive.stderr deleted file mode 100644 index 1e721d760e..0000000000 --- a/axum-macros/tests/from_request/fail/enum_rejection_derive.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: cannot use `rejection_derive` on enums - --> tests/from_request/fail/enum_rejection_derive.rs:4:16 - | -4 | #[from_request(rejection_derive(!Error))] - | ^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs deleted file mode 100644 index ec5bb80099..0000000000 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs +++ /dev/null @@ -1,12 +0,0 @@ -use axum::{body::Body, routing::get, Router}; -use axum_macros::FromRequest; - -#[derive(FromRequest, Clone)] -#[from_request(rejection_derive(!Error))] -struct Extractor(T); - -async fn foo(_: Extractor<()>) {} - -fn main() { - Router::<(), Body>::new().route("/", get(foo)); -} diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr deleted file mode 100644 index 10b674c150..0000000000 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr +++ /dev/null @@ -1,21 +0,0 @@ -error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)] - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:6:18 - | -6 | struct Extractor(T); - | ^ - -error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46 - | -11 | Router::<(), Body>::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call - | - = help: the trait `Handler` is implemented for `Layered` -note: required by a bound in `axum::routing::get` - --> $WORKSPACE/axum/src/routing/method_routing.rs - | - | top_level_handler_fn!(get, GET); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get` - = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs b/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs deleted file mode 100644 index a369150d60..0000000000 --- a/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs +++ /dev/null @@ -1,9 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest, Clone)] -#[from_request(rejection_derive(!Error), via(axum::Extension))] -struct Extractor { - config: String, -} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr b/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr deleted file mode 100644 index 3a50044f83..0000000000 --- a/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: cannot use both `rejection_derive` and `via` - --> tests/from_request/fail/rejection_derive_and_via.rs:4:42 - | -4 | #[from_request(rejection_derive(!Error), via(axum::Extension))] - | ^^^ diff --git a/axum-macros/tests/from_request/fail/unknown_attr_container.stderr b/axum-macros/tests/from_request/fail/unknown_attr_container.stderr index 9d89c0851e..25eeca56a5 100644 --- a/axum-macros/tests/from_request/fail/unknown_attr_container.stderr +++ b/axum-macros/tests/from_request/fail/unknown_attr_container.stderr @@ -1,4 +1,4 @@ -error: expected one of: `via`, `rejection_derive`, `rejection` +error: expected `via` or `rejection` --> tests/from_request/fail/unknown_attr_container.rs:4:16 | 4 | #[from_request(foo)] diff --git a/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs b/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs deleted file mode 100644 index 5f42ef0cf7..0000000000 --- a/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs +++ /dev/null @@ -1,9 +0,0 @@ -use axum_macros::FromRequest; - -#[derive(FromRequest, Clone)] -#[from_request(via(axum::Extension), rejection_derive(!Error))] -struct Extractor { - config: String, -} - -fn main() {} diff --git a/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr b/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr deleted file mode 100644 index af45e8f811..0000000000 --- a/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: cannot use both `via` and `rejection_derive` - --> tests/from_request/fail/via_and_rejection_derive.rs:4:38 - | -4 | #[from_request(via(axum::Extension), rejection_derive(!Error))] - | ^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/from_request/pass/container.rs b/axum-macros/tests/from_request/pass/container.rs index e8eaa0a58a..c125902bad 100644 --- a/axum-macros/tests/from_request/pass/container.rs +++ b/axum-macros/tests/from_request/pass/container.rs @@ -1,6 +1,7 @@ use axum::{ body::Body, - extract::{rejection::JsonRejection, FromRequest, Json}, + extract::{FromRequest, Json}, + response::Response, }; use axum_macros::FromRequest; use serde::Deserialize; @@ -15,7 +16,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest<(), Body, Rejection = JsonRejection>, + Extractor: FromRequest<(), Body, Rejection = Response>, { } diff --git a/axum-macros/tests/from_request/pass/derive_opt_out.rs b/axum-macros/tests/from_request/pass/derive_opt_out.rs deleted file mode 100644 index c5ef9deb90..0000000000 --- a/axum-macros/tests/from_request/pass/derive_opt_out.rs +++ /dev/null @@ -1,38 +0,0 @@ -use axum::{ - async_trait, - extract::{FromRequest, RequestParts}, - response::{IntoResponse, Response}, -}; -use axum_macros::FromRequest; - -#[derive(FromRequest)] -#[from_request(rejection_derive(!Display, !Error))] -struct Extractor { - other: OtherExtractor, -} - -struct OtherExtractor; - -#[async_trait] -impl FromRequest for OtherExtractor -where - B: Send, - S: Send + Sync, -{ - type Rejection = OtherExtractorRejection; - - async fn from_request(_req: &mut RequestParts) -> Result { - unimplemented!() - } -} - -#[derive(Debug)] -struct OtherExtractorRejection; - -impl IntoResponse for OtherExtractorRejection { - fn into_response(self) -> Response { - unimplemented!() - } -} - -fn main() {} diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index 89fb8da004..092989d17e 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -1,10 +1,10 @@ use axum::{ body::Body, - extract::{FromRequest, TypedHeader, rejection::{TypedHeaderRejection, StringRejection}}, + extract::{FromRequest, TypedHeader, rejection::TypedHeaderRejection}, + response::Response, headers::{self, UserAgent}, }; use axum_macros::FromRequest; -use std::convert::Infallible; #[derive(FromRequest)] struct Extractor { @@ -18,34 +18,8 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, + Extractor: FromRequest<(), Body, Rejection = Response>, { } -fn assert_rejection(rejection: ExtractorRejection) -where - ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error, -{ - match rejection { - ExtractorRejection::Uri(inner) => { - let _: Infallible = inner; - } - ExtractorRejection::Body(inner) => { - let _: StringRejection = inner; - } - ExtractorRejection::UserAgent(inner) => { - let _: TypedHeaderRejection = inner; - } - ExtractorRejection::ContentType(inner) => { - let _: TypedHeaderRejection = inner; - } - ExtractorRejection::Etag(inner) => { - let _: Infallible = inner; - } - ExtractorRejection::Host(inner) => { - let _: Infallible = inner; - } - } -} - fn main() {} diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index 8a81869d1a..23da2ac621 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -1,13 +1,13 @@ use axum::{ body::Body, + response::Response, extract::{ - rejection::{ExtensionRejection, TypedHeaderRejection}, + rejection::TypedHeaderRejection, Extension, FromRequest, TypedHeader, }, headers::{self, UserAgent}, }; use axum_macros::FromRequest; -use std::convert::Infallible; #[derive(FromRequest)] struct Extractor { @@ -25,33 +25,10 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, + Extractor: FromRequest<(), Body, Rejection = Response>, { } -fn assert_rejection(rejection: ExtractorRejection) -where - ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error, -{ - match rejection { - ExtractorRejection::State(inner) => { - let _: ExtensionRejection = inner; - } - ExtractorRejection::UserAgent(inner) => { - let _: TypedHeaderRejection = inner; - } - ExtractorRejection::ContentType(inner) => { - let _: TypedHeaderRejection = inner; - } - ExtractorRejection::Etag(inner) => { - let _: Infallible = inner; - } - ExtractorRejection::Host(inner) => { - let _: Infallible = inner; - } - } -} - #[derive(Clone)] struct State; diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index 40a25d5870..7167ffe0ea 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -1,7 +1,7 @@ use axum::{ async_trait, - extract::{rejection::ExtensionRejection, FromRequest, RequestParts}, - http::StatusCode, + extract::{rejection::ExtensionRejection, FromRequest}, + http::{StatusCode, Request}, response::{IntoResponse, Response}, routing::get, Extension, Router, @@ -36,7 +36,7 @@ where // this rejection doesn't implement `Display` and `Error` type Rejection = (StatusCode, String); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: Request, _state: &S) -> Result { todo!() } } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs index 00b6dd78df..227e4a3c8f 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs @@ -1,4 +1,4 @@ -use axum::extract::{Query, rejection::*}; +use axum::extract::Query; use axum_macros::FromRequest; use serde::Deserialize; @@ -17,18 +17,4 @@ where { } -fn assert_rejection(rejection: ExtractorRejection) -where - ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error, -{ - match rejection { - ExtractorRejection::QueryPayload(inner) => { - let _: QueryRejection = inner; - } - ExtractorRejection::JsonPayload(inner) => { - let _: JsonRejection = inner; - } - } -} - fn main() {} diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs index 0b148ebc50..82342c56c5 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs @@ -1,4 +1,5 @@ -use axum::extract::{Query, rejection::*}; +use axum::extract::Query; +use axum::response::Response; use axum_macros::FromRequest; use serde::Deserialize; @@ -8,26 +9,12 @@ struct Extractor( #[from_request(via(axum::extract::Json))] Payload, ); -fn assert_rejection(rejection: ExtractorRejection) -where - ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error, -{ - match rejection { - ExtractorRejection::QueryPayload(inner) => { - let _: QueryRejection = inner; - } - ExtractorRejection::JsonPayload(inner) => { - let _: JsonRejection = inner; - } - } -} - #[derive(Deserialize)] struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body>, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = Response>, { } diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 9aabf3625f..91f3c3e30a 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is (T0, T1, T2, T3) and 138 others = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` - = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` + = note: required because of the requirements on the impl of `FromRequestParts` for `axum::extract::Path` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/typed_path/pass/option_result.rs b/axum-macros/tests/typed_path/pass/option_result.rs index 252bde137f..bd4c6dc282 100644 --- a/axum-macros/tests/typed_path/pass/option_result.rs +++ b/axum-macros/tests/typed_path/pass/option_result.rs @@ -16,7 +16,6 @@ async fn result_handler(_: Result) {} #[typed_path("/users")] struct UsersIndex; -#[axum_macros::debug_handler] async fn result_handler_unit_struct(_: Result) {} fn main() { diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 554206ed3d..b1ce5a204d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -237,18 +237,48 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 } } ``` -- **breaking:** The following types or traits have a new `S` type param - (`()` by default) which represents the state ([#1155]): - - `FromRequest` - - `RequestParts` - - `Router` - - `MethodRouter` - - `Handler` +- **breaking:** It is now only possible for one extractor per handler to consume + the request body. In 0.5 doing so would result in runtime errors but in 0.6 it + is a compile error ([#1272]) + + axum enforces this by only allowing the _last_ extractor to consume the + request. + + For example: + + ```rust + use axum::{Json, http::HeaderMap}; + + // This wont compile on 0.6 because both `Json` and `String` need to consume + // the request body. You can use either `Json` or `String`, but not both. + async fn handler_1( + json: Json, + string: String, + ) {} + + // This won't work either since `Json` is not the last extractor. + async fn handler_2( + json: Json, + headers: HeaderMap, + ) {} + + // This works! + async fn handler_3( + headers: HeaderMap, + json: Json, + ) {} + ``` + + This is done by reworking the `FromRequest` trait and introducing a new + `FromRequestParts` trait. + + If your extractor needs to consume the request body then you should implement + `FromRequest`, otherwise implement `FromRequestParts`. This extractor in 0.5: ```rust - struct MyExtractor; + struct MyExtractor { /* ... */ } #[async_trait] impl FromRequest for MyExtractor @@ -266,22 +296,53 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Becomes this in 0.6: ```rust - struct MyExtractor; + use axum::{ + extract::{FromRequest, FromRequestParts}, + http::{StatusCode, Request, request::Parts}, + async_trait, + }; + + struct MyExtractor { /* ... */ } + // implement `FromRequestParts` if you don't need to consume the request body + #[async_trait] + impl FromRequestParts for MyExtractor + where + S: Send + Sync, + B: Send + 'static, + { + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // ... + } + } + + // implement `FromRequest` if you do need to consume the request body #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, - B: Send, + B: Send + 'static, { type Rejection = StatusCode; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, state: &S) -> Result { // ... } } ``` +- **breaking:** `RequestParts` has been removed as part of the `FromRequest` + rework ([#1272]) +- **breaking:** `BodyAlreadyExtracted` has been removed ([#1272]) +- **breaking:** The following types or traits have a new `S` type param + which represents the state ([#1155]): + - `Router`, defaults to `()` + - `MethodRouter`, defaults to `()` + - `FromRequest`, no default + - `Handler`, no default + ## Middleware - **breaking:** Remove `extractor_middleware` which was previously deprecated. @@ -310,6 +371,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1248]: https://github.com/tokio-rs/axum/pull/1248 +[#1272]: https://github.com/tokio-rs/axum/pull/1272 [#924]: https://github.com/tokio-rs/axum/pull/924 # 0.5.15 (9. August, 2022) diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index d87eb978d1..c7ead05e5e 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -5,14 +5,15 @@ Types and traits for extracting data from requests. - [Intro](#intro) - [Common extractors](#common-extractors) - [Applying multiple extractors](#applying-multiple-extractors) -- [Be careful when extracting `Request`](#be-careful-when-extracting-request) +- [The order of extractors](#the-order-of-extractors) - [Optional extractors](#optional-extractors) - [Customizing extractor responses](#customizing-extractor-responses) - [Accessing inner errors](#accessing-inner-errors) - [Defining custom extractors](#defining-custom-extractors) -- [Accessing other extractors in `FromRequest` implementations](#accessing-other-extractors-in-fromrequest-implementations) +- [Accessing other extractors in `FromRequest` or `FromRequestParts` implementations](#accessing-other-extractors-in-fromrequest-or-fromrequestparts-implementations) - [Request body extractors](#request-body-extractors) - [Running extractors from middleware](#running-extractors-from-middleware) +- [Wrapping extractors](#wrapping-extractors) # Intro @@ -152,83 +153,74 @@ async fn get_user_things( # }; ``` +# The order of extractors + Extractors always run in the order of the function parameters that is from left to right. -# Be careful when extracting `Request` +The request body is an asynchronous stream that can only be consumed once. +Therefore you can only have one extractor that consumes the request body. axum +enforces by that requiring such extractors to be the _last_ argument your +handler takes. -[`Request`] is itself an extractor: +For example -```rust,no_run -use axum::{http::Request, body::Body}; +```rust +use axum::http::{Method, HeaderMap}; -async fn handler(request: Request) { +async fn handler( + // `Method` and `HeaderMap` don't consume the request body so they can + // put anywhere in the argument list + method: Method, + headers: HeaderMap, + // `String` consumes the request body and thus must be the last extractor + body: String, +) { // ... } +# +# let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` -However be careful when combining it with other extractors since it will consume -all extensions and the request body. Therefore it is recommended to always apply -the request extractor last: +We get a compile error if `String` isn't the last extractor: -```rust,no_run -use axum::{http::Request, Extension, body::Body}; +```rust,compile_fail +use axum::http::Method; -// this will fail at runtime since `Request` will have consumed all the -// extensions so `Extension` will be missing -async fn broken( - request: Request, - Extension(state): Extension, -) { - // ... -} - -// this will work since we extract `Extension` before `Request` -async fn works( - Extension(state): Extension, - request: Request, +async fn handler( + // this doesn't work since `String` must be the last argument + body: String, + method: Method, ) { // ... } - -#[derive(Clone)] -struct State {}; +# +# let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` -# Extracting request bodies - -Since request bodies are asynchronous streams they can only be extracted once: +This also means you cannot consume the request body twice: -```rust,no_run -use axum::{Json, http::Request, body::{Bytes, Body}}; -use serde_json::Value; +```rust,compile_fail +use axum::Json; +use serde::Deserialize; -// this will fail at runtime since `Json` and `Bytes` both attempt to extract -// the body -// -// the solution is to only extract the body once so remove either -// `body_json: Json` or `body_bytes: Bytes` -async fn broken( - body_json: Json, - body_bytes: Bytes, -) { - // ... -} +#[derive(Deserialize)] +struct Payload {} -// this doesn't work either for the same reason: `Bytes` and `Request` -// both extract the body -async fn also_broken( - body_json: Json, - request: Request, +async fn handler( + // `String` and `Json` both consume the request body + // so they cannot both be used + string_body: String, + json_body: Json, ) { // ... } +# +# let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` -Also keep this in mind if you extract or otherwise consume the body in -middleware. You either need to not extract the body in handlers or make sure -your middleware reinserts the body using [`RequestParts::body_mut`] so it's -available to handlers. +axum enforces this by requiring the last extractor implements [`FromRequest`] +and all others implement [`FromRequestParts`]. # Optional extractors @@ -407,29 +399,38 @@ happen without major breaking versions. # Defining custom extractors -You can also define your own extractors by implementing [`FromRequest`]: +You can also define your own extractors by implementing either +[`FromRequestParts`] or [`FromRequest`]. + +## Implementing `FromRequestParts` + +Implement `FromRequestParts` if your extractor doesn't need access to the +request body: ```rust,no_run use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::FromRequestParts, routing::get, Router, + http::{ + StatusCode, + header::{HeaderValue, USER_AGENT}, + request::Parts, + }, }; -use http::{StatusCode, header::{HeaderValue, USER_AGENT}}; struct ExtractUserAgent(HeaderValue); #[async_trait] -impl FromRequest for ExtractUserAgent +impl FromRequestParts for ExtractUserAgent where - B: Send, S: Send + Sync, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { - if let Some(user_agent) = req.headers().get(USER_AGENT) { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + if let Some(user_agent) = parts.headers.get(USER_AGENT) { Ok(ExtractUserAgent(user_agent.clone())) } else { Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing")) @@ -447,7 +448,58 @@ let app = Router::new().route("/foo", get(handler)); # }; ``` -# Accessing other extractors in [`FromRequest`] implementations +## Implementing `FromRequest` + +If your extractor needs to consume the request body you must implement [`FromRequest`] + +```rust,no_run +use axum::{ + async_trait, + extract::FromRequest, + response::{Response, IntoResponse}, + body::Bytes, + routing::get, + Router, + http::{ + StatusCode, + header::{HeaderValue, USER_AGENT}, + Request, + }, +}; + +struct ValidatedBody(Bytes); + +#[async_trait] +impl FromRequest for ValidatedBody +where + Bytes: FromRequest, + B: Send + 'static, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request(req: Request, state: &S) -> Result { + let body = Bytes::from_request(req, state) + .await + .map_err(IntoResponse::into_response)?; + + // do validation... + + Ok(Self(body)) + } +} + +async fn handler(ValidatedBody(body): ValidatedBody) { + // ... +} + +let app = Router::new().route("/foo", get(handler)); +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` + +# Accessing other extractors in `FromRequest` or `FromRequestParts` implementations When defining custom extractors you often need to access another extractors in your implementation. @@ -455,9 +507,9 @@ in your implementation. ```rust use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts, TypedHeader}, + extract::{Extension, FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, - http::StatusCode, + http::{StatusCode, request::Parts}, response::{IntoResponse, Response}, routing::get, Router, @@ -473,20 +525,19 @@ struct AuthenticatedUser { } #[async_trait] -impl FromRequest for AuthenticatedUser +impl FromRequestParts for AuthenticatedUser where - B: Send, S: Send + Sync, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let TypedHeader(Authorization(token)) = - TypedHeader::>::from_request(req) + TypedHeader::>::from_request_parts(parts, state) .await .map_err(|err| err.into_response())?; - let Extension(state): Extension = Extension::from_request(req) + let Extension(state): Extension = Extension::from_request_parts(parts, state) .await .map_err(|err| err.into_response())?; @@ -584,14 +635,13 @@ let app = Router::new() # Running extractors from middleware -Extractors can also be run from middleware by making a [`RequestParts`] and -running your extractor: +Extractors can also be run from middleware: ```rust use axum::{ Router, middleware::{self, Next}, - extract::{RequestParts, TypedHeader}, + extract::{TypedHeader, FromRequestParts}, http::{Request, StatusCode}, response::Response, headers::authorization::{Authorization, Bearer}, @@ -604,12 +654,11 @@ async fn auth_middleware( where B: Send, { - // running extractors requires a `RequestParts` - let mut request_parts = RequestParts::new(request); + // running extractors requires a `axum::http::request::Parts` + let (mut parts, body) = request.into_parts(); - // `TypedHeader>` extracts the auth token but - // `RequestParts::extract` works with anything that implements `FromRequest` - let auth = request_parts.extract::>>() + // `TypedHeader>` extracts the auth token + let auth = TypedHeader::>::from_request_parts(&mut parts, &()) .await .map_err(|_| StatusCode::UNAUTHORIZED)?; @@ -617,14 +666,8 @@ where return Err(StatusCode::UNAUTHORIZED); } - // get the request back so we can run `next` - // - // `try_into_request` will fail if you have extracted the request body. We - // know that `TypedHeader` never does that. - // - // see the `consume-body-in-extractor-or-middleware` example if you need to - // extract the body - let request = request_parts.try_into_request().expect("body extracted"); + // reconstruct the request + let request = Request::from_parts(parts, body); Ok(next.run(request).await) } @@ -638,8 +681,81 @@ let app = Router::new().layer(middleware::from_fn(auth_middleware)); # let _: Router<()> = app; ``` +# Wrapping extractors + +If you want write an extractor that generically wraps another extractor (that +may or may not consume the request body) you should implement both +[`FromRequest`] and [`FromRequestParts`]: + +```rust +use axum::{ + Router, + routing::get, + extract::{FromRequest, FromRequestParts}, + http::{Request, HeaderMap, request::Parts}, + async_trait, +}; +use std::time::{Instant, Duration}; + +// an extractor that wraps another and measures how long time it takes to run +struct Timing { + extractor: E, + duration: Duration, +} + +// we must implement both `FromRequestParts` +#[async_trait] +impl FromRequestParts for Timing +where + S: Send + Sync, + T: FromRequestParts, +{ + type Rejection = T::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let start = Instant::now(); + let extractor = T::from_request_parts(parts, state).await?; + let duration = start.elapsed(); + Ok(Timing { + extractor, + duration, + }) + } +} + +// and `FromRequest` +#[async_trait] +impl FromRequest for Timing +where + B: Send + 'static, + S: Send + Sync, + T: FromRequest, +{ + type Rejection = T::Rejection; + + async fn from_request(req: Request, state: &S) -> Result { + let start = Instant::now(); + let extractor = T::from_request(req, state).await?; + let duration = start.elapsed(); + Ok(Timing { + extractor, + duration, + }) + } +} + +async fn handler( + // this uses the `FromRequestParts` impl + _: Timing, + // this uses the `FromRequest` impl + _: Timing, +) {} +# let _: axum::routing::MethodRouter = axum::routing::get(handler); +``` + [`body::Body`]: crate::body::Body [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs [`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html [`Request`]: https://docs.rs/http/latest/http/struct.Request.html [`RequestParts::body_mut`]: crate::extract::RequestParts::body_mut +[`JsonRejection::JsonDataError`]: rejection::JsonRejection::JsonDataError diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index 6a72d82069..fc6eacdca0 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -1,8 +1,8 @@ #![doc = include_str!("../docs/error_handling.md")] use crate::{ - extract::{FromRequest, RequestParts}, - http::{Request, StatusCode}, + extract::FromRequestParts, + http::Request, response::{IntoResponse, Response}, }; use std::{ @@ -161,7 +161,7 @@ macro_rules! impl_service { F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - $( $ty: FromRequest<(), B> + Send,)* + $( $ty: FromRequestParts<()> + Send,)* B: Send + 'static, { type Response = Response; @@ -181,21 +181,16 @@ macro_rules! impl_service { let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { - let mut req = RequestParts::new(req); + let (mut parts, body) = req.into_parts(); $( - let $ty = match $ty::from_request(&mut req).await { + let $ty = match $ty::from_request_parts(&mut parts, &()).await { Ok(value) => value, Err(rejection) => return Ok(rejection.into_response()), }; )* - let req = match req.try_into_request() { - Ok(req) => req, - Err(err) => { - return Ok((StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()); - } - }; + let req = Request::from_parts(parts, body); match inner.oneshot(req).await { Ok(res) => Ok(res.into_response()), diff --git a/axum/src/extension.rs b/axum/src/extension.rs index 4c93ce1b47..575d62ca24 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -1,10 +1,10 @@ -use crate::{ - extract::{rejection::*, FromRequest, RequestParts}, - response::IntoResponseParts, -}; +use crate::{extract::rejection::*, response::IntoResponseParts}; use async_trait::async_trait; -use axum_core::response::{IntoResponse, Response, ResponseParts}; -use http::Request; +use axum_core::{ + extract::FromRequestParts, + response::{IntoResponse, Response, ResponseParts}, +}; +use http::{request::Parts, Request}; use std::{ convert::Infallible, ops::Deref, @@ -73,17 +73,16 @@ use tower_service::Service; pub struct Extension(pub T); #[async_trait] -impl FromRequest for Extension +impl FromRequestParts for Extension where T: Clone + Send + Sync + 'static, - B: Send, S: Send + Sync, { type Rejection = ExtensionRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request_parts(req: &mut Parts, _state: &S) -> Result { let value = req - .extensions() + .extensions .get::() .ok_or_else(|| { MissingExtension::from_err(format!( diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index ba8c301c4c..fff2b46f2b 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -4,9 +4,10 @@ //! //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info -use super::{Extension, FromRequest, RequestParts}; +use super::{Extension, FromRequestParts}; use crate::middleware::AddExtension; use async_trait::async_trait; +use http::request::Parts; use hyper::server::conn::AddrStream; use std::{ convert::Infallible, @@ -128,16 +129,15 @@ opaque_future! { pub struct ConnectInfo(pub T); #[async_trait] -impl FromRequest for ConnectInfo +impl FromRequestParts for ConnectInfo where - B: Send, S: Send + Sync, T: Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequestParts>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(connect_info) = Extension::::from_request(req).await?; + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let Extension(connect_info) = Extension::::from_request_parts(parts, state).await?; Ok(connect_info) } } diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index ae27a6c29c..a3d1f711c9 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -1,7 +1,7 @@ -use super::{rejection::*, FromRequest, RequestParts}; +use super::{rejection::*, FromRequest}; use async_trait::async_trait; -use axum_core::response::IntoResponse; -use http::Method; +use axum_core::{extract::FromRequestParts, response::IntoResponse}; +use http::{request::Parts, Method, Request}; use std::ops::Deref; /// Extractor that will reject requests with a body larger than some size. @@ -40,43 +40,17 @@ impl FromRequest for ContentLengthLimit where T: FromRequest, T::Rejection: IntoResponse, - B: Send, + B: Send + 'static, S: Send + Sync, { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let content_length = req - .headers() - .get(http::header::CONTENT_LENGTH) - .and_then(|value| value.to_str().ok()?.parse::().ok()); - - match (content_length, req.method()) { - (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => { - if content_length.is_some() { - return Err(ContentLengthLimitRejection::ContentLengthNotAllowed( - ContentLengthNotAllowed, - )); - } else if req - .headers() - .get(http::header::TRANSFER_ENCODING) - .map_or(false, |value| value.as_bytes() == b"chunked") - { - return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); - } - } - (Some(content_length), _) if content_length > N => { - return Err(ContentLengthLimitRejection::PayloadTooLarge( - PayloadTooLarge, - )); - } - (None, _) => { - return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); - } - _ => {} - } + async fn from_request(req: Request, state: &S) -> Result { + let (parts, body) = req.into_parts(); + validate::<_, N>(&parts)?; - let value = T::from_request(req) + let req = Request::from_parts(parts, body); + let value = T::from_request(req, state) .await .map_err(ContentLengthLimitRejection::Inner)?; @@ -84,6 +58,60 @@ where } } +#[async_trait] +impl FromRequestParts for ContentLengthLimit +where + T: FromRequestParts, + T::Rejection: IntoResponse, + S: Send + Sync, +{ + type Rejection = ContentLengthLimitRejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + validate::<_, N>(parts)?; + + let value = T::from_request_parts(parts, state) + .await + .map_err(ContentLengthLimitRejection::Inner)?; + + Ok(Self(value)) + } +} + +fn validate(parts: &Parts) -> Result<(), ContentLengthLimitRejection> { + let content_length = parts + .headers + .get(http::header::CONTENT_LENGTH) + .and_then(|value| value.to_str().ok()?.parse::().ok()); + + match (content_length, &parts.method) { + (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => { + if content_length.is_some() { + return Err(ContentLengthLimitRejection::ContentLengthNotAllowed( + ContentLengthNotAllowed, + )); + } else if parts + .headers + .get(http::header::TRANSFER_ENCODING) + .map_or(false, |value| value.as_bytes() == b"chunked") + { + return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + } + } + (Some(content_length), _) if content_length > N => { + return Err(ContentLengthLimitRejection::PayloadTooLarge( + PayloadTooLarge, + )); + } + (None, _) => { + return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + } + _ => {} + } + + Ok(()) +} + impl Deref for ContentLengthLimit { type Target = T; diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index 6137c64ef6..f92a62273a 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -1,9 +1,12 @@ use super::{ rejection::{FailedToResolveHost, HostRejection}, - FromRequest, RequestParts, + FromRequestParts, }; use async_trait::async_trait; -use http::header::{HeaderMap, FORWARDED}; +use http::{ + header::{HeaderMap, FORWARDED}, + request::Parts, +}; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; @@ -21,35 +24,34 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; pub struct Host(pub String); #[async_trait] -impl FromRequest for Host +impl FromRequestParts for Host where - B: Send, S: Send + Sync, { type Rejection = HostRejection; - async fn from_request(req: &mut RequestParts) -> Result { - if let Some(host) = parse_forwarded(req.headers()) { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if let Some(host) = parse_forwarded(&parts.headers) { return Ok(Host(host.to_owned())); } - if let Some(host) = req - .headers() + if let Some(host) = parts + .headers .get(X_FORWARDED_HOST_HEADER_KEY) .and_then(|host| host.to_str().ok()) { return Ok(Host(host.to_owned())); } - if let Some(host) = req - .headers() + if let Some(host) = parts + .headers .get(http::header::HOST) .and_then(|host| host.to_str().ok()) { return Ok(Host(host.to_owned())); } - if let Some(host) = req.uri().host() { + if let Some(host) = parts.uri.host() { return Ok(Host(host.to_owned())); } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 35a076b2cf..6af783d956 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,5 +1,6 @@ -use super::{rejection::*, FromRequest, RequestParts}; +use super::{rejection::*, FromRequestParts}; use async_trait::async_trait; +use http::request::Parts; use std::sync::Arc; /// Access the path in the router that matches the request. @@ -64,16 +65,15 @@ impl MatchedPath { } #[async_trait] -impl FromRequest for MatchedPath +impl FromRequestParts for MatchedPath where - B: Send, S: Send + Sync, { type Rejection = MatchedPathRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let matched_path = req - .extensions() + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let matched_path = parts + .extensions .get::() .ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))? .clone(); diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 081793a83c..c3d0ce3329 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -1,7 +1,6 @@ #![doc = include_str!("../docs/extract.md")] -use http::header; -use rejection::*; +use http::header::{self, HeaderMap}; pub mod connect_info; pub mod path; @@ -17,7 +16,7 @@ mod request_parts; mod state; #[doc(inline)] -pub use axum_core::extract::{FromRef, FromRequest, RequestParts}; +pub use axum_core::extract::{FromRef, FromRequest, FromRequestParts}; #[doc(inline)] #[allow(deprecated)] @@ -75,16 +74,9 @@ pub use self::ws::WebSocketUpgrade; #[doc(no_inline)] pub use crate::TypedHeader; -pub(crate) fn take_body(req: &mut RequestParts) -> Result { - req.take_body().ok_or_else(BodyAlreadyExtracted::default) -} - // this is duplicated in `axum-extra/src/extract/form.rs` -pub(super) fn has_content_type( - req: &RequestParts, - expected_content_type: &mime::Mime, -) -> bool { - let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { +pub(super) fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { content_type } else { return false; diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 3063a45c9e..af0fe20236 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -2,12 +2,13 @@ //! //! See [`Multipart`] for more details. -use super::{rejection::*, BodyStream, FromRequest, RequestParts}; +use super::{BodyStream, FromRequest}; use crate::body::{Bytes, HttpBody}; use crate::BoxError; use async_trait::async_trait; use futures_util::stream::Stream; use http::header::{HeaderMap, CONTENT_TYPE}; +use http::Request; use std::{ fmt, pin::Pin, @@ -58,10 +59,12 @@ where { type Rejection = MultipartRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let stream = BodyStream::from_request(req).await?; - let headers = req.headers(); - let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; + async fn from_request(req: Request, state: &S) -> Result { + let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; + let stream = match BodyStream::from_request(req, state).await { + Ok(stream) => stream, + Err(err) => match err {}, + }; let multipart = multer::Multipart::new(stream, boundary); Ok(Self { inner: multipart }) } @@ -224,7 +227,6 @@ composite_rejection! { /// /// Contains one variant for each way the [`Multipart`] extractor can fail. pub enum MultipartRejection { - BodyAlreadyExtracted, InvalidBoundary, } } diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 16ed753b94..0575c8a34c 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -4,12 +4,12 @@ mod de; use crate::{ - extract::{rejection::*, FromRequest, RequestParts}, + extract::{rejection::*, FromRequestParts}, routing::url_params::UrlParams, }; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; -use http::StatusCode; +use http::{request::Parts, StatusCode}; use serde::de::DeserializeOwned; use std::{ fmt, @@ -163,16 +163,15 @@ impl DerefMut for Path { } #[async_trait] -impl FromRequest for Path +impl FromRequestParts for Path where T: DeserializeOwned + Send, - B: Send, S: Send + Sync, { type Rejection = PathRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let params = match req.extensions_mut().get::() { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let params = match parts.extensions.get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { let err = PathDeserializationError { @@ -413,8 +412,7 @@ impl std::error::Error for FailedToDeserializePathParams {} mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; - use http::{Request, StatusCode}; - use hyper::Body; + use http::StatusCode; use std::collections::HashMap; #[tokio::test] @@ -519,20 +517,6 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } - #[tokio::test] - async fn when_extensions_are_missing() { - let app = Router::new().route("/:key", get(|_: Request, _: Path| async {})); - - let client = TestClient::new(app); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - res.text().await, - "No paths parameters found for matched route. Are you also extracting `Request<_>`?" - ); - } - #[tokio::test] async fn str_reference_deserialize() { struct Param(String); diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index 6eedd8c1fc..b4c34f942a 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -1,5 +1,6 @@ -use super::{rejection::*, FromRequest, RequestParts}; +use super::{rejection::*, FromRequestParts}; use async_trait::async_trait; +use http::request::Parts; use serde::de::DeserializeOwned; use std::ops::Deref; @@ -49,16 +50,15 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequestParts for Query where T: DeserializeOwned, - B: Send, S: Send + Sync, { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { - let query = req.uri().query().unwrap_or_default(); + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let query = parts.uri.query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; Ok(Query(value)) @@ -76,15 +76,17 @@ impl Deref for Query { #[cfg(test)] mod tests { use super::*; - use crate::extract::RequestParts; + use axum_core::extract::FromRequest; use http::Request; use serde::Deserialize; use std::fmt::Debug; - async fn check(uri: impl AsRef, value: T) { + async fn check(uri: impl AsRef, value: T) + where + T: DeserializeOwned + PartialEq + Debug, + { let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); - let mut req = RequestParts::new(req); - assert_eq!(Query::::from_request(&mut req).await.unwrap().0, value); + assert_eq!(Query::::from_request(req, &()).await.unwrap().0, value); } #[tokio::test] diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index b0090957a3..0e507cfcc6 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -1,5 +1,6 @@ -use super::{FromRequest, RequestParts}; +use super::FromRequestParts; use async_trait::async_trait; +use http::request::Parts; use std::convert::Infallible; /// Extractor that extracts the raw query string, without parsing it. @@ -27,15 +28,14 @@ use std::convert::Infallible; pub struct RawQuery(pub Option); #[async_trait] -impl FromRequest for RawQuery +impl FromRequestParts for RawQuery where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let query = req.uri().query().map(|query| query.to_owned()); + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let query = parts.uri.query().map(|query| query.to_owned()); Ok(Self(query)) } } diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index 21dfd14ad1..135ad0a230 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -73,7 +73,7 @@ define_rejection! { define_rejection! { #[status = INTERNAL_SERVER_ERROR] - #[body = "No paths parameters found for matched route. Are you also extracting `Request<_>`?"] + #[body = "No paths parameters found for matched route"] /// Rejection type used if axum's internal representation of path parameters /// is missing. This is commonly caused by extracting `Request<_>`. `Path` /// must be extracted first. diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 5a0da46064..eff77a2b34 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -1,11 +1,11 @@ -use super::{rejection::*, take_body, Extension, FromRequest, RequestParts}; +use super::{Extension, FromRequest, FromRequestParts}; use crate::{ body::{Body, Bytes, HttpBody}, BoxError, Error, }; use async_trait::async_trait; use futures_util::stream::Stream; -use http::Uri; +use http::{request::Parts, Request, Uri}; use std::{ convert::Infallible, fmt, @@ -86,17 +86,16 @@ pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] #[async_trait] -impl FromRequest for OriginalUri +impl FromRequestParts for OriginalUri where - B: Send, S: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let uri = Extension::::from_request(req) + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let uri = Extension::::from_request_parts(parts, state) .await - .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) + .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone()))) .0; Ok(uri) } @@ -148,10 +147,11 @@ where B::Error: Into, S: Send + Sync, { - type Rejection = BodyAlreadyExtracted; + type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)? + async fn from_request(req: Request, _state: &S) -> Result { + let body = req + .into_body() .map_data(Into::into) .map_err(|err| Error::new(err.into())); let stream = BodyStream(SyncWrapper::new(Box::pin(body))); @@ -203,40 +203,17 @@ where B: Send, S: Send + Sync, { - type Rejection = BodyAlreadyExtracted; + type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - Ok(Self(body)) + async fn from_request(req: Request, _state: &S) -> Result { + Ok(Self(req.into_body())) } } #[cfg(test)] mod tests { - use crate::{ - body::Body, - extract::Extension, - routing::{get, post}, - test_helpers::*, - Router, - }; - use http::{Method, Request, StatusCode}; - - #[tokio::test] - async fn multiple_request_extractors() { - async fn handler(_: Request, _: Request) {} - - let app = Router::new().route("/", post(handler)); - - let client = TestClient::new(app); - - let res = client.post("/").body("hi there").send().await; - assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - res.text().await, - "Cannot have two request body extractors for a single handler" - ); - } + use crate::{extract::Extension, routing::get, test_helpers::*, Router}; + use http::{Method, StatusCode}; #[tokio::test] async fn extract_request_parts() { @@ -256,19 +233,4 @@ mod tests { let res = client.get("/").header("x-foo", "123").send().await; assert_eq!(res.status(), StatusCode::OK); } - - #[tokio::test] - async fn extract_request_parts_doesnt_consume_the_body() { - #[derive(Clone)] - struct Ext; - - async fn handler(_parts: http::request::Parts, body: String) { - assert_eq!(body, "foo"); - } - - let client = TestClient::new(Router::new().route("/", get(handler))); - - let res = client.get("/").body("foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - } } diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 1385dfec2a..eb5ae16ef4 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; -use axum_core::extract::{FromRef, FromRequest, RequestParts}; +use axum_core::extract::{FromRef, FromRequestParts}; +use http::request::Parts; use std::{ convert::Infallible, ops::{Deref, DerefMut}, @@ -139,7 +140,8 @@ use std::{ /// to do it: /// /// ```rust -/// use axum_core::extract::{FromRequest, RequestParts, FromRef}; +/// use axum_core::extract::{FromRequestParts, FromRef}; +/// use http::request::Parts; /// use async_trait::async_trait; /// use std::convert::Infallible; /// @@ -147,9 +149,8 @@ use std::{ /// struct MyLibraryExtractor; /// /// #[async_trait] -/// impl FromRequest for MyLibraryExtractor +/// impl FromRequestParts for MyLibraryExtractor /// where -/// B: Send, /// // keep `S` generic but require that it can produce a `MyLibraryState` /// // this means users will have to implement `FromRef for MyLibraryState` /// MyLibraryState: FromRef, @@ -157,9 +158,9 @@ use std::{ /// { /// type Rejection = Infallible; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // get a `MyLibraryState` from a reference to the state -/// let state = MyLibraryState::from_ref(req.state()); +/// let state = MyLibraryState::from_ref(state); /// /// // ... /// # todo!() @@ -171,23 +172,22 @@ use std::{ /// // ... /// } /// ``` -/// -/// Note that you don't need to use the `State` extractor since you can access the state directly -/// from [`RequestParts`]. #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); #[async_trait] -impl FromRequest for State +impl FromRequestParts for State where - B: Send, InnerState: FromRef, OuterState: Send + Sync, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let inner_state = InnerState::from_ref(req.state()); + async fn from_request_parts( + _parts: &mut Parts, + state: &OuterState, + ) -> Result { + let inner_state = InnerState::from_ref(state); Ok(Self(inner_state)) } } diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 976d12a71f..212abdadeb 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -95,7 +95,7 @@ //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split use self::rejection::*; -use super::{FromRequest, RequestParts}; +use super::FromRequestParts; use crate::{ body::{self, Bytes}, response::Response, @@ -107,7 +107,8 @@ use futures_util::{ stream::{Stream, StreamExt}, }; use http::{ - header::{self, HeaderName, HeaderValue}, + header::{self, HeaderMap, HeaderName, HeaderValue}, + request::Parts, Method, StatusCode, }; use hyper::upgrade::{OnUpgrade, Upgraded}; @@ -275,41 +276,40 @@ impl WebSocketUpgrade { } #[async_trait] -impl FromRequest for WebSocketUpgrade +impl FromRequestParts for WebSocketUpgrade where - B: Send, S: Send + Sync, { type Rejection = WebSocketUpgradeRejection; - async fn from_request(req: &mut RequestParts) -> Result { - if req.method() != Method::GET { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if parts.method != Method::GET { return Err(MethodNotGet.into()); } - if !header_contains(req, header::CONNECTION, "upgrade") { + if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { return Err(InvalidConnectionHeader.into()); } - if !header_eq(req, header::UPGRADE, "websocket") { + if !header_eq(&parts.headers, header::UPGRADE, "websocket") { return Err(InvalidUpgradeHeader.into()); } - if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13") { + if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { return Err(InvalidWebSocketVersionHeader.into()); } - let sec_websocket_key = req - .headers_mut() + let sec_websocket_key = parts + .headers .remove(header::SEC_WEBSOCKET_KEY) .ok_or(WebSocketKeyHeaderMissing)?; - let on_upgrade = req - .extensions_mut() + let on_upgrade = parts + .extensions .remove::() .ok_or(ConnectionNotUpgradable)?; - let sec_websocket_protocol = req.headers().get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); + let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); Ok(Self { config: Default::default(), @@ -321,16 +321,16 @@ where } } -fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { - if let Some(header) = req.headers().get(&key) { +fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = headers.get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { false } } -fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { - let header = if let Some(header) = req.headers().get(&key) { +fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = headers.get(&key) { header } else { return false; diff --git a/axum/src/form.rs b/axum/src/form.rs index 9477eff2b5..414e1fd527 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -1,10 +1,10 @@ use crate::body::{Bytes, HttpBody}; -use crate::extract::{has_content_type, rejection::*, FromRequest, RequestParts}; +use crate::extract::{has_content_type, rejection::*, FromRequest}; use crate::BoxError; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use http::header::CONTENT_TYPE; -use http::{Method, StatusCode}; +use http::{Method, Request, StatusCode}; use serde::de::DeserializeOwned; use serde::Serialize; use std::ops::Deref; @@ -59,25 +59,25 @@ pub struct Form(pub T); impl FromRequest for Form where T: DeserializeOwned, - B: HttpBody + Send, + B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, state: &S) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; Ok(Form(value)) } else { - if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) { + if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) { return Err(InvalidFormContentType.into()); } - let bytes = Bytes::from_request(req).await?; + let bytes = Bytes::from_request(req, state).await?; let value = serde_urlencoded::from_bytes(&bytes) .map_err(FailedToDeserializeQueryString::__private_new)?; @@ -114,7 +114,6 @@ impl Deref for Form { mod tests { use super::*; use crate::body::{Empty, Full}; - use crate::extract::RequestParts; use http::Request; use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -130,8 +129,7 @@ mod tests { .uri(uri.as_ref()) .body(Empty::::new()) .unwrap(); - let mut req = RequestParts::new(req); - assert_eq!(Form::::from_request(&mut req).await.unwrap().0, value); + assert_eq!(Form::::from_request(req, &()).await.unwrap().0, value); } async fn check_body(value: T) { @@ -146,8 +144,7 @@ mod tests { serde_urlencoded::to_string(&value).unwrap().into(), )) .unwrap(); - let mut req = RequestParts::new(req); - assert_eq!(Form::::from_request(&mut req).await.unwrap().0, value); + assert_eq!(Form::::from_request(req, &()).await.unwrap().0, value); } #[tokio::test] @@ -216,9 +213,8 @@ mod tests { .into(), )) .unwrap(); - let mut req = RequestParts::new(req); assert!(matches!( - Form::::from_request(&mut req) + Form::::from_request(req, &()) .await .unwrap_err(), FormRejection::InvalidFormContentType(InvalidFormContentType) diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index b902eb909f..2bb7aa043f 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -88,7 +88,7 @@ where use futures_util::future::FutureExt; let handler = self.handler.clone(); - let future = Handler::call(handler, Arc::clone(&self.state), req); + let future = Handler::call(handler, req, Arc::clone(&self.state)); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) diff --git a/axum/src/handler/into_service_state_in_extension.rs b/axum/src/handler/into_service_state_in_extension.rs index 3c59d043dd..0070d0a156 100644 --- a/axum/src/handler/into_service_state_in_extension.rs +++ b/axum/src/handler/into_service_state_in_extension.rs @@ -78,7 +78,7 @@ where .expect("state extension missing. This is a bug in axum, please file an issue"); let handler = self.handler.clone(); - let future = Handler::call(handler, state, req); + let future = Handler::call(handler, req, state); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 80514cf01c..0471122c5e 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -37,7 +37,7 @@ use crate::{ body::Body, - extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, RequestParts}, + extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, FromRequestParts}, response::{IntoResponse, Response}, routing::IntoMakeService, }; @@ -95,12 +95,12 @@ pub use self::{into_service::IntoService, with_state::WithState}; /// {} /// ``` #[doc = include_str!("../docs/debugging_handler_type_errors.md")] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; /// Call the handler with the given request. - fn call(self, state: Arc, req: Request) -> Self::Future; + fn call(self, req: Request, state: Arc) -> Self::Future; /// Apply a [`tower::Layer`] to the handler. /// @@ -162,7 +162,7 @@ pub trait Handler: Clone + Send + Sized + 'static { } } -impl Handler<(), S, B> for F +impl Handler<((),), S, B> for F where F: FnOnce() -> Fut + Clone + Send + 'static, Fut: Future + Send, @@ -171,37 +171,48 @@ where { type Future = Pin + Send>>; - fn call(self, _state: Arc, _req: Request) -> Self::Future { + fn call(self, _req: Request, _state: Arc) -> Self::Future { Box::pin(async move { self().await.into_response() }) } } macro_rules! impl_handler { - ( $($ty:ident),* $(,)? ) => { - #[allow(non_snake_case)] - impl Handler<($($ty,)*), S, B> for F + ( + [$($ty:ident),*], $last:ident + ) => { + #[allow(non_snake_case, unused_mut)] + impl Handler<(M, $($ty,)* $last,), S, B> for F where - F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, + F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static, Fut: Future + Send, B: Send + 'static, S: Send + Sync + 'static, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequestParts + Send, )* + $last: FromRequest + Send, { type Future = Pin + Send>>; - fn call(self, state: Arc, req: Request) -> Self::Future { + fn call(self, req: Request, state: Arc) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state_arc(state, req); + let (mut parts, body) = req.into_parts(); + let state = &state; $( - let $ty = match $ty::from_request(&mut req).await { + let $ty = match $ty::from_request_parts(&mut parts, state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* - let res = self($($ty,)*).await; + let req = Request::from_parts(parts, body); + + let $last = match $last::from_request(req, state).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + + let res = self($($ty,)* $last,).await; res.into_response() }) @@ -210,7 +221,31 @@ macro_rules! impl_handler { }; } -all_the_tuples!(impl_handler); +impl_handler!([], T1); +impl_handler!([T1], T2); +impl_handler!([T1, T2], T3); +impl_handler!([T1, T2, T3], T4); +impl_handler!([T1, T2, T3, T4], T5); +impl_handler!([T1, T2, T3, T4, T5], T6); +impl_handler!([T1, T2, T3, T4, T5, T6], T7); +impl_handler!([T1, T2, T3, T4, T5, T6, T7], T8); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8], T9); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); +impl_handler!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], + T14 +); +impl_handler!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], + T15 +); +impl_handler!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], + T16 +); /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// @@ -259,7 +294,7 @@ where { type Future = future::LayeredFuture; - fn call(self, state: Arc, req: Request) -> Self::Future { + fn call(self, req: Request, state: Arc) -> Self::Future { use futures_util::future::{FutureExt, Map}; let svc = self.handler.with_state_arc(state); diff --git a/axum/src/json.rs b/axum/src/json.rs index e60edf803a..d49630fa54 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -1,14 +1,14 @@ use crate::{ body::{Bytes, HttpBody}, - extract::{rejection::*, FromRequest, RequestParts}, + extract::{rejection::*, FromRequest}, BoxError, }; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use bytes::{BufMut, BytesMut}; use http::{ - header::{self, HeaderValue}, - StatusCode, + header::{self, HeaderMap, HeaderValue}, + Request, StatusCode, }; use serde::{de::DeserializeOwned, Serialize}; use std::ops::{Deref, DerefMut}; @@ -97,16 +97,16 @@ pub struct Json(pub T); impl FromRequest for Json where T: DeserializeOwned, - B: HttpBody + Send, + B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = JsonRejection; - async fn from_request(req: &mut RequestParts) -> Result { - if json_content_type(req) { - let bytes = Bytes::from_request(req).await?; + async fn from_request(req: Request, state: &S) -> Result { + if json_content_type(req.headers()) { + let bytes = Bytes::from_request(req, state).await?; let value = match serde_json::from_slice(&bytes) { Ok(value) => value, @@ -137,8 +137,8 @@ where } } -fn json_content_type(req: &RequestParts) -> bool { - let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { +fn json_content_type(headers: &HeaderMap) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { content_type } else { return false; diff --git a/axum/src/lib.rs b/axum/src/lib.rs index a9b1bfb144..df51086828 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -93,8 +93,8 @@ //! //! # Extractors //! -//! An extractor is a type that implements [`FromRequest`]. Extractors is how -//! you pick apart the incoming request to get the parts your handler needs. +//! An extractor is a type that implements [`FromRequest`] or [`FromRequestParts`]. Extractors is +//! how you pick apart the incoming request to get the parts your handler needs. //! //! ```rust //! use axum::extract::{Path, Query, Json}; @@ -302,9 +302,10 @@ //! //! # Building integrations for axum //! -//! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations -//! should depend on the [`axum-core`] crate, instead of `axum` if possible. [`axum-core`] contains -//! core types and traits and is less likely to receive breaking changes. +//! Libraries authors that want to provide [`FromRequest`], [`FromRequestParts`], or +//! [`IntoResponse`] implementations should depend on the [`axum-core`] crate, instead of `axum` if +//! possible. [`axum-core`] contains core types and traits and is less likely to receive breaking +//! changes. //! //! # Required dependencies //! @@ -376,6 +377,7 @@ //! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides //! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ //! [`FromRequest`]: crate::extract::FromRequest +//! [`FromRequestParts`]: crate::extract::FromRequestParts //! [`HeaderMap`]: http::header::HeaderMap //! [`Request`]: http::Request //! [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index f9391731e1..042c872068 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -1,5 +1,5 @@ use crate::{ - extract::{FromRequest, RequestParts}, + extract::FromRequestParts, response::{IntoResponse, Response}, }; use futures_util::{future::BoxFuture, ready}; @@ -33,28 +33,27 @@ use tower_service::Service; /// /// ```rust /// use axum::{ -/// extract::{FromRequest, RequestParts}, +/// extract::FromRequestParts, /// middleware::from_extractor, /// routing::{get, post}, /// Router, +/// http::{header, StatusCode, request::Parts}, /// }; -/// use http::{header, StatusCode}; /// use async_trait::async_trait; /// /// // An extractor that performs authorization. /// struct RequireAuth; /// /// #[async_trait] -/// impl FromRequest for RequireAuth +/// impl FromRequestParts for RequireAuth /// where -/// B: Send, /// S: Send + Sync, /// { /// type Rejection = StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { -/// let auth_header = req -/// .headers() +/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { +/// let auth_header = parts +/// .headers /// .get(header::AUTHORIZATION) /// .and_then(|value| value.to_str().ok()); /// @@ -169,7 +168,7 @@ where impl Service> for FromExtractor where - E: FromRequest<(), B> + 'static, + E: FromRequestParts<()> + 'static, B: Default + Send + 'static, S: Service> + Clone, S::Response: IntoResponse, @@ -185,8 +184,9 @@ where fn call(&mut self, req: Request) -> Self::Future { let extract_future = Box::pin(async move { - let mut req = RequestParts::new(req); - let extracted = E::from_request(&mut req).await; + let (mut parts, body) = req.into_parts(); + let extracted = E::from_request_parts(&mut parts, &()).await; + let req = Request::from_parts(parts, body); (req, extracted) }); @@ -204,7 +204,7 @@ pin_project! { #[allow(missing_debug_implementations)] pub struct ResponseFuture where - E: FromRequest<(), B>, + E: FromRequestParts<()>, S: Service>, { #[pin] @@ -217,11 +217,11 @@ pin_project! { #[project = StateProj] enum State where - E: FromRequest<(), B>, + E: FromRequestParts<()>, S: Service>, { Extracting { - future: BoxFuture<'static, (RequestParts<(), B>, Result)>, + future: BoxFuture<'static, (Request, Result)>, }, Call { #[pin] future: S::Future }, } @@ -229,7 +229,7 @@ pin_project! { impl Future for ResponseFuture where - E: FromRequest<(), B>, + E: FromRequestParts<()>, S: Service>, S::Response: IntoResponse, B: Default, @@ -247,7 +247,6 @@ where match extracted { Ok(_) => { let mut svc = this.svc.take().expect("future polled after completion"); - let req = req.try_into_request().unwrap_or_default(); let future = svc.call(req); State::Call { future } } @@ -273,23 +272,25 @@ where mod tests { use super::*; use crate::{handler::Handler, routing::get, test_helpers::*, Router}; - use http::{header, StatusCode}; + use http::{header, request::Parts, StatusCode}; #[tokio::test] async fn test_from_extractor() { struct RequireAuth; #[async_trait::async_trait] - impl FromRequest for RequireAuth + impl FromRequestParts for RequireAuth where - B: Send, S: Send + Sync, { type Rejection = StatusCode; - async fn from_request(req: &mut RequestParts) -> Result { - if let Some(auth) = req - .headers() + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result { + if let Some(auth) = parts + .headers .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) { diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 0d37c61863..9728a43a5c 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -1,5 +1,5 @@ use crate::response::{IntoResponse, Response}; -use axum_core::extract::{FromRequest, RequestParts}; +use axum_core::extract::{FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; use std::{ @@ -249,12 +249,15 @@ where } macro_rules! impl_service { - ( $($ty:ident),* $(,)? ) => { - #[allow(non_snake_case)] - impl Service> for FromFn + ( + [$($ty:ident),*], $last:ident + ) => { + #[allow(non_snake_case, unused_mut)] + impl Service> for FromFn where - F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, - $( $ty: FromRequest<(), B> + Send, )* + F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static, + $( $ty: FromRequestParts<()> + Send, )* + $last: FromRequest<(), B> + Send, Fut: Future + Send + 'static, Out: IntoResponse + 'static, S: Service, Error = Infallible> @@ -280,21 +283,29 @@ macro_rules! impl_service { let mut f = self.f.clone(); let future = Box::pin(async move { - let mut parts = RequestParts::new(req); + let (mut parts, body) = req.into_parts(); + $( - let $ty = match $ty::from_request(&mut parts).await { + let $ty = match $ty::from_request_parts(&mut parts, &()).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* + let req = Request::from_parts(parts, body); + + let $last = match $last::from_request(req, &()).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + let inner = ServiceBuilder::new() .boxed_clone() .map_response(IntoResponse::into_response) .service(ready_inner); let next = Next { inner }; - f($($ty),*, next).await.into_response() + f($($ty,)* $last, next).await.into_response() }); ResponseFuture { @@ -305,7 +316,31 @@ macro_rules! impl_service { }; } -all_the_tuples!(impl_service); +impl_service!([], T1); +impl_service!([T1], T2); +impl_service!([T1, T2], T3); +impl_service!([T1, T2, T3], T4); +impl_service!([T1, T2, T3, T4], T5); +impl_service!([T1, T2, T3, T4, T5], T6); +impl_service!([T1, T2, T3, T4, T5, T6], T7); +impl_service!([T1, T2, T3, T4, T5, T6, T7], T8); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8], T9); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); +impl_service!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], + T14 +); +impl_service!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], + T15 +); +impl_service!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], + T16 +); impl fmt::Debug for FromFn where diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index 60ab204132..87805a8230 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -1,7 +1,8 @@ -use crate::extract::{FromRequest, RequestParts}; +use crate::extract::FromRequestParts; use async_trait::async_trait; use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use headers::HeaderMapExt; +use http::request::Parts; use std::{convert::Infallible, ops::Deref}; /// Extractor and response that works with typed header values from [`headers`]. @@ -52,16 +53,15 @@ use std::{convert::Infallible, ops::Deref}; pub struct TypedHeader(pub T); #[async_trait] -impl FromRequest for TypedHeader +impl FromRequestParts for TypedHeader where T: headers::Header, - B: Send, S: Send + Sync, { type Rejection = TypedHeaderRejection; - async fn from_request(req: &mut RequestParts) -> Result { - match req.headers().typed_try_get::() { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + match parts.headers.typed_try_get::() { Ok(Some(value)) => Ok(Self(value)), Ok(None) => Err(TypedHeaderRejection { name: T::name(), diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index f9179bb42d..b3363d9976 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -7,7 +7,7 @@ use axum::{ async_trait, body::{self, BoxBody, Bytes, Full}, - extract::{FromRequest, RequestParts}, + extract::FromRequest, http::{Request, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, @@ -72,31 +72,28 @@ fn do_thing_with_request_body(bytes: Bytes) { tracing::debug!(body = ?bytes); } -async fn handler(_: PrintRequestBody, body: Bytes) { +async fn handler(BufferRequestBody(body): BufferRequestBody) { tracing::debug!(?body, "handler received body"); } // extractor that shows how to consume the request body upfront -struct PrintRequestBody; +struct BufferRequestBody(Bytes); +// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body #[async_trait] -impl FromRequest for PrintRequestBody +impl FromRequest for BufferRequestBody where - S: Clone + Send + Sync, + S: Send + Sync, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { - let state = req.state().clone(); - - let request = Request::from_request(req) + async fn from_request(req: Request, state: &S) -> Result { + let body = Bytes::from_request(req, state) .await .map_err(|err| err.into_response())?; - let request = buffer_request_body(request).await?; - - *req = RequestParts::with_state(state, request); + do_thing_with_request_body(body.clone()); - Ok(Self) + Ok(Self(body)) } } diff --git a/examples/customize-extractor-error/src/custom_extractor.rs b/examples/customize-extractor-error/src/custom_extractor.rs index c7e9d0f954..10aa0f0047 100644 --- a/examples/customize-extractor-error/src/custom_extractor.rs +++ b/examples/customize-extractor-error/src/custom_extractor.rs @@ -4,15 +4,13 @@ //! and `async/await`. This means that you can create more powerful rejections //! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Complexity: Manually implementing `FromRequest` results on more complex code -use axum::extract::MatchedPath; use axum::{ async_trait, - extract::{rejection::JsonRejection, FromRequest, RequestParts}, + extract::{rejection::JsonRejection, FromRequest, FromRequestParts, MatchedPath}, + http::Request, http::StatusCode, response::IntoResponse, - BoxError, }; -use serde::de::DeserializeOwned; use serde_json::{json, Value}; pub async fn handler(Json(value): Json) -> impl IntoResponse { @@ -25,31 +23,33 @@ pub struct Json(pub T); #[async_trait] impl FromRequest for Json where + axum::Json: FromRequest, S: Send + Sync, - // these trait bounds are copied from `impl FromRequest for axum::Json` - // `T: Send` is required to send this future across an await - T: DeserializeOwned + Send, - B: axum::body::HttpBody + Send, - B::Data: Send, - B::Error: Into, + B: Send + 'static, { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { - match axum::Json::::from_request(req).await { + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, body) = req.into_parts(); + + // We can use other extractors to provide better rejection + // messages. For example, here we are using + // `axum::extract::MatchedPath` to provide a better error + // message + // + // Have to run that first since `Json::from_request` consumes + // the request + let path = MatchedPath::from_request_parts(&mut parts, state) + .await + .map(|path| path.as_str().to_owned()) + .ok(); + + let req = Request::from_parts(parts, body); + + match axum::Json::::from_request(req, state).await { Ok(value) => Ok(Self(value.0)), // convert the error from `axum::Json` into whatever we want Err(rejection) => { - let path = req - .extract::() - .await - .map(|x| x.as_str().to_owned()) - .ok(); - - // We can use other extractors to provide better rejection - // messages. For example, here we are using - // `axum::extract::MatchedPath` to provide a better error - // message let payload = json!({ "message": rejection.to_string(), "origin": "custom_extractor", diff --git a/examples/customize-extractor-error/src/derive_from_request.rs b/examples/customize-extractor-error/src/derive_from_request.rs index 2a1625c008..762d602e5f 100644 --- a/examples/customize-extractor-error/src/derive_from_request.rs +++ b/examples/customize-extractor-error/src/derive_from_request.rs @@ -47,7 +47,7 @@ impl From for ApiError { } } -// We implement `IntoResponse` so ApiError can be used as a response +// We implement `IntoResponse` so `ApiError` can be used as a response impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let payload = json!({ diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index c4923069e5..68baf8f879 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -6,8 +6,8 @@ use axum::{ async_trait, - extract::{path::ErrorKind, rejection::PathRejection, FromRequest, RequestParts}, - http::StatusCode, + extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}, + http::{request::Parts, StatusCode}, response::IntoResponse, routing::get, Router, @@ -52,17 +52,16 @@ struct Params { struct Path(T); #[async_trait] -impl FromRequest for Path +impl FromRequestParts for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, - B: Send, S: Send + Sync, { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { - match axum::extract::Path::::from_request(req).await { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + match axum::extract::Path::::from_request_parts(parts, state).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { let (status, body) = match rejection { diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index 914ae18155..32f5704923 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -65,8 +65,8 @@ async fn users_show( /// Handler for `POST /users`. async fn users_create( - Json(params): Json, State(user_repo): State, + Json(params): Json, ) -> Result, AppError> { let user = user_repo.create(params).await?; diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 84d51f4409..3cef04c74b 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -8,9 +8,9 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts, TypedHeader}, + extract::{FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, - http::StatusCode, + http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, @@ -122,17 +122,16 @@ impl AuthBody { } #[async_trait] -impl FromRequest for Claims +impl FromRequestParts for Claims where S: Send + Sync, - B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = - TypedHeader::>::from_request(req) + TypedHeader::>::from_request_parts(parts, state) .await .map_err(|_| AuthError::InvalidToken)?; // Decode the user data diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index c65ee75a0f..4fb9e45bd0 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -96,8 +96,8 @@ async fn kv_get( async fn kv_set( Path(key): Path, - ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb State(state): State, + ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb ) { state.write().unwrap().db.insert(key, bytes); } diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index a61113b97d..079c65eb15 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,15 +12,14 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, - TypedHeader, + rejection::TypedHeaderRejectionReason, FromRef, FromRequestParts, Query, State, TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, Router, }; -use http::header; +use http::{header, request::Parts}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, @@ -139,7 +138,7 @@ async fn discord_auth(State(client): State) -> impl IntoResponse { .url(); // Redirect to Discord's oauth service - Redirect::to(&auth_url.to_string()) + Redirect::to(auth_url.as_ref()) } // Valid user session required. If there is none, redirect to the auth page @@ -224,17 +223,18 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl FromRequest for User +impl FromRequestParts for User where - B: Send, + MemoryStore: FromRef, + S: Send + Sync, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts) -> Result { - let store = req.state().clone().store; + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let store = MemoryStore::from_ref(state); - let cookies = TypedHeader::::from_request(req) + let cookies = TypedHeader::::from_request_parts(parts, state) .await .map_err(|e| match *e.name() { header::COOKIE => match e.reason() { diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index cd0d41a1f6..05d676faf0 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -7,11 +7,12 @@ use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{FromRequest, RequestParts, TypedHeader}, + extract::{FromRef, FromRequestParts, TypedHeader}, headers::Cookie, http::{ self, header::{HeaderMap, HeaderValue}, + request::Parts, StatusCode, }, response::IntoResponse, @@ -80,16 +81,19 @@ enum UserIdFromSession { } #[async_trait] -impl FromRequest for UserIdFromSession +impl FromRequestParts for UserIdFromSession where - B: Send, + MemoryStore: FromRef, + S: Send + Sync, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { - let store = req.state().clone(); + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let store = MemoryStore::from_ref(state); - let cookie = req.extract::>>().await.unwrap(); + let cookie = Option::>::from_request_parts(parts, state) + .await + .unwrap(); let session_cookie = cookie .as_ref() diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 6548cdeb97..9ba41ed804 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -15,8 +15,8 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts, State}, - http::StatusCode, + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, routing::get, Router, }; @@ -75,14 +75,15 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequestParts for DatabaseConnection where - B: Send, + PgPool: FromRef, + S: Send + Sync, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { - let pool = req.state().clone(); + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = PgPool::from_ref(state); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index b82a308db4..f323ebb318 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -105,7 +105,7 @@ struct CreateTodo { text: String, } -async fn todos_create(Json(input): Json, State(db): State) -> impl IntoResponse { +async fn todos_create(State(db): State, Json(input): Json) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, @@ -125,8 +125,8 @@ struct UpdateTodo { async fn todos_update( Path(id): Path, - Json(input): Json, State(db): State, + Json(input): Json, ) -> Result { let mut todo = db .read() diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index e0c60453e3..5e60c2bb78 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -6,8 +6,8 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts, State}, - http::StatusCode, + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, routing::get, Router, }; @@ -68,16 +68,15 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequestParts for DatabaseConnection where - B: Send, + ConnectionPool: FromRef, + S: Send + Sync, { type Rejection = (StatusCode, String); - async fn from_request( - req: &mut RequestParts, - ) -> Result { - let pool = req.state().clone(); + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = ConnectionPool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index e5988de290..359be614d3 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -12,11 +12,11 @@ use async_trait::async_trait; use axum::{ - extract::{Form, FromRequest, RequestParts}, - http::StatusCode, + extract::{rejection::FormRejection, Form, FromRequest}, + http::{Request, StatusCode}, response::{Html, IntoResponse, Response}, routing::get, - BoxError, Router, + Router, }; use serde::{de::DeserializeOwned, Deserialize}; use std::net::SocketAddr; @@ -64,14 +64,13 @@ impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, S: Send + Sync, - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, + Form: FromRequest, + B: Send + 'static, { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts) -> Result { - let Form(value) = Form::::from_request(req).await?; + async fn from_request(req: Request, state: &S) -> Result { + let Form(value) = Form::::from_request(req, state).await?; value.validate()?; Ok(ValidatedForm(value)) } @@ -83,7 +82,7 @@ pub enum ServerError { ValidationError(#[from] validator::ValidationErrors), #[error(transparent)] - AxumFormRejection(#[from] axum::extract::rejection::FormRejection), + AxumFormRejection(#[from] FormRejection), } impl IntoResponse for ServerError { diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index 6b53f77e91..2f67e33501 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -6,8 +6,8 @@ use axum::{ async_trait, - extract::{FromRequest, Path, RequestParts}, - http::StatusCode, + extract::{FromRequestParts, Path}, + http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, Router, @@ -48,15 +48,14 @@ enum Version { } #[async_trait] -impl FromRequest for Version +impl FromRequestParts for Version where - B: Send, S: Send + Sync, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { - let params = Path::>::from_request(req) + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let params = Path::>::from_request_parts(parts, state) .await .map_err(IntoResponse::into_response)?;