Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added #[serde(case_insensitive)] container attribute for case-insensitive identifier deserialization #1902

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
205 changes: 205 additions & 0 deletions serde/src/private/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,22 @@ where
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
pub struct CaseInsensitiveFlatMapDeserializer<'a, 'de: 'a, E>(
pub &'a mut Vec<Option<(Content<'de>, Content<'de>)>>,
pub PhantomData<E>,
);

#[cfg(any(feature = "std", feature = "alloc"))]
impl<'a, 'de, E> CaseInsensitiveFlatMapDeserializer<'a, 'de, E>
where
E: Error,
{
fn deserialize_other<V>() -> Result<V, E> {
Err(Error::custom("can only flatten structs and maps"))
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
macro_rules! forward_to_deserialize_other {
($($func:ident ( $($arg:ty),* ))*) => {
Expand Down Expand Up @@ -2796,6 +2812,128 @@ where
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl<'a, 'de, E> Deserializer<'de> for CaseInsensitiveFlatMapDeserializer<'a, 'de, E>
where
E: Error,
{
type Error = E;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_map(FlatInternallyTaggedAccess {
iter: self.0.iter_mut(),
pending: None,
_marker: PhantomData,
})
}

fn deserialize_enum<V>(
self,
name: &'static str,
variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
for item in self.0.iter_mut() {
// items in the vector are nulled out when used. So we can only use
// an item if it's still filled in and if the field is one we care
// about.
let use_item = match *item {
None => false,
Some((ref c, _)) => c.as_str().map_or(false, |x| {
variants.iter().any(|v| v.eq_ignore_ascii_case(x))
}),
};

if use_item {
let (key, value) = item.take().unwrap();
return visitor.visit_enum(EnumDeserializer::new(key, Some(value)));
}
}

Err(Error::custom(format_args!(
"no variant of enum {} found in flattened data",
name
)))
}

fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_map(FlatMapAccess::new(self.0.iter()))
}

fn deserialize_struct<V>(
self,
_: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_map(CaseInsensitiveFlatStructAccess::new(
self.0.iter_mut(),
fields,
))
}

fn deserialize_newtype_struct<V>(self, _name: &str, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}

fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match visitor.__private_visit_untagged_option(self) {
Ok(value) => Ok(value),
Err(()) => Self::deserialize_other(),
}
}

fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}

forward_to_deserialize_other! {
deserialize_bool()
deserialize_i8()
deserialize_i16()
deserialize_i32()
deserialize_i64()
deserialize_u8()
deserialize_u16()
deserialize_u32()
deserialize_u64()
deserialize_f32()
deserialize_f64()
deserialize_char()
deserialize_str()
deserialize_string()
deserialize_bytes()
deserialize_byte_buf()
deserialize_unit_struct(&'static str)
deserialize_seq()
deserialize_tuple(usize)
deserialize_tuple_struct(&'static str, usize)
deserialize_identifier()
deserialize_ignored_any()
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
pub struct FlatMapAccess<'a, 'de: 'a, E> {
iter: slice::Iter<'a, Option<(Content<'de>, Content<'de>)>>,
Expand Down Expand Up @@ -2911,6 +3049,73 @@ where
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
pub struct CaseInsensitiveFlatStructAccess<'a, 'de: 'a, E> {
iter: slice::IterMut<'a, Option<(Content<'de>, Content<'de>)>>,
pending_content: Option<Content<'de>>,
fields: &'static [&'static str],
_marker: PhantomData<E>,
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl<'a, 'de, E> CaseInsensitiveFlatStructAccess<'a, 'de, E> {
fn new(
iter: slice::IterMut<'a, Option<(Content<'de>, Content<'de>)>>,
fields: &'static [&'static str],
) -> CaseInsensitiveFlatStructAccess<'a, 'de, E> {
CaseInsensitiveFlatStructAccess {
iter: iter,
pending_content: None,
fields: fields,
_marker: PhantomData,
}
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl<'a, 'de, E> MapAccess<'de> for CaseInsensitiveFlatStructAccess<'a, 'de, E>
where
E: Error,
{
type Error = E;

fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
while let Some(item) = self.iter.next() {
// items in the vector are nulled out when used. So we can only use
// an item if it's still filled in and if the field is one we care
// about. In case we do not know which fields we want, we take them all.
let use_item = match *item {
None => false,
Some((ref c, _)) => c.as_str().map_or(false, |key| {
self.fields
.iter()
.any(|field| field.eq_ignore_ascii_case(key))
}),
};

if use_item {
let (key, content) = item.take().unwrap();
self.pending_content = Some(content);
return seed.deserialize(ContentDeserializer::new(key)).map(Some);
}
}
Ok(None)
}

fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
match self.pending_content.take() {
Some(value) => seed.deserialize(ContentDeserializer::new(value)),
None => Err(Error::custom("value is missing")),
}
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
pub struct FlatInternallyTaggedAccess<'a, 'de: 'a, E> {
iter: slice::IterMut<'a, Option<(Content<'de>, Content<'de>)>>,
Expand Down
80 changes: 67 additions & 13 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,7 @@ fn deserialize_generated_identifier(
is_variant,
fallthrough,
!is_variant && cattrs.has_flatten(),
cattrs.case_insensitive(),
));

let lifetime = if !is_variant && cattrs.has_flatten() {
Expand Down Expand Up @@ -2017,6 +2018,7 @@ fn deserialize_custom_identifier(
is_variant,
fallthrough,
false,
cattrs.case_insensitive(),
));

quote_block! {
Expand Down Expand Up @@ -2047,6 +2049,7 @@ fn deserialize_identifier(
is_variant: bool,
fallthrough: Option<TokenStream>,
collect_other_fields: bool,
case_insensitive: bool,
) -> Fragment {
let mut flat_fields = Vec::new();
for (_, ident, aliases) in fields {
Expand Down Expand Up @@ -2123,6 +2126,55 @@ fn deserialize_identifier(
}
};

let (visit_arms_str, visit_arms_borrowed_str, visit_arms_bytes, visit_arms_borrowed_bytes) =
if case_insensitive {
(
quote! {
#(
__value if __value.eq_ignore_ascii_case(#field_strs) => _serde::__private::Ok(#constructors),
)*
},
quote! {
#(
__value if __value.eq_ignore_ascii_case(#field_borrowed_strs) => _serde::__private::Ok(#constructors),
)*
},
quote! {
#(
__value if __value.eq_ignore_ascii_case(#field_bytes) => _serde::__private::Ok(#constructors),
)*
},
quote! {
#(
__value if __value.eq_ignore_ascii_case(#field_borrowed_bytes) => _serde::__private::Ok(#constructors),
)*
},
)
} else {
(
quote! {
#(
#field_strs => _serde::__private::Ok(#constructors),
)*
},
quote! {
#(
#field_borrowed_strs => _serde::__private::Ok(#constructors),
)*
},
quote! {
#(
#field_bytes => _serde::__private::Ok(#constructors),
)*
},
quote! {
#(
#field_borrowed_bytes => _serde::__private::Ok(#constructors),
)*
},
)
};

let variant_indices = 0_u64..;
let fallthrough_msg = format!("{} index 0 <= i < {}", index_expecting, fields.len());
let visit_other = if collect_other_fields {
Expand Down Expand Up @@ -2223,9 +2275,7 @@ fn deserialize_identifier(
__E: _serde::de::Error,
{
match __value {
#(
#field_borrowed_strs => _serde::__private::Ok(#constructors),
)*
#visit_arms_borrowed_str
_ => {
#value_as_borrowed_str_content
#fallthrough_arm
Expand All @@ -2238,9 +2288,7 @@ fn deserialize_identifier(
__E: _serde::de::Error,
{
match __value {
#(
#field_borrowed_bytes => _serde::__private::Ok(#constructors),
)*
#visit_arms_borrowed_bytes
_ => {
#bytes_to_str
#value_as_borrowed_bytes_content
Expand Down Expand Up @@ -2280,9 +2328,7 @@ fn deserialize_identifier(
__E: _serde::de::Error,
{
match __value {
#(
#field_strs => _serde::__private::Ok(#constructors),
)*
#visit_arms_str
_ => {
#value_as_str_content
#fallthrough_arm
Expand All @@ -2295,9 +2341,7 @@ fn deserialize_identifier(
__E: _serde::de::Error,
{
match __value {
#(
#field_bytes => _serde::__private::Ok(#constructors),
)*
#visit_arms_bytes
_ => {
#bytes_to_str
#value_as_bytes_content
Expand Down Expand Up @@ -2496,6 +2540,16 @@ fn deserialize_map(
}
});

let flat_map_deserializer = if cattrs.case_insensitive() {
quote! {
_serde::__private::de::CaseInsensitiveFlatMapDeserializer
}
} else {
quote! {
_serde::__private::de::FlatMapDeserializer
}
};

let extract_collected = fields_names
.iter()
.filter(|&&(field, _)| field.attrs.flatten() && !field.attrs.skip_deserializing())
Expand All @@ -2510,7 +2564,7 @@ fn deserialize_map(
};
quote! {
let #name: #field_ty = try!(#func(
_serde::__private::de::FlatMapDeserializer(
#flat_map_deserializer(
&mut __collect,
_serde::__private::PhantomData)));
}
Expand Down
Loading