From 02e9be41f27daf822575444fdd2b3067433a5996 Mon Sep 17 00:00:00 2001 From: Thomas de Zeeuw Date: Tue, 28 Sep 2021 19:32:35 +0200 Subject: [PATCH] Remove TcpSocket type The socket2 crate provide all the functionality and more. Furthermore supporting all socket options is beyond the scope of Mio. The easier migration is to the socket2 crate, using the Socket or SockRef types. The migration for Tokio is tracked in https://github.com/tokio-rs/tokio/issues/4135. --- Cargo.toml | 2 +- src/net/mod.rs | 2 +- src/net/tcp/listener.rs | 18 +- src/net/tcp/mod.rs | 3 - src/net/tcp/socket.rs | 490 ---------------------------------------- src/net/tcp/stream.rs | 11 +- src/sys/shell/tcp.rs | 114 +--------- src/sys/unix/tcp.rs | 407 ++------------------------------- src/sys/windows/tcp.rs | 336 +++------------------------ tests/tcp.rs | 9 +- tests/tcp_socket.rs | 203 ----------------- tests/tcp_stream.rs | 17 +- tests/util/mod.rs | 47 ++++ 13 files changed, 133 insertions(+), 1526 deletions(-) delete mode 100644 src/net/tcp/socket.rs delete mode 100644 tests/tcp_socket.rs diff --git a/Cargo.toml b/Cargo.toml index cdd601047..15ab7adfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ libc = "0.2.86" [target.'cfg(windows)'.dependencies] miow = "0.3.6" -winapi = { version = "0.3", features = ["winsock2", "mswsock", "mstcpip"] } +winapi = { version = "0.3", features = ["winsock2", "mswsock"] } ntapi = "0.3" [dev-dependencies] diff --git a/src/net/mod.rs b/src/net/mod.rs index 6ea5c6d77..ec4016710 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -8,7 +8,7 @@ //! [portability guidelines]: ../struct.Poll.html#portability mod tcp; -pub use self::tcp::{TcpKeepalive, TcpListener, TcpSocket, TcpStream}; +pub use self::tcp::{TcpListener, TcpStream}; mod udp; pub use self::udp::UdpSocket; diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index da276f3b6..21bffbaff 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -5,8 +5,11 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::{fmt, io}; -use super::{TcpSocket, TcpStream}; use crate::io_source::IoSource; +use crate::net::TcpStream; +#[cfg(unix)] +use crate::sys::tcp::set_reuseaddr; +use crate::sys::tcp::{bind, listen, new_for_addr}; use crate::{event, sys, Interest, Registry, Token}; /// A structure representing a socket server @@ -50,7 +53,11 @@ impl TcpListener { /// 3. Bind the socket to the specified address. /// 4. Calls `listen` on the socket to prepare it to receive new connections. pub fn bind(addr: SocketAddr) -> io::Result { - let socket = TcpSocket::new_for_addr(addr)?; + let socket = new_for_addr(addr)?; + #[cfg(unix)] + let listener = unsafe { TcpListener::from_raw_fd(socket) }; + #[cfg(windows)] + let listener = unsafe { TcpListener::from_raw_socket(socket as _) }; // On platforms with Berkeley-derived sockets, this allows to quickly // rebind a socket, without needing to wait for the OS to clean up the @@ -60,10 +67,11 @@ impl TcpListener { // which allows “socket hijacking”, so we explicitly don't set it here. // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse #[cfg(not(windows))] - socket.set_reuseaddr(true)?; + set_reuseaddr(&listener.inner, true)?; - socket.bind(addr)?; - socket.listen(1024) + bind(&listener.inner, addr)?; + listen(&listener.inner, 1024)?; + Ok(listener) } /// Creates a new `TcpListener` from a standard `net::TcpListener`. diff --git a/src/net/tcp/mod.rs b/src/net/tcp/mod.rs index 7658bdfc4..94af5c10e 100644 --- a/src/net/tcp/mod.rs +++ b/src/net/tcp/mod.rs @@ -1,8 +1,5 @@ mod listener; pub use self::listener::TcpListener; -mod socket; -pub use self::socket::{TcpKeepalive, TcpSocket}; - mod stream; pub use self::stream::TcpStream; diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs deleted file mode 100644 index 69fbacf68..000000000 --- a/src/net/tcp/socket.rs +++ /dev/null @@ -1,490 +0,0 @@ -use std::io; -use std::mem; -use std::net::SocketAddr; -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -use std::time::Duration; - -use crate::net::{TcpListener, TcpStream}; -use crate::sys; - -/// A non-blocking TCP socket used to configure a stream or listener. -/// -/// The `TcpSocket` type wraps the operating-system's socket handle. This type -/// is used to configure the socket before establishing a connection or start -/// listening for inbound connections. -/// -/// The socket will be closed when the value is dropped. -#[derive(Debug)] -pub struct TcpSocket { - sys: sys::tcp::TcpSocket, -} - -/// Configures a socket's TCP keepalive parameters. -#[derive(Debug, Default, Clone)] -pub struct TcpKeepalive { - pub(crate) time: Option, - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - target_os = "windows", - ))] - pub(crate) interval: Option, - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))] - pub(crate) retries: Option, -} - -impl TcpSocket { - /// Create a new IPv4 TCP socket. - /// - /// This calls `socket(2)`. - pub fn new_v4() -> io::Result { - sys::tcp::new_v4_socket().map(|sys| TcpSocket { sys }) - } - - /// Create a new IPv6 TCP socket. - /// - /// This calls `socket(2)`. - pub fn new_v6() -> io::Result { - sys::tcp::new_v6_socket().map(|sys| TcpSocket { sys }) - } - - pub(crate) fn new_for_addr(addr: SocketAddr) -> io::Result { - if addr.is_ipv4() { - TcpSocket::new_v4() - } else { - TcpSocket::new_v6() - } - } - - /// Bind `addr` to the TCP socket. - pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { - sys::tcp::bind(self.sys, addr) - } - - /// Connect the socket to `addr`. - /// - /// This consumes the socket and performs the connect operation. Once the - /// connection completes, the socket is now a non-blocking `TcpStream` and - /// can be used as such. - pub fn connect(self, addr: SocketAddr) -> io::Result { - let stream = sys::tcp::connect(self.sys, addr)?; - - // Don't close the socket - mem::forget(self); - Ok(TcpStream::from_std(stream)) - } - - /// Listen for inbound connections, converting the socket to a - /// `TcpListener`. - pub fn listen(self, backlog: u32) -> io::Result { - let listener = sys::tcp::listen(self.sys, backlog)?; - - // Don't close the socket - mem::forget(self); - Ok(TcpListener::from_std(listener)) - } - - /// Sets the value of `SO_REUSEADDR` on this socket. - pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> { - sys::tcp::set_reuseaddr(self.sys, reuseaddr) - } - - /// Get the value of `SO_REUSEADDR` set on this socket. - pub fn get_reuseaddr(&self) -> io::Result { - sys::tcp::get_reuseaddr(self.sys) - } - - /// Sets the value of `SO_REUSEPORT` on this socket. - /// Only supported available in unix - #[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] - pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> { - sys::tcp::set_reuseport(self.sys, reuseport) - } - - /// Get the value of `SO_REUSEPORT` set on this socket. - /// Only supported available in unix - #[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] - pub fn get_reuseport(&self) -> io::Result { - sys::tcp::get_reuseport(self.sys) - } - - /// Sets the value of `SO_LINGER` on this socket. - pub fn set_linger(&self, dur: Option) -> io::Result<()> { - sys::tcp::set_linger(self.sys, dur) - } - - /// Gets the value of `SO_LINGER` on this socket - pub fn get_linger(&self) -> io::Result> { - sys::tcp::get_linger(self.sys) - } - - /// Sets the value of `SO_RCVBUF` on this socket. - pub fn set_recv_buffer_size(&self, size: u32) -> io::Result<()> { - sys::tcp::set_recv_buffer_size(self.sys, size) - } - - /// Get the value of `SO_RCVBUF` set on this socket. - /// - /// Note that if [`set_recv_buffer_size`] has been called on this socket - /// previously, the value returned by this function may not be the same as - /// the argument provided to `set_recv_buffer_size`. This is for the - /// following reasons: - /// - /// * Most operating systems have minimum and maximum allowed sizes for the - /// receive buffer, and will clamp the provided value if it is below the - /// minimum or above the maximum. The minimum and maximum buffer sizes are - /// OS-dependent. - /// * Linux will double the buffer size to account for internal bookkeeping - /// data, and returns the doubled value from `getsockopt(2)`. As per `man - /// 7 socket`: - /// > Sets or gets the maximum socket receive buffer in bytes. The - /// > kernel doubles this value (to allow space for bookkeeping - /// > overhead) when it is set using `setsockopt(2)`, and this doubled - /// > value is returned by `getsockopt(2)`. - /// - /// [`set_recv_buffer_size`]: #method.set_recv_buffer_size - pub fn get_recv_buffer_size(&self) -> io::Result { - sys::tcp::get_recv_buffer_size(self.sys) - } - - /// Sets the value of `SO_SNDBUF` on this socket. - pub fn set_send_buffer_size(&self, size: u32) -> io::Result<()> { - sys::tcp::set_send_buffer_size(self.sys, size) - } - - /// Get the value of `SO_SNDBUF` set on this socket. - /// - /// Note that if [`set_send_buffer_size`] has been called on this socket - /// previously, the value returned by this function may not be the same as - /// the argument provided to `set_send_buffer_size`. This is for the - /// following reasons: - /// - /// * Most operating systems have minimum and maximum allowed sizes for the - /// receive buffer, and will clamp the provided value if it is below the - /// minimum or above the maximum. The minimum and maximum buffer sizes are - /// OS-dependent. - /// * Linux will double the buffer size to account for internal bookkeeping - /// data, and returns the doubled value from `getsockopt(2)`. As per `man - /// 7 socket`: - /// > Sets or gets the maximum socket send buffer in bytes. The - /// > kernel doubles this value (to allow space for bookkeeping - /// > overhead) when it is set using `setsockopt(2)`, and this doubled - /// > value is returned by `getsockopt(2)`. - /// - /// [`set_send_buffer_size`]: #method.set_send_buffer_size - pub fn get_send_buffer_size(&self) -> io::Result { - sys::tcp::get_send_buffer_size(self.sys) - } - - /// Sets whether keepalive messages are enabled to be sent on this socket. - /// - /// This will set the `SO_KEEPALIVE` option on this socket. - pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> { - sys::tcp::set_keepalive(self.sys, keepalive) - } - - /// Returns whether or not TCP keepalive probes will be sent by this socket. - pub fn get_keepalive(&self) -> io::Result { - sys::tcp::get_keepalive(self.sys) - } - - /// Sets parameters configuring TCP keepalive probes for this socket. - /// - /// The supported parameters depend on the operating system, and are - /// configured using the [`TcpKeepalive`] struct. At a minimum, all systems - /// support configuring the [keepalive time]: the time after which the OS - /// will start sending keepalive messages on an idle connection. - /// - /// # Notes - /// - /// * This will enable TCP keepalive on this socket, if it is not already - /// enabled. - /// * On some platforms, such as Windows, any keepalive parameters *not* - /// configured by the `TcpKeepalive` struct passed to this function may be - /// overwritten with their default values. Therefore, this function should - /// either only be called once per socket, or the same parameters should - /// be passed every time it is called. - /// - /// # Examples - #[cfg_attr(feature = "os-poll", doc = "```")] - #[cfg_attr(not(feature = "os-poll"), doc = "```ignore")] - /// use mio::net::{TcpSocket, TcpKeepalive}; - /// use std::time::Duration; - /// - /// # fn main() -> Result<(), std::io::Error> { - /// let socket = TcpSocket::new_v6()?; - /// let keepalive = TcpKeepalive::default() - /// .with_time(Duration::from_secs(4)); - /// // Depending on the target operating system, we may also be able to - /// // configure the keepalive probe interval and/or the number of retries - /// // here as well. - /// - /// socket.set_keepalive_params(keepalive)?; - /// # Ok(()) } - /// ``` - /// - /// [`TcpKeepalive`]: ../struct.TcpKeepalive.html - /// [keepalive time]: ../struct.TcpKeepalive.html#method.with_time - pub fn set_keepalive_params(&self, keepalive: TcpKeepalive) -> io::Result<()> { - self.set_keepalive(true)?; - sys::tcp::set_keepalive_params(self.sys, keepalive) - } - - /// Returns the amount of time after which TCP keepalive probes will be sent - /// on idle connections. - /// - /// If `None`, then keepalive messages are disabled. - /// - /// This returns the value of `SO_KEEPALIVE` + `IPPROTO_TCP` on OpenBSD, - /// NetBSD, and Haiku, `TCP_KEEPALIVE` on macOS and iOS, and `TCP_KEEPIDLE` - /// on all other Unix operating systems. On Windows, it is not possible to - /// access the value of TCP keepalive parameters after they have been set. - /// - /// Some platforms specify this value in seconds, so sub-second - /// specifications may be omitted. - #[cfg_attr(docsrs, doc(cfg(not(target_os = "windows"))))] - #[cfg(not(target_os = "windows"))] - pub fn get_keepalive_time(&self) -> io::Result> { - sys::tcp::get_keepalive_time(self.sys) - } - - /// Returns the time interval between TCP keepalive probes, if TCP keepalive is - /// enabled on this socket. - /// - /// If `None`, then keepalive messages are disabled. - /// - /// This returns the value of `TCP_KEEPINTVL` on supported Unix operating - /// systems. On Windows, it is not possible to access the value of TCP - /// keepalive parameters after they have been set.. - /// - /// Some platforms specify this value in seconds, so sub-second - /// specifications may be omitted. - #[cfg_attr( - docsrs, - doc(cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))) - )] - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))] - pub fn get_keepalive_interval(&self) -> io::Result> { - sys::tcp::get_keepalive_interval(self.sys) - } - - /// Returns the maximum number of TCP keepalive probes that will be sent before - /// dropping a connection, if TCP keepalive is enabled on this socket. - /// - /// If `None`, then keepalive messages are disabled. - /// - /// This returns the value of `TCP_KEEPCNT` on Unix operating systems that - /// support this option. On Windows, it is not possible to access the value - /// of TCP keepalive parameters after they have been set. - #[cfg_attr( - docsrs, - doc(cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))) - )] - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))] - pub fn get_keepalive_retries(&self) -> io::Result> { - sys::tcp::get_keepalive_retries(self.sys) - } - - /// Returns the local address of this socket - /// - /// Will return `Err` result in windows if called before calling `bind` - pub fn get_localaddr(&self) -> io::Result { - sys::tcp::get_localaddr(self.sys) - } -} - -impl Drop for TcpSocket { - fn drop(&mut self) { - sys::tcp::close(self.sys); - } -} - -#[cfg(unix)] -impl IntoRawFd for TcpSocket { - fn into_raw_fd(self) -> RawFd { - let ret = self.sys; - // Avoid closing the socket - mem::forget(self); - ret - } -} - -#[cfg(unix)] -impl AsRawFd for TcpSocket { - fn as_raw_fd(&self) -> RawFd { - self.sys - } -} - -#[cfg(unix)] -impl FromRawFd for TcpSocket { - /// Converts a `RawFd` to a `TcpSocket`. - /// - /// # Notes - /// - /// The caller is responsible for ensuring that the socket is in - /// non-blocking mode. - unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket { - TcpSocket { sys: fd } - } -} - -#[cfg(windows)] -impl IntoRawSocket for TcpSocket { - fn into_raw_socket(self) -> RawSocket { - // The winapi crate defines `SOCKET` as `usize`. The Rust std - // conditionally defines `RawSocket` as a fixed size unsigned integer - // matching the pointer width. These end up being the same type but we - // must cast between them. - let ret = self.sys as RawSocket; - - // Avoid closing the socket - mem::forget(self); - - ret - } -} - -#[cfg(windows)] -impl AsRawSocket for TcpSocket { - fn as_raw_socket(&self) -> RawSocket { - self.sys as RawSocket - } -} - -#[cfg(windows)] -impl FromRawSocket for TcpSocket { - /// Converts a `RawSocket` to a `TcpSocket`. - /// - /// # Notes - /// - /// The caller is responsible for ensuring that the socket is in - /// non-blocking mode. - unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket { - TcpSocket { - sys: socket as sys::tcp::TcpSocket, - } - } -} - -impl TcpKeepalive { - // Sets the amount of time after which TCP keepalive probes will be sent - /// on idle connections. - /// - /// This will set the value of `SO_KEEPALIVE` + `IPPROTO_TCP` on OpenBSD, - /// NetBSD, and Haiku, `TCP_KEEPALIVE` on macOS and iOS, and `TCP_KEEPIDLE` - /// on all other Unix operating systems. On Windows, this sets the value of - /// the `tcp_keepalive` struct's `keepalivetime` field. - /// - /// Some platforms specify this value in seconds, so sub-second - /// specifications may be omitted. - pub fn with_time(self, time: Duration) -> Self { - Self { - time: Some(time), - ..self - } - } - - /// Sets the time interval between TCP keepalive probes. - /// This sets the value of `TCP_KEEPINTVL` on supported Unix operating - /// systems. On Windows, this sets the value of the `tcp_keepalive` struct's - /// `keepaliveinterval` field. - /// - /// Some platforms specify this value in seconds, so sub-second - /// specifications may be omitted. - #[cfg_attr( - docsrs, - doc(cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - target_os = "windows" - ))) - )] - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - target_os = "windows" - ))] - pub fn with_interval(self, interval: Duration) -> Self { - Self { - interval: Some(interval), - ..self - } - } - - /// Sets the maximum number of TCP keepalive probes that will be sent before - /// dropping a connection, if TCP keepalive is enabled on this socket. - /// - /// This will set the value of `TCP_KEEPCNT` on Unix operating systems that - /// support this option. - #[cfg_attr( - docsrs, - doc(cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))) - )] - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))] - pub fn with_retries(self, retries: u32) -> Self { - Self { - retries: Some(retries), - ..self - } - } - - /// Returns a new, empty set of TCP keepalive parameters. - pub fn new() -> Self { - Self::default() - } -} diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 72bfdebc4..ecc850fec 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -7,7 +7,7 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use crate::io_source::IoSource; -use crate::net::TcpSocket; +use crate::sys::tcp::{connect, new_for_addr}; use crate::{event, Interest, Registry, Token}; /// A non-blocking TCP stream between a local socket and a remote socket. @@ -50,8 +50,13 @@ impl TcpStream { /// Create a new TCP stream and issue a non-blocking connect to the /// specified address. pub fn connect(addr: SocketAddr) -> io::Result { - let socket = TcpSocket::new_for_addr(addr)?; - socket.connect(addr) + let socket = new_for_addr(addr)?; + #[cfg(unix)] + let stream = unsafe { TcpStream::from_raw_fd(socket) }; + #[cfg(windows)] + let stream = unsafe { TcpStream::from_raw_socket(socket as _) }; + connect(&stream.inner, addr)?; + Ok(stream) } /// Creates a new `TcpStream` from a standard `net::TcpStream`. diff --git a/src/sys/shell/tcp.rs b/src/sys/shell/tcp.rs index 0ed225f71..60dfe70f6 100644 --- a/src/sys/shell/tcp.rs +++ b/src/sys/shell/tcp.rs @@ -1,127 +1,27 @@ -use crate::net::TcpKeepalive; use std::io; use std::net::{self, SocketAddr}; -use std::time::Duration; -pub(crate) type TcpSocket = i32; - -pub(crate) fn new_v4_socket() -> io::Result { - os_required!(); -} - -pub(crate) fn new_v6_socket() -> io::Result { - os_required!(); -} - -pub(crate) fn bind(_socket: TcpSocket, _addr: SocketAddr) -> io::Result<()> { - os_required!(); -} - -pub(crate) fn connect(_: TcpSocket, _addr: SocketAddr) -> io::Result { - os_required!(); -} - -pub(crate) fn listen(_: TcpSocket, _: u32) -> io::Result { +pub(crate) fn new_for_addr(_: SocketAddr) -> io::Result { os_required!(); } -pub(crate) fn close(_: TcpSocket) { +pub(crate) fn bind(_: &net::TcpListener, _: SocketAddr) -> io::Result<()> { os_required!(); } -pub(crate) fn set_reuseaddr(_: TcpSocket, _: bool) -> io::Result<()> { +pub(crate) fn connect(_: &net::TcpStream, _: SocketAddr) -> io::Result<()> { os_required!(); } -pub(crate) fn get_reuseaddr(_: TcpSocket) -> io::Result { +pub(crate) fn listen(_: &net::TcpListener, _: u32) -> io::Result<()> { os_required!(); } -#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] -pub(crate) fn set_reuseport(_: TcpSocket, _: bool) -> io::Result<()> { - os_required!(); -} - -#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] -pub(crate) fn get_reuseport(_: TcpSocket) -> io::Result { - os_required!(); -} - -pub(crate) fn set_linger(_: TcpSocket, _: Option) -> io::Result<()> { - os_required!(); -} - -pub(crate) fn get_linger(_: TcpSocket) -> io::Result> { - os_required!(); -} - -pub(crate) fn set_recv_buffer_size(_: TcpSocket, _: u32) -> io::Result<()> { - os_required!(); -} - -pub(crate) fn get_recv_buffer_size(_: TcpSocket) -> io::Result { - os_required!(); -} - -pub(crate) fn set_send_buffer_size(_: TcpSocket, _: u32) -> io::Result<()> { - os_required!(); -} - -pub(crate) fn get_send_buffer_size(_: TcpSocket) -> io::Result { - os_required!(); -} - -pub(crate) fn set_keepalive(_: TcpSocket, _: bool) -> io::Result<()> { - os_required!(); -} - -pub(crate) fn get_keepalive(_: TcpSocket) -> io::Result { - os_required!(); -} - -pub(crate) fn set_keepalive_params(_: TcpSocket, _: TcpKeepalive) -> io::Result<()> { - os_required!() -} - -#[cfg(any( - target_os = "android", - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - target_os = "solaris", -))] -pub(crate) fn get_keepalive_time(_: TcpSocket) -> io::Result> { - os_required!() -} - -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -pub(crate) fn get_keepalive_interval(_: TcpSocket) -> io::Result> { - os_required!() -} - -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -pub(crate) fn get_keepalive_retries(_: TcpSocket) -> io::Result> { - os_required!() -} - -pub fn accept(_: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { +#[cfg(unix)] +pub(crate) fn set_reuseaddr(_: &net::TcpListener, _: bool) -> io::Result<()> { os_required!(); } -pub(crate) fn get_localaddr(_: TcpSocket) -> io::Result { +pub(crate) fn accept(_: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { os_required!(); } diff --git a/src/sys/unix/tcp.rs b/src/sys/unix/tcp.rs index 8a8bcb606..f26bb303d 100644 --- a/src/sys/unix/tcp.rs +++ b/src/sys/unix/tcp.rs @@ -1,422 +1,57 @@ use std::convert::TryInto; use std::io; -use std::mem; use std::mem::{size_of, MaybeUninit}; use std::net::{self, SocketAddr}; use std::os::unix::io::{AsRawFd, FromRawFd}; -use std::time::Duration; -use crate::net::TcpKeepalive; use crate::sys::unix::net::{new_socket, socket_addr, to_socket_addr}; -#[cfg(any(target_os = "openbsd", target_os = "netbsd"))] -use libc::SO_KEEPALIVE as KEEPALIVE_TIME; -#[cfg(any(target_os = "macos", target_os = "ios"))] -use libc::TCP_KEEPALIVE as KEEPALIVE_TIME; -#[cfg(not(any( - target_os = "macos", - target_os = "ios", - target_os = "openbsd", - target_os = "netbsd", -)))] -use libc::TCP_KEEPIDLE as KEEPALIVE_TIME; -pub type TcpSocket = libc::c_int; - -pub(crate) fn new_v4_socket() -> io::Result { - new_socket(libc::AF_INET, libc::SOCK_STREAM) -} - -pub(crate) fn new_v6_socket() -> io::Result { - new_socket(libc::AF_INET6, libc::SOCK_STREAM) +pub(crate) fn new_for_addr(address: SocketAddr) -> io::Result { + let domain = match address { + SocketAddr::V4(_) => libc::AF_INET, + SocketAddr::V6(_) => libc::AF_INET6, + }; + new_socket(domain, libc::SOCK_STREAM) } -pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> { +pub(crate) fn bind(socket: &net::TcpListener, addr: SocketAddr) -> io::Result<()> { let (raw_addr, raw_addr_length) = socket_addr(&addr); - syscall!(bind(socket, raw_addr.as_ptr(), raw_addr_length))?; + syscall!(bind(socket.as_raw_fd(), raw_addr.as_ptr(), raw_addr_length))?; Ok(()) } -pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result { +pub(crate) fn connect(socket: &net::TcpStream, addr: SocketAddr) -> io::Result<()> { let (raw_addr, raw_addr_length) = socket_addr(&addr); - match syscall!(connect(socket, raw_addr.as_ptr(), raw_addr_length)) { + match syscall!(connect( + socket.as_raw_fd(), + raw_addr.as_ptr(), + raw_addr_length + )) { Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => Err(err), - _ => Ok(unsafe { net::TcpStream::from_raw_fd(socket) }), + _ => Ok(()), } } -pub(crate) fn listen(socket: TcpSocket, backlog: u32) -> io::Result { +pub(crate) fn listen(socket: &net::TcpListener, backlog: u32) -> io::Result<()> { let backlog = backlog.try_into().unwrap_or(i32::max_value()); - syscall!(listen(socket, backlog))?; - Ok(unsafe { net::TcpListener::from_raw_fd(socket) }) -} - -pub(crate) fn close(socket: TcpSocket) { - let _ = unsafe { net::TcpStream::from_raw_fd(socket) }; + syscall!(listen(socket.as_raw_fd(), backlog))?; + Ok(()) } -pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<()> { +pub(crate) fn set_reuseaddr(socket: &net::TcpListener, reuseaddr: bool) -> io::Result<()> { let val: libc::c_int = if reuseaddr { 1 } else { 0 }; syscall!(setsockopt( - socket, + socket.as_raw_fd(), libc::SOL_SOCKET, libc::SO_REUSEADDR, &val as *const libc::c_int as *const libc::c_void, size_of::() as libc::socklen_t, - )) - .map(|_| ()) -} - -pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result { - let mut optval: libc::c_int = 0; - let mut optlen = mem::size_of::() as libc::socklen_t; - - syscall!(getsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_REUSEADDR, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(optval != 0) -} - -#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] -pub(crate) fn set_reuseport(socket: TcpSocket, reuseport: bool) -> io::Result<()> { - let val: libc::c_int = if reuseport { 1 } else { 0 }; - - syscall!(setsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_REUSEPORT, - &val as *const libc::c_int as *const libc::c_void, - size_of::() as libc::socklen_t, - )) - .map(|_| ()) -} - -#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] -pub(crate) fn get_reuseport(socket: TcpSocket) -> io::Result { - let mut optval: libc::c_int = 0; - let mut optlen = mem::size_of::() as libc::socklen_t; - - syscall!(getsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_REUSEPORT, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(optval != 0) -} - -pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { - let mut addr: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; - let mut length = size_of::() as libc::socklen_t; - - syscall!(getsockname( - socket, - &mut addr as *mut _ as *mut _, - &mut length - ))?; - - unsafe { to_socket_addr(&addr) } -} - -pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { - let val: libc::linger = libc::linger { - l_onoff: if dur.is_some() { 1 } else { 0 }, - l_linger: dur - .map(|dur| dur.as_secs() as libc::c_int) - .unwrap_or_default(), - }; - syscall!(setsockopt( - socket, - libc::SOL_SOCKET, - #[cfg(target_vendor = "apple")] - libc::SO_LINGER_SEC, - #[cfg(not(target_vendor = "apple"))] - libc::SO_LINGER, - &val as *const libc::linger as *const libc::c_void, - size_of::() as libc::socklen_t, - )) - .map(|_| ()) -} - -pub(crate) fn get_linger(socket: TcpSocket) -> io::Result> { - let mut val: libc::linger = unsafe { std::mem::zeroed() }; - let mut len = mem::size_of::() as libc::socklen_t; - - syscall!(getsockopt( - socket, - libc::SOL_SOCKET, - #[cfg(target_vendor = "apple")] - libc::SO_LINGER_SEC, - #[cfg(not(target_vendor = "apple"))] - libc::SO_LINGER, - &mut val as *mut _ as *mut _, - &mut len, - ))?; - - if val.l_onoff == 0 { - Ok(None) - } else { - Ok(Some(Duration::from_secs(val.l_linger as u64))) - } -} - -pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { - let size = size.try_into().ok().unwrap_or_else(i32::max_value); - syscall!(setsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_RCVBUF, - &size as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ()) -} - -pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { - let mut optval: libc::c_int = 0; - let mut optlen = size_of::() as libc::socklen_t; - syscall!(getsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_RCVBUF, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(optval as u32) -} - -pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { - let size = size.try_into().ok().unwrap_or_else(i32::max_value); - syscall!(setsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_SNDBUF, - &size as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ()) -} - -pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { - let mut optval: libc::c_int = 0; - let mut optlen = size_of::() as libc::socklen_t; - - syscall!(getsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_SNDBUF, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(optval as u32) -} - -pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> { - let val: libc::c_int = if keepalive { 1 } else { 0 }; - syscall!(setsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_KEEPALIVE, - &val as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ()) -} - -pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result { - let mut optval: libc::c_int = 0; - let mut optlen = mem::size_of::() as libc::socklen_t; - - syscall!(getsockopt( - socket, - libc::SOL_SOCKET, - libc::SO_KEEPALIVE, - &mut optval as *mut _ as *mut _, - &mut optlen, ))?; - - Ok(optval != 0) -} - -pub(crate) fn set_keepalive_params(socket: TcpSocket, keepalive: TcpKeepalive) -> io::Result<()> { - if let Some(dur) = keepalive.time { - set_keepalive_time(socket, dur)?; - } - - #[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - ))] - { - if let Some(dur) = keepalive.interval { - set_keepalive_interval(socket, dur)?; - } - - if let Some(retries) = keepalive.retries { - set_keepalive_retries(socket, retries)?; - } - } - Ok(()) } -fn set_keepalive_time(socket: TcpSocket, time: Duration) -> io::Result<()> { - let time_secs = time - .as_secs() - .try_into() - .ok() - .unwrap_or_else(i32::max_value); - syscall!(setsockopt( - socket, - libc::IPPROTO_TCP, - KEEPALIVE_TIME, - &(time_secs as libc::c_int) as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ()) -} - -pub(crate) fn get_keepalive_time(socket: TcpSocket) -> io::Result> { - if !get_keepalive(socket)? { - return Ok(None); - } - - let mut optval: libc::c_int = 0; - let mut optlen = mem::size_of::() as libc::socklen_t; - syscall!(getsockopt( - socket, - libc::IPPROTO_TCP, - KEEPALIVE_TIME, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(Some(Duration::from_secs(optval as u64))) -} - -/// Linux, FreeBSD, and NetBSD support setting the keepalive interval via -/// `TCP_KEEPINTVL`. -/// See: -/// - https://man7.org/linux/man-pages/man7/tcp.7.html -/// - https://www.freebsd.org/cgi/man.cgi?query=tcp#end -/// - http://man.netbsd.org/tcp.4#DESCRIPTION -/// -/// OpenBSD does not: -/// https://man.openbsd.org/tcp -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -fn set_keepalive_interval(socket: TcpSocket, interval: Duration) -> io::Result<()> { - let interval_secs = interval - .as_secs() - .try_into() - .ok() - .unwrap_or_else(i32::max_value); - syscall!(setsockopt( - socket, - libc::IPPROTO_TCP, - libc::TCP_KEEPINTVL, - &(interval_secs as libc::c_int) as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ()) -} - -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -pub(crate) fn get_keepalive_interval(socket: TcpSocket) -> io::Result> { - if !get_keepalive(socket)? { - return Ok(None); - } - - let mut optval: libc::c_int = 0; - let mut optlen = mem::size_of::() as libc::socklen_t; - syscall!(getsockopt( - socket, - libc::IPPROTO_TCP, - libc::TCP_KEEPINTVL, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(Some(Duration::from_secs(optval as u64))) -} - -/// Linux, macOS/iOS, FreeBSD, and NetBSD support setting the number of TCP -/// keepalive retries via `TCP_KEEPCNT`. -/// See: -/// - https://man7.org/linux/man-pages/man7/tcp.7.html -/// - https://www.freebsd.org/cgi/man.cgi?query=tcp#end -/// - http://man.netbsd.org/tcp.4#DESCRIPTION -/// -/// OpenBSD does not: -/// https://man.openbsd.org/tcp -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -fn set_keepalive_retries(socket: TcpSocket, retries: u32) -> io::Result<()> { - let retries = retries.try_into().ok().unwrap_or_else(i32::max_value); - syscall!(setsockopt( - socket, - libc::IPPROTO_TCP, - libc::TCP_KEEPCNT, - &(retries as libc::c_int) as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ()) -} - -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -pub(crate) fn get_keepalive_retries(socket: TcpSocket) -> io::Result> { - if !get_keepalive(socket)? { - return Ok(None); - } - - let mut optval: libc::c_int = 0; - let mut optlen = mem::size_of::() as libc::socklen_t; - syscall!(getsockopt( - socket, - libc::IPPROTO_TCP, - libc::TCP_KEEPCNT, - &mut optval as *mut _ as *mut _, - &mut optlen, - ))?; - - Ok(Some(optval as u32)) -} - -pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { +pub(crate) fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { let mut addr: MaybeUninit = MaybeUninit::uninit(); let mut length = size_of::() as libc::socklen_t; diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 442f49517..b3f05aec6 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -1,345 +1,67 @@ -use std::convert::TryInto; use std::io; -use std::mem::size_of; -use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::os::windows::io::FromRawSocket; -use std::os::windows::raw::SOCKET as StdSocket; -use std::ptr; -use std::time::Duration; // winapi uses usize, stdlib uses u32/u64. +use std::net::{self, SocketAddr}; +use std::os::windows::io::AsRawSocket; -use winapi::ctypes::{c_char, c_int, c_ulong, c_ushort}; -use winapi::shared::mstcpip; -use winapi::shared::ws2def::{AF_INET, AF_INET6, SOCKADDR_IN, SOCKADDR_STORAGE}; -use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH; +use winapi::um::winsock2::{self, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, SOCK_STREAM}; -use winapi::shared::minwindef::{BOOL, DWORD, FALSE, LPDWORD, LPVOID, TRUE}; -use winapi::um::winsock2::{ - self, closesocket, getsockname, getsockopt, linger, setsockopt, WSAIoctl, LPWSAOVERLAPPED, - PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_KEEPALIVE, SO_LINGER, - SO_RCVBUF, SO_REUSEADDR, SO_SNDBUF, -}; - -use crate::net::TcpKeepalive; use crate::sys::windows::net::{init, new_socket, socket_addr}; -pub(crate) type TcpSocket = SOCKET; - -pub(crate) fn new_v4_socket() -> io::Result { - init(); - new_socket(PF_INET, SOCK_STREAM) -} - -pub(crate) fn new_v6_socket() -> io::Result { +pub(crate) fn new_for_addr(address: SocketAddr) -> io::Result { init(); - new_socket(PF_INET6, SOCK_STREAM) + let domain = match address { + SocketAddr::V4(_) => PF_INET, + SocketAddr::V6(_) => PF_INET6, + }; + new_socket(domain, SOCK_STREAM) } -pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> { +pub(crate) fn bind(socket: &net::TcpListener, addr: SocketAddr) -> io::Result<()> { use winsock2::bind; let (raw_addr, raw_addr_length) = socket_addr(&addr); syscall!( - bind(socket, raw_addr.as_ptr(), raw_addr_length), + bind( + socket.as_raw_socket() as _, + raw_addr.as_ptr(), + raw_addr_length + ), PartialEq::eq, SOCKET_ERROR )?; Ok(()) } -pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result { +pub(crate) fn connect(socket: &net::TcpStream, addr: SocketAddr) -> io::Result<()> { use winsock2::connect; let (raw_addr, raw_addr_length) = socket_addr(&addr); - let res = syscall!( - connect(socket, raw_addr.as_ptr(), raw_addr_length), + connect( + socket.as_raw_socket() as _, + raw_addr.as_ptr(), + raw_addr_length + ), PartialEq::eq, SOCKET_ERROR ); match res { Err(err) if err.kind() != io::ErrorKind::WouldBlock => Err(err), - _ => Ok(unsafe { net::TcpStream::from_raw_socket(socket as StdSocket) }), + _ => Ok(()), } } -pub(crate) fn listen(socket: TcpSocket, backlog: u32) -> io::Result { +pub(crate) fn listen(socket: &net::TcpListener, backlog: u32) -> io::Result<()> { use std::convert::TryInto; use winsock2::listen; let backlog = backlog.try_into().unwrap_or(i32::max_value()); - syscall!(listen(socket, backlog), PartialEq::eq, SOCKET_ERROR)?; - Ok(unsafe { net::TcpListener::from_raw_socket(socket as StdSocket) }) -} - -pub(crate) fn close(socket: TcpSocket) { - let _ = unsafe { closesocket(socket) }; -} - -pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<()> { - let val: BOOL = if reuseaddr { TRUE } else { FALSE }; - - match unsafe { - setsockopt( - socket, - SOL_SOCKET, - SO_REUSEADDR, - &val as *const _ as *const c_char, - size_of::() as c_int, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(()), - } -} - -pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result { - let mut optval: c_char = 0; - let mut optlen = size_of::() as c_int; - - match unsafe { - getsockopt( - socket, - SOL_SOCKET, - SO_REUSEADDR, - &mut optval as *mut _ as *mut _, - &mut optlen, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(optval != 0), - } -} - -pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { - let mut storage: SOCKADDR_STORAGE = unsafe { std::mem::zeroed() }; - let mut length = std::mem::size_of_val(&storage) as c_int; - - match unsafe { getsockname(socket, &mut storage as *mut _ as *mut _, &mut length) } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => { - if storage.ss_family as c_int == AF_INET { - // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in. - let addr: &SOCKADDR_IN = unsafe { &*(&storage as *const _ as *const SOCKADDR_IN) }; - let ip_bytes = unsafe { addr.sin_addr.S_un.S_un_b() }; - let ip = - Ipv4Addr::from([ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4]); - let port = u16::from_be(addr.sin_port); - Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) - } else if storage.ss_family as c_int == AF_INET6 { - // Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6. - let addr: &SOCKADDR_IN6_LH = - unsafe { &*(&storage as *const _ as *const SOCKADDR_IN6_LH) }; - let ip = Ipv6Addr::from(*unsafe { addr.sin6_addr.u.Byte() }); - let port = u16::from_be(addr.sin6_port); - let scope_id = unsafe { *addr.u.sin6_scope_id() }; - Ok(SocketAddr::V6(SocketAddrV6::new( - ip, - port, - addr.sin6_flowinfo, - scope_id, - ))) - } else { - Err(std::io::ErrorKind::InvalidInput.into()) - } - } - } -} - -pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { - let val: linger = linger { - l_onoff: if dur.is_some() { 1 } else { 0 }, - l_linger: dur.map(|dur| dur.as_secs() as c_ushort).unwrap_or_default(), - }; - - match unsafe { - setsockopt( - socket, - SOL_SOCKET, - SO_LINGER, - &val as *const _ as *const c_char, - size_of::() as c_int, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(()), - } -} - -pub(crate) fn get_linger(socket: TcpSocket) -> io::Result> { - let mut val: linger = unsafe { std::mem::zeroed() }; - let mut len = size_of::() as c_int; - - match unsafe { - getsockopt( - socket, - SOL_SOCKET, - SO_LINGER, - &mut val as *mut _ as *mut _, - &mut len, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => { - if val.l_onoff == 0 { - Ok(None) - } else { - Ok(Some(Duration::from_secs(val.l_linger as u64))) - } - } - } -} - -pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { - let size = size.try_into().ok().unwrap_or_else(i32::max_value); - match unsafe { - setsockopt( - socket, - SOL_SOCKET, - SO_RCVBUF, - &size as *const _ as *const c_char, - size_of::() as c_int, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(()), - } -} - -pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { - let mut optval: c_int = 0; - let mut optlen = size_of::() as c_int; - match unsafe { - getsockopt( - socket, - SOL_SOCKET, - SO_RCVBUF, - &mut optval as *mut _ as *mut _, - &mut optlen as *mut _, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(optval as u32), - } -} - -pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { - let size = size.try_into().ok().unwrap_or_else(i32::max_value); - match unsafe { - setsockopt( - socket, - SOL_SOCKET, - SO_SNDBUF, - &size as *const _ as *const c_char, - size_of::() as c_int, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(()), - } -} - -pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { - let mut optval: c_int = 0; - let mut optlen = size_of::() as c_int; - match unsafe { - getsockopt( - socket, - SOL_SOCKET, - SO_SNDBUF, - &mut optval as *mut _ as *mut _, - &mut optlen as *mut _, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(optval as u32), - } -} - -pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> { - let val: BOOL = if keepalive { TRUE } else { FALSE }; - match unsafe { - setsockopt( - socket, - SOL_SOCKET, - SO_KEEPALIVE, - &val as *const _ as *const c_char, - size_of::() as c_int, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(()), - } -} - -pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result { - let mut optval: c_char = 0; - let mut optlen = size_of::() as c_int; - - match unsafe { - getsockopt( - socket, - SOL_SOCKET, - SO_KEEPALIVE, - &mut optval as *mut _ as *mut _, - &mut optlen, - ) - } { - SOCKET_ERROR => Err(io::Error::last_os_error()), - _ => Ok(optval != FALSE as c_char), - } -} - -pub(crate) fn set_keepalive_params(socket: TcpSocket, keepalive: TcpKeepalive) -> io::Result<()> { - /// Windows configures keepalive time/interval in a u32 of milliseconds. - fn dur_to_ulong_ms(dur: Duration) -> c_ulong { - dur.as_millis() - .try_into() - .ok() - .unwrap_or_else(u32::max_value) - } - - // If any of the fields on the `tcp_keepalive` struct were not provided by - // the user, just leaving them zero will clobber any existing value. - // Unfortunately, we can't access the current value, so we will use the - // defaults if a value for the time or interval was not not provided. - let time = keepalive.time.unwrap_or_else(|| { - // The default value is two hours, as per - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals - let two_hours = 2 * 60 * 60; - Duration::from_secs(two_hours) - }); - - let interval = keepalive.interval.unwrap_or_else(|| { - // The default value is one second, as per - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals - Duration::from_secs(1) - }); - - let mut keepalive = mstcpip::tcp_keepalive { - // Enable keepalive - onoff: 1, - keepalivetime: dur_to_ulong_ms(time), - keepaliveinterval: dur_to_ulong_ms(interval), - }; - - let mut out = 0; - match unsafe { - WSAIoctl( - socket, - mstcpip::SIO_KEEPALIVE_VALS, - &mut keepalive as *mut _ as LPVOID, - size_of::() as DWORD, - ptr::null_mut() as LPVOID, - 0 as DWORD, - &mut out as *mut _ as LPDWORD, - 0 as LPWSAOVERLAPPED, - None, - ) - } { - 0 => Ok(()), - _ => Err(io::Error::last_os_error()), - } + syscall!( + listen(socket.as_raw_socket() as _, backlog), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(()) } pub(crate) fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { diff --git a/tests/tcp.rs b/tests/tcp.rs index ab5b7e64a..6ff38d2ca 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -1,6 +1,6 @@ #![cfg(all(feature = "os-poll", feature = "net"))] -use mio::net::{TcpListener, TcpSocket, TcpStream}; +use mio::net::{TcpListener, TcpStream}; use mio::{Events, Interest, Poll, Token}; use std::io::{self, Read, Write}; use std::net::{self, Shutdown}; @@ -12,7 +12,7 @@ use std::time::Duration; mod util; use util::{ any_local_address, assert_send, assert_sync, expect_events, expect_no_events, init, - init_with_poll, ExpectEvent, + init_with_poll, set_linger_zero, ExpectEvent, }; const LISTEN: Token = Token(0); @@ -481,9 +481,8 @@ fn connection_reset_by_peer() { let addr = listener.local_addr().unwrap(); // Connect client - let client = TcpSocket::new_v4().unwrap(); - client.set_linger(Some(Duration::from_millis(0))).unwrap(); - let mut client = client.connect(addr).unwrap(); + let mut client = TcpStream::connect(addr).unwrap(); + set_linger_zero(&client); // Register server poll.registry() diff --git a/tests/tcp_socket.rs b/tests/tcp_socket.rs deleted file mode 100644 index 345cd8efb..000000000 --- a/tests/tcp_socket.rs +++ /dev/null @@ -1,203 +0,0 @@ -#![cfg(all(feature = "os-poll", feature = "net"))] - -use mio::net::{TcpKeepalive, TcpSocket}; -use std::io; -use std::time::Duration; - -#[test] -fn is_send_and_sync() { - fn is_send() {} - fn is_sync() {} - - is_send::(); - is_sync::(); -} - -#[test] -fn set_reuseaddr() { - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket.set_reuseaddr(true).unwrap(); - assert!(socket.get_reuseaddr().unwrap()); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] -#[test] -fn set_reuseport() { - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket.set_reuseport(true).unwrap(); - assert!(socket.get_reuseport().unwrap()); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[test] -fn set_keepalive() { - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket.set_keepalive(false).unwrap(); - assert!(!socket.get_keepalive().unwrap()); - - socket.set_keepalive(true).unwrap(); - assert!(socket.get_keepalive().unwrap()); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[test] -fn set_keepalive_time() { - let dur = Duration::from_secs(4); // Chosen by fair dice roll, guaranteed to be random - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket - .set_keepalive_params(TcpKeepalive::default().with_time(dur)) - .unwrap(); - - // It's not possible to access keepalive parameters on Windows... - #[cfg(not(target_os = "windows"))] - assert_eq!(Some(dur), socket.get_keepalive_time().unwrap()); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", - target_os = "windows" -))] -#[test] -fn set_keepalive_interval() { - let dur = Duration::from_secs(4); // Chosen by fair dice roll, guaranteed to be random - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket - .set_keepalive_params(TcpKeepalive::default().with_interval(dur)) - .unwrap(); - // It's not possible to access keepalive parameters on Windows... - #[cfg(not(target_os = "windows"))] - assert_eq!(Some(dur), socket.get_keepalive_interval().unwrap()); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[cfg(any( - target_os = "linux", - target_os = "macos", - target_os = "ios", - target_os = "freebsd", - target_os = "netbsd", -))] -#[test] -fn set_keepalive_retries() { - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket - .set_keepalive_params(TcpKeepalive::default().with_retries(16)) - .unwrap(); - assert_eq!(Some(16), socket.get_keepalive_retries().unwrap()); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[test] -fn get_localaddr() { - let expected_addr = "127.0.0.1:0".parse().unwrap(); - let socket = TcpSocket::new_v4().unwrap(); - - //Windows doesn't support calling getsockname before calling `bind` - #[cfg(not(windows))] - assert_eq!("0.0.0.0:0", socket.get_localaddr().unwrap().to_string()); - - socket.bind(expected_addr).unwrap(); - - let actual_addr = socket.get_localaddr().unwrap(); - - assert_eq!(expected_addr.ip(), actual_addr.ip()); - assert!(actual_addr.port() > 0); - - let _ = socket.listen(128).unwrap(); -} - -#[test] -fn set_linger() { - let addr = "127.0.0.1:0".parse().unwrap(); - - let socket = TcpSocket::new_v4().unwrap(); - socket.set_linger(Some(Duration::from_secs(1))).unwrap(); - assert_eq!(socket.get_linger().unwrap().unwrap().as_secs(), 1); - - let _ = socket.set_linger(None); - assert_eq!(socket.get_linger().unwrap(), None); - - socket.bind(addr).unwrap(); - - let _ = socket.listen(128).unwrap(); -} - -#[test] -fn send_buffer_size_roundtrips() { - test_buffer_sizes( - TcpSocket::set_send_buffer_size, - TcpSocket::get_send_buffer_size, - ) -} - -#[test] -fn recv_buffer_size_roundtrips() { - test_buffer_sizes( - TcpSocket::set_recv_buffer_size, - TcpSocket::get_recv_buffer_size, - ) -} - -// Helper for testing send/recv buffer size. -fn test_buffer_sizes( - set: impl Fn(&TcpSocket, u32) -> io::Result<()>, - get: impl Fn(&TcpSocket) -> io::Result, -) { - let test = |size: u32| { - println!("testing buffer size: {}", size); - let socket = TcpSocket::new_v4().unwrap(); - set(&socket, size).unwrap(); - // Note that this doesn't assert that the values are equal: on Linux, - // the kernel doubles the requested buffer size, and returns the doubled - // value from `getsockopt`. As per `man socket(7)`: - // > Sets or gets the maximum socket send buffer in bytes. The - // > kernel doubles this value (to allow space for bookkeeping - // > overhead) when it is set using setsockopt(2), and this doubled - // > value is returned by getsockopt(2). - // - // Additionally, the buffer size may be clamped above a minimum value, - // and this minimum value is OS-dependent. - let actual = get(&socket).unwrap(); - assert!(actual >= size, "\tactual: {}\n\texpected: {}", actual, size); - }; - - test(256); - test(4096); - test(65512); -} diff --git a/tests/tcp_stream.rs b/tests/tcp_stream.rs index be4282b45..d60dc8504 100644 --- a/tests/tcp_stream.rs +++ b/tests/tcp_stream.rs @@ -1,17 +1,14 @@ #![cfg(all(feature = "os-poll", feature = "net"))] use std::io::{self, IoSlice, IoSliceMut, Read, Write}; -use std::mem::forget; use std::net::{self, Shutdown, SocketAddr}; #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, FromRawSocket}; use std::sync::{mpsc::channel, Arc, Barrier}; use std::thread; use std::time::Duration; -use mio::net::{TcpSocket, TcpStream}; +use mio::net::TcpStream; use mio::{Interest, Token}; #[macro_use] @@ -21,7 +18,7 @@ use util::init; use util::{ any_local_address, any_local_ipv6_address, assert_send, assert_socket_close_on_exec, assert_socket_non_blocking, assert_sync, assert_would_block, expect_events, expect_no_events, - init_with_poll, ExpectEvent, Readiness, + init_with_poll, set_linger_zero, ExpectEvent, Readiness, }; const DATA1: &[u8] = b"Hello world!"; @@ -800,13 +797,3 @@ fn hup_event_on_disconnect() { vec![ExpectEvent::new(Token(1), Interest::READABLE)], ); } - -fn set_linger_zero(socket: &TcpStream) { - #[cfg(windows)] - let s = unsafe { TcpSocket::from_raw_socket(socket.as_raw_socket()) }; - #[cfg(unix)] - let s = unsafe { TcpSocket::from_raw_fd(socket.as_raw_fd()) }; - - s.set_linger(Some(Duration::from_millis(0))).unwrap(); - forget(s); -} diff --git a/tests/util/mod.rs b/tests/util/mod.rs index 14138ea61..9b1d12633 100644 --- a/tests/util/mod.rs +++ b/tests/util/mod.rs @@ -2,6 +2,7 @@ #![allow(dead_code, unused_macros)] #![cfg(any(feature = "os-poll", feature = "net"))] +use std::mem::size_of; use std::net::SocketAddr; use std::ops::BitOr; #[cfg(unix)] @@ -13,6 +14,7 @@ use std::{env, fmt, fs, io}; use log::{error, warn}; use mio::event::Event; +use mio::net::TcpStream; use mio::{Events, Interest, Poll, Token}; pub fn init() { @@ -236,6 +238,51 @@ pub fn any_local_ipv6_address() -> SocketAddr { "[::1]:0".parse().unwrap() } +#[cfg(unix)] +pub fn set_linger_zero(socket: &TcpStream) { + let val = libc::linger { + l_onoff: 1, + l_linger: 0, + }; + let res = unsafe { + libc::setsockopt( + socket.as_raw_fd(), + libc::SOL_SOCKET, + #[cfg(target_vendor = "apple")] + libc::SO_LINGER_SEC, + #[cfg(not(target_vendor = "apple"))] + libc::SO_LINGER, + &val as *const libc::linger as *const libc::c_void, + size_of::() as libc::socklen_t, + ) + }; + assert_eq!(res, 0); +} + +#[cfg(windows)] +pub fn set_linger_zero(socket: &TcpStream) { + use std::os::windows::io::AsRawSocket; + use winapi::um::winsock2::{linger, setsockopt, SOCKET_ERROR, SOL_SOCKET, SO_LINGER}; + + let val = linger { + l_onoff: 1, + l_linger: 0, + }; + + match unsafe { + setsockopt( + socket.as_raw_socket() as _, + SOL_SOCKET, + SO_LINGER, + &val as *const _ as *const _, + size_of::() as _, + ) + } { + SOCKET_ERROR => panic!("error setting linger: {}", io::Error::last_os_error()), + _ => {} + } +} + /// Returns a path to a temporary file using `name` as filename. pub fn temp_file(name: &'static str) -> PathBuf { let mut path = temp_dir();