Use the proper netlink buffer size with large kernel pages

The recommended netlink buffer size is based on the system's page size,
which means that the current size is far too small for systems with 16k
or 64k pages, such as Asahi Linux or RHEL's kernel-64k for ARM64. On
these systems, the server fails to start with errors like this:

Error: Decode error occurred: invalid netlink buffer: length field says 1444 the buffer is 1260 bytes long

Instead, follow the kernel's own netlink docs to compute the buffer
size. The approach here matches the approach merged into Chromium
recently:

https://chromium-review.googlesource.com/c/chromium/src/+/4312885
pull/264/head
Ryan Gonzalez 2023-05-27 16:59:32 -05:00 committed by Matěj Laitl
parent ae96e05e90
commit f67457e0a4
4 changed files with 45 additions and 18 deletions

6
Cargo.lock generated
View File

@ -624,6 +624,8 @@ dependencies = [
"netlink-packet-generic", "netlink-packet-generic",
"netlink-packet-route", "netlink-packet-route",
"netlink-sys", "netlink-sys",
"nix",
"once_cell",
] ]
[[package]] [[package]]
@ -663,9 +665,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.17.0" version = "1.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" checksum = "9670a07f94779e00908f3e686eab508878ebb390ba6e604d3a284c00e8d0487b"
[[package]] [[package]]
name = "os_str_bytes" name = "os_str_bytes"

View File

@ -8,3 +8,5 @@ netlink-sys = "0.8"
netlink-packet-core = "0.4" netlink-packet-core = "0.4"
netlink-packet-generic = "0.3" netlink-packet-generic = "0.3"
netlink-packet-route = "0.13" netlink-packet-route = "0.13"
nix = { version = "0.25", features = ["feature"] }
once_cell = "1"

View File

@ -1,9 +1,5 @@
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
mod linux { mod linux {
pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096;
pub const MAX_GENL_PAYLOAD_LENGTH: usize =
MAX_NETLINK_BUFFER_LENGTH - NETLINK_HEADER_LEN - GENL_HDRLEN;
use netlink_packet_core::{ use netlink_packet_core::{
NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST,
@ -15,6 +11,8 @@ mod linux {
}; };
use netlink_packet_route::RtnlMessage; use netlink_packet_route::RtnlMessage;
use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket}; use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
use nix::unistd::{sysconf, SysconfVar};
use once_cell::sync::OnceCell;
use std::{fmt::Debug, io}; use std::{fmt::Debug, io};
macro_rules! get_nla_value { macro_rules! get_nla_value {
@ -26,6 +24,26 @@ mod linux {
}; };
} }
pub fn max_netlink_buffer_length() -> usize {
static LENGTH: OnceCell<usize> = OnceCell::new();
*LENGTH.get_or_init(|| {
// https://www.kernel.org/doc/html/v6.2/userspace-api/netlink/intro.html#buffer-sizing
// "Netlink expects that the user buffer will be at least 8kB or a page
// size of the CPU architecture, whichever is bigger."
const MIN_NELINK_BUFFER_LENGTH: usize = 8 * 1024;
// Note that sysconf only returns Err / Ok(None) when the parameter is
// invalid, unsupported on the current OS, or an unset limit. PAGE_SIZE
// is *required* to be supported and is not considered a limit, so this
// should never fail unless something has gone massively wrong.
let page_size = sysconf(SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
std::cmp::max(MIN_NELINK_BUFFER_LENGTH, page_size)
})
}
pub fn max_genl_payload_length() -> usize {
max_netlink_buffer_length() - NETLINK_HEADER_LEN - GENL_HDRLEN
}
pub fn netlink_request_genl<F>( pub fn netlink_request_genl<F>(
mut message: GenlMessage<F>, mut message: GenlMessage<F>,
flags: Option<u16>, flags: Option<u16>,
@ -84,13 +102,14 @@ mod linux {
{ {
let mut req = NetlinkMessage::from(message); let mut req = NetlinkMessage::from(message);
if req.buffer_len() > MAX_NETLINK_BUFFER_LENGTH { let max_buffer_len = max_netlink_buffer_length();
if req.buffer_len() > max_buffer_len {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!( format!(
"Serialized netlink packet ({} bytes) larger than maximum size {}: {:?}", "Serialized netlink packet ({} bytes) larger than maximum size {}: {:?}",
req.buffer_len(), req.buffer_len(),
MAX_NETLINK_BUFFER_LENGTH, max_buffer_len,
req req
), ),
)); ));
@ -98,7 +117,7 @@ mod linux {
req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
req.finalize(); req.finalize();
let mut buf = [0; MAX_NETLINK_BUFFER_LENGTH]; let mut buf = vec![0; max_buffer_len];
req.serialize(&mut buf); req.serialize(&mut buf);
let len = req.buffer_len(); let len = req.buffer_len();
@ -141,6 +160,6 @@ mod linux {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub use linux::{ pub use linux::{
netlink_request, netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH, max_genl_payload_length, max_netlink_buffer_length, netlink_request, netlink_request_genl,
MAX_NETLINK_BUFFER_LENGTH, netlink_request_rtnl,
}; };

View File

@ -21,7 +21,7 @@ use netlink_packet_wireguard::{
nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs}, nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
Wireguard, WireguardCmd, Wireguard, WireguardCmd,
}; };
use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH}; use netlink_request::{max_genl_payload_length, netlink_request_genl, netlink_request_rtnl};
use std::{convert::TryFrom, io}; use std::{convert::TryFrom, io};
@ -285,13 +285,15 @@ impl ApplyPayload {
/// Push a device attribute which will be optimally packed into 1 or more netlink messages /// Push a device attribute which will be optimally packed into 1 or more netlink messages
pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> { pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> {
let max_payload_len = max_genl_payload_length();
let nla_buffer_len = nla.buffer_len(); let nla_buffer_len = nla.buffer_len();
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
self.flush_nlas(); self.flush_nlas();
} }
// If the NLA *still* doesn't fit... // If the NLA *still* doesn't fit...
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("encoded NLA ({nla_buffer_len} bytes) is too large: {nla:?}"), format!("encoded NLA ({nla_buffer_len} bytes) is too large: {nla:?}"),
@ -305,6 +307,7 @@ impl ApplyPayload {
/// A helper function to assist in breaking up large peer lists across multiple netlink messages /// A helper function to assist in breaking up large peer lists across multiple netlink messages
pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> { pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> {
const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
let max_payload_len = max_genl_payload_length();
let mut needs_peer_nla = !self let mut needs_peer_nla = !self
.nlas .nlas
.iter() .iter()
@ -314,7 +317,7 @@ impl ApplyPayload {
if needs_peer_nla { if needs_peer_nla {
additional_buffer_len += EMPTY_PEERS.buffer_len(); additional_buffer_len += EMPTY_PEERS.buffer_len();
} }
if (self.current_buffer_len + additional_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { if (self.current_buffer_len + additional_buffer_len) > max_payload_len {
self.flush_nlas(); self.flush_nlas();
needs_peer_nla = true; needs_peer_nla = true;
} }
@ -324,7 +327,7 @@ impl ApplyPayload {
} }
// If the peer *still* doesn't fit... // If the peer *still* doesn't fit...
if (self.current_buffer_len + peer_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { if (self.current_buffer_len + peer_buffer_len) > max_payload_len {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("encoded peer ({peer_buffer_len} bytes) is too large: {peer:?}"), format!("encoded peer ({peer_buffer_len} bytes) is too large: {peer:?}"),
@ -397,7 +400,7 @@ pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
mod tests { mod tests {
use super::*; use super::*;
use netlink_packet_wireguard::nlas::WgAllowedIp; use netlink_packet_wireguard::nlas::WgAllowedIp;
use netlink_request::MAX_NETLINK_BUFFER_LENGTH; use netlink_request::max_netlink_buffer_length;
use std::str::FromStr; use std::str::FromStr;
#[test] #[test]
@ -455,8 +458,9 @@ mod tests {
let messages = payload.finish(); let messages = payload.finish();
println!("generated {} messages", messages.len()); println!("generated {} messages", messages.len());
assert!(messages.len() > 1); assert!(messages.len() > 1);
let max_buffer_len = max_netlink_buffer_length();
for message in messages { for message in messages {
assert!(NetlinkMessage::from(message).buffer_len() <= MAX_NETLINK_BUFFER_LENGTH); assert!(NetlinkMessage::from(message).buffer_len() <= max_buffer_len);
} }
} }
} }