diff --git a/src/npz.rs b/src/npz.rs index 42c0c88..29eded8 100644 --- a/src/npz.rs +++ b/src/npz.rs @@ -105,6 +105,8 @@ impl NpzWriter { /// Adds an array with the specified `name` to the `.npz` file. /// + /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior. + /// /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or /// [`aview0`](ndarray::aview0). pub fn add_array( @@ -118,7 +120,7 @@ impl NpzWriter { S: Data, D: Dimension, { - self.zip.start_file(name, self.options)?; + self.zip.start_file(name.into() + ".npy", self.options)?; // Buffering when writing individual arrays is beneficial even when the // underlying writer is `Cursor>` instead of a real file. The // only exception I saw in testing was the "compressed, in-memory @@ -221,20 +223,40 @@ impl NpzReader { } /// Returns the names of all of the arrays in the file. + /// + /// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches + /// NumPy's behavior. pub fn names(&mut self) -> Result, ReadNpzError> { Ok((0..self.zip.len()) - .map(|i| Ok(self.zip.by_index(i)?.name().to_owned())) + .map(|i| { + let file = self.zip.by_index(i)?; + let name = file.name(); + let stripped = name.strip_suffix(".npy").unwrap_or(name); + Ok(stripped.to_owned()) + }) .collect::>()?) } /// Reads an array by name. + /// + /// Note that this first checks for `name` in the `.npz` file, and if that is not present, + /// checks for `format!("{name}.npy")`. This matches NumPy's behavior. pub fn by_name(&mut self, name: &str) -> Result, ReadNpzError> where S::Elem: ReadableElement, S: DataOwned, D: Dimension, { - Ok(ArrayBase::::read_npy(self.zip.by_name(name)?)?) + // TODO: Combine the two cases into a single `let file = match { ... }` once + // https://github.com/rust-lang/rust/issues/47680 is resolved. + match self.zip.by_name(name) { + Ok(file) => return Ok(ArrayBase::::read_npy(file)?), + Err(ZipError::FileNotFound) => {} + Err(err) => return Err(err.into()), + }; + Ok(ArrayBase::::read_npy( + self.zip.by_name(&format!("{name}.npy"))?, + )?) } /// Reads an array by index in the `.npz` file. diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 25a20f4..a8b9197 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -6,6 +6,8 @@ use std::io::{self, Read}; use std::ops::{Deref, DerefMut}; mod examples; +#[cfg(feature = "npz")] +mod npz; mod primitive; mod round_trip; diff --git a/tests/integration/npz.rs b/tests/integration/npz.rs new file mode 100644 index 0000000..9be26d3 --- /dev/null +++ b/tests/integration/npz.rs @@ -0,0 +1,52 @@ +//! .npz examples. + +use ndarray::{array, Array2}; +use ndarray_npy::{NpzReader, NpzWriter}; +use std::{error::Error, io::Cursor}; + +#[test] +fn round_trip_npz() -> Result<(), Box> { + let mut buf = Vec::::new(); + + let arr1 = array![[1i32, 3, 0], [4, 7, -1]]; + let arr2 = array![[9i32, 6], [-5, 2], [3, -1]]; + + { + let mut writer = NpzWriter::new(Cursor::new(&mut buf)); + writer.add_array("arr1", &arr1)?; + writer.add_array("arr2", &arr2)?; + writer.finish()?; + } + + { + let mut reader = NpzReader::new(Cursor::new(&buf))?; + assert!(!reader.is_empty()); + assert_eq!(reader.len(), 2); + assert_eq!( + reader.names()?, + vec!["arr1".to_string(), "arr2".to_string()], + ); + { + let by_name: Array2 = reader.by_name("arr1")?; + assert_eq!(by_name, arr1); + } + { + let by_name: Array2 = reader.by_name("arr1.npy")?; + assert_eq!(by_name, arr1); + } + { + let by_name: Array2 = reader.by_name("arr2")?; + assert_eq!(by_name, arr2); + } + { + let by_name: Array2 = reader.by_name("arr2.npy")?; + assert_eq!(by_name, arr2); + } + { + let res: Result, _> = reader.by_name("arr1.npy.npy"); + assert!(res.is_err()); + } + } + + Ok(()) +}