|
| 1 | +use futures_core::future::Future; |
| 2 | +use futures_core::task::{Context, Poll}; |
| 3 | +use futures_io::AsyncWrite; |
| 4 | +use futures_io::IoSlice; |
| 5 | +use std::io; |
| 6 | +use std::mem; |
| 7 | +use std::pin::Pin; |
| 8 | + |
| 9 | +/// Future for the |
| 10 | +/// [`write_all_vectored`](super::AsyncWriteExt::write_all_vectored) method. |
| 11 | +#[derive(Debug)] |
| 12 | +#[must_use = "futures do nothing unless you `.await` or poll them"] |
| 13 | +pub struct WriteAllVectored<'a, W: ?Sized + Unpin> { |
| 14 | + writer: &'a mut W, |
| 15 | + bufs: &'a mut [IoSlice<'a>], |
| 16 | +} |
| 17 | + |
| 18 | +impl<W: ?Sized + Unpin> Unpin for WriteAllVectored<'_, W> {} |
| 19 | + |
| 20 | +impl<'a, W: AsyncWrite + ?Sized + Unpin> WriteAllVectored<'a, W> { |
| 21 | + pub(super) fn new(writer: &'a mut W, bufs: &'a mut [IoSlice<'a>]) -> Self { |
| 22 | + WriteAllVectored { writer, bufs } |
| 23 | + } |
| 24 | +} |
| 25 | + |
| 26 | +impl<W: AsyncWrite + ?Sized + Unpin> Future for WriteAllVectored<'_, W> { |
| 27 | + type Output = io::Result<()>; |
| 28 | + |
| 29 | + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 30 | + let this = &mut *self; |
| 31 | + while !this.bufs.is_empty() { |
| 32 | + let n = ready!(Pin::new(&mut this.writer).poll_write_vectored(cx, this.bufs))?; |
| 33 | + if n == 0 { |
| 34 | + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
| 35 | + } else { |
| 36 | + this.bufs = IoSlice::advance(mem::take(&mut this.bufs), n); |
| 37 | + } |
| 38 | + } |
| 39 | + |
| 40 | + Poll::Ready(Ok(())) |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +#[cfg(test)] |
| 45 | +mod tests { |
| 46 | + use std::cmp::min; |
| 47 | + use std::future::Future; |
| 48 | + use std::io; |
| 49 | + use std::pin::Pin; |
| 50 | + use std::task::{Context, Poll}; |
| 51 | + |
| 52 | + use crate::io::{AsyncWrite, AsyncWriteExt, IoSlice}; |
| 53 | + use crate::task::noop_waker; |
| 54 | + |
| 55 | + /// Create a new writer that reads from at most `n_bufs` and reads |
| 56 | + /// `per_call` bytes (in total) per call to write. |
| 57 | + fn test_writer(n_bufs: usize, per_call: usize) -> TestWriter { |
| 58 | + TestWriter { |
| 59 | + n_bufs, |
| 60 | + per_call, |
| 61 | + written: Vec::new(), |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + // TODO: maybe move this the future-test crate? |
| 66 | + struct TestWriter { |
| 67 | + n_bufs: usize, |
| 68 | + per_call: usize, |
| 69 | + written: Vec<u8>, |
| 70 | + } |
| 71 | + |
| 72 | + impl AsyncWrite for TestWriter { |
| 73 | + fn poll_write( |
| 74 | + self: Pin<&mut Self>, |
| 75 | + cx: &mut Context<'_>, |
| 76 | + buf: &[u8], |
| 77 | + ) -> Poll<io::Result<usize>> { |
| 78 | + self.poll_write_vectored(cx, &[IoSlice::new(buf)]) |
| 79 | + } |
| 80 | + |
| 81 | + fn poll_write_vectored( |
| 82 | + mut self: Pin<&mut Self>, |
| 83 | + _cx: &mut Context<'_>, |
| 84 | + bufs: &[IoSlice<'_>], |
| 85 | + ) -> Poll<io::Result<usize>> { |
| 86 | + let mut left = self.per_call; |
| 87 | + let mut written = 0; |
| 88 | + for buf in bufs.iter().take(self.n_bufs) { |
| 89 | + let n = min(left, buf.len()); |
| 90 | + self.written.extend_from_slice(&buf[0..n]); |
| 91 | + left -= n; |
| 92 | + written += n; |
| 93 | + } |
| 94 | + Poll::Ready(Ok(written)) |
| 95 | + } |
| 96 | + |
| 97 | + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 98 | + Poll::Ready(Ok(())) |
| 99 | + } |
| 100 | + |
| 101 | + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 102 | + Poll::Ready(Ok(())) |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + // TODO: maybe move this the future-test crate? |
| 107 | + macro_rules! assert_poll_ok { |
| 108 | + ($e:expr, $expected:expr) => { |
| 109 | + let expected = $expected; |
| 110 | + match $e { |
| 111 | + Poll::Ready(Ok(ok)) if ok == expected => {} |
| 112 | + got => panic!( |
| 113 | + "unexpected result, got: {:?}, wanted: Ready(Ok({:?}))", |
| 114 | + got, expected |
| 115 | + ), |
| 116 | + } |
| 117 | + }; |
| 118 | + } |
| 119 | + |
| 120 | + #[test] |
| 121 | + fn test_writer_read_from_one_buf() { |
| 122 | + let waker = noop_waker(); |
| 123 | + let mut cx = Context::from_waker(&waker); |
| 124 | + |
| 125 | + let mut dst = test_writer(1, 2); |
| 126 | + let mut dst = Pin::new(&mut dst); |
| 127 | + |
| 128 | + assert_poll_ok!(dst.as_mut().poll_write(&mut cx, &[]), 0); |
| 129 | + assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, &[]), 0); |
| 130 | + |
| 131 | + // Read at most 2 bytes. |
| 132 | + assert_poll_ok!(dst.as_mut().poll_write(&mut cx, &[1, 1, 1]), 2); |
| 133 | + let bufs = &[IoSlice::new(&[2, 2, 2])]; |
| 134 | + assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 2); |
| 135 | + |
| 136 | + // Only read from first buf. |
| 137 | + let bufs = &[IoSlice::new(&[3]), IoSlice::new(&[4, 4])]; |
| 138 | + assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 1); |
| 139 | + |
| 140 | + assert_eq!(dst.written, &[1, 1, 2, 2, 3]); |
| 141 | + } |
| 142 | + |
| 143 | + #[test] |
| 144 | + fn test_writer_read_from_multiple_bufs() { |
| 145 | + let waker = noop_waker(); |
| 146 | + let mut cx = Context::from_waker(&waker); |
| 147 | + |
| 148 | + let mut dst = test_writer(3, 3); |
| 149 | + let mut dst = Pin::new(&mut dst); |
| 150 | + |
| 151 | + // Read at most 3 bytes from two buffers. |
| 152 | + let bufs = &[IoSlice::new(&[1]), IoSlice::new(&[2, 2, 2])]; |
| 153 | + assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 3); |
| 154 | + |
| 155 | + // Read at most 3 bytes from three buffers. |
| 156 | + let bufs = &[ |
| 157 | + IoSlice::new(&[3]), |
| 158 | + IoSlice::new(&[4]), |
| 159 | + IoSlice::new(&[5, 5]), |
| 160 | + ]; |
| 161 | + assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 3); |
| 162 | + |
| 163 | + assert_eq!(dst.written, &[1, 2, 2, 3, 4, 5]); |
| 164 | + } |
| 165 | + |
| 166 | + #[test] |
| 167 | + fn test_write_all_vectored() { |
| 168 | + let waker = noop_waker(); |
| 169 | + let mut cx = Context::from_waker(&waker); |
| 170 | + |
| 171 | + #[rustfmt::skip] // Becomes unreadable otherwise. |
| 172 | + let tests: Vec<(_, &'static [u8])> = vec![ |
| 173 | + (vec![], &[]), |
| 174 | + (vec![IoSlice::new(&[1])], &[1]), |
| 175 | + (vec![IoSlice::new(&[1, 2])], &[1, 2]), |
| 176 | + (vec![IoSlice::new(&[1, 2, 3])], &[1, 2, 3]), |
| 177 | + (vec![IoSlice::new(&[1, 2, 3, 4])], &[1, 2, 3, 4]), |
| 178 | + (vec![IoSlice::new(&[1, 2, 3, 4, 5])], &[1, 2, 3, 4, 5]), |
| 179 | + (vec![IoSlice::new(&[1]), IoSlice::new(&[2])], &[1, 2]), |
| 180 | + (vec![IoSlice::new(&[1, 1]), IoSlice::new(&[2, 2])], &[1, 1, 2, 2]), |
| 181 | + (vec![IoSlice::new(&[1, 1, 1]), IoSlice::new(&[2, 2, 2])], &[1, 1, 1, 2, 2, 2]), |
| 182 | + (vec![IoSlice::new(&[1, 1, 1, 1]), IoSlice::new(&[2, 2, 2, 2])], &[1, 1, 1, 1, 2, 2, 2, 2]), |
| 183 | + (vec![IoSlice::new(&[1]), IoSlice::new(&[2]), IoSlice::new(&[3])], &[1, 2, 3]), |
| 184 | + (vec![IoSlice::new(&[1, 1]), IoSlice::new(&[2, 2]), IoSlice::new(&[3, 3])], &[1, 1, 2, 2, 3, 3]), |
| 185 | + (vec![IoSlice::new(&[1, 1, 1]), IoSlice::new(&[2, 2, 2]), IoSlice::new(&[3, 3, 3])], &[1, 1, 1, 2, 2, 2, 3, 3, 3]), |
| 186 | + ]; |
| 187 | + |
| 188 | + for (mut input, wanted) in tests.into_iter() { |
| 189 | + let mut dst = test_writer(2, 2); |
| 190 | + { |
| 191 | + let mut future = dst.write_all_vectored(&mut *input); |
| 192 | + match Pin::new(&mut future).poll(&mut cx) { |
| 193 | + Poll::Ready(Ok(())) => {} |
| 194 | + other => panic!("unexpected result polling future: {:?}", other), |
| 195 | + } |
| 196 | + } |
| 197 | + assert_eq!(&*dst.written, &*wanted); |
| 198 | + } |
| 199 | + } |
| 200 | +} |
0 commit comments