diff --git a/src/compression.rs b/src/compression.rs index 0dd21017a..83a7669bd 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -1,6 +1,6 @@ //! Possible ZIP compression methods. -use std::fmt; +use std::{fmt, io}; #[allow(deprecated)] /// Identifies the storage format used to compress a file within a ZIP archive. @@ -189,6 +189,92 @@ pub const SUPPORTED_COMPRESSION_METHODS: &[CompressionMethod] = &[ CompressionMethod::Zstd, ]; +pub(crate) enum Decompressor { + Stored(R), + #[cfg(feature = "_deflate-any")] + Deflated(flate2::bufread::DeflateDecoder), + #[cfg(feature = "deflate64")] + Deflate64(deflate64::Deflate64Decoder), + #[cfg(feature = "bzip2")] + Bzip2(bzip2::bufread::BzDecoder), + #[cfg(feature = "zstd")] + Zstd(zstd::Decoder<'static, R>), + #[cfg(feature = "lzma")] + Lzma(Box>), + #[cfg(feature = "xz")] + Xz(crate::read::xz::XzDecoder), +} + +impl io::Read for Decompressor { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Decompressor::Stored(r) => r.read(buf), + #[cfg(feature = "_deflate-any")] + Decompressor::Deflated(r) => r.read(buf), + #[cfg(feature = "deflate64")] + Decompressor::Deflate64(r) => r.read(buf), + #[cfg(feature = "bzip2")] + Decompressor::Bzip2(r) => r.read(buf), + #[cfg(feature = "zstd")] + Decompressor::Zstd(r) => r.read(buf), + #[cfg(feature = "lzma")] + Decompressor::Lzma(r) => r.read(buf), + #[cfg(feature = "xz")] + Decompressor::Xz(r) => r.read(buf), + } + } +} + +impl Decompressor { + pub fn new(reader: R, compression_method: CompressionMethod) -> crate::result::ZipResult { + Ok(match compression_method { + CompressionMethod::Stored => Decompressor::Stored(reader), + #[cfg(feature = "_deflate-any")] + CompressionMethod::Deflated => { + Decompressor::Deflated(flate2::bufread::DeflateDecoder::new(reader)) + } + #[cfg(feature = "deflate64")] + CompressionMethod::Deflate64 => { + Decompressor::Deflate64(deflate64::Deflate64Decoder::with_buffer(reader)) + } + #[cfg(feature = "bzip2")] + CompressionMethod::Bzip2 => Decompressor::Bzip2(bzip2::bufread::BzDecoder::new(reader)), + #[cfg(feature = "zstd")] + CompressionMethod::Zstd => Decompressor::Zstd(zstd::Decoder::with_buffer(reader)?), + #[cfg(feature = "lzma")] + CompressionMethod::Lzma => { + Decompressor::Lzma(Box::new(crate::read::lzma::LzmaDecoder::new(reader))) + } + #[cfg(feature = "xz")] + CompressionMethod::Xz => Decompressor::Xz(crate::read::xz::XzDecoder::new(reader)), + _ => { + return Err(crate::result::ZipError::UnsupportedArchive( + "Compression method not supported", + )) + } + }) + } + + /// Consumes this decoder, returning the underlying reader. + pub fn into_inner(self) -> R { + match self { + Decompressor::Stored(r) => r, + #[cfg(feature = "_deflate-any")] + Decompressor::Deflated(r) => r.into_inner(), + #[cfg(feature = "deflate64")] + Decompressor::Deflate64(r) => r.into_inner(), + #[cfg(feature = "bzip2")] + Decompressor::Bzip2(r) => r.into_inner(), + #[cfg(feature = "zstd")] + Decompressor::Zstd(r) => r.finish(), + #[cfg(feature = "lzma")] + Decompressor::Lzma(r) => r.into_inner(), + #[cfg(feature = "xz")] + Decompressor::Xz(r) => r.into_inner(), + } + } +} + #[cfg(test)] mod test { use super::{CompressionMethod, SUPPORTED_COMPRESSION_METHODS}; diff --git a/src/read.rs b/src/read.rs index 830a58514..6185ac5bc 100644 --- a/src/read.rs +++ b/src/read.rs @@ -2,7 +2,7 @@ #[cfg(feature = "aes-crypto")] use crate::aes::{AesReader, AesReaderValid}; -use crate::compression::CompressionMethod; +use crate::compression::{CompressionMethod, Decompressor}; use crate::cp437::FromCp437; use crate::crc32::Crc32Reader; use crate::extra_fields::{ExtendedTimestamp, ExtraField}; @@ -26,18 +26,6 @@ use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::{Arc, OnceLock}; -#[cfg(feature = "deflate-flate2")] -use flate2::read::DeflateDecoder; - -#[cfg(feature = "deflate64")] -use deflate64::Deflate64Decoder; - -#[cfg(feature = "bzip2")] -use bzip2::read::BzDecoder; - -#[cfg(feature = "zstd")] -use zstd::stream::read::Decoder as ZstdDecoder; - mod config; pub use config::*; @@ -123,11 +111,7 @@ pub(crate) mod zip_archive { #[cfg(feature = "aes-crypto")] use crate::aes::PWD_VERIFY_LENGTH; use crate::extra_fields::UnicodeExtraField; -#[cfg(feature = "lzma")] -use crate::read::lzma::LzmaDecoder; -#[cfg(feature = "xz")] -use crate::read::xz::XzDecoder; -use crate::result::ZipError::{InvalidArchive, InvalidPassword, UnsupportedArchive}; +use crate::result::ZipError::{InvalidArchive, InvalidPassword}; use crate::spec::is_dir; use crate::types::ffi::S_IFLNK; use crate::unstable::{path_to_string, LittleEndianReadExt}; @@ -199,141 +183,69 @@ impl<'a> CryptoReader<'a> { } } +#[cold] +fn invalid_state() -> io::Result { + Err(io::Error::new( + io::ErrorKind::Other, + "ZipFileReader was in an invalid state", + )) +} + pub(crate) enum ZipFileReader<'a> { NoReader, Raw(io::Take<&'a mut dyn Read>), - Stored(Crc32Reader>), - #[cfg(feature = "_deflate-any")] - Deflated(Crc32Reader>>), - #[cfg(feature = "deflate64")] - Deflate64(Crc32Reader>>>), - #[cfg(feature = "bzip2")] - Bzip2(Crc32Reader>>), - #[cfg(feature = "zstd")] - Zstd(Crc32Reader>>>), - #[cfg(feature = "lzma")] - Lzma(Crc32Reader>>>), - #[cfg(feature = "xz")] - Xz(Crc32Reader>>), + Compressed(Box>>>>), } impl<'a> Read for ZipFileReader<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => invalid_state(), ZipFileReader::Raw(r) => r.read(buf), - ZipFileReader::Stored(r) => r.read(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read(buf), + ZipFileReader::Compressed(r) => r.read(buf), } } fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => invalid_state(), ZipFileReader::Raw(r) => r.read_exact(buf), - ZipFileReader::Stored(r) => r.read_exact(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read_exact(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read_exact(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read_exact(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read_exact(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read_exact(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read_exact(buf), + ZipFileReader::Compressed(r) => r.read_exact(buf), } } fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => invalid_state(), ZipFileReader::Raw(r) => r.read_to_end(buf), - ZipFileReader::Stored(r) => r.read_to_end(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read_to_end(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read_to_end(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read_to_end(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read_to_end(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read_to_end(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read_to_end(buf), + ZipFileReader::Compressed(r) => r.read_to_end(buf), } } fn read_to_string(&mut self, buf: &mut String) -> io::Result { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => invalid_state(), ZipFileReader::Raw(r) => r.read_to_string(buf), - ZipFileReader::Stored(r) => r.read_to_string(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read_to_string(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read_to_string(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read_to_string(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read_to_string(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read_to_string(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read_to_string(buf), + ZipFileReader::Compressed(r) => r.read_to_string(buf), } } } impl<'a> ZipFileReader<'a> { - /// Consumes this decoder, returning the underlying reader. - pub fn drain(self) { - let mut inner = match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), - ZipFileReader::Raw(r) => r, - ZipFileReader::Stored(r) => r.into_inner().into_inner(), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.into_inner().into_inner().into_inner(), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.into_inner().into_inner().into_inner().into_inner(), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.into_inner().into_inner().into_inner(), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.into_inner().finish().into_inner().into_inner(), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => { - // Lzma reader owns its buffer rather than mutably borrowing it, so we have to drop - // it separately - if let Ok(mut remaining) = r.into_inner().finish() { - let _ = copy(&mut remaining, &mut sink()); - } - return; + fn into_inner(self) -> io::Result> { + match self { + ZipFileReader::NoReader => invalid_state(), + ZipFileReader::Raw(r) => Ok(r), + ZipFileReader::Compressed(r) => { + Ok(r.into_inner().into_inner().into_inner().into_inner()) } - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.into_inner().into_inner().into_inner(), - }; - let _ = copy(&mut inner, &mut sink()); + } } } /// A struct for reading a zip file pub struct ZipFile<'a> { pub(crate) data: Cow<'a, ZipFileData>, - pub(crate) crypto_reader: Option>, pub(crate) reader: ZipFileReader<'a>, } @@ -382,18 +294,14 @@ fn find_data_start( #[allow(clippy::too_many_arguments)] pub(crate) fn make_crypto_reader<'a>( - compression_method: CompressionMethod, - crc32: u32, - mut last_modified_time: Option, - using_data_descriptor: bool, + data: &ZipFileData, reader: io::Take<&'a mut dyn Read>, password: Option<&[u8]>, aes_info: Option<(AesMode, AesVendorVersion, CompressionMethod)>, - #[cfg(feature = "aes-crypto")] compressed_size: u64, ) -> ZipResult> { #[allow(deprecated)] { - if let CompressionMethod::Unsupported(_) = compression_method { + if let CompressionMethod::Unsupported(_) = data.compression_method { return unsupported_zip_error("Compression method not supported"); } } @@ -407,17 +315,18 @@ pub(crate) fn make_crypto_reader<'a>( } #[cfg(feature = "aes-crypto")] (Some(password), Some((aes_mode, vendor_version, _))) => CryptoReader::Aes { - reader: AesReader::new(reader, aes_mode, compressed_size).validate(password)?, + reader: AesReader::new(reader, aes_mode, data.compressed_size).validate(password)?, vendor_version, }, (Some(password), None) => { - if !using_data_descriptor { + let mut last_modified_time = data.last_modified_time; + if !data.using_data_descriptor { last_modified_time = None; } let validator = if let Some(last_modified_time) = last_modified_time { ZipCryptoValidator::InfoZipMsdosTime(last_modified_time.timepart()) } else { - ZipCryptoValidator::PkzipCrc32(crc32) + ZipCryptoValidator::PkzipCrc32(data.crc32) }; CryptoReader::ZipCrypto(ZipCryptoReader::new(reader, password).validate(validator)?) } @@ -434,68 +343,11 @@ pub(crate) fn make_reader( ) -> ZipResult { let ae2_encrypted = reader.is_ae2_encrypted(); - match compression_method { - CompressionMethod::Stored => Ok(ZipFileReader::Stored(Crc32Reader::new( - reader, - crc32, - ae2_encrypted, - ))), - #[cfg(feature = "_deflate-any")] - CompressionMethod::Deflated => { - let deflate_reader = DeflateDecoder::new(reader); - Ok(ZipFileReader::Deflated(Crc32Reader::new( - deflate_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "deflate64")] - CompressionMethod::Deflate64 => { - let deflate64_reader = Deflate64Decoder::new(reader); - Ok(ZipFileReader::Deflate64(Crc32Reader::new( - deflate64_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "bzip2")] - CompressionMethod::Bzip2 => { - let bzip2_reader = BzDecoder::new(reader); - Ok(ZipFileReader::Bzip2(Crc32Reader::new( - bzip2_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "zstd")] - CompressionMethod::Zstd => { - let zstd_reader = ZstdDecoder::new(reader).unwrap(); - Ok(ZipFileReader::Zstd(Crc32Reader::new( - zstd_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "lzma")] - CompressionMethod::Lzma => { - let reader = LzmaDecoder::new(reader); - Ok(ZipFileReader::Lzma(Crc32Reader::new( - Box::new(reader), - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "xz")] - CompressionMethod::Xz => { - let reader = XzDecoder::new(reader); - Ok(ZipFileReader::Xz(Crc32Reader::new( - reader, - crc32, - ae2_encrypted, - ))) - } - _ => Err(UnsupportedArchive("Compression method not supported")), - } + Ok(ZipFileReader::Compressed(Box::new(Crc32Reader::new( + Decompressor::new(io::BufReader::new(reader), compression_method)?, + crc32, + ae2_encrypted, + )))) } #[derive(Debug)] @@ -1207,7 +1059,6 @@ impl ZipArchive { .get_index(file_number) .ok_or(ZipError::FileNotFound)?; Ok(ZipFile { - crypto_reader: None, reader: ZipFileReader::Raw(find_content(data, reader)?), data: Cow::Borrowed(data), }) @@ -1231,21 +1082,11 @@ impl ZipArchive { } let limit_reader = find_content(data, &mut self.reader)?; - let crypto_reader = make_crypto_reader( - data.compression_method, - data.crc32, - data.last_modified_time, - data.using_data_descriptor, - limit_reader, - password, - data.aes_mode, - #[cfg(feature = "aes-crypto")] - data.compressed_size, - )?; + let crypto_reader = make_crypto_reader(data, limit_reader, password, data.aes_mode)?; + Ok(ZipFile { - crypto_reader: Some(crypto_reader), - reader: ZipFileReader::NoReader, data: Cow::Borrowed(data), + reader: make_reader(data.compression_method, data.crc32, crypto_reader)?, }) } @@ -1534,21 +1375,8 @@ pub(crate) fn parse_single_extra_field( /// Methods for retrieving information on zip files impl<'a> ZipFile<'a> { - fn get_reader(&mut self) -> ZipResult<&mut ZipFileReader<'a>> { - if let ZipFileReader::NoReader = self.reader { - let data = &self.data; - let crypto_reader = self.crypto_reader.take().expect("Invalid reader state"); - self.reader = make_reader(data.compression_method, data.crc32, crypto_reader)?; - } - Ok(&mut self.reader) - } - - pub(crate) fn get_raw_reader(&mut self) -> &mut dyn Read { - if let ZipFileReader::NoReader = self.reader { - let crypto_reader = self.crypto_reader.take().expect("Invalid reader state"); - self.reader = ZipFileReader::Raw(crypto_reader.into_inner()) - } - &mut self.reader + pub(crate) fn take_raw_reader(&mut self) -> io::Result> { + std::mem::replace(&mut self.reader, ZipFileReader::NoReader).into_inner() } /// Get the version of the file @@ -1700,19 +1528,19 @@ impl<'a> ZipFile<'a> { impl<'a> Read for ZipFile<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.get_reader()?.read(buf) + self.reader.read(buf) } fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { - self.get_reader()?.read_exact(buf) + self.reader.read_exact(buf) } fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { - self.get_reader()?.read_to_end(buf) + self.reader.read_to_end(buf) } fn read_to_string(&mut self, buf: &mut String) -> io::Result { - self.get_reader()?.read_to_string(buf) + self.reader.read_to_string(buf) } } @@ -1722,19 +1550,9 @@ impl<'a> Drop for ZipFile<'a> { // In this case, we want to exhaust the reader so that the next file is accessible. if let Cow::Owned(_) = self.data { // Get the inner `Take` reader so all decryption, decompression and CRC calculation is skipped. - match &mut self.reader { - ZipFileReader::NoReader => { - let innerreader = self.crypto_reader.take(); - let _ = copy( - &mut innerreader.expect("Invalid reader state").into_inner(), - &mut sink(), - ); - } - reader => { - let innerreader = std::mem::replace(reader, ZipFileReader::NoReader); - innerreader.drain(); - } - }; + if let Ok(mut inner) = self.take_raw_reader() { + let _ = copy(&mut inner, &mut sink()); + } } } } @@ -1785,21 +1603,10 @@ pub fn read_zipfile_from_stream<'a, R: Read>(reader: &'a mut R) -> ZipResult LzmaDecoder { } } - pub fn finish(mut self) -> Result> { - copy(&mut self.compressed_reader, &mut self.stream)?; - self.stream.finish().map_err(Error::from) + pub fn into_inner(self) -> R { + self.compressed_reader } } -impl Read for LzmaDecoder { +impl Read for LzmaDecoder { fn read(&mut self, buf: &mut [u8]) -> Result { let mut bytes_read = self.stream.get_output_mut().unwrap().read(buf)?; while bytes_read < buf.len() { - let mut next_compressed = [0u8; COMPRESSED_BYTES_TO_BUFFER]; - let compressed_bytes_read = self.compressed_reader.read(&mut next_compressed)?; - if compressed_bytes_read == 0 { + let compressed_bytes = self.compressed_reader.fill_buf()?; + if compressed_bytes.is_empty() { break; } - self.stream - .write_all(&next_compressed[..compressed_bytes_read])?; + self.stream.write_all(compressed_bytes)?; bytes_read += self .stream .get_output_mut() diff --git a/src/read/xz.rs b/src/read/xz.rs index 478ae1024..991df62b0 100644 --- a/src/read/xz.rs +++ b/src/read/xz.rs @@ -2,12 +2,12 @@ use crc32fast::Hasher; use lzma_rs::decompress::raw::Lzma2Decoder; use std::{ collections::VecDeque, - io::{BufRead, BufReader, Error, Read, Result, Write}, + io::{BufRead, Error, Read, Result, Write}, }; #[derive(Debug)] -pub struct XzDecoder { - compressed_reader: BufReader, +pub struct XzDecoder { + compressed_reader: R, stream_size: usize, buf: VecDeque, check_size: usize, @@ -15,10 +15,10 @@ pub struct XzDecoder { flags: [u8; 2], } -impl XzDecoder { +impl XzDecoder { pub fn new(inner: R) -> Self { XzDecoder { - compressed_reader: BufReader::new(inner), + compressed_reader: inner, stream_size: 0, buf: VecDeque::new(), check_size: 0, @@ -83,7 +83,7 @@ fn error(s: &'static str) -> Result { Err(Error::new(std::io::ErrorKind::InvalidData, s)) } -fn get_multibyte(input: &mut R, hasher: &mut Hasher) -> Result { +fn get_multibyte(input: &mut R, hasher: &mut Hasher) -> Result { let mut result = 0; for i in 0..9 { let mut b = [0u8; 1]; @@ -98,7 +98,7 @@ fn get_multibyte(input: &mut R, hasher: &mut Hasher) -> Result { error("Invalid multi-byte encoding") } -impl Read for XzDecoder { +impl Read for XzDecoder { fn read(&mut self, buf: &mut [u8]) -> Result { if !self.buf.is_empty() { let len = std::cmp::min(buf.len(), self.buf.len()); @@ -263,8 +263,8 @@ impl Read for XzDecoder { } } -impl XzDecoder { +impl XzDecoder { pub fn into_inner(self) -> R { - self.compressed_reader.into_inner() + self.compressed_reader } } diff --git a/src/write.rs b/src/write.rs index a8a31a53f..35ec5777a 100644 --- a/src/write.rs +++ b/src/write.rs @@ -1297,7 +1297,7 @@ impl ZipWriter { self.writing_to_file = true; self.writing_raw = true; - io::copy(file.get_raw_reader(), self)?; + io::copy(&mut file.take_raw_reader()?, self)?; Ok(()) }