Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add I/O safety to sockopt and some socket functions #1915

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/sys/socket/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,9 +875,10 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
/// One common use is to match on the family of a union type, like this:
/// ```
/// # use nix::sys::socket::*;
/// # use std::os::unix::io::AsRawFd;
/// let fd = socket(AddressFamily::Inet, SockType::Stream,
/// SockFlag::empty(), None).unwrap();
/// let ss: SockaddrStorage = getsockname(fd).unwrap();
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).unwrap();
/// match ss.family().unwrap() {
/// AddressFamily::Inet => println!("{}", ss.as_sockaddr_in().unwrap()),
/// AddressFamily::Inet6 => println!("{}", ss.as_sockaddr_in6().unwrap()),
Expand Down Expand Up @@ -1261,11 +1262,12 @@ impl std::str::FromStr for SockaddrIn6 {
/// ```
/// # use nix::sys::socket::*;
/// # use std::str::FromStr;
/// # use std::os::unix::io::AsRawFd;
/// let localhost = SockaddrIn::from_str("127.0.0.1:8081").unwrap();
/// let fd = socket(AddressFamily::Inet, SockType::Stream, SockFlag::empty(),
/// None).unwrap();
/// bind(fd, &localhost).expect("bind");
/// let ss: SockaddrStorage = getsockname(fd).expect("getsockname");
/// bind(fd.as_raw_fd(), &localhost).expect("bind");
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).expect("getsockname");
/// assert_eq!(&localhost, ss.as_sockaddr_in().unwrap());
/// ```
#[derive(Clone, Copy, Eq)]
Expand Down
66 changes: 39 additions & 27 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use libc::{
use std::io::{IoSlice, IoSliceMut};
#[cfg(feature = "net")]
use std::net;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, RawFd, OwnedFd};
use std::{mem, ptr};

#[deny(missing_docs)]
Expand Down Expand Up @@ -693,6 +693,7 @@ pub enum ControlMessageOwned {
/// # use std::io::{IoSlice, IoSliceMut};
/// # use std::time::*;
/// # use std::str::FromStr;
/// # use std::os::unix::io::AsRawFd;
/// # fn main() {
/// // Set up
/// let message = "Ohayō!".as_bytes();
Expand All @@ -701,22 +702,22 @@ pub enum ControlMessageOwned {
/// SockType::Datagram,
/// SockFlag::empty(),
/// None).unwrap();
/// setsockopt(in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
/// setsockopt(&in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
/// let localhost = SockaddrIn::from_str("127.0.0.1:0").unwrap();
/// bind(in_socket, &localhost).unwrap();
/// let address: SockaddrIn = getsockname(in_socket).unwrap();
/// bind(in_socket.as_raw_fd(), &localhost).unwrap();
/// let address: SockaddrIn = getsockname(in_socket.as_raw_fd()).unwrap();
/// // Get initial time
/// let time0 = SystemTime::now();
/// // Send the message
/// let iov = [IoSlice::new(message)];
/// let flags = MsgFlags::empty();
/// let l = sendmsg(in_socket, &iov, &[], flags, Some(&address)).unwrap();
/// let l = sendmsg(in_socket.as_raw_fd(), &iov, &[], flags, Some(&address)).unwrap();
/// assert_eq!(message.len(), l);
/// // Receive the message
/// let mut buffer = vec![0u8; message.len()];
/// let mut cmsgspace = cmsg_space!(TimeVal);
/// let mut iov = [IoSliceMut::new(&mut buffer)];
/// let r = recvmsg::<SockaddrIn>(in_socket, &mut iov, Some(&mut cmsgspace), flags)
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
/// .unwrap();
/// let rtime = match r.cmsgs().next() {
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
Expand All @@ -732,7 +733,6 @@ pub enum ControlMessageOwned {
/// assert!(time0.duration_since(UNIX_EPOCH).unwrap() <= rduration);
/// assert!(rduration <= time1.duration_since(UNIX_EPOCH).unwrap());
/// // Close socket
/// nix::unistd::close(in_socket).unwrap();
/// # }
/// ```
ScmTimestamp(TimeVal),
Expand Down Expand Up @@ -1451,6 +1451,7 @@ impl<'a> ControlMessage<'a> {
/// # use nix::sys::socket::*;
/// # use nix::unistd::pipe;
/// # use std::io::IoSlice;
/// # use std::os::unix::io::AsRawFd;
/// let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, None,
/// SockFlag::empty())
/// .unwrap();
Expand All @@ -1459,14 +1460,15 @@ impl<'a> ControlMessage<'a> {
/// let iov = [IoSlice::new(b"hello")];
/// let fds = [r];
/// let cmsg = ControlMessage::ScmRights(&fds);
/// sendmsg::<()>(fd1, &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
/// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
/// ```
/// When directing to a specific address, the generic type will be inferred.
/// ```
/// # use nix::sys::socket::*;
/// # use nix::unistd::pipe;
/// # use std::io::IoSlice;
/// # use std::str::FromStr;
/// # use std::os::unix::io::AsRawFd;
/// let localhost = SockaddrIn::from_str("1.2.3.4:8080").unwrap();
/// let fd = socket(AddressFamily::Inet, SockType::Datagram, SockFlag::empty(),
/// None).unwrap();
Expand All @@ -1475,7 +1477,7 @@ impl<'a> ControlMessage<'a> {
/// let iov = [IoSlice::new(b"hello")];
/// let fds = [r];
/// let cmsg = ControlMessage::ScmRights(&fds);
/// sendmsg(fd, &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
/// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
/// ```
pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
flags: MsgFlags, addr: Option<&S>) -> Result<usize>
Expand Down Expand Up @@ -1823,6 +1825,7 @@ mod test {
use crate::sys::socket::{AddressFamily, ControlMessageOwned};
use crate::*;
use std::str::FromStr;
use std::os::unix::io::AsRawFd;

#[cfg_attr(qemu, ignore)]
#[test]
Expand All @@ -1849,9 +1852,9 @@ mod test {
None,
)?;

crate::sys::socket::bind(rsock, &sock_addr)?;
crate::sys::socket::bind(rsock.as_raw_fd(), &sock_addr)?;

setsockopt(rsock, Timestamping, &TimestampingFlag::all())?;
setsockopt(&rsock, Timestamping, &TimestampingFlag::all())?;

let sbuf = (0..400).map(|i| i as u8).collect::<Vec<_>>();

Expand All @@ -1873,13 +1876,13 @@ mod test {
let iov1 = [IoSlice::new(&sbuf)];

let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();
sendmsg(ssock.as_raw_fd(), &iov1, &[], flags, Some(&sock_addr)).unwrap();

let mut data = super::MultiHeaders::<()>::preallocate(recv_iovs.len(), Some(cmsg));

let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));

let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?;
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;

for rmsg in recv {
#[cfg(not(any(qemu, target_arch = "aarch64")))]
Expand Down Expand Up @@ -2091,7 +2094,7 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
ty: SockType,
flags: SockFlag,
protocol: T,
) -> Result<RawFd> {
) -> Result<OwnedFd> {
let protocol = match protocol.into() {
None => 0,
Some(p) => p as c_int,
Expand All @@ -2105,7 +2108,13 @@ pub fn socket<T: Into<Option<SockProtocol>>>(

let res = unsafe { libc::socket(domain as c_int, ty, protocol) };

Errno::result(res)
match res {
-1 => Err(Errno::last()),
fd => {
// Safe because libc::socket returned success
unsafe { Ok(OwnedFd::from_raw_fd(fd)) }
}
}
asomers marked this conversation as resolved.
Show resolved Hide resolved
}

/// Create a pair of connected sockets
Expand All @@ -2116,7 +2125,7 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
ty: SockType,
protocol: T,
flags: SockFlag,
) -> Result<(RawFd, RawFd)> {
) -> Result<(OwnedFd, OwnedFd)> {
let protocol = match protocol.into() {
None => 0,
Some(p) => p as c_int,
Expand All @@ -2135,14 +2144,18 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
};
Errno::result(res)?;

Ok((fds[0], fds[1]))
// Safe because socketpair returned success.
unsafe {
Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1])))
}
}

/// Listen for connections on a socket
///
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
pub fn listen(sockfd: RawFd, backlog: usize) -> Result<()> {
let res = unsafe { libc::listen(sockfd, backlog as c_int) };
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
let fd = sock.as_fd().as_raw_fd();
let res = unsafe { libc::listen(fd, backlog as c_int) };

Errno::result(res).map(drop)
}
Expand Down Expand Up @@ -2302,21 +2315,21 @@ pub trait GetSockOpt: Copy {
type Val;

/// Look up the value of this socket option on the given socket.
fn get(&self, fd: RawFd) -> Result<Self::Val>;
fn get<F: AsFd>(&self, fd: &F) -> Result<Self::Val>;
}

/// Represents a socket option that can be set.
pub trait SetSockOpt: Clone {
type Val;

/// Set the value of this socket option on the given socket.
fn set(&self, fd: RawFd, val: &Self::Val) -> Result<()>;
fn set<F: AsFd>(&self, fd: &F, val: &Self::Val) -> Result<()>;
}

/// Get the current value for the requested socket option
///
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/getsockopt.html)
pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
pub fn getsockopt<F: AsFd, O: GetSockOpt>(fd: &F, opt: O) -> Result<O::Val> {
opt.get(fd)
}

Expand All @@ -2330,15 +2343,14 @@ pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
/// use nix::sys::socket::setsockopt;
/// use nix::sys::socket::sockopt::KeepAlive;
/// use std::net::TcpListener;
/// use std::os::unix::io::AsRawFd;
///
/// let listener = TcpListener::bind("0.0.0.0:0").unwrap();
/// let fd = listener.as_raw_fd();
/// let res = setsockopt(fd, KeepAlive, &true);
/// let fd = listener;
/// let res = setsockopt(&fd, KeepAlive, &true);
/// assert!(res.is_ok());
/// ```
pub fn setsockopt<O: SetSockOpt>(
fd: RawFd,
pub fn setsockopt<F: AsFd, O: SetSockOpt>(
fd: &F,
opt: O,
val: &O::Val,
) -> Result<()> {
Expand Down
41 changes: 17 additions & 24 deletions src/sys/socket/sockopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::ffi::{OsStr, OsString};
use std::mem::{self, MaybeUninit};
#[cfg(target_family = "unix")]
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd};

// Constants
// TCP_CA_NAME_MAX isn't defined in user space include files
Expand Down Expand Up @@ -44,12 +44,12 @@ macro_rules! setsockopt_impl {
impl SetSockOpt for $name {
type Val = $ty;

fn set(&self, fd: RawFd, val: &$ty) -> Result<()> {
fn set<F: AsFd>(&self, fd: &F, val: &$ty) -> Result<()> {
unsafe {
let setter: $setter = Set::new(val);

let res = libc::setsockopt(
fd,
fd.as_fd().as_raw_fd(),
$level,
$flag,
setter.ffi_ptr(),
Expand Down Expand Up @@ -89,12 +89,12 @@ macro_rules! getsockopt_impl {
impl GetSockOpt for $name {
type Val = $ty;

fn get(&self, fd: RawFd) -> Result<$ty> {
fn get<F: AsFd>(&self, fd: &F) -> Result<$ty> {
unsafe {
let mut getter: $getter = Get::uninit();

let res = libc::getsockopt(
fd,
fd.as_fd().as_raw_fd(),
$level,
$flag,
getter.ffi_ptr(),
Expand Down Expand Up @@ -1033,10 +1033,10 @@ pub struct AlgSetAeadAuthSize;
impl SetSockOpt for AlgSetAeadAuthSize {
type Val = usize;

fn set(&self, fd: RawFd, val: &usize) -> Result<()> {
fn set<F: AsFd>(&self, fd: &F, val: &usize) -> Result<()> {
unsafe {
let res = libc::setsockopt(
fd,
fd.as_fd().as_raw_fd(),
libc::SOL_ALG,
libc::ALG_SET_AEAD_AUTHSIZE,
::std::ptr::null(),
Expand Down Expand Up @@ -1067,10 +1067,10 @@ where
{
type Val = T;

fn set(&self, fd: RawFd, val: &T) -> Result<()> {
fn set<F: AsFd>(&self, fd: &F, val: &T) -> Result<()> {
unsafe {
let res = libc::setsockopt(
fd,
fd.as_fd().as_raw_fd(),
libc::SOL_ALG,
libc::ALG_SET_KEY,
val.as_ref().as_ptr() as *const _,
Expand Down Expand Up @@ -1383,34 +1383,30 @@ mod test {
SockFlag::empty(),
)
.unwrap();
let a_cred = getsockopt(a, super::PeerCredentials).unwrap();
let b_cred = getsockopt(b, super::PeerCredentials).unwrap();
let a_cred = getsockopt(&a, super::PeerCredentials).unwrap();
let b_cred = getsockopt(&b, super::PeerCredentials).unwrap();
assert_eq!(a_cred, b_cred);
assert_ne!(a_cred.pid(), 0);
}

#[test]
fn is_socket_type_unix() {
use super::super::*;
use crate::unistd::close;

let (a, b) = socketpair(
let (a, _b) = socketpair(
AddressFamily::Unix,
SockType::Stream,
None,
SockFlag::empty(),
)
.unwrap();
let a_type = getsockopt(a, super::SockType).unwrap();
let a_type = getsockopt(&a, super::SockType).unwrap();
assert_eq!(a_type, SockType::Stream);
close(a).unwrap();
close(b).unwrap();
}

#[test]
fn is_socket_type_dgram() {
use super::super::*;
use crate::unistd::close;

let s = socket(
AddressFamily::Inet,
Expand All @@ -1419,16 +1415,14 @@ mod test {
None,
)
.unwrap();
let s_type = getsockopt(s, super::SockType).unwrap();
let s_type = getsockopt(&s, super::SockType).unwrap();
assert_eq!(s_type, SockType::Datagram);
close(s).unwrap();
}

#[cfg(any(target_os = "freebsd", target_os = "linux"))]
#[test]
fn can_get_listen_on_tcp_socket() {
use super::super::*;
use crate::unistd::close;

let s = socket(
AddressFamily::Inet,
Expand All @@ -1437,11 +1431,10 @@ mod test {
None,
)
.unwrap();
let s_listening = getsockopt(s, super::AcceptConn).unwrap();
let s_listening = getsockopt(&s, super::AcceptConn).unwrap();
assert!(!s_listening);
listen(s, 10).unwrap();
let s_listening2 = getsockopt(s, super::AcceptConn).unwrap();
listen(&s, 10).unwrap();
let s_listening2 = getsockopt(&s, super::AcceptConn).unwrap();
assert!(s_listening2);
close(s).unwrap();
}
}
Loading