Skip to content

Commit

Permalink
feat: add path params deserializer
Browse files Browse the repository at this point in the history
This makes it much easier to retrieve the path params from the URLs, by
implementing a clean API based on serde that is capable of deserializing
tuples, structs, and single values while providing a nice error handling.
  • Loading branch information
m4tx committed Jan 27, 2025
1 parent 2be5248 commit 5b73ca8
Show file tree
Hide file tree
Showing 4 changed files with 923 additions and 9 deletions.
4 changes: 4 additions & 0 deletions cot/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl_error_from_repr!(crate::forms::FormError);
impl_error_from_repr!(crate::auth::AuthError);
#[cfg(feature = "json")]
impl_error_from_repr!(serde_json::Error);
impl_error_from_repr!(crate::request::PathParamsDeserializerError);

#[derive(Debug, Error)]
#[non_exhaustive]
Expand Down Expand Up @@ -140,6 +141,9 @@ pub(crate) enum ErrorRepr {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
/// An error occurred while trying to parse path parameters.
#[error("Could not parse path parameters: {0}")]
PathParametersParse(#[from] crate::request::PathParamsDeserializerError),
}

#[cfg(test)]
Expand Down
67 changes: 63 additions & 4 deletions cot/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use bytes::Bytes;
#[cfg(feature = "json")]
use cot::headers::JSON_CONTENT_TYPE;
use indexmap::IndexMap;
pub use path_params_deserializer::PathParamsDeserializerError;
use tower_sessions::Session;

#[cfg(feature = "db")]
Expand All @@ -29,6 +30,8 @@ use crate::headers::FORM_CONTENT_TYPE;
use crate::router::Router;
use crate::{Body, Result};

mod path_params_deserializer;

/// HTTP request type.
pub type Request = http::Request<Body>;

Expand Down Expand Up @@ -249,10 +252,44 @@ impl PathParams {
self.params.insert(name, value);
}

pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params
.iter()
.map(|(name, value)| (name.as_str(), value.as_str()))
}

#[must_use]
pub fn len(&self) -> usize {
self.params.len()
}

#[must_use]
pub fn is_empty(&self) -> bool {
self.params.is_empty()
}

#[must_use]
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(String::as_str)
}

#[must_use]
pub fn get_index(&self, index: usize) -> Option<&str> {
self.params
.get_index(index)
.map(|(_, value)| value.as_str())
}

#[must_use]
pub fn key_at_index(&self, index: usize) -> Option<&str> {
self.params.get_index(index).map(|(key, _)| key.as_str())
}

pub fn parse<'de, T: serde::Deserialize<'de>>(
&'de self,
) -> std::result::Result<T, PathParamsDeserializerError> {
T::deserialize(path_params_deserializer::PathParamsDeserializer::new(self))
}
}

pub(crate) fn query_pairs(bytes: &Bytes) -> impl Iterator<Item = (Cow<str>, Cow<str>)> {
Expand All @@ -264,7 +301,7 @@ mod tests {
use super::*;

#[tokio::test]
async fn test_form_data() {
async fn form_data() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, FORM_CONTENT_TYPE)
Expand All @@ -277,7 +314,7 @@ mod tests {

#[cfg(feature = "json")]
#[tokio::test]
async fn test_json() {
async fn json() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, JSON_CONTENT_TYPE)
Expand All @@ -289,7 +326,7 @@ mod tests {
}

#[test]
fn test_path_params() {
fn path_params() {
let mut path_params = PathParams::new();
path_params.insert("name".into(), "world".into());

Expand All @@ -298,7 +335,29 @@ mod tests {
}

#[test]
fn test_query_pairs() {
fn path_params_parse() {
let mut path_params = PathParams::new();
path_params.insert("hello".into(), "world".into());
path_params.insert("foo".into(), "bar".into());

#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct Params {
hello: String,
foo: String,
}

let params: Params = path_params.parse().unwrap();
assert_eq!(
params,
Params {
hello: "world".to_string(),
foo: "bar".to_string(),
}
);
}

#[test]
fn create_query_pairs() {
let bytes = Bytes::from_static(b"hello=world&foo=bar");
let pairs: Vec<_> = query_pairs(&bytes).collect();
assert_eq!(
Expand Down
Loading

0 comments on commit 5b73ca8

Please sign in to comment.