diff --git a/netlink-request/src/lib.rs b/netlink-request/src/lib.rs index 3f73dda..68bbcfc 100644 --- a/netlink-request/src/lib.rs +++ b/netlink-request/src/lib.rs @@ -1,7 +1,7 @@ #[cfg(target_os = "linux")] mod linux { - - const NETLINK_BUFFER_LENGTH: usize = 4096; + pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096; + pub const MAX_GENL_PAYLOAD_LENGTH: usize = MAX_NETLINK_BUFFER_LENGTH - GENL_HDRLEN; use netlink_packet_core::{ NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK, @@ -9,7 +9,7 @@ mod linux { }; use netlink_packet_generic::{ ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, - GenlFamily, GenlMessage, + GenlFamily, GenlMessage, constants::GENL_HDRLEN, }; use netlink_packet_route::RtnlMessage; use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket}; @@ -82,16 +82,16 @@ mod linux { { let mut req = NetlinkMessage::from(message); - if req.buffer_len() > NETLINK_BUFFER_LENGTH { + if req.buffer_len() > MAX_NETLINK_BUFFER_LENGTH { return Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("Serialized netlink packet larger than maximum size {}", NETLINK_BUFFER_LENGTH), + format!("Serialized netlink packet larger than maximum size {}", MAX_NETLINK_BUFFER_LENGTH), )); } req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.finalize(); - let mut buf = [0; NETLINK_BUFFER_LENGTH]; + let mut buf = [0; MAX_NETLINK_BUFFER_LENGTH]; req.serialize(&mut buf); let len = req.buffer_len(); @@ -133,4 +133,4 @@ mod linux { } #[cfg(target_os = "linux")] -pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl}; +pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl, MAX_NETLINK_BUFFER_LENGTH, MAX_GENL_PAYLOAD_LENGTH}; diff --git a/wireguard-control/src/backends/kernel.rs b/wireguard-control/src/backends/kernel.rs index aea685a..1922647 100644 --- a/wireguard-control/src/backends/kernel.rs +++ b/wireguard-control/src/backends/kernel.rs @@ -12,7 +12,7 @@ use netlink_packet_route::{ self, nlas::{Info, InfoKind}, }, - LinkMessage, RtnlMessage, + LinkMessage, RtnlMessage, traits::Emitable, }; use netlink_packet_wireguard::{ self, @@ -20,7 +20,7 @@ use netlink_packet_wireguard::{ nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs}, Wireguard, WireguardCmd, }; -use netlink_request::{netlink_request_genl, netlink_request_rtnl}; +use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH}; use std::{convert::TryFrom, io}; @@ -216,33 +216,105 @@ fn add_del(iface: &InterfaceName, add: bool) -> io::Result<()> { pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> { add_del(iface, true)?; - let mut nlas = vec![WgDeviceAttrs::IfName(iface.as_str_lossy().to_string())]; + let mut payload = ApplyPayload::new(iface); if let Some(Key(k)) = builder.private_key { - nlas.push(WgDeviceAttrs::PrivateKey(k)); + payload.push(WgDeviceAttrs::PrivateKey(k)); } if let Some(f) = builder.fwmark { - nlas.push(WgDeviceAttrs::Fwmark(f)); + payload.push(WgDeviceAttrs::Fwmark(f)); } if let Some(f) = builder.listen_port { - nlas.push(WgDeviceAttrs::ListenPort(f)); + payload.push(WgDeviceAttrs::ListenPort(f)); } if builder.replace_peers { - nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); + payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); } - let peers: Vec> = builder - .peers + + builder.peers .iter() - .map(PeerConfigBuilder::to_attrs) - .collect(); - nlas.push(WgDeviceAttrs::Peers(peers)); - let genlmsg: GenlMessage = GenlMessage::from_payload(Wireguard { - cmd: WireguardCmd::SetDevice, - nlas, - }); - netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?; + .for_each(|peer| { + payload.push_peer(peer.to_attrs()) + }); + + for message in payload.finish() { + netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?; + } Ok(()) } +struct ApplyPayload { + iface: String, + nlas: Vec, + current_buffer_len: usize, + messages: Vec>, +} + +impl ApplyPayload { + fn new(iface: &InterfaceName) -> Self { + Self { + iface: iface.as_str_lossy().to_string(), + nlas: vec![], + messages: vec![], + current_buffer_len: 0, + } + } + + fn flush_nlas(&mut self) { + // cleanup: clear out any empty peer lists. + self.nlas.retain(|nla| !matches!(nla, WgDeviceAttrs::Peers(peers) if peers.len() == 0)); + + let name = WgDeviceAttrs::IfName(self.iface.clone()); + self.current_buffer_len = name.buffer_len(); + + if !self.nlas.is_empty() { + self.messages.push(GenlMessage::from_payload(Wireguard { + cmd: WireguardCmd::SetDevice, + nlas: std::mem::replace(&mut self.nlas, vec![name]), + })); + } + } + + /// Push a device attribute which will be optimally packed into 1 or more netlink messages + pub fn push(&mut self, nla: WgDeviceAttrs) { + let nla_buffer_len = nla.buffer_len(); + if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + self.flush_nlas(); + } + self.nlas.push(nla); + self.current_buffer_len += nla_buffer_len; + } + + /// A helper function to assist in breaking up large peer lists across multiple netlink messages + pub fn push_peer(&mut self, peer: Vec) { + const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); + let mut needs_peer_nla = !self.nlas.iter().any(|nla| matches!(nla, WgDeviceAttrs::Peers(_))); + let peer_buffer_len = peer.as_slice().buffer_len() + 4; + let additional_buffer_len = peer_buffer_len + if needs_peer_nla { EMPTY_PEERS.buffer_len() } else { 0 }; + if (self.current_buffer_len + additional_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + self.flush_nlas(); + needs_peer_nla = true; + } + + if needs_peer_nla { + self.push(EMPTY_PEERS); + } + let peers_nla = self.nlas.iter_mut().find_map(|nla| { + match nla { + WgDeviceAttrs::Peers(peers) => Some(peers), + _ => None, + } + }).expect("WgDeviceAttrs::Peers missing from NLAs when it should exist."); + + peers_nla.push(peer); + self.current_buffer_len += peer_buffer_len; + } + + pub fn finish(mut self) -> Vec> { + self.flush_nlas(); + self.messages + } +} + pub fn get_by_name(name: &InterfaceName) -> Result { let genlmsg: GenlMessage = GenlMessage::from_payload(Wireguard { cmd: WireguardCmd::GetDevice, @@ -265,3 +337,57 @@ pub fn get_by_name(name: &InterfaceName) -> Result { pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { add_del(iface, false) } + +#[cfg(test)] +mod tests { + use super::*; + use netlink_request::MAX_NETLINK_BUFFER_LENGTH; + use std::str::FromStr; + + #[test] + fn test_simple_payload() { + let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap()); + payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])); + payload.push(WgDeviceAttrs::Fwmark(111)); + payload.push(WgDeviceAttrs::ListenPort(12345)); + payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); + payload.push_peer(vec![ + WgPeerAttrs::PublicKey([2u8; 32]), + WgPeerAttrs::PersistentKeepalive(25), + WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), + WgPeerAttrs::AllowedIps(vec![vec![ + WgAllowedIpAttrs::Family(AF_INET), + WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), + WgAllowedIpAttrs::Cidr(24) + ]]), + ]); + assert_eq!(payload.finish().len(), 1); + } + + #[test] + fn test_massive_payload() { + let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap()); + payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])); + payload.push(WgDeviceAttrs::Fwmark(111)); + payload.push(WgDeviceAttrs::ListenPort(12345)); + payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); + for _ in 0..10_000 { + payload.push_peer(vec![ + WgPeerAttrs::PublicKey([2u8; 32]), + WgPeerAttrs::PersistentKeepalive(25), + WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), + WgPeerAttrs::AllowedIps(vec![vec![ + WgAllowedIpAttrs::Family(AF_INET), + WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), + WgAllowedIpAttrs::Cidr(24) + ]]), + ]); + } + + let messages = payload.finish(); + assert!(messages.len() > 1); + for message in messages { + assert!(message.buffer_len() < MAX_NETLINK_BUFFER_LENGTH); + } + } +} \ No newline at end of file