Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the proper netlink buffer size with large kernel pages #258

Merged
merged 1 commit into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions netlink-request/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ netlink-sys = "0.8"
netlink-packet-core = "0.4"
netlink-packet-generic = "0.3"
netlink-packet-route = "0.13"
nix = { version = "0.25", features = ["feature"] }
once_cell = "1"
37 changes: 28 additions & 9 deletions netlink-request/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#[cfg(target_os = "linux")]
mod linux {
pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096;
pub const MAX_GENL_PAYLOAD_LENGTH: usize =
MAX_NETLINK_BUFFER_LENGTH - NETLINK_HEADER_LEN - GENL_HDRLEN;

use netlink_packet_core::{
NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST,
Expand All @@ -15,6 +11,8 @@ mod linux {
};
use netlink_packet_route::RtnlMessage;
use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
use nix::unistd::{sysconf, SysconfVar};
use once_cell::sync::OnceCell;
use std::{fmt::Debug, io};

macro_rules! get_nla_value {
Expand All @@ -26,6 +24,26 @@ mod linux {
};
}

pub fn max_netlink_buffer_length() -> usize {
static LENGTH: OnceCell<usize> = OnceCell::new();
*LENGTH.get_or_init(|| {
// https://www.kernel.org/doc/html/v6.2/userspace-api/netlink/intro.html#buffer-sizing
// "Netlink expects that the user buffer will be at least 8kB or a page
// size of the CPU architecture, whichever is bigger."
const MIN_NELINK_BUFFER_LENGTH: usize = 8 * 1024;
// Note that sysconf only returns Err / Ok(None) when the parameter is
// invalid, unsupported on the current OS, or an unset limit. PAGE_SIZE
// is *required* to be supported and is not considered a limit, so this
// should never fail unless something has gone massively wrong.
let page_size = sysconf(SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
std::cmp::max(MIN_NELINK_BUFFER_LENGTH, page_size)
})
}

pub fn max_genl_payload_length() -> usize {
max_netlink_buffer_length() - NETLINK_HEADER_LEN - GENL_HDRLEN
}

pub fn netlink_request_genl<F>(
mut message: GenlMessage<F>,
flags: Option<u16>,
Expand Down Expand Up @@ -84,21 +102,22 @@ mod linux {
{
let mut req = NetlinkMessage::from(message);

if req.buffer_len() > MAX_NETLINK_BUFFER_LENGTH {
let max_buffer_len = max_netlink_buffer_length();
if req.buffer_len() > max_buffer_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Serialized netlink packet ({} bytes) larger than maximum size {}: {:?}",
req.buffer_len(),
MAX_NETLINK_BUFFER_LENGTH,
max_buffer_len,
req
),
));
}

req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
req.finalize();
let mut buf = [0; MAX_NETLINK_BUFFER_LENGTH];
let mut buf = vec![0; max_buffer_len];
req.serialize(&mut buf);
let len = req.buffer_len();

Expand Down Expand Up @@ -141,6 +160,6 @@ mod linux {

#[cfg(target_os = "linux")]
pub use linux::{
netlink_request, netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH,
MAX_NETLINK_BUFFER_LENGTH,
max_genl_payload_length, max_netlink_buffer_length, netlink_request, netlink_request_genl,
netlink_request_rtnl,
};
18 changes: 11 additions & 7 deletions wireguard-control/src/backends/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use netlink_packet_wireguard::{
nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
Wireguard, WireguardCmd,
};
use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH};
use netlink_request::{max_genl_payload_length, netlink_request_genl, netlink_request_rtnl};

use std::{convert::TryFrom, io};

Expand Down Expand Up @@ -285,13 +285,15 @@ impl ApplyPayload {

/// Push a device attribute which will be optimally packed into 1 or more netlink messages
pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> {
let max_payload_len = max_genl_payload_length();

let nla_buffer_len = nla.buffer_len();
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
self.flush_nlas();
}

// If the NLA *still* doesn't fit...
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("encoded NLA ({nla_buffer_len} bytes) is too large: {nla:?}"),
Expand All @@ -305,6 +307,7 @@ impl ApplyPayload {
/// A helper function to assist in breaking up large peer lists across multiple netlink messages
pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> {
const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
let max_payload_len = max_genl_payload_length();
let mut needs_peer_nla = !self
.nlas
.iter()
Expand All @@ -314,7 +317,7 @@ impl ApplyPayload {
if needs_peer_nla {
additional_buffer_len += EMPTY_PEERS.buffer_len();
}
if (self.current_buffer_len + additional_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
if (self.current_buffer_len + additional_buffer_len) > max_payload_len {
self.flush_nlas();
needs_peer_nla = true;
}
Expand All @@ -324,7 +327,7 @@ impl ApplyPayload {
}

// If the peer *still* doesn't fit...
if (self.current_buffer_len + peer_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
if (self.current_buffer_len + peer_buffer_len) > max_payload_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("encoded peer ({peer_buffer_len} bytes) is too large: {peer:?}"),
Expand Down Expand Up @@ -397,7 +400,7 @@ pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
mod tests {
use super::*;
use netlink_packet_wireguard::nlas::WgAllowedIp;
use netlink_request::MAX_NETLINK_BUFFER_LENGTH;
use netlink_request::max_netlink_buffer_length;
use std::str::FromStr;

#[test]
Expand Down Expand Up @@ -455,8 +458,9 @@ mod tests {
let messages = payload.finish();
println!("generated {} messages", messages.len());
assert!(messages.len() > 1);
let max_buffer_len = max_netlink_buffer_length();
for message in messages {
assert!(NetlinkMessage::from(message).buffer_len() <= MAX_NETLINK_BUFFER_LENGTH);
assert!(NetlinkMessage::from(message).buffer_len() <= max_buffer_len);
}
}
}