Skip to content

Commit

Permalink
fs: add set_max_buf_size to tokio::fs::File (#6411)
Browse files Browse the repository at this point in the history
  • Loading branch information
mox692 authored Mar 22, 2024
1 parent bb25a06 commit f9d78fb
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 25 deletions.
38 changes: 34 additions & 4 deletions tokio/src/fs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! [`File`]: File
use crate::fs::{asyncify, OpenOptions};
use crate::io::blocking::Buf;
use crate::io::blocking::{Buf, DEFAULT_MAX_BUF_SIZE};
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use crate::sync::Mutex;

Expand Down Expand Up @@ -90,6 +90,7 @@ use std::fs::File as StdFile;
pub struct File {
std: Arc<StdFile>,
inner: Mutex<Inner>,
max_buf_size: usize,
}

struct Inner {
Expand Down Expand Up @@ -241,6 +242,7 @@ impl File {
last_write_err: None,
pos: 0,
}),
max_buf_size: DEFAULT_MAX_BUF_SIZE,
}
}

Expand Down Expand Up @@ -508,6 +510,34 @@ impl File {
let std = self.std.clone();
asyncify(move || std.set_permissions(perm)).await
}

/// Set the maximum buffer size for the underlying [`AsyncRead`] / [`AsyncWrite`] operation.
///
/// Although Tokio uses a sensible default value for this buffer size, this function would be
/// useful for changing that default depending on the situation.
///
/// # Examples
///
/// ```no_run
/// use tokio::fs::File;
/// use tokio::io::AsyncWriteExt;
///
/// # async fn dox() -> std::io::Result<()> {
/// let mut file = File::open("foo.txt").await?;
///
/// // Set maximum buffer size to 8 MiB
/// file.set_max_buf_size(8 * 1024 * 1024);
///
/// let mut buf = vec![1; 1024 * 1024 * 1024];
///
/// // Write the 1 GiB buffer in chunks up to 8 MiB each.
/// file.write_all(&mut buf).await?;
/// # Ok(())
/// # }
/// ```
pub fn set_max_buf_size(&mut self, max_buf_size: usize) {
self.max_buf_size = max_buf_size;
}
}

impl AsyncRead for File {
Expand All @@ -531,7 +561,7 @@ impl AsyncRead for File {
return Poll::Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
buf.ensure_capacity_for(dst, me.max_buf_size);
let std = me.std.clone();

inner.state = State::Busy(spawn_blocking(move || {
Expand Down Expand Up @@ -668,7 +698,7 @@ impl AsyncWrite for File {
None
};

let n = buf.copy_from(src);
let n = buf.copy_from(src, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
Expand Down Expand Up @@ -739,7 +769,7 @@ impl AsyncWrite for File {
None
};

let n = buf.copy_from_bufs(bufs);
let n = buf.copy_from_bufs(bufs, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/fs/file/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ fn flush_while_idle() {
#[cfg_attr(miri, ignore)] // takes a really long time with miri
fn read_with_buffer_larger_than_max() {
// Chunks
let chunk_a = crate::io::blocking::MAX_BUF;
let chunk_a = crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
let chunk_b = chunk_a * 2;
let chunk_c = chunk_a * 3;
let chunk_d = chunk_a * 4;
Expand Down Expand Up @@ -303,7 +303,7 @@ fn read_with_buffer_larger_than_max() {
#[cfg_attr(miri, ignore)] // takes a really long time with miri
fn write_with_buffer_larger_than_max() {
// Chunks
let chunk_a = crate::io::blocking::MAX_BUF;
let chunk_a = crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
let chunk_b = chunk_a * 2;
let chunk_c = chunk_a * 3;
let chunk_d = chunk_a * 4;
Expand Down
1 change: 1 addition & 0 deletions tokio/src/fs/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mock! {
pub fn open(pb: PathBuf) -> io::Result<Self>;
pub fn set_len(&self, size: u64) -> io::Result<()>;
pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>;
pub fn set_max_buf_size(&self, max_buf_size: usize);
pub fn sync_all(&self) -> io::Result<()>;
pub fn sync_data(&self) -> io::Result<()>;
pub fn try_clone(&self) -> io::Result<Self>;
Expand Down
20 changes: 10 additions & 10 deletions tokio/src/io/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) struct Buf {
pos: usize,
}

pub(crate) const MAX_BUF: usize = 2 * 1024 * 1024;
pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;

#[derive(Debug)]
enum State<T> {
Expand Down Expand Up @@ -64,7 +64,7 @@ where
return Poll::Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
buf.ensure_capacity_for(dst, DEFAULT_MAX_BUF_SIZE);
let mut inner = self.inner.take().unwrap();

self.state = State::Busy(sys::run(move || {
Expand Down Expand Up @@ -111,7 +111,7 @@ where

assert!(buf.is_empty());

let n = buf.copy_from(src);
let n = buf.copy_from(src, DEFAULT_MAX_BUF_SIZE);
let mut inner = self.inner.take().unwrap();

self.state = State::Busy(sys::run(move || {
Expand Down Expand Up @@ -214,10 +214,10 @@ impl Buf {
n
}

pub(crate) fn copy_from(&mut self, src: &[u8]) -> usize {
pub(crate) fn copy_from(&mut self, src: &[u8], max_buf_size: usize) -> usize {
assert!(self.is_empty());

let n = cmp::min(src.len(), MAX_BUF);
let n = cmp::min(src.len(), max_buf_size);

self.buf.extend_from_slice(&src[..n]);
n
Expand All @@ -227,10 +227,10 @@ impl Buf {
&self.buf[self.pos..]
}

pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>, max_buf_size: usize) {
assert!(self.is_empty());

let len = cmp::min(bytes.remaining(), MAX_BUF);
let len = cmp::min(bytes.remaining(), max_buf_size);

if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
Expand Down Expand Up @@ -274,10 +274,10 @@ cfg_fs! {
ret
}

pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>]) -> usize {
pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>], max_buf_size: usize) -> usize {
assert!(self.is_empty());

let mut rem = MAX_BUF;
let mut rem = max_buf_size;
for buf in bufs {
if rem == 0 {
break
Expand All @@ -288,7 +288,7 @@ cfg_fs! {
rem -= len;
}

MAX_BUF - rem
max_buf_size - rem
}
}
}
19 changes: 10 additions & 9 deletions tokio/src/io/stdio_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
/// # Windows
/// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest,
/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `MAX_BUF`.
/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`.
/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
/// # Other platforms
/// Passes data to `inner` as is.
Expand Down Expand Up @@ -45,12 +45,13 @@ where
// 2. If buffer is small, it will not be shrunk.
// That's why, it's "textness" will not change, so we don't have
// to fixup it.
if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
if cfg!(not(any(target_os = "windows", test)))
|| buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE
{
return call_inner(buf);
}

buf = &buf[..crate::io::blocking::MAX_BUF];
buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE];

// Now there are two possibilities.
// If caller gave is binary buffer, we **should not** shrink it
Expand Down Expand Up @@ -108,7 +109,7 @@ where
#[cfg(test)]
#[cfg(not(loom))]
mod tests {
use crate::io::blocking::MAX_BUF;
use crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
use crate::io::AsyncWriteExt;
use std::io;
use std::pin::Pin;
Expand All @@ -123,7 +124,7 @@ mod tests {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
assert!(buf.len() <= MAX_BUF);
assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
assert!(std::str::from_utf8(buf).is_ok());
Poll::Ready(Ok(buf.len()))
}
Expand Down Expand Up @@ -158,7 +159,7 @@ mod tests {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
assert!(buf.len() <= MAX_BUF);
assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
self.write_history.push(buf.len());
Poll::Ready(Ok(buf.len()))
}
Expand All @@ -178,7 +179,7 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)]
fn test_splitter() {
let data = str::repeat("█", MAX_BUF);
let data = str::repeat("█", DEFAULT_MAX_BUF_SIZE);
let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
let fut = async move {
wr.write_all(data.as_bytes()).await.unwrap();
Expand All @@ -197,7 +198,7 @@ mod tests {
// was not shrunk too much.
let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
let mut data: Vec<u8> = str::repeat("a", checked_count).into();
data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1));
let mut writer = LoggingMockWriter::new();
let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
crate::runtime::Builder::new_current_thread()
Expand All @@ -214,7 +215,7 @@ mod tests {
data.len()
);
// Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk
// from the buffer: one because it was outside of MAX_BUF boundary, and
// from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and
// up to one "utf8 code point".
assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
}
Expand Down
22 changes: 22 additions & 0 deletions tokio/tests/fs_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ fn tempfile() -> NamedTempFile {
NamedTempFile::new().unwrap()
}

#[tokio::test]
async fn set_max_buf_size_read() {
let mut tempfile = tempfile();
tempfile.write_all(HELLO).unwrap();
let mut file = File::open(tempfile.path()).await.unwrap();
let mut buf = [0; 1024];
file.set_max_buf_size(1);

// A single read operation reads a maximum of 1 byte.
assert_eq!(file.read(&mut buf).await.unwrap(), 1);
}

#[tokio::test]
async fn set_max_buf_size_write() {
let tempfile = tempfile();
let mut file = File::create(tempfile.path()).await.unwrap();
file.set_max_buf_size(1);

// A single write operation writes a maximum of 1 byte.
assert_eq!(file.write(HELLO).await.unwrap(), 1);
}

#[tokio::test]
#[cfg(unix)]
async fn file_debug_fmt() {
Expand Down

0 comments on commit f9d78fb

Please sign in to comment.