Skip to content

Commit

Permalink
Change HeaderMap extractor to clone the headers (#698)
Browse files Browse the repository at this point in the history
* Change `HeaderMap` extractor to clone the headers

* fix docs

* changelog

* inline variable

* also add changelog item to axum

* don't list types from axum in axum-core's changelog

* document that `HeaderMap::from_request` clones the headers

* fix typo

* a few more typos
  • Loading branch information
davidpdrsn authored Jan 11, 2022
1 parent 216bdcf commit 5beab52
Show file tree
Hide file tree
Showing 16 changed files with 92 additions and 144 deletions.
17 changes: 16 additions & 1 deletion axum-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus
they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead
`HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the
headers you need ([#698])

This includes these breaking changes:
- `RequestParts::take_headers` has been removed.
- `RequestParts::headers` returns `&HeaderMap`.
- `RequestParts::headers_mut` returns `&mut HeaderMap`.
- `HeadersAlreadyExtracted` has been removed.
- The `HeadersAlreadyExtracted` variant has been removed from these rejections:
- `RequestAlreadyExtracted`
- `RequestPartsAlreadyExtracted`
- `<HeaderMap as FromRequest<_>>::Error` has been changed to `std::convert::Infallible`.

[#698]: https://github.com/tokio-rs/axum/pull/698

# 0.1.1 (06. December, 2021)

Expand Down
35 changes: 8 additions & 27 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub struct RequestParts<B> {
method: Method,
uri: Uri,
version: Version,
headers: Option<HeaderMap>,
headers: HeaderMap,
extensions: Option<Extensions>,
body: Option<B>,
}
Expand Down Expand Up @@ -107,7 +107,7 @@ impl<B> RequestParts<B> {
method,
uri,
version,
headers: Some(headers),
headers,
extensions: Some(extensions),
body: Some(body),
}
Expand All @@ -117,22 +117,19 @@ impl<B> RequestParts<B> {
///
/// Fails if
///
/// - The full [`HeaderMap`] has been extracted, that is [`take_headers`]
/// have been called.
/// - The full [`Extensions`] has been extracted, that is
/// [`take_extensions`] have been called.
/// - The request body has been extracted, that is [`take_body`] have been
/// called.
///
/// [`take_headers`]: RequestParts::take_headers
/// [`take_extensions`]: RequestParts::take_extensions
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, RequestAlreadyExtracted> {
let Self {
method,
uri,
version,
mut headers,
headers,
mut extensions,
mut body,
} = self;
Expand All @@ -148,14 +145,7 @@ impl<B> RequestParts<B> {
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;

if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
} else {
return Err(RequestAlreadyExtracted::HeadersAlreadyExtracted(
HeadersAlreadyExtracted,
));
}
*req.headers_mut() = headers;

if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
Expand Down Expand Up @@ -199,22 +189,13 @@ impl<B> RequestParts<B> {
}

/// Gets a reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers(&self) -> Option<&HeaderMap> {
self.headers.as_ref()
pub fn headers(&self) -> &HeaderMap {
&self.headers
}

/// Gets a mutable reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> {
self.headers.as_mut()
}

/// Takes the headers out of the request, leaving a `None` in its place.
pub fn take_headers(&mut self) -> Option<HeaderMap> {
self.headers.take()
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}

/// Gets a reference to the request extensions.
Expand Down
9 changes: 0 additions & 9 deletions axum-core/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ define_rejection! {
pub struct BodyAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the headers has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
Expand Down Expand Up @@ -47,7 +40,6 @@ composite_rejection! {
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
Expand Down Expand Up @@ -79,7 +71,6 @@ composite_rejection! {
///
/// Contains one variant for each way the [`http::request::Parts`] extractor can fail.
pub enum RequestPartsAlreadyExtracted {
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
16 changes: 12 additions & 4 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: None,
headers: HeaderMap::new(),
extensions: None,
body: None,
},
Expand Down Expand Up @@ -65,15 +65,20 @@ where
}
}

/// Clone the headers from the request.
///
/// Prefer using [`TypedHeader`] to extract only the headers you need.
///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait]
impl<B> FromRequest<B> for HeaderMap
where
B: Send,
{
type Rejection = HeadersAlreadyExtracted;
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_headers().ok_or(HeadersAlreadyExtracted)
Ok(req.headers().clone())
}
}

Expand Down Expand Up @@ -143,7 +148,10 @@ where
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
let headers = HeaderMap::from_request(req).await?;
let headers = match HeaderMap::from_request(req).await {
Ok(headers) => headers,
Err(err) => match err {},
};
let extensions = Extensions::from_request(req).await?;

let mut temp_request = Request::new(());
Expand Down
19 changes: 19 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
overwriting old values.
- **breaking:** Require `Output = ()` on `WebSocketStream::on_upgrade` ([#644])
- **breaking:** Make `TypedHeaderRejectionReason` `#[non_exhaustive]` ([#665])
- **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus
they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead
`HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the
headers you need ([#698])

This includes these breaking changes:
- `RequestParts::take_headers` has been removed.
- `RequestParts::headers` returns `&HeaderMap`.
- `RequestParts::headers_mut` returns `&mut HeaderMap`.
- `HeadersAlreadyExtracted` has been removed.
- The `HeadersAlreadyExtracted` removed variant has been removed from these rejections:
- `RequestAlreadyExtracted`
- `RequestPartsAlreadyExtracted`
- `JsonRejection`
- `FormRejection`
- `ContentLengthLimitRejection`
- `WebSocketUpgradeRejection`
- `<HeaderMap as FromRequest<_>>::Error` has been changed to `std::convert::Infallible`.

[#644]: https://github.com/tokio-rs/axum/pull/644
[#665]: https://github.com/tokio-rs/axum/pull/665
[#698]: https://github.com/tokio-rs/axum/pull/698

# 0.4.3 (21. December, 2021)

Expand Down
8 changes: 1 addition & 7 deletions axum/src/docs/extract.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,6 @@ async fn handler(result: Result<Json<Value>, JsonRejection>) -> impl IntoRespons
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to buffer request body".to_string(),
)),
JsonRejection::HeadersAlreadyExtracted(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Headers already extracted".to_string(),
)),
// we must provide a catch-all case since `JsonRejection` is marked
// `#[non_exhaustive]`
_ => Err((
Expand Down Expand Up @@ -377,9 +373,7 @@ where
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let user_agent = req.headers().and_then(|headers| headers.get(USER_AGENT));
if let Some(user_agent) = user_agent {
if let Some(user_agent) = req.headers().get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone()))
} else {
Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
Expand Down
9 changes: 1 addition & 8 deletions axum/src/extract/content_length_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,7 @@ where
type Rejection = ContentLengthLimitRejection<T::Rejection>;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let content_length = req
.headers()
.ok_or_else(|| {
ContentLengthLimitRejection::HeadersAlreadyExtracted(
HeadersAlreadyExtracted::default(),
)
})?
.get(http::header::CONTENT_LENGTH);
let content_length = req.headers().get(http::header::CONTENT_LENGTH);

let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
Expand Down
3 changes: 1 addition & 2 deletions axum/src/extract/extractor_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use tower_service::Service;
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// .and_then(|headers| headers.get(http::header::AUTHORIZATION))
/// .get(http::header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok());
///
/// match auth_header {
Expand Down Expand Up @@ -291,7 +291,6 @@ mod tests {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
.expect("headers already extracted")
.get("authorization")
.and_then(|v| v.to_str().ok())
{
Expand Down
2 changes: 1 addition & 1 deletion axum/src/extract/form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
} else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED)? {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType.into());
}

Expand Down
14 changes: 5 additions & 9 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,20 @@ pub use self::typed_header::TypedHeader;
pub(crate) fn has_content_type<B>(
req: &RequestParts<B>,
expected_content_type: &mime::Mime,
) -> Result<bool, HeadersAlreadyExtracted> {
let content_type = if let Some(content_type) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(header::CONTENT_TYPE)
{
) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
return Ok(false);
return false;
};

let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return Ok(false);
return false;
};

Ok(content_type.starts_with(expected_content_type.as_ref()))
content_type.starts_with(expected_content_type.as_ref())
}

pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {
Expand Down
3 changes: 1 addition & 2 deletions axum/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ where

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let headers = req.headers().ok_or_else(HeadersAlreadyExtracted::default)?;
let headers = req.headers();
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
Expand Down Expand Up @@ -179,7 +179,6 @@ composite_rejection! {
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
HeadersAlreadyExtracted,
}
}

Expand Down
7 changes: 0 additions & 7 deletions axum/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ composite_rejection! {
InvalidFormContentType,
FailedToDeserializeQueryString,
BytesRejection,
HeadersAlreadyExtracted,
}
}

Expand All @@ -139,7 +138,6 @@ composite_rejection! {
InvalidJsonBody,
MissingJsonContentType,
BytesRejection,
HeadersAlreadyExtracted,
}
}

Expand Down Expand Up @@ -195,8 +193,6 @@ pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)]
LengthRequired(LengthRequired),
#[allow(missing_docs)]
HeadersAlreadyExtracted(HeadersAlreadyExtracted),
#[allow(missing_docs)]
Inner(T),
}

Expand All @@ -208,7 +204,6 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(),
Self::HeadersAlreadyExtracted(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}
Expand All @@ -222,7 +217,6 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.fmt(f),
Self::LengthRequired(inner) => inner.fmt(f),
Self::HeadersAlreadyExtracted(inner) => inner.fmt(f),
Self::Inner(inner) => inner.fmt(f),
}
}
Expand All @@ -236,7 +230,6 @@ where
match self {
Self::PayloadTooLarge(inner) => Some(inner),
Self::LengthRequired(inner) => Some(inner),
Self::HeadersAlreadyExtracted(inner) => Some(inner),
Self::Inner(inner) => Some(inner),
}
}
Expand Down
11 changes: 1 addition & 10 deletions axum/src/extract/typed_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,7 @@ where
type Rejection = TypedHeaderRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let headers = if let Some(headers) = req.headers() {
headers
} else {
return Err(TypedHeaderRejection {
name: T::name(),
reason: TypedHeaderRejectionReason::Missing,
});
};

match headers.typed_try_get::<T>() {
match req.headers().typed_try_get::<T>() {
Ok(Some(value)) => Ok(Self(value)),
Ok(None) => Err(TypedHeaderRejection {
name: T::name(),
Expand Down
Loading

0 comments on commit 5beab52

Please sign in to comment.