diff --git a/Cargo.lock b/Cargo.lock index a2b47d5..7c72e44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -636,6 +636,19 @@ dependencies = [ "netlink-packet-utils", ] +[[package]] +name = "netlink-packet-wireguard" +version = "0.1.1" +source = "git+https://github.com/little-dude/netlink?rev=b2bdd6295209c84ef95f85f66c03b55234d77ad6#b2bdd6295209c84ef95f85f66c03b55234d77ad6" +dependencies = [ + "anyhow", + "byteorder", + "libc", + "log", + "netlink-packet-generic", + "netlink-packet-utils", +] + [[package]] name = "netlink-request" version = "1.5.3" @@ -643,7 +656,7 @@ dependencies = [ "netlink-packet-core", "netlink-packet-generic", "netlink-packet-route", - "netlink-packet-wireguard", + "netlink-packet-wireguard 0.1.1 (git+https://github.com/mcginty/netlink?branch=wireguard-fixes)", "netlink-sys", ] @@ -1343,7 +1356,7 @@ dependencies = [ "netlink-packet-core", "netlink-packet-generic", "netlink-packet-route", - "netlink-packet-wireguard", + "netlink-packet-wireguard 0.1.1 (git+https://github.com/little-dude/netlink?rev=b2bdd6295209c84ef95f85f66c03b55234d77ad6)", "netlink-request", "netlink-sys", "rand_core", diff --git a/wireguard-control/Cargo.toml b/wireguard-control/Cargo.toml index 753abe3..460e75c 100644 --- a/wireguard-control/Cargo.toml +++ b/wireguard-control/Cargo.toml @@ -23,4 +23,4 @@ netlink-sys = "0.8" netlink-packet-core = "0.4" netlink-packet-generic = "0.3" netlink-packet-route = "0.10" -netlink-packet-wireguard = { git = "https://github.com/mcginty/netlink", branch = "wireguard-fixes" } +netlink-packet-wireguard = { git = "https://github.com/little-dude/netlink", rev = "b2bdd6295209c84ef95f85f66c03b55234d77ad6" } diff --git a/wireguard-control/src/backends/kernel.rs b/wireguard-control/src/backends/kernel.rs index 042be91..2e22320 100644 --- a/wireguard-control/src/backends/kernel.rs +++ b/wireguard-control/src/backends/kernel.rs @@ -18,7 +18,7 @@ use netlink_packet_route::{ use netlink_packet_wireguard::{ self, constants::{WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME, WGPEER_F_REPLACE_ALLOWEDIPS}, - nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs}, + nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs}, Wireguard, WireguardCmd, }; use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH}; @@ -34,10 +34,10 @@ macro_rules! get_nla_value { }; } -impl<'a> TryFrom> for AllowedIp { +impl<'a> TryFrom for AllowedIp { type Error = io::Error; - fn try_from(attrs: Vec) -> Result { + fn try_from(attrs: WgAllowedIp) -> Result { let address = *get_nla_value!(attrs, WgAllowedIpAttrs, IpAddr) .ok_or_else(|| io::ErrorKind::NotFound)?; let cidr = *get_nla_value!(attrs, WgAllowedIpAttrs, Cidr) @@ -47,8 +47,8 @@ impl<'a> TryFrom> for AllowedIp { } impl AllowedIp { - fn to_attrs(&self) -> Vec { - vec![ + fn to_nla(&self) -> WgAllowedIp { + WgAllowedIp(vec![ WgAllowedIpAttrs::Family(if self.address.is_ipv4() { AF_INET } else { @@ -56,12 +56,12 @@ impl AllowedIp { }), WgAllowedIpAttrs::IpAddr(self.address), WgAllowedIpAttrs::Cidr(self.cidr), - ] + ]) } } impl PeerConfigBuilder { - fn to_attrs(&self) -> Vec { + fn to_nla(&self) -> WgPeer { let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.0)]; let mut flags = 0u32; if let Some(endpoint) = self.endpoint { @@ -73,7 +73,7 @@ impl PeerConfigBuilder { if let Some(i) = self.persistent_keepalive_interval { attrs.push(WgPeerAttrs::PersistentKeepalive(i)); } - let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_attrs).collect(); + let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_nla).collect(); attrs.push(WgPeerAttrs::AllowedIps(allowed_ips)); if self.remove_me { flags |= WGPEER_F_REMOVE_ME; @@ -84,14 +84,14 @@ impl PeerConfigBuilder { if flags != 0 { attrs.push(WgPeerAttrs::Flags(flags)); } - attrs + WgPeer(attrs) } } -impl<'a> TryFrom> for PeerInfo { +impl<'a> TryFrom for PeerInfo { type Error = io::Error; - fn try_from(attrs: Vec) -> Result { + fn try_from(attrs: WgPeer) -> Result { let public_key = get_nla_value!(attrs, WgPeerAttrs, PublicKey) .map(|key| Key(*key)) .ok_or(io::ErrorKind::NotFound)?; @@ -236,7 +236,7 @@ pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> { builder .peers .iter() - .map(|peer| payload.push_peer(peer.to_attrs())) + .map(|peer| payload.push_peer(peer.to_nla())) .collect::, _>>()?; for message in payload.finish() { @@ -306,13 +306,13 @@ impl ApplyPayload { } /// A helper function to assist in breaking up large peer lists across multiple netlink messages - pub fn push_peer(&mut self, peer: Vec) -> io::Result<()> { + pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> { const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); let mut needs_peer_nla = !self .nlas .iter() .any(|nla| matches!(nla, WgDeviceAttrs::Peers(_))); - let peer_buffer_len = peer.as_slice().buffer_len() + 4; + let peer_buffer_len = peer.buffer_len(); let mut additional_buffer_len = peer_buffer_len; if needs_peer_nla { additional_buffer_len += EMPTY_PEERS.buffer_len(); @@ -402,6 +402,7 @@ pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { #[cfg(test)] mod tests { use super::*; + use netlink_packet_wireguard::nlas::WgAllowedIp; use netlink_request::MAX_NETLINK_BUFFER_LENGTH; use std::str::FromStr; @@ -415,17 +416,17 @@ mod tests { .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)) .unwrap(); payload - .push_peer(vec![ + .push_peer(WgPeer(vec![ WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), - WgPeerAttrs::AllowedIps(vec![vec![ + WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![ WgAllowedIpAttrs::Family(AF_INET), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::Cidr(24), - ]]), - ]) + ])]), + ])) .unwrap(); assert_eq!(payload.finish().len(), 1); } @@ -442,18 +443,18 @@ mod tests { for i in 0..10_000 { payload - .push_peer(vec![ + .push_peer(WgPeer(vec![ WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), - WgPeerAttrs::AllowedIps(vec![vec![ + WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![ WgAllowedIpAttrs::Family(AF_INET), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::Cidr(24), - ]]), + ])]), WgPeerAttrs::Unspec(vec![1u8; (i % 256) as usize]), - ]) + ])) .unwrap(); }