diff --git a/Cargo.lock b/Cargo.lock index be6264e..97f7069 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -624,6 +624,8 @@ dependencies = [ "netlink-packet-generic", "netlink-packet-route", "netlink-sys", + "nix", + "once_cell", ] [[package]] @@ -663,9 +665,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.0" +version = "1.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" +checksum = "9670a07f94779e00908f3e686eab508878ebb390ba6e604d3a284c00e8d0487b" [[package]] name = "os_str_bytes" diff --git a/netlink-request/Cargo.toml b/netlink-request/Cargo.toml index 0235e31..71ee703 100644 --- a/netlink-request/Cargo.toml +++ b/netlink-request/Cargo.toml @@ -8,3 +8,5 @@ netlink-sys = "0.8" netlink-packet-core = "0.4" netlink-packet-generic = "0.3" netlink-packet-route = "0.13" +nix = { version = "0.25", features = ["feature"] } +once_cell = "1" diff --git a/netlink-request/src/lib.rs b/netlink-request/src/lib.rs index 152a152..686498e 100644 --- a/netlink-request/src/lib.rs +++ b/netlink-request/src/lib.rs @@ -1,9 +1,5 @@ #[cfg(target_os = "linux")] mod linux { - pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096; - pub const MAX_GENL_PAYLOAD_LENGTH: usize = - MAX_NETLINK_BUFFER_LENGTH - NETLINK_HEADER_LEN - GENL_HDRLEN; - use netlink_packet_core::{ NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, @@ -15,6 +11,8 @@ mod linux { }; use netlink_packet_route::RtnlMessage; use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket}; + use nix::unistd::{sysconf, SysconfVar}; + use once_cell::sync::OnceCell; use std::{fmt::Debug, io}; macro_rules! get_nla_value { @@ -26,6 +24,26 @@ mod linux { }; } + pub fn max_netlink_buffer_length() -> usize { + static LENGTH: OnceCell = OnceCell::new(); + *LENGTH.get_or_init(|| { + // https://www.kernel.org/doc/html/v6.2/userspace-api/netlink/intro.html#buffer-sizing + // "Netlink expects that the user buffer will be at least 8kB or a page + // size of the CPU architecture, whichever is bigger." + const MIN_NELINK_BUFFER_LENGTH: usize = 8 * 1024; + // Note that sysconf only returns Err / Ok(None) when the parameter is + // invalid, unsupported on the current OS, or an unset limit. PAGE_SIZE + // is *required* to be supported and is not considered a limit, so this + // should never fail unless something has gone massively wrong. + let page_size = sysconf(SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize; + std::cmp::max(MIN_NELINK_BUFFER_LENGTH, page_size) + }) + } + + pub fn max_genl_payload_length() -> usize { + max_netlink_buffer_length() - NETLINK_HEADER_LEN - GENL_HDRLEN + } + pub fn netlink_request_genl( mut message: GenlMessage, flags: Option, @@ -84,13 +102,14 @@ mod linux { { let mut req = NetlinkMessage::from(message); - if req.buffer_len() > MAX_NETLINK_BUFFER_LENGTH { + let max_buffer_len = max_netlink_buffer_length(); + if req.buffer_len() > max_buffer_len { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!( "Serialized netlink packet ({} bytes) larger than maximum size {}: {:?}", req.buffer_len(), - MAX_NETLINK_BUFFER_LENGTH, + max_buffer_len, req ), )); @@ -98,7 +117,7 @@ mod linux { req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.finalize(); - let mut buf = [0; MAX_NETLINK_BUFFER_LENGTH]; + let mut buf = vec![0; max_buffer_len]; req.serialize(&mut buf); let len = req.buffer_len(); @@ -141,6 +160,6 @@ mod linux { #[cfg(target_os = "linux")] pub use linux::{ - netlink_request, netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH, - MAX_NETLINK_BUFFER_LENGTH, + max_genl_payload_length, max_netlink_buffer_length, netlink_request, netlink_request_genl, + netlink_request_rtnl, }; diff --git a/wireguard-control/src/backends/kernel.rs b/wireguard-control/src/backends/kernel.rs index bf7edaf..9c60bf1 100644 --- a/wireguard-control/src/backends/kernel.rs +++ b/wireguard-control/src/backends/kernel.rs @@ -21,7 +21,7 @@ use netlink_packet_wireguard::{ nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs}, Wireguard, WireguardCmd, }; -use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH}; +use netlink_request::{max_genl_payload_length, netlink_request_genl, netlink_request_rtnl}; use std::{convert::TryFrom, io}; @@ -285,13 +285,15 @@ impl ApplyPayload { /// Push a device attribute which will be optimally packed into 1 or more netlink messages pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> { + let max_payload_len = max_genl_payload_length(); + let nla_buffer_len = nla.buffer_len(); - if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + if (self.current_buffer_len + nla_buffer_len) > max_payload_len { self.flush_nlas(); } // If the NLA *still* doesn't fit... - if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + if (self.current_buffer_len + nla_buffer_len) > max_payload_len { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!("encoded NLA ({nla_buffer_len} bytes) is too large: {nla:?}"), @@ -305,6 +307,7 @@ impl ApplyPayload { /// A helper function to assist in breaking up large peer lists across multiple netlink messages pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> { const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); + let max_payload_len = max_genl_payload_length(); let mut needs_peer_nla = !self .nlas .iter() @@ -314,7 +317,7 @@ impl ApplyPayload { if needs_peer_nla { additional_buffer_len += EMPTY_PEERS.buffer_len(); } - if (self.current_buffer_len + additional_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + if (self.current_buffer_len + additional_buffer_len) > max_payload_len { self.flush_nlas(); needs_peer_nla = true; } @@ -324,7 +327,7 @@ impl ApplyPayload { } // If the peer *still* doesn't fit... - if (self.current_buffer_len + peer_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + if (self.current_buffer_len + peer_buffer_len) > max_payload_len { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!("encoded peer ({peer_buffer_len} bytes) is too large: {peer:?}"), @@ -397,7 +400,7 @@ pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { mod tests { use super::*; use netlink_packet_wireguard::nlas::WgAllowedIp; - use netlink_request::MAX_NETLINK_BUFFER_LENGTH; + use netlink_request::max_netlink_buffer_length; use std::str::FromStr; #[test] @@ -455,8 +458,9 @@ mod tests { let messages = payload.finish(); println!("generated {} messages", messages.len()); assert!(messages.len() > 1); + let max_buffer_len = max_netlink_buffer_length(); for message in messages { - assert!(NetlinkMessage::from(message).buffer_len() <= MAX_NETLINK_BUFFER_LENGTH); + assert!(NetlinkMessage::from(message).buffer_len() <= max_buffer_len); } } }