diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 5795104e1a..f3d0f30d2e 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -53,7 +53,7 @@ tokio-tungstenite = { optional = true, version = "0.16" } [dev-dependencies] futures = "0.3" -reqwest = { version = "0.11", default-features = false, features = ["json", "stream"] } +reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] } diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index c3366ee04e..d9bad33ed1 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -8,14 +8,13 @@ use crate::BoxError; use async_trait::async_trait; use futures_util::stream::Stream; use http::header::{HeaderMap, CONTENT_TYPE}; -use mime::Mime; use std::{ fmt, pin::Pin, task::{Context, Poll}, }; -/// Extractor that parses `multipart/form-data` requests commonly used with file uploads. +/// Extractor that parses `multipart/form-data` requests (commonly used with file uploads). /// /// # Example /// @@ -42,7 +41,7 @@ use std::{ /// # }; /// ``` /// -/// For security reasons its recommended to combine this with +/// For security reasons it's recommended to combine this with /// [`ContentLengthLimit`](super::ContentLengthLimit) to limit the size of the request payload. #[derive(Debug)] pub struct Multipart { @@ -120,9 +119,9 @@ impl<'a> Field<'a> { self.inner.file_name() } - /// Get the content type of the field. - pub fn content_type(&self) -> Option<&Mime> { - self.inner.content_type() + /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field. + pub fn content_type(&self) -> Option<&str> { + self.inner.content_type().map(|m| m.as_ref()) } /// Get a map of headers as [`HeaderMap`]. @@ -191,3 +190,40 @@ define_rejection! { /// missing or invalid. pub struct InvalidBoundary; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{response::IntoResponse, routing::post, test_helpers::*, Router}; + + #[tokio::test] + async fn content_type_with_encoding() { + const BYTES: &[u8] = "🦀".as_bytes(); + const FILE_NAME: &str = "index.html"; + const CONTENT_TYPE: &str = "text/html; charset=utf-8"; + + async fn handle(mut multipart: Multipart) -> impl IntoResponse { + let field = multipart.next_field().await.unwrap().unwrap(); + + assert_eq!(field.file_name().unwrap(), FILE_NAME); + assert_eq!(field.content_type().unwrap(), CONTENT_TYPE); + assert_eq!(field.bytes().await.unwrap(), BYTES); + + assert!(multipart.next_field().await.unwrap().is_none()); + } + + let app = Router::new().route("/", post(handle)); + + let client = TestClient::new(app); + + let form = reqwest::multipart::Form::new().part( + "file", + reqwest::multipart::Part::bytes(BYTES) + .file_name(FILE_NAME) + .mime_str(CONTENT_TYPE) + .unwrap(), + ); + + client.post("/").multipart(form).send().await; + } +} diff --git a/axum/src/test_helpers.rs b/axum/src/test_helpers.rs index fef2f08b39..0113ec8fd4 100644 --- a/axum/src/test_helpers.rs +++ b/axum/src/test_helpers.rs @@ -100,6 +100,7 @@ impl RequestBuilder { self.builder = self.builder.json(json); self } + pub(crate) fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, @@ -110,6 +111,11 @@ impl RequestBuilder { self.builder = self.builder.header(key, value); self } + + pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self { + self.builder = self.builder.multipart(form); + self + } } pub(crate) struct TestResponse { diff --git a/examples/multipart-form/src/main.rs b/examples/multipart-form/src/main.rs index b734e65a5f..6d1edf1ef8 100644 --- a/examples/multipart-form/src/main.rs +++ b/examples/multipart-form/src/main.rs @@ -65,8 +65,16 @@ async fn accept_form( ) { while let Some(field) = multipart.next_field().await.unwrap() { let name = field.name().unwrap().to_string(); + let file_name = field.file_name().unwrap().to_string(); + let content_type = field.content_type().unwrap().to_string(); let data = field.bytes().await.unwrap(); - println!("Length of `{}` is {} bytes", name, data.len()); + println!( + "Length of `{}` (`{}`: `{}`) is {} bytes", + name, + file_name, + content_type, + data.len() + ); } }