Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RequestExt and RequestPartsExt #1301

Merged
merged 3 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `MethodRouter`, defaults to `()`
- `FromRequest`, no default
- `Handler`, no default
- **added:** Add `RequestExt` and `RequestPartsExt` which adds convenience
methods for running extractors to `http::Request` and `http::request::Parts` ([#1301])

## Middleware

Expand Down Expand Up @@ -372,6 +374,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1248]: https://github.com/tokio-rs/axum/pull/1248
[#1272]: https://github.com/tokio-rs/axum/pull/1272
[#1301]: https://github.com/tokio-rs/axum/pull/1301
[#924]: https://github.com/tokio-rs/axum/pull/924

# 0.5.15 (9. August, 2022)
Expand Down
30 changes: 30 additions & 0 deletions axum/src/ext_traits/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
pub(crate) mod request;
pub(crate) mod request_parts;

#[cfg(test)]
mod tests {
use std::convert::Infallible;

use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;

// some extractor that requires the state, such as `SignedCookieJar`
pub(crate) struct RequiresState(pub(crate) String);

#[async_trait]
impl<S> FromRequestParts<S> for RequiresState
where
S: Send + Sync,
String: FromRef<S>,
{
type Rejection = Infallible;

async fn from_request_parts(
_parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Self(String::from_ref(state)))
}
}
}
200 changes: 200 additions & 0 deletions axum/src/ext_traits/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
use axum_core::extract::{FromRequest, FromRequestParts};
use futures_util::future::BoxFuture;
use http::Request;

mod sealed {
pub trait Sealed<B> {}
impl<B> Sealed<B> for http::Request<B> {}
}

/// Extension trait that adds additional methods to [`Request`].
pub trait RequestExt<B>: sealed::Sealed<B> + Sized {
/// Apply an extractor to this `Request`.
///
/// This is just a convenience for `E::from_request(req, &())`.
///
/// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting
/// the body and don't want to consume the request.
fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
where
E: FromRequest<(), B, M> + 'static,
M: 'static;

/// Apply an extractor that requires some state to this `Request`.
///
/// This is just a convenience for `E::from_request(req, state)`.
///
/// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not
/// extracting the body and don't want to consume the request.
fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
where
E: FromRequest<S, B, M> + 'static,
S: Send + Sync;

/// Apply a parts extractor to this `Request`.
///
/// This is just a convenience for `E::from_request_parts(parts, state)`.
fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
where
E: FromRequestParts<()> + 'static;

/// Apply a parts extractor that requires some state to this `Request`.
///
/// This is just a convenience for `E::from_request_parts(parts, state)`.
fn extract_parts_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;
}

impl<B> RequestExt<B> for Request<B>
where
B: Send + 'static,
{
fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
where
E: FromRequest<(), B, M> + 'static,
M: 'static,
{
self.extract_with_state(&())
}

fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
where
E: FromRequest<S, B, M> + 'static,
S: Send + Sync,
{
E::from_request(self, state)
}

fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
where
E: FromRequestParts<()> + 'static,
{
self.extract_parts_with_state(&())
}

fn extract_parts_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
where
E: FromRequestParts<S> + 'static,
S: Send + Sync,
{
let mut req = Request::new(());
*req.version_mut() = self.version();
*req.method_mut() = self.method().clone();
*req.uri_mut() = self.uri().clone();
*req.headers_mut() = std::mem::take(self.headers_mut());
*req.extensions_mut() = std::mem::take(self.extensions_mut());
let (mut parts, _) = req.into_parts();
Comment on lines +88 to +94
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this is kinda unfortunate. As an alternative, we could require B: Default and use mem::take, wdyt about that? Pretty much everything in http-body including the boxed body types seems to implement Default if the inner buffer does, the only exception being Limited but I'm sure there it can be added too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats probably fine. I'll change it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm actually its kinda weird having to write this on your extractor. Feels kinda arbitrary from the user perspective 🤔

@@ -180,7 +171,7 @@ mod tests {
     impl<S, B> FromRequest<S, B> for WorksForCustomExtractor
     where
         S: Send + Sync,
-        B: Send + 'static,
+        B: Send + Default + 'static,
         String: FromRef<S> + FromRequest<(), B>,
     {
         type Rejection = <String as FromRequest<(), B>>::Rejection;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I didn't consider that you'd want to use this from functions generic over B. Maybe http can be updated to provide .parts(&self) -> &Parts and .parts_mut(&mut self) -> &mut Parts on Request and Response, the current implementation of the types would trivially allow this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there was some discussion around that a while ago hyperium/http#511.


Box::pin(async move {
let result = E::from_request_parts(&mut parts, state).await;

*self.version_mut() = parts.version;
*self.method_mut() = parts.method.clone();
*self.uri_mut() = parts.uri.clone();
*self.headers_mut() = std::mem::take(&mut parts.headers);
*self.extensions_mut() = std::mem::take(&mut parts.extensions);

result
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::Method;
use hyper::Body;

#[tokio::test]
async fn extract_without_state() {
let req = Request::new(());

let method: Method = req.extract().await.unwrap();

assert_eq!(method, Method::GET);
}

#[tokio::test]
async fn extract_body_without_state() {
let req = Request::new(Body::from("foobar"));

let body: String = req.extract().await.unwrap();

assert_eq!(body, "foobar");
}

#[tokio::test]
async fn extract_with_state() {
let req = Request::new(());

let state = "state".to_owned();

let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap();

assert_eq!(extracted_state, state);
}

#[tokio::test]
async fn extract_parts_without_state() {
let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap();

let method: Method = req.extract_parts().await.unwrap();

assert_eq!(method, Method::GET);
assert_eq!(req.headers()["x-foo"], "foo");
}

#[tokio::test]
async fn extract_parts_with_state() {
let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap();

let state = "state".to_owned();

let State(extracted_state): State<String> =
req.extract_parts_with_state(&state).await.unwrap();

assert_eq!(extracted_state, state);
assert_eq!(req.headers()["x-foo"], "foo");
}

// this stuff just needs to compile
#[allow(dead_code)]
struct WorksForCustomExtractor {
method: Method,
from_state: String,
body: String,
}

#[async_trait]
impl<S, B> FromRequest<S, B> for WorksForCustomExtractor
where
S: Send + Sync,
B: Send + 'static,
String: FromRef<S> + FromRequest<(), B>,
{
type Rejection = <String as FromRequest<(), B>>::Rejection;

async fn from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap();
let method = req.extract_parts().await.unwrap();
let body = req.extract().await?;

Ok(Self {
method,
from_state,
body,
})
}
}
}
103 changes: 103 additions & 0 deletions axum/src/ext_traits/request_parts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use axum_core::extract::FromRequestParts;
use futures_util::future::BoxFuture;
use http::request::Parts;

mod sealed {
pub trait Sealed {}
impl Sealed for http::request::Parts {}
}

/// Extension trait that adds additional methods to [`Parts`].
pub trait RequestPartsExt: sealed::Sealed + Sized {
/// Apply an extractor to this `Parts`.
///
/// This is just a convenience for `E::from_request_parts(parts, &())`.
fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
where
E: FromRequestParts<()> + 'static;

/// Apply an extractor that requires some state to this `Parts`.
///
/// This is just a convenience for `E::from_request_parts(parts, state)`.
fn extract_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;
}

impl RequestPartsExt for Parts {
fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
where
E: FromRequestParts<()> + 'static,
{
self.extract_with_state(&())
}

fn extract_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
where
E: FromRequestParts<S> + 'static,
S: Send + Sync,
{
E::from_request_parts(self, state)
}
}

#[cfg(test)]
mod tests {
use std::convert::Infallible;

use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::{Method, Request};

#[tokio::test]
async fn extract_without_state() {
let (mut parts, _) = Request::new(()).into_parts();

let method: Method = parts.extract().await.unwrap();

assert_eq!(method, Method::GET);
}

#[tokio::test]
async fn extract_with_state() {
let (mut parts, _) = Request::new(()).into_parts();

let state = "state".to_owned();

let State(extracted_state): State<String> = parts.extract_with_state(&state).await.unwrap();

assert_eq!(extracted_state, state);
}

// this stuff just needs to compile
#[allow(dead_code)]
struct WorksForCustomExtractor {
method: Method,
from_state: String,
}

#[async_trait]
impl<S> FromRequestParts<S> for WorksForCustomExtractor
where
S: Send + Sync,
String: FromRef<S>,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let RequiresState(from_state) = parts.extract_with_state(state).await?;
let method = parts.extract().await?;

Ok(Self { method, from_state })
}
}
}
3 changes: 3 additions & 0 deletions axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@
#[macro_use]
pub(crate) mod macros;

mod ext_traits;
mod extension;
#[cfg(feature = "form")]
mod form;
Expand Down Expand Up @@ -484,3 +485,5 @@ pub use axum_core::{BoxError, Error};

#[cfg(feature = "macros")]
pub use axum_macros::debug_handler;

pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt};