diff --git a/json/src/de.rs b/json/src/de.rs index 67e3df5cfc..a4d6aa511d 100644 --- a/json/src/de.rs +++ b/json/src/de.rs @@ -9,14 +9,82 @@ use std::str; use std::marker::PhantomData; use serde::de; -use serde::iter::LineColIterator; use super::error::{Error, ErrorCode, Result}; +use read::{self, Read}; + +////////////////////////////////////////////////////////////////////////////// + /// A structure that deserializes JSON into Rust values. -pub struct Deserializer>> { - rdr: LineColIterator, - ch: Option, +pub struct Deserializer(DeserializerImpl>) + where Iter: Iterator>; + +impl Deserializer + where Iter: Iterator>, +{ + /// Creates the JSON parser from an `std::iter::Iterator`. + #[inline] + pub fn new(rdr: Iter) -> Self { + Deserializer(DeserializerImpl::new(read::IteratorRead::new(rdr))) + } + + /// The `Deserializer::end` method should be called after a value has been fully deserialized. + /// This allows the `Deserializer` to validate that the input stream is at the end or that it + /// only has trailing whitespace. + #[inline] + pub fn end(&mut self) -> Result<()> { + self.0.end() + } +} + +impl de::Deserializer for Deserializer + where Iter: Iterator>, +{ + type Error = Error; + + #[inline] + fn deserialize(&mut self, visitor: V) -> Result + where V: de::Visitor, + { + self.0.deserialize(visitor) + } + + /// Parses a `null` as a None, and any other values as a `Some(...)`. + #[inline] + fn deserialize_option(&mut self, visitor: V) -> Result + where V: de::Visitor, + { + self.0.deserialize_option(visitor) + } + + /// Parses a newtype struct as the underlying value. + #[inline] + fn deserialize_newtype_struct(&mut self, + name: &'static str, + visitor: V) -> Result + where V: de::Visitor, + { + self.0.deserialize_newtype_struct(name, visitor) + } + + /// Parses an enum as an object like `{"$KEY":$VALUE}`, where $VALUE is either a straight + /// value, a `[..]`, or a `{..}`. + #[inline] + fn deserialize_enum(&mut self, + name: &'static str, + variants: &'static [&'static str], + visitor: V) -> Result + where V: de::EnumVisitor, + { + self.0.deserialize_enum(name, variants, visitor) + } +} + +////////////////////////////////////////////////////////////////////////////// + +struct DeserializerImpl { + read: R, str_buf: Vec, } @@ -29,29 +97,20 @@ macro_rules! try_or_invalid { } } -impl Deserializer - where Iter: Iterator>, -{ - /// Creates the JSON parser from an `std::iter::Iterator`. - #[inline] - pub fn new(rdr: Iter) -> Deserializer { - Deserializer { - rdr: LineColIterator::new(rdr), - ch: None, +impl DeserializerImpl { + fn new(read: R) -> Self { + DeserializerImpl { + read: read, str_buf: Vec::with_capacity(128), } } - /// The `Deserializer::end` method should be called after a value has been fully deserialized. - /// This allows the `Deserializer` to validate that the input stream is at the end or that it - /// only has trailing whitespace. - #[inline] - pub fn end(&mut self) -> Result<()> { + fn end(&mut self) -> Result<()> { try!(self.parse_whitespace()); if try!(self.eof()) { Ok(()) } else { - Err(self.error(ErrorCode::TrailingCharacters)) + Err(self.peek_error(ErrorCode::TrailingCharacters)) } } @@ -60,19 +119,7 @@ impl Deserializer } fn peek(&mut self) -> Result> { - match self.ch { - Some(ch) => Ok(Some(ch)), - None => { - match self.rdr.next() { - Some(Err(err)) => Err(Error::Io(err)), - Some(Ok(ch)) => { - self.ch = Some(ch); - Ok(self.ch) - } - None => Ok(None), - } - } - } + self.read.peek().map_err(Error::Io) } fn peek_or_null(&mut self) -> Result { @@ -80,28 +127,27 @@ impl Deserializer } fn eat_char(&mut self) { - self.ch = None; + self.read.discard(); } fn next_char(&mut self) -> Result> { - match self.ch.take() { - Some(ch) => Ok(Some(ch)), - None => { - match self.rdr.next() { - Some(Err(err)) => Err(Error::Io(err)), - Some(Ok(ch)) => Ok(Some(ch)), - None => Ok(None), - } - } - } + self.read.next().map_err(Error::Io) } fn next_char_or_null(&mut self) -> Result { Ok(try!(self.next_char()).unwrap_or(b'\x00')) } + /// Error caused by a byte from next_char(). fn error(&mut self, reason: ErrorCode) -> Error { - Error::Syntax(reason, self.rdr.line(), self.rdr.col()) + let pos = self.read.position(); + Error::Syntax(reason, pos.line, pos.column) + } + + /// Error caused by a byte from peek(). + fn peek_error(&mut self, reason: ErrorCode) -> Error { + let pos = self.read.peek_position(); + Error::Syntax(reason, pos.line, pos.column) } fn parse_whitespace(&mut self) -> Result<()> { @@ -121,7 +167,7 @@ impl Deserializer try!(self.parse_whitespace()); if try!(self.eof()) { - return Err(self.error(ErrorCode::EOFWhileParsingValue)); + return Err(self.peek_error(ErrorCode::EOFWhileParsingValue)); } let value = match try!(self.peek_or_null()) { @@ -162,13 +208,19 @@ impl Deserializer visitor.visit_map(MapVisitor::new(self)) } _ => { - Err(self.error(ErrorCode::ExpectedSomeValue)) + Err(self.peek_error(ErrorCode::ExpectedSomeValue)) } }; match value { Ok(value) => Ok(value), - Err(Error::Syntax(code, _, _)) => Err(self.error(code)), + // The de::Error and From impls both create errors + // with unknown line and column. Fill in the position here by + // looking at the current index in the input. There is no way to + // tell whether this should call `error` or `peek_error` so pick the + // one that seems correct more often. Worst case, the position is + // off by one character. + Err(Error::Syntax(code, 0, 0)) => Err(self.error(code)), Err(err) => Err(err), } } @@ -191,7 +243,7 @@ impl Deserializer // There can be only one leading '0'. match try!(self.peek_or_null()) { b'0' ... b'9' => { - Err(self.error(ErrorCode::InvalidNumber)) + Err(self.peek_error(ErrorCode::InvalidNumber)) } _ => { self.parse_number(pos, 0, visitor) @@ -378,7 +430,7 @@ impl Deserializer let exp = if exp <= i32::MAX as u64 { 10_f64.powi(exp as i32) } else { - return Err(self.error(ErrorCode::InvalidNumber)); + return Err(self.peek_error(ErrorCode::InvalidNumber)); }; if pos_exp { @@ -517,15 +569,13 @@ impl Deserializer self.eat_char(); Ok(()) } - Some(_) => Err(self.error(ErrorCode::ExpectedColon)), - None => Err(self.error(ErrorCode::EOFWhileParsingObject)), + Some(_) => Err(self.peek_error(ErrorCode::ExpectedColon)), + None => Err(self.peek_error(ErrorCode::EOFWhileParsingObject)), } } } -impl de::Deserializer for Deserializer - where Iter: Iterator>, -{ +impl de::Deserializer for DeserializerImpl { type Error = Error; #[inline] @@ -599,19 +649,19 @@ impl de::Deserializer for Deserializer visitor.visit(KeyOnlyVariantVisitor::new(self)) } _ => { - Err(self.error(ErrorCode::ExpectedSomeValue)) + Err(self.peek_error(ErrorCode::ExpectedSomeValue)) } } } } -struct SeqVisitor<'a, Iter: 'a + Iterator>> { - de: &'a mut Deserializer, +struct SeqVisitor<'a, R: Read + 'a> { + de: &'a mut DeserializerImpl, first: bool, } -impl<'a, Iter: Iterator>> SeqVisitor<'a, Iter> { - fn new(de: &'a mut Deserializer) -> Self { +impl<'a, R: Read + 'a> SeqVisitor<'a, R> { + fn new(de: &'a mut DeserializerImpl) -> Self { SeqVisitor { de: de, first: true, @@ -619,9 +669,7 @@ impl<'a, Iter: Iterator>> SeqVisitor<'a, Iter> { } } -impl<'a, Iter> de::SeqVisitor for SeqVisitor<'a, Iter> - where Iter: Iterator>, -{ +impl<'a, R: Read + 'a> de::SeqVisitor for SeqVisitor<'a, R> { type Error = Error; fn visit(&mut self) -> Result> @@ -640,11 +688,11 @@ impl<'a, Iter> de::SeqVisitor for SeqVisitor<'a, Iter> if self.first { self.first = false; } else { - return Err(self.de.error(ErrorCode::ExpectedListCommaOrEnd)); + return Err(self.de.peek_error(ErrorCode::ExpectedListCommaOrEnd)); } } None => { - return Err(self.de.error(ErrorCode::EOFWhileParsingList)); + return Err(self.de.peek_error(ErrorCode::EOFWhileParsingList)); } } @@ -667,13 +715,13 @@ impl<'a, Iter> de::SeqVisitor for SeqVisitor<'a, Iter> } } -struct MapVisitor<'a, Iter: 'a + Iterator>> { - de: &'a mut Deserializer, +struct MapVisitor<'a, R: Read + 'a> { + de: &'a mut DeserializerImpl, first: bool, } -impl<'a, Iter: Iterator>> MapVisitor<'a, Iter> { - fn new(de: &'a mut Deserializer) -> Self { +impl<'a, R: Read + 'a> MapVisitor<'a, R> { + fn new(de: &'a mut DeserializerImpl) -> Self { MapVisitor { de: de, first: true, @@ -681,9 +729,7 @@ impl<'a, Iter: Iterator>> MapVisitor<'a, Iter> { } } -impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter> - where Iter: Iterator> -{ +impl<'a, R: Read + 'a> de::MapVisitor for MapVisitor<'a, R> { type Error = Error; fn visit_key(&mut self) -> Result> @@ -703,11 +749,11 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter> if self.first { self.first = false; } else { - return Err(self.de.error(ErrorCode::ExpectedObjectCommaOrEnd)); + return Err(self.de.peek_error(ErrorCode::ExpectedObjectCommaOrEnd)); } } None => { - return Err(self.de.error(ErrorCode::EOFWhileParsingObject)); + return Err(self.de.peek_error(ErrorCode::EOFWhileParsingObject)); } } @@ -716,10 +762,10 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter> Ok(Some(try!(de::Deserialize::deserialize(self.de)))) } Some(_) => { - Err(self.de.error(ErrorCode::KeyMustBeAString)) + Err(self.de.peek_error(ErrorCode::KeyMustBeAString)) } None => { - Err(self.de.error(ErrorCode::EOFWhileParsingValue)) + Err(self.de.peek_error(ErrorCode::EOFWhileParsingValue)) } } } @@ -776,21 +822,19 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter> } } -struct VariantVisitor<'a, Iter: 'a + Iterator>> { - de: &'a mut Deserializer, +struct VariantVisitor<'a, R: Read + 'a> { + de: &'a mut DeserializerImpl, } -impl<'a, Iter: Iterator>> VariantVisitor<'a, Iter> { - fn new(de: &'a mut Deserializer) -> Self { +impl<'a, R: Read + 'a> VariantVisitor<'a, R> { + fn new(de: &'a mut DeserializerImpl) -> Self { VariantVisitor { de: de, } } } -impl<'a, Iter> de::VariantVisitor for VariantVisitor<'a, Iter> - where Iter: Iterator>, -{ +impl<'a, R: Read + 'a> de::VariantVisitor for VariantVisitor<'a, R> { type Error = Error; fn visit_variant(&mut self) -> Result @@ -828,21 +872,19 @@ impl<'a, Iter> de::VariantVisitor for VariantVisitor<'a, Iter> } } -struct KeyOnlyVariantVisitor<'a, Iter: 'a + Iterator>> { - de: &'a mut Deserializer, +struct KeyOnlyVariantVisitor<'a, R: Read + 'a> { + de: &'a mut DeserializerImpl, } -impl<'a, Iter: Iterator>> KeyOnlyVariantVisitor<'a, Iter> { - fn new(de: &'a mut Deserializer) -> Self { +impl<'a, R: Read + 'a> KeyOnlyVariantVisitor<'a, R> { + fn new(de: &'a mut DeserializerImpl) -> Self { KeyOnlyVariantVisitor { de: de, } } } -impl<'a, Iter> de::VariantVisitor for KeyOnlyVariantVisitor<'a, Iter> - where Iter: Iterator>, -{ +impl<'a, R: Read + 'a> de::VariantVisitor for KeyOnlyVariantVisitor<'a, R> { type Error = Error; fn visit_variant(&mut self) -> Result @@ -885,7 +927,7 @@ pub struct StreamDeserializer where Iter: Iterator>, T: de::Deserialize { - deser: Deserializer, + deser: DeserializerImpl>, _marker: PhantomData, } @@ -897,7 +939,7 @@ impl StreamDeserializer /// `Iterator>`. pub fn new(iter: Iter) -> StreamDeserializer { StreamDeserializer { - deser: Deserializer::new(iter), + deser: DeserializerImpl::new(read::IteratorRead::new(iter)), _marker: PhantomData } } @@ -936,7 +978,7 @@ pub fn from_iter(iter: I) -> Result where I: Iterator>, T: de::Deserialize, { - let mut de = Deserializer::new(iter); + let mut de = DeserializerImpl::new(read::IteratorRead::new(iter)); let value = try!(de::Deserialize::deserialize(&mut de)); // Make sure the whole stream has been consumed. @@ -956,7 +998,12 @@ pub fn from_reader(rdr: R) -> Result pub fn from_slice(v: &[u8]) -> Result where T: de::Deserialize { - from_iter(v.iter().map(|byte| Ok(*byte))) + let mut de = DeserializerImpl::new(read::SliceRead::new(v)); + let value = try!(de::Deserialize::deserialize(&mut de)); + + // Make sure the whole stream has been consumed. + try!(de.end()); + Ok(value) } /// Decodes a json value from a `&str`. diff --git a/json/src/lib.rs b/json/src/lib.rs index a497b71f26..91c8000e1b 100644 --- a/json/src/lib.rs +++ b/json/src/lib.rs @@ -150,3 +150,5 @@ pub mod de; pub mod error; pub mod ser; pub mod value; + +mod read; diff --git a/json/src/read.rs b/json/src/read.rs new file mode 100644 index 0000000000..5445ba7e26 --- /dev/null +++ b/json/src/read.rs @@ -0,0 +1,186 @@ +use std::{cmp, io}; + +use serde::iter::LineColIterator; + +/// Trait used by the deserializer for iterating over input. This is manually +/// "specialized" for iterating over &[u8]. Once feature(specialization) is +/// stable we can use actual specialization. +pub trait Read { + fn next(&mut self) -> io::Result>; + fn peek(&mut self) -> io::Result>; + + /// Only valid after a call to peek(). Discards the peeked byte. + fn discard(&mut self); + + /// Position of the most recent call to next(). + /// + /// The most recent call was probably next() and not peek(), but this method + /// should try to return a sensible result if the most recent call was + /// actually peek() because we don't always know. + /// + /// Only called in case of an error, so performance is not important. + fn position(&self) -> Position; + + /// Position of the most recent call to peek(). + /// + /// The most recent call was probably peek() and not next(), but this method + /// should try to return a sensible result if the most recent call was + /// actually next() because we don't always know. + /// + /// Only called in case of an error, so performance is not important. + fn peek_position(&self) -> Position; +} + +pub struct Position { + pub line: usize, + pub column: usize, +} + +pub struct IteratorRead where Iter: Iterator> { + iter: LineColIterator, + /// Temporary storage of peeked byte. + ch: Option, +} + +/// Specialization for Iter=&[u8]. This is more efficient than other iterators +/// because peek() can be read-only and we can compute line/col position only if +/// an error happens. +pub struct SliceRead<'a> { + slice: &'a [u8], + /// Index of the *next* byte that will be returned by next() or peek(). + index: usize, +} + +////////////////////////////////////////////////////////////////////////////// + +impl IteratorRead + where Iter: Iterator>, +{ + pub fn new(iter: Iter) -> Self { + IteratorRead { + iter: LineColIterator::new(iter), + ch: None, + } + } +} + +impl Read for IteratorRead + where Iter: Iterator>, +{ + #[inline] + fn next(&mut self) -> io::Result> { + match self.ch.take() { + Some(ch) => Ok(Some(ch)), + None => { + match self.iter.next() { + Some(Err(err)) => Err(err), + Some(Ok(ch)) => Ok(Some(ch)), + None => Ok(None), + } + } + } + } + + #[inline] + fn peek(&mut self) -> io::Result> { + match self.ch { + Some(ch) => Ok(Some(ch)), + None => { + match self.iter.next() { + Some(Err(err)) => Err(err), + Some(Ok(ch)) => { + self.ch = Some(ch); + Ok(self.ch) + } + None => Ok(None), + } + } + } + } + + #[inline] + fn discard(&mut self) { + self.ch = None; + } + + fn position(&self) -> Position { + Position { + line: self.iter.line(), + column: self.iter.col(), + } + } + + fn peek_position(&self) -> Position { + // The LineColIterator updates its position during peek() so it has the + // right one here. + self.position() + } +} + +////////////////////////////////////////////////////////////////////////////// + +impl<'a> SliceRead<'a> { + pub fn new(slice: &'a [u8]) -> Self { + SliceRead { + slice: slice, + index: 0, + } + } + + fn position_of_index(&self, i: usize) -> Position { + let mut pos = Position { line: 1, column: 0 }; + for ch in &self.slice[..i] { + match *ch { + b'\n' => { + pos.line += 1; + pos.column = 0; + } + _ => { + pos.column += 1; + } + } + } + pos + } +} + +impl<'a> Read for SliceRead<'a> { + #[inline] + fn next(&mut self) -> io::Result> { + // `Ok(self.slice.get(self.index).map(|ch| { self.index += 1; *ch }))` + // is about 10% slower. + Ok(if self.index < self.slice.len() { + let ch = self.slice[self.index]; + self.index += 1; + Some(ch) + } else { + None + }) + } + + #[inline] + fn peek(&mut self) -> io::Result> { + // `Ok(self.slice.get(self.index).map(|ch| *ch))` is about 10% slower + // for some reason. + Ok(if self.index < self.slice.len() { + Some(self.slice[self.index]) + } else { + None + }) + } + + #[inline] + fn discard(&mut self) { + self.index += 1; + } + + fn position(&self) -> Position { + self.position_of_index(self.index) + } + + fn peek_position(&self) -> Position { + // Cap it at slice.len() just in case the most recent call was next() + // and it returned the last byte. + self.position_of_index(cmp::min(self.slice.len(), self.index + 1)) + } +} diff --git a/json_tests/tests/test_json.rs b/json_tests/tests/test_json.rs index d83d0759c9..90207de1ce 100644 --- a/json_tests/tests/test_json.rs +++ b/json_tests/tests/test_json.rs @@ -13,6 +13,7 @@ use serde_json::{ StreamDeserializer, Value, Map, + from_iter, from_str, from_value, to_value, @@ -653,6 +654,9 @@ fn test_parse_ok(errors: Vec<(&str, T)>) let v: T = from_str(s).unwrap(); assert_eq!(v, value.clone()); + let v: T = from_iter(s.bytes().map(Ok)).unwrap(); + assert_eq!(v, value.clone()); + // Make sure we can deserialize into a `Value`. let json_value: Value = from_str(s).unwrap(); assert_eq!(json_value, to_value(&value)); @@ -675,6 +679,9 @@ fn test_parse_unusual_ok(errors: Vec<(&str, T)>) for (s, value) in errors { let v: T = from_str(s).unwrap(); assert_eq!(v, value.clone()); + + let v: T = from_iter(s.bytes().map(Ok)).unwrap(); + assert_eq!(v, value.clone()); } } @@ -682,19 +689,27 @@ fn test_parse_unusual_ok(errors: Vec<(&str, T)>) fn test_parse_err(errors: Vec<(&'static str, Error)>) where T: Debug + PartialEq + de::Deserialize, { - for (s, err) in errors { - match (err, from_str::(s).unwrap_err()) { + for &(s, ref err) in &errors { + match (err, &from_str::(s).unwrap_err()) { ( - Error::Syntax(expected_code, expected_line, expected_col), - Error::Syntax(actual_code, actual_line, actual_col), - ) => { - assert_eq!( - (expected_code, expected_line, expected_col), - (actual_code, actual_line, actual_col) - ) + &Error::Syntax(ref expected_code, expected_line, expected_col), + &Error::Syntax(ref actual_code, actual_line, actual_col), + ) if expected_code == actual_code + && expected_line == actual_line + && expected_col == actual_col => { /* pass */ } + (expected_err, actual_err) => { + panic!("unexpected from_str error: {}, expected: {}", actual_err, expected_err) } + } + match (err, &from_iter::<_, T>(s.bytes().map(Ok)).unwrap_err()) { + ( + &Error::Syntax(ref expected_code, expected_line, expected_col), + &Error::Syntax(ref actual_code, actual_line, actual_col), + ) if expected_code == actual_code + && expected_line == actual_line + && expected_col == actual_col => { /* pass */ } (expected_err, actual_err) => { - panic!("unexpected errors {} != {}", expected_err, actual_err) + panic!("unexpected from_iter error: {}, expected: {}", actual_err, expected_err) } } } @@ -1011,7 +1026,9 @@ fn test_parse_enum_errors() { ("{\"Dog\":[0]}", Error::Syntax(ErrorCode::TrailingCharacters, 1, 9)), ("\"Frog\"", Error::Syntax(ErrorCode::EOFWhileParsingValue, 1, 6)), ("{\"Frog\":{}}", Error::Syntax(ErrorCode::InvalidType(de::Type::Map), 1, 9)), - ("{\"Cat\":[]}", Error::Syntax(ErrorCode::EOFWhileParsingValue, 1, 9)), + ("{\"Cat\":[]}", Error::Syntax(ErrorCode::InvalidLength(0), 1, 9)), + ("{\"Cat\":[0]}", Error::Syntax(ErrorCode::InvalidLength(1), 1, 10)), + ("{\"Cat\":[0, \"\", 2]}", Error::Syntax(ErrorCode::TrailingCharacters, 1, 14)), ( "{\"Cat\":{\"age\": 5, \"name\": \"Kate\", \"foo\":\"bar\"}", Error::Syntax(ErrorCode::UnknownField("foo".to_string()), 1, 39)