diff --git a/src/de.rs b/src/de.rs index 7aad50b96..9975b40a5 100644 --- a/src/de.rs +++ b/src/de.rs @@ -2203,6 +2203,41 @@ where deserialize_numeric_key!(deserialize_f32, deserialize_f32); deserialize_numeric_key!(deserialize_f64); + fn deserialize_bool(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.eat_char(); + + let peek = match tri!(self.de.next_char()) { + Some(b) => b, + None => { + return Err(self.de.peek_error(ErrorCode::EofWhileParsingValue)); + } + }; + + let value = match peek { + b't' => { + tri!(self.de.parse_ident(b"rue\"")); + visitor.visit_bool(true) + } + b'f' => { + tri!(self.de.parse_ident(b"alse\"")); + visitor.visit_bool(false) + } + _ => { + self.de.scratch.clear(); + let s = tri!(self.de.read.parse_str(&mut self.de.scratch)); + Err(de::Error::invalid_type(Unexpected::Str(&s), &visitor)) + } + }; + + match value { + Ok(value) => Ok(value), + Err(err) => Err(self.de.fix_position(err)), + } + } + #[inline] fn deserialize_option(self, visitor: V) -> Result where @@ -2258,7 +2293,7 @@ where } forward_to_deserialize_any! { - bool char str string unit unit_struct seq tuple tuple_struct map struct + char str string unit unit_struct seq tuple tuple_struct map struct identifier ignored_any } } diff --git a/src/ser.rs b/src/ser.rs index 6bb6fd761..3742e0bef 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -827,8 +827,21 @@ where type SerializeStruct = Impossible<(), Error>; type SerializeStructVariant = Impossible<(), Error>; - fn serialize_bool(self, _value: bool) -> Result<()> { - Err(key_must_be_a_string()) + fn serialize_bool(self, value: bool) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(Error::io)); + tri!(self + .ser + .formatter + .write_bool(&mut self.ser.writer, value) + .map_err(Error::io)); + self.ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(Error::io) } fn serialize_i8(self, value: i8) -> Result<()> { diff --git a/src/value/de.rs b/src/value/de.rs index 2090dd009..1e8b5acbb 100644 --- a/src/value/de.rs +++ b/src/value/de.rs @@ -1183,6 +1183,22 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> { deserialize_numeric_key!(deserialize_i128, do_deserialize_i128); deserialize_numeric_key!(deserialize_u128, do_deserialize_u128); + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.key == "true" { + visitor.visit_bool(true) + } else if self.key == "false" { + visitor.visit_bool(false) + } else { + Err(serde::de::Error::invalid_type( + Unexpected::Str(&self.key), + &visitor, + )) + } + } + #[inline] fn deserialize_option(self, visitor: V) -> Result where @@ -1219,8 +1235,8 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> { } forward_to_deserialize_any! { - bool char str string bytes byte_buf unit unit_struct seq tuple - tuple_struct map struct identifier ignored_any + char str string bytes byte_buf unit unit_struct seq tuple tuple_struct + map struct identifier ignored_any } } diff --git a/src/value/ser.rs b/src/value/ser.rs index 6ca53d4c5..835fa9080 100644 --- a/src/value/ser.rs +++ b/src/value/ser.rs @@ -483,8 +483,8 @@ impl serde::Serializer for MapKeySerializer { value.serialize(self) } - fn serialize_bool(self, _value: bool) -> Result { - Err(key_must_be_a_string()) + fn serialize_bool(self, value: bool) -> Result { + Ok(value.to_string()) } fn serialize_i8(self, value: i8) -> Result { diff --git a/tests/test.rs b/tests/test.rs index 8d9a5942a..e548b7dae 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1654,17 +1654,6 @@ fn test_deserialize_from_stream() { assert_eq!(request, response); } -#[test] -fn test_serialize_rejects_bool_keys() { - let map = treemap!( - true => 2, - false => 4, - ); - - let err = to_vec(&map).unwrap_err(); - assert_eq!(err.to_string(), "key must be a string"); -} - #[test] fn test_serialize_rejects_adt_keys() { let map = treemap!( @@ -2018,6 +2007,14 @@ fn test_deny_non_finite_f64_key() { assert!(serde_json::to_value(map).is_err()); } +#[test] +fn test_boolean_key() { + let map = treemap!(false => 0, true => 1); + let j = r#"{"false":0,"true":1}"#; + test_encode_ok(&[(&map, j)]); + test_parse_ok(vec![(j, map)]); +} + #[test] fn test_borrowed_key() { let map: BTreeMap<&str, ()> = from_str("{\"borrowed\":null}").unwrap();