wireguard-control: break up large updates into multiple netlink messages

pull/186/head
Jake McGinty 2022-01-09 22:56:39 -06:00
parent 4784a695ad
commit 92b60f535d
2 changed files with 150 additions and 24 deletions

View File

@ -1,7 +1,7 @@
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
mod linux { mod linux {
pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096;
const NETLINK_BUFFER_LENGTH: usize = 4096; pub const MAX_GENL_PAYLOAD_LENGTH: usize = MAX_NETLINK_BUFFER_LENGTH - GENL_HDRLEN;
use netlink_packet_core::{ use netlink_packet_core::{
NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK, NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK,
@ -9,7 +9,7 @@ mod linux {
}; };
use netlink_packet_generic::{ use netlink_packet_generic::{
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
GenlFamily, GenlMessage, GenlFamily, GenlMessage, constants::GENL_HDRLEN,
}; };
use netlink_packet_route::RtnlMessage; use netlink_packet_route::RtnlMessage;
use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket}; use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
@ -82,16 +82,16 @@ mod linux {
{ {
let mut req = NetlinkMessage::from(message); let mut req = NetlinkMessage::from(message);
if req.buffer_len() > NETLINK_BUFFER_LENGTH { if req.buffer_len() > MAX_NETLINK_BUFFER_LENGTH {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("Serialized netlink packet larger than maximum size {}", NETLINK_BUFFER_LENGTH), format!("Serialized netlink packet larger than maximum size {}", MAX_NETLINK_BUFFER_LENGTH),
)); ));
} }
req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
req.finalize(); req.finalize();
let mut buf = [0; NETLINK_BUFFER_LENGTH]; let mut buf = [0; MAX_NETLINK_BUFFER_LENGTH];
req.serialize(&mut buf); req.serialize(&mut buf);
let len = req.buffer_len(); let len = req.buffer_len();
@ -133,4 +133,4 @@ mod linux {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl}; pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl, MAX_NETLINK_BUFFER_LENGTH, MAX_GENL_PAYLOAD_LENGTH};

View File

@ -12,7 +12,7 @@ use netlink_packet_route::{
self, self,
nlas::{Info, InfoKind}, nlas::{Info, InfoKind},
}, },
LinkMessage, RtnlMessage, LinkMessage, RtnlMessage, traits::Emitable,
}; };
use netlink_packet_wireguard::{ use netlink_packet_wireguard::{
self, self,
@ -20,7 +20,7 @@ use netlink_packet_wireguard::{
nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs}, nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs},
Wireguard, WireguardCmd, Wireguard, WireguardCmd,
}; };
use netlink_request::{netlink_request_genl, netlink_request_rtnl}; use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH};
use std::{convert::TryFrom, io}; use std::{convert::TryFrom, io};
@ -216,33 +216,105 @@ fn add_del(iface: &InterfaceName, add: bool) -> io::Result<()> {
pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> { pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
add_del(iface, true)?; add_del(iface, true)?;
let mut nlas = vec![WgDeviceAttrs::IfName(iface.as_str_lossy().to_string())]; let mut payload = ApplyPayload::new(iface);
if let Some(Key(k)) = builder.private_key { if let Some(Key(k)) = builder.private_key {
nlas.push(WgDeviceAttrs::PrivateKey(k)); payload.push(WgDeviceAttrs::PrivateKey(k));
} }
if let Some(f) = builder.fwmark { if let Some(f) = builder.fwmark {
nlas.push(WgDeviceAttrs::Fwmark(f)); payload.push(WgDeviceAttrs::Fwmark(f));
} }
if let Some(f) = builder.listen_port { if let Some(f) = builder.listen_port {
nlas.push(WgDeviceAttrs::ListenPort(f)); payload.push(WgDeviceAttrs::ListenPort(f));
} }
if builder.replace_peers { if builder.replace_peers {
nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
} }
let peers: Vec<Vec<_>> = builder
.peers builder.peers
.iter() .iter()
.map(PeerConfigBuilder::to_attrs) .for_each(|peer| {
.collect(); payload.push_peer(peer.to_attrs())
nlas.push(WgDeviceAttrs::Peers(peers));
let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
cmd: WireguardCmd::SetDevice,
nlas,
}); });
netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?;
for message in payload.finish() {
netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?;
}
Ok(()) Ok(())
} }
struct ApplyPayload {
iface: String,
nlas: Vec<WgDeviceAttrs>,
current_buffer_len: usize,
messages: Vec<GenlMessage<Wireguard>>,
}
impl ApplyPayload {
fn new(iface: &InterfaceName) -> Self {
Self {
iface: iface.as_str_lossy().to_string(),
nlas: vec![],
messages: vec![],
current_buffer_len: 0,
}
}
fn flush_nlas(&mut self) {
// cleanup: clear out any empty peer lists.
self.nlas.retain(|nla| !matches!(nla, WgDeviceAttrs::Peers(peers) if peers.len() == 0));
let name = WgDeviceAttrs::IfName(self.iface.clone());
self.current_buffer_len = name.buffer_len();
if !self.nlas.is_empty() {
self.messages.push(GenlMessage::from_payload(Wireguard {
cmd: WireguardCmd::SetDevice,
nlas: std::mem::replace(&mut self.nlas, vec![name]),
}));
}
}
/// Push a device attribute which will be optimally packed into 1 or more netlink messages
pub fn push(&mut self, nla: WgDeviceAttrs) {
let nla_buffer_len = nla.buffer_len();
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
self.flush_nlas();
}
self.nlas.push(nla);
self.current_buffer_len += nla_buffer_len;
}
/// A helper function to assist in breaking up large peer lists across multiple netlink messages
pub fn push_peer(&mut self, peer: Vec<WgPeerAttrs>) {
const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
let mut needs_peer_nla = !self.nlas.iter().any(|nla| matches!(nla, WgDeviceAttrs::Peers(_)));
let peer_buffer_len = peer.as_slice().buffer_len() + 4;
let additional_buffer_len = peer_buffer_len + if needs_peer_nla { EMPTY_PEERS.buffer_len() } else { 0 };
if (self.current_buffer_len + additional_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
self.flush_nlas();
needs_peer_nla = true;
}
if needs_peer_nla {
self.push(EMPTY_PEERS);
}
let peers_nla = self.nlas.iter_mut().find_map(|nla| {
match nla {
WgDeviceAttrs::Peers(peers) => Some(peers),
_ => None,
}
}).expect("WgDeviceAttrs::Peers missing from NLAs when it should exist.");
peers_nla.push(peer);
self.current_buffer_len += peer_buffer_len;
}
pub fn finish(mut self) -> Vec<GenlMessage<Wireguard>> {
self.flush_nlas();
self.messages
}
}
pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> { pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard { let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
cmd: WireguardCmd::GetDevice, cmd: WireguardCmd::GetDevice,
@ -265,3 +337,57 @@ pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
add_del(iface, false) add_del(iface, false)
} }
#[cfg(test)]
mod tests {
use super::*;
use netlink_request::MAX_NETLINK_BUFFER_LENGTH;
use std::str::FromStr;
#[test]
fn test_simple_payload() {
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32]));
payload.push(WgDeviceAttrs::Fwmark(111));
payload.push(WgDeviceAttrs::ListenPort(12345));
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
payload.push_peer(vec![
WgPeerAttrs::PublicKey([2u8; 32]),
WgPeerAttrs::PersistentKeepalive(25),
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
WgPeerAttrs::AllowedIps(vec![vec![
WgAllowedIpAttrs::Family(AF_INET),
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24)
]]),
]);
assert_eq!(payload.finish().len(), 1);
}
#[test]
fn test_massive_payload() {
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32]));
payload.push(WgDeviceAttrs::Fwmark(111));
payload.push(WgDeviceAttrs::ListenPort(12345));
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
for _ in 0..10_000 {
payload.push_peer(vec![
WgPeerAttrs::PublicKey([2u8; 32]),
WgPeerAttrs::PersistentKeepalive(25),
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
WgPeerAttrs::AllowedIps(vec![vec![
WgAllowedIpAttrs::Family(AF_INET),
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24)
]]),
]);
}
let messages = payload.finish();
assert!(messages.len() > 1);
for message in messages {
assert!(message.buffer_len() < MAX_NETLINK_BUFFER_LENGTH);
}
}
}