From b4d524d5de348a34f4f55582edd8e0eef96956e7 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Fri, 17 Nov 2017 11:49:10 -0800 Subject: [PATCH 1/2] Have `PollEvented` fns take `&mut self` Generally speaking, it is unsafe to access to perform asynchronous operations using `&self`. Taking `&self` allows usage from a `Sync` context, which has unexpected results. Taking `&mut self` to perform these operations prevents using these asynchronous values from across tasks (unless they are wrapped in `RefCell` or `Mutex`. --- examples/proxy.rs | 14 ++++---- src/net/tcp.rs | 71 ++++--------------------------------- src/net/udp/mod.rs | 38 ++++++-------------- src/reactor/poll_evented.rs | 64 +++------------------------------ tests/udp.rs | 16 ++++----- 5 files changed, 37 insertions(+), 166 deletions(-) diff --git a/examples/proxy.rs b/examples/proxy.rs index 14a63a7f92e..9543f9cbba9 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -21,7 +21,7 @@ extern crate futures_cpupool; extern crate tokio; extern crate tokio_io; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::env; use std::net::{Shutdown, SocketAddr}; use std::io::{self, Read, Write}; @@ -65,9 +65,9 @@ fn main() { // // As a result, we wrap up our client/server manually in arcs and // use the impls below on our custom `MyTcpStream` type. - let client_reader = MyTcpStream(Arc::new(client)); + let client_reader = MyTcpStream(Arc::new(Mutex::new(client))); let client_writer = client_reader.clone(); - let server_reader = MyTcpStream(Arc::new(server)); + let server_reader = MyTcpStream(Arc::new(Mutex::new(server))); let server_writer = server_reader.clone(); // Copy the data (in parallel) between the client and the server. @@ -104,17 +104,17 @@ fn main() { // `AsyncWrite::shutdown` method which actually calls `TcpStream::shutdown` to // notify the remote end that we're done writing. #[derive(Clone)] -struct MyTcpStream(Arc); +struct MyTcpStream(Arc>); impl Read for MyTcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - (&*self.0).read(buf) + self.0.lock().unwrap().read(buf) } } impl Write for MyTcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - (&*self.0).write(buf) + self.0.lock().unwrap().write(buf) } fn flush(&mut self) -> io::Result<()> { @@ -126,7 +126,7 @@ impl AsyncRead for MyTcpStream {} impl AsyncWrite for MyTcpStream { fn shutdown(&mut self) -> Poll<(), io::Error> { - try!(self.0.shutdown(Shutdown::Write)); + try!(self.0.lock().unwrap().shutdown(Shutdown::Write)); Ok(().into()) } } diff --git a/src/net/tcp.rs b/src/net/tcp.rs index 16f6e7bfaa8..7eedee2f572 100644 --- a/src/net/tcp.rs +++ b/src/net/tcp.rs @@ -147,11 +147,6 @@ impl TcpListener { Ok(TcpListener { io: io, pending_accept: None }) } - /// Test whether this socket is ready to be read or not. - pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() - } - /// Returns the local address that this listener is bound to. /// /// This can be useful, for example, when binding to port 0 to figure out @@ -313,26 +308,6 @@ impl TcpStream { Box::new(state) } - /// Test whether this socket is ready to be read or not. - /// - /// If the socket is *not* readable then the current task is scheduled to - /// get a notification when the socket does become readable. That is, this - /// is only suitable for calling in a `Future::poll` method and will - /// automatically handle ensuring a retry once the socket is readable again. - pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() - } - - /// Test whether this socket is ready to be written to or not. - /// - /// If the socket is *not* writable then the current task is scheduled to - /// get a notification when the socket does become writable. That is, this - /// is only suitable for calling in a `Future::poll` method and will - /// automatically handle ensuring a retry once the socket is writable again. - pub fn poll_write(&self) -> Async<()> { - self.io.poll_write() - } - /// Returns the local address that this stream is bound to. pub fn local_addr(&self) -> io::Result { self.io.get_ref().local_addr() @@ -503,45 +478,10 @@ impl AsyncRead for TcpStream { } fn read_buf(&mut self, buf: &mut B) -> Poll { - <&TcpStream>::read_buf(&mut &*self, buf) - } -} - -impl AsyncWrite for TcpStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - <&TcpStream>::shutdown(&mut &*self) - } - - fn write_buf(&mut self, buf: &mut B) -> Poll { - <&TcpStream>::write_buf(&mut &*self, buf) - } -} - -impl<'a> Read for &'a TcpStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - (&self.io).read(buf) - } -} - -impl<'a> Write for &'a TcpStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - (&self.io).write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - (&self.io).flush() - } -} - -impl<'a> AsyncRead for &'a TcpStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } - - fn read_buf(&mut self, buf: &mut B) -> Poll { - if let Async::NotReady = ::poll_read(self) { + if let Async::NotReady = self.io.poll_read() { return Ok(Async::NotReady) } + let r = unsafe { // The `IoVec` type can't have a 0-length size, so we create a bunch // of dummy versions on the stack with 1 length which we'll quickly @@ -586,15 +526,16 @@ impl<'a> AsyncRead for &'a TcpStream { } } -impl<'a> AsyncWrite for &'a TcpStream { +impl AsyncWrite for TcpStream { fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(().into()) } fn write_buf(&mut self, buf: &mut B) -> Poll { - if let Async::NotReady = ::poll_write(self) { + if let Async::NotReady = self.io.poll_write() { return Ok(Async::NotReady) } + let r = { // The `IoVec` type can't have a zero-length size, so create a dummy // version from a 1-length slice which we'll overwrite with the @@ -646,7 +587,7 @@ impl Future for TcpStreamNewState { fn poll(&mut self) -> Poll { { let stream = match *self { - TcpStreamNewState::Waiting(ref s) => s, + TcpStreamNewState::Waiting(ref mut s) => s, TcpStreamNewState::Error(_) => { let e = match mem::replace(self, TcpStreamNewState::Empty) { TcpStreamNewState::Error(e) => e, diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs index 74257f8b8cc..9cb230029ad 100644 --- a/src/net/udp/mod.rs +++ b/src/net/udp/mod.rs @@ -81,10 +81,11 @@ impl UdpSocket { /// Sends data on the socket to the address previously bound via connect(). /// On success, returns the number of bytes written. - pub fn send(&self, buf: &[u8]) -> io::Result { + pub fn send(&mut self, buf: &[u8]) -> io::Result { if let Async::NotReady = self.io.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().send(buf) { Ok(n) => Ok(n), Err(e) => { @@ -98,10 +99,11 @@ impl UdpSocket { /// Receives data from the socket previously bound with connect(). /// On success, returns the number of bytes read. - pub fn recv(&self, buf: &mut [u8]) -> io::Result { + pub fn recv(&mut self, buf: &mut [u8]) -> io::Result { if let Async::NotReady = self.io.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().recv(buf) { Ok(n) => Ok(n), Err(e) => { @@ -113,35 +115,16 @@ impl UdpSocket { } } - /// Test whether this socket is ready to be read or not. - /// - /// If the socket is *not* readable then the current task is scheduled to - /// get a notification when the socket does become readable. That is, this - /// is only suitable for calling in a `Future::poll` method and will - /// automatically handle ensuring a retry once the socket is readable again. - pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() - } - - /// Test whether this socket is ready to be written to or not. - /// - /// If the socket is *not* writable then the current task is scheduled to - /// get a notification when the socket does become writable. That is, this - /// is only suitable for calling in a `Future::poll` method and will - /// automatically handle ensuring a retry once the socket is writable again. - pub fn poll_write(&self) -> Async<()> { - self.io.poll_write() - } - /// Sends data on the socket to the given address. On success, returns the /// number of bytes written. /// /// Address type can be any implementer of `ToSocketAddrs` trait. See its /// documentation for concrete examples. - pub fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result { + pub fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result { if let Async::NotReady = self.io.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().send_to(buf, target) { Ok(n) => Ok(n), Err(e) => { @@ -176,10 +159,11 @@ impl UdpSocket { /// Receives data from the socket. On success, returns the number of bytes /// read and the address from whence the data came. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { if let Async::NotReady = self.io.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().recv_from(buf) { Ok(n) => Ok(n), Err(e) => { @@ -397,8 +381,8 @@ impl Future for SendDgram fn poll(&mut self) -> Poll<(UdpSocket, T), io::Error> { { - let (ref sock, ref buf, ref addr) = - *self.0.as_ref().expect("SendDgram polled after completion"); + let (ref mut sock, ref buf, ref addr) = + *self.0.as_mut().expect("SendDgram polled after completion"); let n = try_nb!(sock.send_to(buf.as_ref(), addr)); if n != buf.as_ref().len() { return Err(incomplete_write("failed to send entire message \ @@ -425,7 +409,7 @@ impl Future for RecvDgram fn poll(&mut self) -> Poll { let (n, addr) = { - let (ref socket, ref mut buf) = + let (ref mut socket, ref mut buf) = *self.0.as_mut().expect("RecvDgram polled after completion"); try_nb!(socket.recv_from(buf.as_mut())) diff --git a/src/reactor/poll_evented.rs b/src/reactor/poll_evented.rs index aa19803f01e..8ba09fd6ab5 100644 --- a/src/reactor/poll_evented.rs +++ b/src/reactor/poll_evented.rs @@ -131,7 +131,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn poll_read(&self) -> Async<()> { + pub fn poll_read(&mut self) -> Async<()> { self.poll_ready(super::read_ready()) .map(|_| ()) } @@ -150,7 +150,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn poll_write(&self) -> Async<()> { + pub fn poll_write(&mut self) -> Async<()> { self.poll_ready(Ready::writable()) .map(|_| ()) } @@ -178,7 +178,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn poll_ready(&self, mask: Ready) -> Async { + pub fn poll_ready(&mut self, mask: Ready) -> Async { let bits = super::ready2usize(mask); match self.readiness.load(Ordering::SeqCst) & bits { 0 => {} @@ -221,7 +221,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn need_read(&self) { + pub fn need_read(&mut self) { let bits = super::ready2usize(super::read_ready()); self.readiness.fetch_and(!bits, Ordering::SeqCst); self.token.schedule_read(); @@ -247,7 +247,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn need_write(&self) { + pub fn need_write(&mut self) { let bits = super::ready2usize(Ready::writable()); self.readiness.fetch_and(!bits, Ordering::SeqCst); self.token.schedule_write(); @@ -318,60 +318,6 @@ impl AsyncWrite for PollEvented { } } -impl<'a, E> Read for &'a PollEvented - where &'a E: Read, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if let Async::NotReady = self.poll_read() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let r = self.get_ref().read(buf); - if is_wouldblock(&r) { - self.need_read(); - } - return r - } -} - -impl<'a, E> Write for &'a PollEvented - where &'a E: Write, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - if let Async::NotReady = self.poll_write() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let r = self.get_ref().write(buf); - if is_wouldblock(&r) { - self.need_write(); - } - return r - } - - fn flush(&mut self) -> io::Result<()> { - if let Async::NotReady = self.poll_write() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let r = self.get_ref().flush(); - if is_wouldblock(&r) { - self.need_write(); - } - return r - } -} - -impl<'a, E> AsyncRead for &'a PollEvented - where &'a E: Read, -{ -} - -impl<'a, E> AsyncWrite for &'a PollEvented - where &'a E: Write, -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(().into()) - } -} - fn is_wouldblock(r: &io::Result) -> bool { match *r { Ok(_) => false, diff --git a/tests/udp.rs b/tests/udp.rs index e22af9347a3..80d4de9591f 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -51,14 +51,14 @@ fn send_and_recv() { } trait SendFn { - fn send(&self, &UdpSocket, &[u8], &SocketAddr) -> Result; + fn send(&self, &mut UdpSocket, &[u8], &SocketAddr) -> Result; } #[derive(Debug, Clone)] struct SendTo {} impl SendFn for SendTo { - fn send(&self, socket: &UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { + fn send(&self, socket: &mut UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { socket.send_to(buf, addr) } } @@ -67,7 +67,7 @@ impl SendFn for SendTo { struct Send {} impl SendFn for Send { - fn send(&self, socket: &UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { + fn send(&self, socket: &mut UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { socket.connect(addr).expect("could not connect"); socket.send(buf) } @@ -96,7 +96,7 @@ impl Future for SendMessage { type Error = io::Error; fn poll(&mut self) -> Poll { - let n = try_nb!(self.send.send(self.socket.as_ref().unwrap(), &self.data[..], &self.addr)); + let n = try_nb!(self.send.send(self.socket.as_mut().unwrap(), &self.data[..], &self.addr)); assert_eq!(n, self.data.len()); @@ -105,14 +105,14 @@ impl Future for SendMessage { } trait RecvFn { - fn recv(&self, &UdpSocket, &mut [u8], &SocketAddr) -> Result; + fn recv(&self, &mut UdpSocket, &mut [u8], &SocketAddr) -> Result; } #[derive(Debug, Clone)] struct RecvFrom {} impl RecvFn for RecvFrom { - fn recv(&self, socket: &UdpSocket, buf: &mut [u8], + fn recv(&self, socket: &mut UdpSocket, buf: &mut [u8], expected_addr: &SocketAddr) -> Result { socket.recv_from(buf).map(|(s, addr)| { assert_eq!(addr, *expected_addr); @@ -125,7 +125,7 @@ impl RecvFn for RecvFrom { struct Recv {} impl RecvFn for Recv { - fn recv(&self, socket: &UdpSocket, buf: &mut [u8], _: &SocketAddr) -> Result { + fn recv(&self, socket: &mut UdpSocket, buf: &mut [u8], _: &SocketAddr) -> Result { socket.recv(buf) } } @@ -155,7 +155,7 @@ impl Future for RecvMessage { fn poll(&mut self) -> Poll { let mut buf = vec![0u8; 10 + self.expected_data.len() * 10]; - let n = try_nb!(self.recv.recv(&self.socket.as_ref().unwrap(), &mut buf[..], + let n = try_nb!(self.recv.recv(&mut self.socket.as_mut().unwrap(), &mut buf[..], &self.expected_addr)); assert_eq!(n, self.expected_data.len()); From 2b9d2e95795412760f2fb2fb7f70383ca311e0f8 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Fri, 17 Nov 2017 12:26:13 -0800 Subject: [PATCH 2/2] Remove atomic readiness cache in PollEvented --- src/reactor/poll_evented.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/reactor/poll_evented.rs b/src/reactor/poll_evented.rs index 8ba09fd6ab5..3616ab10a70 100644 --- a/src/reactor/poll_evented.rs +++ b/src/reactor/poll_evented.rs @@ -8,7 +8,6 @@ use std::fmt; use std::io::{self, Read, Write}; -use std::sync::atomic::{AtomicUsize, Ordering}; use futures::{Async, Poll}; use mio::event::Evented; @@ -65,7 +64,7 @@ use reactor::io_token::IoToken; /// otherwise probably avoid using two tasks on the same `PollEvented`. pub struct PollEvented { token: IoToken, - readiness: AtomicUsize, + readiness: usize, io: E, } @@ -88,7 +87,7 @@ impl PollEvented { Ok(PollEvented { token, - readiness: AtomicUsize::new(0), + readiness: 0, io: io, }) } @@ -180,12 +179,15 @@ impl PollEvented { /// task. pub fn poll_ready(&mut self, mask: Ready) -> Async { let bits = super::ready2usize(mask); - match self.readiness.load(Ordering::SeqCst) & bits { + + match self.readiness & bits { 0 => {} n => return Async::Ready(super::usize2ready(n)), } - self.readiness.fetch_or(self.token.take_readiness(), Ordering::SeqCst); - match self.readiness.load(Ordering::SeqCst) & bits { + + self.readiness |= self.token.take_readiness(); + + match self.readiness & bits { 0 => { if mask.is_writable() { self.need_write(); @@ -223,7 +225,7 @@ impl PollEvented { /// task. pub fn need_read(&mut self) { let bits = super::ready2usize(super::read_ready()); - self.readiness.fetch_and(!bits, Ordering::SeqCst); + self.readiness &= !bits; self.token.schedule_read(); } @@ -249,7 +251,7 @@ impl PollEvented { /// task. pub fn need_write(&mut self) { let bits = super::ready2usize(Ready::writable()); - self.readiness.fetch_and(!bits, Ordering::SeqCst); + self.readiness &= !bits; self.token.schedule_write(); } @@ -277,10 +279,13 @@ impl Read for PollEvented { if let Async::NotReady = self.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } + let r = self.get_mut().read(buf); + if is_wouldblock(&r) { self.need_read(); } + return r } } @@ -290,10 +295,13 @@ impl Write for PollEvented { if let Async::NotReady = self.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + let r = self.get_mut().write(buf); + if is_wouldblock(&r) { self.need_write(); } + return r } @@ -301,10 +309,13 @@ impl Write for PollEvented { if let Async::NotReady = self.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + let r = self.get_mut().flush(); + if is_wouldblock(&r) { self.need_write(); } + return r } }