Skip to content

Commit

Permalink
Add middleware::from_extractor_with_state
Browse files Browse the repository at this point in the history
Fixes #1373
  • Loading branch information
davidpdrsn committed Sep 19, 2022
1 parent 4ade706 commit 990d9cf
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 41 deletions.
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
rejections, instead of `422 Unprocessable Entity` ([#1387])
- **added:** Add `middleware::from_extractor_with_state` and
`middleware::from_extractor_with_state_arc`

[#1371]: https://github.com/tokio-rs/axum/pull/1371
[#1387]: https://github.com/tokio-rs/axum/pull/1387
Expand Down
125 changes: 85 additions & 40 deletions axum/src/middleware/from_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
Expand Down Expand Up @@ -90,8 +91,25 @@ use tower_service::Service;
/// ```
///
/// [`Bytes`]: bytes::Bytes
pub fn from_extractor<E>() -> FromExtractorLayer<E> {
FromExtractorLayer(PhantomData)
pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
from_extractor_with_state(())
}

/// Create a middleware from an extractor with the given state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
from_extractor_with_state_arc(Arc::new(state))
}

/// Create a middleware from an extractor with the given [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_extractor_with_state_arc<E, S>(state: Arc<S>) -> FromExtractorLayer<E, S> {
FromExtractorLayer {
state,
_marker: PhantomData,
}
}

/// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
Expand All @@ -100,28 +118,39 @@ pub fn from_extractor<E>() -> FromExtractorLayer<E> {
/// See [`from_extractor`] for more details.
///
/// [`Layer`]: tower::Layer
pub struct FromExtractorLayer<E>(PhantomData<fn() -> E>);
pub struct FromExtractorLayer<E, S> {
state: Arc<S>,
_marker: PhantomData<fn() -> E>,
}

impl<E> Clone for FromExtractorLayer<E> {
impl<E, S> Clone for FromExtractorLayer<E, S> {
fn clone(&self) -> Self {
Self(PhantomData)
Self {
state: Arc::clone(&self.state),
_marker: PhantomData,
}
}
}

impl<E> fmt::Debug for FromExtractorLayer<E> {
impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromExtractorLayer")
.field("state", &self.state)
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
.finish()
}
}

impl<E, S> Layer<S> for FromExtractorLayer<E> {
type Service = FromExtractor<S, E>;
impl<E, T, S> Layer<T> for FromExtractorLayer<E, S> {
type Service = FromExtractor<T, E, S>;

fn layer(&self, inner: S) -> Self::Service {
fn layer(&self, inner: T) -> Self::Service {
FromExtractor {
inner,
state: Arc::clone(&self.state),
_extractor: PhantomData,
}
}
Expand All @@ -130,62 +159,68 @@ impl<E, S> Layer<S> for FromExtractorLayer<E> {
/// Middleware that runs an extractor and discards the value.
///
/// See [`from_extractor`] for more details.
pub struct FromExtractor<S, E> {
inner: S,
pub struct FromExtractor<T, E, S> {
inner: T,
state: Arc<S>,
_extractor: PhantomData<fn() -> E>,
}

#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<FromExtractor<(), NotSendSync>>();
assert_sync::<FromExtractor<(), NotSendSync>>();
assert_send::<FromExtractor<(), NotSendSync, ()>>();
assert_sync::<FromExtractor<(), NotSendSync, ()>>();
}

impl<S, E> Clone for FromExtractor<S, E>
impl<T, E, S> Clone for FromExtractor<T, E, S>
where
S: Clone,
T: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: Arc::clone(&self.state),
_extractor: PhantomData,
}
}
}

impl<S, E> fmt::Debug for FromExtractor<S, E>
impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
where
T: fmt::Debug,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromExtractor")
.field("inner", &self.inner)
.field("state", &self.state)
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
.finish()
}
}

impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
where
E: FromRequestParts<()> + 'static,
E: FromRequestParts<S> + 'static,
B: Default + Send + 'static,
S: Service<Request<B>> + Clone,
S::Response: IntoResponse,
T: Service<Request<B>> + Clone,
T::Response: IntoResponse,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = ResponseFuture<B, S, E>;
type Error = T::Error;
type Future = ResponseFuture<B, T, E, S>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<B>) -> Self::Future {
let state = Arc::clone(&self.state);
let extract_future = Box::pin(async move {
let (mut parts, body) = req.into_parts();
let extracted = E::from_request_parts(&mut parts, &()).await;
let extracted = E::from_request_parts(&mut parts, &state).await;
let req = Request::from_parts(parts, body);
(req, extracted)
});
Expand All @@ -202,39 +237,39 @@ where
pin_project! {
/// Response future for [`FromExtractor`].
#[allow(missing_debug_implementations)]
pub struct ResponseFuture<B, S, E>
pub struct ResponseFuture<B, T, E, S>
where
E: FromRequestParts<()>,
S: Service<Request<B>>,
E: FromRequestParts<S>,
T: Service<Request<B>>,
{
#[pin]
state: State<B, S, E>,
svc: Option<S>,
state: State<B, T, E, S>,
svc: Option<T>,
}
}

pin_project! {
#[project = StateProj]
enum State<B, S, E>
enum State<B, T, E, S>
where
E: FromRequestParts<()>,
S: Service<Request<B>>,
E: FromRequestParts<S>,
T: Service<Request<B>>,
{
Extracting {
future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
},
Call { #[pin] future: S::Future },
Call { #[pin] future: T::Future },
}
}

impl<B, S, E> Future for ResponseFuture<B, S, E>
impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
where
E: FromRequestParts<()>,
S: Service<Request<B>>,
S::Response: IntoResponse,
E: FromRequestParts<S>,
T: Service<Request<B>>,
T::Response: IntoResponse,
B: Default,
{
type Output = Result<Response, S::Error>;
type Output = Result<Response, T::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
Expand Down Expand Up @@ -272,29 +307,35 @@ where
mod tests {
use super::*;
use crate::{handler::Handler, routing::get, test_helpers::*, Router};
use axum_core::extract::FromRef;
use http::{header, request::Parts, StatusCode};

#[tokio::test]
async fn test_from_extractor() {
#[derive(Clone)]
struct Secret(&'static str);

struct RequireAuth;

#[async_trait::async_trait]
impl<S> FromRequestParts<S> for RequireAuth
where
S: Send + Sync,
Secret: FromRef<S>,
{
type Rejection = StatusCode;

async fn from_request_parts(
parts: &mut Parts,
_state: &S,
state: &S,
) -> Result<Self, Self::Rejection> {
let Secret(secret) = Secret::from_ref(state);
if let Some(auth) = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
if auth == "secret" {
if auth == secret {
return Ok(Self);
}
}
Expand All @@ -305,7 +346,11 @@ mod tests {

async fn handler() {}

let app = Router::new().route("/", get(handler.layer(from_extractor::<RequireAuth>())));
let state = Secret("secret");
let app = Router::new().route(
"/",
get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
);

let client = TestClient::new(app);

Expand Down
5 changes: 4 additions & 1 deletion axum/src/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
mod from_extractor;
mod from_fn;

pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
pub use self::from_extractor::{
from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor,
FromExtractorLayer,
};
pub use self::from_fn::{
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
};
Expand Down

0 comments on commit 990d9cf

Please sign in to comment.