diff --git a/src/spooled.rs b/src/spooled.rs index 2c8eaa4e5..9b2e7e7b2 100644 --- a/src/spooled.rs +++ b/src/spooled.rs @@ -96,7 +96,7 @@ impl SpooledTempFile { } pub fn set_len(&mut self, size: u64) -> Result<(), io::Error> { - if size as usize > self.max_size { + if size > self.max_size as u64 { self.roll()?; // does nothing if already rolled over } match &mut self.inner { @@ -157,7 +157,7 @@ impl Write for SpooledTempFile { // roll over to file if necessary if matches! { &self.inner, SpooledData::InMemory(cursor) - if cursor.position() as usize + buf.len() > self.max_size + if cursor.position().saturating_add(buf.len() as u64) > self.max_size as u64 } { self.roll()?; } @@ -173,8 +173,10 @@ impl Write for SpooledTempFile { if matches! { &self.inner, SpooledData::InMemory(cursor) // Borrowed from the rust standard library. - if cursor.position() as usize + bufs.iter() - .fold(0usize, |a, b| a.saturating_add(b.len())) > self.max_size + if bufs + .iter() + .fold(cursor.position(), |a, b| a.saturating_add(b.len() as u64)) + > self.max_size as u64 } { self.roll()?; } diff --git a/src/util.rs b/src/util.rs index 8c04953a3..e91150f0a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,7 +5,11 @@ use std::{io, iter::repeat_with}; use crate::error::IoResultExt; fn tmpname(prefix: &OsStr, suffix: &OsStr, rand_len: usize) -> OsString { - let mut buf = OsString::with_capacity(prefix.len() + suffix.len() + rand_len); + let capacity = prefix + .len() + .saturating_add(suffix.len()) + .saturating_add(rand_len); + let mut buf = OsString::with_capacity(capacity); buf.push(prefix); let mut char_buf = [0u8; 4]; for c in repeat_with(fastrand::alphanumeric).take(rand_len) { diff --git a/tests/spooled.rs b/tests/spooled.rs index a3bcc03d8..406021436 100644 --- a/tests/spooled.rs +++ b/tests/spooled.rs @@ -305,3 +305,18 @@ fn test_set_len_rollover() { assert_eq!(t.read_to_end(&mut buf).unwrap(), 20); assert_eq!(buf.as_slice(), b"abcde\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); } + +#[test] +fn test_write_overflow() { + let mut t = spooled_tempfile(10); + t.seek(SeekFrom::Start(u64::MAX)).unwrap(); + assert!(t.write(b"abcde").is_err()); +} + +#[cfg(target_pointer_width = "32")] +#[test] +fn test_set_len_truncation() { + let mut t = spooled_tempfile(100); + assert!(t.set_len(usize::MAX as u64 + 5).is_ok()); + assert!(t.is_rolled()); +}