Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FromRef trait #1268

Merged
merged 5 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions axum-core/src/extract/from_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/// Used to do reference-to-value conversions thus not consuming the input value.
///
/// This is mainly used with [`State`] to extract "substates" from a reference to main application
/// state.
///
/// See [`State`] for more details on how library authors should use this trait.
///
/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is
// defined in axum. That allows crate authors to use it when implementing extractors.
pub trait FromRef<T> {
/// Converts to this type from a reference to the input type.
fn from_ref(input: &T) -> Self;
}

impl<T> FromRef<T> for T
where
T: Clone,
{
fn from_ref(input: &T) -> Self {
input.clone()
}
}
3 changes: 3 additions & 0 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ use std::convert::Infallible;

pub mod rejection;

mod from_ref;
mod request_parts;
mod tuple;

pub use self::from_ref::FromRef;

/// Types that can be created from requests.
///
/// See [`axum::extract`] for more details.
Expand Down
14 changes: 7 additions & 7 deletions axum-extra/src/extract/cookie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, routing::get, Router};
use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use tower::ServiceExt;

macro_rules! cookie_test {
Expand Down Expand Up @@ -308,15 +308,15 @@ mod tests {
custom_key: CustomKey,
}

impl From<AppState> for Key {
fn from(state: AppState) -> Key {
state.key
impl FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Key {
state.key.clone()
}
}

impl From<AppState> for CustomKey {
fn from(state: AppState) -> CustomKey {
state.custom_key
impl FromRef<AppState> for CustomKey {
fn from_ref(state: &AppState) -> CustomKey {
state.custom_key.clone()
}
}

Expand Down
19 changes: 9 additions & 10 deletions axum-extra/src/extract/cookie/private.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{cookies_from_request, set_cookies, Cookie, Key};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use cookie::PrivateJar;
Expand All @@ -23,7 +23,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// use axum::{
/// Router,
/// routing::{post, get},
/// extract::TypedHeader,
/// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer},
/// http::StatusCode,
Expand Down Expand Up @@ -51,9 +51,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// }
///
/// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl From<AppState> for Key {
/// fn from(state: AppState) -> Self {
/// state.key
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
Expand Down Expand Up @@ -90,15 +90,14 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
where
B: Send,
S: Into<K> + Clone + Send,
K: Into<Key> + Clone + Send + Sync + 'static,
S: Send,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let state = req.state().clone();
let key: K = state.into();
let key: Key = key.into();
let k = K::from_ref(req.state());
let key = k.into();
let PrivateCookieJar {
jar,
key,
Expand Down
19 changes: 9 additions & 10 deletions axum-extra/src/extract/cookie/signed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{cookies_from_request, set_cookies};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use cookie::SignedJar;
Expand All @@ -24,7 +24,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// use axum::{
/// Router,
/// routing::{post, get},
/// extract::TypedHeader,
/// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer},
/// http::StatusCode,
Expand Down Expand Up @@ -69,9 +69,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// }
///
/// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl From<AppState> for Key {
/// fn from(state: AppState) -> Self {
/// state.key
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
Expand Down Expand Up @@ -108,15 +108,14 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
where
B: Send,
S: Into<K> + Clone + Send,
K: Into<Key> + Clone + Send + Sync + 'static,
S: Send,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let state = req.state().clone();
let key: K = state.into();
let key: Key = key.into();
let k = K::from_ref(req.state());
let key = k.into();
let SignedCookieJar {
jar,
key,
Expand Down
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod request_parts;
mod state;

#[doc(inline)]
pub use axum_core::extract::{FromRequest, RequestParts};
pub use axum_core::extract::{FromRef, FromRequest, RequestParts};

#[doc(inline)]
#[allow(deprecated)]
Expand Down
27 changes: 14 additions & 13 deletions axum/src/extract/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use axum_core::extract::{FromRef, FromRequest, RequestParts};
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
Expand Down Expand Up @@ -91,7 +91,7 @@ use std::{
/// [`State`] only allows a single state type but you can use [`From`] to extract "substates":
///
/// ```
/// use axum::{Router, routing::get, extract::State};
/// use axum::{Router, routing::get, extract::{State, FromRef}};
///
/// // the application state
/// #[derive(Clone)]
Expand All @@ -105,9 +105,9 @@ use std::{
/// struct ApiState {}
///
/// // support converting an `AppState` in an `ApiState`
/// impl From<AppState> for ApiState {
/// fn from(app_state: AppState) -> ApiState {
/// app_state.api_state
/// impl FromRef<AppState> for ApiState {
/// fn from_ref(app_state: &AppState) -> ApiState {
/// app_state.api_state.clone()
/// }
/// }
///
Expand Down Expand Up @@ -139,7 +139,7 @@ use std::{
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequest, RequestParts};
/// use axum_core::extract::{FromRequest, RequestParts, FromRef};
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
Expand All @@ -151,14 +151,15 @@ use std::{
/// where
/// B: Send,
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `From<UserState> for MyLibraryState`
/// S: Into<MyLibraryState> + Clone + Send,
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
/// S: Send,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from the shared application state
/// let state: MyLibraryState = req.state().clone().into();
/// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(req.state());
///
/// // ...
/// # todo!()
Expand All @@ -180,13 +181,13 @@ pub struct State<S>(pub S);
impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState>
where
B: Send,
OuterState: Clone + Into<InnerState> + Send,
InnerState: FromRef<OuterState>,
OuterState: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<OuterState, B>) -> Result<Self, Self::Rejection> {
let outer_state = req.state().clone();
let inner_state = outer_state.into();
let inner_state = InnerState::from_ref(req.state());
Ok(Self(inner_state))
}
}
Expand Down
8 changes: 4 additions & 4 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
body::{Bytes, Empty},
error_handling::HandleErrorLayer,
extract::{self, Path, State},
extract::{self, FromRef, Path, State},
handler::{Handler, HandlerWithoutStateExt},
response::IntoResponse,
routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter},
Expand Down Expand Up @@ -654,9 +654,9 @@ async fn extracting_state() {
value: i32,
}

impl From<AppState> for InnerState {
fn from(state: AppState) -> Self {
state.inner
impl FromRef<AppState> for InnerState {
fn from_ref(state: &AppState) -> Self {
state.inner.clone()
}
}

Expand Down
15 changes: 8 additions & 7 deletions examples/oauth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use async_session::{MemoryStore, Session, SessionStore};
use axum::{
async_trait,
extract::{
rejection::TypedHeaderRejectionReason, FromRequest, Query, RequestParts, State, TypedHeader,
rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State,
TypedHeader,
},
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
Expand Down Expand Up @@ -69,15 +70,15 @@ struct AppState {
oauth_client: BasicClient,
}

impl From<AppState> for MemoryStore {
fn from(state: AppState) -> Self {
state.store
impl FromRef<AppState> for MemoryStore {
fn from_ref(state: &AppState) -> Self {
state.store.clone()
}
}

impl From<AppState> for BasicClient {
fn from(state: AppState) -> Self {
state.oauth_client
impl FromRef<AppState> for BasicClient {
fn from_ref(state: &AppState) -> Self {
state.oauth_client.clone()
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/query-params-with-empty-strings/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async fn main() {
.unwrap();
}

fn app() -> Router<()> {
fn app() -> Router {
Router::new().route("/", get(handler))
}

Expand Down
8 changes: 4 additions & 4 deletions examples/routes-and-handlers-close-together/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,30 @@ async fn main() {
.unwrap();
}

fn root() -> Router<()> {
fn root() -> Router {
async fn handler() -> &'static str {
"Hello, World!"
}

route("/", get(handler))
}

fn get_foo() -> Router<()> {
fn get_foo() -> Router {
async fn handler() -> &'static str {
"Hi from `GET /foo`"
}

route("/foo", get(handler))
}

fn post_foo() -> Router<()> {
fn post_foo() -> Router {
async fn handler() -> &'static str {
"Hi from `POST /foo`"
}

route("/foo", post(handler))
}

fn route(path: &str, method_router: MethodRouter<()>) -> Router<()> {
fn route(path: &str, method_router: MethodRouter<()>) -> Router {
Router::new().route(path, method_router)
}
2 changes: 1 addition & 1 deletion examples/testing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async fn main() {
/// Having a function that produces our app makes it easy to call it from tests
/// without having to create an HTTP server.
#[allow(dead_code)]
fn app() -> Router<()> {
fn app() -> Router {
Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route(
Expand Down