From 82177e9b704be0a812c02dbaffcb6c55f7300094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20du=20Garreau?= Date: Fri, 5 Apr 2024 16:13:44 +0200 Subject: [PATCH] Improve several `Read` methods on `ZipFile` --- src/crc32.rs | 35 ++++++++++++++++++++-- src/read.rs | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 3 deletions(-) diff --git a/src/crc32.rs b/src/crc32.rs index ebace898d..9e25d6e5a 100644 --- a/src/crc32.rs +++ b/src/crc32.rs @@ -36,20 +36,49 @@ impl Crc32Reader { } } +#[cold] +fn invalid_checksum() -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, "Invalid checksum") +} + impl Read for Crc32Reader { fn read(&mut self, buf: &mut [u8]) -> io::Result { let invalid_check = !buf.is_empty() && !self.check_matches() && !self.ae2_encrypted; let count = match self.inner.read(buf) { - Ok(0) if invalid_check => { - return Err(io::Error::new(io::ErrorKind::Other, "Invalid checksum")) - } + Ok(0) if invalid_check => return Err(invalid_checksum()), Ok(n) => n, Err(e) => return Err(e), }; self.hasher.update(&buf[0..count]); Ok(count) } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + let start = buf.len(); + let n = self.inner.read_to_end(buf)?; + + self.hasher.update(&buf[start..]); + + if !self.check_matches() && !self.ae2_encrypted { + return Err(invalid_checksum()); + } + + Ok(n) + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + let start = buf.len(); + let n = self.inner.read_to_string(buf)?; + + self.hasher.update(&buf.as_bytes()[start..]); + + if !self.check_matches() && !self.ae2_encrypted { + return Err(invalid_checksum()); + } + + Ok(n) + } } #[cfg(test)] diff --git a/src/read.rs b/src/read.rs index b702b4f21..e4add3bf4 100644 --- a/src/read.rs +++ b/src/read.rs @@ -91,6 +91,24 @@ impl<'a> Read for CryptoReader<'a> { CryptoReader::Aes { reader: r, .. } => r.read(buf), } } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + match self { + CryptoReader::Plaintext(r) => r.read_to_end(buf), + CryptoReader::ZipCrypto(r) => r.read_to_end(buf), + #[cfg(feature = "aes-crypto")] + CryptoReader::Aes { reader: r, .. } => r.read_to_end(buf), + } + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + match self { + CryptoReader::Plaintext(r) => r.read_to_string(buf), + CryptoReader::ZipCrypto(r) => r.read_to_string(buf), + #[cfg(feature = "aes-crypto")] + CryptoReader::Aes { reader: r, .. } => r.read_to_string(buf), + } + } } impl<'a> CryptoReader<'a> { @@ -153,6 +171,60 @@ impl<'a> Read for ZipFileReader<'a> { ZipFileReader::Zstd(r) => r.read(buf), } } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + match self { + ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::Raw(r) => r.read_exact(buf), + ZipFileReader::Stored(r) => r.read_exact(buf), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + ZipFileReader::Deflated(r) => r.read_exact(buf), + #[cfg(feature = "bzip2")] + ZipFileReader::Bzip2(r) => r.read_exact(buf), + #[cfg(feature = "zstd")] + ZipFileReader::Zstd(r) => r.read_exact(buf), + } + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + match self { + ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::Raw(r) => r.read_to_end(buf), + ZipFileReader::Stored(r) => r.read_to_end(buf), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + ZipFileReader::Deflated(r) => r.read_to_end(buf), + #[cfg(feature = "bzip2")] + ZipFileReader::Bzip2(r) => r.read_to_end(buf), + #[cfg(feature = "zstd")] + ZipFileReader::Zstd(r) => r.read_to_end(buf), + } + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + match self { + ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::Raw(r) => r.read_to_string(buf), + ZipFileReader::Stored(r) => r.read_to_string(buf), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + ZipFileReader::Deflated(r) => r.read_to_string(buf), + #[cfg(feature = "bzip2")] + ZipFileReader::Bzip2(r) => r.read_to_string(buf), + #[cfg(feature = "zstd")] + ZipFileReader::Zstd(r) => r.read_to_string(buf), + } + } } impl<'a> ZipFileReader<'a> { @@ -979,6 +1051,18 @@ impl<'a> Read for ZipFile<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.get_reader().read(buf) } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.get_reader().read_exact(buf) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.get_reader().read_to_end(buf) + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.get_reader().read_to_string(buf) + } } impl<'a> Drop for ZipFile<'a> {