diff --git a/examples/proxy.rs b/examples/proxy.rs index 98b86e9fe63..51735ba16da 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -21,7 +21,7 @@ extern crate futures_cpupool; extern crate tokio; extern crate tokio_io; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::env; use std::net::{Shutdown, SocketAddr}; use std::io::{self, Read, Write}; @@ -60,9 +60,9 @@ fn main() { // // As a result, we wrap up our client/server manually in arcs and // use the impls below on our custom `MyTcpStream` type. - let client_reader = MyTcpStream(Arc::new(client)); + let client_reader = MyTcpStream(Arc::new(Mutex::new(client))); let client_writer = client_reader.clone(); - let server_reader = MyTcpStream(Arc::new(server)); + let server_reader = MyTcpStream(Arc::new(Mutex::new(server))); let server_writer = server_reader.clone(); // Copy the data (in parallel) between the client and the server. @@ -99,17 +99,17 @@ fn main() { // `AsyncWrite::shutdown` method which actually calls `TcpStream::shutdown` to // notify the remote end that we're done writing. #[derive(Clone)] -struct MyTcpStream(Arc); +struct MyTcpStream(Arc>); impl Read for MyTcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - (&*self.0).read(buf) + self.0.lock().unwrap().read(buf) } } impl Write for MyTcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - (&*self.0).write(buf) + self.0.lock().unwrap().write(buf) } fn flush(&mut self) -> io::Result<()> { @@ -121,7 +121,7 @@ impl AsyncRead for MyTcpStream {} impl AsyncWrite for MyTcpStream { fn shutdown(&mut self) -> Poll<(), io::Error> { - try!(self.0.shutdown(Shutdown::Write)); + try!(self.0.lock().unwrap().shutdown(Shutdown::Write)); Ok(().into()) } } diff --git a/src/net/tcp.rs b/src/net/tcp.rs index 82dabdb733b..e7d068010e5 100644 --- a/src/net/tcp.rs +++ b/src/net/tcp.rs @@ -130,16 +130,6 @@ impl TcpListener { Ok(TcpListener { io: io }) } - /// Test whether this socket is ready to be read or not. - /// - /// # Panics - /// - /// This function will panic if called outside the context of a future's - /// task. - pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() - } - /// Returns the local address that this listener is bound to. /// /// This can be useful, for example, when binding to port 0 to figure out @@ -287,32 +277,6 @@ impl TcpStream { TcpStreamNew { inner: inner } } - /// Test whether this stream is ready to be read or not. - /// - /// If the stream is *not* readable then the current task is scheduled to - /// get a notification when the stream does become readable. - /// - /// # Panics - /// - /// This function will panic if called outside the context of a future's - /// task. - pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() - } - - /// Test whether this stream is ready to be written or not. - /// - /// If the stream is *not* writable then the current task is scheduled to - /// get a notification when the stream does become writable. - /// - /// # Panics - /// - /// This function will panic if called outside the context of a future's - /// task. - pub fn poll_write(&self) -> Async<()> { - self.io.poll_write() - } - /// Returns the local address that this stream is bound to. pub fn local_addr(&self) -> io::Result { self.io.get_ref().local_addr() @@ -329,8 +293,8 @@ impl TcpStream { /// /// Successive calls return the same data. This is accomplished by passing /// `MSG_PEEK` as a flag to the underlying recv system call. - pub fn peek(&self, buf: &mut [u8]) -> io::Result { - if let Async::NotReady = self.poll_read() { + pub fn peek(&mut self, buf: &mut [u8]) -> io::Result { + if let Async::NotReady = self.io.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } @@ -497,45 +461,10 @@ impl AsyncRead for TcpStream { } fn read_buf(&mut self, buf: &mut B) -> Poll { - <&TcpStream>::read_buf(&mut &*self, buf) - } -} - -impl AsyncWrite for TcpStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - <&TcpStream>::shutdown(&mut &*self) - } - - fn write_buf(&mut self, buf: &mut B) -> Poll { - <&TcpStream>::write_buf(&mut &*self, buf) - } -} - -impl<'a> Read for &'a TcpStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - (&self.io).read(buf) - } -} - -impl<'a> Write for &'a TcpStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - (&self.io).write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - (&self.io).flush() - } -} - -impl<'a> AsyncRead for &'a TcpStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } - - fn read_buf(&mut self, buf: &mut B) -> Poll { - if let Async::NotReady = ::poll_read(self) { + if let Async::NotReady = self.io.poll_read() { return Ok(Async::NotReady) } + let r = unsafe { // The `IoVec` type can't have a 0-length size, so we create a bunch // of dummy versions on the stack with 1 length which we'll quickly @@ -580,15 +509,16 @@ impl<'a> AsyncRead for &'a TcpStream { } } -impl<'a> AsyncWrite for &'a TcpStream { +impl AsyncWrite for TcpStream { fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(().into()) } fn write_buf(&mut self, buf: &mut B) -> Poll { - if let Async::NotReady = ::poll_write(self) { + if let Async::NotReady = self.io.poll_write() { return Ok(Async::NotReady) } + let r = { // The `IoVec` type can't have a zero-length size, so create a dummy // version from a 1-length slice which we'll overwrite with the @@ -635,7 +565,7 @@ impl Future for TcpStreamNewState { fn poll(&mut self) -> Poll { { let stream = match *self { - TcpStreamNewState::Waiting(ref s) => s, + TcpStreamNewState::Waiting(ref mut s) => s, TcpStreamNewState::Error(_) => { let e = match mem::replace(self, TcpStreamNewState::Empty) { TcpStreamNewState::Error(e) => e, diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs index ab79544db8c..f8715a82b2e 100644 --- a/src/net/udp/mod.rs +++ b/src/net/udp/mod.rs @@ -85,10 +85,11 @@ impl UdpSocket { /// /// This function will panic if called outside the context of a future's /// task. - pub fn send(&self, buf: &[u8]) -> io::Result { + pub fn send(&mut self, buf: &[u8]) -> io::Result { if let Async::NotReady = self.io.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().send(buf) { Ok(n) => Ok(n), Err(e) => { @@ -107,10 +108,11 @@ impl UdpSocket { /// /// This function will panic if called outside the context of a future's /// task. - pub fn recv(&self, buf: &mut [u8]) -> io::Result { + pub fn recv(&mut self, buf: &mut [u8]) -> io::Result { if let Async::NotReady = self.io.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().recv(buf) { Ok(n) => Ok(n), Err(e) => { @@ -122,43 +124,21 @@ impl UdpSocket { } } - /// Test whether this socket is ready to be read or not. - /// - /// If the socket is *not* readable then the current task is scheduled to - /// get a notification when the socket does become readable. - /// - /// # Panics - /// - /// This function will panic if called outside the context of a future's - /// task. - pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() - } - - /// Test whether this socket is ready to be written to or not. - /// - /// If the socket is *not* writable then the current task is scheduled to - /// get a notification when the socket does become writable. - /// - /// # Panics - /// - /// This function will panic if called outside the context of a future's - /// task. - pub fn poll_write(&self) -> Async<()> { - self.io.poll_write() - } - /// Sends data on the socket to the given address. On success, returns the /// number of bytes written. /// + /// Address type can be any implementer of `ToSocketAddrs` trait. See its + /// documentation for concrete examples. + /// /// # Panics /// /// This function will panic if called outside the context of a future's /// task. - pub fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result { + pub fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result { if let Async::NotReady = self.io.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().send_to(buf, target) { Ok(n) => Ok(n), Err(e) => { @@ -197,10 +177,11 @@ impl UdpSocket { /// /// This function will panic if called outside the context of a future's /// task. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { if let Async::NotReady = self.io.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } + match self.io.get_ref().recv_from(buf) { Ok(n) => Ok(n), Err(e) => { @@ -423,8 +404,8 @@ impl Future for SendDgram fn poll(&mut self) -> Poll<(UdpSocket, T), io::Error> { { - let ref inner = - self.state.as_ref().expect("SendDgram polled after completion"); + let ref mut inner = + self.state.as_mut().expect("SendDgram polled after completion"); let n = try_nb!(inner.socket.send_to(inner.buffer.as_ref(), &inner.addr)); if n != inner.buffer.as_ref().len() { return Err(incomplete_write("failed to send entire message \ diff --git a/src/reactor/poll_evented.rs b/src/reactor/poll_evented.rs index eb750f87f03..fcd36e5d8f4 100644 --- a/src/reactor/poll_evented.rs +++ b/src/reactor/poll_evented.rs @@ -8,7 +8,7 @@ use std::fmt; use std::io::{self, Read, Write}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::Ordering; use futures::{Async, Poll}; use mio::event::Evented; @@ -63,9 +63,9 @@ use reactor::{Handle, Direction}; /// method you want to also use `need_read` to signal blocking and you should /// otherwise probably avoid using two tasks on the same `PollEvented`. pub struct PollEvented { - handle: Handle, token: usize, - readiness: AtomicUsize, + handle: Handle, + readiness: usize, io: E, } @@ -92,8 +92,8 @@ impl PollEvented { Ok(PollEvented { token, + readiness: 0, handle: handle.clone(), - readiness: AtomicUsize::new(0), io: io, }) } @@ -112,7 +112,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn poll_read(&self) -> Async<()> { + pub fn poll_read(&mut self) -> Async<()> { self.poll_ready(super::read_ready()) .map(|_| ()) } @@ -131,7 +131,7 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn poll_write(&self) -> Async<()> { + pub fn poll_write(&mut self) -> Async<()> { self.poll_ready(Ready::writable()) .map(|_| ()) } @@ -159,9 +159,10 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn poll_ready(&self, mask: Ready) -> Async { + pub fn poll_ready(&mut self, mask: Ready) -> Async { let bits = super::ready2usize(mask); - match self.readiness.load(Ordering::SeqCst) & bits { + + match self.readiness & bits { 0 => {} n => return Async::Ready(super::usize2ready(n)), } @@ -171,8 +172,9 @@ impl PollEvented { io_dispatch[self.token].readiness.swap(0, Ordering::SeqCst) }).unwrap_or(0); - self.readiness.fetch_or(token_readiness, Ordering::SeqCst); - match self.readiness.load(Ordering::SeqCst) & bits { + self.readiness |= token_readiness; + + match self.readiness & bits { 0 => { if mask.is_writable() { if self.need_write().is_err() { @@ -220,9 +222,9 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn need_read(&self) -> io::Result<()> { + pub fn need_read(&mut self) -> io::Result<()> { let bits = super::ready2usize(super::read_ready()); - self.readiness.fetch_and(!bits, Ordering::SeqCst); + self.readiness &= !bits; let inner = match self.handle.inner() { Some(inner) => inner, @@ -260,9 +262,9 @@ impl PollEvented { /// /// This function will panic if called outside the context of a future's /// task. - pub fn need_write(&self) -> io::Result<()> { + pub fn need_write(&mut self) -> io::Result<()> { let bits = super::ready2usize(Ready::writable()); - self.readiness.fetch_and(!bits, Ordering::SeqCst); + self.readiness &= !bits; let inner = match self.handle.inner() { Some(inner) => inner, @@ -319,10 +321,13 @@ impl Read for PollEvented { if let Async::NotReady = self.poll_read() { return Err(io::ErrorKind::WouldBlock.into()) } + let r = self.get_mut().read(buf); + if is_wouldblock(&r) { self.need_read()?; } + return r } } @@ -332,10 +337,13 @@ impl Write for PollEvented { if let Async::NotReady = self.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } + let r = self.get_mut().write(buf); + if is_wouldblock(&r) { self.need_write()?; } + return r } @@ -343,72 +351,21 @@ impl Write for PollEvented { if let Async::NotReady = self.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } - let r = self.get_mut().flush(); - if is_wouldblock(&r) { - self.need_write()?; - } - return r - } -} - -impl AsyncRead for PollEvented { -} - -impl AsyncWrite for PollEvented { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(().into()) - } -} -impl<'a, E> Read for &'a PollEvented - where &'a E: Read, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if let Async::NotReady = self.poll_read() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let r = self.get_ref().read(buf); - if is_wouldblock(&r) { - self.need_read()?; - } - return r - } -} + let r = self.get_mut().flush(); -impl<'a, E> Write for &'a PollEvented - where &'a E: Write, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - if let Async::NotReady = self.poll_write() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let r = self.get_ref().write(buf); if is_wouldblock(&r) { self.need_write()?; } - return r - } - fn flush(&mut self) -> io::Result<()> { - if let Async::NotReady = self.poll_write() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let r = self.get_ref().flush(); - if is_wouldblock(&r) { - self.need_write()?; - } return r } } -impl<'a, E> AsyncRead for &'a PollEvented - where &'a E: Read, -{ +impl AsyncRead for PollEvented { } -impl<'a, E> AsyncWrite for &'a PollEvented - where &'a E: Write, -{ +impl AsyncWrite for PollEvented { fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(().into()) } diff --git a/tests/udp.rs b/tests/udp.rs index 9c16860fbf7..f0a47d37c62 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -48,14 +48,14 @@ fn send_and_recv() { } trait SendFn { - fn send(&self, &UdpSocket, &[u8], &SocketAddr) -> Result; + fn send(&self, &mut UdpSocket, &[u8], &SocketAddr) -> Result; } #[derive(Debug, Clone)] struct SendTo {} impl SendFn for SendTo { - fn send(&self, socket: &UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { + fn send(&self, socket: &mut UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { socket.send_to(buf, addr) } } @@ -64,7 +64,7 @@ impl SendFn for SendTo { struct Send {} impl SendFn for Send { - fn send(&self, socket: &UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { + fn send(&self, socket: &mut UdpSocket, buf: &[u8], addr: &SocketAddr) -> Result { socket.connect(addr).expect("could not connect"); socket.send(buf) } @@ -93,7 +93,7 @@ impl Future for SendMessage { type Error = io::Error; fn poll(&mut self) -> Poll { - let n = try_nb!(self.send.send(self.socket.as_ref().unwrap(), &self.data[..], &self.addr)); + let n = try_nb!(self.send.send(self.socket.as_mut().unwrap(), &self.data[..], &self.addr)); assert_eq!(n, self.data.len()); @@ -102,14 +102,14 @@ impl Future for SendMessage { } trait RecvFn { - fn recv(&self, &UdpSocket, &mut [u8], &SocketAddr) -> Result; + fn recv(&self, &mut UdpSocket, &mut [u8], &SocketAddr) -> Result; } #[derive(Debug, Clone)] struct RecvFrom {} impl RecvFn for RecvFrom { - fn recv(&self, socket: &UdpSocket, buf: &mut [u8], + fn recv(&self, socket: &mut UdpSocket, buf: &mut [u8], expected_addr: &SocketAddr) -> Result { socket.recv_from(buf).map(|(s, addr)| { assert_eq!(addr, *expected_addr); @@ -122,7 +122,7 @@ impl RecvFn for RecvFrom { struct Recv {} impl RecvFn for Recv { - fn recv(&self, socket: &UdpSocket, buf: &mut [u8], _: &SocketAddr) -> Result { + fn recv(&self, socket: &mut UdpSocket, buf: &mut [u8], _: &SocketAddr) -> Result { socket.recv(buf) } } @@ -152,7 +152,7 @@ impl Future for RecvMessage { fn poll(&mut self) -> Poll { let mut buf = vec![0u8; 10 + self.expected_data.len() * 10]; - let n = try_nb!(self.recv.recv(&self.socket.as_ref().unwrap(), &mut buf[..], + let n = try_nb!(self.recv.recv(&mut self.socket.as_mut().unwrap(), &mut buf[..], &self.expected_addr)); assert_eq!(n, self.expected_data.len());