diff --git a/axum-core/src/extract/from_ref.rs b/axum-core/src/extract/from_ref.rs new file mode 100644 index 0000000000..c0124140e5 --- /dev/null +++ b/axum-core/src/extract/from_ref.rs @@ -0,0 +1,23 @@ +/// Used to do reference-to-value conversions thus not consuming the input value. +/// +/// This is mainly used with [`State`] to extract "substates" from a reference to main application +/// state. +/// +/// See [`State`] for more details on how library authors should use this trait. +/// +/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html +// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is +// defined in axum. That allows crate authors to use it when implementing extractors. +pub trait FromRef { + /// Converts to this type from a reference to the input type. + fn from_ref(input: &T) -> Self; +} + +impl FromRef for T +where + T: Clone, +{ + fn from_ref(input: &T) -> Self { + input.clone() + } +} diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index f9de0f399d..ade7bf0345 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -12,9 +12,12 @@ use std::convert::Infallible; pub mod rejection; +mod from_ref; mod request_parts; mod tuple; +pub use self::from_ref::FromRef; + /// Types that can be created from requests. /// /// See [`axum::extract`] for more details. diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 38ecbbe54b..44842d039c 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -227,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) { #[cfg(test)] mod tests { use super::*; - use axum::{body::Body, http::Request, routing::get, Router}; + use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router}; use tower::ServiceExt; macro_rules! cookie_test { @@ -308,15 +308,15 @@ mod tests { custom_key: CustomKey, } - impl From for Key { - fn from(state: AppState) -> Key { - state.key + impl FromRef for Key { + fn from_ref(state: &AppState) -> Key { + state.key.clone() } } - impl From for CustomKey { - fn from(state: AppState) -> CustomKey { - state.custom_key + impl FromRef for CustomKey { + fn from_ref(state: &AppState) -> CustomKey { + state.custom_key.clone() } } diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 7540b1764f..d3705fb2be 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -1,7 +1,7 @@ use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::{FromRef, FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::PrivateJar; @@ -23,7 +23,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// use axum::{ /// Router, /// routing::{post, get}, -/// extract::TypedHeader, +/// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, @@ -51,9 +51,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// } /// /// // this impl tells `SignedCookieJar` how to access the key from our state -/// impl From for Key { -/// fn from(state: AppState) -> Self { -/// state.key +/// impl FromRef for Key { +/// fn from_ref(state: &AppState) -> Self { +/// state.key.clone() /// } /// } /// @@ -90,15 +90,14 @@ impl fmt::Debug for PrivateCookieJar { impl FromRequest for PrivateCookieJar where B: Send, - S: Into + Clone + Send, - K: Into + Clone + Send + Sync + 'static, + S: Send, + K: FromRef + Into, { type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let state = req.state().clone(); - let key: K = state.into(); - let key: Key = key.into(); + let k = K::from_ref(req.state()); + let key = k.into(); let PrivateCookieJar { jar, key, diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index d56a8eb567..74da2a11ae 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -1,7 +1,7 @@ use super::{cookies_from_request, set_cookies}; use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::{FromRef, FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::SignedJar; @@ -24,7 +24,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// use axum::{ /// Router, /// routing::{post, get}, -/// extract::TypedHeader, +/// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, @@ -69,9 +69,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// } /// /// // this impl tells `SignedCookieJar` how to access the key from our state -/// impl From for Key { -/// fn from(state: AppState) -> Self { -/// state.key +/// impl FromRef for Key { +/// fn from_ref(state: &AppState) -> Self { +/// state.key.clone() /// } /// } /// @@ -108,15 +108,14 @@ impl fmt::Debug for SignedCookieJar { impl FromRequest for SignedCookieJar where B: Send, - S: Into + Clone + Send, - K: Into + Clone + Send + Sync + 'static, + S: Send, + K: FromRef + Into, { type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let state = req.state().clone(); - let key: K = state.into(); - let key: Key = key.into(); + let k = K::from_ref(req.state()); + let key = k.into(); let SignedCookieJar { jar, key, diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index c4aeb3ad70..081793a83c 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -17,7 +17,7 @@ mod request_parts; mod state; #[doc(inline)] -pub use axum_core::extract::{FromRequest, RequestParts}; +pub use axum_core::extract::{FromRef, FromRequest, RequestParts}; #[doc(inline)] #[allow(deprecated)] diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 9991ad45c2..94ccf5b1a6 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use axum_core::extract::{FromRequest, RequestParts}; +use axum_core::extract::{FromRef, FromRequest, RequestParts}; use std::{ convert::Infallible, ops::{Deref, DerefMut}, @@ -91,7 +91,7 @@ use std::{ /// [`State`] only allows a single state type but you can use [`From`] to extract "substates": /// /// ``` -/// use axum::{Router, routing::get, extract::State}; +/// use axum::{Router, routing::get, extract::{State, FromRef}}; /// /// // the application state /// #[derive(Clone)] @@ -105,9 +105,9 @@ use std::{ /// struct ApiState {} /// /// // support converting an `AppState` in an `ApiState` -/// impl From for ApiState { -/// fn from(app_state: AppState) -> ApiState { -/// app_state.api_state +/// impl FromRef for ApiState { +/// fn from_ref(app_state: &AppState) -> ApiState { +/// app_state.api_state.clone() /// } /// } /// @@ -139,7 +139,7 @@ use std::{ /// to do it: /// /// ```rust -/// use axum_core::extract::{FromRequest, RequestParts}; +/// use axum_core::extract::{FromRequest, RequestParts, FromRef}; /// use async_trait::async_trait; /// use std::convert::Infallible; /// @@ -151,14 +151,15 @@ use std::{ /// where /// B: Send, /// // keep `S` generic but require that it can produce a `MyLibraryState` -/// // this means users will have to implement `From for MyLibraryState` -/// S: Into + Clone + Send, +/// // this means users will have to implement `FromRef for MyLibraryState` +/// MyLibraryState: FromRef, +/// S: Send, /// { /// type Rejection = Infallible; /// /// async fn from_request(req: &mut RequestParts) -> Result { -/// // get a `MyLibraryState` from the shared application state -/// let state: MyLibraryState = req.state().clone().into(); +/// // get a `MyLibraryState` from a reference to the state +/// let state = MyLibraryState::from_ref(req.state()); /// /// // ... /// # todo!() @@ -180,13 +181,13 @@ pub struct State(pub S); impl FromRequest for State where B: Send, - OuterState: Clone + Into + Send, + InnerState: FromRef, + OuterState: Send, { type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let outer_state = req.state().clone(); - let inner_state = outer_state.into(); + let inner_state = InnerState::from_ref(req.state()); Ok(Self(inner_state)) } } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index a90ca2110c..fa46cdf86c 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1,7 +1,7 @@ use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, - extract::{self, Path, State}, + extract::{self, FromRef, Path, State}, handler::{Handler, HandlerWithoutStateExt}, response::IntoResponse, routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, @@ -654,9 +654,9 @@ async fn extracting_state() { value: i32, } - impl From for InnerState { - fn from(state: AppState) -> Self { - state.inner + impl FromRef for InnerState { + fn from_ref(state: &AppState) -> Self { + state.inner.clone() } } diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 303ca1647f..a61113b97d 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,7 +12,8 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, FromRequest, Query, RequestParts, State, TypedHeader, + rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, + TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, @@ -69,15 +70,15 @@ struct AppState { oauth_client: BasicClient, } -impl From for MemoryStore { - fn from(state: AppState) -> Self { - state.store +impl FromRef for MemoryStore { + fn from_ref(state: &AppState) -> Self { + state.store.clone() } } -impl From for BasicClient { - fn from(state: AppState) -> Self { - state.oauth_client +impl FromRef for BasicClient { + fn from_ref(state: &AppState) -> Self { + state.oauth_client.clone() } } diff --git a/examples/query-params-with-empty-strings/src/main.rs b/examples/query-params-with-empty-strings/src/main.rs index 7e9a08894c..0af20111d7 100644 --- a/examples/query-params-with-empty-strings/src/main.rs +++ b/examples/query-params-with-empty-strings/src/main.rs @@ -16,7 +16,7 @@ async fn main() { .unwrap(); } -fn app() -> Router<()> { +fn app() -> Router { Router::new().route("/", get(handler)) } diff --git a/examples/routes-and-handlers-close-together/src/main.rs b/examples/routes-and-handlers-close-together/src/main.rs index 6fc75c9e41..41aaa49db5 100644 --- a/examples/routes-and-handlers-close-together/src/main.rs +++ b/examples/routes-and-handlers-close-together/src/main.rs @@ -25,7 +25,7 @@ async fn main() { .unwrap(); } -fn root() -> Router<()> { +fn root() -> Router { async fn handler() -> &'static str { "Hello, World!" } @@ -33,7 +33,7 @@ fn root() -> Router<()> { route("/", get(handler)) } -fn get_foo() -> Router<()> { +fn get_foo() -> Router { async fn handler() -> &'static str { "Hi from `GET /foo`" } @@ -41,7 +41,7 @@ fn get_foo() -> Router<()> { route("/foo", get(handler)) } -fn post_foo() -> Router<()> { +fn post_foo() -> Router { async fn handler() -> &'static str { "Hi from `POST /foo`" } @@ -49,6 +49,6 @@ fn post_foo() -> Router<()> { route("/foo", post(handler)) } -fn route(path: &str, method_router: MethodRouter<()>) -> Router<()> { +fn route(path: &str, method_router: MethodRouter<()>) -> Router { Router::new().route(path, method_router) } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index f31935188a..0bb9b352a0 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -34,7 +34,7 @@ async fn main() { /// Having a function that produces our app makes it easy to call it from tests /// without having to create an HTTP server. #[allow(dead_code)] -fn app() -> Router<()> { +fn app() -> Router { Router::new() .route("/", get(|| async { "Hello, World!" })) .route(