diff --git a/cot/src/error.rs b/cot/src/error.rs index f2990d0..dc0f1a5 100644 --- a/cot/src/error.rs +++ b/cot/src/error.rs @@ -82,6 +82,7 @@ impl_error_from_repr!(crate::forms::FormError); impl_error_from_repr!(crate::auth::AuthError); #[cfg(feature = "json")] impl_error_from_repr!(serde_json::Error); +impl_error_from_repr!(crate::request::PathParamsDeserializerError); #[derive(Debug, Error)] #[non_exhaustive] @@ -140,6 +141,9 @@ pub(crate) enum ErrorRepr { #[source] source: Box, }, + /// An error occurred while trying to parse path parameters. + #[error("Could not parse path parameters: {0}")] + PathParametersParse(#[from] crate::request::PathParamsDeserializerError), } #[cfg(test)] diff --git a/cot/src/request.rs b/cot/src/request.rs index 25bbb78..4fc1b1a 100644 --- a/cot/src/request.rs +++ b/cot/src/request.rs @@ -20,6 +20,7 @@ use bytes::Bytes; #[cfg(feature = "json")] use cot::headers::JSON_CONTENT_TYPE; use indexmap::IndexMap; +pub use path_params_deserializer::PathParamsDeserializerError; use tower_sessions::Session; #[cfg(feature = "db")] @@ -29,6 +30,8 @@ use crate::headers::FORM_CONTENT_TYPE; use crate::router::Router; use crate::{Body, Result}; +mod path_params_deserializer; + /// HTTP request type. pub type Request = http::Request; @@ -249,10 +252,39 @@ impl PathParams { self.params.insert(name, value); } + pub fn iter(&self) -> impl Iterator { + self.params + .iter() + .map(|(name, value)| (name.as_str(), value.as_str())) + } + + #[must_use] + pub fn len(&self) -> usize { + self.params.len() + } + #[must_use] pub fn get(&self, name: &str) -> Option<&str> { self.params.get(name).map(String::as_str) } + + #[must_use] + pub fn get_index(&self, index: usize) -> Option<&str> { + self.params + .get_index(index) + .map(|(_, value)| value.as_str()) + } + + #[must_use] + pub fn key_at_index(&self, index: usize) -> Option<&str> { + self.params.get_index(index).map(|(key, _)| key.as_str()) + } + + pub fn parse<'de, T: serde::Deserialize<'de>>( + &'de self, + ) -> std::result::Result { + T::deserialize(path_params_deserializer::PathParamsDeserializer::new(self)) + } } pub(crate) fn query_pairs(bytes: &Bytes) -> impl Iterator, Cow)> { @@ -264,7 +296,7 @@ mod tests { use super::*; #[tokio::test] - async fn test_form_data() { + async fn form_data() { let mut request = http::Request::builder() .method(http::Method::POST) .header(http::header::CONTENT_TYPE, FORM_CONTENT_TYPE) @@ -277,7 +309,7 @@ mod tests { #[cfg(feature = "json")] #[tokio::test] - async fn test_json() { + async fn json() { let mut request = http::Request::builder() .method(http::Method::POST) .header(http::header::CONTENT_TYPE, JSON_CONTENT_TYPE) @@ -289,7 +321,7 @@ mod tests { } #[test] - fn test_path_params() { + fn path_params() { let mut path_params = PathParams::new(); path_params.insert("name".into(), "world".into()); @@ -298,7 +330,29 @@ mod tests { } #[test] - fn test_query_pairs() { + fn path_params_parse() { + let mut path_params = PathParams::new(); + path_params.insert("hello".into(), "world".into()); + path_params.insert("foo".into(), "bar".into()); + + #[derive(Debug, PartialEq, Eq, serde::Deserialize)] + struct Params { + hello: String, + foo: String, + } + + let params: Params = path_params.parse().unwrap(); + assert_eq!( + params, + Params { + hello: "world".to_string(), + foo: "bar".to_string(), + } + ); + } + + #[test] + fn create_query_pairs() { let bytes = Bytes::from_static(b"hello=world&foo=bar"); let pairs: Vec<_> = query_pairs(&bytes).collect(); assert_eq!( diff --git a/cot/src/request/path_params_deserializer.rs b/cot/src/request/path_params_deserializer.rs new file mode 100644 index 0000000..0b7170b --- /dev/null +++ b/cot/src/request/path_params_deserializer.rs @@ -0,0 +1,855 @@ +use std::fmt::Display; + +use cot::request::PathParams; +use serde::de::{DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor}; +use serde::Deserializer; +use thiserror::Error; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Error)] +pub enum PathParamsDeserializerError { + /// Invalid number of path parameters + #[error("Invalid number of path parameters: expected {expected}, got {actual}")] + InvalidParamNumber { expected: usize, actual: usize }, + /// A value cannot be parsed into given type. + #[error("Failed to parse value `{value}` as `{expected_type}`")] + ParseError { + value: String, + expected_type: &'static str, + }, + /// Deserialization into given type is not supported. + #[error("Deserializing `{type_name}` is not supported")] + UnsupportedType { type_name: &'static str }, + /// An error that doesn't fit any other variant. + #[error("{0}")] + Custom(String), +} + +impl PathParamsDeserializerError { + fn unsupported_type<'de, V>() -> Self + where + V: Visitor<'de>, + { + Self::UnsupportedType { + type_name: std::any::type_name::(), + } + } +} + +impl serde::de::Error for PathParamsDeserializerError { + fn custom(msg: T) -> Self + where + T: Display, + { + Self::Custom(msg.to_string()) + } +} + +#[derive(Debug)] +pub(super) struct PathParamsDeserializer<'de> { + path_params: &'de PathParams, +} + +impl<'de> PathParamsDeserializer<'de> { + #[must_use] + pub(super) fn new(path_params: &'de PathParams) -> Self { + Self { path_params } + } + + fn get_single_value(&self) -> Result<&'de str, PathParamsDeserializerError> { + self.check_param_num(1)?; + + let value = self + .path_params + .get_index(0) + .expect("we checked for len == 1"); + Ok(value) + } + + fn check_param_num(&self, expected: usize) -> Result<(), PathParamsDeserializerError> { + if self.path_params.len() == expected { + Ok(()) + } else { + Err(PathParamsDeserializerError::InvalidParamNumber { + expected, + actual: self.path_params.len(), + }) + } + } +} + +macro_rules! deserialize_value { + ($deserialize_fn_name:ident, $visit_fn_name:ident, $type_name:ident) => { + fn $deserialize_fn_name(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let value = self.get_single_value()?; + + let value = value + .parse() + .map_err(|_| PathParamsDeserializerError::ParseError { + value: value.to_string(), + expected_type: stringify!($type_name), + })?; + + visitor.$visit_fn_name(value) + } + }; +} + +macro_rules! deserialize_not_supported { + ($deserialize_fn_name:ident) => { + fn $deserialize_fn_name(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathParamsDeserializerError::unsupported_type::()) + } + }; +} + +impl<'de> Deserializer<'de> for PathParamsDeserializer<'de> { + type Error = PathParamsDeserializerError; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) + } + + deserialize_value!(deserialize_bool, visit_bool, bool); + deserialize_value!(deserialize_i8, visit_i8, i8); + deserialize_value!(deserialize_i16, visit_i16, i16); + deserialize_value!(deserialize_i32, visit_i32, i32); + deserialize_value!(deserialize_i64, visit_i64, i64); + deserialize_value!(deserialize_i128, visit_i128, i128); + deserialize_value!(deserialize_u8, visit_u8, u8); + deserialize_value!(deserialize_u16, visit_u16, u16); + deserialize_value!(deserialize_u32, visit_u32, u32); + deserialize_value!(deserialize_u64, visit_u64, u64); + deserialize_value!(deserialize_u128, visit_u128, u128); + deserialize_value!(deserialize_f32, visit_f32, f32); + deserialize_value!(deserialize_f64, visit_f64, f64); + deserialize_value!(deserialize_char, visit_char, char); + deserialize_value!(deserialize_string, visit_string, String); + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let value = self.get_single_value()?; + visitor.visit_borrowed_str(value) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let value = self.get_single_value()?; + visitor.visit_bytes(value.as_bytes()) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let value = self.get_single_value()?; + visitor.visit_byte_buf(value.as_bytes().to_owned()) + } + + deserialize_not_supported!(deserialize_option); + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(SequenceDeserializer::new(self.path_params)) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.check_param_num(len)?; + + visitor.visit_seq(SequenceDeserializer::new(self.path_params)) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.check_param_num(len)?; + + visitor.visit_seq(SequenceDeserializer::new(self.path_params)) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(MapDeserializer::new(self.path_params)) + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.check_param_num(fields.len())?; + + visitor.visit_map(MapDeserializer::new(self.path_params)) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(EnumDeserializer::new(self.get_single_value()?)) + } + + deserialize_not_supported!(deserialize_identifier); + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_any(visitor) + } +} + +#[derive(Debug)] +struct SequenceDeserializer<'de> { + path_params: &'de PathParams, + index: usize, +} + +impl<'de> SequenceDeserializer<'de> { + fn new(path_params: &'de PathParams) -> Self { + Self { + path_params, + index: 0, + } + } +} + +impl<'de> SeqAccess<'de> for SequenceDeserializer<'de> { + type Error = PathParamsDeserializerError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + if let Some(value) = self.path_params.get_index(self.index) { + let key = self + .path_params + .key_at_index(self.index) + .expect("a value should always have a key"); + self.index += 1; + + let deserialized = seed.deserialize(ValueDeserializer::new(key, value))?; + Ok(Some(deserialized)) + } else { + Ok(None) + } + } +} + +#[derive(Debug)] +struct MapDeserializer<'de> { + path_params: &'de PathParams, + index: usize, +} + +impl<'de> MapDeserializer<'de> { + #[must_use] + fn new(path_params: &'de PathParams) -> Self { + Self { + path_params, + index: 0, + } + } +} + +impl<'de> MapAccess<'de> for MapDeserializer<'de> { + type Error = PathParamsDeserializerError; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if let Some(key) = self.path_params.key_at_index(self.index) { + let deserialized = seed.deserialize(ValueDeserializer::new_value(key))?; + Ok(Some(deserialized)) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + let key = self + .path_params + .key_at_index(self.index) + .expect("next_key_seed should've been called first and checked if the key exists"); + let value = self + .path_params + .get_index(self.index) + .expect("next_key_seed should've been called first and checked if the value exists"); + self.index += 1; + + let deserialized = seed.deserialize(ValueDeserializer::new(key, value))?; + Ok(deserialized) + } +} + +#[derive(Debug)] +struct ValueDeserializer<'de> { + key: Option<&'de str>, + value: &'de str, +} + +impl<'de> ValueDeserializer<'de> { + #[must_use] + fn new(key: &'de str, value: &'de str) -> Self { + Self { + key: Some(key), + value, + } + } + + #[must_use] + fn new_value(value: &'de str) -> Self { + Self { key: None, value } + } + + fn get_single_value(&self) -> Result<&'de str, PathParamsDeserializerError> { + Ok(self.value) + } +} + +impl<'de> Deserializer<'de> for ValueDeserializer<'de> { + type Error = PathParamsDeserializerError; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) + } + + deserialize_value!(deserialize_bool, visit_bool, bool); + deserialize_value!(deserialize_i8, visit_i8, i8); + deserialize_value!(deserialize_i16, visit_i16, i16); + deserialize_value!(deserialize_i32, visit_i32, i32); + deserialize_value!(deserialize_i64, visit_i64, i64); + deserialize_value!(deserialize_i128, visit_i128, i128); + deserialize_value!(deserialize_u8, visit_u8, u8); + deserialize_value!(deserialize_u16, visit_u16, u16); + deserialize_value!(deserialize_u32, visit_u32, u32); + deserialize_value!(deserialize_u64, visit_u64, u64); + deserialize_value!(deserialize_u128, visit_u128, u128); + deserialize_value!(deserialize_f32, visit_f32, f32); + deserialize_value!(deserialize_f64, visit_f64, f64); + deserialize_value!(deserialize_char, visit_char, char); + deserialize_value!(deserialize_string, visit_string, String); + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.value) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.value.as_bytes()) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_byte_buf(self.value.as_bytes().to_owned()) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + deserialize_not_supported!(deserialize_seq); + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + if let Some(key) = self.key { + if len == 2 { + return visitor.visit_seq(ArrayDeserializer::new([key, self.value])); + } + } + + Err(PathParamsDeserializerError::unsupported_type::()) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathParamsDeserializerError::unsupported_type::()) + } + + deserialize_not_supported!(deserialize_map); + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathParamsDeserializerError::unsupported_type::()) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(EnumDeserializer::new(self.value)) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.value) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } +} + +#[derive(Debug)] +struct ArrayDeserializer<'de, const LEN: usize> { + sequence: [&'de str; LEN], + index: usize, +} + +impl<'de, const LEN: usize> ArrayDeserializer<'de, LEN> { + #[must_use] + fn new(sequence: [&'de str; LEN]) -> Self { + Self { sequence, index: 0 } + } +} + +impl<'de, const LEN: usize> SeqAccess<'de> for ArrayDeserializer<'de, LEN> { + type Error = PathParamsDeserializerError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + if let Some(value) = self.sequence.get(self.index) { + self.index += 1; + + seed.deserialize(ValueDeserializer::new_value(value)) + .map(Some) + } else { + Ok(None) + } + } +} + +#[derive(Debug)] +struct EnumDeserializer<'de> { + value: &'de str, +} + +impl<'de> EnumDeserializer<'de> { + #[must_use] + fn new(value: &'de str) -> Self { + Self { value } + } +} + +impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { + type Error = PathParamsDeserializerError; + type Variant = UnitVariant; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + Ok(( + seed.deserialize(ValueDeserializer::new_value(self.value))?, + UnitVariant, + )) + } +} + +#[derive(Debug)] +struct UnitVariant; + +impl<'de> VariantAccess<'de> for UnitVariant { + type Error = PathParamsDeserializerError; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + Err(PathParamsDeserializerError::UnsupportedType { + type_name: "newtype enum variant", + }) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathParamsDeserializerError::UnsupportedType { + type_name: "tuple enum variant", + }) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathParamsDeserializerError::UnsupportedType { + type_name: "struct enum variant", + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use serde::Deserialize; + + use super::*; + + macro_rules! test_deserialize_value { + ($test_name:ident, $ty:ty, $value:literal, $expected:literal) => { + #[test] + fn $test_name() { + let path_params = create_path_params([("some_name", $value)]); + let deserializer = PathParamsDeserializer::new(&path_params); + let test_val = <$ty>::deserialize(deserializer).unwrap(); + assert_eq!(test_val, $expected); + } + }; + } + + test_deserialize_value!(deserialize_str, &str, "test", "test"); + test_deserialize_value!(deserialize_string, String, "test", "test"); + test_deserialize_value!(deserialize_bool_true, bool, "true", true); + test_deserialize_value!(deserialize_bool_false, bool, "false", false); + test_deserialize_value!(deserialize_i8, i8, "42", 42); + test_deserialize_value!(deserialize_i16, i16, "2137", 2137); + test_deserialize_value!(deserialize_i32, i32, "2137420", 2137420); + test_deserialize_value!(deserialize_i64, i64, "2137420691337", 2137_420_691_337); + test_deserialize_value!( + deserialize_i128, + i128, + "21372137213721372137", + 21372137213721372137 + ); + test_deserialize_value!(deserialize_u8, u8, "42", 42); + test_deserialize_value!(deserialize_u16, u16, "2137", 2137); + test_deserialize_value!(deserialize_u32, u32, "2137420", 2137420); + test_deserialize_value!(deserialize_u64, u64, "2137420691337", 2137420691337); + test_deserialize_value!( + deserialize_u128, + u128, + "21372137213721372137", + 21372137213721372137 + ); + test_deserialize_value!(deserialize_f32, f32, "2.137", 2.137); + test_deserialize_value!(deserialize_f64, f64, "2.137", 2.137); + test_deserialize_value!(deserialize_char, char, "a", 'a'); + + #[test] + fn deserialize_tuple() { + let path_params = create_path_params([("a", "test"), ("b", "123"), ("c", "true")]); + let actual = + <(String, i32, bool)>::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!(actual, ("test".to_string(), 123, true)); + } + + #[test] + fn deserialize_tuple_pairs() { + let path_params = create_path_params([("a", "test"), ("b", "123"), ("c", "true")]); + let actual = <((String, String), (String, i32), (char, bool))>::deserialize( + PathParamsDeserializer::new(&path_params), + ) + .unwrap(); + assert_eq!( + actual, + ( + ("a".to_string(), "test".to_string()), + ("b".to_string(), 123), + ('c', true) + ) + ); + } + + #[test] + fn deserialize_vec() { + let path_params = create_path_params([("a", "1"), ("b", "2"), ("c", "3")]); + let actual = >::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!(actual, vec![1, 2, 3]); + } + + #[test] + fn deserialize_struct() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Params { + a: String, + b: i32, + c: bool, + } + + let path_params = create_path_params([("a", "test"), ("b", "123"), ("c", "true")]); + let actual = Params::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!( + actual, + Params { + a: "test".to_string(), + b: 123, + c: true, + } + ); + } + + #[test] + fn deserialize_map() { + let path_params = create_path_params([("a", "test"), ("b", "123"), ("c", "true")]); + let actual = + >::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!( + actual, + HashMap::from([("a", "test"), ("b", "123"), ("c", "true")]) + ); + } + + #[test] + fn deserialize_map_ints() { + let path_params = create_path_params([("1", "123"), ("2", "456"), ("3", "789")]); + let actual = + >::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!(actual, HashMap::from([(1, 123), (2, 456), (3, 789)])); + } + + #[test] + fn deserialize_enum() { + #[derive(Debug, PartialEq, Eq, Deserialize)] + enum ParamEnum { + A, + B, + } + + let path_params = create_path_params([("x", "A")]); + + let actual = ParamEnum::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!(actual, ParamEnum::A); + } + + #[test] + fn deserialize_enum_vec() { + #[derive(Debug, PartialEq, Eq, Deserialize)] + enum ParamEnum { + A, + B, + #[serde(rename = "foo")] + C, + } + + let path_params = create_path_params([("x", "A"), ("y", "B"), ("z", "foo")]); + + let actual = + >::deserialize(PathParamsDeserializer::new(&path_params)).unwrap(); + assert_eq!(actual, vec![ParamEnum::A, ParamEnum::B, ParamEnum::C]); + } + + #[test] + fn deserialize_wrong_param_num_tuple_error() { + let path_params = create_path_params([("x", "a")]); + + let actual = + <(String, String)>::deserialize(PathParamsDeserializer::new(&path_params)).unwrap_err(); + assert_eq!( + actual, + PathParamsDeserializerError::InvalidParamNumber { + expected: 2, + actual: 1, + } + ); + } + + #[test] + fn deserialize_wrong_param_num_struct_error() { + #[derive(Debug, PartialEq, Eq, Deserialize)] + struct Params { + a: String, + b: String, + } + + let path_params = create_path_params([("x", "a")]); + + let actual = Params::deserialize(PathParamsDeserializer::new(&path_params)).unwrap_err(); + assert_eq!( + actual, + PathParamsDeserializerError::InvalidParamNumber { + expected: 2, + actual: 1, + } + ); + } + + #[test] + fn deserialize_parse_error() { + let path_params = create_path_params([("x", "a")]); + + let actual = i32::deserialize(PathParamsDeserializer::new(&path_params)).unwrap_err(); + assert_eq!( + actual, + PathParamsDeserializerError::ParseError { + value: "a".to_string(), + expected_type: "i32", + } + ); + } + + #[test] + fn deserialize_unsupported_type_error() { + let path_params = create_path_params([("x", "a")]); + + let actual = + >::deserialize(PathParamsDeserializer::new(&path_params)).unwrap_err(); + assert_eq!( + actual, + PathParamsDeserializerError::UnsupportedType { + type_name: "core::option::Option", + } + ); + } + + fn create_path_params(items: I) -> PathParams + where + A: ToString, + B: ToString, + I: IntoIterator, + { + let mut path_params = PathParams::new(); + for (a, b) in items { + path_params.insert(a.to_string(), b.to_string()); + } + + path_params + } +} diff --git a/examples/todo-list/src/main.rs b/examples/todo-list/src/main.rs index fb9651e..d462127 100644 --- a/examples/todo-list/src/main.rs +++ b/examples/todo-list/src/main.rs @@ -57,11 +57,7 @@ async fn add_todo(mut request: Request) -> cot::Result { } async fn remove_todo(request: Request) -> cot::Result { - let todo_id = request - .path_params() - .get("todo_id") - .expect("todo_id not found"); - let todo_id = todo_id.parse::().expect("todo_id is not a number"); + let todo_id: i32 = request.path_params().parse()?; { query!(TodoItem, $id == todo_id)