diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 033b2982c7d..efce9fda990 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -3,7 +3,7 @@ //! [`File`]: File use crate::fs::{asyncify, OpenOptions}; -use crate::io::blocking::Buf; +use crate::io::blocking::{Buf, DEFAULT_MAX_BUF_SIZE}; use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use crate::sync::Mutex; @@ -90,6 +90,7 @@ use std::fs::File as StdFile; pub struct File { std: Arc, inner: Mutex, + max_buf_size: usize, } struct Inner { @@ -241,6 +242,7 @@ impl File { last_write_err: None, pos: 0, }), + max_buf_size: DEFAULT_MAX_BUF_SIZE, } } @@ -508,6 +510,34 @@ impl File { let std = self.std.clone(); asyncify(move || std.set_permissions(perm)).await } + + /// Set the maximum buffer size for the underlying [`AsyncRead`] / [`AsyncWrite`] operation. + /// + /// Although Tokio uses a sensible default value for this buffer size, this function would be + /// useful for changing that default depending on the situation. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::AsyncWriteExt; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::open("foo.txt").await?; + /// + /// // Set maximum buffer size to 8 MiB + /// file.set_max_buf_size(8 * 1024 * 1024); + /// + /// let mut buf = vec![1; 1024 * 1024 * 1024]; + /// + /// // Write the 1 GiB buffer in chunks up to 8 MiB each. + /// file.write_all(&mut buf).await?; + /// # Ok(()) + /// # } + /// ``` + pub fn set_max_buf_size(&mut self, max_buf_size: usize) { + self.max_buf_size = max_buf_size; + } } impl AsyncRead for File { @@ -531,7 +561,7 @@ impl AsyncRead for File { return Poll::Ready(Ok(())); } - buf.ensure_capacity_for(dst); + buf.ensure_capacity_for(dst, me.max_buf_size); let std = me.std.clone(); inner.state = State::Busy(spawn_blocking(move || { @@ -668,7 +698,7 @@ impl AsyncWrite for File { None }; - let n = buf.copy_from(src); + let n = buf.copy_from(src, me.max_buf_size); let std = me.std.clone(); let blocking_task_join_handle = spawn_mandatory_blocking(move || { @@ -739,7 +769,7 @@ impl AsyncWrite for File { None }; - let n = buf.copy_from_bufs(bufs); + let n = buf.copy_from_bufs(bufs, me.max_buf_size); let std = me.std.clone(); let blocking_task_join_handle = spawn_mandatory_blocking(move || { diff --git a/tokio/src/fs/file/tests.rs b/tokio/src/fs/file/tests.rs index 7c61b3c4b31..e824876c131 100644 --- a/tokio/src/fs/file/tests.rs +++ b/tokio/src/fs/file/tests.rs @@ -231,7 +231,7 @@ fn flush_while_idle() { #[cfg_attr(miri, ignore)] // takes a really long time with miri fn read_with_buffer_larger_than_max() { // Chunks - let chunk_a = crate::io::blocking::MAX_BUF; + let chunk_a = crate::io::blocking::DEFAULT_MAX_BUF_SIZE; let chunk_b = chunk_a * 2; let chunk_c = chunk_a * 3; let chunk_d = chunk_a * 4; @@ -303,7 +303,7 @@ fn read_with_buffer_larger_than_max() { #[cfg_attr(miri, ignore)] // takes a really long time with miri fn write_with_buffer_larger_than_max() { // Chunks - let chunk_a = crate::io::blocking::MAX_BUF; + let chunk_a = crate::io::blocking::DEFAULT_MAX_BUF_SIZE; let chunk_b = chunk_a * 2; let chunk_c = chunk_a * 3; let chunk_d = chunk_a * 4; diff --git a/tokio/src/fs/mocks.rs b/tokio/src/fs/mocks.rs index b718ed54f95..a2ce1cd6ca3 100644 --- a/tokio/src/fs/mocks.rs +++ b/tokio/src/fs/mocks.rs @@ -30,6 +30,7 @@ mock! { pub fn open(pb: PathBuf) -> io::Result; pub fn set_len(&self, size: u64) -> io::Result<()>; pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>; + pub fn set_max_buf_size(&self, max_buf_size: usize); pub fn sync_all(&self) -> io::Result<()>; pub fn sync_data(&self) -> io::Result<()>; pub fn try_clone(&self) -> io::Result; diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index b5d7dca2b5c..52aa798c4fe 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -23,7 +23,7 @@ pub(crate) struct Buf { pos: usize, } -pub(crate) const MAX_BUF: usize = 2 * 1024 * 1024; +pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024; #[derive(Debug)] enum State { @@ -64,7 +64,7 @@ where return Poll::Ready(Ok(())); } - buf.ensure_capacity_for(dst); + buf.ensure_capacity_for(dst, DEFAULT_MAX_BUF_SIZE); let mut inner = self.inner.take().unwrap(); self.state = State::Busy(sys::run(move || { @@ -111,7 +111,7 @@ where assert!(buf.is_empty()); - let n = buf.copy_from(src); + let n = buf.copy_from(src, DEFAULT_MAX_BUF_SIZE); let mut inner = self.inner.take().unwrap(); self.state = State::Busy(sys::run(move || { @@ -214,10 +214,10 @@ impl Buf { n } - pub(crate) fn copy_from(&mut self, src: &[u8]) -> usize { + pub(crate) fn copy_from(&mut self, src: &[u8], max_buf_size: usize) -> usize { assert!(self.is_empty()); - let n = cmp::min(src.len(), MAX_BUF); + let n = cmp::min(src.len(), max_buf_size); self.buf.extend_from_slice(&src[..n]); n @@ -227,10 +227,10 @@ impl Buf { &self.buf[self.pos..] } - pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) { + pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>, max_buf_size: usize) { assert!(self.is_empty()); - let len = cmp::min(bytes.remaining(), MAX_BUF); + let len = cmp::min(bytes.remaining(), max_buf_size); if self.buf.len() < len { self.buf.reserve(len - self.buf.len()); @@ -274,10 +274,10 @@ cfg_fs! { ret } - pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>]) -> usize { + pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>], max_buf_size: usize) -> usize { assert!(self.is_empty()); - let mut rem = MAX_BUF; + let mut rem = max_buf_size; for buf in bufs { if rem == 0 { break @@ -288,7 +288,7 @@ cfg_fs! { rem -= len; } - MAX_BUF - rem + max_buf_size - rem } } } diff --git a/tokio/src/io/stdio_common.rs b/tokio/src/io/stdio_common.rs index c32b889e582..4adbfe23606 100644 --- a/tokio/src/io/stdio_common.rs +++ b/tokio/src/io/stdio_common.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; /// # Windows /// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest, -/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `MAX_BUF`. +/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`. /// That's why, wrapped writer will always receive well-formed utf-8 bytes. /// # Other platforms /// Passes data to `inner` as is. @@ -45,12 +45,13 @@ where // 2. If buffer is small, it will not be shrunk. // That's why, it's "textness" will not change, so we don't have // to fixup it. - if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF + if cfg!(not(any(target_os = "windows", test))) + || buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE { return call_inner(buf); } - buf = &buf[..crate::io::blocking::MAX_BUF]; + buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE]; // Now there are two possibilities. // If caller gave is binary buffer, we **should not** shrink it @@ -108,7 +109,7 @@ where #[cfg(test)] #[cfg(not(loom))] mod tests { - use crate::io::blocking::MAX_BUF; + use crate::io::blocking::DEFAULT_MAX_BUF_SIZE; use crate::io::AsyncWriteExt; use std::io; use std::pin::Pin; @@ -123,7 +124,7 @@ mod tests { _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - assert!(buf.len() <= MAX_BUF); + assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE); assert!(std::str::from_utf8(buf).is_ok()); Poll::Ready(Ok(buf.len())) } @@ -158,7 +159,7 @@ mod tests { _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - assert!(buf.len() <= MAX_BUF); + assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE); self.write_history.push(buf.len()); Poll::Ready(Ok(buf.len())) } @@ -178,7 +179,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] fn test_splitter() { - let data = str::repeat("█", MAX_BUF); + let data = str::repeat("█", DEFAULT_MAX_BUF_SIZE); let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); let fut = async move { wr.write_all(data.as_bytes()).await.unwrap(); @@ -197,7 +198,7 @@ mod tests { // was not shrunk too much. let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; let mut data: Vec = str::repeat("a", checked_count).into(); - data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1)); + data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1)); let mut writer = LoggingMockWriter::new(); let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); crate::runtime::Builder::new_current_thread() @@ -214,7 +215,7 @@ mod tests { data.len() ); // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk - // from the buffer: one because it was outside of MAX_BUF boundary, and + // from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and // up to one "utf8 code point". assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); } diff --git a/tokio/tests/fs_file.rs b/tokio/tests/fs_file.rs index 6a8b07a7ffe..520c4ec8438 100644 --- a/tokio/tests/fs_file.rs +++ b/tokio/tests/fs_file.rs @@ -180,6 +180,28 @@ fn tempfile() -> NamedTempFile { NamedTempFile::new().unwrap() } +#[tokio::test] +async fn set_max_buf_size_read() { + let mut tempfile = tempfile(); + tempfile.write_all(HELLO).unwrap(); + let mut file = File::open(tempfile.path()).await.unwrap(); + let mut buf = [0; 1024]; + file.set_max_buf_size(1); + + // A single read operation reads a maximum of 1 byte. + assert_eq!(file.read(&mut buf).await.unwrap(), 1); +} + +#[tokio::test] +async fn set_max_buf_size_write() { + let tempfile = tempfile(); + let mut file = File::create(tempfile.path()).await.unwrap(); + file.set_max_buf_size(1); + + // A single write operation writes a maximum of 1 byte. + assert_eq!(file.write(HELLO).await.unwrap(), 1); +} + #[tokio::test] #[cfg(unix)] async fn file_debug_fmt() {