Skip to content

Commit

Permalink
Use rustix instead of nix
Browse files Browse the repository at this point in the history
  • Loading branch information
ids1024 authored and elinorbgr committed Oct 31, 2023
1 parent 8581b9d commit edd0f60
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 137 deletions.
10 changes: 4 additions & 6 deletions wayland-backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ features = [
"const_new", # 1.51
]

[dependencies.nix]
version = "0.26.0"
default-features = false
[dependencies.rustix]
version = "0.38.17"
features = [
"event",
"fs",
"poll",
"socket",
"uio",
"net",
"process",
]

[build-dependencies]
Expand Down
23 changes: 13 additions & 10 deletions wayland-backend/src/rs/server_impl/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
ffi::CString,
os::unix::io::OwnedFd,
os::unix::io::{AsFd, BorrowedFd, OwnedFd},
os::unix::{io::RawFd, net::UnixStream},
sync::Arc,
};
Expand Down Expand Up @@ -295,13 +295,10 @@ impl<D> Client<D> {

#[cfg(any(target_os = "linux", target_os = "android"))]
pub(crate) fn get_credentials(&self) -> Credentials {
use std::os::unix::io::AsRawFd;
let creds = nix::sys::socket::getsockopt(
self.socket.as_raw_fd(),
nix::sys::socket::sockopt::PeerCredentials,
)
.expect("getsockopt failed!?");
Credentials { pid: creds.pid(), uid: creds.uid(), gid: creds.gid() }
let creds =
rustix::net::sockopt::get_socket_peercred(&self.socket).expect("getsockopt failed!?");
let pid = rustix::process::Pid::as_raw(Some(creds.pid));
Credentials { pid, uid: creds.uid.as_raw(), gid: creds.gid.as_raw() }
}

#[cfg(not(any(target_os = "linux", target_os = "android")))]
Expand Down Expand Up @@ -336,7 +333,7 @@ impl<D> Client<D> {
&mut self,
) -> std::io::Result<(Message<u32, OwnedFd>, Object<Data<D>>)> {
if self.killed {
return Err(nix::errno::Errno::EPIPE.into());
return Err(rustix::io::Errno::PIPE.into());
}
loop {
let map = &self.map;
Expand All @@ -358,7 +355,7 @@ impl<D> Client<D> {
}
Err(MessageParseError::Malformed) => {
self.kill(DisconnectReason::ConnectionClosed);
return Err(nix::errno::Errno::EPROTO.into());
return Err(rustix::io::Errno::PROTO.into());
}
};

Expand Down Expand Up @@ -659,6 +656,12 @@ impl<D> Client<D> {
}
}

impl<D> AsFd for Client<D> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}

#[derive(Debug)]
pub(crate) struct ClientStore<D: 'static> {
clients: Vec<Option<Client<D>>>,
Expand Down
38 changes: 16 additions & 22 deletions wayland-backend/src/rs/server_impl/common_poll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
os::unix::io::{AsRawFd, FromRawFd},
os::unix::io::AsRawFd,
os::unix::io::{BorrowedFd, OwnedFd},
sync::{Arc, Mutex},
};
Expand All @@ -16,15 +16,15 @@ use crate::{
};

#[cfg(any(target_os = "linux", target_os = "android"))]
use nix::sys::epoll::*;
use rustix::event::epoll;

#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd"
))]
use nix::sys::event::*;
use rustix::event::kqueue::*;
use smallvec::SmallVec;

#[derive(Debug)]
Expand All @@ -35,7 +35,7 @@ pub struct InnerBackend<D: 'static> {
impl<D> InnerBackend<D> {
pub fn new() -> Result<Self, InitError> {
#[cfg(any(target_os = "linux", target_os = "android"))]
let poll_fd = epoll_create1(EpollCreateFlags::EPOLL_CLOEXEC)
let poll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)
.map_err(Into::into)
.map_err(InitError::Io)?;

Expand All @@ -47,9 +47,7 @@ impl<D> InnerBackend<D> {
))]
let poll_fd = kqueue().map_err(Into::into).map_err(InitError::Io)?;

Ok(Self {
state: Arc::new(Mutex::new(State::new(unsafe { OwnedFd::from_raw_fd(poll_fd) }))),
})
Ok(Self { state: Arc::new(Mutex::new(State::new(poll_fd))) })
}

pub fn flush(&self, client: Option<ClientId>) -> std::io::Result<()> {
Expand Down Expand Up @@ -80,18 +78,20 @@ impl<D> InnerBackend<D> {

#[cfg(any(target_os = "linux", target_os = "android"))]
pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
use std::os::unix::io::AsFd;

let poll_fd = self.poll_fd();
let mut dispatched = 0;
loop {
let mut events = [EpollEvent::empty(); 32];
let nevents = epoll_wait(poll_fd.as_raw_fd(), &mut events, 0)?;
let mut events = epoll::EventVec::with_capacity(32);
epoll::wait(poll_fd.as_fd(), &mut events, 0)?;

if nevents == 0 {
if events.is_empty() {
break;
}

for event in events.iter().take(nevents) {
let id = InnerClientId::from_u64(event.data());
for event in events.iter() {
let id = InnerClientId::from_u64(event.data.u64());
// remove the cb while we call it, to gracefully handle reentrancy
if let Ok(count) = self.dispatch_events_for(data, id) {
dispatched += count;
Expand All @@ -111,19 +111,13 @@ impl<D> InnerBackend<D> {
target_os = "openbsd"
))]
pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
use std::time::Duration;

let poll_fd = self.poll_fd();
let mut dispatched = 0;
loop {
let mut events = [KEvent::new(
0,
EventFilter::EVFILT_READ,
EventFlag::empty(),
FilterFlag::empty(),
0,
0,
); 32];

let nevents = kevent(poll_fd.as_raw_fd(), &[], &mut events, 0)?;
let mut events = Vec::with_capacity(32);
let nevents = unsafe { kevent(&poll_fd, &[], &mut events, Some(Duration::ZERO))? };

if nevents == 0 {
break;
Expand Down
33 changes: 17 additions & 16 deletions wayland-backend/src/rs/server_impl/handle.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::{
ffi::CString,
os::unix::io::OwnedFd,
os::unix::{
io::{AsRawFd, RawFd},
net::UnixStream,
},
os::unix::{io::RawFd, net::UnixStream},
sync::{Arc, Mutex, Weak},
};

Expand Down Expand Up @@ -314,15 +311,19 @@ impl<D> ErasedState for State<D> {
stream: UnixStream,
data: Arc<dyn ClientData>,
) -> std::io::Result<InnerClientId> {
let client_fd = stream.as_raw_fd();
let id = self.clients.create_client(stream, data);
let client = self.clients.get_client(id.clone()).unwrap();

// register the client to the internal epoll
#[cfg(any(target_os = "linux", target_os = "android"))]
let ret = {
use nix::sys::epoll::*;
let mut evt = EpollEvent::new(EpollFlags::EPOLLIN, id.as_u64());
epoll_ctl(self.poll_fd.as_raw_fd(), EpollOp::EpollCtlAdd, client_fd, &mut evt)
use rustix::event::epoll;
epoll::add(
&self.poll_fd,
client,
epoll::EventData::new_u64(id.as_u64()),
epoll::EventFlags::IN,
)
};

#[cfg(any(
Expand All @@ -332,17 +333,17 @@ impl<D> ErasedState for State<D> {
target_os = "openbsd"
))]
let ret = {
use nix::sys::event::*;
let evt = KEvent::new(
client_fd as usize,
EventFilter::EVFILT_READ,
EventFlag::EV_ADD | EventFlag::EV_RECEIPT,
FilterFlag::empty(),
0,
use rustix::event::kqueue::*;
use std::os::unix::io::{AsFd, AsRawFd};

let evt = Event::new(
EventFilter::Read(client.as_fd().as_raw_fd()),
EventFlags::ADD | EventFlags::RECEIPT,
id.as_u64() as isize,
);

kevent_ts(self.poll_fd.as_raw_fd(), &[evt], &mut [], None).map(|_| ())
let mut events = Vec::new();
unsafe { kevent(&self.poll_fd, &[evt], &mut events, None).map(|_| ()) }
};

match ret {
Expand Down
66 changes: 39 additions & 27 deletions wayland-backend/src/rs/socket.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
//! Wayland socket manipulation

use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
use std::os::unix::io::{AsFd, BorrowedFd, OwnedFd};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;

use nix::sys::socket;
use rustix::net::{
recvmsg, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
SendAncillaryMessage, SendFlags,
};

use crate::protocol::{ArgumentType, Message};

Expand Down Expand Up @@ -35,14 +38,19 @@ impl Socket {
/// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
/// end may lose some data.
pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> IoResult<usize> {
let flags = socket::MsgFlags::MSG_DONTWAIT | socket::MsgFlags::MSG_NOSIGNAL;
let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
let iov = [IoSlice::new(bytes)];

if !fds.is_empty() {
let cmsgs = [socket::ControlMessage::ScmRights(fds)];
Ok(socket::sendmsg::<()>(self.stream.as_raw_fd(), &iov, &cmsgs, flags, None)?)
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
let fds =
unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
Ok(sendmsg(self, &iov, &mut cmsg_buffer, flags)?)
} else {
Ok(socket::sendmsg::<()>(self.stream.as_raw_fd(), &iov, &[], flags, None)?)
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut []);
Ok(sendmsg(self, &iov, &mut cmsg_buffer, flags)?)
}
}

Expand All @@ -58,25 +66,27 @@ impl Socket {
/// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
/// be lost.
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> IoResult<(usize, usize)> {
let mut cmsg = nix::cmsg_space!([RawFd; MAX_FDS_OUT]);
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut iov = [IoSliceMut::new(buffer)];
let msg = socket::recvmsg::<()>(
self.stream.as_raw_fd(),
let msg = recvmsg(
&self.stream,
&mut iov[..],
Some(&mut cmsg),
socket::MsgFlags::MSG_DONTWAIT
| socket::MsgFlags::MSG_CMSG_CLOEXEC
| socket::MsgFlags::MSG_NOSIGNAL,
&mut cmsg_buffer,
RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC,
)?;

let mut fd_count = 0;
let received_fds = msg.cmsgs().flat_map(|cmsg| match cmsg {
socket::ControlMessageOwned::ScmRights(s) => s,
_ => Vec::new(),
});
let received_fds = cmsg_buffer
.drain()
.filter_map(|cmsg| match cmsg {
RecvAncillaryMessage::ScmRights(fds) => Some(fds),
_ => None,
})
.flatten();
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd;
*place = fd.into_raw_fd();
}
Ok((msg.bytes, fd_count))
}
Expand Down Expand Up @@ -141,7 +151,7 @@ impl BufferedSocket {
let written = self.socket.send_msg(bytes, fds)?;
for &fd in fds {
// once the fds are sent, we can close them
let _ = ::nix::unistd::close(fd);
unsafe { rustix::io::close(fd) };
}
written
};
Expand Down Expand Up @@ -192,7 +202,7 @@ impl BufferedSocket {
if !self.attempt_write_message(msg)? {
// If this fails again, this means the message is too big
// to be transmitted at all
return Err(::nix::errno::Errno::E2BIG.into());
return Err(rustix::io::Errno::TOOBIG.into());
}
}
Ok(())
Expand All @@ -215,7 +225,7 @@ impl BufferedSocket {
};
if in_bytes == 0 {
// the other end of the socket was closed
return Err(::nix::errno::Errno::EPIPE.into());
return Err(rustix::io::Errno::PIPE.into());
}
// advance the storage
self.in_data.advance(in_bytes / 4 + usize::from(in_bytes % 4 > 0));
Expand Down Expand Up @@ -342,14 +352,14 @@ mod tests {
use crate::protocol::{AllowNull, Argument, ArgumentType, Message};

use std::ffi::CString;
use std::os::unix::io::RawFd;
use std::os::unix::io::BorrowedFd;
use std::os::unix::prelude::IntoRawFd;

use smallvec::smallvec;

fn same_file(a: RawFd, b: RawFd) -> bool {
let stat1 = ::nix::sys::stat::fstat(a).unwrap();
let stat2 = ::nix::sys::stat::fstat(b).unwrap();
fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
let stat1 = rustix::fs::fstat(a).unwrap();
let stat2 = rustix::fs::fstat(b).unwrap();
stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
}

Expand All @@ -366,7 +376,9 @@ mod tests {
assert_eq!(msg1.args.len(), msg2.args.len());
for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
assert!(same_file(fd1.as_raw_fd(), fd2.as_raw_fd()));
let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
assert!(same_file(fd1, fd2));
} else {
assert_eq!(arg1, arg2);
}
Expand Down
2 changes: 1 addition & 1 deletion wayland-backend/src/sys/client_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl ConnectionState {
#[inline]
fn store_and_return_error(&mut self, err: std::io::Error) -> WaylandError {
// check if it was actually a protocol error
let err = if err.raw_os_error() == Some(nix::errno::Errno::EPROTO as i32) {
let err = if err.raw_os_error() == Some(rustix::io::Errno::PROTO.raw_os_error()) {
let mut object_id = 0;
let mut interface = std::ptr::null();
let code = unsafe {
Expand Down
Loading

0 comments on commit edd0f60

Please sign in to comment.