diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 8e0518c6fc..5a42320f23 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -1,78 +1,142 @@ -// TODO: Bring this back, and make sure it works with all the different possible permutations -// See https://github.com/tokio-rs/axum/pull/1277#issuecomment-1220358420 for details - -// use super::{FromRequest, FromRequestParts}; -// use crate::response::{IntoResponse, Response}; -// use async_trait::async_trait; -// use http::request::{Parts, Request}; -// use std::convert::Infallible; - -// #[async_trait] -// impl FromRequestParts for () -// where -// S: Send + Sync, -// { -// type Rejection = Infallible; - -// async fn from_request_parts(_: &mut Parts, _: &S) -> Result<(), Self::Rejection> { -// Ok(()) -// } -// } - -// macro_rules! impl_from_request { -// ( -// [$($ty:ident),*], $last:ident -// ) => { -// #[async_trait] -// #[allow(non_snake_case, unused_mut, unused_variables)] -// impl FromRequest for ($($ty,)* $last,) -// where -// $( $ty: FromRequestParts + Send, )* -// $last: FromRequest + Send, -// B: Send + 'static, -// S: Send + Sync, -// { -// type Rejection = Response; - -// async fn from_request(req: Request, state: &S) -> Result { -// let (mut parts, body) = req.into_parts(); - -// $( -// let $ty = $ty::from_request_parts(&mut parts, state).await.map_err(|err| err.into_response())?; -// )* - -// let req = Request::from_parts(parts, body); - -// let $last = $last::from_request(req, state).await.map_err(|err| err.into_response())?; - -// Ok(($($ty,)* $last,)) -// } -// } -// }; -// } - -// impl_from_request!([], T1); -// impl_from_request!([T1], T2); -// impl_from_request!([T1, T2], T3); -// impl_from_request!([T1, T2, T3], T4); -// impl_from_request!([T1, T2, T3, T4], T5); -// impl_from_request!([T1, T2, T3, T4, T5], T6); -// impl_from_request!([T1, T2, T3, T4, T5, T6], T7); -// impl_from_request!([T1, T2, T3, T4, T5, T6, T7], T8); -// impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8], T9); -// impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); -// impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); -// impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); -// impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); -// impl_from_request!( -// [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], -// T14 -// ); -// impl_from_request!( -// [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], -// T15 -// ); -// impl_from_request!( -// [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], -// T16 -// ); +use super::{FromRequest, FromRequestParts}; +use crate::response::{IntoResponse, Response}; +use async_trait::async_trait; +use http::request::{Parts, Request}; +use std::convert::Infallible; + +#[async_trait] +impl FromRequestParts for () +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts(_: &mut Parts, _: &S) -> Result<(), Self::Rejection> { + Ok(()) + } +} + +mod private { + #[derive(Debug, Clone, Copy)] + pub enum TupleOnce {} +} + +macro_rules! impl_from_request { + ( + [$($ty:ident),*], $last:ident + ) => { + #[async_trait] + #[allow(non_snake_case, unused_mut, unused_variables)] + impl FromRequestParts for ($($ty,)* $last,) + where + $( $ty: FromRequestParts + Send, )* + $last: FromRequestParts + Send, + S: Send + Sync, + { + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + $( + let $ty = $ty::from_request_parts(parts, state) + .await + .map_err(|err| err.into_response())?; + )* + let $last = $last::from_request_parts(parts, state) + .await + .map_err(|err| err.into_response())?; + + Ok(($($ty,)* $last,)) + } + } + + // This impl must not be generic over M, otherwise it would conflict with the blanket + // implementation of `FromRequest` for `T: FromRequestParts`. + #[async_trait] + #[allow(non_snake_case, unused_mut, unused_variables)] + impl FromRequest for ($($ty,)* $last,) + where + $( $ty: FromRequestParts + Send, )* + $last: FromRequest + Send, + B: Send + 'static, + S: Send + Sync, + { + type Rejection = Response; + + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, body) = req.into_parts(); + + $( + let $ty = $ty::from_request_parts(&mut parts, state).await.map_err(|err| err.into_response())?; + )* + + let req = Request::from_parts(parts, body); + + let $last = $last::from_request(req, state).await.map_err(|err| err.into_response())?; + + Ok(($($ty,)* $last,)) + } + } + }; +} + +impl_from_request!([], T1); +impl_from_request!([T1], T2); +impl_from_request!([T1, T2], T3); +impl_from_request!([T1, T2, T3], T4); +impl_from_request!([T1, T2, T3, T4], T5); +impl_from_request!([T1, T2, T3, T4, T5], T6); +impl_from_request!([T1, T2, T3, T4, T5, T6], T7); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7], T8); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8], T9); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); +impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); +impl_from_request!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], + T14 +); +impl_from_request!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], + T15 +); +impl_from_request!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], + T16 +); + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use http::Method; + + use crate::extract::{FromRequest, FromRequestParts}; + + fn assert_from_request() + where + T: FromRequest<(), http_body::Full, M>, + { + } + + fn assert_from_request_parts>() {} + + #[test] + fn unit() { + assert_from_request_parts::<()>(); + assert_from_request::<_, ()>(); + } + + #[test] + fn tuple_of_one() { + assert_from_request_parts::<(Method,)>(); + assert_from_request::<_, (Method,)>(); + assert_from_request::<_, (Bytes,)>(); + } + + #[test] + fn tuple_of_two() { + assert_from_request_parts::<((), ())>(); + assert_from_request::<_, ((), ())>(); + assert_from_request::<_, (Method, Bytes)>(); + } +}