wireguard-control: add more checks on peer/NLA sizes

pull/186/head
Jake McGinty 2022-01-10 20:29:41 -06:00
parent a95fa1b03e
commit a21928c30c
2 changed files with 75 additions and 37 deletions

View File

@ -87,7 +87,8 @@ mod linux {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!( format!(
"Serialized netlink packet larger than maximum size {}", "Serialized netlink packet ({} bytes) larger than maximum size {}",
req.buffer_len(),
MAX_NETLINK_BUFFER_LENGTH MAX_NETLINK_BUFFER_LENGTH
), ),
)); ));

View File

@ -219,22 +219,23 @@ pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
add_del(iface, true)?; add_del(iface, true)?;
let mut payload = ApplyPayload::new(iface); let mut payload = ApplyPayload::new(iface);
if let Some(Key(k)) = builder.private_key { if let Some(Key(k)) = builder.private_key {
payload.push(WgDeviceAttrs::PrivateKey(k)); payload.push(WgDeviceAttrs::PrivateKey(k))?;
} }
if let Some(f) = builder.fwmark { if let Some(f) = builder.fwmark {
payload.push(WgDeviceAttrs::Fwmark(f)); payload.push(WgDeviceAttrs::Fwmark(f))?;
} }
if let Some(f) = builder.listen_port { if let Some(f) = builder.listen_port {
payload.push(WgDeviceAttrs::ListenPort(f)); payload.push(WgDeviceAttrs::ListenPort(f))?;
} }
if builder.replace_peers { if builder.replace_peers {
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))?;
} }
builder builder
.peers .peers
.iter() .iter()
.for_each(|peer| payload.push_peer(peer.to_attrs())); .map(|peer| payload.push_peer(peer.to_attrs()))
.collect::<Result<Vec<_>, _>>()?;
for message in payload.finish() { for message in payload.finish() {
netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?; netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?;
@ -278,17 +279,29 @@ impl ApplyPayload {
} }
/// Push a device attribute which will be optimally packed into 1 or more netlink messages /// Push a device attribute which will be optimally packed into 1 or more netlink messages
pub fn push(&mut self, nla: WgDeviceAttrs) { pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> {
let nla_buffer_len = nla.buffer_len(); let nla_buffer_len = nla.buffer_len();
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH { if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
self.flush_nlas(); self.flush_nlas();
} }
// If the NLA *still* doesn't fit...
if (self.current_buffer_len + nla_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"encoded NLA ({} bytes) is too large: {:?}",
nla_buffer_len, nla
),
));
}
self.nlas.push(nla); self.nlas.push(nla);
self.current_buffer_len += nla_buffer_len; self.current_buffer_len += nla_buffer_len;
Ok(())
} }
/// A helper function to assist in breaking up large peer lists across multiple netlink messages /// A helper function to assist in breaking up large peer lists across multiple netlink messages
pub fn push_peer(&mut self, peer: Vec<WgPeerAttrs>) { pub fn push_peer(&mut self, peer: Vec<WgPeerAttrs>) -> io::Result<()> {
const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]); const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
let mut needs_peer_nla = !self let mut needs_peer_nla = !self
.nlas .nlas
@ -305,8 +318,20 @@ impl ApplyPayload {
} }
if needs_peer_nla { if needs_peer_nla {
self.push(EMPTY_PEERS); self.push(EMPTY_PEERS)?;
} }
// If the peer *still* doesn't fit...
if (self.current_buffer_len + peer_buffer_len) > MAX_GENL_PAYLOAD_LENGTH {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"encoded peer ({} bytes) is too large: {:?}",
peer_buffer_len, peer
),
));
}
let peers_nla = self let peers_nla = self
.nlas .nlas
.iter_mut() .iter_mut()
@ -318,6 +343,8 @@ impl ApplyPayload {
peers_nla.push(peer); peers_nla.push(peer);
self.current_buffer_len += peer_buffer_len; self.current_buffer_len += peer_buffer_len;
Ok(())
} }
pub fn finish(mut self) -> Vec<GenlMessage<Wireguard>> { pub fn finish(mut self) -> Vec<GenlMessage<Wireguard>> {
@ -358,32 +385,14 @@ mod tests {
#[test] #[test]
fn test_simple_payload() { fn test_simple_payload() {
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap()); let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])); payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap();
payload.push(WgDeviceAttrs::Fwmark(111)); payload.push(WgDeviceAttrs::Fwmark(111)).unwrap();
payload.push(WgDeviceAttrs::ListenPort(12345)); payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap();
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); payload
payload.push_peer(vec![ .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
WgPeerAttrs::PublicKey([2u8; 32]), .unwrap();
WgPeerAttrs::PersistentKeepalive(25), payload
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), .push_peer(vec![
WgPeerAttrs::AllowedIps(vec![vec![
WgAllowedIpAttrs::Family(AF_INET),
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24),
]]),
]);
assert_eq!(payload.finish().len(), 1);
}
#[test]
fn test_massive_payload() {
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32]));
payload.push(WgDeviceAttrs::Fwmark(111));
payload.push(WgDeviceAttrs::ListenPort(12345));
payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
for _ in 0..10_000 {
payload.push_peer(vec![
WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PublicKey([2u8; 32]),
WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::PersistentKeepalive(25),
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
@ -392,13 +401,41 @@ mod tests {
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24), WgAllowedIpAttrs::Cidr(24),
]]), ]]),
]); ])
.unwrap();
assert_eq!(payload.finish().len(), 1);
}
#[test]
fn test_massive_payload() {
let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap();
payload.push(WgDeviceAttrs::Fwmark(111)).unwrap();
payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap();
payload
.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
.unwrap();
for _ in 0..10_000 {
payload
.push_peer(vec![
WgPeerAttrs::PublicKey([2u8; 32]),
WgPeerAttrs::PersistentKeepalive(25),
WgPeerAttrs::PresharedKey([1u8; 32]),
WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
WgPeerAttrs::AllowedIps(vec![vec![
WgAllowedIpAttrs::Family(AF_INET),
WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
WgAllowedIpAttrs::Cidr(24),
]]),
])
.unwrap();
} }
let messages = payload.finish(); let messages = payload.finish();
assert!(messages.len() > 1); assert!(messages.len() > 1);
for message in messages { for message in messages {
assert!(message.buffer_len() < MAX_NETLINK_BUFFER_LENGTH); assert!(message.buffer_len() <= MAX_NETLINK_BUFFER_LENGTH);
} }
} }
} }