wireguard-control: go back to using upstream netlink

pull/192/head
Jake McGinty 2022-02-01 05:40:36 +09:00
parent 061c6539e1
commit 2cb530762c
3 changed files with 39 additions and 25 deletions

17
Cargo.lock generated
View File

@ -636,6 +636,19 @@ dependencies = [
"netlink-packet-utils", "netlink-packet-utils",
] ]
[[package]]
name = "netlink-packet-wireguard"
version = "0.1.1"
source = "git+https://github.com/little-dude/netlink?rev=b2bdd6295209c84ef95f85f66c03b55234d77ad6#b2bdd6295209c84ef95f85f66c03b55234d77ad6"
dependencies = [
"anyhow",
"byteorder",
"libc",
"log",
"netlink-packet-generic",
"netlink-packet-utils",
]
[[package]] [[package]]
name = "netlink-request" name = "netlink-request"
version = "1.5.3" version = "1.5.3"
@ -643,7 +656,7 @@ dependencies = [
"netlink-packet-core", "netlink-packet-core",
"netlink-packet-generic", "netlink-packet-generic",
"netlink-packet-route", "netlink-packet-route",
"netlink-packet-wireguard", "netlink-packet-wireguard 0.1.1 (git+https://github.com/mcginty/netlink?branch=wireguard-fixes)",
"netlink-sys", "netlink-sys",
] ]
@ -1343,7 +1356,7 @@ dependencies = [
"netlink-packet-core", "netlink-packet-core",
"netlink-packet-generic", "netlink-packet-generic",
"netlink-packet-route", "netlink-packet-route",
"netlink-packet-wireguard", "netlink-packet-wireguard 0.1.1 (git+https://github.com/little-dude/netlink?rev=b2bdd6295209c84ef95f85f66c03b55234d77ad6)",
"netlink-request", "netlink-request",
"netlink-sys", "netlink-sys",
"rand_core", "rand_core",

View File

@ -23,4 +23,4 @@ netlink-sys = "0.8"
netlink-packet-core = "0.4" netlink-packet-core = "0.4"
netlink-packet-generic = "0.3" netlink-packet-generic = "0.3"
netlink-packet-route = "0.10" netlink-packet-route = "0.10"
netlink-packet-wireguard = { git = "https://github.com/mcginty/netlink", branch = "wireguard-fixes" } netlink-packet-wireguard = { git = "https://github.com/little-dude/netlink", rev = "b2bdd6295209c84ef95f85f66c03b55234d77ad6" }

View File

@ -18,7 +18,7 @@ use netlink_packet_route::{
use netlink_packet_wireguard::{ use netlink_packet_wireguard::{
self, self,
constants::{WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME, WGPEER_F_REPLACE_ALLOWEDIPS}, constants::{WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME, WGPEER_F_REPLACE_ALLOWEDIPS},
nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs}, nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
Wireguard, WireguardCmd, Wireguard, WireguardCmd,
}; };
use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH}; use netlink_request::{netlink_request_genl, netlink_request_rtnl, MAX_GENL_PAYLOAD_LENGTH};
@ -34,10 +34,10 @@ macro_rules! get_nla_value {
}; };
} }
impl<'a> TryFrom<Vec<WgAllowedIpAttrs>> for AllowedIp { impl<'a> TryFrom<WgAllowedIp> for AllowedIp {
type Error = io::Error; type Error = io::Error;
fn try_from(attrs: Vec<WgAllowedIpAttrs>) -> Result<Self, Self::Error> { fn try_from(attrs: WgAllowedIp) -> Result<Self, Self::Error> {
let address = *get_nla_value!(attrs, WgAllowedIpAttrs, IpAddr) let address = *get_nla_value!(attrs, WgAllowedIpAttrs, IpAddr)
.ok_or_else(|| io::ErrorKind::NotFound)?; .ok_or_else(|| io::ErrorKind::NotFound)?;
let cidr = *get_nla_value!(attrs, WgAllowedIpAttrs, Cidr) let cidr = *get_nla_value!(attrs, WgAllowedIpAttrs, Cidr)
@ -47,8 +47,8 @@ impl<'a> TryFrom<Vec<WgAllowedIpAttrs>> for AllowedIp {
} }
impl AllowedIp { impl AllowedIp {
fn to_attrs(&self) -> Vec<WgAllowedIpAttrs> { fn to_nla(&self) -> WgAllowedIp {
vec![ WgAllowedIp(vec![
WgAllowedIpAttrs::Family(if self.address.is_ipv4() { WgAllowedIpAttrs::Family(if self.address.is_ipv4() {
AF_INET AF_INET
} else { } else {
@ -56,12 +56,12 @@ impl AllowedIp {
}), }),
WgAllowedIpAttrs::IpAddr(self.address), WgAllowedIpAttrs::IpAddr(self.address),
WgAllowedIpAttrs::Cidr(self.cidr), WgAllowedIpAttrs::Cidr(self.cidr),
] ])
} }
} }
impl PeerConfigBuilder { impl PeerConfigBuilder {
fn to_attrs(&self) -> Vec<WgPeerAttrs> { fn to_nla(&self) -> WgPeer {
let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.0)]; let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.0)];
let mut flags = 0u32; let mut flags = 0u32;
if let Some(endpoint) = self.endpoint { if let Some(endpoint) = self.endpoint {
@ -73,7 +73,7 @@ impl PeerConfigBuilder {
if let Some(i) = self.persistent_keepalive_interval { if let Some(i) = self.persistent_keepalive_interval {
attrs.push(WgPeerAttrs::PersistentKeepalive(i)); attrs.push(WgPeerAttrs::PersistentKeepalive(i));
} }
let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_attrs).collect(); let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_nla).collect();
attrs.push(WgPeerAttrs::AllowedIps(allowed_ips)); attrs.push(WgPeerAttrs::AllowedIps(allowed_ips));
if self.remove_me { if self.remove_me {
flags |= WGPEER_F_REMOVE_ME; flags |= WGPEER_F_REMOVE_ME;
@ -84,14 +84,14 @@ impl PeerConfigBuilder {
if flags != 0 { if flags != 0 {
attrs.push(WgPeerAttrs::Flags(flags)); attrs.push(WgPeerAttrs::Flags(flags));
} }
attrs WgPeer(attrs)
} }
} }
impl<'a> TryFrom<Vec<WgPeerAttrs>> for PeerInfo { impl<'a> TryFrom<WgPeer> for PeerInfo {
type Error = io::Error; type Error = io::Error;
fn try_from(attrs: Vec<WgPeerAttrs>) -> Result<Self, Self::Error> { fn try_from(attrs: WgPeer) -> Result<Self, Self::Error> {
let public_key = get_nla_value!(attrs, WgPeerAttrs, PublicKey) let public_key = get_nla_value!(attrs, WgPeerAttrs, PublicKey)
.map(|key| Key(*key)) .map(|key| Key(*key))
.ok_or(io::ErrorKind::NotFound)?; .ok_or(io::ErrorKind::NotFound)?;
@ -236,7 +236,7 @@ pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
builder builder
.peers .peers
.iter() .iter()
.map(|peer| payload.push_peer(peer.to_attrs())) .map(|peer| payload.push_peer(peer.to_nla()))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
for message in payload.finish() { for message in payload.finish() {
@ -306,13 +306,13 @@ impl ApplyPayload {
} }
/// A helper function to assist in breaking up large peer lists across multiple netlink messages /// A helper function to assist in breaking up large peer lists across multiple netlink messages
pub fn push_peer(&mut self, peer: Vec<WgPeerAttrs>) -> io::Result<()> { pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> {
const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
let mut needs_peer_nla = !self let mut needs_peer_nla = !self
.nlas .nlas
.iter() .iter()
.any(|nla| matches!(nla, WgDeviceAttrs::Peers(_))); .any(|nla| matches!(nla, WgDeviceAttrs::Peers(_)));
let peer_buffer_len = peer.as_slice().buffer_len() + 4; let peer_buffer_len = peer.buffer_len();
let mut additional_buffer_len = peer_buffer_len; let mut additional_buffer_len = peer_buffer_len;
if needs_peer_nla { if needs_peer_nla {
additional_buffer_len += EMPTY_PEERS.buffer_len(); additional_buffer_len += EMPTY_PEERS.buffer_len();
@ -402,6 +402,7 @@ pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use netlink_packet_wireguard::nlas::WgAllowedIp;
use netlink_request::MAX_NETLINK_BUFFER_LENGTH; use netlink_request::MAX_NETLINK_BUFFER_LENGTH;
use std::str::FromStr; use std::str::FromStr;
@ -415,17 +416,17 @@ mod tests {
.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)) .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
.unwrap(); .unwrap();
payload payload
.push_peer(vec![ .push_peer(WgPeer(vec![
WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PublicKey([2u8; 32]),
WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::PersistentKeepalive(25),
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
WgPeerAttrs::AllowedIps(vec![vec![ WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
WgAllowedIpAttrs::Family(AF_INET), WgAllowedIpAttrs::Family(AF_INET),
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24), WgAllowedIpAttrs::Cidr(24),
]]), ])]),
]) ]))
.unwrap(); .unwrap();
assert_eq!(payload.finish().len(), 1); assert_eq!(payload.finish().len(), 1);
} }
@ -442,18 +443,18 @@ mod tests {
for i in 0..10_000 { for i in 0..10_000 {
payload payload
.push_peer(vec![ .push_peer(WgPeer(vec![
WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PublicKey([2u8; 32]),
WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::PersistentKeepalive(25),
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
WgPeerAttrs::AllowedIps(vec![vec![ WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
WgAllowedIpAttrs::Family(AF_INET), WgAllowedIpAttrs::Family(AF_INET),
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24), WgAllowedIpAttrs::Cidr(24),
]]), ])]),
WgPeerAttrs::Unspec(vec![1u8; (i % 256) as usize]), WgPeerAttrs::Unspec(vec![1u8; (i % 256) as usize]),
]) ]))
.unwrap(); .unwrap();
} }