From 0dc62ca2ff5d2c1cc05ae4c4dc7c781556e5d7fd Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 11:37:26 -0400 Subject: [PATCH 01/20] add type_str --- src/lib.rs | 2 + src/type_str.rs | 639 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 641 insertions(+) create mode 100644 src/type_str.rs diff --git a/src/lib.rs b/src/lib.rs index 9cc711f..7d71001 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -132,8 +132,10 @@ mod header; mod serializable; mod npy_data; mod out_file; +mod type_str; pub use serializable::Serializable; pub use header::{DType, Field}; pub use npy_data::NpyData; pub use out_file::{to_file, OutFile}; +pub use type_str::{TypeStr, ParseTypeStrError}; diff --git a/src/type_str.rs b/src/type_str.rs new file mode 100644 index 0000000..3168411 --- /dev/null +++ b/src/type_str.rs @@ -0,0 +1,639 @@ +use std::fmt; + +/// Represents an Array Interface type-string. +/// +/// This is more or less the `DType` of a scalar type. +/// Exposes a `FromStr` impl for construction, and a `Display` impl for writing. +/// +/// ``` +/// # fn main() -> Result<(), Box> { +/// use npy::TypeStr; +/// +/// let ts = "|i1".parse::()?; +/// +/// assert_eq!(format!("{}", ts), "|i1"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TypeStr { + pub(crate) endianness: Endianness, + pub(crate) type_kind: TypeKind, + pub(crate) size: u64, + pub(crate) time_units: Option, +} + +/// Represents the first character in a type-string. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) enum Endianness { + /// Code `<`. + Little, + /// Code `>`. + Big, + /// Code `|`. Used when endianness is irrelevant. + /// + /// Only valid when the size is `1`, or when `kind` is `TypeKind::Other` + /// or `TypeKind::ByteStr`. + Irrelevant, +} + +impl Endianness { + fn from_char(s: char) -> Option { + match s { + '<' => Some(Endianness::Little), + '>' => Some(Endianness::Big), + '|' => Some(Endianness::Irrelevant), + _ => None, + } + } + + fn to_str(self) -> &'static str { + match self { + Endianness::Little => "<", + Endianness::Big => ">", + Endianness::Irrelevant => "|", + } + } +} + +impl Endianness { + pub(crate) fn of_machine() -> Self { + match i32::from_be(0x00_00_00_01) { + 0x00_00_00_01 => Endianness::Big, + 0x01_00_00_00 => Endianness::Little, + _ => unreachable!(), + } + } + + /// Returns `true` if byteorder swapping is necessary between two types. + pub(crate) fn requires_swap(self, other: Endianness) -> bool { + match (self, other) { + (Endianness::Little, Endianness::Big) | + (Endianness::Big, Endianness::Little) => true, + + _ => false, + } + } +} + +/// Represents the second character in a type-string. +/// +/// Indicates the type of data stored. Affects the interpretation of `size` and `endianness`. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) enum TypeKind { + /// Code `b`. + /// + /// `size` must be 1, and legal values are `0x00` (`false`) or `0x01` (`true`). + Bool, + /// Code `i`. + /// + /// Notice that numpy does not support 128-bit integers. + Int, + /// Code `u`. + /// + /// Notice that numpy does not support 128-bit integers. + Uint, + /// Code `f`. + /// + /// Notice that numpy **does** support 128-bit floats. + Float, + /// Code `c`. Represents a complex number. + /// + /// The real part followed by the imaginary part, with `size` bytes total between the two of + /// them. Each part has the specified endianness, but the real part always comes first. + Complex, + /// Code `m`. Represents a `numpy.timedelta64`. + /// + /// Can use `i64` for serialization. `size` must be 8. + /// Check [`PlainDtype::time_units`] for the units. + TimeDelta, + /// Code `M`. Represents a `numpy.datetime64`. + /// + /// Can use `u64` for serialization. `size` must be 8. + /// Check [`PlainDtype::time_units`] for the units. + DateTime, + /// Code `S`. Represents a Python 3 `bytes` (`str` in Python 2). + /// + /// Can use `Vec` for serialization. + /// + /// A `bytes` of length `size`. Strings shorter than this length are zero-padded on the right. + /// This implies that they cannot contain trailing `NUL`s. (They can, however, contain interior + /// `NUL`s). To preserve trailing `NUL`s, use `RawData` (`V`) instead. + ByteStr, + /// Code `U`. Represents a Python 3 `str` (`unicode` in Python 2). + /// + /// A `str` that contains `size` code points (**not bytes!**). Each code unit is encoded as a + /// 32-bit integer of the given endianness. Strings with fewer than `size` code units are + /// zero-padded on the right. (thus they cannot contain trailing copies of U+0000 'NULL'; + /// they can, however, contain interior copies) + /// + /// Like Rust's `char`, the code points must have a value in `[0, 0x110000)`. However, unlike + /// `char`, surrogate code points are allowed. + UnicodeStr, + /// Code `V`. Represents a binary blob of `size` bytes. + /// + /// Can use `Vec` for serialization. + RawData, +} + +impl TypeKind { + fn from_char(s: char) -> Option { + match s { + 'b' => Some(TypeKind::Bool), + 'i' => Some(TypeKind::Int), + 'u' => Some(TypeKind::Uint), + 'f' => Some(TypeKind::Float), + 'c' => Some(TypeKind::Complex), + 'm' => Some(TypeKind::TimeDelta), + 'M' => Some(TypeKind::DateTime), + 'S' => Some(TypeKind::ByteStr), + 'U' => Some(TypeKind::UnicodeStr), + 'V' => Some(TypeKind::RawData), + _ => None, + } + } + + fn to_str(self) -> &'static str { + match self { + TypeKind::Bool => "b", + TypeKind::Int => "i", + TypeKind::Uint => "u", + TypeKind::Float => "f", + TypeKind::Complex => "c", + TypeKind::TimeDelta => "m", + TypeKind::DateTime => "M", + TypeKind::ByteStr => "S", + TypeKind::UnicodeStr => "U", + TypeKind::RawData => "V", + } + } +} + +impl TypeKind { + // `None` means all sizes are valid. + fn valid_sizes(self) -> Option<&'static [u64]> { + match self { + TypeKind::Bool => Some(&[1]), + + // numpy doesn't actually support 128-bit ints + TypeKind::Int | + TypeKind::Uint => Some(&[1, 2, 4, 8]), + + // yes, 128-bit floats are supported by numpy + TypeKind::Float => Some(&[2, 4, 8, 16]), + + // 4-byte complex numbers are mysteriously missing from numpy + TypeKind::Complex => Some(&[8, 16, 32]), + + TypeKind::TimeDelta | + TypeKind::DateTime => Some(&[8]), + + // (Note: numpy does support types `|S0` and `|U0`, though for some reason `numpy.save` + // changes them to `|S1` and `|U1`.) + TypeKind::ByteStr | + TypeKind::UnicodeStr | + TypeKind::RawData => None, + } + } + + /// Returns `true` if `|` endianness is illegal. + fn requires_endianness(self, size: u64) -> bool { + match self { + TypeKind::Bool | + TypeKind::Int | + TypeKind::Uint | + TypeKind::Float | + TypeKind::TimeDelta | + TypeKind::DateTime | + TypeKind::Complex => size != 1, + + TypeKind::UnicodeStr => true, + + TypeKind::ByteStr | + TypeKind::RawData => false, + } + } + + /// Returns `true` if `|` endianness is illegal. + fn has_units(self) -> bool { + match self { + TypeKind::TimeDelta | + TypeKind::DateTime => true, + + _ => false, + } + } +} + +impl TypeStr { + pub(crate) fn with_auto_endianness(type_kind: TypeKind, size: u64, time_units: Option) -> Self { + let endianness = match type_kind.requires_endianness(size) { + true => Endianness::of_machine(), + false => Endianness::Irrelevant, + }; + TypeStr { endianness, type_kind, size, time_units }.validate().unwrap() + } + + /// The number of bytes for a single scalar value. + pub(crate) fn num_bytes(&self) -> usize { + match self.type_kind { + TypeKind::Bool | + TypeKind::Int | + TypeKind::Uint | + TypeKind::Float | + TypeKind::Complex | + TypeKind::TimeDelta | + TypeKind::DateTime | + TypeKind::ByteStr | + TypeKind::RawData => self.size as usize, + + TypeKind::UnicodeStr => self.size as usize * 4, + } + } +} + +/// Represents the units of the `m` and `M` datatypes. +/// +/// These appear inside square brackets at the end of the `descr` string for these datatypes. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) enum TimeUnits { + /// Code `Y`. + Year, + /// Code `M`. + Month, + /// Code `W`. + Week, + /// Code `D`. + Day, + /// Code `h`. + Hour, + /// Code `m`. + Minute, + /// Code `s`. + Second, + /// Code `ms`. + Millisecond, + /// Code `us`. + Microsecond, + /// Code `ns`. + Nanosecond, + /// Code `ps`. + Picosecond, + /// Code `fs`. + Femtosecond, + /// Code `as`. + Attosecond, +} + +impl TimeUnits { + fn from_str(s: &str) -> Option { + match s { + "Y" => Some(TimeUnits::Year), + "M" => Some(TimeUnits::Month), + "W" => Some(TimeUnits::Week), + "D" => Some(TimeUnits::Day), + "h" => Some(TimeUnits::Hour), + "m" => Some(TimeUnits::Minute), + "s" => Some(TimeUnits::Second), + "ms" => Some(TimeUnits::Millisecond), + "us" => Some(TimeUnits::Microsecond), + "ns" => Some(TimeUnits::Nanosecond), + "ps" => Some(TimeUnits::Picosecond), + "fs" => Some(TimeUnits::Femtosecond), + "as" => Some(TimeUnits::Attosecond), + _ => None, + } + } + + fn to_str(self) -> &'static str { + match self { + TimeUnits::Year => "Y", + TimeUnits::Month => "M", + TimeUnits::Week => "W", + TimeUnits::Day => "D", + TimeUnits::Hour => "h", + TimeUnits::Minute => "m", + TimeUnits::Second => "s", + TimeUnits::Millisecond => "ms", + TimeUnits::Microsecond => "us", + TimeUnits::Nanosecond => "ns", + TimeUnits::Picosecond => "ps", + TimeUnits::Femtosecond => "fs", + TimeUnits::Attosecond => "as", + } + } +} + +impl fmt::Display for Endianness { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self.to_str(), f) + } +} + +impl fmt::Display for TypeKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self.to_str(), f) + } +} + +impl fmt::Display for TimeUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self.to_str(), f) + } +} + +impl fmt::Display for TypeStr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}{}{}", self.endianness, self.type_kind, self.size)?; + if let Some(time_units) = self.time_units { + write!(f, "[{}]", time_units)?; + } + Ok(()) + } +} + +pub use self::parse::ParseTypeStrError; +mod parse { + use super::*; + + /// Error type returned by `::parse`. + #[derive(Debug, Clone)] + pub struct ParseTypeStrError(ErrorKind); + + #[derive(Debug, Clone)] + enum ErrorKind { + SyntaxError, + ParseIntError(std::num::ParseIntError), + InvalidEndianness(TypeStr), + InvalidSize(TypeStr), + MissingOrUnexpectedUnits(TypeStr), + } + + impl fmt::Display for ParseTypeStrError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::ErrorKind::*; + + match &self.0 { + SyntaxError => write!(f, "Invalid type-string"), + InvalidEndianness(ty) => write!(f, "Type string '{}' has invalid endianness", ty), + InvalidSize(ty) => { + write!(f, "Type string '{}' has invalid size.", ty)?; + write!(f, " Valid sizes are: {:?}", ty.type_kind.valid_sizes().unwrap())?; + Ok(()) + }, + MissingOrUnexpectedUnits(ty) => { + if ty.type_kind.has_units() { + write!(f, "Type string '{}' is missing time units.", ty) + } else { + write!(f, "Unexpected time units in type string '{}'.", ty) + } + }, + ParseIntError(e) => write!(f, "{}", e), + } + } + } + + macro_rules! bail { + ($variant:expr) => { + return Err(ParseTypeStrError($variant)) + }; + } + + impl std::error::Error for ParseTypeStrError {} + + impl std::str::FromStr for TypeStr { + type Err = ParseTypeStrError; + + fn from_str(input: &str) -> Result { + use self::ErrorKind::*; + + if input.len() < 3 { + bail!(SyntaxError); + } + + let mut chars = input.chars(); + + let c = chars.next().unwrap(); + let endianness = match Endianness::from_char(c) { + None => bail!(SyntaxError), + Some(v) => v, + }; + + let c = chars.next().unwrap(); + let type_kind = match TypeKind::from_char(c) { + None => bail!(SyntaxError), + Some(v) => v, + }; + + let remainder = chars.as_str(); + let size_end = { + remainder.bytes().position(|b| !b.is_ascii_digit()) + .unwrap_or(remainder.len()) + }; + if size_end == 0 { + bail!(SyntaxError); + } + let (size, remainder) = remainder.split_at(size_end); + let size = match size.parse() { + Err(e) => bail!(ParseIntError(e)), // probably overflow + Ok(v) => v, + }; + + let time_units = if remainder.is_empty() { + None + } else { + let mut chars = remainder.chars(); + match (chars.next(), chars.next_back()) { + (Some('['), Some(']')) => {}, + _ => bail!(SyntaxError), + } + + match TimeUnits::from_str(chars.as_str()) { + None => bail!(SyntaxError), + Some(v) => Some(v), + } + }; + + TypeStr { endianness, type_kind, size, time_units } + .validate() + } + } + + impl TypeStr { + pub(crate) fn validate(self) -> Result { + use self::ErrorKind::*; + + let TypeStr { endianness, type_kind, size, time_units } = self; + + if type_kind.requires_endianness(size) && endianness == Endianness::Irrelevant { + bail!(InvalidEndianness(self)); + } + + if let Some(valid_sizes) = type_kind.valid_sizes() { + if !valid_sizes.contains(&size) { + bail!(InvalidSize(self)); + } + } + + if type_kind.has_units() != time_units.is_some() { + bail!(MissingOrUnexpectedUnits(self)); + } + + Ok(self) + } + } + + #[cfg(test)] + #[deny(unused)] + mod tests { + use super::*; + + macro_rules! assert_matches { + ($expr:expr, $pat:pat) => { + match $expr { + $pat => {}, + actual => panic!("Expected: {}\nGot: {:?}", stringify!($pat), actual), + } + }; + } + + macro_rules! check_ok { + ($s:expr) => { + assert_matches!($s.parse::(), Ok(_)); + }; + } + macro_rules! check_err { + ($s:expr, $p:pat) => { + assert_matches!($s.parse::(), Err(ParseTypeStrError($p))); + }; + } + + #[test] + fn errors() { + use self::ErrorKind::*; + + check_err!("", SyntaxError); + check_err!(">", SyntaxError); + check_err!(">m", SyntaxError); + check_err!(">m8[", SyntaxError); + check_err!(">m8[us", SyntaxError); + check_ok!(">m8[us]"); + check_ok!(">m8[D]"); + check_err!(">m8[us]garbage", SyntaxError); + check_err!(">m8[us]]", SyntaxError); + + + check_err!("", SyntaxError); + check_err!(">", SyntaxError); + check_err!(">i", SyntaxError); + check_ok!(">i8"); + check_ok!(">c16"); + check_err!(">i8garbage", SyntaxError); + + // length-zero integer + check_err!(">m[us]", SyntaxError); + check_err!(">i", SyntaxError); + + // make sure integer overflow doesn't panic + check_err!(">m999999999999999999999999999999[us]", _); + check_err!(">i999999999999999999999999999999", _); + + // Unrecognized specifiers + check_ok!("m8[us]"); + check_err!(">m8[bus]", _); + check_err!(">m8[usb]", _); + check_err!(">m8[xq]", _); + + // Required endianness + check_ok!("|i1"); + check_ok!("|S7"); + check_ok!("|V7"); + check_err!("|i8", InvalidEndianness { .. }); + check_err!("|U1", InvalidEndianness { .. }); + + // Size + check_ok!(">i8"); + check_err!(">i9", InvalidSize { .. }); + check_err!(">m4[us]", InvalidSize { .. }); + check_err!(">b4", InvalidSize { .. }); + check_ok!("|S0"); + check_ok!(">U0"); + check_ok!("|V0"); + check_ok!("|V7"); + + // Presence or absence of units + check_ok!(">i8"); + check_ok!(">m8[us]"); + check_err!(">i8[us]", MissingOrUnexpectedUnits { .. }); + check_err!(">m8", MissingOrUnexpectedUnits { .. }); + } + } +} + +#[cfg(test)] +#[deny(unused)] +mod tests { + use super::*; + + #[test] + fn display_simple() { + assert_eq!( + TypeStr { + endianness: Endianness::Little, + type_kind: TypeKind::Int, + size: 8, + time_units: None, + }.to_string(), + "m8[ns]", + ); + } + + #[test] + fn roundtrip() { + macro_rules! check_roundtrip { + ($text:expr) => { + let text = $text.to_string(); + match text.parse::() { + Err(e) => panic!("Failed to parse {:?}: {}", text, e), + Ok(v) => assert_eq!(text, v.to_string()), + } + }; + } + + check_roundtrip!(">i8"); + check_roundtrip!(">f16"); + check_roundtrip!("i1"); + check_roundtrip!("|i1"); + check_roundtrip!("|S7"); + check_roundtrip!("|S0"); + check_roundtrip!("U3"); + check_roundtrip!("m8[ms]"); + } +} From 57a79b26b5e512c7544b30277e8493a7511c38d6 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Mon, 3 Jun 2019 13:30:55 -0400 Subject: [PATCH 02/20] Make DType contain TypeStr --- src/header.rs | 73 ++++++++++++++++++++++++++++++++------------- src/serializable.rs | 20 ++++++------- tests/roundtrip.rs | 2 +- 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/src/header.rs b/src/header.rs index 580221c..dece2f0 100644 --- a/src/header.rs +++ b/src/header.rs @@ -2,6 +2,7 @@ use nom::IResult; use std::collections::HashMap; use std::io::Result; +use type_str::TypeStr; /// Representation of a Numpy type #[derive(PartialEq, Eq, Debug)] @@ -11,7 +12,7 @@ pub enum DType { /// Numpy type string. First character is `'>'` for big endian, `'<'` for little endian. /// /// Examples: `>i4`, `f8`. The number corresponds to the number of bytes. - ty: String, + ty: TypeStr, /// Shape of a type. /// @@ -64,7 +65,10 @@ impl DType { pub fn from_descr(descr: Value) -> Result { use DType::*; match descr { - Value::String(string) => Ok(Plain { ty: string, shape: vec![] }), + Value::String(string) => { + let ty = convert_string_to_type_str(&string)?; + Ok(Plain { ty, shape: vec![] }) + }, Value::List(ref list) => Ok(Record(convert_list_to_record_fields(list)?)), _ => invalid_data("must be string or list") } @@ -86,7 +90,7 @@ fn convert_tuple_to_record_field(tuple: &[Value]) -> Result { 2 | 3 => match (&tuple[0], &tuple[1], tuple.get(2)) { (&String(ref name), &String(ref dtype), ref shape) => Ok(Field { name: name.clone(), dtype: DType::Plain { - ty: dtype.clone(), + ty: convert_string_to_type_str(dtype)?, shape: if let &Some(ref s) = shape { convert_value_to_shape(s)? } else { @@ -127,6 +131,13 @@ fn convert_value_to_positive_integer(number: &Value) -> Result { } } +fn convert_string_to_type_str(string: &str) -> Result { + match string.parse() { + Ok(ty) => Ok(ty), + Err(e) => invalid_data(&format!("invalid type string: {}", e)), + } +} + fn first_error(results: I) -> Result> where I: IntoIterator> { @@ -244,84 +255,103 @@ mod parser { #[cfg(test)] mod tests { use super::*; + use std::error::Error; + + type TestResult = std::result::Result<(), Box>; #[test] - fn description_of_record_array_as_python_list_of_tuples() { + fn description_of_record_array_as_python_list_of_tuples() -> TestResult { let dtype = DType::Record(vec![ Field { name: "float".to_string(), - dtype: DType::Plain { ty: ">f4".to_string(), shape: vec![] } + dtype: DType::Plain { ty: ">f4".parse()?, shape: vec![] } }, Field { name: "byte".to_string(), - dtype: DType::Plain { ty: "f8".to_string(), shape: vec![] }; + fn description_of_unstructured_primitive_array() -> TestResult { + let dtype = DType::Plain { ty: ">f8".parse()?, shape: vec![] }; assert_eq!(dtype.descr(), "'>f8'"); + Ok(()) } #[test] - fn description_of_nested_record_dtype() { + fn description_of_nested_record_dtype() -> TestResult { let dtype = DType::Record(vec![ Field { name: "parent".to_string(), dtype: DType::Record(vec![ Field { name: "child".to_string(), - dtype: DType::Plain { ty: " TestResult { + let dtype = ">f8"; + assert_eq!( + DType::from_descr(Value::String(dtype.to_string())).unwrap(), + DType::Plain { ty: dtype.parse()?, shape: vec![] } + ); + Ok(()) } #[test] - fn converts_simple_description_to_record_dtype() { - let dtype = ">f8".to_string(); + fn converts_non_endian_description_to_record_dtype() -> TestResult { + let dtype = "|u1"; assert_eq!( - DType::from_descr(Value::String(dtype.clone())).unwrap(), - DType::Plain { ty: dtype, shape: vec![] } + DType::from_descr(Value::String(dtype.to_string())).unwrap(), + DType::Plain { ty: dtype.parse()?, shape: vec![] } ); + Ok(()) } #[test] - fn converts_record_description_to_record_dtype() { + fn converts_record_description_to_record_dtype() -> TestResult { let descr = parse("[('a', ' TestResult { let descr = parse("[('a', '>f8', (1,))]"); let expected_dtype = DType::Record(vec![ Field { name: "a".to_string(), - dtype: DType::Plain { ty: ">f8".to_string(), shape: vec![1] } + dtype: DType::Plain { ty: ">f8".parse()?, shape: vec![1] } } ]); assert_eq!(DType::from_descr(descr).unwrap(), expected_dtype); + Ok(()) } #[test] - fn record_description_with_nested_record_field() { + fn record_description_with_nested_record_field() -> TestResult { let descr = parse("[('parent', [('child', ' DType { - DType::Plain { ty: " usize { 1 } @@ -44,7 +44,7 @@ impl Serializable for i8 { impl Serializable for i16 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 2 } @@ -61,7 +61,7 @@ impl Serializable for i16 { impl Serializable for i32 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } @@ -78,7 +78,7 @@ impl Serializable for i32 { impl Serializable for i64 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } @@ -95,7 +95,7 @@ impl Serializable for i64 { impl Serializable for u8 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 1 } @@ -112,7 +112,7 @@ impl Serializable for u8 { impl Serializable for u16 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 2 } @@ -129,7 +129,7 @@ impl Serializable for u16 { impl Serializable for u32 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } @@ -146,7 +146,7 @@ impl Serializable for u32 { impl Serializable for u64 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } @@ -163,7 +163,7 @@ impl Serializable for u64 { impl Serializable for f32 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } @@ -180,7 +180,7 @@ impl Serializable for f32 { impl Serializable for f64 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index 334a616..b7f6824 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -38,7 +38,7 @@ struct Vector5(Vec); impl Serializable for Vector5 { #[inline] fn dtype() -> DType { - DType::Plain { ty: " Date: Thu, 6 Jun 2019 12:05:15 -0400 Subject: [PATCH 03/20] add new serialization traits, impl for ints --- src/header.rs | 29 ++++ src/lib.rs | 3 + src/serialize.rs | 421 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 453 insertions(+) create mode 100644 src/serialize.rs diff --git a/src/header.rs b/src/header.rs index dece2f0..502dfa0 100644 --- a/src/header.rs +++ b/src/header.rs @@ -73,6 +73,35 @@ impl DType { _ => invalid_data("must be string or list") } } + + #[cfg(test)] + pub(crate) fn parse(source: &str) -> Result { + let descr = match parser::item(source.as_bytes()) { + IResult::Done(_, header) => { + Ok(header) + }, + IResult::Incomplete(needed) => { + invalid_data(&format!("could not parse Python expression: {:?}", needed)) + }, + IResult::Error(err) => { + invalid_data(&format!("could not parse Python expression: {:?}", err)) + }, + }?; + Self::from_descr(descr) + } + + /// Construct a scalar `DType`. (one which is not a nested array or record type) + pub fn new_scalar(ty: TypeStr) -> Self { + DType::Plain { ty, shape: vec![] } + } + + /// Return a `TypeStr` only if the `DType` is a primitive scalar. (no arrays or record types) + pub(crate) fn as_scalar(&self) -> Option<&TypeStr> { + match self { + DType::Plain { ty, shape } if shape.is_empty() => Some(ty), + _ => None, + } + } } fn convert_list_to_record_fields(values: &[Value]) -> Result> { diff --git a/src/lib.rs b/src/lib.rs index 7d71001..85291e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,9 +133,12 @@ mod serializable; mod npy_data; mod out_file; mod type_str; +mod serialize; pub use serializable::Serializable; pub use header::{DType, Field}; pub use npy_data::NpyData; pub use out_file::{to_file, OutFile}; +pub use serialize::{Serialize, Deserialize, AutoSerialize}; +pub use serialize::{TypeRead, TypeWrite, DTypeError}; pub use type_str::{TypeStr, ParseTypeStrError}; diff --git a/src/serialize.rs b/src/serialize.rs new file mode 100644 index 0000000..7f487b5 --- /dev/null +++ b/src/serialize.rs @@ -0,0 +1,421 @@ +use header::DType; +use type_str::{TypeStr, Endianness, TypeKind}; +use byteorder::{ByteOrder, NativeEndian, WriteBytesExt}; +use self::{TypeKind::*}; +use std::io; +use std::fmt; + +/// Trait that permits reading a type from an `.npy` file. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait Deserialize: Sized { + /// Think of this as like a `Fn(&[u8]) -> (Self, &[u8])`. + /// + /// There is no closure-like sugar for these; you must manually define a type that + /// implements [`TypeRead`]. + type Reader: TypeRead; + + /// Get a function that deserializes a single data field at a time + /// + /// The function receives a byte buffer containing at least + /// `dtype.num_bytes()` bytes. + /// + /// # Errors + /// + /// Returns `Err` if the `DType` is not compatible with `Self`. + fn reader(dtype: &DType) -> Result; +} + +/// Trait that permits writing a type to an `.npy` file. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait Serialize { + /// Think of this as some sort of `for Fn(W, &Self) -> io::Result<()>`. + /// + /// There is no closure-like sugar for these; you must manually define a type that + /// implements [`TypeWrite`]. + type Writer: TypeWrite; + + /// Get a function that serializes a single data field at a time. + /// + /// # Errors + /// + /// Returns `Err` if the `DType` is not compatible with `Self`. + fn writer(dtype: &DType) -> Result; +} + +/// Subtrait of [`Serialize`] for types which have a reasonable default [`DType`]. +/// +/// This opens up some simpler APIs for serialization. (e.g. [`::to_file`]) +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait AutoSerialize: Serialize { + /// A suggested format for serialization. + /// + /// The builtin implementations for primitive types generally prefer `|` endianness if possible, + /// else the machine endian format. + fn default_dtype() -> DType; +} + +/// Like a `Fn(&[u8]) -> (T, &[u8])`. +/// +/// It is a separate trait from `Fn` for consistency with [`TypeWrite`], and so that +/// default methods can potentially be added in the future that may be overriden +/// for efficiency. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait TypeRead { + /// Type returned by the function. + type Value; + + /// The function. + /// + /// Receives *at least* enough bytes to read `Self::Value`, and returns the remainder. + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Self::Value, &'a [u8]); +} + +/// Like some sort of `for Fn(W, &T) -> io::Result<()>`. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait TypeWrite { + /// Type accepted by the function. + type Value: ?Sized; + + /// The function. + fn write_one(&self, writer: W, value: &Self::Value) -> io::Result<()> + where Self: Sized; +} + +/// Indicates that a particular rust type does not support serialization or deserialization +/// as a given [`DType`]. +#[derive(Debug, Clone)] +pub struct DTypeError(ErrorKind); + +#[derive(Debug, Clone)] +enum ErrorKind { + Custom(String), + ExpectedScalar { + dtype: String, + rust_type: &'static str, + }, + BadScalar { + type_str: TypeStr, + rust_type: &'static str, + verb: &'static str, + }, +} + +impl std::error::Error for DTypeError {} + +impl DTypeError { + /// Construct with a custom error message. + pub fn custom>(msg: S) -> Self { + DTypeError(ErrorKind::Custom(msg.as_ref().to_string())) + } + + // verb should be "read" or "write" + fn bad_scalar(verb: &'static str, type_str: &TypeStr, rust_type: &'static str) -> Self { + let type_str = type_str.clone(); + DTypeError(ErrorKind::BadScalar { type_str, rust_type, verb }) + } + + fn expected_scalar(dtype: &DType, rust_type: &'static str) -> Self { + let dtype = dtype.descr(); + DTypeError(ErrorKind::ExpectedScalar { dtype, rust_type }) + } +} + +impl fmt::Display for DTypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.0 { + ErrorKind::Custom(msg) => { + write!(f, "{}", msg) + }, + ErrorKind::ExpectedScalar { dtype, rust_type } => { + write!(f, "type {} requires a scalar (string) dtype, not {}", rust_type, dtype) + }, + ErrorKind::BadScalar { type_str, rust_type, verb } => { + write!(f, "cannot {} type {} with type-string '{}'", verb, rust_type, type_str) + }, + } + } +} + +// Takes info about each data size, from largest to smallest. +macro_rules! impl_integer_serializable { + ( @iterate + meta: $meta:tt + remaining: [] + ) => {}; + + ( @iterate + meta: $meta:tt + remaining: [$first:tt $($smaller:tt)*] + ) => { + impl_integer_serializable! { + @generate + meta: $meta + current: $first + } + + impl_integer_serializable! { + @iterate + meta: $meta + remaining: [ $($smaller)* ] + } + }; + + ( + @generate + meta: [ (main_ty: $Int:ident) (date_ty: $DateTime:ident) ] + current: [ $size:literal $int:ident + (size1: $size1_cfg:meta) $read_int:ident $write_int:ident + ] + ) => { + mod $int { + use super::*; + + pub struct AnyEndianReader { pub(super) swap_byteorder: bool } + pub struct AnyEndianWriter { pub(super) swap_byteorder: bool } + + pub(super) fn expect_scalar_dtype(dtype: &DType) -> Result<&TypeStr, DTypeError> { + dtype.as_scalar().ok_or_else(|| { + DTypeError::expected_scalar(dtype, stringify!($int)) + }) + } + + #[inline] + fn maybe_swap(swap: bool, x: $int) -> $int { + match swap { + true => x.to_be().to_le(), + false => x, + } + } + + impl TypeRead for AnyEndianReader { + type Value = $int; + + #[inline(always)] + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Self::Value, &'a [u8]) { + let value = maybe_swap(self.swap_byteorder, NativeEndian::$read_int(bytes)); + (value, &bytes[$size..]) + } + } + + impl TypeWrite for AnyEndianWriter { + type Value = $int; + + #[inline(always)] + fn write_one(&self, mut writer: W, &value: &Self::Value) -> io::Result<()> { + writer.$write_int::(maybe_swap(self.swap_byteorder, value)) + } + } + } + + impl Deserialize for $int { + type Reader = $int::AnyEndianReader; + + fn reader(dtype: &DType) -> Result { + match $int::expect_scalar_dtype(dtype)? { + // Read an integer of the same size and signedness. + // + // DateTime is an unsigned integer and TimeDelta is a signed integer, + // so we support those too. + TypeStr { size: $size, endianness, type_kind: $Int, .. } | + TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($int::AnyEndianReader { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("read", type_str, stringify!($int))), + } + } + } + + impl Serialize for $int { + type Writer = $int::AnyEndianWriter; + + fn writer(dtype: &DType) -> Result { + match $int::expect_scalar_dtype(dtype)? { + // Write a signed integer of the correct size + TypeStr { size: $size, endianness, type_kind: $Int, .. } | + TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($int::AnyEndianWriter { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("write", type_str, stringify!($int))), + } + } + } + + impl AutoSerialize for $int { + fn default_dtype() -> DType { + DType::new_scalar(TypeStr::with_auto_endianness($Int, $size, None)) + } + } + }; +} + +// Needed by the macro: Methods missing from byteorder +trait ReadSingleByteExt { + #[inline(always)] fn read_u8_(bytes: &[u8]) -> u8 { bytes[0] } + #[inline(always)] fn read_i8_(bytes: &[u8]) -> i8 { i8::from_ne_bytes([bytes[0]]) } +} + +impl ReadSingleByteExt for E {} + +/// Needed by the macro: Methods modified to take a generic type param +trait WriteSingleByteExt: WriteBytesExt { + #[inline(always)] fn write_u8_(&mut self, value: u8) -> io::Result<()> { self.write_u8(value) } + #[inline(always)] fn write_i8_(&mut self, value: i8) -> io::Result<()> { self.write_i8(value) } +} + +impl WriteSingleByteExt for W {} + +// `all()` means "true", `any()` means "false". (these get put inside `cfg`) +impl_integer_serializable! { + @iterate + meta: [ (main_ty: Int) (date_ty: TimeDelta) ] + remaining: [ + // numpy doesn't support i128 + [ 8 i64 (size1: any()) read_i64 write_i64 ] + [ 4 i32 (size1: any()) read_i32 write_i32 ] + [ 2 i16 (size1: any()) read_i16 write_i16 ] + [ 1 i8 (size1: all()) read_i8_ write_i8_ ] + ] +} + +impl_integer_serializable! { + @iterate + meta: [ (main_ty: Uint) (date_ty: DateTime) ] + remaining: [ + // numpy doesn't support i128 + [ 8 u64 (size1: any()) read_u64 write_u64 ] + [ 4 u32 (size1: any()) read_u32 write_u32 ] + [ 2 u16 (size1: any()) read_u16 write_u16 ] + [ 1 u8 (size1: all()) read_u8_ write_u8_ ] + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + fn reader_output(dtype: &DType, bytes: &[u8]) -> T { + T::reader(dtype).unwrap_or_else(|e| panic!("{}", e)).read_one(bytes).0 + } + + fn reader_expect_err(dtype: &DType) { + T::reader(dtype).err().expect("reader_expect_err failed!"); + } + + fn writer_output(dtype: &DType, value: &T) -> Vec { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value).unwrap(); + vec + } + + fn writer_expect_err(dtype: &DType) { + T::writer(dtype).err().expect("writer_expect_err failed!"); + } + + fn writer_expect_write_err(dtype: &DType, value: &T) { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value) + .err().expect("writer_expect_write_err failed!"); + } + + const BE_ONE_64: &[u8] = &[0, 0, 0, 0, 0, 0, 0, 1]; + const LE_ONE_64: &[u8] = &[1, 0, 0, 0, 0, 0, 0, 0]; + const BE_ONE_32: &[u8] = &[0, 0, 0, 1]; + const LE_ONE_32: &[u8] = &[1, 0, 0, 0]; + + #[test] + fn identity() { + let be = DType::parse("'>i4'").unwrap(); + let le = DType::parse("'(&be, BE_ONE_32), 1); + assert_eq!(reader_output::(&le, LE_ONE_32), 1); + assert_eq!(writer_output::(&be, &1), BE_ONE_32); + assert_eq!(writer_output::(&le, &1), LE_ONE_32); + + let be = DType::parse("'>u4'").unwrap(); + let le = DType::parse("'(&be, BE_ONE_32), 1); + assert_eq!(reader_output::(&le, LE_ONE_32), 1); + assert_eq!(writer_output::(&be, &1), BE_ONE_32); + assert_eq!(writer_output::(&le, &1), LE_ONE_32); + + for &dtype in &["'>i1'", "'(&dtype, &[1]), 1); + assert_eq!(writer_output::(&dtype, &1), &[1][..]); + } + + for &dtype in &["'>u1'", "'(&dtype, &[1]), 1); + assert_eq!(writer_output::(&dtype, &1), &[1][..]); + } + } + + #[test] + fn datetime_as_int() { + let be = DType::parse("'>m8[ns]'").unwrap(); + let le = DType::parse("'(&be, BE_ONE_64), 1); + assert_eq!(reader_output::(&le, LE_ONE_64), 1); + assert_eq!(writer_output::(&be, &1), BE_ONE_64); + assert_eq!(writer_output::(&le, &1), LE_ONE_64); + + let be = DType::parse("'>M8[ns]'").unwrap(); + let le = DType::parse("'(&be, BE_ONE_64), 1); + assert_eq!(reader_output::(&le, LE_ONE_64), 1); + assert_eq!(writer_output::(&be, &1), BE_ONE_64); + assert_eq!(writer_output::(&le, &1), LE_ONE_64); + } + + #[test] + fn wrong_size_int() { + let t_i32 = DType::parse("'(&t_i32); + reader_expect_err::(&t_i32); + reader_expect_err::(&t_u32); + reader_expect_err::(&t_u32); + writer_expect_err::(&t_i32); + writer_expect_err::(&t_i32); + writer_expect_err::(&t_u32); + writer_expect_err::(&t_u32); + } + + #[test] + fn default_simple_type_strs() { + assert_eq!(i8::default_dtype().descr(), "'|i1'"); + assert_eq!(u8::default_dtype().descr(), "'|u1'"); + + if 1 == i32::from_be(1) { + assert_eq!(i16::default_dtype().descr(), "'>i2'"); + assert_eq!(i32::default_dtype().descr(), "'>i4'"); + assert_eq!(i64::default_dtype().descr(), "'>i8'"); + assert_eq!(u32::default_dtype().descr(), "'>u4'"); + } else { + assert_eq!(i16::default_dtype().descr(), "' Date: Thu, 6 Jun 2019 12:08:47 -0400 Subject: [PATCH 04/20] impl Serialize traits for floats --- src/serialize.rs | 110 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/src/serialize.rs b/src/serialize.rs index 7f487b5..897eaef 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -301,6 +301,93 @@ impl_integer_serializable! { ] } +// Takes info about each data size, from largest to smallest. +macro_rules! impl_float_serializable { + ( $( [ $size:literal $float:ident $read_float:ident $write_float:ident ] )+ ) => { $( + mod $float { + use super::*; + + pub struct AnyEndianReader { pub(super) swap_byteorder: bool } + pub struct AnyEndianWriter { pub(super) swap_byteorder: bool } + + #[inline] + fn maybe_swap(swap: bool, x: $float) -> $float { + match swap { + true => $float::from_bits(x.to_bits().to_be().to_le()), + false => x, + } + } + + pub(super) fn expect_scalar_dtype(dtype: &DType) -> Result<&TypeStr, DTypeError> { + dtype.as_scalar().ok_or_else(|| { + DTypeError::expected_scalar(dtype, stringify!($float)) + }) + } + + impl TypeRead for AnyEndianReader { + type Value = $float; + + #[inline(always)] + fn read_one<'a>(&self, bytes: &'a [u8]) -> ($float, &'a [u8]) { + let value = maybe_swap(self.swap_byteorder, NativeEndian::$read_float(bytes)); + (value, &bytes[$size..]) + } + } + + impl TypeWrite for AnyEndianWriter { + type Value = $float; + + #[inline(always)] + fn write_one(&self, mut writer: W, &value: &$float) -> io::Result<()> { + writer.$write_float::(maybe_swap(self.swap_byteorder, value)) + } + } + } + + impl Deserialize for $float { + type Reader = $float::AnyEndianReader; + + fn reader(dtype: &DType) -> Result { + match $float::expect_scalar_dtype(dtype)? { + // Read a float of the correct size + TypeStr { size: $size, endianness, type_kind: Float, .. } => { + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($float::AnyEndianReader { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("read", type_str, stringify!($float))), + } + } + } + + impl Serialize for $float { + type Writer = $float::AnyEndianWriter; + + fn writer(dtype: &DType) -> Result { + match $float::expect_scalar_dtype(dtype)? { + // Write a float of the correct size + TypeStr { size: $size, endianness, type_kind: Float, .. } => { + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($float::AnyEndianWriter { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("write", type_str, stringify!($float))), + } + } + } + + impl AutoSerialize for $float { + fn default_dtype() -> DType { + DType::new_scalar(TypeStr::with_auto_endianness(Float, $size, None)) + } + } + )+}; +} + +impl_float_serializable! { + // TODO: numpy supports f16, f128 + [ 8 f64 read_f64 write_f64 ] + [ 4 f32 read_f32 write_f32 ] +} + #[cfg(test)] mod tests { use super::*; @@ -367,6 +454,29 @@ mod tests { } } + #[test] + fn native_float_types() { + let be_bytes = 42.0_f64.to_bits().to_be_bytes(); + let le_bytes = 42.0_f64.to_bits().to_le_bytes(); + let be = DType::parse("'>f8'").unwrap(); + let le = DType::parse("'(&be, &be_bytes), 42.0); + assert_eq!(reader_output::(&le, &le_bytes), 42.0); + assert_eq!(writer_output::(&be, &42.0), &be_bytes); + assert_eq!(writer_output::(&le, &42.0), &le_bytes); + + let be_bytes = 42.0_f32.to_bits().to_be_bytes(); + let le_bytes = 42.0_f32.to_bits().to_le_bytes(); + let be = DType::parse("'>f4'").unwrap(); + let le = DType::parse("'(&be, &be_bytes), 42.0); + assert_eq!(reader_output::(&le, &le_bytes), 42.0); + assert_eq!(writer_output::(&be, &42.0), &be_bytes); + assert_eq!(writer_output::(&le, &42.0), &le_bytes); + } + #[test] fn datetime_as_int() { let be = DType::parse("'>m8[ns]'").unwrap(); From b6ce98b748e59ca27c89f23542ed6e90304244d9 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 12:13:50 -0400 Subject: [PATCH 05/20] impl Serialize traits for bytestrings --- src/serialize.rs | 246 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) diff --git a/src/serialize.rs b/src/serialize.rs index 897eaef..dd45e4c 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -4,6 +4,7 @@ use byteorder::{ByteOrder, NativeEndian, WriteBytesExt}; use self::{TypeKind::*}; use std::io; use std::fmt; +use std::convert::TryFrom; /// Trait that permits reading a type from an `.npy` file. /// @@ -108,6 +109,7 @@ enum ErrorKind { rust_type: &'static str, verb: &'static str, }, + UsizeOverflow(u64), } impl std::error::Error for DTypeError {} @@ -128,6 +130,10 @@ impl DTypeError { let dtype = dtype.descr(); DTypeError(ErrorKind::ExpectedScalar { dtype, rust_type }) } + + fn bad_usize(x: u64) -> Self { + DTypeError(ErrorKind::UsizeOverflow(x)) + } } impl fmt::Display for DTypeError { @@ -142,10 +148,17 @@ impl fmt::Display for DTypeError { ErrorKind::BadScalar { type_str, rust_type, verb } => { write!(f, "cannot {} type {} with type-string '{}'", verb, rust_type, type_str) }, + ErrorKind::UsizeOverflow(value) => { + write!(f, "cannot cast {} as usize", value) + }, } } } +fn invalid_data(message: &str) -> io::Result { + Err(io::Error::new(io::ErrorKind::InvalidData, message.to_string())) +} + // Takes info about each data size, from largest to smallest. macro_rules! impl_integer_serializable { ( @iterate @@ -388,7 +401,171 @@ impl_float_serializable! { [ 4 f32 read_f32 write_f32 ] } +pub struct BytesReader { + size: usize, + is_byte_str: bool, +} + +impl TypeRead for BytesReader { + type Value = Vec; + + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Vec, &'a [u8]) { + let mut vec = vec![]; + + let (src, remainder) = bytes.split_at(self.size); + vec.resize(self.size, 0); + vec.copy_from_slice(src); + + // truncate trailing zeros for type 'S' + if self.is_byte_str { + let end = vec.iter().rposition(|x| x != &0).map_or(0, |ind| ind + 1); + vec.truncate(end); + } + + (vec, remainder) + } +} + +impl Deserialize for Vec { + type Reader = BytesReader; + + fn reader(type_str: &DType) -> Result { + let type_str = type_str.as_scalar().ok_or_else(|| DTypeError::expected_scalar(type_str, "Vec"))?; + let size = match usize::try_from(type_str.size) { + Ok(size) => size, + Err(_) => return Err(DTypeError::bad_usize(type_str.size)), + }; + + let is_byte_str = match *type_str { + TypeStr { type_kind: ByteStr, .. } => true, + TypeStr { type_kind: RawData, .. } => false, + _ => return Err(DTypeError::bad_scalar("read", type_str, "Vec")), + }; + Ok(BytesReader { size, is_byte_str }) + } +} + +pub struct BytesWriter { + type_str: TypeStr, + size: usize, + is_byte_str: bool, +} + +impl TypeWrite for BytesWriter { + type Value = [u8]; + + fn write_one(&self, mut w: W, bytes: &[u8]) -> io::Result<()> { + use std::cmp::Ordering; + + match (bytes.len().cmp(&self.size), self.is_byte_str) { + (Ordering::Greater, _) | + (Ordering::Less, false) => return invalid_data( + &format!("bad item length {} for type-string '{}'", bytes.len(), self.type_str), + ), + _ => {}, + } + + w.write_all(bytes)?; + if self.is_byte_str { + w.write_all(&vec![0; self.size - bytes.len()])?; + } + Ok(()) + } +} + +impl Serialize for [u8] { + type Writer = BytesWriter; + + fn writer(dtype: &DType) -> Result { + let type_str = dtype.as_scalar().ok_or_else(|| DTypeError::expected_scalar(dtype, "[u8]"))?; + + let size = match usize::try_from(type_str.size) { + Ok(size) => size, + Err(_) => return Err(DTypeError::bad_usize(type_str.size)), + }; + + let type_str = type_str.clone(); + let is_byte_str = match type_str { + TypeStr { type_kind: ByteStr, .. } => true, + TypeStr { type_kind: RawData, .. } => false, + _ => return Err(DTypeError::bad_scalar("read", &type_str, "[u8]")), + }; + Ok(BytesWriter { type_str, size, is_byte_str }) + } +} + +#[macro_use] +mod helper { + use super::*; + use std::ops::Deref; + + pub struct TypeWriteViaDeref + where + T: Deref, + ::Target: Serialize, + { + pub(crate) inner: <::Target as Serialize>::Writer, + } + + impl TypeWrite for TypeWriteViaDeref + where + T: Deref, + U: Serialize, + { + type Value = T; + + #[inline(always)] + fn write_one(&self, writer: W, value: &T) -> io::Result<()> { + self.inner.write_one(writer, value) + } + } + + macro_rules! impl_serialize_by_deref { + ([$($generics:tt)*] $T:ty => $Target:ty $(where $($bounds:tt)+)*) => { + impl<$($generics)*> Serialize for $T + $(where $($bounds)+)* + { + type Writer = helper::TypeWriteViaDeref<$T>; + + #[inline(always)] + fn writer(dtype: &DType) -> Result { + Ok(helper::TypeWriteViaDeref { inner: <$Target>::writer(dtype)? }) + } + } + }; + } + + macro_rules! impl_auto_serialize { + ([$($generics:tt)*] $T:ty as $Delegate:ty $(where $($bounds:tt)+)*) => { + impl<$($generics)*> AutoSerialize for $T + $(where $($bounds)+)* + { + #[inline(always)] + fn default_dtype() -> DType { + <$Delegate>::default_dtype() + } + } + }; + } +} + +impl_serialize_by_deref!{[] Vec => [u8]} + +impl_serialize_by_deref!{['a, T: ?Sized] &'a T => T where T: Serialize} +impl_serialize_by_deref!{['a, T: ?Sized] &'a mut T => T where T: Serialize} +impl_serialize_by_deref!{[T: ?Sized] Box => T where T: Serialize} +impl_serialize_by_deref!{[T: ?Sized] std::rc::Rc => T where T: Serialize} +impl_serialize_by_deref!{[T: ?Sized] std::sync::Arc => T where T: Serialize} +impl_serialize_by_deref!{['a, T: ?Sized] std::borrow::Cow<'a, T> => T where T: Serialize + std::borrow::ToOwned} +impl_auto_serialize!{[T: ?Sized] &T as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] &mut T as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] Box as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] std::rc::Rc as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] std::sync::Arc as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] std::borrow::Cow<'_, T> as T where T: AutoSerialize + std::borrow::ToOwned} + #[cfg(test)] +#[deny(unused)] mod tests { use super::*; @@ -511,6 +688,67 @@ mod tests { writer_expect_err::(&t_u32); } + #[test] + fn bytes_any_endianness() { + for ty in vec!["'S3'", "'|S3'"] { + let ty = DType::parse(ty).unwrap(); + assert_eq!(writer_output(&ty, &[1, 3, 5][..]), vec![1, 3, 5]); + assert_eq!(reader_output::>(&ty, &[1, 3, 5][..]), vec![1, 3, 5]); + } + } + + #[test] + fn bytes_size_zero() { + let ts = DType::parse("'|S0'").unwrap(); + assert_eq!(reader_output::>(&ts, &[]), vec![]); + assert_eq!(writer_output(&ts, &[][..]), vec![]); + + let ts = DType::parse("'|V0'").unwrap(); + assert_eq!(reader_output::>(&ts, &[]), vec![]); + assert_eq!(writer_output::<[u8]>(&ts, &[]), vec![]); + } + + #[test] + fn wrong_size_bytes() { + let s_3 = DType::parse("'|S3'").unwrap(); + let v_3 = DType::parse("'|V3'").unwrap(); + + assert_eq!(writer_output(&s_3, &[1, 3, 5][..]), vec![1, 3, 5]); + assert_eq!(writer_output(&v_3, &[1, 3, 5][..]), vec![1, 3, 5]); + + assert_eq!(writer_output(&s_3, &[1][..]), vec![1, 0, 0]); + writer_expect_write_err(&v_3, &[1][..]); + + assert_eq!(writer_output(&s_3, &[][..]), vec![0, 0, 0]); + writer_expect_write_err(&v_3, &[][..]); + + writer_expect_write_err(&s_3, &[1, 3, 5, 7][..]); + writer_expect_write_err(&v_3, &[1, 3, 5, 7][..]); + } + + #[test] + fn read_bytes_with_trailing_zeros() { + let ts = DType::parse("'|S2'").unwrap(); + assert_eq!(reader_output::>(&ts, &[1, 3]), vec![1, 3]); + assert_eq!(reader_output::>(&ts, &[1, 0]), vec![1]); + assert_eq!(reader_output::>(&ts, &[0, 0]), vec![]); + + let ts = DType::parse("'|V2'").unwrap(); + assert_eq!(reader_output::>(&ts, &[1, 3]), vec![1, 3]); + assert_eq!(reader_output::>(&ts, &[1, 0]), vec![1, 0]); + assert_eq!(reader_output::>(&ts, &[0, 0]), vec![0, 0]); + } + + #[test] + fn bytestr_preserves_interior_zeros() { + const DATA: &[u8] = &[0, 1, 0, 0, 3, 5]; + + let ts = DType::parse("'|S6'").unwrap(); + + assert_eq!(reader_output::>(&ts, DATA), DATA.to_vec()); + assert_eq!(writer_output(&ts, DATA), DATA.to_vec()); + } + #[test] fn default_simple_type_strs() { assert_eq!(i8::default_dtype().descr(), "'|i1'"); @@ -528,4 +766,12 @@ mod tests { assert_eq!(u32::default_dtype().descr(), "'>(&ts, &vec![1, 3, 5]), vec![1, 3, 5]); + assert_eq!(writer_output::<&[u8]>(&ts, &&[1, 3, 5][..]), vec![1, 3, 5]); + } } From a50046535f300bd7319694b755be187cd2cd27c4 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 12:21:27 -0400 Subject: [PATCH 06/20] support dynamic readers and writers --- src/lib.rs | 2 +- src/serialize.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 85291e0..b164598 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,5 +140,5 @@ pub use header::{DType, Field}; pub use npy_data::NpyData; pub use out_file::{to_file, OutFile}; pub use serialize::{Serialize, Deserialize, AutoSerialize}; -pub use serialize::{TypeRead, TypeWrite, DTypeError}; +pub use serialize::{TypeRead, TypeWrite, TypeWriteDyn, DTypeError}; pub use type_str::{TypeStr, ParseTypeStrError}; diff --git a/src/serialize.rs b/src/serialize.rs index dd45e4c..3f7af85 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -92,6 +92,22 @@ pub trait TypeWrite { where Self: Sized; } +/// The proper trait to use for trait objects of [`TypeWrite`]. +/// +/// `Box` is useless because `dyn TypeWrite` has no object-safe methods. +/// The workaround is to use `Box` instead, which itself implements `TypeWrite`. +pub trait TypeWriteDyn: TypeWrite { + #[doc(hidden)] + fn write_one_dyn(&self, writer: &mut dyn io::Write, value: &Self::Value) -> io::Result<()>; +} + +impl TypeWriteDyn for T { + #[inline(always)] + fn write_one_dyn(&self, writer: &mut dyn io::Write, value: &Self::Value) -> io::Result<()> { + self.write_one(writer, value) + } +} + /// Indicates that a particular rust type does not support serialization or deserialization /// as a given [`DType`]. #[derive(Debug, Clone)] @@ -155,6 +171,30 @@ impl fmt::Display for DTypeError { } } +impl TypeRead for Box> { + type Value = T; + + #[inline(always)] + fn read_one<'a>(&self, bytes: &'a [u8]) -> (T, &'a [u8]) { + (**self).read_one(bytes) + } +} + +impl TypeWrite for Box> { + type Value = T; + + #[inline(always)] + fn write_one(&self, mut writer: W, value: &T) -> io::Result<()> + where Self: Sized, + { + // Boxes must always go through two virtual dispatches. + // + // (one on the TypeWrite trait object, and one on the Writer which must be + // cast to the monomorphic type `&mut dyn io::write`) + (**self).write_one_dyn(&mut writer, value) + } +} + fn invalid_data(message: &str) -> io::Result { Err(io::Error::new(io::ErrorKind::InvalidData, message.to_string())) } @@ -774,4 +814,14 @@ mod tests { assert_eq!(writer_output::>(&ts, &vec![1, 3, 5]), vec![1, 3, 5]); assert_eq!(writer_output::<&[u8]>(&ts, &&[1, 3, 5][..]), vec![1, 3, 5]); } + + #[test] + fn dynamic_readers_and_writers() { + let writer: Box> = Box::new(i32::writer(&i32::default_dtype()).unwrap()); + let reader: Box> = Box::new(i32::reader(&i32::default_dtype()).unwrap()); + + let mut buf = vec![]; + writer.write_one(&mut buf, &4000).unwrap(); + assert_eq!(reader.read_one(&buf).0, 4000); + } } From f9d2faac0c63138a1ab2a952951f768b85368843 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Tue, 4 Jun 2019 14:05:31 -0400 Subject: [PATCH 07/20] Add feature "derive" to reexport the macros This enables these derives to be qualified under `npy`, for disambiguation from `serde`: extern crate npy; #[derive(npy::Serialize, npy::Deserialize)] struct MyStruct { ... } This has a couple of downsides with regard to maintainence: * npy_derive can no longer be a dev-dependency because it must become an optional dependency. * Many tests and examples need the feature. We need to list all of these in Cargo.toml. * Because this crate is 2015 edition, as soon as we list *any* tests and examples, we must list *all* of them; including the ones that don't need the feature! --- This commit had to update `.travis.yml` to start using the feature. I took this opportunity to also add `--examples` (which the default script does not do) to ensure that examples build correctly. --- .travis.yml | 4 ++++ Cargo.toml | 39 ++++++++++++++++++++++++++++++++++++++- examples/large.rs | 5 +---- examples/roundtrip.rs | 1 - examples/simple.rs | 4 +--- src/lib.rs | 38 ++++++++++++++++++++++++++++---------- 6 files changed, 72 insertions(+), 19 deletions(-) diff --git a/.travis.yml b/.travis.yml index 42c29fa..5ee3e3b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,3 +7,7 @@ matrix: allow_failures: - rust: nightly +script: + - cargo build --verbose + - cargo build --verbose --features derive --examples + - cargo test --verbose --features derive diff --git a/Cargo.toml b/Cargo.toml index d7f0d30..756c6c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,43 @@ members = [ "npy-derive" ] byteorder = "1" nom = "3" +[dependencies.npy-derive] +path = "npy-derive" +version = "0.4" +optional = true +default-features = false + [dev-dependencies] memmap = "0.6" -npy-derive = { path = "npy-derive", version = "0.4" } + +[features] +default = [] + +# Reexports the derive macros so that you can use them qualified under `npy::`: +# +# #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] +# struct Struct { ... } +# +# This is a nicer alternative to `#[macro_use] extern crate npy_derive`, which +# directly imports things like `#[derive(Serialize)]` that may conflict with +# other crates (e.g. `serde`). +derive = ["npy-derive"] + +[[example]] +name = "plain" + +[[example]] +name = "large" +required-features = ["derive"] + +[[example]] +name = "simple" +required-features = ["derive"] + +[[example]] +name = "roundtrip" +required-features = ["derive"] + +[[test]] +name = "roundtrip" +required-features = ["derive"] diff --git a/examples/large.rs b/examples/large.rs index c0767c3..08eb310 100644 --- a/examples/large.rs +++ b/examples/large.rs @@ -1,14 +1,11 @@ extern crate memmap; -#[macro_use] -extern crate npy_derive; extern crate npy; use std::fs::File; use memmap::MmapOptions; - -#[derive(Serializable, Debug, Default)] +#[derive(npy::Serializable, Debug, Default)] struct Array { a: i32, b: f32, diff --git a/examples/roundtrip.rs b/examples/roundtrip.rs index ef75eaf..c17685a 100644 --- a/examples/roundtrip.rs +++ b/examples/roundtrip.rs @@ -1,4 +1,3 @@ - #[macro_use] extern crate npy_derive; extern crate npy; diff --git a/examples/simple.rs b/examples/simple.rs index 2a7eb43..86af8d9 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,6 +1,4 @@ -#[macro_use] -extern crate npy_derive; extern crate npy; use std::io::Read; @@ -12,7 +10,7 @@ use npy::NpyData; // a = np.array([(1,2.5,4), (2,3.1,5)], dtype=[('a', 'i4'),('b', 'f4'),('c', 'i8')]) // np.save('examples/simple.npy', a) -#[derive(Serializable, Debug)] +#[derive(npy::Serializable, Debug)] struct Array { a: i32, b: f32, diff --git a/src/lib.rs b/src/lib.rs index b164598..244d8f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,10 +20,10 @@ are supported. These are: [Structured arrays](https://docs.scipy.org/doc/numpy/user/basics.rec.html). They can contain the following field types: * primitive types, - * other [`Serializable`](trait.Serializable.html) structs, - * arrays of [`Serializable`](trait.Serializable.html) types (including arrays) of length ≤ 16. - * `struct`s with manual [`Serializable`](trait.Serializable.html) implementations. An example - this can be found in the [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). + * other structs that implement the traits, + * arrays of types that implement the traits (including arrays) of length ≤ 16. + * `struct`s with manual trait implementations. An example of this can be found in the + [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). To successfully import an array from NPY using the `#[derive(Serializable)]` mechanism, the target struct must contain: @@ -85,18 +85,31 @@ a = np.array([(1,2.5,4), (2,3.1,5)], dtype=[('a', 'i4'),('b', 'f4'),('c', 'i8')] np.save('examples/simple.npy', a) ``` -To load this in Rust, we need to create a corresponding struct, that derives `Serializable`. Make sure -the field names and types all match up: +To load this in Rust, we need to create a corresponding struct. +There are three derivable traits we can define for it: +* [`Deserialize`] — Enables easy reading of `.npy` files. +* [`AutoSerialize`] — Enables easy writing of `.npy` files. (in a default format) +* [`Serialize`] — Supertrait of `AutoSerialize` that allows one to specify a custom [`DType`]. + +**Enable the `"derive"` feature in `Cargo.toml`,** +and make sure the field names and types all match up: +*/ + +// It is not currently possible in Cargo.toml to specify that an optional dependency should +// also be a dev-dependency. Therefore, we discretely remove this example when generating +// doctests, so that: +// - It always appears in documentation (`cargo doc`) +// - It is only tested when the feature is present (`cargo test --features derive`) +#![cfg_attr(any(not(test), feature="derive"), doc = r##" ``` -#[macro_use] -extern crate npy_derive; +// make sure to add `features = ["derive"]` in Cargo.toml! extern crate npy; use std::io::Read; use npy::NpyData; -#[derive(Serializable, Debug)] +#[derive(npy::Serializable, Debug)] struct Array { a: i32, b: f32, @@ -114,7 +127,8 @@ fn main() { } } ``` - +"##)] +/*! The output is: ```text @@ -124,6 +138,10 @@ Array { a: 2, b: 3.1, c: 5 } */ +// Reexport the macros. +#[cfg(feature = "derive")] extern crate npy_derive; +#[cfg(feature = "derive")] pub use npy_derive::*; + extern crate byteorder; #[macro_use] extern crate nom; From 6aee13e9396a54bd369721ddace1ecb06ce137d0 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 12:36:51 -0400 Subject: [PATCH 08/20] Derive Clone for DType The new derive macros will need this... --- src/header.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/header.rs b/src/header.rs index 502dfa0..d847b4e 100644 --- a/src/header.rs +++ b/src/header.rs @@ -5,7 +5,7 @@ use std::io::Result; use type_str::TypeStr; /// Representation of a Numpy type -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] pub enum DType { /// A simple array with only a single field Plain { @@ -25,7 +25,7 @@ pub enum DType { Record(Vec) } -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] /// A field of a record dtype pub struct Field { /// The name of the field From af1e445776e6a0075fabcc5bfd856d086d188bb2 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 12:52:28 -0400 Subject: [PATCH 09/20] add the new derives --- Cargo.toml | 4 + npy-derive/Cargo.toml | 1 + npy-derive/src/lib.rs | 271 +++++++++++++++++++++++++++++++++++++++- src/serialize.rs | 32 +++++ tests/derive_hygiene.rs | 18 +++ 5 files changed, 321 insertions(+), 5 deletions(-) create mode 100644 tests/derive_hygiene.rs diff --git a/Cargo.toml b/Cargo.toml index 756c6c3..786349d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,10 @@ required-features = ["derive"] name = "roundtrip" required-features = ["derive"] +[[test]] +name = "derive_hygiene" +required-features = ["derive"] + [[test]] name = "roundtrip" required-features = ["derive"] diff --git a/npy-derive/Cargo.toml b/npy-derive/Cargo.toml index eced55c..e355b55 100644 --- a/npy-derive/Cargo.toml +++ b/npy-derive/Cargo.toml @@ -10,5 +10,6 @@ repository = "https://github.com/potocpav/npy-rs" proc-macro = true [dependencies] +proc-macro2 = "0.2" quote = "0.4" syn = "0.12" diff --git a/npy-derive/src/lib.rs b/npy-derive/src/lib.rs index beb8999..dd6ffa2 100644 --- a/npy-derive/src/lib.rs +++ b/npy-derive/src/lib.rs @@ -1,20 +1,22 @@ -#![recursion_limit = "128"] +#![recursion_limit = "256"] /*! Derive `trait Serializable` for a structure. -Using this crate, it is enough to `#[derive(Serializable)]` on a struct to be able to serialize and -deserialize it. All the fields must implement [`Serializable`](../npy/trait.Serializable.html). +Using this crate, it is enough to `#[derive(npy::Serialize, npy::Deserialize)]` on a struct to be able to +serialize and deserialize it. All of the fields must implement [`Serialize`](../npy/trait.Serialize.html) +and [`Deserialize`](../npy/trait.Deserialize.html) respectively. */ extern crate proc_macro; +extern crate proc_macro2; extern crate syn; #[macro_use] extern crate quote; use proc_macro::TokenStream; -use syn::Data; +use proc_macro2::Span; use quote::{Tokens, ToTokens}; /// Macros 1.1-based custom derive function @@ -36,7 +38,7 @@ pub fn npy_data(input: TokenStream) -> TokenStream { fn impl_npy_data(ast: &syn::DeriveInput) -> quote::Tokens { let name = &ast.ident; let fields = match ast.data { - Data::Struct(ref data) => &data.fields, + syn::Data::Struct(ref data) => &data.fields, _ => panic!("#[derive(Serializable)] can only be used with structs"), }; // Helper is provided for handling complex generic types correctly and effortlessly @@ -100,3 +102,262 @@ fn impl_npy_data(ast: &syn::DeriveInput) -> quote::Tokens { } } } + +/// Macros 1.1-based custom derive function +#[proc_macro_derive(Serialize)] +pub fn npy_serialize(input: TokenStream) -> TokenStream { + // Parse the string representation + let ast = syn::parse(input).unwrap(); + + // Build the impl + let expanded = impl_npy_serialize(&ast); + + // Return the generated impl + expanded.into() +} + +#[proc_macro_derive(Deserialize)] +pub fn npy_deserialize(input: TokenStream) -> TokenStream { + // Parse the string representation + let ast = syn::parse(input).unwrap(); + + // Build the impl + let expanded = impl_npy_deserialize(&ast); + + // Return the generated impl + expanded.into() +} + +#[proc_macro_derive(AutoSerialize)] +pub fn npy_auto_serialize(input: TokenStream) -> TokenStream { + // Parse the string representation + let ast = syn::parse(input).unwrap(); + + // Build the impl + let expanded = impl_npy_auto_serialize(&ast); + + // Return the generated impl + expanded.into() +} + +struct FieldData { + idents: Vec, + idents_str: Vec, + types: Vec, +} + +impl FieldData { + fn extract(ast: &syn::DeriveInput) -> Self { + let fields = match ast.data { + syn::Data::Struct(ref data) => &data.fields, + _ => panic!("npy derive macros can only be used with structs"), + }; + + let idents: Vec = fields.iter().map(|f| { + f.ident.clone().expect("Tuple structs not supported") + }).collect(); + let idents_str = idents.iter().map(|t| unraw(t)).collect::>(); + + let types: Vec = fields.iter().map(|f| { + let ty = &f.ty; + quote!( #ty ) + }).collect::>(); + + FieldData { idents, idents_str, types } + } +} + +fn impl_npy_serialize(ast: &syn::DeriveInput) -> Tokens { + let name = &ast.ident; + let vis = &ast.vis; + let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str); + + let idents_1 = idents; + + wrap_in_const("Serialize", &name, quote! { + use ::std::io; + + #vis struct GeneratedWriter #ty_generics #where_clause { + writers: FieldWriters #ty_generics, + } + + struct FieldWriters #ty_generics #where_clause { + #( #idents: <#types as _npy::Serialize>::Writer ,)* + } + + #field_dtypes_struct + + impl #impl_generics _npy::TypeWrite for GeneratedWriter #ty_generics #where_clause { + type Value = #name #ty_generics; + + #[allow(unused_mut)] + fn write_one(&self, mut w: W, value: &Self::Value) -> io::Result<()> { + #( + let method = <<#types as _npy::Serialize>::Writer as _npy::TypeWrite>::write_one; + method(&self.writers.#idents, &mut w, &value.#idents_1)?; + )* + p::Ok(()) + } + } + + impl #impl_generics _npy::Serialize for #name #ty_generics #where_clause { + type Writer = GeneratedWriter #ty_generics; + + fn writer(dtype: &_npy::DType) -> p::Result { + let dtypes = FieldDTypes::extract(dtype)?; + let writers = FieldWriters { + #( #idents: <#types as _npy::Serialize>::writer(&dtypes.#idents_1)? ,)* + }; + + p::Ok(GeneratedWriter { writers }) + } + } + }) +} + +fn impl_npy_deserialize(ast: &syn::DeriveInput) -> Tokens { + let name = &ast.ident; + let vis = &ast.vis; + let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str); + + let idents_1 = idents; + + wrap_in_const("Deserialize", &name, quote! { + #vis struct GeneratedReader #ty_generics #where_clause { + readers: FieldReaders #ty_generics, + } + + struct FieldReaders #ty_generics #where_clause { + #( #idents: <#types as _npy::Deserialize>::Reader ,)* + } + + #field_dtypes_struct + + impl #impl_generics _npy::TypeRead for GeneratedReader #ty_generics #where_clause { + type Value = #name #ty_generics; + + #[allow(unused_mut)] + fn read_one<'a>(&self, mut remainder: &'a [u8]) -> (Self::Value, &'a [u8]) { + #( + let func = <<#types as _npy::Deserialize>::Reader as _npy::TypeRead>::read_one; + let (#idents, new_remainder) = func(&self.readers.#idents_1, remainder); + remainder = new_remainder; + )* + (#name { #( #idents ),* }, remainder) + } + } + + impl #impl_generics _npy::Deserialize for #name #ty_generics #where_clause { + type Reader = GeneratedReader #ty_generics; + + fn reader(dtype: &_npy::DType) -> p::Result { + let dtypes = FieldDTypes::extract(dtype)?; + let readers = FieldReaders { + #( #idents: <#types as _npy::Deserialize>::reader(&dtypes.#idents_1)? ,)* + }; + + p::Ok(GeneratedReader { readers }) + } + } + }) +} + +fn impl_npy_auto_serialize(ast: &syn::DeriveInput) -> Tokens { + let name = &ast.ident; + let FieldData { idents: _, ref idents_str, ref types } = FieldData::extract(ast); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + + wrap_in_const("AutoSerialize", &name, quote! { + impl #impl_generics _npy::AutoSerialize for #name #ty_generics #where_clause { + fn default_dtype() -> _npy::DType { + _npy::DType::Record(vec![#( + _npy::Field { + name: #idents_str.to_string(), + dtype: <#types as _npy::AutoSerialize>::default_dtype() + } + ),*]) + } + } + }) +} + +fn gen_field_dtypes_struct( + idents: &[syn::Ident], + idents_str: &[String], +) -> Tokens { + assert_eq!(idents.len(), idents_str.len()); + quote!{ + struct FieldDTypes { + #( #idents : _npy::DType ,)* + } + + impl FieldDTypes { + fn extract(dtype: &_npy::DType) -> p::Result { + let fields = match dtype { + _npy::DType::Record(fields) => fields, + _npy::DType::Plain { ty, .. } => return p::Err(_npy::DTypeError::expected_record(ty)), + }; + + let correct_names: &[&str] = &[ #(#idents_str),* ]; + + if p::Iterator::ne( + p::Iterator::map(fields.iter(), |f| &f.name[..]), + p::Iterator::cloned(correct_names.iter()), + ) { + let actual_names = p::Iterator::map(fields.iter(), |f| &f.name[..]); + return p::Err(_npy::DTypeError::wrong_fields(actual_names, correct_names)); + } + + #[allow(unused_mut)] + let mut fields = p::IntoIterator::into_iter(fields); + p::Result::Ok(FieldDTypes { + #( #idents : { + let field = p::Iterator::next(&mut fields).unwrap(); + p::Clone::clone(&field.dtype) + },)* + }) + } + } + } +} + +// from the wonderful folks working on serde +fn wrap_in_const( + trait_: &str, + ty: &syn::Ident, + code: Tokens, +) -> Tokens { + let dummy_const = syn::Ident::new( + &format!("__IMPL_npy_{}_FOR_{}", trait_, unraw(ty)), + Span::call_site(), + ); + + quote! { + #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] + const #dummy_const: () = { + #[allow(unknown_lints)] + #[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))] + #[allow(rust_2018_idioms)] + extern crate npy as _npy; + + // if our generated code directly imports any traits, then the #[no_implicit_prelude] + // test won't catch accidental use of method syntax on trait methods (which can fail + // due to ambiguity with similarly-named methods on other traits). So if we want to + // abbreviate paths, we need to do this instead: + use ::std::prelude::v1 as p; + + #code + }; + } +} + +fn unraw(ident: &syn::Ident) -> String { + ident.to_string().trim_start_matches("r#").to_owned() +} diff --git a/src/serialize.rs b/src/serialize.rs index 3f7af85..ceee6ab 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -120,6 +120,13 @@ enum ErrorKind { dtype: String, rust_type: &'static str, }, + ExpectedRecord { + type_str: TypeStr, + }, + WrongFields { + expected: Vec, + actual: Vec, + }, BadScalar { type_str: TypeStr, rust_type: &'static str, @@ -150,6 +157,25 @@ impl DTypeError { fn bad_usize(x: u64) -> Self { DTypeError(ErrorKind::UsizeOverflow(x)) } + + // used by derives + #[doc(hidden)] + pub fn expected_record(type_str: &TypeStr) -> Self { + let type_str = type_str.clone(); + DTypeError(ErrorKind::ExpectedRecord { type_str }) + } + + // used by derives + #[doc(hidden)] + pub fn wrong_fields, S2: AsRef>( + expected: impl IntoIterator, + actual: impl IntoIterator, + ) -> Self { + DTypeError(ErrorKind::WrongFields { + expected: expected.into_iter().map(|s| s.as_ref().to_string()).collect(), + actual: actual.into_iter().map(|s| s.as_ref().to_string()).collect(), + }) + } } impl fmt::Display for DTypeError { @@ -161,6 +187,12 @@ impl fmt::Display for DTypeError { ErrorKind::ExpectedScalar { dtype, rust_type } => { write!(f, "type {} requires a scalar (string) dtype, not {}", rust_type, dtype) }, + ErrorKind::ExpectedRecord { type_str } => { + write!(f, "expected a record type; got a scalar type '{}'", type_str) + }, + ErrorKind::WrongFields { actual, expected } => { + write!(f, "field names do not match (expected {:?}, got {:?})", expected, actual) + }, ErrorKind::BadScalar { type_str, rust_type, verb } => { write!(f, "cannot {} type {} with type-string '{}'", verb, rust_type, type_str) }, diff --git a/tests/derive_hygiene.rs b/tests/derive_hygiene.rs new file mode 100644 index 0000000..b8eed08 --- /dev/null +++ b/tests/derive_hygiene.rs @@ -0,0 +1,18 @@ +extern crate npy_derive; +extern crate npy as lol; + +#[no_implicit_prelude] +mod not_root { + use ::npy_derive; + + #[derive(npy_derive::Serialize, npy_derive::Deserialize)] + struct Struct { + foo: i32, + bar: LocalType, + } + + #[derive(npy_derive::Serialize, npy_derive::Deserialize)] + struct LocalType; +} + +fn main() {} From 2a6672b0fb7faea648f5a1b831145b89e1a986db Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 14:04:45 -0400 Subject: [PATCH 10/20] impl Serialize for arrays This had to wait until after the derives were added so that the tests could use the derives. --- Cargo.toml | 4 ++ src/header.rs | 5 +- src/serialize.rs | 133 +++++++++++++++++++++++++++++++++++++++ tests/serialize_array.rs | 132 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 tests/serialize_array.rs diff --git a/Cargo.toml b/Cargo.toml index 786349d..d5b1a1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,3 +62,7 @@ required-features = ["derive"] [[test]] name = "roundtrip" required-features = ["derive"] + +[[test]] +name = "serialize_array" +required-features = ["derive"] diff --git a/src/header.rs b/src/header.rs index d847b4e..0e92573 100644 --- a/src/header.rs +++ b/src/header.rs @@ -74,8 +74,9 @@ impl DType { } } - #[cfg(test)] - pub(crate) fn parse(source: &str) -> Result { + // not part of stable API, but needed by the serialize_array test + #[doc(hidden)] + pub fn parse(source: &str) -> Result { let descr = match parser::item(source.as_bytes()) { IResult::Done(_, header) => { Ok(header) diff --git a/src/serialize.rs b/src/serialize.rs index ceee6ab..b2d7d07 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -120,6 +120,13 @@ enum ErrorKind { dtype: String, rust_type: &'static str, }, + ExpectedArray { + got: &'static str, // "a scalar", "a record" + }, + WrongArrayLen { + expected: u64, + actual: u64, + }, ExpectedRecord { type_str: TypeStr, }, @@ -190,6 +197,12 @@ impl fmt::Display for DTypeError { ErrorKind::ExpectedRecord { type_str } => { write!(f, "expected a record type; got a scalar type '{}'", type_str) }, + ErrorKind::ExpectedArray { got } => { + write!(f, "rust array types require an array dtype (got {})", got) + }, + ErrorKind::WrongArrayLen { actual, expected } => { + write!(f, "wrong array size (expected {}, got {})", expected, actual) + }, ErrorKind::WrongFields { actual, expected } => { write!(f, "field names do not match (expected {:?}, got {:?})", expected, actual) }, @@ -636,11 +649,131 @@ impl_auto_serialize!{[T: ?Sized] std::rc::Rc as T where T: AutoSerialize} impl_auto_serialize!{[T: ?Sized] std::sync::Arc as T where T: AutoSerialize} impl_auto_serialize!{[T: ?Sized] std::borrow::Cow<'_, T> as T where T: AutoSerialize + std::borrow::ToOwned} +impl DType { + /// Expect an array dtype, get the length of the array and the inner dtype. + fn array_inner_dtype(&self, expected_len: u64) -> Result { + match *self { + DType::Record { .. } => Err(DTypeError(ErrorKind::ExpectedArray { got: "a record" })), + DType::Plain { ref ty, ref shape } => { + let ty = ty.clone(); + let mut shape = shape.to_vec(); + + let len = match shape.is_empty() { + true => return Err(DTypeError(ErrorKind::ExpectedArray { got: "a scalar" })), + false => shape.remove(0), + }; + + if len != expected_len { + return Err(DTypeError(ErrorKind::WrongArrayLen { + actual: len, + expected: expected_len, + })); + } + + Ok(DType::Plain { ty, shape }) + }, + } + } +} + +macro_rules! gen_array_serializable { + ($([$n:tt in mod $mod_name:ident])+) => { $( + mod $mod_name { + use super::*; + + pub struct ArrayReader{ inner: I } + pub struct ArrayWriter{ inner: I } + + impl TypeRead for ArrayReader + where I::Value: Copy + Default, + { + type Value = [I::Value; $n]; + + #[inline] + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Self::Value, &'a [u8]) { + let mut value = [I::Value::default(); $n]; + + let mut remainder = bytes; + for place in &mut value { + let (item, new_remainder) = self.inner.read_one(remainder); + *place = item; + remainder = new_remainder; + } + + (value, remainder) + } + } + + impl TypeWrite for ArrayWriter + where I::Value: Sized, + { + type Value = [I::Value; $n]; + + #[inline] + fn write_one(&self, mut writer: W, value: &Self::Value) -> io::Result<()> + where Self: Sized, + { + for item in value { + self.inner.write_one(&mut writer, item)?; + } + Ok(()) + } + } + + impl AutoSerialize for [T; $n] { + #[inline] + fn default_dtype() -> DType { + use DType::*; + + match T::default_dtype() { + Plain { ty, mut shape } => DType::Plain { + ty, + shape: { shape.insert(0, $n); shape }, + }, + Record(_) => unimplemented!("arrays of nested records") + } + } + } + + impl Deserialize for [T; $n] { + type Reader = ArrayReader<::Reader>; + + #[inline] + fn reader(dtype: &DType) -> Result { + let inner_dtype = dtype.array_inner_dtype($n)?; + let inner = ::reader(&inner_dtype)?; + Ok(ArrayReader { inner }) + } + } + + impl Serialize for [T; $n] { + type Writer = ArrayWriter<::Writer>; + + #[inline] + fn writer(dtype: &DType) -> Result { + let inner = ::writer(&dtype.array_inner_dtype($n)?)?; + Ok(ArrayWriter { inner }) + } + } + } + )+ } +} + +gen_array_serializable!{ + /* no size 0 */ [ 1 in mod arr1] [ 2 in mod arr2] [ 3 in mod arr3] + [ 4 in mod arr4] [ 5 in mod arr5] [ 6 in mod arr6] [ 7 in mod arr7] + [ 8 in mod arr8] [ 9 in mod arr9] [10 in mod arr10] [11 in mod arr11] + [12 in mod arr12] [13 in mod arr13] [14 in mod arr14] [15 in mod arr15] + [16 in mod arr16] +} + #[cfg(test)] #[deny(unused)] mod tests { use super::*; + // NOTE: Tests for arrays are in tests/serialize_array.rs because they require derives + fn reader_output(dtype: &DType, bytes: &[u8]) -> T { T::reader(dtype).unwrap_or_else(|e| panic!("{}", e)).read_one(bytes).0 } diff --git a/tests/serialize_array.rs b/tests/serialize_array.rs new file mode 100644 index 0000000..7db0f60 --- /dev/null +++ b/tests/serialize_array.rs @@ -0,0 +1,132 @@ +extern crate npy; + +use npy::{Deserialize, Serialize, AutoSerialize, DType, TypeStr, Field}; +use npy::{TypeRead, TypeWrite}; + +// These tests ideally would be in npy::serialize::tests, but they require "derive" +// because arrays can only exist as record fields. + +fn reader_output(dtype: &DType, bytes: &[u8]) -> T { + T::reader(dtype).unwrap_or_else(|e| panic!("{}", e)).read_one(bytes).0 +} + +fn reader_expect_err(dtype: &DType) { + T::reader(dtype).err().expect("reader_expect_err failed!"); +} + +fn writer_output(dtype: &DType, value: &T) -> Vec { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value).unwrap(); + vec +} + +fn writer_expect_err(dtype: &DType) { + T::writer(dtype).err().expect("writer_expect_err failed!"); +} + +fn writer_expect_write_err(dtype: &DType, value: &T) { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value) + .err().expect("writer_expect_write_err failed!"); +} + +#[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] +#[derive(Debug, PartialEq)] +struct Array3 { + field: [i32; 3], +} + +#[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] +#[derive(Debug, PartialEq)] +struct Array23 { + field: [[i32; 3]; 2], +} + +const ARRAY3_DESCR_LE: &str = "[('field', '(&dtype, &bytes), value); + assert_eq!(writer_output::(&dtype, &value), bytes); + reader_expect_err::(&dtype); + writer_expect_err::(&dtype); +} + +#[test] +fn read_write_nested() { + let dtype = DType::parse(ARRAY23_DESCR_LE).unwrap(); + let value = Array23 { field: [[1, 3, 5], [7, 9, 11]] }; + let mut bytes = vec![]; + for n in vec![1, 3, 5, 7, 9, 11] { + bytes.extend_from_slice(&i32::to_le_bytes(n)); + } + + assert_eq!(reader_output::(&dtype, &bytes), value); + assert_eq!(writer_output::(&dtype, &value), bytes); + reader_expect_err::(&dtype); + writer_expect_err::(&dtype); +} + +#[test] +fn incompatible() { + // wrong size + let dtype = DType::parse(ARRAY2_DESCR_LE).unwrap(); + writer_expect_err::(&dtype); + reader_expect_err::(&dtype); + + // scalar instead of array + let dtype = DType::parse(ARRAY_SCALAR_DESCR_LE).unwrap(); + writer_expect_err::(&dtype); + reader_expect_err::(&dtype); + + // record instead of array + let dtype = DType::parse(ARRAY_RECORD_DESCR_LE).unwrap(); + writer_expect_err::(&dtype); + reader_expect_err::(&dtype); +} + +#[test] +fn default_dtype() { + let int_ty: TypeStr = { + if 1 == i32::from_be(1) { + ">i4".parse().unwrap() + } else { + " Date: Thu, 6 Jun 2019 14:36:15 -0400 Subject: [PATCH 11/20] The breaking changes: Update NpyData and OutFile This is the single most important commit in the PR. All breaking changes to existing public APIs are contained in here. Serialize is completely removed. Examples and tests are not yet updated, so they are broken in this commit. --- npy-derive/src/lib.rs | 86 +--------------- src/header.rs | 12 +++ src/lib.rs | 2 - src/npy_data.rs | 54 +++++----- src/out_file.rs | 41 +++++--- src/serializable.rs | 234 ------------------------------------------ 6 files changed, 66 insertions(+), 363 deletions(-) delete mode 100644 src/serializable.rs diff --git a/npy-derive/src/lib.rs b/npy-derive/src/lib.rs index dd6ffa2..7a50777 100644 --- a/npy-derive/src/lib.rs +++ b/npy-derive/src/lib.rs @@ -17,91 +17,7 @@ extern crate quote; use proc_macro::TokenStream; use proc_macro2::Span; -use quote::{Tokens, ToTokens}; - -/// Macros 1.1-based custom derive function -#[proc_macro_derive(Serializable)] -pub fn npy_data(input: TokenStream) -> TokenStream { - // Construct a string representation of the type definition - // let s = input.to_string(); - - // Parse the string representation - let ast = syn::parse(input).unwrap(); - - // Build the impl - let expanded = impl_npy_data(&ast); - - // Return the generated impl - expanded.into() -} - -fn impl_npy_data(ast: &syn::DeriveInput) -> quote::Tokens { - let name = &ast.ident; - let fields = match ast.data { - syn::Data::Struct(ref data) => &data.fields, - _ => panic!("#[derive(Serializable)] can only be used with structs"), - }; - // Helper is provided for handling complex generic types correctly and effortlessly - let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); - - let idents = fields.iter().map(|f| { - let mut t = Tokens::new(); - f.ident.clone().expect("Tuple structs not supported").to_tokens(&mut t); - t - }).collect::>(); - let types = fields.iter().map(|f| { - let mut t = Tokens::new(); - f.ty.to_tokens(&mut t); - t - }).collect::>(); - - let idents_c = idents.clone(); - let idents_str = idents.clone().into_iter().map(|t| t.to_string()).collect::>(); - let idents_str_c1 = idents_str.clone(); - let types_c1 = types.clone(); - let types_c2 = types.clone(); - let types_c3 = types.clone(); - - let nats_0 = 0usize..; - let nats_1 = 0usize..; - let n_fields = types.len(); - - quote! { - impl #impl_generics ::npy::Serializable for #name #ty_generics #where_clause { - fn dtype() -> ::npy::DType { - ::npy::DType::Record(vec![#( - ::npy::Field { - name: #idents_str_c1.to_string(), - dtype: <#types_c1 as ::npy::Serializable>::dtype() - } - ),*]) - } - - fn n_bytes() -> usize { - #( <#types_c2 as ::npy::Serializable>::n_bytes() )+* - } - - #[allow(unused_assignments)] - fn read(buf: &[u8]) -> Self { - let mut offset = 0; - let mut offsets = [0; #n_fields + 1]; - #( - offset += <#types_c3 as ::npy::Serializable>::n_bytes(); - offsets[#nats_0 + 1] = offset; - )* - - #name { #( - #idents: ::npy::Serializable::read(&buf[offsets[#nats_1]..]) - ),* } - } - - fn write(&self, writer: &mut W) -> ::std::io::Result<()> { - #( ::npy::Serializable::write(&self.#idents_c, writer)?; )* - Ok(()) - } - } - } -} +use quote::Tokens; /// Macros 1.1-based custom derive function #[proc_macro_derive(Serialize)] diff --git a/src/header.rs b/src/header.rs index 0e92573..f239d8a 100644 --- a/src/header.rs +++ b/src/header.rs @@ -103,6 +103,18 @@ impl DType { _ => None, } } + + /// Get the number of bytes that each item of this type occupies. + pub fn num_bytes(&self) -> usize { + match self { + DType::Plain { ty, shape } => { + ty.num_bytes() * shape.iter().product::() as usize + }, + DType::Record(fields) => { + fields.iter().map(|field| field.dtype.num_bytes()).sum() + }, + } + } } fn convert_list_to_record_fields(values: &[Value]) -> Result> { diff --git a/src/lib.rs b/src/lib.rs index 244d8f9..ae686d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,13 +147,11 @@ extern crate byteorder; extern crate nom; mod header; -mod serializable; mod npy_data; mod out_file; mod type_str; mod serialize; -pub use serializable::Serializable; pub use header::{DType, Field}; pub use npy_data::NpyData; pub use out_file::{to_file, OutFile}; diff --git a/src/npy_data.rs b/src/npy_data.rs index 59a4404..183b2bc 100644 --- a/src/npy_data.rs +++ b/src/npy_data.rs @@ -1,11 +1,8 @@ - use nom::*; use std::io::{Result, ErrorKind, Error}; -use std::marker::PhantomData; use header::{Value, DType, parse_header}; -use serializable::Serializable; - +use serialize::{Deserialize, TypeRead}; /// The data structure representing a deserialized `npy` file. /// @@ -13,17 +10,29 @@ use serializable::Serializable; /// as a byte array, and deserialized only on-demand to minimize unnecessary allocations. /// The whole contents of the file can be deserialized by the [`to_vec`](#method.to_vec) /// member function. -pub struct NpyData<'a, T> { +pub struct NpyData<'a, T: Deserialize> { data: &'a [u8], + dtype: DType, n_records: usize, - _t: PhantomData, + item_size: usize, + reader: ::Reader, } -impl<'a, T: Serializable> NpyData<'a, T> { +impl<'a, T: Deserialize> NpyData<'a, T> { /// Deserialize a NPY file represented as bytes pub fn from_bytes(bytes: &'a [u8]) -> ::std::io::Result> { - let (data_slice, ns) = Self::get_data_slice(bytes)?; - Ok(NpyData { data: data_slice, n_records: ns as usize, _t: PhantomData }) + let (dtype, data, ns) = Self::get_data_slice(bytes)?; + let reader = match T::reader(&dtype) { + Ok(reader) => reader, + Err(e) => return Err(Error::new(ErrorKind::InvalidData, e.to_string())), + }; + let item_size = dtype.num_bytes(); + Ok(NpyData { data, dtype, n_records: ns as usize, item_size, reader }) + } + + /// Get the dtype as written in the file. + pub fn dtype(&self) -> DType { + self.dtype.clone() } /// Gets a single data-record with the specified index. Returns None, if the index is @@ -46,9 +55,9 @@ impl<'a, T: Serializable> NpyData<'a, T> { self.n_records == 0 } - /// Gets a single data-record wit the specified index. Panics, if the index is out of bounds. + /// Gets a single data-record with the specified index. Panics if the index is out of bounds. pub fn get_unchecked(&self, i: usize) -> T { - T::read(&self.data[i * T::n_bytes()..]) + self.reader.read_one(&self.data[i * self.item_size..]).0 } /// Construct a vector with the deserialized contents of the whole file @@ -60,7 +69,7 @@ impl<'a, T: Serializable> NpyData<'a, T> { v } - fn get_data_slice(bytes: &[u8]) -> Result<(&[u8], i64)> { + fn get_data_slice(bytes: &[u8]) -> Result<(DType, &[u8], i64)> { let (data, header) = match parse_header(bytes) { IResult::Done(data, header) => { Ok((data, header)) @@ -95,35 +104,28 @@ impl<'a, T: Serializable> NpyData<'a, T> { "\'descr\' field is not present or doesn't contain a list."))?; if let Ok(dtype) = DType::from_descr(descr.clone()) { - let expected_dtype = T::dtype(); - if dtype != expected_dtype { - return Err(Error::new(ErrorKind::InvalidData, - format!("Types don't match! found: {:?}, expected: {:?}", dtype, expected_dtype) - )); - } + Ok((dtype, data, ns)) } else { - return Err(Error::new(ErrorKind::InvalidData, format!("fail?!?"))); + Err(Error::new(ErrorKind::InvalidData, format!("fail?!?"))) } - - Ok((data, ns)) } } /// A result of NPY file deserialization. /// /// It is an iterator to offer a lazy interface in case the data don't fit into memory. -pub struct IntoIter<'a, T: 'a> { +pub struct IntoIter<'a, T: 'a + Deserialize> { data: NpyData<'a, T>, i: usize, } -impl<'a, T> IntoIter<'a, T> { +impl<'a, T> IntoIter<'a, T> where T: Deserialize { fn new(data: NpyData<'a, T>) -> Self { IntoIter { data, i: 0 } } } -impl<'a, T: 'a + Serializable> IntoIterator for NpyData<'a, T> { +impl<'a, T: 'a> IntoIterator for NpyData<'a, T> where T: Deserialize { type Item = T; type IntoIter = IntoIter<'a, T>; @@ -132,7 +134,7 @@ impl<'a, T: 'a + Serializable> IntoIterator for NpyData<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> where T: Serializable { +impl<'a, T> Iterator for IntoIter<'a, T> where T: Deserialize { type Item = T; fn next(&mut self) -> Option { @@ -145,4 +147,4 @@ impl<'a, T> Iterator for IntoIter<'a, T> where T: Serializable { } } -impl<'a, T> ExactSizeIterator for IntoIter<'a, T> where T: Serializable {} +impl<'a, T> ExactSizeIterator for IntoIter<'a, T> where T: Deserialize {} diff --git a/src/out_file.rs b/src/out_file.rs index f11fc21..6e69ec2 100644 --- a/src/out_file.rs +++ b/src/out_file.rs @@ -1,30 +1,34 @@ - use std::io::{self,Write,BufWriter,Seek,SeekFrom}; use std::fs::File; use std::path::Path; -use std::marker::PhantomData; use byteorder::{WriteBytesExt, LittleEndian}; -use serializable::Serializable; +use serialize::{AutoSerialize, Serialize, TypeWrite}; use header::DType; const FILLER: &'static [u8] = &[42; 19]; /// Serialize into a file one row at a time. To serialize an iterator, use the /// [`to_file`](fn.to_file.html) function. -pub struct OutFile { +pub struct OutFile { shape_pos: usize, len: usize, fw: BufWriter, - _t: PhantomData + writer: ::Writer, } -impl OutFile { - /// Open a file +impl OutFile { + /// Create a file, using the default format for the given type. pub fn open>(path: P) -> io::Result { - let dtype = Row::dtype(); - if let &DType::Plain { ref shape, .. } = &dtype { + Self::open_with_dtype(&Row::default_dtype(), path) + } +} + +impl OutFile { + /// Create a file, using the provided dtype. + pub fn open_with_dtype>(dtype: &DType, path: P) -> io::Result { + if let &DType::Plain { ref shape, .. } = dtype { assert!(shape.len() == 0, "plain non-scalar dtypes not supported"); } let mut fw = BufWriter::new(File::create(path)?); @@ -32,7 +36,12 @@ impl OutFile { fw.write_all(b"NUMPY")?; fw.write_all(&[0x01u8, 0x00])?; - let (header, shape_pos) = create_header(&dtype); + let (header, shape_pos) = create_header(dtype); + + let writer = match Row::writer(dtype) { + Ok(writer) => writer, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; let mut padding: Vec = vec![]; padding.extend(&::std::iter::repeat(b' ').take(15 - ((header.len() + 10) % 16)).collect::>()); @@ -44,21 +53,21 @@ impl OutFile { fw.write_u16::(len as u16)?; fw.write_all(&header)?; - // Padding to 8 bytes + // Padding to 16 bytes fw.write_all(&padding)?; Ok(OutFile { shape_pos: shape_pos, len: 0, fw: fw, - _t: PhantomData, + writer: writer, }) } /// Append a single row to the file pub fn push(&mut self, row: &Row) -> io::Result<()> { self.len += 1; - row.write(&mut self.fw) + self.writer.write_one(&mut self.fw, row) } fn close_(&mut self) -> io::Result<()> { @@ -90,7 +99,7 @@ fn create_header(dtype: &DType) -> (Vec, usize) { (header, shape_pos) } -impl Drop for OutFile { +impl Drop for OutFile { fn drop(&mut self) { let _ = self.close_(); // Ignore the errors } @@ -101,9 +110,9 @@ impl Drop for OutFile { /// Serialize an iterator over a struct to a NPY file /// /// A single-statement alternative to saving row by row using the [`OutFile`](struct.OutFile.html). -pub fn to_file<'a, S, T, P>(filename: P, data: T) -> ::std::io::Result<()> where +pub fn to_file(filename: P, data: T) -> ::std::io::Result<()> where P: AsRef, - S: Serializable + 'a, + S: AutoSerialize, T: IntoIterator { let mut of = OutFile::open(filename)?; diff --git a/src/serializable.rs b/src/serializable.rs deleted file mode 100644 index 5f3083b..0000000 --- a/src/serializable.rs +++ /dev/null @@ -1,234 +0,0 @@ - -use std::io::{Write,Result}; -use byteorder::{WriteBytesExt, LittleEndian}; -use header::DType; -use byteorder::ByteOrder; - -/// This trait contains information on how to serialize and deserialize a type. -/// -/// An example illustrating a `Serializable` implementation for a fixed-size vector is in -/// [the roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). -/// It is strongly advised to annotate the `Serializable` functions as `#[inline]` for good -/// performance. -pub trait Serializable : Sized { - /// Convert a type to a structure representing a Numpy type - fn dtype() -> DType; - - /// Get the number of bytes of the binary repr - fn n_bytes() -> usize; - - /// Deserialize a single data field, advancing the cursor in the process. - fn read(c: &[u8]) -> Self; - - /// Serialize a single data field into a writer. - fn write(&self, writer: &mut W) -> Result<()>; -} - -impl Serializable for i8 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 1 } - #[inline] - fn read(buf: &[u8]) -> Self { - unsafe { ::std::mem::transmute(buf[0]) } // TODO: a better way - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i8(*self) - } -} - -impl Serializable for i16 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 2 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_i16(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i16::(*self) - } -} - -impl Serializable for i32 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_i32(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i32::(*self) - } -} - -impl Serializable for i64 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_i64(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i64::(*self) - } -} - -impl Serializable for u8 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 1 } - #[inline] - fn read(buf: &[u8]) -> Self { - buf[0] - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u8(*self) - } -} - -impl Serializable for u16 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 2 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_u16(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(*self) - } -} - -impl Serializable for u32 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_u32(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u32::(*self) - } -} - -impl Serializable for u64 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_u64(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u64::(*self) - } -} - -impl Serializable for f32 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_f32(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_f32::(*self) - } -} - -impl Serializable for f64 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_f64(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_f64::(*self) - } -} - -macro_rules! gen_array_serializable { - ($($n:tt),+) => { $( - impl Serializable for [T; $n] { - #[inline] - fn dtype() -> DType { - use DType::*; - match T::dtype() { - Plain { ref ty, ref shape } => DType::Plain { - ty: ty.clone(), - shape: shape.clone().into_iter().chain(Some($n)).collect() - }, - Record(_) => unimplemented!("arrays of nested records") - } - } - #[inline] - fn n_bytes() -> usize { T::n_bytes() * $n } - #[inline] - fn read(buf: &[u8]) -> Self { - let mut a = [T::default(); $n]; - let mut off = 0; - for x in &mut a { - *x = T::read(&buf[off..]); - off += T::n_bytes(); - } - a - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - for item in self { - item.write(writer)?; - } - Ok(()) - } - } - )+ } -} - -gen_array_serializable!(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16); From f0c499b4b28a5ab274f1a626f683ce37db35bffb Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 14:50:06 -0400 Subject: [PATCH 12/20] update benches, examples, and tests to use new traits --- benches/bench.rs | 104 ++++++++++---- examples/large.rs | 2 +- examples/roundtrip.rs | 2 +- examples/simple.rs | 3 +- src/lib.rs | 13 +- tests/roundtrip.rs | 313 +++++++++++++++++++++++++++++++++++++++--- 6 files changed, 378 insertions(+), 59 deletions(-) diff --git a/benches/bench.rs b/benches/bench.rs index 35c40e2..9a46809 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,44 +1,94 @@ #![feature(test)] -#[macro_use] -extern crate npy_derive; extern crate npy; extern crate test; -use npy::Serializable; +use npy::{Serialize, Deserialize, AutoSerialize, TypeWrite, TypeRead}; use test::Bencher; use test::black_box as bb; -#[derive(Serializable, Debug, PartialEq)] -struct Array { - a: i32, - b: f32, +const NITER: usize = 100_000; + +macro_rules! gen_benches { + ($T:ty, $new:expr) => { + #[inline(never)] + fn test_data() -> Vec { + let mut raw = Vec::new(); + let writer = <$T>::writer(&<$T>::default_dtype()).unwrap(); + for i in 0usize..NITER { + writer.write_one(&mut raw, &$new(i)).unwrap(); + } + raw + } + + #[bench] + fn read(b: &mut Bencher) { + let raw = test_data(); + b.iter(|| { + let dtype = <$T>::default_dtype(); + let reader = <$T>::reader(&dtype).unwrap(); + + let mut remainder = &raw[..]; + for _ in 0usize..NITER { + let (value, new_remainder) = reader.read_one(remainder); + bb(value); + remainder = new_remainder; + } + assert_eq!(remainder.len(), 0); + }); + } + + #[bench] + fn write(b: &mut Bencher) { + b.iter(|| { + bb(test_data()) + }); + } + }; } -const NITER: usize = 100_000; +#[cfg(feature = "derive")] +mod simple { + use super::*; + + #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] + #[derive(Debug, PartialEq)] + struct Simple { + a: i32, + b: f32, + } + + gen_benches!(Simple, |i| Simple { a: i as i32, b: i as f32 }); +} -fn test_data() -> Vec { - let mut raw = Vec::new(); - for i in 0..NITER { - let arr = Array { a: i as i32, b: i as f32 }; - arr.write(&mut raw).unwrap(); +#[cfg(feature = "derive")] +mod one_field { + use super::*; + + #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] + #[derive(Debug, PartialEq)] + struct OneField { + a: i32, } - raw + + gen_benches!(OneField, |i| OneField { a: i as i32 }); } -#[bench] -fn read(b: &mut Bencher) { - let raw = test_data(); - b.iter(|| { - for i in 0..NITER { - bb(Array::read(&raw[i*8..])); - } - }); +#[cfg(feature = "derive")] +mod array { + use super::*; + + #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] + #[derive(Debug, PartialEq)] + struct Array { + a: [f32; 8], + } + + gen_benches!(Array, |i| Array { a: [i as f32; 8] }); } -#[bench] -fn write(b: &mut Bencher) { - b.iter(|| { - bb(test_data()) - }); +mod plain_f32 { + use super::*; + + gen_benches!(f32, |i| i as f32); } diff --git a/examples/large.rs b/examples/large.rs index 08eb310..3652e49 100644 --- a/examples/large.rs +++ b/examples/large.rs @@ -5,7 +5,7 @@ extern crate npy; use std::fs::File; use memmap::MmapOptions; -#[derive(npy::Serializable, Debug, Default)] +#[derive(npy::Serialize, npy::Deserialize, Debug, Default)] struct Array { a: i32, b: f32, diff --git a/examples/roundtrip.rs b/examples/roundtrip.rs index c17685a..b188611 100644 --- a/examples/roundtrip.rs +++ b/examples/roundtrip.rs @@ -4,7 +4,7 @@ extern crate npy; use std::io::Read; -#[derive(Serializable, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, AutoSerialize, Debug, PartialEq, Clone)] struct Array { a: i32, b: f32, diff --git a/examples/simple.rs b/examples/simple.rs index 86af8d9..c8b76a6 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,4 +1,5 @@ +extern crate npy_derive; extern crate npy; use std::io::Read; @@ -10,7 +11,7 @@ use npy::NpyData; // a = np.array([(1,2.5,4), (2,3.1,5)], dtype=[('a', 'i4'),('b', 'f4'),('c', 'i8')]) // np.save('examples/simple.npy', a) -#[derive(npy::Serializable, Debug)] +#[derive(npy::Deserialize, Debug)] struct Array { a: i32, b: f32, diff --git a/src/lib.rs b/src/lib.rs index ae686d3..d3cee86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,12 +11,12 @@ It stores the type, shape and endianness information in a header, which is followed by a flat binary data field. This crate offers a simple, mostly type-safe way to read and write *.npy files. Files are handled using iterators, so they don't need to fit in memory. -One-dimensional arrays of types that implement the [`Serializable`](trait.Serializable.html) trait -are supported. These are: +One-dimensional arrays of types that implement the [`Serialize`], [`Deserialize`], +and/or [`AutoSerialize`] traits are supported. These are: * primitive types: `i8`, `u8`, `i16`, `u16`, `i32`, `u32`, `f32`, `f64`. These map to the `numpy` types of `int8`, `uint8`, `int16`, etc. - * `struct`s annotated as `#[derive(Serializable)]`. These map to `numpy`'s + * `struct`s annotated as e.g. `#[derive(npy::Serialize)]`. These map to `numpy`'s [Structured arrays](https://docs.scipy.org/doc/numpy/user/basics.rec.html). They can contain the following field types: * primitive types, @@ -25,13 +25,12 @@ are supported. These are: * `struct`s with manual trait implementations. An example of this can be found in the [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). -To successfully import an array from NPY using the `#[derive(Serializable)]` mechanism, the target -struct must contain: +To successfully import an array from NPY using the `#[derive(npy::Serialize)]` mechanism, +you must enable the `"derive"` feature, and the target struct must contain: * corresponding number of fields in the same order, * corresponding names of fields, * compatible field types. -* only little endian fields # Examples @@ -109,7 +108,7 @@ extern crate npy; use std::io::Read; use npy::NpyData; -#[derive(npy::Serializable, Debug)] +#[derive(npy::Deserialize, Debug)] struct Array { a: i32, b: f32, diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index b7f6824..48eb769 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -1,20 +1,20 @@ -#[macro_use] -extern crate npy_derive; extern crate npy; extern crate byteorder; use byteorder::ByteOrder; use std::io::{Read, Write}; use byteorder::{WriteBytesExt, LittleEndian}; -use npy::{DType, Serializable}; +use npy::{DType, Field, OutFile, Serialize, Deserialize, AutoSerialize}; -#[derive(Serializable, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, AutoSerialize)] +#[derive(Debug, PartialEq, Clone)] struct Nested { v1: f32, v2: f32, } -#[derive(Serializable, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, AutoSerialize)] +#[derive(Debug, PartialEq, Clone)] struct Array { v_i8: i8, v_i16: i16, @@ -35,35 +35,66 @@ struct Array { #[derive(Debug, PartialEq, Clone)] struct Vector5(Vec); -impl Serializable for Vector5 { +impl AutoSerialize for Vector5 { #[inline] - fn dtype() -> DType { + fn default_dtype() -> DType { DType::Plain { ty: " usize { 5 * 4 } +impl Serialize for Vector5 { + type Writer = Vector5Writer; - #[inline] - fn read(buf: &[u8]) -> Self { - let mut ret = Vector5(vec![]); - let mut off = 0; - for _ in 0..5 { - ret.0.push(LittleEndian::read_i32(&buf[off..])); - off += i32::n_bytes(); + fn writer(dtype: &DType) -> Result { + if dtype == &Self::default_dtype() { + Ok(Vector5Writer) + } else { + Err(npy::DTypeError::custom("Vector5 only supports ' Result { + if dtype == &Self::default_dtype() { + Ok(Vector5Reader) + } else { + Err(npy::DTypeError::custom("Vector5 only supports '(&self, writer: &mut W) -> std::io::Result<()> { + fn write_one(&self, mut writer: W, value: &Self::Value) -> std::io::Result<()> { for i in 0..5 { - writer.write_i32::(self.0[i])? + writer.write_i32::(value.0[i])? } Ok(()) } } +impl npy::TypeRead for Vector5Reader { + type Value = Vector5; + + #[inline] + fn read_one<'a>(&self, mut remainder: &'a [u8]) -> (Self::Value, &'a [u8]) { + let mut ret = Vector5(vec![]); + for _ in 0..5 { + ret.0.push(LittleEndian::read_i32(remainder)); + remainder = &remainder[4..]; + } + (ret, remainder) + } +} + #[test] fn roundtrip() { let n = 100i64; @@ -101,16 +132,254 @@ fn roundtrip() { assert_eq!(arrays, arrays2); } +fn plain_field(name: &str, dtype: &str) -> Field { + Field { + name: name.to_string(), + dtype: DType::new_scalar(dtype.parse().unwrap()), + } +} + #[test] -fn roundtrip_with_simple_dtype() { +fn roundtrip_with_plain_dtype() { let array_written = vec![2., 3., 4., 5.]; - npy::to_file("tests/roundtrip_simple.npy", array_written.clone()).unwrap(); + npy::to_file("tests/roundtrip_plain.npy", array_written.clone()).unwrap(); let mut buffer = vec![]; - std::fs::File::open("tests/roundtrip_simple.npy").unwrap() + std::fs::File::open("tests/roundtrip_plain.npy").unwrap() .read_to_end(&mut buffer).unwrap(); let array_read = npy::NpyData::from_bytes(&buffer).unwrap().to_vec(); assert_eq!(array_written, array_read); } + +#[test] +fn roundtrip_byteorder() { + let path = "tests/roundtrip_byteorder.npy"; + + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + be_u32: u32, + le_u32: u32, + be_f32: f32, + le_f32: f32, + be_i8: i8, + le_i8: i8, + na_i8: i8, + } + + let dtype = DType::Record(vec![ + plain_field("be_u32", ">u4"), + plain_field("le_u32", "f4"), + plain_field("le_f32", "i1"), + plain_field("le_i8", "::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), dtype); +} + +#[test] +fn roundtrip_datetime() { + let path = "tests/roundtrip_datetime.npy"; + + // Similar to: + // + // ``` + // import numpy.datetime64 as dt + // import numpy as np + // + // arr = np.array([( + // dt('2011-01-01', 'ns'), + // dt('2011-01-02') - dt('2011-01-01'), + // dt('2011-01-02') - dt('2011-01-01'), + // )], dtype=[ + // ('datetime', 'm8[D]'), + // ]) + // ``` + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + datetime: u64, + timedelta_le: i64, + timedelta_be: i64, + } + + let dtype = DType::Record(vec![ + plain_field("datetime", "m8[D]"), + ]); + + let row = Row { + datetime: 1_293_840_000_000_000_000, + timedelta_le: 1, + timedelta_be: 1, + }; + + let expected_data_bytes = { + let mut buf = vec![]; + buf.extend_from_slice(&i64::to_le_bytes(1_293_840_000_000_000_000)); + buf.extend_from_slice(&i64::to_le_bytes(1)); + buf.extend_from_slice(&i64::to_be_bytes(1)); + buf + }; + + let mut out_file = OutFile::open_with_dtype(&dtype, path).unwrap(); + out_file.push(&row).unwrap(); + out_file.close().unwrap(); + + let buffer = std::fs::read(path).unwrap(); + assert!(buffer.ends_with(&expected_data_bytes)); + + let data = npy::NpyData::::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), dtype); +} + +#[test] +fn roundtrip_bytes() { + let path = "tests/roundtrip_bytes.npy"; + + // Similar to: + // + // ``` + // import numpy as np + // + // arr = np.array([( + // b"\x00such\x00wow", + // b"\x00such\x00wow\x00\x00\x00", + // )], dtype=[ + // ('bytestr', '|S12'), + // ('raw', '|V12'), + // ]) + // ``` + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + bytestr: Vec, + raw: Vec, + } + + let dtype = DType::Record(vec![ + plain_field("bytestr", "|S12"), + plain_field("raw", "|V12"), + ]); + + let row = Row { + // checks that: + // * bytestr can be shorter than the len + // * bytestr can contain non-trailing NULs + bytestr: b"\x00lol\x00lol".to_vec(), + // * raw can contain trailing NULs + raw: b"\x00lol\x00lol\x00\x00\x00\x00".to_vec(), + }; + + let expected_data_bytes = { + let mut buf = vec![]; + // check that bytestr is nul-padded + buf.extend_from_slice(b"\x00lol\x00lol\x00\x00\x00\x00"); + buf.extend_from_slice(b"\x00lol\x00lol\x00\x00\x00\x00"); + buf + }; + + let mut out_file = OutFile::open_with_dtype(&dtype, path).unwrap(); + out_file.push(&row).unwrap(); + out_file.close().unwrap(); + + let buffer = std::fs::read(path).unwrap(); + assert!(buffer.ends_with(&expected_data_bytes)); + + let data = npy::NpyData::::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), dtype); +} + +// check that all byte orders are identical for bytestrings +// (i.e. don't accidentally reverse the bytestrings) +#[test] +fn roundtrip_bytes_byteorder() { + let path = "tests/roundtrip_bytes_byteorder.npy"; + + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + s_le: Vec, + s_be: Vec, + s_na: Vec, + v_le: Vec, + v_be: Vec, + v_na: Vec, + }; + + let dtype = DType::Record(vec![ + plain_field("s_le", "S4"), + plain_field("s_na", "|S4"), + plain_field("v_le", "V4"), + plain_field("v_na", "|V4"), + ]); + + let row = Row { + s_le: b"abcd".to_vec(), + s_be: b"abcd".to_vec(), + s_na: b"abcd".to_vec(), + v_le: b"abcd".to_vec(), + v_be: b"abcd".to_vec(), + v_na: b"abcd".to_vec(), + }; + + let expected_data_bytes = { + let mut buf = vec![]; + for _ in 0..6 { + buf.extend_from_slice(b"abcd"); + } + buf + }; + + let mut out_file = OutFile::open_with_dtype(&dtype, path).unwrap(); + out_file.push(&row).unwrap(); + out_file.close().unwrap(); + + let buffer = std::fs::read(path).unwrap(); + assert!(buffer.ends_with(&expected_data_bytes)); + + let data = npy::NpyData::::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), dtype); +} From d177998c2192ecebe48a13aa8edca98b86dfb5d4 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 14:54:34 -0400 Subject: [PATCH 13/20] small bits of cleanup Fix a couple of things I missed while rewriting and reorganizing the commit history. --- src/header.rs | 11 +++++------ tests/serialize_array.rs | 7 ------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/header.rs b/src/header.rs index f239d8a..dc881c8 100644 --- a/src/header.rs +++ b/src/header.rs @@ -9,9 +9,11 @@ use type_str::TypeStr; pub enum DType { /// A simple array with only a single field Plain { - /// Numpy type string. First character is `'>'` for big endian, `'<'` for little endian. + /// Numpy type string. First character is `'>'` for big endian, `'<'` for little endian, + /// or can be `'|'` if it doesn't matter. /// - /// Examples: `>i4`, `f8`. The number corresponds to the number of bytes. + /// Examples: `>i4`, `f8`, `|S7`. The number usually corresponds to the number of + /// bytes (with the single exception of unicode strings `|U3`). ty: TypeStr, /// Shape of a type. @@ -65,10 +67,7 @@ impl DType { pub fn from_descr(descr: Value) -> Result { use DType::*; match descr { - Value::String(string) => { - let ty = convert_string_to_type_str(&string)?; - Ok(Plain { ty, shape: vec![] }) - }, + Value::String(ref string) => Ok(Self::new_scalar(convert_string_to_type_str(string)?)), Value::List(ref list) => Ok(Record(convert_list_to_record_fields(list)?)), _ => invalid_data("must be string or list") } diff --git a/tests/serialize_array.rs b/tests/serialize_array.rs index 7db0f60..4f746b1 100644 --- a/tests/serialize_array.rs +++ b/tests/serialize_array.rs @@ -25,13 +25,6 @@ fn writer_expect_err(dtype: &DType) { T::writer(dtype).err().expect("writer_expect_err failed!"); } -fn writer_expect_write_err(dtype: &DType, value: &T) { - let mut vec = vec![]; - T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) - .write_one(&mut vec, value) - .err().expect("writer_expect_write_err failed!"); -} - #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] #[derive(Debug, PartialEq)] struct Array3 { From 7344b12a783dd8fb5de38516bc1550ade849c6df Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Thu, 6 Jun 2019 15:39:45 -0400 Subject: [PATCH 14/20] add a benchmark for to_vec I tried a variety of things to optimize this function: * Replacing usage of get_unchecked with reuse of the remainder returned by read_one, so that the stride can be statically known rather than having to be looked up. (this is what optimized the old read benchmark) * Putting an assertion up front to prove that the data vector is long enough. But whatever I do, performance won't budge. In the f32 benchmark, a very hot bounds check still occurs on every read to ensure that the length of the data is at least 4 bytes. So I'm adding the benchmark, but leaving the function itself alone. --- .gitignore | 1 + benches/bench.rs | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/.gitignore b/.gitignore index baa1df9..183ce61 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ target Cargo.lock tests/*.npy +benches/*.npy diff --git a/benches/bench.rs b/benches/bench.rs index 9a46809..001641a 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -38,6 +38,19 @@ macro_rules! gen_benches { }); } + #[bench] + fn read_to_vec(b: &mut Bencher) { + // FIXME: Write to a Cursor> once #16 is merged + let path = concat!("benches/bench_", stringify!($T), ".npy"); + + npy::to_file(path, (0usize..NITER).map($new)).unwrap(); + let bytes = std::fs::read(path).unwrap(); + + b.iter(|| { + bb(npy::NpyData::<$T>::from_bytes(&bytes).unwrap().to_vec()) + }); + } + #[bench] fn write(b: &mut Bencher) { b.iter(|| { From dda98d04dd25566f4c1452301e75bf48090b929b Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Tue, 11 Jun 2019 22:49:19 -0400 Subject: [PATCH 15/20] fix npy-derive to not rely on NLL 2015 edition crates do not use NLL yet in the latest stable compiler, so our derive macro must be conservative. Apparently, in the latest nightly, this was changed; 2015 edition will at some point use NLL in the future. This is why I did not notice the problem at first! --- npy-derive/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/npy-derive/src/lib.rs b/npy-derive/src/lib.rs index 7a50777..4333c03 100644 --- a/npy-derive/src/lib.rs +++ b/npy-derive/src/lib.rs @@ -111,10 +111,10 @@ fn impl_npy_serialize(ast: &syn::DeriveInput) -> Tokens { #[allow(unused_mut)] fn write_one(&self, mut w: W, value: &Self::Value) -> io::Result<()> { - #( + #({ // braces for pre-NLL let method = <<#types as _npy::Serialize>::Writer as _npy::TypeWrite>::write_one; method(&self.writers.#idents, &mut w, &value.#idents_1)?; - )* + })* p::Ok(()) } } From ec6266482b84e4311ed3d3643e4370bbe46ad277 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Wed, 12 Jun 2019 11:45:16 -0400 Subject: [PATCH 16/20] Apply suggestions from code review Co-Authored-By: Pavel Potocek --- src/serialize.rs | 4 ++-- src/type_str.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serialize.rs b/src/serialize.rs index b2d7d07..307cda9 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -290,7 +290,7 @@ macro_rules! impl_integer_serializable { #[inline] fn maybe_swap(swap: bool, x: $int) -> $int { match swap { - true => x.to_be().to_le(), + true => x.swap_bytes(), false => x, } } @@ -411,7 +411,7 @@ macro_rules! impl_float_serializable { #[inline] fn maybe_swap(swap: bool, x: $float) -> $float { match swap { - true => $float::from_bits(x.to_bits().to_be().to_le()), + true => $float::from_bits(x.to_bits().swap_bytes()), false => x, } } diff --git a/src/type_str.rs b/src/type_str.rs index 3168411..c448068 100644 --- a/src/type_str.rs +++ b/src/type_str.rs @@ -214,7 +214,7 @@ impl TypeKind { } } - /// Returns `true` if `|` endianness is illegal. + /// Returns `true` if unit specification is required. fn has_units(self) -> bool { match self { TypeKind::TimeDelta | From 5adcee0e4553d5d7bd70f9d0a4e5274be96f6239 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Wed, 12 Jun 2019 12:15:40 -0400 Subject: [PATCH 17/20] remove size1_cfg It was an artefact of an old design. Out of paranoia, I added some assertions to the Serialize/Deserialize impls to make sure the endianness is valid. These are redundant since it is checked in TypeStr::from_str, but most other such properties are at least implicitly checked by the `_` arm in impls of `reader` and `writer` and I wanted to be safe. --- src/serialize.rs | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/serialize.rs b/src/serialize.rs index 307cda9..9bc68ee 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -271,9 +271,7 @@ macro_rules! impl_integer_serializable { ( @generate meta: [ (main_ty: $Int:ident) (date_ty: $DateTime:ident) ] - current: [ $size:literal $int:ident - (size1: $size1_cfg:meta) $read_int:ident $write_int:ident - ] + current: [ $size:literal $int:ident $read_int:ident $write_int:ident ] ) => { mod $int { use super::*; @@ -326,6 +324,8 @@ macro_rules! impl_integer_serializable { // so we support those too. TypeStr { size: $size, endianness, type_kind: $Int, .. } | TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { + assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); Ok($int::AnyEndianReader { swap_byteorder }) }, @@ -342,6 +342,8 @@ macro_rules! impl_integer_serializable { // Write a signed integer of the correct size TypeStr { size: $size, endianness, type_kind: $Int, .. } | TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { + assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); Ok($int::AnyEndianWriter { swap_byteorder }) }, @@ -374,16 +376,15 @@ trait WriteSingleByteExt: WriteBytesExt { impl WriteSingleByteExt for W {} -// `all()` means "true", `any()` means "false". (these get put inside `cfg`) impl_integer_serializable! { @iterate meta: [ (main_ty: Int) (date_ty: TimeDelta) ] remaining: [ // numpy doesn't support i128 - [ 8 i64 (size1: any()) read_i64 write_i64 ] - [ 4 i32 (size1: any()) read_i32 write_i32 ] - [ 2 i16 (size1: any()) read_i16 write_i16 ] - [ 1 i8 (size1: all()) read_i8_ write_i8_ ] + [ 8 i64 read_i64 write_i64 ] + [ 4 i32 read_i32 write_i32 ] + [ 2 i16 read_i16 write_i16 ] + [ 1 i8 read_i8_ write_i8_ ] ] } @@ -392,10 +393,10 @@ impl_integer_serializable! { meta: [ (main_ty: Uint) (date_ty: DateTime) ] remaining: [ // numpy doesn't support i128 - [ 8 u64 (size1: any()) read_u64 write_u64 ] - [ 4 u32 (size1: any()) read_u32 write_u32 ] - [ 2 u16 (size1: any()) read_u16 write_u16 ] - [ 1 u8 (size1: all()) read_u8_ write_u8_ ] + [ 8 u64 read_u64 write_u64 ] + [ 4 u32 read_u32 write_u32 ] + [ 2 u16 read_u16 write_u16 ] + [ 1 u8 read_u8_ write_u8_ ] ] } @@ -449,6 +450,8 @@ macro_rules! impl_float_serializable { match $float::expect_scalar_dtype(dtype)? { // Read a float of the correct size TypeStr { size: $size, endianness, type_kind: Float, .. } => { + assert_ne!(endianness, &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); Ok($float::AnyEndianReader { swap_byteorder }) }, @@ -464,6 +467,8 @@ macro_rules! impl_float_serializable { match $float::expect_scalar_dtype(dtype)? { // Write a float of the correct size TypeStr { size: $size, endianness, type_kind: Float, .. } => { + assert_ne!(endianness, &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); Ok($float::AnyEndianWriter { swap_byteorder }) }, @@ -806,7 +811,7 @@ mod tests { const LE_ONE_32: &[u8] = &[1, 0, 0, 0]; #[test] - fn identity() { + fn native_int_types() { let be = DType::parse("'>i4'").unwrap(); let le = DType::parse("'(&le, &1), LE_ONE_64); } + #[test] + fn illegal_endianness() { + // There is currently no need to test that each type rejects '|' endianness in their + // (De)Serialize impls, because this is checked up-front during DType construction. + assert!(DType::parse("'|i4'").is_err()); + } + #[test] fn wrong_size_int() { let t_i32 = DType::parse("' Date: Wed, 12 Jun 2019 12:22:09 -0400 Subject: [PATCH 18/20] change NpyData::dtype to return a reference I'm not really sure why I had it return a clone in the first place... --- src/npy_data.rs | 4 ++-- tests/roundtrip.rs | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/npy_data.rs b/src/npy_data.rs index 183b2bc..f7736e4 100644 --- a/src/npy_data.rs +++ b/src/npy_data.rs @@ -31,8 +31,8 @@ impl<'a, T: Deserialize> NpyData<'a, T> { } /// Get the dtype as written in the file. - pub fn dtype(&self) -> DType { - self.dtype.clone() + pub fn dtype(&self) -> &DType { + &self.dtype } /// Gets a single data-record with the specified index. Returns None, if the index is diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index 48eb769..f4f3935 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -208,7 +208,7 @@ fn roundtrip_byteorder() { let data = npy::NpyData::::from_bytes(&buffer).unwrap(); assert_eq!(data.to_vec(), vec![row]); - assert_eq!(data.dtype(), dtype); + assert_eq!(data.dtype(), &dtype); } #[test] @@ -268,7 +268,7 @@ fn roundtrip_datetime() { let data = npy::NpyData::::from_bytes(&buffer).unwrap(); assert_eq!(data.to_vec(), vec![row]); - assert_eq!(data.dtype(), dtype); + assert_eq!(data.dtype(), &dtype); } #[test] @@ -326,7 +326,7 @@ fn roundtrip_bytes() { let data = npy::NpyData::::from_bytes(&buffer).unwrap(); assert_eq!(data.to_vec(), vec![row]); - assert_eq!(data.dtype(), dtype); + assert_eq!(data.dtype(), &dtype); } // check that all byte orders are identical for bytestrings @@ -381,5 +381,5 @@ fn roundtrip_bytes_byteorder() { let data = npy::NpyData::::from_bytes(&buffer).unwrap(); assert_eq!(data.to_vec(), vec![row]); - assert_eq!(data.dtype(), dtype); + assert_eq!(data.dtype(), &dtype); } From c4bf3ca27e234544dbb9bcdc738c49e3c80ddd02 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Fri, 12 Jul 2019 19:35:10 -0400 Subject: [PATCH 19/20] fixup some outdated comments --- src/serialize.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serialize.rs b/src/serialize.rs index 9bc68ee..a97c8e1 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -244,7 +244,6 @@ fn invalid_data(message: &str) -> io::Result { Err(io::Error::new(io::ErrorKind::InvalidData, message.to_string())) } -// Takes info about each data size, from largest to smallest. macro_rules! impl_integer_serializable { ( @iterate meta: $meta:tt @@ -318,7 +317,7 @@ macro_rules! impl_integer_serializable { fn reader(dtype: &DType) -> Result { match $int::expect_scalar_dtype(dtype)? { - // Read an integer of the same size and signedness. + // Read an integer of the correct size and signedness. // // DateTime is an unsigned integer and TimeDelta is a signed integer, // so we support those too. @@ -339,7 +338,7 @@ macro_rules! impl_integer_serializable { fn writer(dtype: &DType) -> Result { match $int::expect_scalar_dtype(dtype)? { - // Write a signed integer of the correct size + // Write an integer of the correct size and signedness. TypeStr { size: $size, endianness, type_kind: $Int, .. } | TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); From 646d05377672e7679658e137509ba27a16bb18a4 Mon Sep 17 00:00:00 2001 From: Michael Lamparski Date: Mon, 15 Jul 2019 10:45:03 -0400 Subject: [PATCH 20/20] remove datetime; simplify macro; fix comments I don't want to yet commit to a specific API for serializing DateTime/TimeDelta, so that we can keep open the option of widening conversions. --- src/serialize.rs | 73 ++++++++---------------------------------------- src/type_str.rs | 12 ++++---- 2 files changed, 19 insertions(+), 66 deletions(-) diff --git a/src/serialize.rs b/src/serialize.rs index a97c8e1..452b97c 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -245,33 +245,12 @@ fn invalid_data(message: &str) -> io::Result { } macro_rules! impl_integer_serializable { - ( @iterate - meta: $meta:tt - remaining: [] - ) => {}; - - ( @iterate - meta: $meta:tt - remaining: [$first:tt $($smaller:tt)*] - ) => { - impl_integer_serializable! { - @generate - meta: $meta - current: $first - } - - impl_integer_serializable! { - @iterate - meta: $meta - remaining: [ $($smaller)* ] - } - }; - ( - @generate - meta: [ (main_ty: $Int:ident) (date_ty: $DateTime:ident) ] - current: [ $size:literal $int:ident $read_int:ident $write_int:ident ] - ) => { + meta: [ (main_ty: $Int:ident) ] + types: [ $( + [ $size:literal $int:ident $read_int:ident $write_int:ident ] + )* ] + ) => { $( mod $int { use super::*; @@ -318,11 +297,7 @@ macro_rules! impl_integer_serializable { fn reader(dtype: &DType) -> Result { match $int::expect_scalar_dtype(dtype)? { // Read an integer of the correct size and signedness. - // - // DateTime is an unsigned integer and TimeDelta is a signed integer, - // so we support those too. - TypeStr { size: $size, endianness, type_kind: $Int, .. } | - TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { + TypeStr { size: $size, endianness, type_kind: $Int, .. } => { assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); @@ -339,8 +314,7 @@ macro_rules! impl_integer_serializable { fn writer(dtype: &DType) -> Result { match $int::expect_scalar_dtype(dtype)? { // Write an integer of the correct size and signedness. - TypeStr { size: $size, endianness, type_kind: $Int, .. } | - TypeStr { size: $size, endianness, type_kind: $DateTime, .. } => { + TypeStr { size: $size, endianness, type_kind: $Int, .. } => { assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); @@ -356,7 +330,7 @@ macro_rules! impl_integer_serializable { DType::new_scalar(TypeStr::with_auto_endianness($Int, $size, None)) } } - }; + )*}; } // Needed by the macro: Methods missing from byteorder @@ -376,9 +350,8 @@ trait WriteSingleByteExt: WriteBytesExt { impl WriteSingleByteExt for W {} impl_integer_serializable! { - @iterate - meta: [ (main_ty: Int) (date_ty: TimeDelta) ] - remaining: [ + meta: [ (main_ty: Int) ] + types: [ // numpy doesn't support i128 [ 8 i64 read_i64 write_i64 ] [ 4 i32 read_i32 write_i32 ] @@ -388,9 +361,8 @@ impl_integer_serializable! { } impl_integer_serializable! { - @iterate - meta: [ (main_ty: Uint) (date_ty: DateTime) ] - remaining: [ + meta: [ (main_ty: Uint) ] + types: [ // numpy doesn't support i128 [ 8 u64 read_u64 write_u64 ] [ 4 u32 read_u32 write_u32 ] @@ -804,8 +776,6 @@ mod tests { .err().expect("writer_expect_write_err failed!"); } - const BE_ONE_64: &[u8] = &[0, 0, 0, 0, 0, 0, 0, 1]; - const LE_ONE_64: &[u8] = &[1, 0, 0, 0, 0, 0, 0, 0]; const BE_ONE_32: &[u8] = &[0, 0, 0, 1]; const LE_ONE_32: &[u8] = &[1, 0, 0, 0]; @@ -863,25 +833,6 @@ mod tests { assert_eq!(writer_output::(&le, &42.0), &le_bytes); } - #[test] - fn datetime_as_int() { - let be = DType::parse("'>m8[ns]'").unwrap(); - let le = DType::parse("'(&be, BE_ONE_64), 1); - assert_eq!(reader_output::(&le, LE_ONE_64), 1); - assert_eq!(writer_output::(&be, &1), BE_ONE_64); - assert_eq!(writer_output::(&le, &1), LE_ONE_64); - - let be = DType::parse("'>M8[ns]'").unwrap(); - let le = DType::parse("'(&be, BE_ONE_64), 1); - assert_eq!(reader_output::(&le, LE_ONE_64), 1); - assert_eq!(writer_output::(&be, &1), BE_ONE_64); - assert_eq!(writer_output::(&le, &1), LE_ONE_64); - } - #[test] fn illegal_endianness() { // There is currently no need to test that each type rejects '|' endianness in their diff --git a/src/type_str.rs b/src/type_str.rs index c448068..d71f316 100644 --- a/src/type_str.rs +++ b/src/type_str.rs @@ -95,7 +95,9 @@ pub(crate) enum TypeKind { Uint, /// Code `f`. /// - /// Notice that numpy **does** support 128-bit floats. + /// Notice that numpy supports half-precision floats (`np.float16`), as well as possibly + /// `