Skip to content

Commit 53b4738

Browse files
committed
Add support for Netlink socket addresses
1 parent 880bbd3 commit 53b4738

File tree

6 files changed

+179
-0
lines changed

6 files changed

+179
-0
lines changed

src/backend/libc/net/read_sockaddr.rs

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::backend::c;
88
#[cfg(not(windows))]
99
use crate::ffi::CStr;
1010
use crate::io::Errno;
11+
#[cfg(linux_kernel)]
12+
use crate::net::netlink::SocketAddrNetlink;
1113
#[cfg(target_os = "linux")]
1214
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
1315
use crate::net::{AddressFamily, Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
@@ -239,3 +241,14 @@ pub(crate) fn read_sockaddr_xdp(addr: &SocketAddrAny) -> Result<SocketAddrXdp, E
239241
u32::from_be(decode.sxdp_shared_umem_fd),
240242
))
241243
}
244+
245+
#[cfg(linux_kernel)]
246+
#[inline]
247+
pub(crate) fn read_sockaddr_netlink(addr: &SocketAddrAny) -> Result<SocketAddrNetlink, Errno> {
248+
if addr.address_family() != AddressFamily::NETLINK {
249+
return Err(Errno::AFNOSUPPORT);
250+
}
251+
assert!(addr.len() >= size_of::<c::sockaddr_nl>());
252+
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_nl>() };
253+
Ok(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups))
254+
}

src/backend/linux_raw/net/read_sockaddr.rs

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
use crate::backend::c;
66
use crate::io::Errno;
7+
use crate::net::netlink::SocketAddrNetlink;
78
#[cfg(target_os = "linux")]
89
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
910
use crate::net::{AddressFamily, SocketAddrAny};
@@ -133,3 +134,13 @@ pub(crate) fn read_sockaddr_xdp(addr: &SocketAddrAny) -> Result<SocketAddrXdp, E
133134
u32::from_be(decode.sxdp_shared_umem_fd),
134135
))
135136
}
137+
138+
#[inline]
139+
pub(crate) fn read_sockaddr_netlink(addr: &SocketAddrAny) -> Result<SocketAddrNetlink, Errno> {
140+
if addr.address_family() != AddressFamily::NETLINK {
141+
return Err(Errno::AFNOSUPPORT);
142+
}
143+
assert!(addr.len() >= size_of::<c::sockaddr_nl>());
144+
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_nl>() };
145+
Ok(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups))
146+
}

src/net/socket_addr_any.rs

+6
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ impl fmt::Debug for SocketAddrAny {
191191
return addr.fmt(f);
192192
}
193193
}
194+
#[cfg(linux_kernel)]
195+
AddressFamily::NETLINK => {
196+
if let Ok(addr) = crate::net::netlink::SocketAddrNetlink::try_from(self.clone()) {
197+
return addr.fmt(f);
198+
}
199+
}
194200
_ => {}
195201
}
196202

src/net/types.rs

+83
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,12 @@ pub mod netlink {
10031003
use {
10041004
super::{new_raw_protocol, Protocol},
10051005
crate::backend::c,
1006+
crate::backend::net::read_sockaddr::read_sockaddr_netlink,
1007+
crate::net::{
1008+
addr::{call_with_sockaddr, SocketAddrArg, SocketAddrOpaque},
1009+
SocketAddrAny,
1010+
},
1011+
core::mem,
10061012
};
10071013

10081014
/// `NETLINK_UNUSED`
@@ -1112,6 +1118,83 @@ pub mod netlink {
11121118
/// `NETLINK_GET_STRICT_CHK`
11131119
#[cfg(linux_kernel)]
11141120
pub const GET_STRICT_CHK: Protocol = Protocol(new_raw_protocol(c::NETLINK_GET_STRICT_CHK as _));
1121+
1122+
/// A Netlink socket address.
1123+
///
1124+
/// Used to bind to a Netlink socket.
1125+
///
1126+
/// Not ABI compatible with `struct sockaddr_nl`
1127+
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
1128+
#[cfg(linux_kernel)]
1129+
pub struct SocketAddrNetlink {
1130+
/// Port ID
1131+
pid: u32,
1132+
1133+
/// Multicast groups mask
1134+
groups: u32,
1135+
}
1136+
1137+
#[cfg(linux_kernel)]
1138+
impl SocketAddrNetlink {
1139+
/// Construct a netlink address
1140+
#[inline]
1141+
pub fn new(pid: u32, groups: u32) -> Self {
1142+
Self { pid, groups }
1143+
}
1144+
1145+
/// Return port id.
1146+
#[inline]
1147+
pub fn pid(&self) -> u32 {
1148+
self.pid
1149+
}
1150+
1151+
/// Set port id.
1152+
#[inline]
1153+
pub fn set_pid(&mut self, pid: u32) {
1154+
self.pid = pid;
1155+
}
1156+
1157+
/// Return multicast groups mask.
1158+
#[inline]
1159+
pub fn groups(&self) -> u32 {
1160+
self.groups
1161+
}
1162+
1163+
/// Set multicast groups mask.
1164+
#[inline]
1165+
pub fn set_groups(&mut self, groups: u32) {
1166+
self.groups = groups;
1167+
}
1168+
}
1169+
1170+
#[cfg(linux_kernel)]
1171+
#[allow(unsafe_code)]
1172+
unsafe impl SocketAddrArg for SocketAddrNetlink {
1173+
fn with_sockaddr<R>(&self, f: impl FnOnce(*const SocketAddrOpaque, usize) -> R) -> R {
1174+
let mut addr: c::sockaddr_nl = unsafe { mem::zeroed() };
1175+
addr.nl_family = c::AF_NETLINK as _;
1176+
addr.nl_pid = self.pid;
1177+
addr.nl_groups = self.groups;
1178+
call_with_sockaddr(&addr, f)
1179+
}
1180+
}
1181+
1182+
#[cfg(linux_kernel)]
1183+
impl From<SocketAddrNetlink> for SocketAddrAny {
1184+
#[inline]
1185+
fn from(from: SocketAddrNetlink) -> Self {
1186+
from.as_any()
1187+
}
1188+
}
1189+
1190+
#[cfg(linux_kernel)]
1191+
impl TryFrom<SocketAddrAny> for SocketAddrNetlink {
1192+
type Error = crate::io::Errno;
1193+
1194+
fn try_from(addr: SocketAddrAny) -> Result<Self, Self::Error> {
1195+
read_sockaddr_netlink(&addr)
1196+
}
1197+
}
11151198
}
11161199

11171200
/// `ETH_P_*` constants.

tests/net/main.rs

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ mod addr;
1010
mod cmsg;
1111
mod connect_bind_send;
1212
mod dgram;
13+
#[cfg(linux_kernel)]
14+
mod netlink;
1315
#[cfg(feature = "event")]
1416
mod poll;
1517
#[cfg(unix)]

tests/net/netlink.rs

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use rustix::net::netlink::{self, SocketAddrNetlink};
2+
use rustix::net::{
3+
bind, getsockname, recvfrom, sendto, socket_with, AddressFamily, RecvFlags, SendFlags,
4+
SocketAddrAny, SocketFlags, SocketType,
5+
};
6+
7+
#[test]
8+
fn encode_decode() {
9+
let orig = SocketAddrNetlink::new(0x12345678, 0x9abcdef0);
10+
let encoded = SocketAddrAny::from(orig);
11+
let decoded = SocketAddrNetlink::try_from(encoded).unwrap();
12+
assert_eq!(decoded, orig);
13+
}
14+
15+
#[test]
16+
fn test_bind_kobject_uevent() {
17+
let server = socket_with(
18+
AddressFamily::NETLINK,
19+
SocketType::RAW,
20+
SocketFlags::CLOEXEC,
21+
Some(netlink::KOBJECT_UEVENT),
22+
)
23+
.unwrap();
24+
25+
bind(&server, &SocketAddrNetlink::new(0, 1)).unwrap();
26+
}
27+
28+
#[test]
29+
#[cfg_attr(
30+
not(any(target_arch = "x86", target_arch = "x86_64")),
31+
ignore = "qemu used in CI does not support NETLINK_USERSOCK"
32+
)]
33+
fn test_usersock() {
34+
let server = socket_with(
35+
AddressFamily::NETLINK,
36+
SocketType::RAW,
37+
SocketFlags::CLOEXEC,
38+
Some(netlink::USERSOCK),
39+
)
40+
.unwrap();
41+
42+
bind(&server, &SocketAddrNetlink::new(0, 0)).unwrap();
43+
let addr = getsockname(&server).unwrap();
44+
let addr = SocketAddrNetlink::try_from(addr).unwrap();
45+
46+
let client = socket_with(
47+
AddressFamily::NETLINK,
48+
SocketType::RAW,
49+
SocketFlags::CLOEXEC,
50+
Some(netlink::USERSOCK),
51+
)
52+
.unwrap();
53+
54+
let data = b"ABCDEF";
55+
56+
sendto(client, data, SendFlags::empty(), &addr).unwrap();
57+
58+
let mut buffer = [0u8; 4096];
59+
let (len, src) = recvfrom(&server, &mut buffer, RecvFlags::empty()).unwrap();
60+
61+
assert_eq!(&buffer[..len], data);
62+
let src = SocketAddrNetlink::try_from(src.unwrap()).unwrap();
63+
assert_eq!(src.groups(), 0);
64+
}

0 commit comments

Comments
 (0)