Skip to content

Commit

Permalink
Poll evented mutability (#37)
Browse files Browse the repository at this point in the history
Generally speaking, it is unsafe to access to perform asynchronous
operations using `&self`. Taking `&self` allows usage from a `Sync`
context, which has unexpected results.

Taking `&mut self` to perform these operations prevents using these
asynchronous values from across tasks (unless they are wrapped in
`RefCell` or `Mutex`.
  • Loading branch information
carllerche authored Feb 1, 2018
1 parent a616220 commit 65cbfce
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 193 deletions.
14 changes: 7 additions & 7 deletions examples/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ extern crate futures_cpupool;
extern crate tokio;
extern crate tokio_io;

use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::env;
use std::net::{Shutdown, SocketAddr};
use std::io::{self, Read, Write};
Expand Down Expand Up @@ -60,9 +60,9 @@ fn main() {
//
// As a result, we wrap up our client/server manually in arcs and
// use the impls below on our custom `MyTcpStream` type.
let client_reader = MyTcpStream(Arc::new(client));
let client_reader = MyTcpStream(Arc::new(Mutex::new(client)));
let client_writer = client_reader.clone();
let server_reader = MyTcpStream(Arc::new(server));
let server_reader = MyTcpStream(Arc::new(Mutex::new(server)));
let server_writer = server_reader.clone();

// Copy the data (in parallel) between the client and the server.
Expand Down Expand Up @@ -99,17 +99,17 @@ fn main() {
// `AsyncWrite::shutdown` method which actually calls `TcpStream::shutdown` to
// notify the remote end that we're done writing.
#[derive(Clone)]
struct MyTcpStream(Arc<TcpStream>);
struct MyTcpStream(Arc<Mutex<TcpStream>>);

impl Read for MyTcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&*self.0).read(buf)
self.0.lock().unwrap().read(buf)
}
}

impl Write for MyTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
(&*self.0).write(buf)
self.0.lock().unwrap().write(buf)
}

fn flush(&mut self) -> io::Result<()> {
Expand All @@ -121,7 +121,7 @@ impl AsyncRead for MyTcpStream {}

impl AsyncWrite for MyTcpStream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
try!(self.0.shutdown(Shutdown::Write));
try!(self.0.lock().unwrap().shutdown(Shutdown::Write));
Ok(().into())
}
}
86 changes: 8 additions & 78 deletions src/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,6 @@ impl TcpListener {
Ok(TcpListener { io: io })
}

/// Test whether this socket is ready to be read or not.
///
/// # Panics
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn poll_read(&self) -> Async<()> {
self.io.poll_read()
}

/// Returns the local address that this listener is bound to.
///
/// This can be useful, for example, when binding to port 0 to figure out
Expand Down Expand Up @@ -287,32 +277,6 @@ impl TcpStream {
TcpStreamNew { inner: inner }
}

/// Test whether this stream is ready to be read or not.
///
/// If the stream is *not* readable then the current task is scheduled to
/// get a notification when the stream does become readable.
///
/// # Panics
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn poll_read(&self) -> Async<()> {
self.io.poll_read()
}

/// Test whether this stream is ready to be written or not.
///
/// If the stream is *not* writable then the current task is scheduled to
/// get a notification when the stream does become writable.
///
/// # Panics
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn poll_write(&self) -> Async<()> {
self.io.poll_write()
}

/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().local_addr()
Expand All @@ -329,8 +293,8 @@ impl TcpStream {
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying recv system call.
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
if let Async::NotReady = self.poll_read() {
pub fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Async::NotReady = self.io.poll_read() {
return Err(io::ErrorKind::WouldBlock.into())
}

Expand Down Expand Up @@ -497,45 +461,10 @@ impl AsyncRead for TcpStream {
}

fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
<&TcpStream>::read_buf(&mut &*self, buf)
}
}

impl AsyncWrite for TcpStream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
<&TcpStream>::shutdown(&mut &*self)
}

fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
<&TcpStream>::write_buf(&mut &*self, buf)
}
}

impl<'a> Read for &'a TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&self.io).read(buf)
}
}

impl<'a> Write for &'a TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
(&self.io).write(buf)
}

fn flush(&mut self) -> io::Result<()> {
(&self.io).flush()
}
}

impl<'a> AsyncRead for &'a TcpStream {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
if let Async::NotReady = <TcpStream>::poll_read(self) {
if let Async::NotReady = self.io.poll_read() {
return Ok(Async::NotReady)
}

let r = unsafe {
// The `IoVec` type can't have a 0-length size, so we create a bunch
// of dummy versions on the stack with 1 length which we'll quickly
Expand Down Expand Up @@ -580,15 +509,16 @@ impl<'a> AsyncRead for &'a TcpStream {
}
}

impl<'a> AsyncWrite for &'a TcpStream {
impl AsyncWrite for TcpStream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
Ok(().into())
}

fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
if let Async::NotReady = <TcpStream>::poll_write(self) {
if let Async::NotReady = self.io.poll_write() {
return Ok(Async::NotReady)
}

let r = {
// The `IoVec` type can't have a zero-length size, so create a dummy
// version from a 1-length slice which we'll overwrite with the
Expand Down Expand Up @@ -635,7 +565,7 @@ impl Future for TcpStreamNewState {
fn poll(&mut self) -> Poll<TcpStream, io::Error> {
{
let stream = match *self {
TcpStreamNewState::Waiting(ref s) => s,
TcpStreamNewState::Waiting(ref mut s) => s,
TcpStreamNewState::Error(_) => {
let e = match mem::replace(self, TcpStreamNewState::Empty) {
TcpStreamNewState::Error(e) => e,
Expand Down
45 changes: 13 additions & 32 deletions src/net/udp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ impl UdpSocket {
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
pub fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
if let Async::NotReady = self.io.poll_write() {
return Err(io::ErrorKind::WouldBlock.into())
}

match self.io.get_ref().send(buf) {
Ok(n) => Ok(n),
Err(e) => {
Expand All @@ -107,10 +108,11 @@ impl UdpSocket {
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
pub fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Async::NotReady = self.io.poll_read() {
return Err(io::ErrorKind::WouldBlock.into())
}

match self.io.get_ref().recv(buf) {
Ok(n) => Ok(n),
Err(e) => {
Expand All @@ -122,43 +124,21 @@ impl UdpSocket {
}
}

/// Test whether this socket is ready to be read or not.
///
/// If the socket is *not* readable then the current task is scheduled to
/// get a notification when the socket does become readable.
///
/// # Panics
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn poll_read(&self) -> Async<()> {
self.io.poll_read()
}

/// Test whether this socket is ready to be written to or not.
///
/// If the socket is *not* writable then the current task is scheduled to
/// get a notification when the socket does become writable.
///
/// # Panics
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn poll_write(&self) -> Async<()> {
self.io.poll_write()
}

/// Sends data on the socket to the given address. On success, returns the
/// number of bytes written.
///
/// Address type can be any implementer of `ToSocketAddrs` trait. See its
/// documentation for concrete examples.
///
/// # Panics
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
pub fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
if let Async::NotReady = self.io.poll_write() {
return Err(io::ErrorKind::WouldBlock.into())
}

match self.io.get_ref().send_to(buf, target) {
Ok(n) => Ok(n),
Err(e) => {
Expand Down Expand Up @@ -197,10 +177,11 @@ impl UdpSocket {
///
/// This function will panic if called outside the context of a future's
/// task.
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
pub fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
if let Async::NotReady = self.io.poll_read() {
return Err(io::ErrorKind::WouldBlock.into())
}

match self.io.get_ref().recv_from(buf) {
Ok(n) => Ok(n),
Err(e) => {
Expand Down Expand Up @@ -423,8 +404,8 @@ impl<T> Future for SendDgram<T>

fn poll(&mut self) -> Poll<(UdpSocket, T), io::Error> {
{
let ref inner =
self.state.as_ref().expect("SendDgram polled after completion");
let ref mut inner =
self.state.as_mut().expect("SendDgram polled after completion");
let n = try_nb!(inner.socket.send_to(inner.buffer.as_ref(), &inner.addr));
if n != inner.buffer.as_ref().len() {
return Err(incomplete_write("failed to send entire message \
Expand Down
Loading

0 comments on commit 65cbfce

Please sign in to comment.