diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index f9179bb42d..b3363d9976 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -7,7 +7,7 @@ use axum::{ async_trait, body::{self, BoxBody, Bytes, Full}, - extract::{FromRequest, RequestParts}, + extract::FromRequest, http::{Request, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, @@ -72,31 +72,28 @@ fn do_thing_with_request_body(bytes: Bytes) { tracing::debug!(body = ?bytes); } -async fn handler(_: PrintRequestBody, body: Bytes) { +async fn handler(BufferRequestBody(body): BufferRequestBody) { tracing::debug!(?body, "handler received body"); } // extractor that shows how to consume the request body upfront -struct PrintRequestBody; +struct BufferRequestBody(Bytes); +// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body #[async_trait] -impl FromRequest for PrintRequestBody +impl FromRequest for BufferRequestBody where - S: Clone + Send + Sync, + S: Send + Sync, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { - let state = req.state().clone(); - - let request = Request::from_request(req) + async fn from_request(req: Request, state: &S) -> Result { + let body = Bytes::from_request(req, state) .await .map_err(|err| err.into_response())?; - let request = buffer_request_body(request).await?; - - *req = RequestParts::with_state(state, request); + do_thing_with_request_body(body.clone()); - Ok(Self) + Ok(Self(body)) } } diff --git a/examples/customize-extractor-error/src/custom_extractor.rs b/examples/customize-extractor-error/src/custom_extractor.rs index c7e9d0f954..10aa0f0047 100644 --- a/examples/customize-extractor-error/src/custom_extractor.rs +++ b/examples/customize-extractor-error/src/custom_extractor.rs @@ -4,15 +4,13 @@ //! and `async/await`. This means that you can create more powerful rejections //! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Complexity: Manually implementing `FromRequest` results on more complex code -use axum::extract::MatchedPath; use axum::{ async_trait, - extract::{rejection::JsonRejection, FromRequest, RequestParts}, + extract::{rejection::JsonRejection, FromRequest, FromRequestParts, MatchedPath}, + http::Request, http::StatusCode, response::IntoResponse, - BoxError, }; -use serde::de::DeserializeOwned; use serde_json::{json, Value}; pub async fn handler(Json(value): Json) -> impl IntoResponse { @@ -25,31 +23,33 @@ pub struct Json(pub T); #[async_trait] impl FromRequest for Json where + axum::Json: FromRequest, S: Send + Sync, - // these trait bounds are copied from `impl FromRequest for axum::Json` - // `T: Send` is required to send this future across an await - T: DeserializeOwned + Send, - B: axum::body::HttpBody + Send, - B::Data: Send, - B::Error: Into, + B: Send + 'static, { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { - match axum::Json::::from_request(req).await { + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, body) = req.into_parts(); + + // We can use other extractors to provide better rejection + // messages. For example, here we are using + // `axum::extract::MatchedPath` to provide a better error + // message + // + // Have to run that first since `Json::from_request` consumes + // the request + let path = MatchedPath::from_request_parts(&mut parts, state) + .await + .map(|path| path.as_str().to_owned()) + .ok(); + + let req = Request::from_parts(parts, body); + + match axum::Json::::from_request(req, state).await { Ok(value) => Ok(Self(value.0)), // convert the error from `axum::Json` into whatever we want Err(rejection) => { - let path = req - .extract::() - .await - .map(|x| x.as_str().to_owned()) - .ok(); - - // We can use other extractors to provide better rejection - // messages. For example, here we are using - // `axum::extract::MatchedPath` to provide a better error - // message let payload = json!({ "message": rejection.to_string(), "origin": "custom_extractor", diff --git a/examples/customize-extractor-error/src/derive_from_request.rs b/examples/customize-extractor-error/src/derive_from_request.rs index 2a1625c008..762d602e5f 100644 --- a/examples/customize-extractor-error/src/derive_from_request.rs +++ b/examples/customize-extractor-error/src/derive_from_request.rs @@ -47,7 +47,7 @@ impl From for ApiError { } } -// We implement `IntoResponse` so ApiError can be used as a response +// We implement `IntoResponse` so `ApiError` can be used as a response impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let payload = json!({ diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index c4923069e5..68baf8f879 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -6,8 +6,8 @@ use axum::{ async_trait, - extract::{path::ErrorKind, rejection::PathRejection, FromRequest, RequestParts}, - http::StatusCode, + extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}, + http::{request::Parts, StatusCode}, response::IntoResponse, routing::get, Router, @@ -52,17 +52,16 @@ struct Params { struct Path(T); #[async_trait] -impl FromRequest for Path +impl FromRequestParts for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, - B: Send, S: Send + Sync, { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { - match axum::extract::Path::::from_request(req).await { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + match axum::extract::Path::::from_request_parts(parts, state).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { let (status, body) = match rejection { diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index 914ae18155..32f5704923 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -65,8 +65,8 @@ async fn users_show( /// Handler for `POST /users`. async fn users_create( - Json(params): Json, State(user_repo): State, + Json(params): Json, ) -> Result, AppError> { let user = user_repo.create(params).await?; diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 84d51f4409..3cef04c74b 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -8,9 +8,9 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts, TypedHeader}, + extract::{FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, - http::StatusCode, + http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, @@ -122,17 +122,16 @@ impl AuthBody { } #[async_trait] -impl FromRequest for Claims +impl FromRequestParts for Claims where S: Send + Sync, - B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = - TypedHeader::>::from_request(req) + TypedHeader::>::from_request_parts(parts, state) .await .map_err(|_| AuthError::InvalidToken)?; // Decode the user data diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index c65ee75a0f..4fb9e45bd0 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -96,8 +96,8 @@ async fn kv_get( async fn kv_set( Path(key): Path, - ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb State(state): State, + ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb ) { state.write().unwrap().db.insert(key, bytes); } diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index a61113b97d..079c65eb15 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,15 +12,14 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, - TypedHeader, + rejection::TypedHeaderRejectionReason, FromRef, FromRequestParts, Query, State, TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, Router, }; -use http::header; +use http::{header, request::Parts}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, @@ -139,7 +138,7 @@ async fn discord_auth(State(client): State) -> impl IntoResponse { .url(); // Redirect to Discord's oauth service - Redirect::to(&auth_url.to_string()) + Redirect::to(auth_url.as_ref()) } // Valid user session required. If there is none, redirect to the auth page @@ -224,17 +223,18 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl FromRequest for User +impl FromRequestParts for User where - B: Send, + MemoryStore: FromRef, + S: Send + Sync, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts) -> Result { - let store = req.state().clone().store; + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let store = MemoryStore::from_ref(state); - let cookies = TypedHeader::::from_request(req) + let cookies = TypedHeader::::from_request_parts(parts, state) .await .map_err(|e| match *e.name() { header::COOKIE => match e.reason() { diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index cd0d41a1f6..05d676faf0 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -7,11 +7,12 @@ use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{FromRequest, RequestParts, TypedHeader}, + extract::{FromRef, FromRequestParts, TypedHeader}, headers::Cookie, http::{ self, header::{HeaderMap, HeaderValue}, + request::Parts, StatusCode, }, response::IntoResponse, @@ -80,16 +81,19 @@ enum UserIdFromSession { } #[async_trait] -impl FromRequest for UserIdFromSession +impl FromRequestParts for UserIdFromSession where - B: Send, + MemoryStore: FromRef, + S: Send + Sync, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { - let store = req.state().clone(); + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let store = MemoryStore::from_ref(state); - let cookie = req.extract::>>().await.unwrap(); + let cookie = Option::>::from_request_parts(parts, state) + .await + .unwrap(); let session_cookie = cookie .as_ref() diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 6548cdeb97..9ba41ed804 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -15,8 +15,8 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts, State}, - http::StatusCode, + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, routing::get, Router, }; @@ -75,14 +75,15 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequestParts for DatabaseConnection where - B: Send, + PgPool: FromRef, + S: Send + Sync, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { - let pool = req.state().clone(); + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = PgPool::from_ref(state); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index b82a308db4..f323ebb318 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -105,7 +105,7 @@ struct CreateTodo { text: String, } -async fn todos_create(Json(input): Json, State(db): State) -> impl IntoResponse { +async fn todos_create(State(db): State, Json(input): Json) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, @@ -125,8 +125,8 @@ struct UpdateTodo { async fn todos_update( Path(id): Path, - Json(input): Json, State(db): State, + Json(input): Json, ) -> Result { let mut todo = db .read() diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index e0c60453e3..5e60c2bb78 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -6,8 +6,8 @@ use axum::{ async_trait, - extract::{FromRequest, RequestParts, State}, - http::StatusCode, + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, routing::get, Router, }; @@ -68,16 +68,15 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequestParts for DatabaseConnection where - B: Send, + ConnectionPool: FromRef, + S: Send + Sync, { type Rejection = (StatusCode, String); - async fn from_request( - req: &mut RequestParts, - ) -> Result { - let pool = req.state().clone(); + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = ConnectionPool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index e5988de290..359be614d3 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -12,11 +12,11 @@ use async_trait::async_trait; use axum::{ - extract::{Form, FromRequest, RequestParts}, - http::StatusCode, + extract::{rejection::FormRejection, Form, FromRequest}, + http::{Request, StatusCode}, response::{Html, IntoResponse, Response}, routing::get, - BoxError, Router, + Router, }; use serde::{de::DeserializeOwned, Deserialize}; use std::net::SocketAddr; @@ -64,14 +64,13 @@ impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, S: Send + Sync, - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, + Form: FromRequest, + B: Send + 'static, { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts) -> Result { - let Form(value) = Form::::from_request(req).await?; + async fn from_request(req: Request, state: &S) -> Result { + let Form(value) = Form::::from_request(req, state).await?; value.validate()?; Ok(ValidatedForm(value)) } @@ -83,7 +82,7 @@ pub enum ServerError { ValidationError(#[from] validator::ValidationErrors), #[error(transparent)] - AxumFormRejection(#[from] axum::extract::rejection::FormRejection), + AxumFormRejection(#[from] FormRejection), } impl IntoResponse for ServerError { diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index 6b53f77e91..2f67e33501 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -6,8 +6,8 @@ use axum::{ async_trait, - extract::{FromRequest, Path, RequestParts}, - http::StatusCode, + extract::{FromRequestParts, Path}, + http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, Router, @@ -48,15 +48,14 @@ enum Version { } #[async_trait] -impl FromRequest for Version +impl FromRequestParts for Version where - B: Send, S: Send + Sync, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { - let params = Path::>::from_request(req) + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let params = Path::>::from_request_parts(parts, state) .await .map_err(IntoResponse::into_response)?;