Skip to content

Commit

Permalink
Add FromRef trait (#1268)
Browse files Browse the repository at this point in the history
* Add `FromRef` trait

* Remove unnecessary type params

* format

* fix docs link

* format examples
  • Loading branch information
davidpdrsn authored Aug 17, 2022
1 parent e211b15 commit 96531b7
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 58 deletions.
23 changes: 23 additions & 0 deletions axum-core/src/extract/from_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/// Used to do reference-to-value conversions thus not consuming the input value.
///
/// This is mainly used with [`State`] to extract "substates" from a reference to main application
/// state.
///
/// See [`State`] for more details on how library authors should use this trait.
///
/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is
// defined in axum. That allows crate authors to use it when implementing extractors.
pub trait FromRef<T> {
/// Converts to this type from a reference to the input type.
fn from_ref(input: &T) -> Self;
}

impl<T> FromRef<T> for T
where
T: Clone,
{
fn from_ref(input: &T) -> Self {
input.clone()
}
}
3 changes: 3 additions & 0 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ use std::convert::Infallible;

pub mod rejection;

mod from_ref;
mod request_parts;
mod tuple;

pub use self::from_ref::FromRef;

/// Types that can be created from requests.
///
/// See [`axum::extract`] for more details.
Expand Down
14 changes: 7 additions & 7 deletions axum-extra/src/extract/cookie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, routing::get, Router};
use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use tower::ServiceExt;

macro_rules! cookie_test {
Expand Down Expand Up @@ -308,15 +308,15 @@ mod tests {
custom_key: CustomKey,
}

impl From<AppState> for Key {
fn from(state: AppState) -> Key {
state.key
impl FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Key {
state.key.clone()
}
}

impl From<AppState> for CustomKey {
fn from(state: AppState) -> CustomKey {
state.custom_key
impl FromRef<AppState> for CustomKey {
fn from_ref(state: &AppState) -> CustomKey {
state.custom_key.clone()
}
}

Expand Down
19 changes: 9 additions & 10 deletions axum-extra/src/extract/cookie/private.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{cookies_from_request, set_cookies, Cookie, Key};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use cookie::PrivateJar;
Expand All @@ -23,7 +23,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// use axum::{
/// Router,
/// routing::{post, get},
/// extract::TypedHeader,
/// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer},
/// http::StatusCode,
Expand Down Expand Up @@ -51,9 +51,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// }
///
/// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl From<AppState> for Key {
/// fn from(state: AppState) -> Self {
/// state.key
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
Expand Down Expand Up @@ -90,15 +90,14 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
where
B: Send,
S: Into<K> + Clone + Send,
K: Into<Key> + Clone + Send + Sync + 'static,
S: Send,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let state = req.state().clone();
let key: K = state.into();
let key: Key = key.into();
let k = K::from_ref(req.state());
let key = k.into();
let PrivateCookieJar {
jar,
key,
Expand Down
19 changes: 9 additions & 10 deletions axum-extra/src/extract/cookie/signed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{cookies_from_request, set_cookies};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use cookie::SignedJar;
Expand All @@ -24,7 +24,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// use axum::{
/// Router,
/// routing::{post, get},
/// extract::TypedHeader,
/// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer},
/// http::StatusCode,
Expand Down Expand Up @@ -69,9 +69,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// }
///
/// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl From<AppState> for Key {
/// fn from(state: AppState) -> Self {
/// state.key
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
Expand Down Expand Up @@ -108,15 +108,14 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
where
B: Send,
S: Into<K> + Clone + Send,
K: Into<Key> + Clone + Send + Sync + 'static,
S: Send,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let state = req.state().clone();
let key: K = state.into();
let key: Key = key.into();
let k = K::from_ref(req.state());
let key = k.into();
let SignedCookieJar {
jar,
key,
Expand Down
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod request_parts;
mod state;

#[doc(inline)]
pub use axum_core::extract::{FromRequest, RequestParts};
pub use axum_core::extract::{FromRef, FromRequest, RequestParts};

#[doc(inline)]
#[allow(deprecated)]
Expand Down
27 changes: 14 additions & 13 deletions axum/src/extract/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use axum_core::extract::{FromRef, FromRequest, RequestParts};
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
Expand Down Expand Up @@ -91,7 +91,7 @@ use std::{
/// [`State`] only allows a single state type but you can use [`From`] to extract "substates":
///
/// ```
/// use axum::{Router, routing::get, extract::State};
/// use axum::{Router, routing::get, extract::{State, FromRef}};
///
/// // the application state
/// #[derive(Clone)]
Expand All @@ -105,9 +105,9 @@ use std::{
/// struct ApiState {}
///
/// // support converting an `AppState` in an `ApiState`
/// impl From<AppState> for ApiState {
/// fn from(app_state: AppState) -> ApiState {
/// app_state.api_state
/// impl FromRef<AppState> for ApiState {
/// fn from_ref(app_state: &AppState) -> ApiState {
/// app_state.api_state.clone()
/// }
/// }
///
Expand Down Expand Up @@ -139,7 +139,7 @@ use std::{
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequest, RequestParts};
/// use axum_core::extract::{FromRequest, RequestParts, FromRef};
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
Expand All @@ -151,14 +151,15 @@ use std::{
/// where
/// B: Send,
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `From<UserState> for MyLibraryState`
/// S: Into<MyLibraryState> + Clone + Send,
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
/// S: Send,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from the shared application state
/// let state: MyLibraryState = req.state().clone().into();
/// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(req.state());
///
/// // ...
/// # todo!()
Expand All @@ -180,13 +181,13 @@ pub struct State<S>(pub S);
impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState>
where
B: Send,
OuterState: Clone + Into<InnerState> + Send,
InnerState: FromRef<OuterState>,
OuterState: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<OuterState, B>) -> Result<Self, Self::Rejection> {
let outer_state = req.state().clone();
let inner_state = outer_state.into();
let inner_state = InnerState::from_ref(req.state());
Ok(Self(inner_state))
}
}
Expand Down
8 changes: 4 additions & 4 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
body::{Bytes, Empty},
error_handling::HandleErrorLayer,
extract::{self, Path, State},
extract::{self, FromRef, Path, State},
handler::{Handler, HandlerWithoutStateExt},
response::IntoResponse,
routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter},
Expand Down Expand Up @@ -654,9 +654,9 @@ async fn extracting_state() {
value: i32,
}

impl From<AppState> for InnerState {
fn from(state: AppState) -> Self {
state.inner
impl FromRef<AppState> for InnerState {
fn from_ref(state: &AppState) -> Self {
state.inner.clone()
}
}

Expand Down
15 changes: 8 additions & 7 deletions examples/oauth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use async_session::{MemoryStore, Session, SessionStore};
use axum::{
async_trait,
extract::{
rejection::TypedHeaderRejectionReason, FromRequest, Query, RequestParts, State, TypedHeader,
rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State,
TypedHeader,
},
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
Expand Down Expand Up @@ -69,15 +70,15 @@ struct AppState {
oauth_client: BasicClient,
}

impl From<AppState> for MemoryStore {
fn from(state: AppState) -> Self {
state.store
impl FromRef<AppState> for MemoryStore {
fn from_ref(state: &AppState) -> Self {
state.store.clone()
}
}

impl From<AppState> for BasicClient {
fn from(state: AppState) -> Self {
state.oauth_client
impl FromRef<AppState> for BasicClient {
fn from_ref(state: &AppState) -> Self {
state.oauth_client.clone()
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/query-params-with-empty-strings/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async fn main() {
.unwrap();
}

fn app() -> Router<()> {
fn app() -> Router {
Router::new().route("/", get(handler))
}

Expand Down
8 changes: 4 additions & 4 deletions examples/routes-and-handlers-close-together/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,30 @@ async fn main() {
.unwrap();
}

fn root() -> Router<()> {
fn root() -> Router {
async fn handler() -> &'static str {
"Hello, World!"
}

route("/", get(handler))
}

fn get_foo() -> Router<()> {
fn get_foo() -> Router {
async fn handler() -> &'static str {
"Hi from `GET /foo`"
}

route("/foo", get(handler))
}

fn post_foo() -> Router<()> {
fn post_foo() -> Router {
async fn handler() -> &'static str {
"Hi from `POST /foo`"
}

route("/foo", post(handler))
}

fn route(path: &str, method_router: MethodRouter<()>) -> Router<()> {
fn route(path: &str, method_router: MethodRouter<()>) -> Router {
Router::new().route(path, method_router)
}
2 changes: 1 addition & 1 deletion examples/testing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async fn main() {
/// Having a function that produces our app makes it easy to call it from tests
/// without having to create an HTTP server.
#[allow(dead_code)]
fn app() -> Router<()> {
fn app() -> Router {
Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route(
Expand Down

0 comments on commit 96531b7

Please sign in to comment.