From a10f6dc2aa6eead11ee289974a590fa92157c61f Mon Sep 17 00:00:00 2001 From: mikhailantoshkin Date: Tue, 14 Nov 2023 07:57:47 +0900 Subject: [PATCH] Move OptionalQuery to axum_extra::query --- axum-extra/src/extract/mod.rs | 8 +- axum-extra/src/extract/optional_query.rs | 189 --------------------- axum-extra/src/extract/query.rs | 203 ++++++++++++++++++++++- 3 files changed, 203 insertions(+), 197 deletions(-) delete mode 100644 axum-extra/src/extract/optional_query.rs diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 43c1835e0e..8435fc8422 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -10,9 +10,6 @@ mod form; #[cfg(feature = "cookie")] pub mod cookie; -#[cfg(feature = "query")] -mod optional_query; - #[cfg(feature = "query")] mod query; @@ -34,10 +31,7 @@ pub use self::cookie::SignedCookieJar; pub use self::form::{Form, FormRejection}; #[cfg(feature = "query")] -pub use self::optional_query::{OptionalQuery, OptionalQueryRejection}; - -#[cfg(feature = "query")] -pub use self::query::{Query, QueryRejection}; +pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection}; #[cfg(feature = "multipart")] pub use self::multipart::Multipart; diff --git a/axum-extra/src/extract/optional_query.rs b/axum-extra/src/extract/optional_query.rs deleted file mode 100644 index c25ef19803..0000000000 --- a/axum-extra/src/extract/optional_query.rs +++ /dev/null @@ -1,189 +0,0 @@ -use axum::{ - async_trait, - extract::FromRequestParts, - response::{IntoResponse, Response}, - Error, -}; -use http::{request::Parts, StatusCode}; -use serde::de::DeserializeOwned; -use std::fmt; - -/// Extractor that deserializes query strings into `None` if no query parameters are present and `Some(T)` otherwise. -/// -/// `T` is expected to implement [`serde::Deserialize`]. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::{routing::get, Router}; -/// use axum_extra::extract::OptionalQuery; -/// use serde::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct Pagination { -/// page: usize, -/// per_page: usize, -/// } -/// -/// // This will parse query strings like `?page=2&per_page=30` into `Some(Pagination)` and -/// // empty query string into `None` -/// async fn list_things(OptionalQuery(pagination): OptionalQuery) { -/// match pagination { -/// Some(Pagination{page, per_page}) => { /* return specified page */ }, -/// None => { /* return fist page */ } -/// } -/// // ... -/// } -/// -/// let app = Router::new().route("/list_things", get(list_things)); -/// # let _: Router = app; -/// ``` -/// -/// If the query string cannot be parsed it will reject the request with a `400 -/// Bad Request` response. -/// -/// For handling values being empty vs missing see the [query-params-with-empty-strings][example] -/// example. -/// -/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs -#[cfg_attr(docsrs, doc(cfg(feature = "query")))] -#[derive(Debug, Clone, Copy, Default)] -pub struct OptionalQuery(pub Option); - -#[async_trait] -impl FromRequestParts for OptionalQuery -where - T: DeserializeOwned, - S: Send + Sync, -{ - type Rejection = OptionalQueryRejection; - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - if let Some(query) = parts.uri.query() { - let value = serde_html_form::from_str(query).map_err(|err| { - OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err)) - })?; - Ok(OptionalQuery(value)) - } else { - Ok(OptionalQuery(None)) - } - } -} - -impl std::ops::Deref for OptionalQuery { - type Target = Option; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl std::ops::DerefMut for OptionalQuery { - #[inline] - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -/// Rejection used for [`OptionalQuery`]. -/// -/// Contains one variant for each way the [`OptionalQuery`] extractor can fail. -#[derive(Debug)] -#[non_exhaustive] -#[cfg(feature = "query")] -pub enum OptionalQueryRejection { - #[allow(missing_docs)] - FailedToDeserializeQueryString(Error), -} - -impl IntoResponse for OptionalQueryRejection { - fn into_response(self) -> Response { - match self { - Self::FailedToDeserializeQueryString(inner) => ( - StatusCode::BAD_REQUEST, - format!("Failed to deserialize query string: {inner}"), - ) - .into_response(), - } - } -} - -impl fmt::Display for OptionalQueryRejection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::FailedToDeserializeQueryString(inner) => inner.fmt(f), - } - } -} - -impl std::error::Error for OptionalQueryRejection { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::FailedToDeserializeQueryString(inner) => Some(inner), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::*; - use axum::{routing::post, Router}; - use http::{header::CONTENT_TYPE, StatusCode}; - use serde::Deserialize; - - #[tokio::test] - async fn no_parameters_deserialize_into_none() { - #[derive(Deserialize)] - struct Data { - value: String, - } - - let app = Router::new().route( - "/", - post(|OptionalQuery(data): OptionalQuery| async move { - match data { - None => "None".into(), - Some(data) => data.value, - } - }), - ); - - let client = TestClient::new(app); - - let res = client.post("/").body("").send().await; - - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "None"); - } - - #[tokio::test] - async fn parsing_errors_are_preserved() { - #[derive(Deserialize)] - struct Data { - value: String, - } - - let app = Router::new().route( - "/", - post(|OptionalQuery(data): OptionalQuery| async move { - match data { - None => "None".into(), - Some(data) => data.value, - } - }), - ); - - let client = TestClient::new(app); - - let res = client - .post("/?other=something") - .header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .body("") - .send() - .await; - - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - } -} diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index b4f5bebdc2..9a2ee81a5f 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -112,6 +112,124 @@ impl std::error::Error for QueryRejection { } } +/// Extractor that deserializes query strings into `None` if no query parameters are present. +/// Otherwise behaviour is identical to [`Query`] +/// +/// `T` is expected to implement [`serde::Deserialize`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{routing::get, Router}; +/// use axum_extra::extract::OptionalQuery; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Pagination { +/// page: usize, +/// per_page: usize, +/// } +/// +/// // This will parse query strings like `?page=2&per_page=30` into `Some(Pagination)` and +/// // empty query string into `None` +/// async fn list_things(OptionalQuery(pagination): OptionalQuery) { +/// match pagination { +/// Some(Pagination{ page, per_page }) => { /* return specified page */ }, +/// None => { /* return fist page */ } +/// } +/// // ... +/// } +/// +/// let app = Router::new().route("/list_things", get(list_things)); +/// # let _: Router = app; +/// ``` +/// +/// If the query string cannot be parsed it will reject the request with a `400 +/// Bad Request` response. +/// +/// For handling values being empty vs missing see the [query-params-with-empty-strings][example] +/// example. +/// +/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs +#[cfg_attr(docsrs, doc(cfg(feature = "query")))] +#[derive(Debug, Clone, Copy, Default)] +pub struct OptionalQuery(pub Option); + +#[async_trait] +impl FromRequestParts for OptionalQuery +where + T: DeserializeOwned, + S: Send + Sync, +{ + type Rejection = OptionalQueryRejection; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if let Some(query) = parts.uri.query() { + let value = serde_html_form::from_str(query).map_err(|err| { + OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err)) + })?; + Ok(OptionalQuery(Some(value))) + } else { + Ok(OptionalQuery(None)) + } + } +} + +impl std::ops::Deref for OptionalQuery { + type Target = Option; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for OptionalQuery { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Rejection used for [`OptionalQuery`]. +/// +/// Contains one variant for each way the [`OptionalQuery`] extractor can fail. +#[derive(Debug)] +#[non_exhaustive] +#[cfg(feature = "query")] +pub enum OptionalQueryRejection { + #[allow(missing_docs)] + FailedToDeserializeQueryString(Error), +} + +impl IntoResponse for OptionalQueryRejection { + fn into_response(self) -> Response { + match self { + Self::FailedToDeserializeQueryString(inner) => ( + StatusCode::BAD_REQUEST, + format!("Failed to deserialize query string: {inner}"), + ) + .into_response(), + } + } +} + +impl fmt::Display for OptionalQueryRejection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::FailedToDeserializeQueryString(inner) => inner.fmt(f), + } + } +} + +impl std::error::Error for OptionalQueryRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::FailedToDeserializeQueryString(inner) => Some(inner), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -121,7 +239,7 @@ mod tests { use serde::Deserialize; #[tokio::test] - async fn supports_multiple_values() { + async fn query_supports_multiple_values() { #[derive(Deserialize)] struct Data { #[serde(rename = "value")] @@ -145,4 +263,87 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one,two"); } + + #[tokio::test] + async fn optional_query_supports_multiple_values() { + #[derive(Deserialize)] + struct Data { + #[serde(rename = "value")] + values: Vec, + } + + let app = Router::new().route( + "/", + post(|OptionalQuery(data): OptionalQuery| async move { + data.map(|Data { values }| values.join(",")) + .unwrap_or("None".to_string()) + }), + ); + + let client = TestClient::new(app); + + let res = client + .post("/?value=one&value=two") + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body("") + .send() + .await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "one,two"); + } + + #[tokio::test] + async fn optional_query_deserializes_no_parameters_into_none() { + #[derive(Deserialize)] + struct Data { + value: String, + } + + let app = Router::new().route( + "/", + post(|OptionalQuery(data): OptionalQuery| async move { + match data { + None => "None".into(), + Some(data) => data.value, + } + }), + ); + + let client = TestClient::new(app); + + let res = client.post("/").body("").send().await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "None"); + } + + #[tokio::test] + async fn optional_query_preserves_parsing_errors() { + #[derive(Deserialize)] + struct Data { + value: String, + } + + let app = Router::new().route( + "/", + post(|OptionalQuery(data): OptionalQuery| async move { + match data { + None => "None".into(), + Some(data) => data.value, + } + }), + ); + + let client = TestClient::new(app); + + let res = client + .post("/?other=something") + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body("") + .send() + .await; + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } }