wireguard-control: break up large updates into multiple netlink messages
parent
4784a695ad
commit
92b60f535d
|
@ -1,7 +1,7 @@
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
mod linux {
|
mod linux {
|
||||||
|
pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096;
|
||||||
const NETLINK_BUFFER_LENGTH: usize = 4096;
|
pub const MAX_GENL_PAYLOAD_LENGTH: usize = MAX_NETLINK_BUFFER_LENGTH - GENL_HDRLEN;
|
||||||
|
|
||||||
use netlink_packet_core::{
|
use netlink_packet_core::{
|
||||||
NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK,
|
NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK,
|
||||||
|
@ -9,7 +9,7 @@ mod linux {
|
||||||
};
|
};
|
||||||
use netlink_packet_generic::{
|
use netlink_packet_generic::{
|
||||||
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
|
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
|
||||||
GenlFamily, GenlMessage,
|
GenlFamily, GenlMessage, constants::GENL_HDRLEN,
|
||||||
};
|
};
|
||||||
use netlink_packet_route::RtnlMessage;
|
use netlink_packet_route::RtnlMessage;
|
||||||
use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
|
use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
|
||||||
|
@ -82,16 +82,16 @@ mod linux {
|
||||||
{
|
{
|
||||||
let mut req = NetlinkMessage::from(message);
|
let mut req = NetlinkMessage::from(message);
|
||||||
|
|
||||||
if req.buffer_len() > NETLINK_BUFFER_LENGTH {
|
if req.buffer_len() > MAX_NETLINK_BUFFER_LENGTH {
|
||||||
return Err(io::Error::new(
|
return Err(io::Error::new(
|
||||||
io::ErrorKind::InvalidInput,
|
io::ErrorKind::InvalidInput,
|
||||||
format!("Serialized netlink packet larger than maximum size {}", NETLINK_BUFFER_LENGTH),
|
format!("Serialized netlink packet larger than maximum size {}", MAX_NETLINK_BUFFER_LENGTH),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
|
req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
|
||||||
req.finalize();
|
req.finalize();
|
||||||
let mut buf = [0; NETLINK_BUFFER_LENGTH];
|
let mut buf = [0; MAX_NETLINK_BUFFER_LENGTH];
|
||||||
req.serialize(&mut buf);
|
req.serialize(&mut buf);
|
||||||
let len = req.buffer_len();
|
let len = req.buffer_len();
|
||||||
|
|
||||||
|
@ -133,4 +133,4 @@ mod linux {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl};
|
pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl, MAX_NETLINK_BUFFER_LENGTH, MAX_GENL_PAYLOAD_LENGTH};
|
||||||
|
|
|
@ -12,7 +12,7 @@ use netlink_packet_route::{
|
||||||
self,
|
self,
|
||||||
nlas::{Info, InfoKind},
|
nlas::{Info, InfoKind},
|
||||||
},
|
},
|
||||||
LinkMessage, RtnlMessage,
|
LinkMessage, RtnlMessage, traits::Emitable,
|
||||||
};
|
};
|
||||||
use netlink_packet_wireguard::{
|
use netlink_packet_wireguard::{
|
||||||
self,
|
self,
|
||||||
|
@ -20,7 +20,7 @@ use netlink_packet_wireguard::{
|
||||||
nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs},
|
nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs},
|
||||||
Wireguard, WireguardCmd,
|
Wireguard, WireguardCmd,
|
||||||
};
|
};
|
||||||
use netlink_request::{netlink_request_genl, netlink_request_rtnl};
|
use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH};
|
||||||
|
|
||||||
use std::{convert::TryFrom, io};
|
use std::{convert::TryFrom, io};
|
||||||
|
|
||||||
|
@ -216,33 +216,105 @@ fn add_del(iface: &InterfaceName, add: bool) -> io::Result<()> {
|
||||||
|
|
||||||
pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
|
pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
|
||||||
add_del(iface, true)?;
|
add_del(iface, true)?;
|
||||||
let mut nlas = vec![WgDeviceAttrs::IfName(iface.as_str_lossy().to_string())];
|
let mut payload = ApplyPayload::new(iface);
|
||||||
if let Some(Key(k)) = builder.private_key {
|
if let Some(Key(k)) = builder.private_key {
|
||||||
nlas.push(WgDeviceAttrs::PrivateKey(k));
|
payload.push(WgDeviceAttrs::PrivateKey(k));
|
||||||
}
|
}
|
||||||
if let Some(f) = builder.fwmark {
|
if let Some(f) = builder.fwmark {
|
||||||
nlas.push(WgDeviceAttrs::Fwmark(f));
|
payload.push(WgDeviceAttrs::Fwmark(f));
|
||||||
}
|
}
|
||||||
if let Some(f) = builder.listen_port {
|
if let Some(f) = builder.listen_port {
|
||||||
nlas.push(WgDeviceAttrs::ListenPort(f));
|
payload.push(WgDeviceAttrs::ListenPort(f));
|
||||||
}
|
}
|
||||||
if builder.replace_peers {
|
if builder.replace_peers {
|
||||||
nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
|
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
|
||||||
}
|
}
|
||||||
let peers: Vec<Vec<_>> = builder
|
|
||||||
.peers
|
builder.peers
|
||||||
.iter()
|
.iter()
|
||||||
.map(PeerConfigBuilder::to_attrs)
|
.for_each(|peer| {
|
||||||
.collect();
|
payload.push_peer(peer.to_attrs())
|
||||||
nlas.push(WgDeviceAttrs::Peers(peers));
|
});
|
||||||
let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
|
|
||||||
cmd: WireguardCmd::SetDevice,
|
for message in payload.finish() {
|
||||||
nlas,
|
netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?;
|
||||||
});
|
}
|
||||||
netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ApplyPayload {
|
||||||
|
iface: String,
|
||||||
|
nlas: Vec<WgDeviceAttrs>,
|
||||||
|
current_buffer_len: usize,
|
||||||
|
messages: Vec<GenlMessage<Wireguard>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApplyPayload {
|
||||||
|
fn new(iface: &InterfaceName) -> Self {
|
||||||
|
Self {
|
||||||
|
iface: iface.as_str_lossy().to_string(),
|
||||||
|
nlas: vec![],
|
||||||
|
messages: vec![],
|
||||||
|
current_buffer_len: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush_nlas(&mut self) {
|
||||||
|
// cleanup: clear out any empty peer lists.
|
||||||
|
self.nlas.retain(|nla| !matches!(nla, WgDeviceAttrs::Peers(peers) if peers.len() == 0));
|
||||||
|
|
||||||
|
let name = WgDeviceAttrs::IfName(self.iface.clone());
|
||||||
|
self.current_buffer_len = name.buffer_len();
|
||||||
|
|
||||||
|
if !self.nlas.is_empty() {
|
||||||
|
self.messages.push(GenlMessage::from_payload(Wireguard {
|
||||||
|
cmd: WireguardCmd::SetDevice,
|
||||||
|
nlas: std::mem::replace(&mut self.nlas, vec![name]),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Push a device attribute which will be optimally packed into 1 or more netlink messages
|
||||||
|
pub fn push(&mut self, nla: WgDeviceAttrs) {
|
||||||
|
let nla_buffer_len = nla.buffer_len();
|
||||||
|
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
|
||||||
|
self.flush_nlas();
|
||||||
|
}
|
||||||
|
self.nlas.push(nla);
|
||||||
|
self.current_buffer_len += nla_buffer_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A helper function to assist in breaking up large peer lists across multiple netlink messages
|
||||||
|
pub fn push_peer(&mut self, peer: Vec<WgPeerAttrs>) {
|
||||||
|
const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
|
||||||
|
let mut needs_peer_nla = !self.nlas.iter().any(|nla| matches!(nla, WgDeviceAttrs::Peers(_)));
|
||||||
|
let peer_buffer_len = peer.as_slice().buffer_len() + 4;
|
||||||
|
let additional_buffer_len = peer_buffer_len + if needs_peer_nla { EMPTY_PEERS.buffer_len() } else { 0 };
|
||||||
|
if (self.current_buffer_len + additional_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
|
||||||
|
self.flush_nlas();
|
||||||
|
needs_peer_nla = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if needs_peer_nla {
|
||||||
|
self.push(EMPTY_PEERS);
|
||||||
|
}
|
||||||
|
let peers_nla = self.nlas.iter_mut().find_map(|nla| {
|
||||||
|
match nla {
|
||||||
|
WgDeviceAttrs::Peers(peers) => Some(peers),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}).expect("WgDeviceAttrs::Peers missing from NLAs when it should exist.");
|
||||||
|
|
||||||
|
peers_nla.push(peer);
|
||||||
|
self.current_buffer_len += peer_buffer_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish(mut self) -> Vec<GenlMessage<Wireguard>> {
|
||||||
|
self.flush_nlas();
|
||||||
|
self.messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
|
pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
|
||||||
let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
|
let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
|
||||||
cmd: WireguardCmd::GetDevice,
|
cmd: WireguardCmd::GetDevice,
|
||||||
|
@ -265,3 +337,57 @@ pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
|
||||||
pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
|
pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
|
||||||
add_del(iface, false)
|
add_del(iface, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use netlink_request::MAX_NETLINK_BUFFER_LENGTH;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simple_payload() {
|
||||||
|
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
|
||||||
|
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32]));
|
||||||
|
payload.push(WgDeviceAttrs::Fwmark(111));
|
||||||
|
payload.push(WgDeviceAttrs::ListenPort(12345));
|
||||||
|
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
|
||||||
|
payload.push_peer(vec![
|
||||||
|
WgPeerAttrs::PublicKey([2u8; 32]),
|
||||||
|
WgPeerAttrs::PersistentKeepalive(25),
|
||||||
|
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
|
||||||
|
WgPeerAttrs::AllowedIps(vec![vec![
|
||||||
|
WgAllowedIpAttrs::Family(AF_INET),
|
||||||
|
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
|
||||||
|
WgAllowedIpAttrs::Cidr(24)
|
||||||
|
]]),
|
||||||
|
]);
|
||||||
|
assert_eq!(payload.finish().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_massive_payload() {
|
||||||
|
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
|
||||||
|
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32]));
|
||||||
|
payload.push(WgDeviceAttrs::Fwmark(111));
|
||||||
|
payload.push(WgDeviceAttrs::ListenPort(12345));
|
||||||
|
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
payload.push_peer(vec![
|
||||||
|
WgPeerAttrs::PublicKey([2u8; 32]),
|
||||||
|
WgPeerAttrs::PersistentKeepalive(25),
|
||||||
|
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
|
||||||
|
WgPeerAttrs::AllowedIps(vec![vec![
|
||||||
|
WgAllowedIpAttrs::Family(AF_INET),
|
||||||
|
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
|
||||||
|
WgAllowedIpAttrs::Cidr(24)
|
||||||
|
]]),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages = payload.finish();
|
||||||
|
assert!(messages.len() > 1);
|
||||||
|
for message in messages {
|
||||||
|
assert!(message.buffer_len() < MAX_NETLINK_BUFFER_LENGTH);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue