diff --git a/src/deserializer.rs b/src/deserializer.rs index c7c3217..a68eb86 100644 --- a/src/deserializer.rs +++ b/src/deserializer.rs @@ -1205,6 +1205,26 @@ mod tests { ); } + #[test] + fn flatten_map() { + #[derive(Deserialize, Debug, PartialEq)] + struct Row { + x: f64, + y: f64, + #[serde(flatten)] + extra: HashMap, + } + + let header = StringRecord::from(vec!["x", "y", "prop1", "prop2"]); + let record = StringRecord::from(vec!["1", "2", "3", "4"]); + let got: Row = record.deserialize(Some(&header)).unwrap(); + let mut extra = HashMap::new(); + extra.insert("prop1".to_string(), 3.0); + extra.insert("prop2".to_string(), 4.0); + + assert_eq!(got, Row { x: 1.0, y: 2.0, extra }); + } + #[test] fn partially_invalid_utf8() { #[derive(Debug, Deserialize, PartialEq)] diff --git a/src/serializer.rs b/src/serializer.rs index a3f9ff8..81760a6 100644 --- a/src/serializer.rs +++ b/src/serializer.rs @@ -200,12 +200,7 @@ impl<'a, 'w, W: io::Write> Serializer for &'a mut SeRecord<'w, W> { self, _len: Option, ) -> Result { - // The right behavior for serializing maps isn't clear. - Err(Error::custom( - "serializing maps is not supported, \ - if you have a use case, please file an issue at \ - https://github.com/BurntSushi/rust-csv", - )) + Ok(self) } fn serialize_struct( @@ -297,20 +292,21 @@ impl<'a, 'w, W: io::Write> SerializeMap for &'a mut SeRecord<'w, W> { fn serialize_key( &mut self, - _key: &T, + key: &T, ) -> Result<(), Self::Error> { - unreachable!() + self.wtr.check_map_key(key) } fn serialize_value( &mut self, - _value: &T, + value: &T, ) -> Result<(), Self::Error> { - unreachable!() + value.serialize(&mut **self) } fn end(self) -> Result { - unreachable!() + self.wtr.on_map_end(); + Ok(()) } } @@ -646,12 +642,7 @@ impl<'a, 'w, W: io::Write> Serializer for &'a mut SeHeader<'w, W> { self, _len: Option, ) -> Result { - // The right behavior for serializing maps isn't clear. - Err(Error::custom( - "serializing maps is not supported, \ - if you have a use case, please file an issue at \ - https://github.com/BurntSushi/rust-csv", - )) + self.handle_container("map") } fn serialize_struct( @@ -743,20 +734,45 @@ impl<'a, 'w, W: io::Write> SerializeMap for &'a mut SeHeader<'w, W> { fn serialize_key( &mut self, - _key: &T, + key: &T, ) -> Result<(), Self::Error> { - unreachable!() + // Grab old state and update state to `EncounteredStructField`. + let old_state = + mem::replace(&mut self.state, HeaderState::EncounteredStructField); + if let HeaderState::ErrorIfWrite(err) = old_state { + return Err(err); + } + + self.wtr.check_map_key(key)?; + self.state = HeaderState::InStructField; + key.serialize(&mut **self)?; // This does not actually serialize anything, just checks that the key is a scalar value. + if let HeaderState::ErrorIfWrite(err) = + mem::replace(&mut self.state, HeaderState::InStructField) + { + return Err(err); + } + let mut key_serializer = SeRecord { wtr: self.wtr }; + key.serialize(&mut key_serializer)?; + Ok(()) } fn serialize_value( &mut self, - _value: &T, + value: &T, ) -> Result<(), Self::Error> { - unreachable!() + if !matches!(self.state, HeaderState::InStructField) { + return Err(Error::new(ErrorKind::Serialize( + "Attempted to serialize value without key".to_string(), + ))); + } + value.serialize(&mut **self)?; + self.state = HeaderState::EncounteredStructField; + Ok(()) } fn end(self) -> Result { - unreachable!() + self.wtr.on_map_end(); + Ok(()) } } @@ -809,6 +825,8 @@ impl<'a, 'w, W: io::Write> SerializeStructVariant for &'a mut SeHeader<'w, W> { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use {bstr::ByteSlice, serde::Serialize}; use crate::{ @@ -847,6 +865,18 @@ mod tests { s.serialize(&mut SeHeader::new(&mut wtr)).unwrap_err() } + #[derive(Debug)] + struct CustomOrderMap(Vec<(&'static str, f64)>); + + impl Serialize for CustomOrderMap { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_map(self.0.iter().copied()) + } + } + #[test] fn bool() { let got = serialize(true); @@ -1117,6 +1147,58 @@ mod tests { } } + #[test] + fn ordered_map() { + let mut map = BTreeMap::new(); + map.insert("a", 2.0); + map.insert("b", 1.0); + + let got = serialize(&map); + assert_eq!(got, "2.0,1.0\n"); + let (wrote, got) = serialize_header(map); + assert!(wrote); + assert_eq!(got, "a,b"); + } + + #[test] + fn ordered_map_with_collection_as_key() { + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] + struct MyKey { + name: &'static str, + other_attribute: u8, + } + + let mut map = BTreeMap::new(); + map.insert(MyKey { name: "a", other_attribute: 1 }, 2.0); + + let error = serialize_header_err(map); + assert!( + matches!(error.kind(), ErrorKind::Serialize(_)), + "Expected ErrorKind::Serialize but got '{error}'" + ); + } + + #[test] + fn unordered_map() { + let mut writer = Writer::from_writer(vec![]); + writer + .serialize(CustomOrderMap(vec![("a", 2.0), ("b", 1.0)])) + .unwrap(); + writer + .serialize(CustomOrderMap(vec![("a", 3.0), ("b", 4.0)])) + .unwrap(); + writer.flush().unwrap(); + let csv = String::from_utf8(writer.get_ref().clone()).unwrap(); + assert_eq!(csv, "a,b\n2.0,1.0\n3.0,4.0\n"); + let error = writer + .serialize(CustomOrderMap(vec![("b", 2.0), ("a", 1.0)])) // Wrong key order + .unwrap_err(); + assert!( + matches!(error.kind(), ErrorKind::Serialize(_)), + "Got unexpected error: {error}" + ) + } + #[test] fn struct_no_headers() { #[derive(Serialize)] @@ -1325,4 +1407,125 @@ mod tests { assert!(wrote); assert_eq!(got, "label,num,label2,value,empty,label,num"); } + + #[test] + fn flatten() { + #[derive(Clone, Serialize, Debug, PartialEq)] + struct Input { + x: f64, + y: f64, + } + + #[derive(Clone, Serialize, Debug, PartialEq)] + struct Properties { + prop1: f64, + prop2: f64, + } + + #[derive(Clone, Serialize, Debug, PartialEq)] + struct Row { + #[serde(flatten)] + input: Input, + #[serde(flatten)] + properties: Properties, + } + let row = Row { + input: Input { x: 1.0, y: 2.0 }, + properties: Properties { prop1: 3.0, prop2: 4.0 }, + }; + + let got = serialize(row.clone()); + assert_eq!(got, "1.0,2.0,3.0,4.0\n"); + + let (wrote, got) = serialize_header(row.clone()); + assert!(wrote); + assert_eq!(got, "x,y,prop1,prop2"); + } + + #[test] + fn flatten_map() { + #[derive(Clone, Serialize, Debug, PartialEq)] + struct Row { + x: f64, + y: f64, + #[serde(flatten)] + extra: BTreeMap<&'static str, f64>, + } + let mut extra = BTreeMap::new(); + extra.insert("extra1", 3.0); + extra.insert("extra2", 4.0); + let row = Row { x: 1.0, y: 2.0, extra }; + + let got = serialize(row.clone()); + assert_eq!(got, format!("1.0,2.0,3.0,4.0\n")); + + let (wrote, got) = serialize_header(row.clone()); + assert!(wrote); + assert_eq!(got, format!("x,y,extra1,extra2")); + } + + #[test] + fn flatten_map_with_different_key_order() { + #[derive(Serialize, Debug)] + struct Row { + x: f64, + y: f64, + #[serde(flatten)] + extra: CustomOrderMap, + } + let mut writer = Writer::from_writer(vec![]); + writer + .serialize(Row { + x: 1.0, + y: 2.0, + extra: CustomOrderMap(vec![("extra1", 3.0), ("extra2", 4.0)]), + }) + .unwrap(); + let error = writer + .serialize(Row { + x: 1.0, + y: 2.0, + extra: CustomOrderMap(vec![("extra2", 4.0), ("extra1", 3.0)]), + }) + .unwrap_err(); + assert!( + matches!(error.kind(), ErrorKind::Serialize(_)), + "Expected ErrorKind::Serialize but got '{error}'" + ); + } + + #[test] + fn flatten_map_different_num_entries() { + #[derive(Clone, Serialize, Debug, PartialEq)] + struct Row { + x: f64, + y: f64, + #[serde(flatten)] + extra: BTreeMap<&'static str, f64>, + } + let mut writer = Writer::from_writer(vec![]); + + let mut extra = BTreeMap::new(); + extra.insert("extra1", 3.0); + extra.insert("extra2", 4.0); + let row = Row { x: 1.0, y: 2.0, extra }; + writer.serialize(row).unwrap(); + + let mut extra = BTreeMap::new(); + extra.insert("extra1", 3.0); + extra.insert("extra2", 4.0); + extra.insert("extra3", 5.0); + let row = Row { x: 1.0, y: 2.0, extra }; + let error = writer.serialize(row).unwrap_err(); + match *error.kind() { + ErrorKind::UnequalLengths { + pos: None, + expected_len: 4, + len: 5, + } => {} + ref x => { + panic!("expected ErrorKind::UnequalLengths but got '{x:?}'") + } + } + } } diff --git a/src/writer.rs b/src/writer.rs index 195e663..0bbd3e4 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -540,6 +540,15 @@ pub struct Writer { state: WriterState, } +/// State for tracking headers while writing. +#[derive(Debug)] +struct HeaderTrackingState { + /// The serialized headers in their expected order. + expected_headers: Vec>, + /// The index into the `expected_headers` list of the next expected header. + next_expected_index: usize, +} + #[derive(Debug)] struct WriterState { /// Whether the Serde serializer should attempt to write a header row. @@ -557,6 +566,9 @@ struct WriterState { /// immediately after flushing the buffer. This avoids flushing the buffer /// twice if the inner writer panics. panicked: bool, + /// Header tracking state for map like data, to ensure that column order + /// is preserved across all rows. + header_tracking: Option, } /// HeaderState encodes a small state machine for handling header writes. @@ -638,6 +650,10 @@ impl Writer { first_field_count: None, fields_written: 0, panicked: false, + header_tracking: Some(HeaderTrackingState { + expected_headers: Vec::new(), + next_expected_index: 0, + }), }, } } @@ -1180,6 +1196,47 @@ impl Writer { } Ok(()) } + + /// Track the `key` of a map entry. If this is not the first row, also verify that the + /// `key` matches the expected next key. + pub(crate) fn check_map_key( + &mut self, + key: &T, + ) -> Result<()> { + let Some(tracking) = &mut self.state.header_tracking else { + return Ok(()); + }; + let mut encoded_key_serializer = Writer::from_writer(Vec::new()); + serialize(&mut encoded_key_serializer, key)?; + let encoded_key = + encoded_key_serializer.into_inner().map_err(|error| { + Error::new(ErrorKind::Serialize(format!( + "Failed to serialize key to bytes: {error:?}" + ))) + })?; + if let Some(expected_key) = + tracking.expected_headers.get(tracking.next_expected_index) + { + if expected_key != &encoded_key { + return Err(Error::new(ErrorKind::Serialize(format!( + "Out of order key `{}`", + String::from_utf8_lossy(&encoded_key) + )))); + } + } else { + // Even if this is not the first row, accept more keys. If the writer is flexible then adding more fields is allowed. + tracking.expected_headers.push(encoded_key); + } + tracking.next_expected_index += 1; + Ok(()) + } + + /// Reset the map key tracking at the end of a row. + pub(crate) fn on_map_end(&mut self) { + if let Some(tracking) = &mut self.state.header_tracking { + tracking.next_expected_index = 0; + } + } } impl Buffer {