Skip to content

Commit

Permalink
Update to Mio v0.8
Browse files Browse the repository at this point in the history
The major breaking change in Mio v0.8 is TcpSocket type being removed.

Replacing Mio's TcpSocket we switch to the socket2 library which
provides a similar type Socket, as well as SockRef, which provide all
options TcpSocket provided (and more!).

Tokio's TcpSocket type is now backed by Socket2 instead of Mio's
TcpSocket. The main pitfall here is that socket2 isn't non-blocking by
default, which Mio obviously is. As a result we have to do potentially
blocking calls more carefully, specifically we need to handle
would-block-like errors when connecting the TcpSocket ourselves.

One benefit for this change is that adding more socket options to
TcpSocket is now merely a single function call away (in most cases
anyway).
  • Loading branch information
Thomasdezeeuw committed Feb 13, 2022
1 parent ac0f894 commit 8fb15da
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 50 deletions.
18 changes: 9 additions & 9 deletions tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,18 @@ macros = ["tokio-macros"]
net = [
"libc",
"mio/os-poll",
"mio/os-util",
"mio/tcp",
"mio/udp",
"mio/uds",
"mio/os-ext",
"mio/net",
"socket2",
"winapi/namedpipeapi",
]
process = [
"bytes",
"once_cell",
"libc",
"mio/os-poll",
"mio/os-util",
"mio/uds",
"mio/os-ext",
"mio/net",
"signal-hook-registry",
"winapi/threadpoollegacyapiset",
]
Expand All @@ -75,8 +74,8 @@ signal = [
"once_cell",
"libc",
"mio/os-poll",
"mio/uds",
"mio/os-util",
"mio/net",
"mio/os-ext",
"signal-hook-registry",
"winapi/consoleapi",
]
Expand All @@ -98,7 +97,8 @@ pin-project-lite = "0.2.0"
bytes = { version = "1.0.0", optional = true }
once_cell = { version = "1.5.2", optional = true }
memchr = { version = "2.2", optional = true }
mio = { version = "0.7.11", optional = true }
mio = { version = "0.8.0", optional = true }
socket2 = { version = "0.4.4", optional = true, features = [ "all" ] }
num_cpus = { version = "1.8.0", optional = true }
parking_lot = { version = "0.12.0", optional = true }

Expand Down
117 changes: 97 additions & 20 deletions tokio/src/net/tcp/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ cfg_net! {
/// [`socket2`]: https://docs.rs/socket2/
#[cfg_attr(docsrs, doc(alias = "connect_std"))]
pub struct TcpSocket {
inner: mio::net::TcpSocket,
inner: socket2::Socket,
}
}

Expand Down Expand Up @@ -119,8 +119,7 @@ impl TcpSocket {
/// }
/// ```
pub fn new_v4() -> io::Result<TcpSocket> {
let inner = mio::net::TcpSocket::new_v4()?;
Ok(TcpSocket { inner })
TcpSocket::new(socket2::Domain::IPV4)
}

/// Creates a new socket configured for IPv6.
Expand Down Expand Up @@ -153,7 +152,34 @@ impl TcpSocket {
/// }
/// ```
pub fn new_v6() -> io::Result<TcpSocket> {
let inner = mio::net::TcpSocket::new_v6()?;
TcpSocket::new(socket2::Domain::IPV6)
}

fn new(domain: socket2::Domain) -> io::Result<TcpSocket> {
let ty = socket2::Type::STREAM;
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
))]
let ty = ty.nonblocking();
let inner = socket2::Socket::new(domain, ty, Some(socket2::Protocol::TCP))?;
#[cfg(not(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
)))]
inner.set_nonblocking(true)?;
Ok(TcpSocket { inner })
}

Expand Down Expand Up @@ -184,7 +210,7 @@ impl TcpSocket {
/// }
/// ```
pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
self.inner.set_reuseaddr(reuseaddr)
self.inner.set_reuse_address(reuseaddr)
}

/// Retrieves the value set for `SO_REUSEADDR` on this socket.
Expand All @@ -210,7 +236,7 @@ impl TcpSocket {
/// }
/// ```
pub fn reuseaddr(&self) -> io::Result<bool> {
self.inner.get_reuseaddr()
self.inner.reuse_address()
}

/// Allows the socket to bind to an in-use port. Only available for unix systems
Expand Down Expand Up @@ -244,7 +270,7 @@ impl TcpSocket {
doc(cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos"))))
)]
pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> {
self.inner.set_reuseport(reuseport)
self.inner.set_reuse_port(reuseport)
}

/// Allows the socket to bind to an in-use port. Only available for unix systems
Expand Down Expand Up @@ -279,14 +305,14 @@ impl TcpSocket {
doc(cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos"))))
)]
pub fn reuseport(&self) -> io::Result<bool> {
self.inner.get_reuseport()
self.inner.reuse_port()
}

/// Sets the size of the TCP send buffer on this socket.
///
/// On most operating systems, this sets the `SO_SNDBUF` socket option.
pub fn set_send_buffer_size(&self, size: u32) -> io::Result<()> {
self.inner.set_send_buffer_size(size)
self.inner.set_send_buffer_size(size as usize)
}

/// Returns the size of the TCP send buffer for this socket.
Expand All @@ -313,14 +339,14 @@ impl TcpSocket {
///
/// [`set_send_buffer_size`]: #method.set_send_buffer_size
pub fn send_buffer_size(&self) -> io::Result<u32> {
self.inner.get_send_buffer_size()
self.inner.send_buffer_size().map(|n| n as u32)
}

/// Sets the size of the TCP receive buffer on this socket.
///
/// On most operating systems, this sets the `SO_RCVBUF` socket option.
pub fn set_recv_buffer_size(&self, size: u32) -> io::Result<()> {
self.inner.set_recv_buffer_size(size)
self.inner.set_recv_buffer_size(size as usize)
}

/// Returns the size of the TCP receive buffer for this socket.
Expand All @@ -347,7 +373,7 @@ impl TcpSocket {
///
/// [`set_recv_buffer_size`]: #method.set_recv_buffer_size
pub fn recv_buffer_size(&self) -> io::Result<u32> {
self.inner.get_recv_buffer_size()
self.inner.recv_buffer_size().map(|n| n as u32)
}

/// Sets the linger duration of this socket by setting the SO_LINGER option.
Expand All @@ -369,7 +395,7 @@ impl TcpSocket {
///
/// [`set_linger`]: TcpSocket::set_linger
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.get_linger()
self.inner.linger()
}

/// Gets the local address of this socket.
Expand All @@ -395,7 +421,7 @@ impl TcpSocket {
/// }
/// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_localaddr()
self.inner.local_addr().and_then(convert_address)
}

/// Binds the socket to the given address.
Expand Down Expand Up @@ -427,7 +453,7 @@ impl TcpSocket {
/// }
/// ```
pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
self.inner.bind(addr)
self.inner.bind(&addr.into())
}

/// Establishes a TCP connection with a peer at the specified socket address.
Expand Down Expand Up @@ -463,7 +489,32 @@ impl TcpSocket {
/// }
/// ```
pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
let mio = self.inner.connect(addr)?;
if let Err(err) = self.inner.connect(&addr.into()) {
#[cfg(unix)]
if err.raw_os_error() != Some(libc::EINPROGRESS) {
return Err(err);
}
#[cfg(windows)]
if err.kind() != io::ErrorKind::WouldBlock {
return Err(err);
}
}
#[cfg(unix)]
let mio = {
use std::os::unix::io::{FromRawFd, IntoRawFd};

let raw_fd = self.inner.into_raw_fd();
unsafe { mio::net::TcpStream::from_raw_fd(raw_fd) }
};

#[cfg(windows)]
let mio = {
use std::os::windows::io::{FromRawSocket, IntoRawSocket};

let raw_socket = self.inner.into_raw_socket();
unsafe { mio::net::TcpStream::from_raw_socket(raw_socket) }
};

TcpStream::connect_mio(mio).await
}

Expand Down Expand Up @@ -503,7 +554,23 @@ impl TcpSocket {
/// }
/// ```
pub fn listen(self, backlog: u32) -> io::Result<TcpListener> {
let mio = self.inner.listen(backlog)?;
self.inner.listen(backlog as i32)?;
#[cfg(unix)]
let mio = {
use std::os::unix::io::{FromRawFd, IntoRawFd};

let raw_fd = self.inner.into_raw_fd();
unsafe { mio::net::TcpListener::from_raw_fd(raw_fd) }
};

#[cfg(windows)]
let mio = {
use std::os::windows::io::{FromRawSocket, IntoRawSocket};

let raw_socket = self.inner.into_raw_socket();
unsafe { mio::net::TcpListener::from_raw_socket(raw_socket) }
};

TcpListener::new(mio)
}

Expand All @@ -523,7 +590,7 @@ impl TcpSocket {
///
/// #[tokio::main]
/// async fn main() -> std::io::Result<()> {
///
///
/// let socket2_socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
///
/// let socket = TcpSocket::from_std_stream(socket2_socket.into());
Expand All @@ -550,6 +617,16 @@ impl TcpSocket {
}
}

fn convert_address(address: socket2::SockAddr) -> io::Result<SocketAddr> {
match address.as_socket() {
Some(address) => Ok(address),
None => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid address family (not IPv4 or IPv6)",
)),
}
}

impl fmt::Debug for TcpSocket {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(fmt)
Expand All @@ -572,7 +649,7 @@ impl FromRawFd for TcpSocket {
/// The caller is responsible for ensuring that the socket is in
/// non-blocking mode.
unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
let inner = mio::net::TcpSocket::from_raw_fd(fd);
let inner = socket2::Socket::from_raw_fd(fd);
TcpSocket { inner }
}
}
Expand Down Expand Up @@ -607,7 +684,7 @@ impl FromRawSocket for TcpSocket {
/// The caller is responsible for ensuring that the socket is in
/// non-blocking mode.
unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket {
let inner = mio::net::TcpSocket::from_raw_socket(socket);
let inner = socket2::Socket::from_raw_socket(socket);
TcpSocket { inner }
}
}
24 changes: 3 additions & 21 deletions tokio/src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ impl TcpStream {
/// // if the readiness event is a false positive.
/// match stream.try_read(&mut data) {
/// Ok(n) => {
/// println!("read {} bytes", n);
/// println!("read {} bytes", n);
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
Expand Down Expand Up @@ -1090,9 +1090,7 @@ impl TcpStream {
/// # }
/// ```
pub fn linger(&self) -> io::Result<Option<Duration>> {
let mio_socket = std::mem::ManuallyDrop::new(self.to_mio());

mio_socket.get_linger()
socket2::SockRef::from(self).linger()
}

/// Sets the linger duration of this socket by setting the SO_LINGER option.
Expand All @@ -1117,23 +1115,7 @@ impl TcpStream {
/// # }
/// ```
pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
let mio_socket = std::mem::ManuallyDrop::new(self.to_mio());

mio_socket.set_linger(dur)
}

fn to_mio(&self) -> mio::net::TcpSocket {
#[cfg(windows)]
{
use std::os::windows::io::{AsRawSocket, FromRawSocket};
unsafe { mio::net::TcpSocket::from_raw_socket(self.as_raw_socket()) }
}

#[cfg(unix)]
{
use std::os::unix::io::{AsRawFd, FromRawFd};
unsafe { mio::net::TcpSocket::from_raw_fd(self.as_raw_fd()) }
}
socket2::SockRef::from(self).set_linger(dur)
}

/// Gets the value of the `IP_TTL` option for this socket.
Expand Down

0 comments on commit 8fb15da

Please sign in to comment.