From a21928c30cd900e2afdce35962d00aea4216d48e Mon Sep 17 00:00:00 2001 From: Jake McGinty Date: Mon, 10 Jan 2022 20:29:41 -0600 Subject: [PATCH] wireguard-control: add more checks on peer/NLA sizes --- netlink-request/src/lib.rs | 3 +- wireguard-control/src/backends/kernel.rs | 109 +++++++++++++++-------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/netlink-request/src/lib.rs b/netlink-request/src/lib.rs index a8e0ef8..198c456 100644 --- a/netlink-request/src/lib.rs +++ b/netlink-request/src/lib.rs @@ -87,7 +87,8 @@ mod linux { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!( - "Serialized netlink packet larger than maximum size {}", + "Serialized netlink packet ({} bytes) larger than maximum size {}", + req.buffer_len(), MAX_NETLINK_BUFFER_LENGTH ), )); diff --git a/wireguard-control/src/backends/kernel.rs b/wireguard-control/src/backends/kernel.rs index 711c0da..e46d3fe 100644 --- a/wireguard-control/src/backends/kernel.rs +++ b/wireguard-control/src/backends/kernel.rs @@ -219,22 +219,23 @@ pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> { add_del(iface, true)?; let mut payload = ApplyPayload::new(iface); if let Some(Key(k)) = builder.private_key { - payload.push(WgDeviceAttrs::PrivateKey(k)); + payload.push(WgDeviceAttrs::PrivateKey(k))?; } if let Some(f) = builder.fwmark { - payload.push(WgDeviceAttrs::Fwmark(f)); + payload.push(WgDeviceAttrs::Fwmark(f))?; } if let Some(f) = builder.listen_port { - payload.push(WgDeviceAttrs::ListenPort(f)); + payload.push(WgDeviceAttrs::ListenPort(f))?; } if builder.replace_peers { - payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); + payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))?; } builder .peers .iter() - .for_each(|peer| payload.push_peer(peer.to_attrs())); + .map(|peer| payload.push_peer(peer.to_attrs())) + .collect::, _>>()?; for message in payload.finish() { netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?; @@ -278,17 +279,29 @@ impl ApplyPayload { } /// Push a device attribute which will be optimally packed into 1 or more netlink messages - pub fn push(&mut self, nla: WgDeviceAttrs) { + pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> { let nla_buffer_len = nla.buffer_len(); if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { self.flush_nlas(); } + + // If the NLA *still* doesn't fit... + if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "encoded NLA ({} bytes) is too large: {:?}", + nla_buffer_len, nla + ), + )); + } self.nlas.push(nla); self.current_buffer_len += nla_buffer_len; + Ok(()) } /// A helper function to assist in breaking up large peer lists across multiple netlink messages - pub fn push_peer(&mut self, peer: Vec) { + pub fn push_peer(&mut self, peer: Vec) -> io::Result<()> { const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); let mut needs_peer_nla = !self .nlas @@ -305,8 +318,20 @@ impl ApplyPayload { } if needs_peer_nla { - self.push(EMPTY_PEERS); + self.push(EMPTY_PEERS)?; } + + // If the peer *still* doesn't fit... + if (self.current_buffer_len + peer_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "encoded peer ({} bytes) is too large: {:?}", + peer_buffer_len, peer + ), + )); + } + let peers_nla = self .nlas .iter_mut() @@ -318,6 +343,8 @@ impl ApplyPayload { peers_nla.push(peer); self.current_buffer_len += peer_buffer_len; + + Ok(()) } pub fn finish(mut self) -> Vec> { @@ -358,32 +385,14 @@ mod tests { #[test] fn test_simple_payload() { let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap()); - payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])); - payload.push(WgDeviceAttrs::Fwmark(111)); - payload.push(WgDeviceAttrs::ListenPort(12345)); - payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); - payload.push_peer(vec![ - WgPeerAttrs::PublicKey([2u8; 32]), - WgPeerAttrs::PersistentKeepalive(25), - WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), - WgPeerAttrs::AllowedIps(vec![vec![ - WgAllowedIpAttrs::Family(AF_INET), - WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), - WgAllowedIpAttrs::Cidr(24), - ]]), - ]); - assert_eq!(payload.finish().len(), 1); - } - - #[test] - fn test_massive_payload() { - let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap()); - payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])); - payload.push(WgDeviceAttrs::Fwmark(111)); - payload.push(WgDeviceAttrs::ListenPort(12345)); - payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); - for _ in 0..10_000 { - payload.push_peer(vec![ + payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap(); + payload.push(WgDeviceAttrs::Fwmark(111)).unwrap(); + payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap(); + payload + .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)) + .unwrap(); + payload + .push_peer(vec![ WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), @@ -392,13 +401,41 @@ mod tests { WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::Cidr(24), ]]), - ]); + ]) + .unwrap(); + assert_eq!(payload.finish().len(), 1); + } + + #[test] + fn test_massive_payload() { + let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap()); + payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap(); + payload.push(WgDeviceAttrs::Fwmark(111)).unwrap(); + payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap(); + payload + .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)) + .unwrap(); + + for _ in 0..10_000 { + payload + .push_peer(vec![ + WgPeerAttrs::PublicKey([2u8; 32]), + WgPeerAttrs::PersistentKeepalive(25), + WgPeerAttrs::PresharedKey([1u8; 32]), + WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), + WgPeerAttrs::AllowedIps(vec![vec![ + WgAllowedIpAttrs::Family(AF_INET), + WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), + WgAllowedIpAttrs::Cidr(24), + ]]), + ]) + .unwrap(); } let messages = payload.finish(); assert!(messages.len() > 1); for message in messages { - assert!(message.buffer_len() < MAX_NETLINK_BUFFER_LENGTH); + assert!(message.buffer_len() <= MAX_NETLINK_BUFFER_LENGTH); } } }