diff --git a/src/io.rs b/src/io.rs index 54ed881c3..a8bc97d03 100644 --- a/src/io.rs +++ b/src/io.rs @@ -13,6 +13,7 @@ // limitations under the License. use anyhow::{bail, ensure, Context, Error, Result}; +use bincode::Options; use flate2::read::GzDecoder; use openssl::hash::{Hasher, MessageDigest}; use openssl::sha; @@ -257,6 +258,49 @@ impl Read for LimitReader { } } +pub struct LimitWriter { + sink: W, + length: u64, + remaining: u64, + conflict: String, +} + +impl LimitWriter { + pub fn new(sink: W, length: u64, conflict: String) -> Self { + Self { + sink, + length, + remaining: length, + conflict, + } + } +} + +impl Write for LimitWriter { + fn write(&mut self, buf: &[u8]) -> result::Result { + if buf.is_empty() { + return Ok(0); + } + let allowed = self.remaining.min(buf.len() as u64); + if allowed == 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("collision with {} at offset {}", self.conflict, self.length), + )); + } + let count = self.sink.write(&buf[..allowed as usize])?; + self.remaining = self + .remaining + .checked_sub(count as u64) + .expect("wrote more bytes than allowed"); + Ok(count) + } + + fn flush(&mut self) -> result::Result<(), io::Error> { + self.sink.flush() + } +} + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Default)] pub struct Sha256Digest(pub [u8; 32]); @@ -355,6 +399,16 @@ impl TryFrom> for Sha256Digest { } } +/// Provides uniform bincode options for all our serialization operations. +pub fn bincoder() -> impl bincode::Options { + bincode::options() + .allow_trailing_bytes() + // make the defaults explicit + .with_no_limit() + .with_little_endian() + .with_varint_encoding() +} + #[cfg(test)] mod tests { use super::*; @@ -468,10 +522,7 @@ mod tests { #[test] fn limit_reader_test() { // build input data - let mut data: Vec = Vec::new(); - for i in 0..100 { - data.push(i); - } + let data: Vec = (0..100).collect(); // limit larger than file let mut file = Cursor::new(data.clone()); @@ -528,4 +579,43 @@ mod tests { "collision with foo at offset 50" ); } + + #[test] + fn limit_writer_test() { + let data: Vec = (0..100).collect(); + + // limit larger than data + let mut outbuf: Vec = Vec::new(); + let mut lim = LimitWriter::new(&mut outbuf, 150, "foo".into()); + lim.write_all(&data).unwrap(); + lim.flush().unwrap(); + assert_eq!(data, outbuf); + + // limit exactly equal to data + let mut outbuf: Vec = Vec::new(); + let mut lim = LimitWriter::new(&mut outbuf, 100, "foo".into()); + lim.write_all(&data).unwrap(); + lim.flush().unwrap(); + assert_eq!(data, outbuf); + + // limit smaller than data + let mut outbuf: Vec = Vec::new(); + let mut lim = LimitWriter::new(&mut outbuf, 90, "foo".into()); + assert_eq!( + lim.write_all(&data).unwrap_err().to_string(), + "collision with foo at offset 90" + ); + + // directly test writing in multiple chunks + let mut outbuf: Vec = Vec::new(); + let mut lim = LimitWriter::new(&mut outbuf, 90, "foo".into()); + assert_eq!(lim.write(&data[0..60]).unwrap(), 60); + assert_eq!(lim.write(&data[60..100]).unwrap(), 30); // short write + assert_eq!( + lim.write(&data[90..100]).unwrap_err().to_string(), + "collision with foo at offset 90" + ); + assert_eq!(lim.write(&data[0..0]).unwrap(), 0); + assert_eq!(&data[0..90], &outbuf); + } } diff --git a/src/iso9660.rs b/src/iso9660.rs index 623219ef1..b4a4ed91a 100644 --- a/src/iso9660.rs +++ b/src/iso9660.rs @@ -27,14 +27,14 @@ // straightforward to see to what they correspond using the referenced linked above. use std::fs; -use std::io::{BufReader, Read, Seek, SeekFrom}; +use std::io::{BufReader, Read, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; use anyhow::{anyhow, bail, Context, Result}; use bytes::{Buf, Bytes}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; -use crate::io::BUFFER_SIZE; +use crate::io::*; // technically the standard supports others, but this is the only one we support const ISO9660_SECTOR_SIZE: usize = 2048; @@ -132,7 +132,7 @@ impl IsoFs { /// Returns a reader for a file record. pub fn read_file(&mut self, file: &File) -> Result { self.file - .seek(SeekFrom::Start(file.address)) + .seek(SeekFrom::Start(file.address.as_offset())) .with_context(|| format!("seeking to file {}", file.name))?; Ok(BufReader::with_capacity( BUFFER_SIZE, @@ -140,6 +140,18 @@ impl IsoFs { )) } + /// Returns a writer for a file record. + pub fn overwrite_file(&mut self, file: &File) -> Result { + self.file + .seek(SeekFrom::Start(file.address.as_offset())) + .with_context(|| format!("seeking to file {}", file.name))?; + Ok(LimitWriter::new( + &mut self.file, + file.length as u64, + format!("end of file {}", file.name), + )) + } + fn get_primary_volume_descriptor(&self) -> Result<&PrimaryVolumeDescriptor> { for d in &self.descriptors { if let VolumeDescriptor::Primary(p) = d { @@ -198,17 +210,30 @@ impl DirectoryRecord { #[derive(Debug, Serialize, Clone)] pub struct Directory { pub name: String, - pub address: u64, + pub address: Address, pub length: u32, } #[derive(Debug, Serialize, Clone)] pub struct File { pub name: String, - pub address: u64, + pub address: Address, pub length: u32, } +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +pub struct Address(u32); + +impl Address { + pub fn as_offset(&self) -> u64 { + self.0 as u64 * ISO9660_SECTOR_SIZE as u64 + } + + pub fn as_sector(&self) -> u32 { + self.0 + } +} + /// Requested path was not found. #[derive(Debug, thiserror::Error)] #[error("{0}")] @@ -216,8 +241,8 @@ pub struct NotFound(String); /// Reads all the volume descriptors. fn get_volume_descriptors(f: &mut fs::File) -> Result> { - const ISO9660_VOLUME_DESCRIPTORS: u64 = 0x10 * (ISO9660_SECTOR_SIZE as u64); - f.seek(SeekFrom::Start(ISO9660_VOLUME_DESCRIPTORS)) + const ISO9660_VOLUME_DESCRIPTORS: Address = Address(0x10); + f.seek(SeekFrom::Start(ISO9660_VOLUME_DESCRIPTORS.as_offset())) .context("seeking to volume descriptors")?; let mut descriptors: Vec = Vec::new(); @@ -310,7 +335,7 @@ pub struct IsoFsIterator { impl IsoFsIterator { fn new(iso: &mut fs::File, dir: &Directory) -> Result { - iso.seek(SeekFrom::Start(dir.address)) + iso.seek(SeekFrom::Start(dir.address.as_offset())) .with_context(|| format!("seeking to directory {}", dir.name))?; let mut buf = vec![0; dir.length as usize]; @@ -409,7 +434,7 @@ fn get_next_directory_record( bail!("incomplete directory record; corrupt ISO?"); } - let address = (eat(buf, 1).get_u32_le() as u64) * (ISO9660_SECTOR_SIZE as u64); + let address = Address(eat(buf, 1).get_u32_le()); let length = eat(buf, 4).get_u32_le(); let flags = eat(buf, 25 - 14).get_u8(); let name_length = eat(buf, 32 - 26).get_u8() as usize; diff --git a/src/live.rs b/src/live.rs index ed00f539a..6677ea362 100644 --- a/src/live.rs +++ b/src/live.rs @@ -697,21 +697,26 @@ pub fn iso_extract_pxe(config: &IsoExtractPxeConfig) -> Result<()> { s }; let path = Path::new(&config.output_dir).join(&filename); - let mut outf = OpenOptions::new() - .write(true) - .create_new(true) - .open(&path) - .with_context(|| format!("opening {}", path.display()))?; - let mut bufw = BufWriter::with_capacity(BUFFER_SIZE, &mut outf); println!("{}", path.display()); - copy(&mut iso.read_file(&file)?, &mut bufw)?; - bufw.flush()?; + copy_file_from_iso(&mut iso, &file, &path)?; } } } Ok(()) } +fn copy_file_from_iso(iso: &mut IsoFs, file: &iso9660::File, output_path: &Path) -> Result<()> { + let mut outf = OpenOptions::new() + .write(true) + .create_new(true) + .open(&output_path) + .with_context(|| format!("opening {}", output_path.display()))?; + let mut bufw = BufWriter::with_capacity(BUFFER_SIZE, &mut outf); + copy(&mut iso.read_file(file)?, &mut bufw)?; + bufw.flush()?; + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/osmet/file.rs b/src/osmet/file.rs index 9fd1d7ac2..6acf85873 100644 --- a/src/osmet/file.rs +++ b/src/osmet/file.rs @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; use structopt::clap::crate_version; use xz2::read::XzDecoder; -use crate::io::BUFFER_SIZE; +use crate::io::{bincoder, BUFFER_SIZE}; use super::*; @@ -85,10 +85,11 @@ pub(super) fn osmet_file_write( .tempfile_in(path.parent().unwrap())?, ); - bincoder() + let coder = &mut bincoder(); + coder .serialize_into(&mut f, &header) .context("failed to serialize osmet file header")?; - bincoder() + coder .serialize_into(&mut f, &osmet) .context("failed to serialize osmet")?; @@ -201,12 +202,3 @@ fn verify_canonical(mappings: &[Mapping]) -> Result { Ok(cursor) } - -fn bincoder() -> impl bincode::Options { - bincode::options() - .allow_trailing_bytes() - // make the defaults explicit - .with_no_limit() - .with_little_endian() - .with_varint_encoding() -}