Skip to content

Commit

Permalink
Merge pull request #76 from jturner314/names-in-npz
Browse files Browse the repository at this point in the history
Match NumPy regarding .npy for names in .npz files
  • Loading branch information
jturner314 authored Sep 14, 2024
2 parents 7e6ea69 + d3f0dcb commit 08bef59
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/npz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ impl<W: Write + Seek> NpzWriter<W> {

/// Adds an array with the specified `name` to the `.npz` file.
///
/// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
///
/// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
/// [`aview0`](ndarray::aview0).
pub fn add_array<N, S, D>(
Expand All @@ -118,7 +120,7 @@ impl<W: Write + Seek> NpzWriter<W> {
S: Data,
D: Dimension,
{
self.zip.start_file(name, self.options)?;
self.zip.start_file(name.into() + ".npy", self.options)?;
// Buffering when writing individual arrays is beneficial even when the
// underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
// only exception I saw in testing was the "compressed, in-memory
Expand Down Expand Up @@ -221,20 +223,40 @@ impl<R: Read + Seek> NpzReader<R> {
}

/// Returns the names of all of the arrays in the file.
///
/// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches
/// NumPy's behavior.
pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
Ok((0..self.zip.len())
.map(|i| Ok(self.zip.by_index(i)?.name().to_owned()))
.map(|i| {
let file = self.zip.by_index(i)?;
let name = file.name();
let stripped = name.strip_suffix(".npy").unwrap_or(name);
Ok(stripped.to_owned())
})
.collect::<Result<_, ZipError>>()?)
}

/// Reads an array by name.
///
/// Note that this first checks for `name` in the `.npz` file, and if that is not present,
/// checks for `format!("{name}.npy")`. This matches NumPy's behavior.
pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_name(name)?)?)
// TODO: Combine the two cases into a single `let file = match { ... }` once
// https://github.com/rust-lang/rust/issues/47680 is resolved.
match self.zip.by_name(name) {
Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
Err(ZipError::FileNotFound) => {}
Err(err) => return Err(err.into()),
};
Ok(ArrayBase::<S, D>::read_npy(
self.zip.by_name(&format!("{name}.npy"))?,
)?)
}

/// Reads an array by index in the `.npz` file.
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::io::{self, Read};
use std::ops::{Deref, DerefMut};

mod examples;
#[cfg(feature = "npz")]
mod npz;
mod primitive;
mod round_trip;

Expand Down
52 changes: 52 additions & 0 deletions tests/integration/npz.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//! .npz examples.
use ndarray::{array, Array2};
use ndarray_npy::{NpzReader, NpzWriter};
use std::{error::Error, io::Cursor};

#[test]
fn round_trip_npz() -> Result<(), Box<dyn Error>> {
let mut buf = Vec::<u8>::new();

let arr1 = array![[1i32, 3, 0], [4, 7, -1]];
let arr2 = array![[9i32, 6], [-5, 2], [3, -1]];

{
let mut writer = NpzWriter::new(Cursor::new(&mut buf));
writer.add_array("arr1", &arr1)?;
writer.add_array("arr2", &arr2)?;
writer.finish()?;
}

{
let mut reader = NpzReader::new(Cursor::new(&buf))?;
assert!(!reader.is_empty());
assert_eq!(reader.len(), 2);
assert_eq!(
reader.names()?,
vec!["arr1".to_string(), "arr2".to_string()],
);
{
let by_name: Array2<i32> = reader.by_name("arr1")?;
assert_eq!(by_name, arr1);
}
{
let by_name: Array2<i32> = reader.by_name("arr1.npy")?;
assert_eq!(by_name, arr1);
}
{
let by_name: Array2<i32> = reader.by_name("arr2")?;
assert_eq!(by_name, arr2);
}
{
let by_name: Array2<i32> = reader.by_name("arr2.npy")?;
assert_eq!(by_name, arr2);
}
{
let res: Result<Array2<i32>, _> = reader.by_name("arr1.npy.npy");
assert!(res.is_err());
}
}

Ok(())
}

0 comments on commit 08bef59

Please sign in to comment.