diff --git a/netlink-request/src/lib.rs b/netlink-request/src/lib.rs index 91d245e..61af252 100644 --- a/netlink-request/src/lib.rs +++ b/netlink-request/src/lib.rs @@ -1,11 +1,12 @@ #[cfg(target_os = "linux")] mod linux { pub const MAX_NETLINK_BUFFER_LENGTH: usize = 4096; - pub const MAX_GENL_PAYLOAD_LENGTH: usize = MAX_NETLINK_BUFFER_LENGTH - GENL_HDRLEN; + pub const MAX_GENL_PAYLOAD_LENGTH: usize = + MAX_NETLINK_BUFFER_LENGTH - NETLINK_HEADER_LEN - GENL_HDRLEN; use netlink_packet_core::{ - NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK, - NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, + NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, + NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, }; use netlink_packet_generic::{ constants::GENL_HDRLEN, diff --git a/wireguard-control/src/backends/kernel.rs b/wireguard-control/src/backends/kernel.rs index e46d3fe..0ea3ef7 100644 --- a/wireguard-control/src/backends/kernel.rs +++ b/wireguard-control/src/backends/kernel.rs @@ -396,6 +396,7 @@ mod tests { WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PersistentKeepalive(25), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), + WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), WgPeerAttrs::AllowedIps(vec![vec![ WgAllowedIpAttrs::Family(AF_INET), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), @@ -416,26 +417,28 @@ mod tests { .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)) .unwrap(); - for _ in 0..10_000 { + for i in 0..10_000 { payload .push_peer(vec![ WgPeerAttrs::PublicKey([2u8; 32]), WgPeerAttrs::PersistentKeepalive(25), - WgPeerAttrs::PresharedKey([1u8; 32]), WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()), + WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), WgPeerAttrs::AllowedIps(vec![vec![ WgAllowedIpAttrs::Family(AF_INET), WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()), WgAllowedIpAttrs::Cidr(24), ]]), + WgPeerAttrs::Unspec(vec![1u8; (i % 256) as usize]), ]) .unwrap(); } let messages = payload.finish(); + println!("generated {} messages", messages.len()); assert!(messages.len() > 1); for message in messages { - assert!(message.buffer_len() <= MAX_NETLINK_BUFFER_LENGTH); + assert!(NetlinkMessage::from(message).buffer_len() <= MAX_NETLINK_BUFFER_LENGTH); } } }